sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,330 @@
1
+ from collections.abc import Iterable
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import PersimmonConfig
7
+
8
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
9
+ from sglang.srt.layers.activation import get_act_fn
10
+ from sglang.srt.layers.linear import (
11
+ ColumnParallelLinear,
12
+ QKVParallelLinear,
13
+ RowParallelLinear,
14
+ )
15
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
16
+ from sglang.srt.layers.quantization import QuantizationConfig
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.layers.rotary_embedding import get_rope
19
+ from sglang.srt.layers.utils import PPMissingLayer
20
+ from sglang.srt.layers.vocab_parallel_embedding import (
21
+ ParallelLMHead,
22
+ VocabParallelEmbedding,
23
+ )
24
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
26
+ from sglang.srt.utils import add_prefix, make_layers
27
+
28
+
29
+ class PersimmonMLP(nn.Module):
30
+
31
+ def __init__(
32
+ self, config: PersimmonConfig, quant_config: Optional[QuantizationConfig] = None
33
+ ):
34
+ super().__init__()
35
+ self.dense_h_to_4h = ColumnParallelLinear(
36
+ config.hidden_size, config.intermediate_size, quant_config=quant_config
37
+ )
38
+ self.dense_4h_to_h = RowParallelLinear(
39
+ config.intermediate_size, config.hidden_size, quant_config=quant_config
40
+ )
41
+ self.act = get_act_fn(config.hidden_act)
42
+
43
+ def forward(self, hidden_states) -> torch.Tensor:
44
+ hidden_states, _ = self.dense_h_to_4h(hidden_states)
45
+ hidden_states = self.act(hidden_states)
46
+ hidden_states, _ = self.dense_4h_to_h(hidden_states)
47
+ return hidden_states
48
+
49
+
50
+ class PersimmonAttention(nn.Module):
51
+
52
+ def __init__(
53
+ self,
54
+ config: PersimmonConfig,
55
+ quant_config: Optional[QuantizationConfig] = None,
56
+ prefix: str = "",
57
+ layer_id: int = 0,
58
+ ):
59
+ super().__init__()
60
+ self.config = config
61
+ tensor_parallel_world_size = get_tensor_model_parallel_world_size()
62
+
63
+ self.hidden_size = config.hidden_size
64
+ self.total_num_heads = config.num_attention_heads
65
+ self.num_heads = self.total_num_heads // tensor_parallel_world_size
66
+ self.head_dim = self.hidden_size // self.total_num_heads
67
+ self.max_position_embeddings = config.max_position_embeddings
68
+ self.rope_theta = config.rope_theta
69
+ self.partial_rotary_factor = config.partial_rotary_factor
70
+ self.is_causal = True
71
+
72
+ assert (self.head_dim * self.total_num_heads) == self.hidden_size
73
+ assert self.total_num_heads % tensor_parallel_world_size == 0
74
+
75
+ self.query_key_value = QKVParallelLinear(
76
+ self.hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ bias=True,
80
+ quant_config=quant_config,
81
+ )
82
+ self.dense = RowParallelLinear(
83
+ self.total_num_heads * self.head_dim,
84
+ self.hidden_size,
85
+ bias=True,
86
+ quant_config=quant_config,
87
+ )
88
+ self.is_qk_layernorm = config.qk_layernorm
89
+
90
+ if self.is_qk_layernorm:
91
+ self.q_layernorm = nn.LayerNorm(self.head_dim)
92
+ self.k_layernorm = nn.LayerNorm(self.head_dim)
93
+
94
+ self.rotary_emb = get_rope(
95
+ self.head_dim,
96
+ rotary_dim=self.head_dim,
97
+ max_position=self.max_position_embeddings,
98
+ base=self.rope_theta,
99
+ partial_rotary_factor=self.partial_rotary_factor,
100
+ )
101
+ self.scaling = self.head_dim**-0.5
102
+ self.attn = RadixAttention(
103
+ self.num_heads,
104
+ self.head_dim,
105
+ self.scaling,
106
+ num_kv_heads=self.num_heads,
107
+ layer_id=layer_id,
108
+ quant_config=quant_config,
109
+ prefix=add_prefix("attn", prefix),
110
+ )
111
+
112
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
113
+ seq_length = x.shape[0]
114
+ return x.view(seq_length, self.num_heads, self.head_dim)
115
+
116
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
117
+ seq_length = x.shape[0]
118
+ return x.view(seq_length, self.num_heads * self.head_dim)
119
+
120
+ def forward(
121
+ self,
122
+ position_ids: torch.Tensor,
123
+ forward_batch: ForwardBatch,
124
+ hidden_states: torch.Tensor,
125
+ ) -> torch.Tensor:
126
+ qkv, _ = self.query_key_value(hidden_states)
127
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
128
+
129
+ if self.is_qk_layernorm:
130
+ q = self._split_heads(q)
131
+ k = self._split_heads(k)
132
+
133
+ q = self.q_layernorm(q)
134
+ k = self.k_layernorm(k)
135
+
136
+ q = self._merge_heads(q)
137
+ k = self._merge_heads(k)
138
+
139
+ q, k = self.rotary_emb(position_ids, q, k)
140
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
141
+ output, _ = self.dense(attn_output)
142
+ return output
143
+
144
+
145
+ class PersimmonDecoderLayer(nn.Module):
146
+
147
+ def __init__(
148
+ self,
149
+ config: PersimmonConfig,
150
+ quant_config: Optional[QuantizationConfig] = None,
151
+ prefix: str = "",
152
+ idx: int = 0,
153
+ ):
154
+ super().__init__()
155
+ self.hidden_size = config.hidden_size
156
+ self.self_attn = PersimmonAttention(
157
+ config=config,
158
+ quant_config=quant_config,
159
+ prefix=add_prefix("self_attn", prefix),
160
+ layer_id=idx,
161
+ )
162
+ self.mlp = PersimmonMLP(config, quant_config=quant_config)
163
+ self.input_layernorm = nn.LayerNorm(
164
+ config.hidden_size, eps=config.layer_norm_eps
165
+ )
166
+ self.post_attention_layernorm = nn.LayerNorm(
167
+ config.hidden_size, eps=config.layer_norm_eps
168
+ )
169
+
170
+ def forward(
171
+ self,
172
+ position_ids: torch.Tensor,
173
+ forward_batch: ForwardBatch,
174
+ hidden_states: torch.Tensor,
175
+ ) -> torch.Tensor:
176
+ residual = hidden_states
177
+
178
+ hidden_states = self.input_layernorm(hidden_states)
179
+
180
+ hidden_states = self.self_attn(
181
+ position_ids=position_ids,
182
+ hidden_states=hidden_states,
183
+ forward_batch=forward_batch,
184
+ )
185
+ hidden_states = residual + hidden_states
186
+
187
+ residual = hidden_states
188
+ hidden_states = self.post_attention_layernorm(hidden_states)
189
+ hidden_states = self.mlp(hidden_states)
190
+
191
+ hidden_states = hidden_states + residual
192
+
193
+ outputs = hidden_states
194
+ return outputs
195
+
196
+
197
+ class PersimmonModel(nn.Module):
198
+
199
+ def __init__(
200
+ self,
201
+ config: PersimmonConfig,
202
+ quant_config: Optional[QuantizationConfig] = None,
203
+ prefix: str = "",
204
+ ):
205
+ super().__init__()
206
+ self.config = config
207
+ self.pp_group = get_pp_group()
208
+
209
+ if self.pp_group.is_first_rank:
210
+ self.embed_tokens = VocabParallelEmbedding(
211
+ config.vocab_size, config.hidden_size
212
+ )
213
+ else:
214
+ self.embed_tokens = PPMissingLayer()
215
+
216
+ self.layers, self.start_layer, self.end_layer = make_layers(
217
+ config.num_hidden_layers,
218
+ lambda idx, prefix: PersimmonDecoderLayer(
219
+ config, quant_config=quant_config, prefix=prefix, idx=idx
220
+ ),
221
+ prefix="model.layers",
222
+ pp_rank=self.pp_group.rank_in_group,
223
+ pp_size=self.pp_group.world_size,
224
+ )
225
+
226
+ if self.pp_group.is_last_rank:
227
+ self.final_layernorm = nn.LayerNorm(
228
+ config.hidden_size, eps=config.layer_norm_eps
229
+ )
230
+ else:
231
+ self.final_layernorm = PPMissingLayer()
232
+
233
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
234
+ return self.embed_tokens(input_ids)
235
+
236
+ def forward(
237
+ self,
238
+ input_ids: torch.Tensor,
239
+ forward_batch: ForwardBatch,
240
+ positions: torch.Tensor,
241
+ inputs_embeds: Optional[torch.Tensor] = None,
242
+ ) -> torch.Tensor:
243
+ if self.pp_group.is_first_rank:
244
+ if inputs_embeds is not None:
245
+ hidden_states = inputs_embeds
246
+ else:
247
+ hidden_states = self.get_input_embeddings(input_ids)
248
+ else:
249
+ hidden_states = forward_batch.pp_input_hidden
250
+ for i in range(self.start_layer, self.end_layer):
251
+ layer = self.layers[i]
252
+ hidden_states = layer(
253
+ position_ids=positions,
254
+ forward_batch=forward_batch,
255
+ hidden_states=hidden_states,
256
+ )
257
+ return self.final_layernorm(hidden_states)
258
+
259
+
260
+ class PersimmonForCausalLM(nn.Module):
261
+
262
+ def __init__(
263
+ self,
264
+ config: PersimmonConfig,
265
+ quant_config: Optional[QuantizationConfig] = None,
266
+ prefix: str = "",
267
+ ):
268
+ super().__init__()
269
+ self.config = config
270
+ self.quant_config = quant_config
271
+ self.model = PersimmonModel(
272
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
273
+ )
274
+ self.lm_head = ParallelLMHead(
275
+ config.vocab_size,
276
+ config.hidden_size,
277
+ bias=False,
278
+ quant_config=quant_config,
279
+ )
280
+ self.logits_processor = LogitsProcessor(config)
281
+
282
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
283
+ return self.model.get_input_embeddings(input_ids)
284
+
285
+ def forward(
286
+ self,
287
+ input_ids: torch.Tensor,
288
+ positions: torch.Tensor,
289
+ forward_batch: ForwardBatch,
290
+ inputs_embeds: Optional[torch.Tensor] = None,
291
+ ) -> LogitsProcessorOutput:
292
+ hidden_states = self.model(
293
+ input_ids=input_ids,
294
+ forward_batch=forward_batch,
295
+ positions=positions,
296
+ inputs_embeds=inputs_embeds,
297
+ )
298
+
299
+ return self.logits_processor(
300
+ input_ids, hidden_states, self.lm_head, forward_batch
301
+ )
302
+
303
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
304
+ params_dict = dict(self.named_parameters())
305
+ for name, loaded_weight in weights:
306
+ if "rotary_emb.inv_freq" in name:
307
+ continue
308
+ if name not in params_dict:
309
+ if name == "lm_head.weight":
310
+ continue
311
+ print(f"Warning: weight {name} not found in model.")
312
+ continue
313
+ param = params_dict[name]
314
+ if "query_key_value" in name:
315
+ output_dim = getattr(param, "output_dim", None)
316
+ if output_dim is not None:
317
+ loaded_weight_shape = loaded_weight.shape
318
+ num_heads = self.config.num_attention_heads
319
+ loaded_weight = loaded_weight.view(
320
+ loaded_weight_shape[:output_dim]
321
+ + (num_heads, 3, -1)
322
+ + loaded_weight_shape[output_dim + 1 :]
323
+ )
324
+ loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
325
+ loaded_weight = loaded_weight.reshape(loaded_weight_shape)
326
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
327
+ weight_loader(param, loaded_weight)
328
+
329
+
330
+ EntryClass = PersimmonForCausalLM
@@ -0,0 +1,321 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/phi.py
2
+ from typing import Iterable, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import PhiConfig
7
+
8
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
9
+ from sglang.srt.layers.activation import get_act_fn
10
+ from sglang.srt.layers.linear import (
11
+ ColumnParallelLinear,
12
+ QKVParallelLinear,
13
+ RowParallelLinear,
14
+ )
15
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
16
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
+ from sglang.srt.layers.radix_attention import RadixAttention
18
+ from sglang.srt.layers.rotary_embedding import get_rope
19
+ from sglang.srt.layers.vocab_parallel_embedding import (
20
+ ParallelLMHead,
21
+ VocabParallelEmbedding,
22
+ )
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.utils import add_prefix, make_layers
26
+
27
+
28
+ class PhiAttention(nn.Module):
29
+
30
+ def __init__(
31
+ self,
32
+ config: PhiConfig,
33
+ quant_config: Optional[QuantizationConfig] = None,
34
+ prefix: str = "",
35
+ layer_id: int = 0,
36
+ ):
37
+ super().__init__()
38
+ self.total_num_heads = config.num_attention_heads
39
+ self.hidden_size = config.hidden_size
40
+ self.head_size = self.hidden_size // self.total_num_heads
41
+
42
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
43
+ assert self.total_num_heads % tensor_model_parallel_world_size == 0
44
+ self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
45
+
46
+ self.qkv_proj = QKVParallelLinear(
47
+ self.hidden_size,
48
+ self.head_size,
49
+ self.total_num_heads,
50
+ bias=True,
51
+ quant_config=quant_config,
52
+ )
53
+ self.dense = RowParallelLinear(
54
+ self.hidden_size,
55
+ self.hidden_size,
56
+ quant_config=quant_config,
57
+ )
58
+
59
+ scaling = self.head_size**-0.5
60
+ rotary_dim = int(
61
+ config.partial_rotary_factor
62
+ * (config.hidden_size // config.num_attention_heads)
63
+ )
64
+ assert rotary_dim % 2 == 0
65
+
66
+ rope_theta = getattr(config, "rope_theta", 10000.0)
67
+ max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
68
+ self.rotary_emb = get_rope(
69
+ self.head_size,
70
+ rotary_dim=rotary_dim,
71
+ max_position=max_position_embeddings,
72
+ base=rope_theta,
73
+ )
74
+ self.attn = RadixAttention(
75
+ self.num_heads,
76
+ self.head_size,
77
+ scaling,
78
+ num_kv_heads=self.num_heads,
79
+ layer_id=layer_id,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("attn", prefix),
82
+ )
83
+
84
+ def forward(
85
+ self,
86
+ position_ids: torch.Tensor,
87
+ forward_batch: ForwardBatch,
88
+ hidden_states: torch.Tensor,
89
+ ) -> torch.Tensor:
90
+ qkv, _ = self.qkv_proj(hidden_states)
91
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
92
+ q, k = self.rotary_emb(position_ids, q, k)
93
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
94
+ output, _ = self.dense(attn_output)
95
+ return output
96
+
97
+
98
+ class PhiMLP(nn.Module):
99
+
100
+ def __init__(
101
+ self, config: PhiConfig, quant_config: Optional[QuantizationConfig] = None
102
+ ):
103
+ super().__init__()
104
+
105
+ n_inner = getattr(config, "n_inner", None)
106
+ n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
107
+
108
+ self.fc1 = ColumnParallelLinear(
109
+ config.hidden_size,
110
+ n_inner,
111
+ quant_config=quant_config,
112
+ )
113
+ self.fc2 = RowParallelLinear(
114
+ n_inner,
115
+ config.hidden_size,
116
+ quant_config=quant_config,
117
+ )
118
+ self.act = get_act_fn(config.hidden_act)
119
+
120
+ def forward(self, hidden_states):
121
+ hidden_states, _ = self.fc1(hidden_states)
122
+ hidden_states = self.act(hidden_states)
123
+ hidden_states, _ = self.fc2(hidden_states)
124
+ return hidden_states
125
+
126
+
127
+ class PhiLayer(nn.Module):
128
+
129
+ def __init__(
130
+ self,
131
+ config: PhiConfig,
132
+ quant_config: Optional[QuantizationConfig] = None,
133
+ prefix: str = "",
134
+ idx: int = 0,
135
+ ):
136
+ super().__init__()
137
+ self.input_layernorm = nn.LayerNorm(
138
+ config.hidden_size, eps=config.layer_norm_eps
139
+ )
140
+ self.self_attn = PhiAttention(
141
+ config,
142
+ quant_config,
143
+ prefix=add_prefix("self_attn", prefix),
144
+ layer_id=idx,
145
+ )
146
+ self.mlp = PhiMLP(config, quant_config)
147
+
148
+ def forward(
149
+ self,
150
+ position_ids: torch.Tensor,
151
+ forward_batch: ForwardBatch,
152
+ hidden_states: torch.Tensor,
153
+ ) -> torch.Tensor:
154
+ residual = hidden_states
155
+ hidden_states = self.input_layernorm(hidden_states)
156
+ attn_outputs = self.self_attn(
157
+ position_ids=position_ids,
158
+ hidden_states=hidden_states,
159
+ forward_batch=forward_batch,
160
+ )
161
+ feed_forward_hidden_states = self.mlp(hidden_states)
162
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
163
+ return hidden_states
164
+
165
+
166
+ class PhiModel(nn.Module):
167
+
168
+ def __init__(
169
+ self,
170
+ config: PhiConfig,
171
+ quant_config: Optional[QuantizationConfig] = None,
172
+ prefix: str = "",
173
+ ):
174
+ super().__init__()
175
+ self.config = config
176
+ self.embed_tokens = VocabParallelEmbedding(
177
+ config.vocab_size, config.hidden_size
178
+ )
179
+
180
+ pp_group = get_pp_group()
181
+ pp_size = pp_group.world_size
182
+ pp_rank = pp_group.rank
183
+
184
+ self.start_layer = pp_rank * config.num_hidden_layers // pp_size
185
+ self.end_layer = (pp_rank + 1) * config.num_hidden_layers // pp_size
186
+
187
+ self.layers = make_layers(
188
+ config.num_hidden_layers,
189
+ lambda idx, prefix: PhiLayer(
190
+ config, quant_config=quant_config, prefix=prefix, idx=idx
191
+ ),
192
+ prefix=add_prefix("layers", prefix),
193
+ )
194
+
195
+ self.final_layernorm = nn.LayerNorm(
196
+ config.hidden_size, eps=config.layer_norm_eps
197
+ )
198
+
199
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
200
+ return self.embed_tokens(input_ids)
201
+
202
+ def forward(
203
+ self,
204
+ input_ids: torch.Tensor,
205
+ forward_batch: ForwardBatch,
206
+ positions: torch.Tensor,
207
+ inputs_embeds: Optional[torch.Tensor] = None,
208
+ ) -> torch.Tensor:
209
+ if inputs_embeds is not None:
210
+ hidden_states = inputs_embeds
211
+ else:
212
+ hidden_states = self.get_input_embeddings(input_ids)
213
+ for i in range(self.start_layer, self.end_layer):
214
+ layer = self.layers[i]
215
+
216
+ hidden_states = layer(
217
+ position_ids=positions,
218
+ forward_batch=forward_batch,
219
+ hidden_states=hidden_states,
220
+ )
221
+ hidden_states = self.final_layernorm(hidden_states)
222
+ return hidden_states
223
+
224
+
225
+ class PhiForCausalLM(nn.Module):
226
+ packed_modules_mapping = {
227
+ "qkv_proj": [
228
+ "q_proj",
229
+ "k_proj",
230
+ "v_proj",
231
+ ]
232
+ }
233
+
234
+ def __init__(
235
+ self,
236
+ config: PhiConfig,
237
+ quant_config: Optional[QuantizationConfig] = None,
238
+ prefix: str = "",
239
+ ):
240
+ super().__init__()
241
+ self.config = config
242
+ self.quant_config = quant_config
243
+ self.model = PhiModel(
244
+ config=config,
245
+ quant_config=quant_config,
246
+ prefix=add_prefix("model", prefix),
247
+ )
248
+
249
+ self.lm_head = ParallelLMHead(
250
+ config.vocab_size,
251
+ config.hidden_size,
252
+ bias=True,
253
+ quant_config=quant_config,
254
+ )
255
+ self.logits_processor = LogitsProcessor(config)
256
+
257
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
258
+ return self.model.get_input_embeddings(input_ids)
259
+
260
+ def forward(
261
+ self,
262
+ input_ids: torch.Tensor,
263
+ positions: torch.Tensor,
264
+ forward_batch: ForwardBatch,
265
+ inputs_embeds: Optional[torch.Tensor] = None,
266
+ ) -> LogitsProcessorOutput:
267
+
268
+ hidden_states = self.model(
269
+ input_ids=input_ids,
270
+ forward_batch=forward_batch,
271
+ positions=positions,
272
+ inputs_embeds=inputs_embeds,
273
+ )
274
+
275
+ return self.logits_processor(
276
+ input_ids, hidden_states, self.lm_head, forward_batch
277
+ )
278
+
279
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
280
+ params_dict = dict(self.named_parameters())
281
+ weights = dict(weights)
282
+ loaded_keys = set()
283
+
284
+ for name, param in params_dict.items():
285
+ if name in loaded_keys:
286
+ continue
287
+
288
+ # Handle packed weights
289
+ is_packed = False
290
+ for packed_name, src_names in self.packed_modules_mapping.items():
291
+ if packed_name not in name:
292
+ continue
293
+
294
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
295
+ for src_name in src_names:
296
+ full_src_name = name.replace(packed_name, src_name)
297
+ if full_src_name in weights:
298
+ loaded_weight = weights[full_src_name]
299
+ # The shard_id for QKVParallelLinear is 'q', 'k', 'v'.
300
+ shard_id = src_name.split("_")[0]
301
+ weight_loader(param, loaded_weight, shard_id)
302
+ loaded_keys.add(full_src_name)
303
+
304
+ loaded_keys.add(name)
305
+ is_packed = True
306
+ break
307
+ if is_packed:
308
+ continue
309
+
310
+ # Handle non-packed weights
311
+ if name not in weights:
312
+ # Redundant with the check in the loop, but good for safety
313
+ continue
314
+
315
+ loaded_weight = weights[name]
316
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
317
+ weight_loader(param, loaded_weight)
318
+ loaded_keys.add(name)
319
+
320
+
321
+ EntryClass = PhiForCausalLM