sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,506 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
18
+ """Inference-only LLaMA model compatible with HuggingFace weights."""
19
+
20
+ import types
21
+ from typing import Any, Dict, Iterable, Optional, Tuple
22
+
23
+ import torch
24
+ from torch import nn
25
+ from torch.nn.parameter import Parameter
26
+ from transformers import LlamaConfig
27
+ from vllm.config import CacheConfig
28
+ from vllm.distributed import get_tensor_model_parallel_world_size
29
+ from vllm.model_executor.layers.rotary_embedding import get_rope
30
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
+
36
+ from sglang.srt.layers.activation import SiluAndMul
37
+ from sglang.srt.layers.layernorm import RMSNorm
38
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
39
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
+ from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_
42
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+
45
+
46
+ def gate_up_proj_weight_loader(
47
+ self,
48
+ param: Parameter,
49
+ loaded_weight: torch.Tensor,
50
+ loaded_shard_id: Optional[int] = None,
51
+ ):
52
+ if loaded_shard_id is None:
53
+ shard_offsets: List[Tuple[int, int, int]] = []
54
+ for i, output_size in enumerate(self.output_sizes):
55
+ shard_offsets.append((i, current_shard_offset, output_size))
56
+ current_shard_offset += output_size
57
+ for shard_id, shard_offset, shard_size in shard_offsets:
58
+ loaded_weight_shard = loaded_weight.narrow(
59
+ output_dim, shard_offset, shard_size
60
+ )
61
+ self.weight_loader(param, loaded_weight_shard, shard_id)
62
+ else:
63
+ assert loaded_shard_id < len(self.output_sizes)
64
+ param_data = param.data
65
+ shard_size = loaded_weight.shape[0]
66
+ shard_offset = loaded_shard_id * shard_size
67
+ param_data = param_data.narrow(0, shard_offset, shard_size)
68
+ assert param_data.shape == loaded_weight.shape
69
+ param_data.copy_(loaded_weight)
70
+ return
71
+
72
+
73
+ class LlamaMLP(nn.Module):
74
+ def __init__(
75
+ self,
76
+ hidden_size: int,
77
+ intermediate_size: int,
78
+ hidden_act: str,
79
+ quant_config: Optional[QuantizationConfig] = None,
80
+ prefix: str = "",
81
+ ) -> None:
82
+ super().__init__()
83
+ self.gate_up_proj = torch.nn.Linear(
84
+ hidden_size,
85
+ intermediate_size * 2,
86
+ bias=False,
87
+ )
88
+ self.gate_up_proj.output_sizes = [intermediate_size] * 2
89
+ self.gate_up_proj.weight_loader = types.MethodType(
90
+ gate_up_proj_weight_loader, self.gate_up_proj
91
+ )
92
+ self.gate_up_proj.weight.weight_loader = self.gate_up_proj.weight_loader
93
+ self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
94
+ if hidden_act != "silu":
95
+ raise ValueError(
96
+ f"Unsupported activation: {hidden_act}. "
97
+ "Only silu is supported for now."
98
+ )
99
+ self.act_fn = SiluAndMul()
100
+
101
+ def forward(self, x):
102
+ gate_up = self.gate_up_proj(x)
103
+ x = self.act_fn(gate_up)
104
+ x = self.down_proj(x)
105
+ return x
106
+
107
+
108
+ def _get_shard_offset_mapping(self, loaded_shard_id: str):
109
+ shard_offset_mapping = {
110
+ "q": 0,
111
+ "k": self.num_heads * self.head_size,
112
+ "v": (self.num_heads + self.num_kv_heads) * self.head_size,
113
+ "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
114
+ }
115
+ return shard_offset_mapping.get(loaded_shard_id)
116
+
117
+
118
+ def _get_shard_size_mapping(self, loaded_shard_id: str):
119
+ shard_size_mapping = {
120
+ "q": self.num_heads * self.head_size,
121
+ "k": self.num_kv_heads * self.head_size,
122
+ "v": self.num_kv_heads * self.head_size,
123
+ }
124
+ return shard_size_mapping.get(loaded_shard_id)
125
+
126
+
127
+ def qkv_proj_weight_loader(
128
+ self,
129
+ param: Parameter,
130
+ loaded_weight: torch.Tensor,
131
+ loaded_shard_id: Optional[str] = None,
132
+ ):
133
+ if loaded_shard_id is None:
134
+ shard_offsets = [
135
+ # (shard_id, shard_offset, shard_size)
136
+ ("q", 0, self.total_num_heads * self.head_size),
137
+ (
138
+ "k",
139
+ self.total_num_heads * self.head_size,
140
+ self.total_num_kv_heads * self.head_size,
141
+ ),
142
+ (
143
+ "v",
144
+ (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
145
+ self.total_num_kv_heads * self.head_size,
146
+ ),
147
+ ]
148
+ for shard_id, shard_offset, shard_size in shard_offsets:
149
+ loaded_weight_shard = loaded_weight.narrow(
150
+ param.output_dim, shard_offset, shard_size
151
+ )
152
+ self.weight_loader(param, loaded_weight_shard, shard_id)
153
+ else:
154
+ shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
155
+ shard_size = self._get_shard_size_mapping(loaded_shard_id)
156
+ param_data = param.data
157
+ param_data = param_data.narrow(0, shard_offset, shard_size)
158
+ assert param_data.shape == loaded_weight.shape
159
+ param_data.copy_(loaded_weight)
160
+ return
161
+
162
+
163
+ class LlamaAttention(nn.Module):
164
+ def __init__(
165
+ self,
166
+ config: LlamaConfig,
167
+ hidden_size: int,
168
+ num_heads: int,
169
+ num_kv_heads: int,
170
+ layer_id: int = 0,
171
+ rope_theta: float = 10000,
172
+ rope_scaling: Optional[Dict[str, Any]] = None,
173
+ rope_is_neox_style: bool = True,
174
+ max_position_embeddings: int = 8192,
175
+ quant_config: Optional[QuantizationConfig] = None,
176
+ prefix: str = "",
177
+ ) -> None:
178
+ super().__init__()
179
+ self.hidden_size = hidden_size
180
+ tp_size = get_tensor_model_parallel_world_size()
181
+ self.total_num_heads = num_heads
182
+ assert self.total_num_heads % tp_size == 0
183
+ self.num_heads = self.total_num_heads // tp_size
184
+ self.total_num_kv_heads = num_kv_heads
185
+ if self.total_num_kv_heads >= tp_size:
186
+ # Number of KV heads is greater than TP size, so we partition
187
+ # the KV heads across multiple tensor parallel GPUs.
188
+ assert self.total_num_kv_heads % tp_size == 0
189
+ else:
190
+ # Number of KV heads is less than TP size, so we replicate
191
+ # the KV heads across multiple tensor parallel GPUs.
192
+ assert tp_size % self.total_num_kv_heads == 0
193
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
194
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
195
+ self.head_dim = getattr(
196
+ config, "head_dim", self.hidden_size // self.total_num_heads
197
+ )
198
+ self.q_size = self.num_heads * self.head_dim
199
+ self.kv_size = self.num_kv_heads * self.head_dim
200
+ self.scaling = self.head_dim**-0.5
201
+ self.rope_theta = rope_theta
202
+ self.max_position_embeddings = max_position_embeddings
203
+
204
+ self.qkv_proj = torch.nn.Linear(
205
+ hidden_size,
206
+ (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
207
+ bias=False,
208
+ )
209
+ self.qkv_proj.total_num_heads = self.total_num_heads
210
+ self.qkv_proj.head_size = self.head_dim
211
+ self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
212
+ self.qkv_proj.num_heads = self.total_num_heads
213
+ self.qkv_proj.num_kv_heads = self.total_num_kv_heads
214
+ self.qkv_proj.weight_loader = types.MethodType(
215
+ qkv_proj_weight_loader, self.qkv_proj
216
+ )
217
+ self.qkv_proj._get_shard_offset_mapping = types.MethodType(
218
+ _get_shard_offset_mapping, self.qkv_proj
219
+ )
220
+ self.qkv_proj._get_shard_size_mapping = types.MethodType(
221
+ _get_shard_size_mapping, self.qkv_proj
222
+ )
223
+ self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
224
+ self.qkv_proj.weight.output_dim = 0
225
+ self.o_proj = torch.nn.Linear(
226
+ self.total_num_heads * self.head_dim,
227
+ hidden_size,
228
+ bias=False,
229
+ )
230
+ self.rotary_emb = get_rope(
231
+ self.head_dim,
232
+ rotary_dim=self.head_dim,
233
+ max_position=max_position_embeddings,
234
+ base=rope_theta,
235
+ rope_scaling=rope_scaling,
236
+ is_neox_style=rope_is_neox_style,
237
+ )
238
+ self.attn = RadixAttention(
239
+ self.num_heads,
240
+ self.head_dim,
241
+ self.scaling,
242
+ num_kv_heads=self.num_kv_heads,
243
+ layer_id=layer_id,
244
+ )
245
+
246
+ def forward(
247
+ self,
248
+ positions: torch.Tensor,
249
+ hidden_states: torch.Tensor,
250
+ forward_batch: ForwardBatch,
251
+ ) -> torch.Tensor:
252
+ qkv = self.qkv_proj(hidden_states)
253
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
254
+ q, k = self.rotary_emb(positions, q, k)
255
+ attn_output = self.attn(q, k, v, forward_batch)
256
+ output = self.o_proj(attn_output)
257
+ return output
258
+
259
+
260
+ class LlamaDecoderLayer(nn.Module):
261
+ def __init__(
262
+ self,
263
+ config: LlamaConfig,
264
+ layer_id: int = 0,
265
+ quant_config: Optional[QuantizationConfig] = None,
266
+ prefix: str = "",
267
+ ) -> None:
268
+ super().__init__()
269
+ self.hidden_size = config.hidden_size
270
+ rope_theta = getattr(config, "rope_theta", 10000)
271
+ rope_scaling = getattr(config, "rope_scaling", None)
272
+ if rope_scaling is not None and getattr(
273
+ config, "original_max_position_embeddings", None
274
+ ):
275
+ rope_scaling["original_max_position_embeddings"] = (
276
+ config.original_max_position_embeddings
277
+ )
278
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
279
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
280
+ self.self_attn = LlamaAttention(
281
+ config=config,
282
+ hidden_size=self.hidden_size,
283
+ num_heads=config.num_attention_heads,
284
+ num_kv_heads=config.num_key_value_heads,
285
+ layer_id=layer_id,
286
+ rope_theta=rope_theta,
287
+ rope_scaling=rope_scaling,
288
+ rope_is_neox_style=rope_is_neox_style,
289
+ max_position_embeddings=max_position_embeddings,
290
+ quant_config=quant_config,
291
+ prefix=f"{prefix}.self_attn",
292
+ )
293
+ self.mlp = LlamaMLP(
294
+ hidden_size=self.hidden_size,
295
+ intermediate_size=config.intermediate_size,
296
+ hidden_act=config.hidden_act,
297
+ quant_config=quant_config,
298
+ prefix=f"{prefix}.mlp",
299
+ )
300
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ self.post_attention_layernorm = RMSNorm(
302
+ config.hidden_size, eps=config.rms_norm_eps
303
+ )
304
+
305
+ def forward(
306
+ self,
307
+ positions: torch.Tensor,
308
+ hidden_states: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ residual: Optional[torch.Tensor],
311
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
312
+ # Self Attention
313
+ if residual is None:
314
+ residual = hidden_states
315
+ hidden_states = self.input_layernorm(hidden_states)
316
+ else:
317
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
318
+ hidden_states = self.self_attn(
319
+ positions=positions,
320
+ hidden_states=hidden_states,
321
+ forward_batch=forward_batch,
322
+ )
323
+
324
+ # Fully Connected
325
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
326
+ hidden_states = self.mlp(hidden_states)
327
+ return hidden_states, residual
328
+
329
+
330
+ class LlamaModel(nn.Module):
331
+ def __init__(
332
+ self,
333
+ config: LlamaConfig,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ ) -> None:
336
+ super().__init__()
337
+ self.config = config
338
+ self.padding_idx = config.pad_token_id
339
+ self.vocab_size = config.vocab_size
340
+ self.embed_tokens = VocabParallelEmbedding(
341
+ config.vocab_size,
342
+ config.hidden_size,
343
+ )
344
+ self.layers = nn.ModuleList(
345
+ [
346
+ LlamaDecoderLayer(
347
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
348
+ )
349
+ for i in range(config.num_hidden_layers)
350
+ ]
351
+ )
352
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
353
+
354
+ def forward(
355
+ self,
356
+ input_ids: torch.Tensor,
357
+ positions: torch.Tensor,
358
+ forward_batch: ForwardBatch,
359
+ input_embeds: torch.Tensor = None,
360
+ ) -> torch.Tensor:
361
+ if input_embeds is None:
362
+ hidden_states = self.embed_tokens(input_ids)
363
+ else:
364
+ hidden_states = input_embeds
365
+ residual = None
366
+ for i in range(len(self.layers)):
367
+ layer = self.layers[i]
368
+ hidden_states, residual = layer(
369
+ positions,
370
+ hidden_states,
371
+ forward_batch,
372
+ residual,
373
+ )
374
+ hidden_states, _ = self.norm(hidden_states, residual)
375
+ return hidden_states
376
+
377
+
378
+ class TorchNativeLlamaForCausalLM(nn.Module):
379
+ def __init__(
380
+ self,
381
+ config: LlamaConfig,
382
+ quant_config: Optional[QuantizationConfig] = None,
383
+ cache_config: Optional[CacheConfig] = None,
384
+ ) -> None:
385
+ super().__init__()
386
+ self.config = config
387
+ self.quant_config = quant_config
388
+ self.torchao_config = global_server_args_dict["torchao_config"]
389
+ self.model = LlamaModel(config, quant_config=quant_config)
390
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
391
+ self.logits_processor = LogitsProcessor(config)
392
+
393
+ @torch.no_grad()
394
+ def forward(
395
+ self,
396
+ input_ids: torch.Tensor,
397
+ positions: torch.Tensor,
398
+ forward_batch: ForwardBatch,
399
+ input_embeds: torch.Tensor = None,
400
+ ) -> LogitsProcessorOutput:
401
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
402
+ return self.logits_processor(
403
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
404
+ )
405
+
406
+ def get_hidden_dim(self, module_name):
407
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
408
+ return self.config.hidden_size, self.config.hidden_size
409
+ elif module_name in ["kv_proj"]:
410
+ return self.config.hidden_size, self.config.hidden_size // (
411
+ self.config.num_attention_heads // self.config.num_key_value_heads
412
+ )
413
+ elif module_name == "gate_up_proj":
414
+ return self.config.hidden_size, self.config.intermediate_size
415
+ elif module_name == "down_proj":
416
+ return self.config.intermediate_size, self.config.hidden_size
417
+ else:
418
+ raise NotImplementedError()
419
+
420
+ def get_module_name(self, name):
421
+ params_mapping = {
422
+ "q_proj": "qkv_proj",
423
+ "k_proj": "qkv_proj",
424
+ "v_proj": "qkv_proj",
425
+ "gate_proj": "gate_up_proj",
426
+ "up_proj": "gate_up_proj",
427
+ }
428
+ return params_mapping.get(name, name)
429
+
430
+ def get_module_name_from_weight_name(self, name):
431
+ stacked_params_mapping = [
432
+ # (param_name, shard_name, shard_id, num_shard)
433
+ ("qkv_proj", "q_proj", "q", 3),
434
+ ("qkv_proj", "k_proj", "k", 3),
435
+ ("qkv_proj", "v_proj", "v", 3),
436
+ ("gate_up_proj", "gate_proj", 0, 2),
437
+ ("gate_up_proj", "up_proj", 1, 2),
438
+ ]
439
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
440
+ if weight_name in name:
441
+ return (
442
+ name.replace(weight_name, param_name)[: -len(".weight")],
443
+ num_shard,
444
+ )
445
+ return name[: -len(".weight")], 1
446
+
447
+ def get_num_params(self):
448
+ params_dict = dict(self.named_parameters())
449
+ return len(params_dict)
450
+
451
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
452
+ stacked_params_mapping = [
453
+ # (param_name, shard_name, shard_id)
454
+ (".qkv_proj", ".q_proj", "q"),
455
+ (".qkv_proj", ".k_proj", "k"),
456
+ (".qkv_proj", ".v_proj", "v"),
457
+ (".gate_up_proj", ".gate_proj", 0),
458
+ (".gate_up_proj", ".up_proj", 1),
459
+ ]
460
+ params_dict = dict(self.named_parameters())
461
+
462
+ for name, loaded_weight in weights:
463
+ if "rotary_emb.inv_freq" in name or "projector" in name:
464
+ continue
465
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
466
+ # Models trained using ColossalAI may include these tensors in
467
+ # the checkpoint. Skip them.
468
+ continue
469
+ if name.startswith("model.vision_tower") and name not in params_dict:
470
+ continue
471
+
472
+ for param_name, weight_name, shard_id in stacked_params_mapping:
473
+ if weight_name not in name:
474
+ continue
475
+ name = name.replace(weight_name, param_name)
476
+ # Skip loading extra bias for GPTQ models.
477
+ if name.endswith(".bias") and name not in params_dict:
478
+ continue
479
+ param = params_dict[name]
480
+ weight_loader = param.weight_loader
481
+ weight_loader(param, loaded_weight, shard_id)
482
+ break
483
+ else:
484
+ # Skip loading extra bias for GPTQ models.
485
+ if name.endswith(".bias") and name not in params_dict:
486
+ continue
487
+ param = params_dict[name]
488
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
489
+ weight_loader(param, loaded_weight)
490
+
491
+ if (
492
+ hasattr(self.config, "tie_word_embeddings")
493
+ and self.config.tie_word_embeddings
494
+ ):
495
+ # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
496
+ param = self.lm_head.weight
497
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
498
+ weight_loader(param, self.model.embed_tokens.weight)
499
+ apply_torchao_config_(self, params_dict, set(["proj.weight"]))
500
+
501
+
502
+ class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
503
+ pass
504
+
505
+
506
+ EntryClass = [TorchNativeLlamaForCausalLM, TorchNativePhi3ForCausalLM]
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.model_executor.model_runner import InputMetadata
44
+ from sglang.srt.model_executor.model_runner import ForwardBatch
45
45
 
46
46
 
47
47
  class XverseMLP(nn.Module):
@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
160
160
  self,
161
161
  positions: torch.Tensor,
162
162
  hidden_states: torch.Tensor,
163
- input_metadata: InputMetadata,
163
+ forward_batch: ForwardBatch,
164
164
  ) -> torch.Tensor:
165
165
  qkv, _ = self.qkv_proj(hidden_states)
166
166
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
167
167
  q, k = self.rotary_emb(positions, q, k)
168
- attn_output = self.attn(q, k, v, input_metadata)
168
+ attn_output = self.attn(q, k, v, forward_batch)
169
169
  output, _ = self.o_proj(attn_output)
170
170
  return output
171
171
 
@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
222
222
  self,
223
223
  positions: torch.Tensor,
224
224
  hidden_states: torch.Tensor,
225
- input_metadata: InputMetadata,
225
+ forward_batch: ForwardBatch,
226
226
  residual: Optional[torch.Tensor],
227
227
  ) -> Tuple[torch.Tensor, torch.Tensor]:
