sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +73 -14
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +30 -7
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -12
 - sglang/srt/entrypoints/engine.py +31 -20
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +94 -94
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +23 -2
 - sglang/srt/eplb/expert_distribution.py +64 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +19 -8
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +24 -1
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +35 -6
 - sglang/srt/layers/logits_processor.py +9 -20
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +78 -289
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +35 -10
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +13 -84
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +130 -46
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +29 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +185 -144
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +165 -78
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +253 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +55 -14
 - sglang/srt/model_executor/model_runner.py +77 -170
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +296 -78
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +29 -197
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +23 -2
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +35 -5
 - sglang/srt/models/qwen3_moe.py +18 -12
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +459 -199
 - sglang/srt/single_batch_overlap.py +2 -4
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +142 -74
 - sglang/srt/utils/hf_transformers_utils.py +38 -12
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
 - sglang/srt/models/vila.py +0 -306
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/models/glm4.py
    CHANGED
    
    | 
         @@ -15,46 +15,119 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            # Modeling from:
         
     | 
| 
       16 
16 
     | 
    
         
             
            # ./llama.py and
         
     | 
| 
       17 
17 
     | 
    
         
             
            # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
         
     | 
| 
       18 
     | 
    
         
            -
            """Inference-only  
     | 
| 
      
 18 
     | 
    
         
            +
            """Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
      
 20 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 21 
     | 
    
         
            +
            from typing import Any, Dict, Iterable, Optional, Tuple, Union
         
     | 
| 
       21 
22 
     | 
    
         | 
| 
       22 
23 
     | 
    
         
             
            import torch
         
     | 
| 
       23 
24 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       24 
     | 
    
         
            -
            from transformers import Glm4Config
         
     | 
| 
       25 
25 
     | 
    
         | 
| 
       26 
     | 
    
         
            -
            from sglang.srt.distributed import  
     | 
| 
      
 26 
     | 
    
         
            +
            from sglang.srt.distributed import (
         
     | 
| 
      
 27 
     | 
    
         
            +
                get_pp_group,
         
     | 
| 
      
 28 
     | 
    
         
            +
                get_tensor_model_parallel_rank,
         
     | 
| 
      
 29 
     | 
    
         
            +
                get_tensor_model_parallel_world_size,
         
     | 
| 
      
 30 
     | 
    
         
            +
            )
         
     | 
| 
      
 31 
     | 
    
         
            +
            from sglang.srt.layers.activation import SiluAndMul
         
     | 
| 
      
 32 
     | 
    
         
            +
            from sglang.srt.layers.dp_attention import is_dp_attention_enabled
         
     | 
| 
       27 
33 
     | 
    
         
             
            from sglang.srt.layers.layernorm import RMSNorm
         
     | 
| 
       28 
     | 
    
         
            -
            from sglang.srt.layers.linear import  
     | 
| 
      
 34 
     | 
    
         
            +
            from sglang.srt.layers.linear import (
         
     | 
| 
      
 35 
     | 
    
         
            +
                MergedColumnParallelLinear,
         
     | 
| 
      
 36 
     | 
    
         
            +
                QKVParallelLinear,
         
     | 
| 
      
 37 
     | 
    
         
            +
                RowParallelLinear,
         
     | 
| 
      
 38 
     | 
    
         
            +
            )
         
     | 
| 
       29 
39 
     | 
    
         
             
            from sglang.srt.layers.logits_processor import LogitsProcessor
         
     | 
| 
      
 40 
     | 
    
         
            +
            from sglang.srt.layers.pooler import Pooler, PoolingType
         
     | 
| 
       30 
41 
     | 
    
         
             
            from sglang.srt.layers.quantization.base_config import QuantizationConfig
         
     | 
| 
       31 
42 
     | 
    
         
             
            from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
       32 
43 
     | 
    
         
             
            from sglang.srt.layers.rotary_embedding import get_rope
         
     | 
| 
      
 44 
     | 
    
         
            +
            from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
         
     | 
| 
       33 
45 
     | 
    
         
             
            from sglang.srt.layers.vocab_parallel_embedding import (
         
     | 
| 
       34 
46 
     | 
    
         
             
                ParallelLMHead,
         
     | 
| 
       35 
47 
     | 
    
         
             
                VocabParallelEmbedding,
         
     | 
| 
       36 
48 
     | 
    
         
             
            )
         
     | 
| 
       37 
     | 
    
         
            -
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       38 
     | 
    
         
            -
            from sglang.srt.model_loader.weight_utils import  
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
      
 49 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
         
     | 
| 
      
 50 
     | 
    
         
            +
            from sglang.srt.model_loader.weight_utils import (
         
     | 
| 
      
 51 
     | 
    
         
            +
                default_weight_loader,
         
     | 
| 
      
 52 
     | 
    
         
            +
                kv_cache_scales_loader,
         
     | 
| 
      
 53 
     | 
    
         
            +
            )
         
     | 
| 
       40 
54 
     | 
    
         
             
            from sglang.srt.utils import add_prefix, make_layers
         
     | 
| 
       41 
55 
     | 
    
         | 
| 
      
 56 
     | 
    
         
            +
            Glm4Config = None
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
            class Glm4MLP(nn.Module):
         
     | 
| 
      
 62 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 63 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 64 
     | 
    
         
            +
                    hidden_size: int,
         
     | 
| 
      
 65 
     | 
    
         
            +
                    intermediate_size: int,
         
     | 
| 
      
 66 
     | 
    
         
            +
                    hidden_act: str,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    prefix: str = "",
         
     | 
| 
      
 69 
     | 
    
         
            +
                    reduce_results: bool = True,
         
     | 
| 
      
 70 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 71 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 72 
     | 
    
         
            +
                    self.gate_up_proj = MergedColumnParallelLinear(
         
     | 
| 
      
 73 
     | 
    
         
            +
                        hidden_size,
         
     | 
| 
      
 74 
     | 
    
         
            +
                        [intermediate_size] * 2,
         
     | 
| 
      
 75 
     | 
    
         
            +
                        bias=False,
         
     | 
| 
      
 76 
     | 
    
         
            +
                        quant_config=quant_config,
         
     | 
| 
      
 77 
     | 
    
         
            +
                        prefix=add_prefix("gate_up_proj", prefix),
         
     | 
| 
      
 78 
     | 
    
         
            +
                    )
         
     | 
| 
      
 79 
     | 
    
         
            +
                    self.down_proj = RowParallelLinear(
         
     | 
| 
      
 80 
     | 
    
         
            +
                        intermediate_size,
         
     | 
| 
      
 81 
     | 
    
         
            +
                        hidden_size,
         
     | 
| 
      
 82 
     | 
    
         
            +
                        bias=False,
         
     | 
| 
      
 83 
     | 
    
         
            +
                        quant_config=quant_config,
         
     | 
| 
      
 84 
     | 
    
         
            +
                        prefix=add_prefix("down_proj", prefix),
         
     | 
| 
      
 85 
     | 
    
         
            +
                        reduce_results=reduce_results,
         
     | 
| 
      
 86 
     | 
    
         
            +
                    )
         
     | 
