sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/bailing_moe.py
CHANGED
@@ -1,377 +1,907 @@
|
|
1
|
-
#
|
2
|
-
#
|
3
|
-
|
4
|
-
|
5
|
-
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5
|
+
# and OPT implementations in this library. It has been modified from its
|
6
|
+
# original forms to accommodate minor architectural differences compared
|
7
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8
|
+
#
|
9
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10
|
+
# you may not use this file except in compliance with the License.
|
11
|
+
# You may obtain a copy of the License at
|
12
|
+
#
|
13
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14
|
+
#
|
15
|
+
# Unless required by applicable law or agreed to in writing, software
|
16
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18
|
+
# See the License for the specific language governing permissions and
|
19
|
+
# limitations under the License.
|
20
|
+
""" SGLang BailingMoE model."""
|
21
|
+
import logging
|
22
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
6
23
|
|
7
24
|
import torch
|
8
25
|
import torch.nn.functional as F
|
9
26
|
from torch import nn
|
10
|
-
from transformers
|
27
|
+
from transformers import PretrainedConfig
|
11
28
|
|
12
29
|
from sglang.srt.distributed import (
|
30
|
+
get_pp_group,
|
13
31
|
get_tensor_model_parallel_world_size,
|
32
|
+
parallel_state,
|
14
33
|
tensor_model_parallel_all_reduce,
|
15
34
|
)
|
35
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
36
|
+
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
37
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
16
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.communicator import (
|
40
|
+
LayerCommunicator,
|
41
|
+
LayerScatterModes,
|
42
|
+
enable_moe_dense_fully_dp,
|
43
|
+
)
|
44
|
+
from sglang.srt.layers.dp_attention import (
|
45
|
+
get_attention_dp_size,
|
46
|
+
get_attention_tp_rank,
|
47
|
+
get_attention_tp_size,
|
48
|
+
is_dp_attention_enabled,
|
49
|
+
)
|
17
50
|
from sglang.srt.layers.layernorm import RMSNorm
|
18
51
|
from sglang.srt.layers.linear import (
|
19
52
|
MergedColumnParallelLinear,
|
20
53
|
QKVParallelLinear,
|
21
|
-
ReplicatedLinear,
|
22
54
|
RowParallelLinear,
|
23
55
|
)
|
24
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
-
from sglang.srt.layers.moe
|
57
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
58
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
59
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
60
|
+
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
26
61
|
from sglang.srt.layers.moe.topk import TopK
|
62
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
27
63
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
64
|
from sglang.srt.layers.radix_attention import RadixAttention
|
29
65
|
from sglang.srt.layers.rotary_embedding import get_rope
|
66
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
30
67
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
31
68
|
ParallelLMHead,
|
32
69
|
VocabParallelEmbedding,
|
33
70
|
)
|
34
|
-
from sglang.srt.
|
71
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
72
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
73
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
36
|
-
from sglang.srt.utils import
|
75
|
+
from sglang.srt.models.utils import (
|
76
|
+
create_fused_set_kv_buffer_arg,
|
77
|
+
enable_fused_set_kv_buffer,
|
78
|
+
)
|
79
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
37
80
|
|
81
|
+
LoraConfig = None
|
82
|
+
logger = logging.getLogger(__name__)
|
83
|
+
_is_cuda = is_cuda()
|
38
84
|
|
39
|
-
class BailingAttention(nn.Module):
|
40
85
|
|
86
|
+
class BailingMoEMLP(nn.Module):
|
41
87
|
def __init__(
|
42
88
|
self,
|
89
|
+
intermediate_size: int,
|
43
90
|
config: PretrainedConfig,
|
44
|
-
layer_id: int = 0,
|
45
91
|
quant_config: Optional[QuantizationConfig] = None,
|
92
|
+
reduce_results: Optional[bool] = True,
|
46
93
|
prefix: str = "",
|
47
|
-
|
94
|
+
tp_rank: Optional[int] = None,
|
95
|
+
tp_size: Optional[int] = None,
|
96
|
+
) -> None:
|
48
97
|
super().__init__()
|
49
|
-
self.
|
50
|
-
tp_size = get_tensor_model_parallel_world_size()
|
51
|
-
|
52
|
-
self.total_num_heads = config.num_attention_heads
|
53
|
-
self.total_num_kv_heads = config.num_key_value_heads
|
54
|
-
|
55
|
-
assert self.total_num_heads % tp_size == 0
|
56
|
-
assert self.total_num_kv_heads % tp_size == 0
|
57
|
-
|
58
|
-
self.num_heads = self.total_num_heads // tp_size
|
59
|
-
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
60
|
-
self.q_size = self.num_heads * self.head_dim
|
61
|
-
|
62
|
-
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
63
|
-
self.kv_size = self.num_kv_heads * self.head_dim
|
64
|
-
self.scale = self.head_dim**-0.5
|
65
|
-
|
66
|
-
self.query_key_value = QKVParallelLinear(
|
67
|
-
self.hidden_size,
|
68
|
-
self.head_dim,
|
69
|
-
self.total_num_heads,
|
70
|
-
self.total_num_kv_heads,
|
71
|
-
bias=(config.use_bias or config.use_qkv_bias),
|
72
|
-
quant_config=quant_config,
|
73
|
-
prefix=add_prefix("query_key_value", prefix),
|
74
|
-
)
|
98
|
+
self.tp_size = tp_size
|
75
99
|
|
76
|
-
self.
|
77
|
-
|
78
|
-
|
100
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
101
|
+
config.hidden_size,
|
102
|
+
[intermediate_size] * 2,
|
79
103
|
bias=config.use_bias,
|
80
104
|
quant_config=quant_config,
|
81
|
-
prefix=add_prefix("
|
105
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
106
|
+
tp_rank=tp_rank,
|
107
|
+
tp_size=tp_size,
|
82
108
|
)
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
num_kv_heads=self.num_kv_heads,
|
89
|
-
layer_id=layer_id,
|
109
|
+
self.down_proj = RowParallelLinear(
|
110
|
+
intermediate_size,
|
111
|
+
config.hidden_size,
|
112
|
+
bias=config.use_bias,
|
113
|
+
reduce_results=reduce_results,
|
90
114
|
quant_config=quant_config,
|
91
|
-
prefix=add_prefix("
|
115
|
+
prefix=add_prefix("down_proj", prefix),
|
116
|
+
tp_rank=tp_rank,
|
117
|
+
tp_size=tp_size,
|
92
118
|
)
|
93
119
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
max_position=config.max_position_embeddings,
|
98
|
-
base=config.rope_theta,
|
99
|
-
is_neox_style=True,
|
100
|
-
rope_scaling=config.rope_scaling,
|
101
|
-
)
|
120
|
+
if config.hidden_act != "silu":
|
121
|
+
raise ValueError("Unsupported activation. Only silu is supported for now.")
|
122
|
+
self.act_fn = SiluAndMul()
|
102
123
|
|
103
124
|
def forward(
|
104
125
|
self,
|
105
126
|
hidden_states: torch.Tensor,
|
106
|
-
|
107
|
-
|
127
|
+
forward_batch: Optional[ForwardBatch] = None,
|
128
|
+
use_reduce_scatter: bool = False,
|
108
129
|
) -> torch.Tensor:
|
109
|
-
|
110
|
-
|
130
|
+
if (self.tp_size == 1) and hidden_states.shape[0] == 0:
|
131
|
+
return hidden_states
|
111
132
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
133
|
+
gate_up, _ = self.gate_up_proj(hidden_states)
|
134
|
+
hidden_states = self.act_fn(gate_up)
|
135
|
+
hidden_states, _ = self.down_proj(
|
136
|
+
hidden_states, skip_all_reduce=use_reduce_scatter
|
137
|
+
)
|
138
|
+
return hidden_states
|
116
139
|
|
117
140
|
|
118
|
-
class
|
141
|
+
class BailingMoEGate(nn.Module):
|
119
142
|
def __init__(
|
120
143
|
self,
|
121
|
-
|
122
|
-
|
123
|
-
quant_config: Optional[QuantizationConfig] = None,
|
124
|
-
reduce_results: Optional[bool] = True,
|
144
|
+
config,
|
145
|
+
params_dtype: Optional[torch.dtype] = None,
|
125
146
|
prefix: str = "",
|
126
|
-
)
|
147
|
+
):
|
127
148
|
super().__init__()
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
intermediate_size,
|
137
|
-
config.hidden_size,
|
138
|
-
bias=config.use_bias,
|
139
|
-
quant_config=quant_config,
|
140
|
-
reduce_results=reduce_results,
|
141
|
-
prefix=add_prefix("down_proj", prefix),
|
149
|
+
if params_dtype is None:
|
150
|
+
params_dtype = torch.get_default_dtype()
|
151
|
+
self.params_dtype = params_dtype
|
152
|
+
self.weight = nn.Parameter(
|
153
|
+
torch.empty(
|
154
|
+
(config.num_experts, config.hidden_size),
|
155
|
+
dtype=self.params_dtype,
|
156
|
+
),
|
142
157
|
)
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
return x
|
158
|
+
if getattr(config, "moe_router_enable_expert_bias", False):
|
159
|
+
self.expert_bias = nn.Parameter(
|
160
|
+
torch.empty((config.num_experts,), dtype=torch.float32),
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
self.expert_bias = None
|
150
164
|
|
165
|
+
def forward(self, hidden_states):
|
166
|
+
logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
|
167
|
+
hidden_states.dtype
|
168
|
+
)
|
169
|
+
return logits
|
151
170
|
|
152
|
-
class BailingMoE(nn.Module):
|
153
171
|
|
172
|
+
class BailingMoESparseMoeBlock(nn.Module):
|
154
173
|
def __init__(
|
155
174
|
self,
|
156
|
-
config: PretrainedConfig,
|
157
175
|
layer_id: int,
|
176
|
+
config: PretrainedConfig,
|
158
177
|
quant_config: Optional[QuantizationConfig] = None,
|
178
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
159
179
|
prefix: str = "",
|
160
180
|
):
|
161
181
|
super().__init__()
|
182
|
+
self.layer_id = layer_id
|
183
|
+
self.alt_stream = alt_stream
|
162
184
|
self.tp_size = get_tensor_model_parallel_world_size()
|
163
|
-
self.num_experts = config.num_experts
|
164
185
|
self.top_k = config.num_experts_per_tok
|
186
|
+
self.norm_topk_prob = config.norm_topk_prob
|
165
187
|
self.hidden_size = config.hidden_size
|
166
188
|
self.num_shared_experts = config.num_shared_experts
|
167
|
-
self.
|
168
|
-
self.
|
189
|
+
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
190
|
+
self.score_function = getattr(config, "score_function", None)
|
169
191
|
|
170
|
-
|
171
|
-
|
192
|
+
if config.hidden_act != "silu":
|
193
|
+
raise ValueError(
|
194
|
+
f"Unsupported activation: {config.hidden_act}. "
|
195
|
+
"Only silu is supported for now."
|
196
|
+
)
|
197
|
+
|
198
|
+
# Gate always runs at half / full precision for now.
|
199
|
+
router_dtype = getattr(config, "router_dtype", None)
|
200
|
+
if router_dtype is None:
|
201
|
+
self.router_dtype = None
|
202
|
+
elif router_dtype == "fp32":
|
203
|
+
self.router_dtype = torch.float32
|
204
|
+
else:
|
205
|
+
self.router_dtype = torch.bfloat16
|
206
|
+
|
207
|
+
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
|
208
|
+
assert global_server_args_dict["ep_num_redundant_experts"] == 0
|
209
|
+
# check group topk
|
210
|
+
self.num_expert_group = getattr(config, "n_group", 0)
|
211
|
+
self.topk_group = getattr(config, "topk_group", 0)
|
212
|
+
if self.num_expert_group > 0 or self.topk_group > 0:
|
213
|
+
assert (
|
214
|
+
self.num_expert_group > 0
|
215
|
+
and 0 < self.topk_group <= self.num_expert_group
|
216
|
+
)
|
217
|
+
self.use_grouped_topk = True
|
218
|
+
else:
|
219
|
+
self.num_expert_group = self.topk_group = None
|
220
|
+
self.use_grouped_topk = False
|
221
|
+
|
222
|
+
self.num_experts = (
|
223
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
224
|
+
)
|
225
|
+
|
226
|
+
self.gate = BailingMoEGate(
|
227
|
+
config=config,
|
228
|
+
params_dtype=self.router_dtype,
|
229
|
+
prefix=add_prefix("gate", prefix),
|
230
|
+
)
|
231
|
+
self.correction_bias = (
|
232
|
+
self.gate.expert_bias.data if self.gate.expert_bias is not None else None
|
172
233
|
)
|
173
234
|
|
174
|
-
self.
|
235
|
+
if self.score_function is not None:
|
236
|
+
assert (
|
237
|
+
self.score_function == "softmax" and self.correction_bias is None
|
238
|
+
) or (
|
239
|
+
self.score_function == "sigmoid" and self.correction_bias is not None
|
240
|
+
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
|
175
241
|
|
176
|
-
self.
|
242
|
+
self.topk = TopK(
|
243
|
+
top_k=self.top_k,
|
244
|
+
renormalize=self.norm_topk_prob,
|
245
|
+
use_grouped_topk=self.use_grouped_topk,
|
246
|
+
num_expert_group=self.num_expert_group,
|
247
|
+
# num_fused_shared_experts=self.num_fused_shared_experts,
|
248
|
+
topk_group=self.topk_group,
|
249
|
+
correction_bias=self.correction_bias,
|
250
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
251
|
+
)
|
252
|
+
|
253
|
+
self.experts = get_moe_impl_class(quant_config)(
|
177
254
|
num_experts=self.num_experts,
|
178
255
|
top_k=self.top_k,
|
179
|
-
layer_id=layer_id,
|
180
|
-
hidden_size=
|
181
|
-
intermediate_size=
|
182
|
-
reduce_results=False,
|
256
|
+
layer_id=self.layer_id,
|
257
|
+
hidden_size=config.hidden_size,
|
258
|
+
intermediate_size=config.moe_intermediate_size,
|
183
259
|
quant_config=quant_config,
|
260
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
184
261
|
prefix=add_prefix("experts", prefix),
|
185
262
|
)
|
186
|
-
|
187
|
-
if
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
263
|
+
# shared expert
|
264
|
+
if config.num_shared_experts is not None:
|
265
|
+
if hasattr(config, "moe_shared_expert_intermediate_size"):
|
266
|
+
intermediate_size = config.moe_shared_expert_intermediate_size
|
267
|
+
else:
|
268
|
+
intermediate_size = config.moe_intermediate_size
|
269
|
+
intermediate_size *= config.num_shared_experts
|
270
|
+
# disable tp for shared experts when enable deepep moe
|
271
|
+
self.shared_experts = BailingMoEMLP(
|
272
|
+
intermediate_size=intermediate_size,
|
193
273
|
config=config,
|
194
274
|
quant_config=quant_config,
|
195
275
|
reduce_results=False,
|
196
276
|
prefix=add_prefix("shared_experts", prefix),
|
277
|
+
**(
|
278
|
+
dict(tp_rank=0, tp_size=1)
|
279
|
+
if get_moe_a2a_backend().is_deepep()
|
280
|
+
else {}
|
281
|
+
),
|
197
282
|
)
|
283
|
+
# dispatcher
|
284
|
+
if get_moe_a2a_backend().is_deepep():
|
285
|
+
# TODO: we will support tp < ep in the future
|
286
|
+
self.ep_size = get_tensor_model_parallel_world_size()
|
287
|
+
|
288
|
+
self.deepep_dispatcher = DeepEPDispatcher(
|
289
|
+
group=parallel_state.get_tp_group().device_group,
|
290
|
+
router_topk=self.top_k,
|
291
|
+
permute_fusion=True,
|
292
|
+
num_experts=self.num_experts,
|
293
|
+
num_local_experts=config.num_experts // self.tp_size,
|
294
|
+
hidden_size=config.hidden_size,
|
295
|
+
params_dtype=config.torch_dtype,
|
296
|
+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
297
|
+
async_finish=True, # TODO
|
298
|
+
return_recv_hook=True,
|
299
|
+
)
|
300
|
+
|
301
|
+
def forward(
|
302
|
+
self,
|
303
|
+
hidden_states: torch.Tensor,
|
304
|
+
forward_batch: Optional[ForwardBatch] = None,
|
305
|
+
use_reduce_scatter: bool = False,
|
306
|
+
) -> torch.Tensor:
|
307
|
+
if not get_moe_a2a_backend().is_deepep():
|
308
|
+
return self.forward_normal(hidden_states, use_reduce_scatter)
|
198
309
|
else:
|
199
|
-
self.
|
310
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
200
311
|
|
201
|
-
def
|
202
|
-
|
203
|
-
|
312
|
+
def get_moe_weights(self):
|
313
|
+
return [
|
314
|
+
x.data
|
315
|
+
for name, x in self.experts.named_parameters()
|
316
|
+
if name not in ["correction_bias"]
|
317
|
+
]
|
204
318
|
|
319
|
+
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
205
320
|
shared_output = None
|
206
|
-
if self.
|
207
|
-
shared_output = self.shared_experts(
|
321
|
+
if self.num_shared_experts > 0:
|
322
|
+
shared_output = self.shared_experts(hidden_states)
|
323
|
+
return shared_output
|
208
324
|
|
209
|
-
|
210
|
-
|
211
|
-
|
325
|
+
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
326
|
+
# router_logits: (num_tokens, n_experts)
|
327
|
+
router_logits = self.gate(hidden_states)
|
328
|
+
topk_output = self.topk(hidden_states, router_logits)
|
329
|
+
return self.experts(hidden_states, topk_output)
|
212
330
|
|
213
|
-
|
331
|
+
def forward_normal_dual_stream(
|
332
|
+
self,
|
333
|
+
hidden_states: torch.Tensor,
|
334
|
+
) -> torch.Tensor:
|
335
|
+
current_stream = torch.cuda.current_stream()
|
336
|
+
self.alt_stream.wait_stream(current_stream)
|
337
|
+
shared_output = self._forward_shared_experts(hidden_states.clone())
|
338
|
+
|
339
|
+
with torch.cuda.stream(self.alt_stream):
|
340
|
+
router_output = self._forward_router_experts(hidden_states)
|
341
|
+
current_stream.wait_stream(self.alt_stream)
|
342
|
+
|
343
|
+
return router_output, shared_output
|
344
|
+
|
345
|
+
def forward_normal(
|
346
|
+
self,
|
347
|
+
hidden_states: torch.Tensor,
|
348
|
+
use_reduce_scatter: bool = False,
|
349
|
+
) -> torch.Tensor:
|
350
|
+
num_tokens, hidden_size = hidden_states.shape
|
351
|
+
hidden_states = hidden_states.view(-1, hidden_size)
|
352
|
+
|
353
|
+
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
354
|
+
if (
|
355
|
+
self.alt_stream is not None
|
356
|
+
and hidden_states.shape[0] > 0
|
357
|
+
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
358
|
+
and get_is_capture_mode()
|
359
|
+
):
|
360
|
+
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
361
|
+
hidden_states
|
362
|
+
)
|
363
|
+
else:
|
364
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
365
|
+
final_hidden_states = self._forward_router_experts(hidden_states)
|
366
|
+
|
367
|
+
if self.num_shared_experts > 0:
|
214
368
|
final_hidden_states = final_hidden_states + shared_output
|
215
369
|
|
216
|
-
if self.tp_size > 1:
|
370
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
217
371
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
372
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
373
|
+
|
374
|
+
def forward_deepep(
|
375
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
376
|
+
) -> torch.Tensor:
|
377
|
+
shared_output = None
|
378
|
+
forward_mode = forward_batch.forward_mode
|
379
|
+
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
380
|
+
router_logits = self.gate(hidden_states)
|
381
|
+
if self.num_shared_experts > 0:
|
382
|
+
shared_output = self.shared_experts(hidden_states)
|
383
|
+
|
384
|
+
topk_weights, topk_idx, _ = self.topk(
|
385
|
+
hidden_states,
|
386
|
+
router_logits,
|
387
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
388
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
389
|
+
layer_id=self.layer_id,
|
390
|
+
),
|
391
|
+
)
|
392
|
+
else:
|
393
|
+
topk_idx = torch.full(
|
394
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
395
|
+
)
|
396
|
+
topk_weights = torch.empty(
|
397
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
398
|
+
)
|
399
|
+
|
400
|
+
if self.ep_size > 1:
|
401
|
+
(
|
402
|
+
hidden_states,
|
403
|
+
topk_idx,
|
404
|
+
topk_weights,
|
405
|
+
reorder_topk_ids,
|
406
|
+
num_recv_tokens_per_expert,
|
407
|
+
seg_indptr,
|
408
|
+
masked_m,
|
409
|
+
expected_m,
|
410
|
+
) = self.deepep_dispatcher.dispatch(
|
411
|
+
hidden_states,
|
412
|
+
topk_idx,
|
413
|
+
topk_weights,
|
414
|
+
forward_batch=forward_batch,
|
415
|
+
)
|
416
|
+
|
417
|
+
final_hidden_states = self.experts(
|
418
|
+
hidden_states=hidden_states,
|
419
|
+
topk_idx=topk_idx,
|
420
|
+
topk_weights=topk_weights,
|
421
|
+
reorder_topk_ids=reorder_topk_ids,
|
422
|
+
seg_indptr=seg_indptr,
|
423
|
+
masked_m=masked_m,
|
424
|
+
expected_m=expected_m,
|
425
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
426
|
+
forward_batch=forward_batch,
|
427
|
+
)
|
428
|
+
if self.ep_size > 1:
|
429
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
430
|
+
final_hidden_states,
|
431
|
+
topk_idx,
|
432
|
+
topk_weights,
|
433
|
+
forward_batch=forward_batch,
|
434
|
+
)
|
218
435
|
|
219
|
-
|
436
|
+
final_hidden_states *= self.routed_scaling_factor
|
220
437
|
|
438
|
+
if shared_output is not None:
|
439
|
+
final_hidden_states = final_hidden_states + shared_output
|
440
|
+
return final_hidden_states
|
221
441
|
|
222
|
-
class BailingMoeBlock(nn.Module):
|
223
442
|
|
443
|
+
class BailingMoEAttention(nn.Module):
|
224
444
|
def __init__(
|
225
445
|
self,
|
226
446
|
config: PretrainedConfig,
|
227
|
-
layer_id: int,
|
447
|
+
layer_id: int = 0,
|
228
448
|
quant_config: Optional[QuantizationConfig] = None,
|
449
|
+
reduce_results: bool = True,
|
229
450
|
prefix: str = "",
|
451
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
230
452
|
):
|
231
453
|
super().__init__()
|
232
|
-
self.
|
233
|
-
self.
|
234
|
-
|
454
|
+
self.hidden_size = config.hidden_size
|
455
|
+
self.total_num_heads = config.num_attention_heads
|
456
|
+
self.total_kv_heads = config.num_key_value_heads
|
457
|
+
self.dp_size = get_attention_dp_size()
|
458
|
+
attn_tp_rank = get_attention_tp_rank()
|
459
|
+
attn_tp_size = get_attention_tp_size()
|
460
|
+
|
461
|
+
assert self.total_num_heads % attn_tp_size == 0
|
462
|
+
assert self.total_kv_heads % attn_tp_size == 0
|
463
|
+
assert self.total_num_heads >= self.total_kv_heads
|
464
|
+
|
465
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
466
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
|
467
|
+
self.q_size = self.head_dim * self.num_heads
|
468
|
+
|
469
|
+
self.num_kv_heads = self.total_kv_heads // attn_tp_size
|
470
|
+
self.kv_size = max(1, self.num_kv_heads * self.head_dim)
|
471
|
+
|
472
|
+
self.scale = self.head_dim**-0.5
|
473
|
+
|
474
|
+
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
475
|
+
|
476
|
+
self.query_key_value = QKVParallelLinear(
|
477
|
+
self.hidden_size,
|
478
|
+
self.head_dim,
|
479
|
+
self.total_num_heads,
|
480
|
+
self.total_kv_heads,
|
481
|
+
bias=(config.use_bias or config.use_qkv_bias),
|
482
|
+
quant_config=quant_config,
|
483
|
+
prefix=add_prefix("query_key_value", prefix),
|
484
|
+
tp_rank=attn_tp_rank,
|
485
|
+
tp_size=attn_tp_size,
|
486
|
+
)
|
487
|
+
|
488
|
+
if self.use_qk_norm:
|
489
|
+
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
490
|
+
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
491
|
+
|
492
|
+
self.dense = RowParallelLinear(
|
493
|
+
self.total_num_heads * self.head_dim,
|
494
|
+
self.hidden_size,
|
495
|
+
bias=config.use_bias,
|
496
|
+
quant_config=quant_config,
|
497
|
+
reduce_results=reduce_results,
|
498
|
+
prefix=add_prefix("dense", prefix),
|
499
|
+
tp_rank=attn_tp_rank,
|
500
|
+
tp_size=attn_tp_size,
|
235
501
|
)
|
236
|
-
|
237
|
-
|
502
|
+
|
503
|
+
if hasattr(config, "partial_rotary_factor"):
|
504
|
+
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
505
|
+
elif hasattr(config, "rotary_dim"):
|
506
|
+
self.rotary_dim = config.rotary_dim
|
507
|
+
else:
|
508
|
+
self.rotary_dim = self.head_dim
|
509
|
+
self.rotary_emb = get_rope(
|
510
|
+
self.head_dim,
|
511
|
+
rotary_dim=self.rotary_dim,
|
512
|
+
max_position=config.max_position_embeddings,
|
513
|
+
base=config.rope_theta,
|
514
|
+
rope_scaling=config.rope_scaling,
|
238
515
|
)
|
239
|
-
|
240
|
-
|
516
|
+
|
517
|
+
self.attn = RadixAttention(
|
518
|
+
self.num_heads,
|
519
|
+
self.head_dim,
|
520
|
+
self.scale,
|
521
|
+
num_kv_heads=self.num_kv_heads,
|
241
522
|
layer_id=layer_id,
|
242
|
-
|
243
|
-
prefix=add_prefix("mlp", prefix),
|
523
|
+
prefix=add_prefix("attn", prefix),
|
244
524
|
)
|
245
525
|
|
526
|
+
self.alt_stream = alt_stream
|
527
|
+
|
528
|
+
def _apply_qk_norm(
|
529
|
+
self, q: torch.Tensor, k: torch.Tensor
|
530
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
531
|
+
# overlap qk norm
|
532
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
533
|
+
current_stream = torch.cuda.current_stream()
|
534
|
+
self.alt_stream.wait_stream(current_stream)
|
535
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
536
|
+
q_by_head = self.query_layernorm(q_by_head)
|
537
|
+
with torch.cuda.stream(self.alt_stream):
|
538
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
539
|
+
k_by_head = self.key_layernorm(k_by_head)
|
540
|
+
current_stream.wait_stream(self.alt_stream)
|
541
|
+
else:
|
542
|
+
q_by_head = q.reshape(-1, self.head_dim)
|
543
|
+
q_by_head = self.query_layernorm(q_by_head)
|
544
|
+
k_by_head = k.reshape(-1, self.head_dim)
|
545
|
+
k_by_head = self.key_layernorm(k_by_head)
|
546
|
+
q = q_by_head.view(q.shape)
|
547
|
+
k = k_by_head.view(k.shape)
|
548
|
+
return q, k
|
549
|
+
|
246
550
|
def forward(
|
247
551
|
self,
|
552
|
+
positions: torch.Tensor,
|
248
553
|
hidden_states: torch.Tensor,
|
249
|
-
position_ids: torch.Tensor,
|
250
|
-
residual: Optional[torch.Tensor],
|
251
554
|
forward_batch: ForwardBatch,
|
252
|
-
) ->
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
555
|
+
) -> torch.Tensor:
|
556
|
+
if hidden_states.shape[0] == 0:
|
557
|
+
return hidden_states
|
558
|
+
qkv, _ = self.query_key_value(hidden_states)
|
559
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
560
|
+
if self.use_qk_norm:
|
561
|
+
q, k = self._apply_qk_norm(q, k)
|
562
|
+
q, k = self.rotary_emb(
|
563
|
+
positions,
|
564
|
+
q,
|
565
|
+
k,
|
566
|
+
fused_set_kv_buffer_arg=(
|
567
|
+
create_fused_set_kv_buffer_arg(
|
568
|
+
value=v,
|
569
|
+
layer=self.attn,
|
570
|
+
forward_batch=forward_batch,
|
571
|
+
)
|
572
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
573
|
+
else None
|
574
|
+
),
|
575
|
+
)
|
576
|
+
context_layer = self.attn(
|
577
|
+
q,
|
578
|
+
k,
|
579
|
+
v,
|
580
|
+
forward_batch,
|
581
|
+
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
582
|
+
)
|
583
|
+
attn_output, _ = self.dense(context_layer)
|
584
|
+
return attn_output
|
585
|
+
|
586
|
+
|
587
|
+
class BailingMoEBlock(nn.Module):
|
588
|
+
def __init__(
|
589
|
+
self,
|
590
|
+
config: PretrainedConfig,
|
591
|
+
layer_id: int = 0,
|
592
|
+
quant_config: Optional[QuantizationConfig] = None,
|
593
|
+
prefix: str = "",
|
594
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
595
|
+
):
|
596
|
+
super().__init__()
|
597
|
+
hidden_size = config.hidden_size
|
598
|
+
|
599
|
+
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
600
|
+
self.dp_size = get_attention_dp_size()
|
601
|
+
self.attention = BailingMoEAttention(
|
602
|
+
config,
|
603
|
+
layer_id,
|
604
|
+
quant_config,
|
605
|
+
reduce_results=False,
|
606
|
+
prefix=add_prefix("attention", prefix),
|
607
|
+
alt_stream=alt_stream,
|
608
|
+
)
|
609
|
+
self.layer_id = layer_id
|
610
|
+
self.attn_tp_size = get_attention_tp_size()
|
611
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
612
|
+
|
613
|
+
self.is_layer_sparse = self._is_layer_sparse(
|
614
|
+
config, layer_id=layer_id, is_nextn=False
|
615
|
+
)
|
616
|
+
is_previous_layer_sparse = self._is_layer_sparse(
|
617
|
+
config, layer_id=layer_id - 1, is_nextn=False
|
618
|
+
)
|
619
|
+
|
620
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
621
|
+
layer_id=layer_id,
|
622
|
+
num_layers=config.num_hidden_layers,
|
623
|
+
is_layer_sparse=self.is_layer_sparse,
|
624
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
625
|
+
)
|
626
|
+
|
627
|
+
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
628
|
+
|
629
|
+
if self.is_layer_sparse:
|
630
|
+
self.mlp = BailingMoESparseMoeBlock(
|
631
|
+
layer_id=layer_id,
|
632
|
+
config=config,
|
633
|
+
quant_config=quant_config,
|
634
|
+
alt_stream=alt_stream,
|
635
|
+
prefix=add_prefix("mlp", prefix),
|
636
|
+
)
|
257
637
|
else:
|
258
|
-
|
259
|
-
|
638
|
+
if enable_moe_dense_fully_dp():
|
639
|
+
mlp_tp_rank, mlp_tp_size = 0, 1
|
640
|
+
else:
|
641
|
+
mlp_tp_rank, mlp_tp_size = None, None
|
642
|
+
self.mlp = BailingMoEMLP(
|
643
|
+
intermediate_size=config.intermediate_size,
|
644
|
+
config=config,
|
645
|
+
quant_config=quant_config,
|
646
|
+
prefix=add_prefix("mlp", prefix),
|
647
|
+
tp_rank=mlp_tp_rank,
|
648
|
+
tp_size=mlp_tp_size,
|
260
649
|
)
|
261
650
|
|
262
|
-
|
263
|
-
|
264
|
-
|
651
|
+
self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
|
652
|
+
|
653
|
+
self.layer_communicator = LayerCommunicator(
|
654
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
655
|
+
input_layernorm=self.input_layernorm,
|
656
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
657
|
+
allow_reduce_scatter=True,
|
658
|
+
)
|
659
|
+
|
660
|
+
def _is_layer_sparse(
|
661
|
+
self, config: PretrainedConfig, layer_id: int, is_nextn: bool
|
662
|
+
) -> bool:
|
663
|
+
return is_nextn or (
|
664
|
+
config.num_experts is not None and layer_id >= config.first_k_dense_replace
|
665
|
+
)
|
666
|
+
|
667
|
+
def forward(
|
668
|
+
self,
|
669
|
+
positions: torch.Tensor,
|
670
|
+
hidden_states: torch.Tensor,
|
671
|
+
forward_batch: ForwardBatch,
|
672
|
+
residual: Optional[torch.Tensor],
|
673
|
+
) -> torch.Tensor:
|
674
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
675
|
+
hidden_states=hidden_states,
|
676
|
+
residual=residual,
|
677
|
+
forward_batch=forward_batch,
|
678
|
+
)
|
679
|
+
|
680
|
+
hidden_states = self.attention(
|
681
|
+
positions=positions,
|
682
|
+
hidden_states=hidden_states,
|
265
683
|
forward_batch=forward_batch,
|
266
684
|
)
|
267
685
|
|
268
|
-
|
269
|
-
|
270
|
-
|
686
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
687
|
+
hidden_states=hidden_states,
|
688
|
+
residual=residual,
|
689
|
+
forward_batch=forward_batch,
|
690
|
+
)
|
691
|
+
|
692
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
693
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
694
|
+
forward_batch
|
695
|
+
)
|
696
|
+
|
697
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
698
|
+
|
699
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
700
|
+
hidden_states=hidden_states,
|
701
|
+
residual=residual,
|
702
|
+
forward_batch=forward_batch,
|
271
703
|
)
|
272
|
-
mlp_output = self.mlp(normed_hidden_states)
|
273
704
|
|
274
|
-
return
|
705
|
+
return hidden_states, residual
|
275
706
|
|
276
707
|
|
277
|
-
class
|
708
|
+
class BailingMoEModel(nn.Module):
|
278
709
|
|
279
710
|
def __init__(
|
280
711
|
self,
|
281
712
|
config: PretrainedConfig,
|
282
713
|
quant_config: Optional[QuantizationConfig] = None,
|
714
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
283
715
|
prefix: str = "",
|
284
716
|
):
|
285
717
|
super().__init__()
|
718
|
+
self.pp_group = get_pp_group()
|
286
719
|
self.config = config
|
287
|
-
self.padding_idx = config.pad_token_id
|
288
720
|
self.vocab_size = config.vocab_size
|
289
721
|
self.embed_dim = config.hidden_size
|
722
|
+
if self.pp_group.is_first_rank:
|
723
|
+
self.word_embeddings = VocabParallelEmbedding(
|
724
|
+
self.vocab_size,
|
725
|
+
self.embed_dim,
|
726
|
+
quant_config=quant_config,
|
727
|
+
prefix=add_prefix("word_embeddings", prefix),
|
728
|
+
enable_tp=not is_dp_attention_enabled(),
|
729
|
+
)
|
730
|
+
else:
|
731
|
+
self.word_embeddings = PPMissingLayer()
|
290
732
|
|
291
|
-
self.embed_tokens = VocabParallelEmbedding(
|
292
|
-
config.vocab_size,
|
293
|
-
config.hidden_size,
|
294
|
-
prefix=add_prefix("embed_tokens", prefix),
|
295
|
-
)
|
296
733
|
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
|
297
734
|
|
298
|
-
self.layers = make_layers(
|
735
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
299
736
|
config.num_hidden_layers,
|
300
|
-
lambda idx, prefix:
|
301
|
-
config=config,
|
737
|
+
lambda idx, prefix: BailingMoEBlock(
|
302
738
|
layer_id=idx,
|
739
|
+
config=config,
|
303
740
|
quant_config=quant_config,
|
304
741
|
prefix=prefix,
|
742
|
+
alt_stream=alt_stream,
|
305
743
|
),
|
744
|
+
pp_rank=self.pp_group.rank_in_group,
|
745
|
+
pp_size=self.pp_group.world_size,
|
306
746
|
prefix=add_prefix("layers", prefix),
|
307
747
|
)
|
308
|
-
|
309
|
-
|
748
|
+
if self.pp_group.is_last_rank:
|
749
|
+
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
|
750
|
+
else:
|
751
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
310
752
|
|
311
753
|
def forward(
|
312
754
|
self,
|
313
755
|
input_ids: torch.Tensor,
|
314
|
-
|
756
|
+
positions: torch.Tensor,
|
315
757
|
forward_batch: ForwardBatch,
|
316
|
-
input_embeds:
|
317
|
-
|
318
|
-
|
319
|
-
|
758
|
+
input_embeds: torch.Tensor = None,
|
759
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
760
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
761
|
+
if self.pp_group.is_first_rank:
|
762
|
+
if input_embeds is None:
|
763
|
+
hidden_states = self.word_embeddings(input_ids)
|
764
|
+
else:
|
765
|
+
hidden_states = input_embeds
|
766
|
+
residual = None
|
320
767
|
else:
|
321
|
-
|
768
|
+
assert pp_proxy_tensors is not None
|
769
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
770
|
+
residual = pp_proxy_tensors["residual"]
|
322
771
|
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
hidden_states,
|
327
|
-
|
328
|
-
|
329
|
-
|
772
|
+
for i in range(self.start_layer, self.end_layer):
|
773
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
774
|
+
layer = self.layers[i]
|
775
|
+
hidden_states, residual = layer(
|
776
|
+
positions,
|
777
|
+
hidden_states,
|
778
|
+
forward_batch,
|
779
|
+
residual,
|
780
|
+
)
|
781
|
+
if not self.pp_group.is_last_rank:
|
782
|
+
return PPProxyTensors(
|
783
|
+
{
|
784
|
+
"hidden_states": hidden_states,
|
785
|
+
"residual": residual,
|
786
|
+
}
|
330
787
|
)
|
788
|
+
else:
|
789
|
+
if not forward_batch.forward_mode.is_idle():
|
790
|
+
if residual is None:
|
791
|
+
hidden_states = self.norm(hidden_states)
|
792
|
+
else:
|
793
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
794
|
+
return hidden_states
|
331
795
|
|
332
|
-
hidden_states, _ = self.norm(hidden_states, residual)
|
333
|
-
return hidden_states
|
334
|
-
|
335
|
-
|
336
|
-
class BailingMoeForCausalLM(nn.Module):
|
337
796
|
|
797
|
+
class BailingMoEForCausalLM(nn.Module):
|
338
798
|
def __init__(
|
339
799
|
self,
|
340
800
|
config: PretrainedConfig,
|
341
801
|
quant_config: Optional[QuantizationConfig] = None,
|
342
|
-
|
802
|
+
prefix: str = "",
|
803
|
+
):
|
343
804
|
super().__init__()
|
805
|
+
self.pp_group = get_pp_group()
|
344
806
|
self.config = config
|
345
|
-
self.
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
807
|
+
self.quant_config = quant_config
|
808
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
809
|
+
|
810
|
+
self.model = BailingMoEModel(
|
811
|
+
config,
|
812
|
+
quant_config,
|
813
|
+
alt_stream=alt_stream,
|
814
|
+
prefix=add_prefix("model", ""),
|
350
815
|
)
|
351
|
-
if config.tie_word_embeddings:
|
352
|
-
self.lm_head.weight = self.model.embed_tokens.weight
|
353
816
|
|
817
|
+
# tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的
|
818
|
+
if config.tie_word_embeddings:
|
819
|
+
self.lm_head = self.model.word_embeddings
|
820
|
+
else:
|
821
|
+
# TODO something wrong with ParallelLMHead with DP attention enabled
|
822
|
+
self.lm_head = ParallelLMHead(
|
823
|
+
config.vocab_size,
|
824
|
+
config.hidden_size,
|
825
|
+
quant_config=quant_config,
|
826
|
+
prefix=add_prefix("lm_head", prefix),
|
827
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
828
|
+
)
|
354
829
|
self.logits_processor = LogitsProcessor(config)
|
355
830
|
|
831
|
+
@property
|
832
|
+
def start_layer(self):
|
833
|
+
return self.model.start_layer
|
834
|
+
|
835
|
+
@property
|
836
|
+
def end_layer(self):
|
837
|
+
return self.model.end_layer
|
838
|
+
|
839
|
+
def get_embed_and_head(self):
|
840
|
+
"""Used by the eagle_worker."""
|
841
|
+
return self.model.word_embeddings.weight, self.lm_head.weight
|
842
|
+
|
843
|
+
def set_embed_and_head(self, embed, head):
|
844
|
+
"""Used by the eagle_worker."""
|
845
|
+
del self.model.word_embeddings.weight
|
846
|
+
del self.lm_head.weight
|
847
|
+
self.model.word_embeddings.weight = embed
|
848
|
+
self.lm_head.weight = head
|
849
|
+
torch.cuda.empty_cache()
|
850
|
+
torch.cuda.synchronize()
|
851
|
+
|
852
|
+
@torch.no_grad()
|
356
853
|
def forward(
|
357
854
|
self,
|
358
855
|
input_ids: torch.Tensor,
|
359
856
|
positions: torch.Tensor,
|
360
857
|
forward_batch: ForwardBatch,
|
361
|
-
|
858
|
+
input_embeds: torch.Tensor = None,
|
859
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
362
860
|
) -> torch.Tensor:
|
363
|
-
hidden_states = self.model(
|
364
|
-
|
365
|
-
|
861
|
+
hidden_states = self.model(
|
862
|
+
input_ids,
|
863
|
+
positions,
|
864
|
+
forward_batch,
|
865
|
+
input_embeds,
|
866
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
366
867
|
)
|
868
|
+
if self.pp_group.is_last_rank:
|
869
|
+
return self.logits_processor(
|
870
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
871
|
+
)
|
872
|
+
else:
|
873
|
+
return hidden_states
|
367
874
|
|
368
|
-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
875
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
876
|
+
if is_nextn:
|
877
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
878
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
879
|
+
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
|
880
|
+
# compatible with old design
|
881
|
+
nextn_layer_id = (
|
882
|
+
0
|
883
|
+
if self.config.num_hidden_layers == 1
|
884
|
+
else self.config.num_hidden_layers
|
885
|
+
)
|
886
|
+
else:
|
887
|
+
raise ValueError("num_nextn_predict_layers is not in the config")
|
369
888
|
|
370
889
|
stacked_params_mapping = [
|
890
|
+
# (param_name, shard_name, shard_id)
|
371
891
|
("gate_up_proj", "gate_proj", 0),
|
372
892
|
("gate_up_proj", "up_proj", 1),
|
373
893
|
]
|
374
894
|
|
895
|
+
if is_nextn:
|
896
|
+
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
|
897
|
+
nextn_spec_weight_names = [
|
898
|
+
"final_layernorm",
|
899
|
+
"eh_proj",
|
900
|
+
"enorm",
|
901
|
+
"hnorm",
|
902
|
+
]
|
903
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
904
|
+
# (param_name, weight_name, expert_id, shard_id)
|
375
905
|
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
376
906
|
ckpt_gate_proj_name="gate_proj",
|
377
907
|
ckpt_down_proj_name="down_proj",
|
@@ -381,39 +911,87 @@ class BailingMoeForCausalLM(nn.Module):
|
|
381
911
|
|
382
912
|
params_dict = dict(self.named_parameters())
|
383
913
|
for name, loaded_weight in weights:
|
914
|
+
if (
|
915
|
+
("v_head" in name)
|
916
|
+
or ("inv_freq" in name)
|
917
|
+
or (self.config.tie_word_embeddings and "lm_head" in name)
|
918
|
+
):
|
919
|
+
continue
|
384
920
|
|
385
921
|
if (
|
386
922
|
hasattr(self.config, "norm_head")
|
387
923
|
and self.config.norm_head
|
388
924
|
and "lm_head.weight" in name
|
389
925
|
):
|
926
|
+
import torch.nn.functional as F
|
927
|
+
|
390
928
|
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
|
391
929
|
|
392
|
-
if
|
393
|
-
|
930
|
+
if is_nextn:
|
931
|
+
if not name.startswith(nextn_layer_prefix):
|
932
|
+
continue
|
933
|
+
|
934
|
+
# Use shared head and embed weights from target model
|
935
|
+
if "shared_head.head" in name or "embed_tokens" in name:
|
936
|
+
continue
|
937
|
+
|
938
|
+
is_decoder = True
|
939
|
+
# For nextn specific weights
|
940
|
+
for weight_name in nextn_spec_weight_names:
|
941
|
+
if weight_name in name:
|
942
|
+
name = name.replace(nextn_layer_prefix, "model")
|
943
|
+
is_decoder = False
|
944
|
+
break
|
945
|
+
# For decoder layer weights
|
946
|
+
if is_decoder:
|
947
|
+
name = name.replace(nextn_layer_prefix, "model.decoder")
|
394
948
|
|
395
949
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
396
|
-
if weight_name
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
950
|
+
if weight_name not in name:
|
951
|
+
continue
|
952
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
953
|
+
# Since we handle the experts below in expert_params_mapping,
|
954
|
+
# we need to skip here BEFORE we update the name, otherwise
|
955
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
956
|
+
# will then be updated below in expert_params_mapping
|
957
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
958
|
+
if "mlp.experts" in name:
|
959
|
+
continue
|
960
|
+
name = name.replace(weight_name, param_name)
|
961
|
+
# Skip loading extra bias for GPTQ models.
|
962
|
+
if name.endswith(".bias") and name not in params_dict:
|
963
|
+
continue
|
964
|
+
if name not in params_dict:
|
965
|
+
continue
|
966
|
+
|
967
|
+
param = params_dict[name]
|
968
|
+
weight_loader = param.weight_loader
|
969
|
+
weight_loader(param, loaded_weight, shard_id)
|
970
|
+
break
|
401
971
|
else:
|
402
|
-
for
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
972
|
+
for mapping in expert_params_mapping:
|
973
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
974
|
+
if weight_name not in name:
|
975
|
+
continue
|
976
|
+
name = name.replace(weight_name, param_name)
|
977
|
+
if name not in params_dict:
|
978
|
+
continue
|
979
|
+
param = params_dict[name]
|
980
|
+
weight_loader = param.weight_loader
|
981
|
+
weight_loader(
|
982
|
+
param,
|
983
|
+
loaded_weight,
|
984
|
+
name,
|
985
|
+
shard_id=shard_id,
|
986
|
+
expert_id=expert_id,
|
987
|
+
)
|
988
|
+
break
|
414
989
|
else:
|
990
|
+
# Skip loading extra bias for GPTQ models.
|
415
991
|
if name.endswith(".bias") and name not in params_dict:
|
416
992
|
continue
|
993
|
+
if name not in params_dict:
|
994
|
+
continue
|
417
995
|
|
418
996
|
param = params_dict[name]
|
419
997
|
weight_loader = getattr(
|
@@ -421,5 +999,30 @@ class BailingMoeForCausalLM(nn.Module):
|
|
421
999
|
)
|
422
1000
|
weight_loader(param, loaded_weight)
|
423
1001
|
|
1002
|
+
if not is_nextn:
|
1003
|
+
self.routed_experts_weights_of_layer = {
|
1004
|
+
layer_id: layer.mlp.get_moe_weights()
|
1005
|
+
for layer_id, layer in enumerate(self.model.layers)
|
1006
|
+
if not isinstance(layer, PPMissingLayer)
|
1007
|
+
and isinstance(layer.mlp, BailingMoESparseMoeBlock)
|
1008
|
+
}
|
1009
|
+
|
1010
|
+
@classmethod
|
1011
|
+
def get_model_config_for_expert_location(cls, config):
|
1012
|
+
num_groups = getattr(config, "n_group", 0)
|
1013
|
+
return ModelConfigForExpertLocation(
|
1014
|
+
num_layers=config.num_hidden_layers,
|
1015
|
+
num_logical_experts=config.num_experts,
|
1016
|
+
num_groups=None if num_groups == 0 else num_groups,
|
1017
|
+
)
|
1018
|
+
|
1019
|
+
|
1020
|
+
class BailingMoeForCausalLM(BailingMoEForCausalLM):
|
1021
|
+
pass
|
1022
|
+
|
1023
|
+
|
1024
|
+
class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
|
1025
|
+
pass
|
1026
|
+
|
424
1027
|
|
425
|
-
EntryClass = BailingMoeForCausalLM
|
1028
|
+
EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]
|