228
228
  # Self Attention
@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
234
234
  hidden_states = self.self_attn(
235
235
  positions=positions,
236
236
  hidden_states=hidden_states,
237
- input_metadata=input_metadata,
237
+ forward_batch=forward_batch,
238
238
  )
239
239
 
240
240
  # Fully Connected
@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
271
271
  self,
272
272
  input_ids: torch.Tensor,
273
273
  positions: torch.Tensor,
274
- input_metadata: InputMetadata,
274
+ forward_batch: ForwardBatch,
275
275
  input_embeds: torch.Tensor = None,
276
276
  ) -> torch.Tensor:
277
277
  if input_embeds is None:
@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
284
284
  hidden_states, residual = layer(
285
285
  positions,
286
286
  hidden_states,
287
- input_metadata,
287
+ forward_batch,
288
288
  residual,
289
289
  )
290
290
  # print(f"layer[{i}].hidden_states: {hidden_states}")
@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
312
312
  self,
313
313
  input_ids: torch.Tensor,
314
314
  positions: torch.Tensor,
315
- input_metadata: InputMetadata,
315
+ forward_batch: ForwardBatch,
316
316
  input_embeds: torch.Tensor = None,
317
317
  ) -> torch.Tensor:
318
- hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
319
319
  return self.logits_processor(
320
- input_ids, hidden_states, self.lm_head.weight, input_metadata
320
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
321
321
  )