| 
      
 87 
     | 
    
         
            +
                    if hidden_act != "silu":
         
     | 
| 
      
 88 
     | 
    
         
            +
                        raise ValueError(
         
     | 
| 
      
 89 
     | 
    
         
            +
                            f"Unsupported activation: {hidden_act}. Only silu is supported for now."
         
     | 
| 
      
 90 
     | 
    
         
            +
                        )
         
     | 
| 
      
 91 
     | 
    
         
            +
                    self.act_fn = SiluAndMul()
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 94 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 95 
     | 
    
         
            +
                    x,
         
     | 
| 
      
 96 
     | 
    
         
            +
                    forward_batch=None,
         
     | 
| 
      
 97 
     | 
    
         
            +
                    use_reduce_scatter: bool = False,
         
     | 
| 
      
 98 
     | 
    
         
            +
                ):
         
     | 
| 
      
 99 
     | 
    
         
            +
                    gate_up, _ = self.gate_up_proj(x)
         
     | 
| 
      
 100 
     | 
    
         
            +
                    x = self.act_fn(gate_up)
         
     | 
| 
      
 101 
     | 
    
         
            +
                    x, _ = self.down_proj(
         
     | 
| 
      
 102 
     | 
    
         
            +
                        x,
         
     | 
| 
      
 103 
     | 
    
         
            +
                        skip_all_reduce=use_reduce_scatter,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    )
         
     | 
| 
      
 105 
     | 
    
         
            +
                    return x
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
       42 
107 
     | 
    
         | 
| 
       43 
108 
     | 
    
         
             
            class Glm4Attention(nn.Module):
         
     | 
| 
       44 
109 
     | 
    
         
             
                def __init__(
         
     | 
| 
       45 
110 
     | 
    
         
             
                    self,
         
     | 
| 
       46 
     | 
    
         
            -
                     
     | 
| 
      
 111 
     | 
    
         
            +
                    hidden_size: int,
         
     | 
| 
      
 112 
     | 
    
         
            +
                    num_heads: int,
         
     | 
| 
      
 113 
     | 
    
         
            +
                    num_kv_heads: int,
         
     | 
| 
      
 114 
     | 
    
         
            +
                    head_dim: Optional[int] = None,
         
     | 
| 
       47 
115 
     | 
    
         
             
                    layer_id: int = 0,
         
     | 
| 
      
 116 
     | 
    
         
            +
                    rope_theta: float = 1000000,
         
     | 
| 
      
 117 
     | 
    
         
            +
                    rope_scaling: Optional[Dict[str, Any]] = None,
         
     | 
| 
      
 118 
     | 
    
         
            +
                    max_position_embeddings: int = 131072,
         
     | 
| 
       48 
119 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
      
 120 
     | 
    
         
            +
                    dual_chunk_attention_config: Optional[dict[str, Any]] = None,
         
     | 
| 
      
 121 
     | 
    
         
            +
                    partial_rotary_factor: float = 0.5,
         
     | 
| 
       49 
122 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       50 
     | 
    
         
            -
                ):
         
     | 
| 
      
 123 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
       51 
124 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       52 
     | 
    
         
            -
                    self.hidden_size =  
     | 
| 
      
 125 
     | 
    
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 
       53 
126 
     | 
    
         
             
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
       54 
     | 
    
         
            -
                    self.total_num_heads =  
     | 
| 
      
 127 
     | 
    
         
            +
                    self.total_num_heads = num_heads
         
     | 
| 
       55 
128 
     | 
    
         
             
                    assert self.total_num_heads % tp_size == 0
         
     | 
| 
       56 
129 
     | 
    
         
             
                    self.num_heads = self.total_num_heads // tp_size
         
     | 
| 
       57 
     | 
    
         
            -
                    self.total_num_kv_heads =  
     | 
| 
      
 130 
     | 
    
         
            +
                    self.total_num_kv_heads = num_kv_heads
         
     | 
| 
       58 
131 
     | 
    
         
             
                    if self.total_num_kv_heads >= tp_size:
         
     | 
| 
       59 
132 
     | 
    
         
             
                        # Number of KV heads is greater than TP size, so we partition
         
     | 
| 
       60 
133 
     | 
    
         
             
                        # the KV heads across multiple tensor parallel GPUs.
         
     | 
| 
         @@ -63,27 +136,30 @@ class Glm4Attention(nn.Module): 
     | 
|
| 
       63 
136 
     | 
    
         
             
                        # Number of KV heads is less than TP size, so we replicate
         
     | 
| 
       64 
137 
     | 
    
         
             
                        # the KV heads across multiple tensor parallel GPUs.
         
     | 
| 
       65 
138 
     | 
    
         
             
                        assert tp_size % self.total_num_kv_heads == 0
         
     | 
| 
       66 
     | 
    
         
            -
                    partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
         
     | 
| 
       67 
139 
     | 
    
         
             
                    self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
         
     | 
| 
       68 
     | 
    
         
            -
                     
     | 
| 
      
 140 
     | 
    
         
            +
                    if head_dim is not None:
         
     | 
| 
      
 141 
     | 
    
         
            +
                        self.head_dim = head_dim
         
     | 
| 
      
 142 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 143 
     | 
    
         
            +
                        self.head_dim = hidden_size // self.total_num_heads
         
     | 
| 
       69 
144 
     | 
    
         
             
                    self.q_size = self.num_heads * self.head_dim
         
     | 
| 
       70 
145 
     | 
    
         
             
                    self.kv_size = self.num_kv_heads * self.head_dim
         
     | 
| 
       71 
146 
     | 
    
         
             
                    self.scaling = self.head_dim**-0.5
         
     | 
| 
       72 
     | 
    
         
            -
                    self.rope_theta =  
     | 
| 
       73 
     | 
    
         
            -
                    self. 
     | 
| 
      
 147 
     | 
    
         
            +
                    self.rope_theta = rope_theta
         
     | 
| 
      
 148 
     | 
    
         
            +
                    self.max_position_embeddings = max_position_embeddings
         
     | 
| 
      
 149 
     | 
    
         
            +
                    self.partial_rotary_factor = partial_rotary_factor
         
     | 
| 
       74 
150 
     | 
    
         | 
| 
       75 
