sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/glm4.py CHANGED
@@ -15,46 +15,119 @@
15
15
  # Modeling from:
16
16
  # ./llama.py and
17
17
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4/modular_glm4.py
18
- """Inference-only GLM4 model compatible with THUDM weights."""
18
+ """Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
19
19
 
20
- from typing import Iterable, List, Optional, Tuple, Union
20
+ import logging
21
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
21
22
 
22
23
  import torch
23
24
  from torch import nn
24
- from transformers import Glm4Config
25
25
 
26
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
+ from sglang.srt.distributed import (
27
+ get_pp_group,
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ )
31
+ from sglang.srt.layers.activation import SiluAndMul
32
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
27
33
  from sglang.srt.layers.layernorm import RMSNorm
28
- from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
34
+ from sglang.srt.layers.linear import (
35
+ MergedColumnParallelLinear,
36
+ QKVParallelLinear,
37
+ RowParallelLinear,
38
+ )
29
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.pooler import Pooler, PoolingType
30
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
42
  from sglang.srt.layers.radix_attention import RadixAttention
32
43
  from sglang.srt.layers.rotary_embedding import get_rope
44
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
33
45
  from sglang.srt.layers.vocab_parallel_embedding import (
34
46
  ParallelLMHead,
35
47
  VocabParallelEmbedding,
36
48
  )
37
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
- from sglang.srt.model_loader.weight_utils import default_weight_loader
39
- from sglang.srt.models.llama import LlamaMLP as Glm4MLP
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
50
+ from sglang.srt.model_loader.weight_utils import (
51
+ default_weight_loader,
52
+ kv_cache_scales_loader,
53
+ )
40
54
  from sglang.srt.utils import add_prefix, make_layers
41
55
 
56
+ Glm4Config = None
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+
61
+ class Glm4MLP(nn.Module):
62
+ def __init__(
63
+ self,
64
+ hidden_size: int,
65
+ intermediate_size: int,
66
+ hidden_act: str,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ prefix: str = "",
69
+ reduce_results: bool = True,
70
+ ) -> None:
71
+ super().__init__()
72
+ self.gate_up_proj = MergedColumnParallelLinear(
73
+ hidden_size,
74
+ [intermediate_size] * 2,
75
+ bias=False,
76
+ quant_config=quant_config,
77
+ prefix=add_prefix("gate_up_proj", prefix),
78
+ )
79
+ self.down_proj = RowParallelLinear(
80
+ intermediate_size,
81
+ hidden_size,
82
+ bias=False,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("down_proj", prefix),
85
+ reduce_results=reduce_results,
86
+ )
87
+ if hidden_act != "silu":
88
+ raise ValueError(
89
+ f"Unsupported activation: {hidden_act}. Only silu is supported for now."
90
+ )
91
+ self.act_fn = SiluAndMul()
92
+
93
+ def forward(
94
+ self,
95
+ x,
96
+ forward_batch=None,
97
+ use_reduce_scatter: bool = False,
98
+ ):
99
+ gate_up, _ = self.gate_up_proj(x)
100
+ x = self.act_fn(gate_up)
101
+ x, _ = self.down_proj(
102
+ x,
103
+ skip_all_reduce=use_reduce_scatter,
104
+ )
105
+ return x
106
+
42
107
 
43
108
  class Glm4Attention(nn.Module):
44
109
  def __init__(
45
110
  self,
46
- config,
111
+ hidden_size: int,
112
+ num_heads: int,
113
+ num_kv_heads: int,
114
+ head_dim: Optional[int] = None,
47
115
  layer_id: int = 0,
116
+ rope_theta: float = 1000000,
117
+ rope_scaling: Optional[Dict[str, Any]] = None,
118
+ max_position_embeddings: int = 131072,
48
119
  quant_config: Optional[QuantizationConfig] = None,
120
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
121
+ partial_rotary_factor: float = 0.5,
49
122
  prefix: str = "",
50
- ):
123
+ ) -> None:
51
124
  super().__init__()
52
- self.hidden_size = config.hidden_size
125
+ self.hidden_size = hidden_size
53
126
  tp_size = get_tensor_model_parallel_world_size()
54
- self.total_num_heads = config.num_attention_heads
127
+ self.total_num_heads = num_heads
55
128
  assert self.total_num_heads % tp_size == 0
56
129
  self.num_heads = self.total_num_heads // tp_size
57
- self.total_num_kv_heads = config.num_key_value_heads
130
+ self.total_num_kv_heads = num_kv_heads
58
131
  if self.total_num_kv_heads >= tp_size:
59
132
  # Number of KV heads is greater than TP size, so we partition
60
133
  # the KV heads across multiple tensor parallel GPUs.
@@ -63,27 +136,30 @@ class Glm4Attention(nn.Module):
63
136
  # Number of KV heads is less than TP size, so we replicate
64
137
  # the KV heads across multiple tensor parallel GPUs.
65
138
  assert tp_size % self.total_num_kv_heads == 0
66
- partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
67
139
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
68
- self.head_dim = config.hidden_size // self.total_num_heads
140
+ if head_dim is not None:
141
+ self.head_dim = head_dim
142
+ else:
143
+ self.head_dim = hidden_size // self.total_num_heads
69
144
  self.q_size = self.num_heads * self.head_dim
70
145
  self.kv_size = self.num_kv_heads * self.head_dim
71
146
  self.scaling = self.head_dim**-0.5
72
- self.rope_theta = getattr(config, "rope_theta", 1000000)
73
- self.rope_scaling = getattr(config, "rope_scaling", None)
147
+ self.rope_theta = rope_theta
148
+ self.max_position_embeddings = max_position_embeddings
149
+ self.partial_rotary_factor = partial_rotary_factor
74
150
 
75
151
  self.qkv_proj = QKVParallelLinear(
76
- self.hidden_size,
152
+ hidden_size,
77
153
  self.head_dim,
78
154
  self.total_num_heads,
79
155
  self.total_num_kv_heads,
80
- bias=config.attention_bias,
156
+ bias=True,
81
157
  quant_config=quant_config,
82
158
  prefix=add_prefix("qkv_proj", prefix),
83
159
  )
84
160
  self.o_proj = RowParallelLinear(
85
161
  self.total_num_heads * self.head_dim,
86
- self.hidden_size,
162
+ hidden_size,
87
163
  bias=False,
88
164
  quant_config=quant_config,
89
165
  prefix=add_prefix("o_proj", prefix),
@@ -92,9 +168,10 @@ class Glm4Attention(nn.Module):
92
168
  self.rotary_emb = get_rope(
93
169
  self.head_dim,
94
170
  rotary_dim=self.head_dim,
95
- max_position=config.max_position_embeddings,
96
- base=self.rope_theta,
97
- rope_scaling=self.rope_scaling,
171
+ max_position=max_position_embeddings,
172
+ base=rope_theta,
173
+ rope_scaling=rope_scaling,
174
+ dual_chunk_attention_config=dual_chunk_attention_config,
98
175
  partial_rotary_factor=partial_rotary_factor,
99
176
  is_neox_style=False,
100
177
  )
@@ -117,14 +194,9 @@ class Glm4Attention(nn.Module):
117
194
  qkv, _ = self.qkv_proj(hidden_states)
118
195
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
119
196
  q, k = self.rotary_emb(positions, q, k)
120
- context_layer = self.attn(
121
- q,
122
- k,
123
- v,
124
- forward_batch,
125
- )
126
- attn_output, _ = self.o_proj(context_layer)
127
- return attn_output
197
+ attn_output = self.attn(q, k, v, forward_batch)
198
+ output, _ = self.o_proj(attn_output)
199
+ return output
128
200
 
129
201
 
130
202
  class Glm4DecoderLayer(nn.Module):
@@ -136,15 +208,35 @@ class Glm4DecoderLayer(nn.Module):
136
208
 
137
209
  def __init__(
138
210
  self,
139
- config,
140
- layer_id: int,
211
+ config: Glm4Config,
212
+ layer_id: int = 0,
141
213
  quant_config: Optional[QuantizationConfig] = None,
142
214
  prefix: str = "",
143
- ):
215
+ alt_stream: Optional[torch.cuda.Stream] = None,
216
+ ) -> None:
144
217
  super().__init__()
145
- # Self attention.
218
+ self.hidden_size = config.hidden_size
219
+ rope_theta = getattr(config, "rope_theta", 1000000)
220
+ rope_scaling = getattr(config, "rope_scaling", None)
221
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
222
+ head_dim = getattr(config, "head_dim", None)
223
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", None)
224
+ dual_chunk_attention_config = getattr(
225
+ config, "dual_chunk_attention_config", None
226
+ )
146
227
  self.self_attn = Glm4Attention(
147
- config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
228
+ hidden_size=self.hidden_size,
229
+ num_heads=config.num_attention_heads,
230
+ num_kv_heads=config.num_key_value_heads,
231
+ head_dim=head_dim,
232
+ layer_id=layer_id,
233
+ rope_theta=rope_theta,
234
+ rope_scaling=rope_scaling,
235
+ max_position_embeddings=max_position_embeddings,
236
+ quant_config=quant_config,
237
+ dual_chunk_attention_config=dual_chunk_attention_config,
238
+ partial_rotary_factor=partial_rotary_factor,
239
+ prefix=add_prefix("self_attn", prefix),
148
240
  )
149
241
 
150
242
  # MLP
@@ -199,54 +291,125 @@ class Glm4Model(nn.Module):
199
291
  config: Glm4Config,
200
292
  quant_config: Optional[QuantizationConfig] = None,
201
293
  prefix: str = "",
294
+ decoder_layer_type: type[nn.Module] = Glm4DecoderLayer,
295
+ alt_stream: Optional[torch.cuda.Stream] = None,
202
296
  ) -> None:
203
297
  super().__init__()
204
298
  self.config = config
205
- self.embed_tokens = VocabParallelEmbedding(
206
- config.vocab_size,
207
- config.hidden_size,
208
- quant_config=quant_config,
209
- prefix=add_prefix("embed_tokens", prefix),
210
- )
211
- self.layers = make_layers(
299
+ self.padding_idx = config.pad_token_id
300
+ self.vocab_size = config.vocab_size
301
+ self.pp_group = get_pp_group()
302
+
303
+ if self.pp_group.is_first_rank:
304
+ self.embed_tokens = VocabParallelEmbedding(
305
+ config.vocab_size,
306
+ config.hidden_size,
307
+ quant_config=quant_config,
308
+ enable_tp=not is_dp_attention_enabled(),
309
+ prefix=add_prefix("embed_tokens", prefix),
310
+ )
311
+ else:
312
+ self.embed_tokens = PPMissingLayer()
313
+
314
+ # Use the provided decoder layer type or default to Glm4DecoderLayer
315
+ decoder_layer_type = decoder_layer_type or Glm4DecoderLayer
316
+ self.layers, self.start_layer, self.end_layer = make_layers(
212
317
  config.num_hidden_layers,
213
- lambda idx, prefix: Glm4DecoderLayer(
214
- config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
318
+ lambda idx, prefix: decoder_layer_type(
319
+ layer_id=idx,
320
+ config=config,
321
+ quant_config=quant_config,
322
+ prefix=prefix,
323
+ alt_stream=alt_stream,
215
324
  ),
216
- prefix="model.layers",
325
+ pp_rank=self.pp_group.rank_in_group,
326
+ pp_size=self.pp_group.world_size,
327
+ prefix=add_prefix("layers", prefix),
217
328
  )
329
+ if self.pp_group.is_last_rank:
330
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
331
+ else:
332
+ self.norm = PPMissingLayer(return_tuple=True)
218
333
 
219
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
334
+ # For EAGLE3 support
335
+ self.layers_to_capture = []
220
336
 
221
337
  def get_input_embeddings(self) -> nn.Embedding:
222
338
  return self.embed_tokens
223
339
 
224
- def dtype(self) -> torch.dtype:
225
- return next(self.parameters()).dtype
226
-
227
- @torch.no_grad()
228
340
  def forward(
229
341
  self,
230
342
  input_ids: torch.Tensor,
231
343
  positions: torch.Tensor,
232
344
  forward_batch: ForwardBatch,
233
345
  input_embeds: torch.Tensor = None,
234
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
235
- if input_embeds is None:
236
- hidden_states = self.embed_tokens(input_ids)
346
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
347
+ ) -> Union[torch.Tensor, PPProxyTensors]:
348
+ if self.pp_group.is_first_rank:
349
+ if input_embeds is None:
350
+ hidden_states = self.embed_tokens(input_ids)
351
+ else:
352
+ hidden_states = input_embeds
353
+ residual = None
237
354
  else:
238
- hidden_states = input_embeds
239
- residual = None
240
- for layer in self.layers:
355
+ assert pp_proxy_tensors is not None
356
+ hidden_states = pp_proxy_tensors["hidden_states"]
357
+ residual = pp_proxy_tensors["residual"]
358
+
359
+ aux_hidden_states = []
360
+ for i in range(self.start_layer, self.end_layer):
361
+ if i in self.layers_to_capture:
362
+ aux_hidden_states.append(
363
+ hidden_states + residual if residual is not None else hidden_states
364
+ )
365
+ layer = self.layers[i]
241
366
  hidden_states, residual = layer(
242
367
  positions,
243
368
  hidden_states,
244
369
  forward_batch,
245
370
  residual,
246
371
  )
247
- hidden_states, _ = self.norm(hidden_states, residual)
372
+ if not self.pp_group.is_last_rank:
373
+ return PPProxyTensors(
374
+ {
375
+ "hidden_states": hidden_states,
376
+ "residual": residual,
377
+ }
378
+ )
379
+ else:
380
+ if hidden_states.shape[0] != 0:
381
+ if residual is None:
382
+ hidden_states = self.norm(hidden_states)
383
+ else:
384
+ hidden_states, _ = self.norm(hidden_states, residual)
385
+
386
+ if len(aux_hidden_states) == 0:
387
+ return hidden_states
248
388
 
249
- return hidden_states
389
+ return hidden_states, aux_hidden_states
390
+
391
+ # If this function is called, it should always initialize KV cache scale
392
+ # factors (or else raise an exception). Thus, handled exceptions should
393
+ # make sure to leave KV cache scale factors in a known good (dummy) state
394
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
395
+ tp_size = get_tensor_model_parallel_world_size()
396
+ tp_rank = get_tensor_model_parallel_rank()
397
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
398
+ quantization_param_path,
399
+ tp_rank,
400
+ tp_size,
401
+ self.config.num_hidden_layers,
402
+ self.config.__class__.model_type,
403
+ ):
404
+ if not isinstance(self.layers[layer_idx], nn.Identity):
405
+ layer_self_attn = self.layers[layer_idx].self_attn
406
+ if hasattr(layer_self_attn.attn, "k_scale"):
407
+ layer_self_attn.attn.k_scale = scaling_factor
408
+ layer_self_attn.attn.v_scale = scaling_factor
409
+ else:
410
+ raise RuntimeError(
411
+ "Self attention has no KV cache scaling factor attribute!"
412
+ )
250
413
 
251
414
 
252
415
  class Glm4ForCausalLM(nn.Module):
@@ -255,21 +418,54 @@ class Glm4ForCausalLM(nn.Module):
255
418
  config: Glm4Config,
256
419
  quant_config: Optional[QuantizationConfig] = None,
257
420
  prefix: str = "",
258
- ):
421
+ ) -> None:
259
422
  super().__init__()
260
- self.config: Glm4Config = config
423
+ self.pp_group = get_pp_group()
424
+ self.config = config
261
425
  self.quant_config = quant_config
262
- self.model = Glm4Model(config, quant_config, add_prefix("model", prefix))
263
- if config.tie_word_embeddings:
264
- self.lm_head = self.model.embed_tokens
426
+ self.model = Glm4Model(
427
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
428
+ )
429
+
430
+ # handle the lm head on different pp ranks
431
+ if self.pp_group.is_last_rank:
432
+ if self.pp_group.world_size == 1 and config.tie_word_embeddings:
433
+ self.lm_head = self.model.embed_tokens
434
+ else:
435
+ self.lm_head = ParallelLMHead(
436
+ config.vocab_size,
437
+ config.hidden_size,
438
+ quant_config=quant_config,
439
+ prefix=add_prefix("lm_head", prefix),
440
+ )
265
441
  else:
266
- self.lm_head = ParallelLMHead(
267
- config.vocab_size,
268
- config.hidden_size,
269
- quant_config=quant_config,
270
- prefix="lm_head",
271
- )
442
+ # ranks other than the last rank will have a placeholder layer
443
+ self.lm_head = PPMissingLayer()
444
+
445
+ # perform weight tying for PP
446
+ if self.pp_group.world_size > 1 and config.tie_word_embeddings:
447
+ if self.pp_group.is_first_rank:
448
+ self.pp_group.send(
449
+ self.model.embed_tokens.weight, dst=self.pp_group.last_rank
450
+ )
451
+ else:
452
+ emb_token_weight = self.pp_group.recv(
453
+ size=(config.vocab_size, config.hidden_size),
454
+ dtype=next(self.model.parameters()).dtype,
455
+ src=self.pp_group.first_rank,
456
+ )
457
+ self.lm_head.weight.copy_(emb_token_weight)
458
+
272
459
  self.logits_processor = LogitsProcessor(config)
460
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
461
+ # For EAGLE3 support
462
+ self.capture_aux_hidden_states = False
463
+
464
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
465
+ return self.model.get_input_embedding(input_ids)
466
+
467
+ def get_input_embeddings(self) -> nn.Embedding:
468
+ return self.model.embed_tokens
273
469
 
274
470
  @torch.no_grad()
275
471
  def forward(
@@ -277,34 +473,138 @@ class Glm4ForCausalLM(nn.Module):
277
473
  input_ids: torch.Tensor,
278
474
  positions: torch.Tensor,
279
475
  forward_batch: ForwardBatch,
476
+ input_embeds: torch.Tensor = None,
477
+ get_embedding: bool = False,
478
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
280
479
  ) -> torch.Tensor:
281
- hidden_states = self.model(input_ids, positions, forward_batch)
282
- return self.logits_processor(
283
- input_ids, hidden_states, self.lm_head, forward_batch
480
+ hidden_states = self.model(
481
+ input_ids,
482
+ positions,
483
+ forward_batch,
484
+ input_embeds,
485
+ pp_proxy_tensors=pp_proxy_tensors,
284
486
  )
487
+ aux_hidden_states = None
488
+ if self.capture_aux_hidden_states:
489
+ hidden_states, aux_hidden_states = hidden_states
490
+
491
+ if self.pp_group.is_last_rank:
492
+ if not get_embedding:
493
+ return self.logits_processor(
494
+ input_ids,
495
+ hidden_states,
496
+ self.lm_head,
497
+ forward_batch,
498
+ aux_hidden_states,
499
+ )
500
+ else:
501
+ return self.pooler(hidden_states, forward_batch)
502
+ else:
503
+ return hidden_states
504
+
505
+ @torch.no_grad()
506
+ def forward_split_prefill(
507
+ self,
508
+ input_ids: torch.Tensor,
509
+ positions: torch.Tensor,
510
+ forward_batch: ForwardBatch,
511
+ split_interval: Tuple[int, int], # [start, end) 0-based
512
+ input_embeds: torch.Tensor = None,
513
+ ):
514
+ start, end = split_interval
515
+ # embed
516
+ if start == 0:
517
+ if input_embeds is None:
518
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
519
+ else:
520
+ forward_batch.hidden_states = input_embeds
521
+ # decoder layer
522
+ for i in range(start, end):
523
+ layer = self.model.layers[i]
524
+ forward_batch.hidden_states, forward_batch.residual = layer(
525
+ positions,
526
+ forward_batch.hidden_states,
527
+ forward_batch,
528
+ forward_batch.residual,
529
+ )
530
+
531
+ if end == self.model.config.num_hidden_layers:
532
+ # norm
533
+ hidden_states, _ = self.model.norm(
534
+ forward_batch.hidden_states, forward_batch.residual
535
+ )
536
+ forward_batch.hidden_states = hidden_states
537
+ # logits process
538
+ result = self.logits_processor(
539
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
540
+ )
541
+ else:
542
+ result = None
543
+
544
+ return result
545
+
546
+ @property
547
+ def start_layer(self):
548
+ return self.model.start_layer
549
+
550
+ @property
551
+ def end_layer(self):
552
+ return self.model.end_layer
285
553
 
286
554
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
287
555
  stacked_params_mapping = [
288
- # (param_name, weight_name, shard_id)
556
+ # (param_name, shard_name, shard_id)
289
557
  (".qkv_proj", ".q_proj", "q"),
290
558
  (".qkv_proj", ".k_proj", "k"),
291
559
  (".qkv_proj", ".v_proj", "v"),
292
- (".gate_up_proj", ".gate_proj", 0),
293
560
  (".gate_up_proj", ".up_proj", 1),
561
+ (".gate_up_proj", ".gate_proj", 0),
294
562
  ]
563
+
295
564
  params_dict = dict(self.named_parameters())
296
565
  for name, loaded_weight in weights:
297
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
566
+ layer_id = get_layer_id(name)
567
+ if (
568
+ layer_id is not None
569
+ and hasattr(self.model, "start_layer")
570
+ and (
571
+ layer_id < self.model.start_layer
572
+ or layer_id >= self.model.end_layer
573
+ )
574
+ ):
575
+ continue
576
+
577
+ if "rotary_emb.inv_freq" in name or "projector" in name:
298
578
  continue
579
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
580
+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
581
+ # Handle pp weight tying here
582
+ # find the embed_tokens.weight in the weights
583
+ embed_token_weights = next(
584
+ filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
585
+ )[1]
586
+ loaded_weight = embed_token_weights
587
+ else:
588
+ continue
589
+
299
590
  for param_name, weight_name, shard_id in stacked_params_mapping:
300
591
  if weight_name not in name:
301
592
  continue
302
593
  name = name.replace(weight_name, param_name)
594
+ # Skip loading extra bias for GPTQ models.
595
+ if name.endswith(".bias") and name not in params_dict:
596
+ continue
597
+ if name not in params_dict:
598
+ continue
303
599
  param = params_dict[name]
304
600
  weight_loader = param.weight_loader
305
601
  weight_loader(param, loaded_weight, shard_id)
306
602
  break
307
603
  else:
604
+ # Skip loading extra bias for GPTQ models.
605
+ if name.endswith(".bias") and name not in params_dict:
606
+ continue
607
+
308
608
  if name in params_dict.keys():
309
609
  param = params_dict[name]
310
610
  weight_loader = getattr(
@@ -312,7 +612,21 @@ class Glm4ForCausalLM(nn.Module):
312
612
  )
313
613
  weight_loader(param, loaded_weight)
314
614
  else:
315
- raise KeyError(f"Parameter '{name}' not found in model.")
615
+ logger.warning(f"Parameter {name} not found in params_dict")
616
+
617
+ def get_embed_and_head(self):
618
+ return self.model.embed_tokens.weight, self.lm_head.weight
619
+
620
+ def set_embed_and_head(self, embed, head):
621
+ del self.model.embed_tokens.weight
622
+ del self.lm_head.weight
623
+ self.model.embed_tokens.weight = embed
624
+ self.lm_head.weight = head
625
+ torch.cuda.empty_cache()
626
+ torch.cuda.synchronize()
627
+
628
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
629
+ self.model.load_kv_cache_scales(quantization_param_path)
316
630
 
317
631
 
318
632
  EntryClass = [Glm4ForCausalLM]