322
322
 
323
323
  def load_weights(
@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
46
  from sglang.srt.layers.radix_attention import RadixAttention
47
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
 
49
49
 
50
50
  class XverseMLP(nn.Module):
@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
244
244
  self,
245
245
  positions: torch.Tensor,
246
246
  hidden_states: torch.Tensor,
247
- input_metadata: InputMetadata,
247
+ forward_batch: ForwardBatch,
248
248
  ) -> torch.Tensor:
249
249
  qkv, _ = self.qkv_proj(hidden_states)
250
250
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
251
251
  q, k = self.rotary_emb(positions, q, k)
252
- attn_output = self.attn(q, k, v, input_metadata)
252
+ attn_output = self.attn(q, k, v, forward_batch)
253
253
  output, _ = self.o_proj(attn_output)
254
254
  return output
255
255
 
@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
300
300
  self,
301
301
  positions: torch.Tensor,
302
302
  hidden_states: torch.Tensor,
303
- input_metadata: InputMetadata,
303
+ forward_batch: ForwardBatch,
304
304
  residual: Optional[torch.Tensor],
305
305
  ) -> torch.Tensor:
306
306
  # Self Attention
@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
312
312
  hidden_states = self.self_attn(
313
313
  positions=positions,
314
314
  hidden_states=hidden_states,
315
- input_metadata=input_metadata,
315
+ forward_batch=forward_batch,
316
316
  )