151 
     | 
    
         
             
                    self.qkv_proj = QKVParallelLinear(
         
     | 
| 
       76 
     | 
    
         
            -
                         
     | 
| 
      
 152 
     | 
    
         
            +
                        hidden_size,
         
     | 
| 
       77 
153 
     | 
    
         
             
                        self.head_dim,
         
     | 
| 
       78 
154 
     | 
    
         
             
                        self.total_num_heads,
         
     | 
| 
       79 
155 
     | 
    
         
             
                        self.total_num_kv_heads,
         
     | 
| 
       80 
     | 
    
         
            -
                        bias= 
     | 
| 
      
 156 
     | 
    
         
            +
                        bias=True,
         
     | 
| 
       81 
157 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       82 
158 
     | 
    
         
             
                        prefix=add_prefix("qkv_proj", prefix),
         
     | 
| 
       83 
159 
     | 
    
         
             
                    )
         
     | 
| 
       84 
160 
     | 
    
         
             
                    self.o_proj = RowParallelLinear(
         
     | 
| 
       85 
161 
     | 
    
         
             
                        self.total_num_heads * self.head_dim,
         
     | 
| 
       86 
     | 
    
         
            -
                         
     | 
| 
      
 162 
     | 
    
         
            +
                        hidden_size,
         
     | 
| 
       87 
163 
     | 
    
         
             
                        bias=False,
         
     | 
| 
       88 
164 
     | 
    
         
             
                        quant_config=quant_config,
         
     | 
| 
       89 
165 
     | 
    
         
             
                        prefix=add_prefix("o_proj", prefix),
         
     | 
| 
         @@ -92,9 +168,10 @@ class Glm4Attention(nn.Module): 
     | 
|
| 
       92 
168 
     | 
    
         
             
                    self.rotary_emb = get_rope(
         
     | 
| 
       93 
169 
     | 
    
         
             
                        self.head_dim,
         
     | 
| 
       94 
170 
     | 
    
         
             
                        rotary_dim=self.head_dim,
         
     | 
| 
       95 
     | 
    
         
            -
                        max_position= 
     | 
| 
       96 
     | 
    
         
            -
                        base= 
     | 
| 
       97 
     | 
    
         
            -
                        rope_scaling= 
     | 
| 
      
 171 
     | 
    
         
            +
                        max_position=max_position_embeddings,
         
     | 
| 
      
 172 
     | 
    
         
            +
                        base=rope_theta,
         
     | 
| 
      
 173 
     | 
    
         
            +
                        rope_scaling=rope_scaling,
         
     | 
| 
      
 174 
     | 
    
         
            +
                        dual_chunk_attention_config=dual_chunk_attention_config,
         
     | 
| 
       98 
175 
     | 
    
         
             
                        partial_rotary_factor=partial_rotary_factor,
         
     | 
| 
       99 
176 
     | 
    
         
             
                        is_neox_style=False,
         
     | 
| 
       100 
177 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -117,14 +194,9 @@ class Glm4Attention(nn.Module): 
     | 
|
| 
       117 
194 
     | 
    
         
             
                    qkv, _ = self.qkv_proj(hidden_states)
         
     | 
| 
       118 
195 
     | 
    
         
             
                    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
         
     | 
| 
       119 
196 
     | 
    
         
             
                    q, k = self.rotary_emb(positions, q, k)
         
     | 
| 
       120 
     | 
    
         
            -
                     
     | 
| 
       121 
     | 
    
         
            -
             
     | 
| 
       122 
     | 
    
         
            -
             
     | 
| 
       123 
     | 
    
         
            -
                        v,
         
     | 
| 
       124 
     | 
    
         
            -
                        forward_batch,
         
     | 
| 
       125 
     | 
    
         
            -
                    )
         
     | 
| 
       126 
     | 
    
         
            -
                    attn_output, _ = self.o_proj(context_layer)
         
     | 
| 
       127 
     | 
    
         
            -
                    return attn_output
         
     | 
| 
      
 197 
     | 
    
         
            +
                    attn_output = self.attn(q, k, v, forward_batch)
         
     | 
| 
      
 198 
     | 
    
         
            +
                    output, _ = self.o_proj(attn_output)
         
     | 
| 
      
 199 
     | 
    
         
            +
                    return output
         
     | 
| 
       128 
200 
     | 
    
         | 
| 
       129 
201 
     | 
    
         | 
| 
       130 
202 
     | 
    
         
             
            class Glm4DecoderLayer(nn.Module):
         
     | 
| 
         @@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module): 
     | 
|
| 
       136 
208 
     | 
    
         | 
| 
       137 
209 
     | 
    
         
             
                def __init__(
         
     | 
| 
       138 
210 
     | 
    
         
             
                    self,
         
     | 
| 
       139 
     | 
    
         
            -
                    config,
         
     | 
| 
       140 
     | 
    
         
            -
                    layer_id: int,
         
     | 
| 
      
 211 
     | 
    
         
            +
                    config: Glm4Config,
         
     | 
| 
      
 212 
     | 
    
         
            +
                    layer_id: int = 0,
         
     | 
| 
       141 
213 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       142 
214 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       143 
     | 
    
         
            -
             
     | 
| 
      
 215 
     | 
    
         
            +
                    alt_stream: Optional[torch.cuda.Stream] = None,
         
     | 
| 
      
 216 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
       144 
217 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       145 
     | 
    
         
            -
                     
     | 
| 
      
 218 
     | 
    
         
            +
                    self.hidden_size = config.hidden_size
         
     | 
| 
      
 219 
     | 
    
         
            +
                    rope_theta = getattr(config, "rope_theta", 1000000)
         
     | 
| 
      
 220 
     | 
    
         
            +
                    rope_scaling = getattr(config, "rope_scaling", None)
         
     | 
| 
      
 221 
     | 
    
         
            +
                    max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
         
     | 
| 
      
 222 
     | 
    
         
            +
                    head_dim = getattr(config, "head_dim", None)
         
     | 
| 
      
 223 
     | 
    
         
            +
                    partial_rotary_factor = getattr(config, "partial_rotary_factor", None)
         
     | 
| 
      
 224 
     | 
    
         
            +
                    dual_chunk_attention_config = getattr(
         
     | 
| 
      
 225 
     | 
    
         
            +
                        config, "dual_chunk_attention_config", None
         
     | 
| 
      
 226 
     | 
    
         
            +
                    )
         
     | 
| 
       146 
227 
     | 
    
         
             
                    self.self_attn = Glm4Attention(
         
     | 
| 
       147 
     | 
    
         
            -
                         
     | 
| 
      
 228 
     | 
    
         
            +
                        hidden_size=self.hidden_size,
         
     | 
| 
      
 229 
     | 
    
         
            +
                        num_heads=config.num_attention_heads,
         
     | 
| 
      
 230 
     | 
    
         
            +
                        num_kv_heads=config.num_key_value_heads,
         
     | 
| 
      
 231 
     | 
    
         
            +
                        head_dim=head_dim,
         
     | 
| 
      
 232 
     | 
    
         
            +
                        layer_id=layer_id,
         
     | 
| 
      
 233 
     | 
    
         
            +
                        rope_theta=rope_theta,
         
     | 
| 
      
 234 
     | 
    
         
            +
                        rope_scaling=rope_scaling,
         
     | 
| 
      
 235 
     | 
    
         
            +
                        max_position_embeddings=max_position_embeddings,
         
     | 
| 
      
 236 
     | 
    
         
            +
                        quant_config=quant_config,
         
     | 
| 
      
 237 
     | 
    
         
            +
                        dual_chunk_attention_config=dual_chunk_attention_config,
         
     | 
| 
      
 238 
     | 
    
         
            +
                        partial_rotary_factor=partial_rotary_factor,
         
     | 
| 
      
 239 
     | 
    
         
            +
                        prefix=add_prefix("self_attn", prefix),
         
     | 
| 
       148 
240 
     | 
    
         
             
                    )
         
     | 
| 
       149 
241 
     | 
    
         | 
| 
       150 
242 
     | 
    
         
             
                    # MLP
         
     | 
| 
         @@ -199,54 +291,125 @@ class Glm4Model(nn.Module): 
     | 
|
| 
       199 
291 
     | 
    
         
             
                    config: Glm4Config,
         
     | 
| 
       200 
292 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       201 
293 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
      
 294 
     | 
    
         
            +
                    decoder_layer_type: type[nn.Module] = Glm4DecoderLayer,
         
     | 
| 
      
 295 
     | 
    
         
            +
                    alt_stream: Optional[torch.cuda.Stream] = None,
         
     | 
| 
       202 
296 
     | 
    
         
             
                ) -> None:
         
     | 
| 
       203 
297 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       204 
298 
     | 
    
         
             
                    self.config = config
         
     | 
| 
       205 
     | 
    
         
            -
                    self. 
     | 
| 
       206 
     | 
    
         
            -
             
     | 
| 
       207 
     | 
    
         
            -
             
     | 
| 
       208 
     | 
    
         
            -
             
     | 
| 
       209 
     | 
    
         
            -
             
     | 
| 
       210 
     | 
    
         
            -
             
     | 
| 
       211 
     | 
    
         
            -
             
     | 
| 
      
 299 
     | 
    
         
            +
                    self.padding_idx = config.pad_token_id
         
     | 
| 
      
 300 
     | 
    
         
            +
                    self.vocab_size = config.vocab_size
         
     | 
| 
      
 301 
     | 
    
         
            +
                    self.pp_group = get_pp_group()
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
                    if self.pp_group.is_first_rank:
         
     | 
| 
      
 304 
     | 
    
         
            +
                        self.embed_tokens = VocabParallelEmbedding(
         
     | 
| 
      
 305 
     | 
    
         
            +
                            config.vocab_size,
         
     | 
| 
      
 306 
     | 
    
         
            +
                            config.hidden_size,
         
     | 
| 
      
 307 
     | 
    
         
            +
                            quant_config=quant_config,
         
     | 
| 
      
 308 
     | 
    
         
            +
                            enable_tp=not is_dp_attention_enabled(),
         
     | 
| 
      
 309 
     | 
    
         
            +
                            prefix=add_prefix("embed_tokens", prefix),
         
     | 
| 
      
 310 
     | 
    
         
            +
                        )
         
     | 
| 
      
 311 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 312 
     | 
    
         
            +
                        self.embed_tokens = PPMissingLayer()
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
                    # Use the provided decoder layer type or default to Glm4DecoderLayer
         
     | 
| 
      
 315 
     | 
    
         
            +
                    decoder_layer_type = decoder_layer_type or Glm4DecoderLayer
         
     | 
| 
      
 316 
     | 
    
         
            +
                    self.layers, self.start_layer, self.end_layer = make_layers(
         
     | 
| 
       212 
317 
     | 
    
         
             
                        config.num_hidden_layers,
         
     | 
| 
       213 
     | 
    
         
            -
                        lambda idx, prefix:  
     | 
| 
       214 
     | 
    
         
            -
                             
     | 
| 
      
 318 
     | 
    
         
            +
                        lambda idx, prefix: decoder_layer_type(
         
     | 
| 
      
 319 
     | 
    
         
            +
                            layer_id=idx,
         
     | 
| 
      
 320 
     | 
    
         
            +
                            config=config,
         
     | 
| 
      
 321 
     | 
    
         
            +
                            quant_config=quant_config,
         
     | 
| 
      
 322 
     | 
    
         
            +
                            prefix=prefix,
         
     | 
| 
      
 323 
     | 
    
         
            +
                            alt_stream=alt_stream,
         
     | 
| 
       215 
324 
     | 
    
         
             
                        ),
         
     | 
| 
       216 
     | 
    
         
            -
                         
     | 
| 
      
 325 
     | 
    
         
            +
                        pp_rank=self.pp_group.rank_in_group,
         
     | 
| 
      
 326 
     | 
    
         
            +
                        pp_size=self.pp_group.world_size,
         
     | 
| 
      
 327 
     | 
    
         
            +
                        prefix=add_prefix("layers", prefix),
         
     | 
| 
       217 
328 
     | 
    
         
             
                    )
         
     | 
| 
      
 329 
     | 
    
         
            +
                    if self.pp_group.is_last_rank:
         
     | 
| 
      
 330 
     | 
    
         
            +
                        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         
     | 
| 
      
 331 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 332 
     | 
    
         
            +
                        self.norm = PPMissingLayer(return_tuple=True)
         
     | 
| 
       218 
333 
     | 
    
         | 
| 
       219 
     | 
    
         
            -
                     
     | 
| 
      
 334 
     | 
    
         
            +
                    # For EAGLE3 support
         
     | 
| 
      
 335 
     | 
    
         
            +
                    self.layers_to_capture = []
         
     | 
| 
       220 
336 
     | 
    
         | 
| 
       221 
337 
     | 
    
         
             
                def get_input_embeddings(self) -> nn.Embedding:
         
     | 
| 
       222 
338 
     | 
    
         
             
                    return self.embed_tokens
         
     | 
| 
       223 
339 
     | 
    
         | 
| 
       224 
     | 
    
         
            -
                def dtype(self) -> torch.dtype:
         
     | 
| 
       225 
     | 
    
         
            -
                    return next(self.parameters()).dtype
         
     | 
| 
       226 
     | 
    
         
            -
             
     | 
| 
       227 
     | 
    
         
            -
                @torch.no_grad()
         
     | 
| 
       228 
340 
     | 
    
         
             
                def forward(
         
     | 
| 
       229 
341 
     | 
    
         
             
                    self,
         
     | 
| 
       230 
342 
     | 
    
         
             
                    input_ids: torch.Tensor,
         
     | 
| 
       231 
343 
     | 
    
         
             
                    positions: torch.Tensor,
         
     | 
| 
       232 
344 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
       233 
345 
     | 
    
         
             
                    input_embeds: torch.Tensor = None,
         
     | 
| 
       234 
     | 
    
         
            -
             
     | 
| 
       235 
     | 
    
         
            -
             
     | 
| 
       236 
     | 
    
         
            -
             
     | 
| 
      
 346 
     | 
    
         
            +
                    pp_proxy_tensors: Optional[PPProxyTensors] = None,
         
     | 
| 
      
 347 
     | 
    
         
            +
                ) -> Union[torch.Tensor, PPProxyTensors]:
         
     | 