317
317
 
318
318
  # Fully Connected
@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
353
353
  self,
354
354
  input_ids: torch.Tensor,
355
355
  positions: torch.Tensor,
356
- input_metadata: InputMetadata,
356
+ forward_batch: ForwardBatch,
357
357
  ) -> torch.Tensor:
358
358
  hidden_states = self.embed_tokens(input_ids)
359
359
  residual = None
360
360
  for i in range(len(self.layers)):
361
361
  layer = self.layers[i]
362
362
  hidden_states, residual = layer(
363
- positions, hidden_states, input_metadata, residual
363
+ positions, hidden_states, forward_batch, residual
364
364
  )
365
365
  hidden_states, _ = self.norm(hidden_states, residual)
366
366
  return hidden_states
@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
388
388
  self,
389
389
  input_ids: torch.Tensor,
390
390
  positions: torch.Tensor,
391
- input_metadata: InputMetadata,
391
+ forward_batch: ForwardBatch,
392
392
  ) -> torch.Tensor:
393
- hidden_states = self.model(input_ids, positions, input_metadata)
393
+ hidden_states = self.model(input_ids, positions, forward_batch)
394
394
  return self.logits_processor(
395
- input_ids, hidden_states, self.lm_head.weight, input_metadata
395
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
396
396
  )
397
397
 
398
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):