| 
      
 348 
     | 
    
         
            +
                    if self.pp_group.is_first_rank:
         
     | 
| 
      
 349 
     | 
    
         
            +
                        if input_embeds is None:
         
     | 
| 
      
 350 
     | 
    
         
            +
                            hidden_states = self.embed_tokens(input_ids)
         
     | 
| 
      
 351 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 352 
     | 
    
         
            +
                            hidden_states = input_embeds
         
     | 
| 
      
 353 
     | 
    
         
            +
                        residual = None
         
     | 
| 
       237 
354 
     | 
    
         
             
                    else:
         
     | 
| 
       238 
     | 
    
         
            -
                         
     | 
| 
       239 
     | 
    
         
            -
             
     | 
| 
       240 
     | 
    
         
            -
             
     | 
| 
      
 355 
     | 
    
         
            +
                        assert pp_proxy_tensors is not None
         
     | 
| 
      
 356 
     | 
    
         
            +
                        hidden_states = pp_proxy_tensors["hidden_states"]
         
     | 
| 
      
 357 
     | 
    
         
            +
                        residual = pp_proxy_tensors["residual"]
         
     | 
| 
      
 358 
     | 
    
         
            +
             
     | 
| 
      
 359 
     | 
    
         
            +
                    aux_hidden_states = []
         
     | 
| 
      
 360 
     | 
    
         
            +
                    for i in range(self.start_layer, self.end_layer):
         
     | 
| 
      
 361 
     | 
    
         
            +
                        if i in self.layers_to_capture:
         
     | 
| 
      
 362 
     | 
    
         
            +
                            aux_hidden_states.append(
         
     | 
| 
      
 363 
     | 
    
         
            +
                                hidden_states + residual if residual is not None else hidden_states
         
     | 
| 
      
 364 
     | 
    
         
            +
                            )
         
     | 
| 
      
 365 
     | 
    
         
            +
                        layer = self.layers[i]
         
     | 
| 
       241 
366 
     | 
    
         
             
                        hidden_states, residual = layer(
         
     | 
| 
       242 
367 
     | 
    
         
             
                            positions,
         
     | 
| 
       243 
368 
     | 
    
         
             
                            hidden_states,
         
     | 
| 
       244 
369 
     | 
    
         
             
                            forward_batch,
         
     | 
| 
       245 
370 
     | 
    
         
             
                            residual,
         
     | 
| 
       246 
371 
     | 
    
         
             
                        )
         
     | 
| 
       247 
     | 
    
         
            -
                     
     | 
| 
      
 372 
     | 
    
         
            +
                    if not self.pp_group.is_last_rank:
         
     | 
| 
      
 373 
     | 
    
         
            +
                        return PPProxyTensors(
         
     | 
| 
      
 374 
     | 
    
         
            +
                            {
         
     | 
| 
      
 375 
     | 
    
         
            +
                                "hidden_states": hidden_states,
         
     | 
| 
      
 376 
     | 
    
         
            +
                                "residual": residual,
         
     | 
| 
      
 377 
     | 
    
         
            +
                            }
         
     | 
| 
      
 378 
     | 
    
         
            +
                        )
         
     | 
| 
      
 379 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 380 
     | 
    
         
            +
                        if hidden_states.shape[0] != 0:
         
     | 
| 
      
 381 
     | 
    
         
            +
                            if residual is None:
         
     | 
| 
      
 382 
     | 
    
         
            +
                                hidden_states = self.norm(hidden_states)
         
     | 
| 
      
 383 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 384 
     | 
    
         
            +
                                hidden_states, _ = self.norm(hidden_states, residual)
         
     | 
| 
      
 385 
     | 
    
         
            +
             
     | 
| 
      
 386 
     | 
    
         
            +
                    if len(aux_hidden_states) == 0:
         
     | 
| 
      
 387 
     | 
    
         
            +
                        return hidden_states
         
     | 
| 
       248 
388 
     | 
    
         | 
| 
       249 
     | 
    
         
            -
                    return hidden_states
         
     | 
| 
      
 389 
     | 
    
         
            +
                    return hidden_states, aux_hidden_states
         
     | 
| 
      
 390 
     | 
    
         
            +
             
     | 
| 
      
 391 
     | 
    
         
            +
                # If this function is called, it should always initialize KV cache scale
         
     | 
| 
      
 392 
     | 
    
         
            +
                # factors (or else raise an exception). Thus, handled exceptions should
         
     | 
| 
      
 393 
     | 
    
         
            +
                # make sure to leave KV cache scale factors in a known good (dummy) state
         
     | 
| 
      
 394 
     | 
    
         
            +
                def load_kv_cache_scales(self, quantization_param_path: str) -> None:
         
     | 
| 
      
 395 
     | 
    
         
            +
                    tp_size = get_tensor_model_parallel_world_size()
         
     | 
| 
      
 396 
     | 
    
         
            +
                    tp_rank = get_tensor_model_parallel_rank()
         
     | 
| 
      
 397 
     | 
    
         
            +
                    for layer_idx, scaling_factor in kv_cache_scales_loader(
         
     | 
| 
      
 398 
     | 
    
         
            +
                        quantization_param_path,
         
     | 
| 
      
 399 
     | 
    
         
            +
                        tp_rank,
         
     | 
| 
      
 400 
     | 
    
         
            +
                        tp_size,
         
     | 
| 
      
 401 
     | 
    
         
            +
                        self.config.num_hidden_layers,
         
     | 
| 
      
 402 
     | 
    
         
            +
                        self.config.__class__.model_type,
         
     | 
| 
      
 403 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 404 
     | 
    
         
            +
                        if not isinstance(self.layers[layer_idx], nn.Identity):
         
     | 
| 
      
 405 
     | 
    
         
            +
                            layer_self_attn = self.layers[layer_idx].self_attn
         
     | 
| 
      
 406 
     | 
    
         
            +
                        if hasattr(layer_self_attn.attn, "k_scale"):
         
     | 
| 
      
 407 
     | 
    
         
            +
                            layer_self_attn.attn.k_scale = scaling_factor
         
     | 
| 
      
 408 
     | 
    
         
            +
                            layer_self_attn.attn.v_scale = scaling_factor
         
     | 
| 
      
 409 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 410 
     | 
    
         
            +
                            raise RuntimeError(
         
     | 
| 
      
 411 
     | 
    
         
            +
                                "Self attention has no KV cache scaling factor attribute!"
         
     | 
| 
      
 412 
     | 
    
         
            +
                            )
         
     | 
| 
       250 
413 
     | 
    
         | 
| 
       251 
414 
     | 
    
         | 
| 
       252 
415 
     | 
    
         
             
            class Glm4ForCausalLM(nn.Module):
         
     | 
| 
         @@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module): 
     | 
|
| 
       255 
418 
     | 
    
         
             
                    config: Glm4Config,
         
     | 
| 
       256 
419 
     | 
    
         
             
                    quant_config: Optional[QuantizationConfig] = None,
         
     | 
| 
       257 
420 
     | 
    
         
             
                    prefix: str = "",
         
     | 
| 
       258 
     | 
    
         
            -
                ):
         
     | 
| 
      
 421 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
       259 
422 
     | 
    
         
             
                    super().__init__()
         
     | 
| 
       260 
     | 
    
         
            -
                    self. 
     | 
| 
      
 423 
     | 
    
         
            +
                    self.pp_group = get_pp_group()
         
     | 
| 
      
 424 
     | 
    
         
            +
                    self.config = config
         
     | 
| 
       261 
425 
     | 
    
         
             
                    self.quant_config = quant_config
         
     | 
| 
       262 
     | 
    
         
            -
                    self.model = Glm4Model( 
     | 
| 
       263 
     | 
    
         
            -
             
     | 
| 
       264 
     | 
    
         
            -
             
     | 
| 
      
 426 
     | 
    
         
            +
                    self.model = Glm4Model(
         
     | 
| 
      
 427 
     | 
    
         
            +
                        config, quant_config=quant_config, prefix=add_prefix("model", prefix)
         
     | 
| 
      
 428 
     | 
    
         
            +
                    )
         
     | 
| 
      
 429 
     | 
    
         
            +
             
     | 
| 
      
 430 
     | 
    
         
            +
                    # handle the lm head on different pp ranks
         
     | 
| 
      
 431 
     | 
    
         
            +
                    if self.pp_group.is_last_rank:
         
     | 
| 
      
 432 
     | 
    
         
            +
                        if self.pp_group.world_size == 1 and config.tie_word_embeddings:
         
     | 
| 
      
 433 
     | 
    
         
            +
                            self.lm_head = self.model.embed_tokens
         
     | 
| 
      
 434 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 435 
     | 
    
         
            +
                            self.lm_head = ParallelLMHead(
         
     | 
| 
      
 436 
     | 
    
         
            +
                                config.vocab_size,
         
     | 
| 
      
 437 
     | 
    
         
            +
                                config.hidden_size,
         
     | 
| 
      
 438 
     | 
    
         
            +
                                quant_config=quant_config,
         
     | 
| 
      
 439 
     | 
    
         
            +
                                prefix=add_prefix("lm_head", prefix),
         
     | 
| 
      
 440 
     | 
    
         
            +
                            )
         
     | 
| 
       265 
441 
     | 
    
         
             
                    else:
         
     | 
| 
       266 
     | 
    
         
            -
                         
     | 
| 
       267 
     | 
    
         
            -
             
     | 
| 
       268 
     | 
    
         
            -
             
     | 
| 
       269 
     | 
    
         
            -
             
     | 
| 
       270 
     | 
    
         
            -
             
     | 
| 
       271 
     | 
    
         
            -
                         
     | 
| 
      
 442 
     | 
    
         
            +
                        # ranks other than the last rank will have a placeholder layer
         
     | 
| 
      
 443 
     | 
    
         
            +
                        self.lm_head = PPMissingLayer()
         
     | 
| 
      
 444 
     | 
    
         
            +
             
     | 
| 
      
 445 
     | 
    
         
            +
                    # perform weight tying for PP
         
     | 
| 
      
 446 
     | 
    
         
            +
                    if self.pp_group.world_size > 1 and config.tie_word_embeddings:
         
     | 
| 
      
 447 
     | 
    
         
            +
                        if self.pp_group.is_first_rank:
         
     | 
| 
      
 448 
     | 
    
         
            +
                            self.pp_group.send(
         
     | 
| 
      
 449 
     | 
    
         
            +
                                self.model.embed_tokens.weight, dst=self.pp_group.last_rank
         
     | 
| 
      
 450 
     | 
    
         
            +
                            )
         
     | 
| 
      
 451 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 452 
     | 
    
         
            +
                            emb_token_weight = self.pp_group.recv(
         
     | 
| 
      
 453 
     | 
    
         
            +
                                size=(config.vocab_size, config.hidden_size),
         
     | 
| 
      
 454 
     | 
    
         
            +
                                dtype=next(self.model.parameters()).dtype,
         
     | 
| 
      
 455 
     | 
    
         
            +
                                src=self.pp_group.first_rank,
         
     | 
| 
      
 456 
     | 
    
         
            +
                            )
         
     | 
| 
      
 457 
     | 
    
         
            +
                            self.lm_head.weight.copy_(emb_token_weight)
         
     | 
| 
      
 458 
     | 
    
         
            +
             
     | 
| 
       272 
459 
     | 
    
         
             
                    self.logits_processor = LogitsProcessor(config)
         
     | 
| 
      
 460 
     | 
    
         
            +
                    self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
         
     | 
| 
      
 461 
     | 
    
         
            +
                    # For EAGLE3 support
         
     | 
| 
      
 462 
     | 
    
         
            +
                    self.capture_aux_hidden_states = False
         
     | 
| 
      
 463 
     | 
    
         
            +
             
     | 
| 
      
 464 
     | 
    
         
            +
                def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
         
     | 
| 
      
 465 
     | 
    
         
            +
                    return self.model.get_input_embedding(input_ids)
         
     | 
| 
      
 466 
     | 
    
         
            +
             
     | 
| 
      
 467 
     | 
    
         
            +
                def get_input_embeddings(self) -> nn.Embedding:
         
     | 
| 
      
 468 
     | 
    
         
            +
                    return self.model.embed_tokens
         
     | 
| 
       273 
469 
     | 
    
         | 
| 
       274 
470 
     | 
    
         
             
                @torch.no_grad()
         
     | 
| 
       275 
471 
     | 
    
         
             
                def forward(
         
     | 
| 
         @@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module): 
     | 
|
| 
       277 
473 
     | 
    
         
             
                    input_ids: torch.Tensor,
         
     | 
| 
       278 
474 
     | 
    
         
             
                    positions: torch.Tensor,
         
     | 
| 
       279 
475 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 476 
     | 
    
         
            +
                    input_embeds: torch.Tensor = None,
         
     | 
| 
      
 477 
     | 
    
         
            +
                    get_embedding: bool = False,
         
     | 
| 
      
 478 
     | 
    
         
            +
                    pp_proxy_tensors: Optional[PPProxyTensors] = None,
         
     | 
| 
       280 
479 
     | 
    
         
             
                ) -> torch.Tensor:
         
     | 
| 
       281 
     | 
    
         
            -
                    hidden_states = self.model( 
     | 
| 
       282 
     | 
    
         
            -
             
     | 
| 
       283 
     | 
    
         
            -
                         
     | 
| 
      
 480 
     | 
    
         
            +
                    hidden_states = self.model(
         
     | 
| 
      
 481 
     | 
    
         
            +
                        input_ids,
         
     | 
| 
      
 482 
     | 
    
         
            +
                        positions,
         
     | 
| 
      
 483 
     | 
    
         
            +
                        forward_batch,
         
     | 
| 
      
 484 
     | 
    
         
            +
                        input_embeds,
         
     | 
| 
      
 485 
     | 
    
         
            +
                        pp_proxy_tensors=pp_proxy_tensors,
         
     | 
| 
       284 
486 
     | 
    
         
             
                    )
         
     | 
| 
      
 487 
     | 
    
         
            +
                    aux_hidden_states = None
         
     | 
| 
      
 488 
     | 
    
         
            +
                    if self.capture_aux_hidden_states:
         
     | 
| 
      
 489 
     | 
    
         
            +
                        hidden_states, aux_hidden_states = hidden_states
         
     | 
| 
      
 490 
     | 
    
         
            +
             
     | 
| 
      
 491 
     | 
    
         
            +
                    if self.pp_group.is_last_rank:
         
     | 
| 
      
 492 
     | 
    
         
            +
                        if not get_embedding:
         
     | 
| 
      
 493 
     | 
    
         
            +
                            return self.logits_processor(
         
     | 
| 
      
 494 
     | 
    
         
            +
                                input_ids,
         
     | 
| 
      
 495 
     | 
    
         
            +
                                hidden_states,
         
     | 
| 
      
 496 
     | 
    
         
            +
                                self.lm_head,
         
     | 
| 
      
 497 
     | 
    
         
            +
                                forward_batch,
         
     | 
| 
      
 498 
     | 
    
         
            +
                                aux_hidden_states,
         
     | 
| 
      
 499 
     | 
    
         
            +
                            )
         
     | 
| 
      
 500 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 501 
     | 
    
         
            +
                            return self.pooler(hidden_states, forward_batch)
         
     | 
| 
      
 502 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 503 
     | 
    
         
            +
                        return hidden_states
         
     | 
| 
      
 504 
     | 
    
         
            +
             
     | 
| 
      
 505 
     | 
    
         
            +
                @torch.no_grad()
         
     | 
| 
      
 506 
     | 
    
         
            +
                def forward_split_prefill(
         
     | 
| 
      
 507 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 508 
     | 
    
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 
      
 509 
     | 
    
         
            +
                    positions: torch.Tensor,
         
     | 
| 
      
 510 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 511 
     | 
    
         
            +
                    split_interval: Tuple[int, int],  # [start, end) 0-based
         
     | 
| 
      
 512 
     | 
    
         
            +
                    input_embeds: torch.Tensor = None,
         
     | 
| 
      
 513 
     | 
    
         
            +
                ):
         
     | 
| 
      
 514 
     | 
    
         
            +
                    start, end = split_interval
         
     | 
| 
      
 515 
     | 
    
         
            +
                    # embed
         
     | 
| 
      
 516 
     | 
    
         
            +
                    if start == 0:
         
     | 
| 
      
 517 
     | 
    
         
            +
                        if input_embeds is None:
         
     | 
| 
      
 518 
     | 
    
         
            +
                            forward_batch.hidden_states = self.model.embed_tokens(input_ids)
         
     | 
| 
      
 519 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 520 
     | 
    
         
            +
                            forward_batch.hidden_states = input_embeds
         
     | 
| 
      
 521 
     | 
    
         
            +
                    # decoder layer
         
     | 
| 
      
 522 
     | 
    
         
            +
                    for i in range(start, end):
         
     | 
| 
      
 523 
     | 
    
         
            +
                        layer = self.model.layers[i]
         
     | 
| 
      
 524 
     | 
    
         
            +
                        forward_batch.hidden_states, forward_batch.residual = layer(
         
     | 
| 
      
 525 
     | 
    
         
            +
                            positions,
         
     | 
| 
      
 526 
     | 
    
         
            +
                            forward_batch.hidden_states,
         
     | 
| 
      
 527 
     | 
    
         
            +
                            forward_batch,
         
     | 
| 
      
 528 
     | 
    
         
            +
                            forward_batch.residual,
         
     | 
| 
      
 529 
     | 
    
         
            +
                        )
         
     | 
| 
      
 530 
     | 
    
         
            +
             
     | 
| 
      
 531 
     | 
    
         
            +
                    if end == self.model.config.num_hidden_layers:
         
     | 
| 
      
 532 
     | 
    
         
            +
                        # norm
         
     | 
| 
      
 533 
     | 
    
         
            +
                        hidden_states, _ = self.model.norm(
         
     | 
| 
      
 534 
     | 
    
         
            +
                            forward_batch.hidden_states, forward_batch.residual
         
     | 
| 
      
 535 
     | 
    
         
            +
                        )
         
     | 
| 
      
 536 
     | 
    
         
            +
                        forward_batch.hidden_states = hidden_states
         
     | 
| 
      
 537 
     | 
    
         
            +
                        # logits process
         
     | 
| 
      
 538 
     | 
    
         
            +
                        result = self.logits_processor(
         
     | 
| 
      
 539 
     | 
    
         
            +
                            input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
         
     | 
| 
      
 540 
     | 
    
         
            +
                        )
         
     | 
| 
      
 541 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 542 
     | 
    
         
            +
                        result = None
         
     | 
| 
      
 543 
     | 
    
         
            +
             
     | 
| 
      
 544 
     | 
    
         
            +
                    return result
         
     | 
| 
      
 545 
     | 
    
         
            +
             
     | 
| 
      
 546 
     | 
    
         
            +
                @property
         
     | 
| 
      
 547 
     | 
    
         
            +
                def start_layer(self):
         
     | 
| 
      
 548 
     | 
    
         
            +
                    return self.model.start_layer
         
     | 
| 
      
 549 
     | 
    
         
            +
             
     | 
| 
      
 550 
     | 
    
         
            +
                @property
         
     | 
| 
      
 551 
     | 
    
         
            +
                def end_layer(self):
         
     | 
| 
      
 552 
     | 
    
         
            +
                    return self.model.end_layer
         
     | 
| 
       285 
553 
     | 
    
         | 
| 
       286 
554 
     | 
    
         
             
                def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         
     | 
| 
       287 
555 
     | 
    
         
             
                    stacked_params_mapping = [
         
     | 
| 
       288 
     | 
    
         
            -
                        # (param_name,  
     | 
| 
      
 556 
     | 
    
         
            +
                        # (param_name, shard_name, shard_id)
         
     | 
| 
       289 
557 
     | 
    
         
             
                        (".qkv_proj", ".q_proj", "q"),
         
     | 
| 
       290 
558 
     | 
    
         
             
                        (".qkv_proj", ".k_proj", "k"),
         
     | 
| 
       291 
559 
     | 
    
         
             
                        (".qkv_proj", ".v_proj", "v"),
         
     | 
| 
       292 
     | 
    
         
            -
                        (".gate_up_proj", ".gate_proj", 0),
         
     | 
| 
       293 
560 
     | 
    
         
             
                        (".gate_up_proj", ".up_proj", 1),
         
     | 
| 
      
 561 
     | 
    
         
            +
                        (".gate_up_proj", ".gate_proj", 0),
         
     | 
| 
       294 
562 
     | 
    
         
             
                    ]
         
     | 
| 
      
 563 
     | 
    
         
            +
             
     | 
| 
       295 
564 
     | 
    
         
             
                    params_dict = dict(self.named_parameters())
         
     | 
| 
       296 
565 
     | 
    
         
             
                    for name, loaded_weight in weights:
         
     | 
| 
       297 
     | 
    
         
            -
                         
     | 
| 
      
 566 
     | 
    
         
            +
                        layer_id = get_layer_id(name)
         
     | 
| 
      
 567 
     | 
    
         
            +
                        if (
         
     | 
| 
      
 568 
     | 
    
         
            +
                            layer_id is not None
         
     | 
| 
      
 569 
     | 
    
         
            +
                            and hasattr(self.model, "start_layer")
         
     | 
| 
      
 570 
     | 
    
         
            +
                            and (
         
     | 
| 
      
 571 
     | 
    
         
            +
                                layer_id < self.model.start_layer
         
     | 
| 
      
 572 
     | 
    
         
            +
                                or layer_id >= self.model.end_layer
         
     | 
| 
      
 573 
     | 
    
         
            +
                            )
         
     | 
| 
      
 574 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 575 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 576 
     | 
    
         
            +
             
     | 
| 
      
 577 
     | 
    
         
            +
                        if "rotary_emb.inv_freq" in name or "projector" in name:
         
     | 
| 
       298 
578 
     | 
    
         
             
                            continue
         
     | 
| 
      
 579 
     | 
    
         
            +
                        if self.config.tie_word_embeddings and "lm_head.weight" in name:
         
     | 
| 
      
 580 
     | 
    
         
            +
                            if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
         
     | 
| 
      
 581 
     | 
    
         
            +
                                # Handle pp weight tying here
         
     | 
| 
      
 582 
     | 
    
         
            +
                                # find the embed_tokens.weight in the weights
         
     | 
| 
      
 583 
     | 
    
         
            +
                                embed_token_weights = next(
         
     | 
| 
      
 584 
     | 
    
         
            +
                                    filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
         
     | 
| 
      
 585 
     | 
    
         
            +
                                )[1]
         
     | 
| 
      
 586 
     | 
    
         
            +
                                loaded_weight = embed_token_weights
         
     | 
| 
      
 587 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 588 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 589 
     | 
    
         
            +
             
     | 
| 
       299 
590 
     | 
    
         
             
                        for param_name, weight_name, shard_id in stacked_params_mapping:
         
     | 
| 
       300 
591 
     | 
    
         
             
                            if weight_name not in name:
         
     | 
| 
       301 
592 
     | 
    
         
             
                                continue
         
     | 
| 
       302 
593 
     | 
    
         
             
                            name = name.replace(weight_name, param_name)
         
     | 
| 
      
 594 
     | 
    
         
            +
                            # Skip loading extra bias for GPTQ models.
         
     | 
| 
      
 595 
     | 
    
         
            +
                            if name.endswith(".bias") and name not in params_dict:
         
     | 
| 
      
 596 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 597 
     | 
    
         
            +
                            if name not in params_dict:
         
     | 
| 
      
 598 
     | 
    
         
            +
                                continue
         
     | 
| 
       303 
599 
     | 
    
         
             
                            param = params_dict[name]
         
     | 
| 
       304 
600 
     | 
    
         
             
                            weight_loader = param.weight_loader
         
     | 
| 
       305 
601 
     | 
    
         
             
                            weight_loader(param, loaded_weight, shard_id)
         
     | 
| 
       306 
602 
     | 
    
         
             
                            break
         
     | 
| 
       307 
603 
     | 
    
         
             
                        else:
         
     | 
| 
      
 604 
     | 
    
         
            +
                            # Skip loading extra bias for GPTQ models.
         
     | 
| 
      
 605 
     | 
    
         
            +
                            if name.endswith(".bias") and name not in params_dict:
         
     | 
| 
      
 606 
     | 
    
         
            +
                                continue
         
     | 
| 
      
 607 
     | 
    
         
            +
             
     | 
| 
       308 
608 
     | 
    
         
             
                            if name in params_dict.keys():
         
     | 
| 
       309 
609 
     | 
    
         
             
                                param = params_dict[name]
         
     | 
| 
       310 
610 
     | 
    
         
             
                                weight_loader = getattr(
         
     | 
| 
         @@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module): 
     | 
|
| 
       312 
612 
     | 
    
         
             
                                )
         
     | 
| 
       313 
613 
     | 
    
         
             
                                weight_loader(param, loaded_weight)
         
     | 
| 
       314 
614 
     | 
    
         
             
                            else:
         
     | 
| 
       315 
     | 
    
         
            -
                                 
     | 
| 
      
 615 
     | 
    
         
            +
                                logger.warning(f"Parameter {name} not found in params_dict")
         
     | 
| 
      
 616 
     | 
    
         
            +
             
     | 
| 
      
 617 
     | 
    
         
            +
                def get_embed_and_head(self):
         
     | 
| 
      
 618 
     | 
    
         
            +
                    return self.model.embed_tokens.weight, self.lm_head.weight
         
     | 
| 
      
 619 
     | 
    
         
            +
             
     | 
| 
      
 620 
     | 
    
         
            +
                def set_embed_and_head(self, embed, head):
         
     | 
| 
      
 621 
     | 
    
         
            +
                    del self.model.embed_tokens.weight
         
     | 
| 
      
 622 
     | 
    
         
            +
                    del self.lm_head.weight
         
     | 
| 
      
 623 
     | 
    
         
            +
                    self.model.embed_tokens.weight = embed
         
     | 
| 
      
 624 
     | 
    
         
            +
                    self.lm_head.weight = head
         
     | 
| 
      
 625 
     | 
    
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 
      
 626 
     | 
    
         
            +
                    torch.cuda.synchronize()
         
     | 
| 
      
 627 
     | 
    
         
            +
             
     | 
| 
      
 628 
     | 
    
         
            +
                def load_kv_cache_scales(self, quantization_param_path: str) -> None:
         
     | 
| 
      
 629 
     | 
    
         
            +
                    self.model.load_kv_cache_scales(quantization_param_path)
         
     | 
| 
       316 
630 
     | 
    
         | 
| 
       317 
631 
     | 
    
         | 
| 
       318 
632 
     | 
    
         
             
            EntryClass = [Glm4ForCausalLM]
         
     |