sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +4 -4
  9. sglang/lang/interpreter.py +24 -9
  10. sglang/lang/ir.py +1 -1
  11. sglang/srt/constrained/__init__.py +15 -0
  12. sglang/srt/constrained/base_cache.py +15 -0
  13. sglang/srt/constrained/fsm_cache.py +36 -1
  14. sglang/srt/constrained/jump_forward.py +15 -0
  15. sglang/srt/conversation.py +26 -0
  16. sglang/srt/hf_transformers_utils.py +18 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  18. sglang/srt/layers/extend_attention.py +15 -0
  19. sglang/srt/layers/fused_moe.py +15 -0
  20. sglang/srt/layers/linear.py +15 -0
  21. sglang/srt/layers/logits_processor.py +109 -72
  22. sglang/srt/layers/quantization/__init__.py +15 -0
  23. sglang/srt/layers/quantization/fp8.py +15 -0
  24. sglang/srt/layers/radix_attention.py +21 -3
  25. sglang/srt/layers/token_attention.py +16 -1
  26. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  27. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  28. sglang/srt/managers/detokenizer_manager.py +16 -1
  29. sglang/srt/managers/io_struct.py +38 -5
  30. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  31. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
  32. sglang/srt/managers/tokenizer_manager.py +99 -57
  33. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
  34. sglang/srt/mem_cache/flush_cache.py +33 -0
  35. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  36. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  37. sglang/srt/mm_utils.py +15 -0
  38. sglang/srt/model_config.py +20 -0
  39. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
  40. sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
  41. sglang/srt/model_loader/model_loader.py +15 -0
  42. sglang/srt/model_loader/utils.py +16 -1
  43. sglang/srt/models/chatglm.py +16 -1
  44. sglang/srt/models/commandr.py +16 -1
  45. sglang/srt/models/dbrx.py +16 -1
  46. sglang/srt/models/deepseek.py +16 -1
  47. sglang/srt/models/deepseek_v2.py +532 -0
  48. sglang/srt/models/gemma.py +16 -1
  49. sglang/srt/models/gemma2.py +16 -1
  50. sglang/srt/models/gpt_bigcode.py +16 -1
  51. sglang/srt/models/grok.py +16 -1
  52. sglang/srt/models/internlm2.py +16 -1
  53. sglang/srt/models/llama2.py +16 -1
  54. sglang/srt/models/llama_classification.py +19 -4
  55. sglang/srt/models/llava.py +17 -2
  56. sglang/srt/models/llavavid.py +17 -2
  57. sglang/srt/models/minicpm.py +16 -1
  58. sglang/srt/models/mistral.py +15 -0
  59. sglang/srt/models/mixtral.py +16 -1
  60. sglang/srt/models/mixtral_quant.py +16 -1
  61. sglang/srt/models/qwen.py +16 -1
  62. sglang/srt/models/qwen2.py +16 -1
  63. sglang/srt/models/qwen2_moe.py +16 -1
  64. sglang/srt/models/stablelm.py +16 -1
  65. sglang/srt/models/yivl.py +15 -0
  66. sglang/srt/openai_api/adapter.py +545 -160
  67. sglang/srt/openai_api/protocol.py +65 -1
  68. sglang/srt/sampling_params.py +20 -4
  69. sglang/srt/server.py +90 -37
  70. sglang/srt/server_args.py +76 -17
  71. sglang/srt/utils.py +15 -0
  72. sglang/test/test_programs.py +5 -1
  73. sglang/utils.py +22 -0
  74. sglang/version.py +1 -1
  75. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
  76. sglang-0.2.7.dist-info/RECORD +93 -0
  77. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  78. sglang/srt/flush_cache.py +0 -18
  79. sglang-0.2.5.dist-info/RECORD +0 -92
  80. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from:
17
+ # https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
18
+ """Inference-only DeepseekV2 model."""
19
+ from typing import Any, Dict, Iterable, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+ from vllm.config import CacheConfig
25
+ from vllm.distributed import (
26
+ get_tensor_model_parallel_world_size,
27
+ tensor_model_parallel_all_reduce,
28
+ )
29
+ from vllm.model_executor.layers.activation import SiluAndMul
30
+ from vllm.model_executor.layers.fused_moe import FusedMoE
31
+ from vllm.model_executor.layers.layernorm import RMSNorm
32
+ from vllm.model_executor.layers.linear import (
33
+ ColumnParallelLinear,
34
+ MergedColumnParallelLinear,
35
+ ReplicatedLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
+
46
+ from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.model_executor.model_runner import InputMetadata
49
+
50
+
51
+ class DeepseekV2MLP(nn.Module):
52
+ def __init__(
53
+ self,
54
+ hidden_size: int,
55
+ intermediate_size: int,
56
+ hidden_act: str,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ reduce_results: bool = True,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.gate_up_proj = MergedColumnParallelLinear(
62
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ reduce_results=reduce_results,
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.down_proj(x)
82
+ return x
83
+
84
+
85
+ class DeepseekV2MoE(nn.Module):
86
+
87
+ def __init__(
88
+ self,
89
+ config: PretrainedConfig,
90
+ quant_config: Optional[QuantizationConfig] = None,
91
+ ):
92
+ super().__init__()
93
+ self.tp_size = get_tensor_model_parallel_world_size()
94
+ self.routed_scaling_factor = config.routed_scaling_factor
95
+ self.n_shared_experts = config.n_shared_experts
96
+ self.routed_scaling_factor = config.routed_scaling_factor
97
+ if self.tp_size > config.n_routed_experts:
98
+ raise ValueError(
99
+ f"Tensor parallel size {self.tp_size} is greater than "
100
+ f"the number of experts {config.n_routed_experts}."
101
+ )
102
+
103
+ if config.hidden_act != "silu":
104
+ raise ValueError(
105
+ f"Unsupported activation: {config.hidden_act}. "
106
+ "Only silu is supported for now."
107
+ )
108
+
109
+ self.experts = FusedMoE(
110
+ num_experts=config.n_routed_experts,
111
+ top_k=config.num_experts_per_tok,
112
+ hidden_size=config.hidden_size,
113
+ intermediate_size=config.moe_intermediate_size,
114
+ reduce_results=False,
115
+ renormalize=config.norm_topk_prob,
116
+ quant_config=quant_config,
117
+ use_grouped_topk=True,
118
+ num_expert_group=config.n_group,
119
+ topk_group=config.topk_group,
120
+ )
121
+
122
+ self.gate = ReplicatedLinear(
123
+ config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
124
+ )
125
+ if config.n_shared_experts is not None:
126
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
127
+ self.shared_experts = DeepseekV2MLP(
128
+ hidden_size=config.hidden_size,
129
+ intermediate_size=intermediate_size,
130
+ hidden_act=config.hidden_act,
131
+ quant_config=quant_config,
132
+ reduce_results=False,
133
+ )
134
+
135
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
136
+ num_tokens, hidden_dim = hidden_states.shape
137
+ hidden_states = hidden_states.view(-1, hidden_dim)
138
+ if self.n_shared_experts is not None:
139
+ shared_output = self.shared_experts(hidden_states)
140
+ # router_logits: (num_tokens, n_experts)
141
+ router_logits, _ = self.gate(hidden_states)
142
+ final_hidden_states = (
143
+ self.experts(hidden_states=hidden_states, router_logits=router_logits)
144
+ * self.routed_scaling_factor
145
+ )
146
+ if shared_output is not None:
147
+ final_hidden_states = final_hidden_states + shared_output
148
+ if self.tp_size > 1:
149
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
150
+
151
+ return final_hidden_states.view(num_tokens, hidden_dim)
152
+
153
+
154
+ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
155
+ import math
156
+
157
+ if scale <= 1:
158
+ return 1.0
159
+ return 0.1 * mscale * math.log(scale) + 1.0
160
+
161
+
162
+ class DeepseekV2Attention(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ config: PretrainedConfig,
167
+ hidden_size: int,
168
+ num_heads: int,
169
+ qk_nope_head_dim: int,
170
+ qk_rope_head_dim: int,
171
+ v_head_dim: int,
172
+ q_lora_rank: int,
173
+ kv_lora_rank: int,
174
+ rope_theta: float = 10000,
175
+ rope_scaling: Optional[Dict[str, Any]] = None,
176
+ max_position_embeddings: int = 8192,
177
+ cache_config: Optional[CacheConfig] = None,
178
+ quant_config: Optional[QuantizationConfig] = None,
179
+ layer_id=None,
180
+ ) -> None:
181
+ super().__init__()
182
+ self.layer_id = layer_id
183
+ self.hidden_size = hidden_size
184
+ self.qk_nope_head_dim = qk_nope_head_dim
185
+ self.qk_rope_head_dim = qk_rope_head_dim
186
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
187
+ self.v_head_dim = v_head_dim
188
+ self.q_lora_rank = q_lora_rank
189
+ self.kv_lora_rank = kv_lora_rank
190
+ self.num_heads = num_heads
191
+ tp_size = get_tensor_model_parallel_world_size()
192
+ assert num_heads % tp_size == 0
193
+ self.num_local_heads = num_heads // tp_size
194
+ self.scaling = self.qk_head_dim**-0.5
195
+ self.rope_theta = rope_theta
196
+ self.max_position_embeddings = max_position_embeddings
197
+
198
+ if self.q_lora_rank is not None:
199
+ self.q_a_proj = ReplicatedLinear(
200
+ self.hidden_size,
201
+ self.q_lora_rank,
202
+ bias=False,
203
+ quant_config=quant_config,
204
+ )
205
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
206
+ self.q_b_proj = ColumnParallelLinear(
207
+ q_lora_rank,
208
+ self.num_heads * self.qk_head_dim,
209
+ bias=False,
210
+ quant_config=quant_config,
211
+ )
212
+ else:
213
+ self.q_proj = ColumnParallelLinear(
214
+ self.hidden_size,
215
+ self.num_heads * self.qk_head_dim,
216
+ bias=False,
217
+ quant_config=quant_config,
218
+ )
219
+
220
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
221
+ self.hidden_size,
222
+ self.kv_lora_rank + self.qk_rope_head_dim,
223
+ bias=False,
224
+ quant_config=quant_config,
225
+ )
226
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
227
+ self.kv_b_proj = ColumnParallelLinear(
228
+ self.kv_lora_rank,
229
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
230
+ bias=False,
231
+ quant_config=quant_config,
232
+ )
233
+ # O projection.
234
+ self.o_proj = RowParallelLinear(
235
+ self.num_heads * self.v_head_dim,
236
+ self.hidden_size,
237
+ bias=False,
238
+ quant_config=quant_config,
239
+ )
240
+ rope_scaling["type"] = "deepseek_yarn"
241
+ self.rotary_emb = get_rope(
242
+ qk_rope_head_dim,
243
+ rotary_dim=qk_rope_head_dim,
244
+ max_position=max_position_embeddings,
245
+ base=rope_theta,
246
+ rope_scaling=rope_scaling,
247
+ is_neox_style=False,
248
+ )
249
+
250
+ if rope_scaling:
251
+ mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
252
+ scaling_factor = rope_scaling["factor"]
253
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
254
+ self.scaling = self.scaling * mscale * mscale
255
+
256
+ # self.attn = Attention(self.num_heads,
257
+ # self.qk_head_dim,
258
+ # self.scaling,
259
+ # num_kv_heads=self.num_heads)
260
+
261
+ # TODO, support head_size 192
262
+ self.attn = RadixAttention(
263
+ self.num_local_heads,
264
+ 256,
265
+ self.scaling,
266
+ num_kv_heads=self.num_local_heads,
267
+ layer_id=layer_id,
268
+ )
269
+
270
+ def forward(
271
+ self,
272
+ positions: torch.Tensor,
273
+ hidden_states: torch.Tensor,
274
+ input_metadata: InputMetadata,
275
+ ) -> torch.Tensor:
276
+ if self.q_lora_rank is not None:
277
+ q = self.q_a_proj(hidden_states)[0]
278
+ q = self.q_a_layernorm(q)
279
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
280
+ else:
281
+ q = self.q_proj(hidden_states)[0].view(
282
+ -1, self.num_local_heads, self.qk_head_dim
283
+ )
284
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
285
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
286
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
287
+ latent_cache = latent_cache.unsqueeze(1)
288
+ kv_a = self.kv_a_layernorm(kv_a.contiguous())
289
+ kv = self.kv_b_proj(kv_a)[0]
290
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
291
+ k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
292
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
293
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
294
+ q[..., self.qk_nope_head_dim :] = q_pe
295
+ k = torch.empty_like(q)
296
+ k[..., : self.qk_nope_head_dim] = k_nope
297
+ k[..., self.qk_nope_head_dim :] = k_pe
298
+ q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
299
+ -1, self.num_local_heads * 256
300
+ )
301
+ k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
302
+ -1, self.num_local_heads * 256
303
+ )
304
+ v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
305
+ -1, self.num_local_heads * 256
306
+ )
307
+ attn_output = self.attn(q, k, v, input_metadata)
308
+ attn_output = attn_output.view(-1, self.num_local_heads, 256)[
309
+ ..., : self.v_head_dim
310
+ ].reshape(-1, self.num_local_heads * self.v_head_dim)
311
+ output, _ = self.o_proj(attn_output)
312
+ return output
313
+
314
+
315
+ class DeepseekV2DecoderLayer(nn.Module):
316
+
317
+ def __init__(
318
+ self,
319
+ config: PretrainedConfig,
320
+ layer_id: int,
321
+ cache_config: Optional[CacheConfig] = None,
322
+ quant_config: Optional[QuantizationConfig] = None,
323
+ ) -> None:
324
+ super().__init__()
325
+ self.hidden_size = config.hidden_size
326
+ rope_theta = getattr(config, "rope_theta", 10000)
327
+ rope_scaling = getattr(config, "rope_scaling", None)
328
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
329
+ self.self_attn = DeepseekV2Attention(
330
+ config=config,
331
+ hidden_size=self.hidden_size,
332
+ num_heads=config.num_attention_heads,
333
+ qk_nope_head_dim=config.qk_nope_head_dim,
334
+ qk_rope_head_dim=config.qk_rope_head_dim,
335
+ v_head_dim=config.v_head_dim,
336
+ q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
337
+ kv_lora_rank=config.kv_lora_rank,
338
+ rope_theta=rope_theta,
339
+ rope_scaling=rope_scaling,
340
+ max_position_embeddings=max_position_embeddings,
341
+ cache_config=cache_config,
342
+ quant_config=quant_config,
343
+ layer_id=layer_id,
344
+ )
345
+ if (
346
+ config.n_routed_experts is not None
347
+ and layer_id >= config.first_k_dense_replace
348
+ and layer_id % config.moe_layer_freq == 0
349
+ ):
350
+ self.mlp = DeepseekV2MoE(config=config, quant_config=quant_config)
351
+ else:
352
+ self.mlp = DeepseekV2MLP(
353
+ hidden_size=config.hidden_size,
354
+ intermediate_size=config.intermediate_size,
355
+ hidden_act=config.hidden_act,
356
+ quant_config=quant_config,
357
+ )
358
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359
+ self.post_attention_layernorm = RMSNorm(
360
+ config.hidden_size, eps=config.rms_norm_eps
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ positions: torch.Tensor,
366
+ hidden_states: torch.Tensor,
367
+ input_metadata: InputMetadata,
368
+ residual: Optional[torch.Tensor],
369
+ ) -> torch.Tensor:
370
+ # Self Attention
371
+ if residual is None:
372
+ residual = hidden_states
373
+ hidden_states = self.input_layernorm(hidden_states)
374
+ else:
375
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
376
+ hidden_states = self.self_attn(
377
+ positions=positions,
378
+ hidden_states=hidden_states,
379
+ input_metadata=input_metadata,
380
+ )
381
+
382
+ # Fully Connected
383
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
384
+ hidden_states = self.mlp(hidden_states)
385
+ return hidden_states, residual
386
+
387
+
388
+ class DeepseekV2Model(nn.Module):
389
+
390
+ fall_back_to_pt_during_load = False
391
+
392
+ def __init__(
393
+ self,
394
+ config: PretrainedConfig,
395
+ cache_config: Optional[CacheConfig] = None,
396
+ quant_config: Optional[QuantizationConfig] = None,
397
+ ) -> None:
398
+ super().__init__()
399
+ self.padding_id = config.pad_token_id
400
+ self.vocab_size = config.vocab_size
401
+
402
+ self.embed_tokens = VocabParallelEmbedding(
403
+ config.vocab_size,
404
+ config.hidden_size,
405
+ )
406
+ self.layers = nn.ModuleList(
407
+ [
408
+ DeepseekV2DecoderLayer(
409
+ config,
410
+ layer_id,
411
+ cache_config=cache_config,
412
+ quant_config=quant_config,
413
+ )
414
+ for layer_id in range(config.num_hidden_layers)
415
+ ]
416
+ )
417
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
418
+
419
+ def forward(
420
+ self,
421
+ input_ids: torch.Tensor,
422
+ positions: torch.Tensor,
423
+ input_metadata: InputMetadata,
424
+ ) -> torch.Tensor:
425
+ hidden_states = self.embed_tokens(input_ids)
426
+ residual = None
427
+ for i in range(len(self.layers)):
428
+ layer = self.layers[i]
429
+ hidden_states, residual = layer(
430
+ positions, hidden_states, input_metadata, residual
431
+ )
432
+ hidden_states, _ = self.norm(hidden_states, residual)
433
+ return hidden_states
434
+
435
+
436
+ class DeepseekV2ForCausalLM(nn.Module):
437
+
438
+ def __init__(
439
+ self,
440
+ config: PretrainedConfig,
441
+ cache_config: Optional[CacheConfig] = None,
442
+ quant_config: Optional[QuantizationConfig] = None,
443
+ ) -> None:
444
+ super().__init__()
445
+ self.config = config
446
+ self.quant_config = quant_config
447
+ self.model = DeepseekV2Model(config, cache_config, quant_config)
448
+ self.lm_head = ParallelLMHead(
449
+ config.vocab_size, config.hidden_size, quant_config=quant_config
450
+ )
451
+ self.logits_processor = LogitsProcessor(config)
452
+
453
+ def forward(
454
+ self,
455
+ input_ids: torch.Tensor,
456
+ positions: torch.Tensor,
457
+ input_metadata: InputMetadata,
458
+ ) -> torch.Tensor:
459
+ hidden_states = self.model(input_ids, positions, input_metadata)
460
+ return self.logits_processor(
461
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
462
+ )
463
+
464
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
465
+ stacked_params_mapping = [
466
+ # (param_name, shard_name, shard_id)
467
+ ("gate_up_proj", "gate_proj", 0),
468
+ ("gate_up_proj", "up_proj", 1),
469
+ ]
470
+
471
+ # Params for weights, fp8 weight scales, fp8 activation scales
472
+ # (param_name, weight_name, expert_id, shard_id)
473
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
474
+ ckpt_gate_proj_name="gate_proj",
475
+ ckpt_down_proj_name="down_proj",
476
+ ckpt_up_proj_name="up_proj",
477
+ num_experts=self.config.n_routed_experts,
478
+ )
479
+
480
+ params_dict = dict(self.named_parameters())
481
+ for name, loaded_weight in weights:
482
+ if "rotary_emb.inv_freq" in name:
483
+ continue
484
+ for param_name, weight_name, shard_id in stacked_params_mapping:
485
+ # Skip non-stacked layers and experts (experts handled below).
486
+ if weight_name not in name:
487
+ continue
488
+ # We have mlp.experts[0].gate_proj in the checkpoint.
489
+ # Since we handle the experts below in expert_params_mapping,
490
+ # we need to skip here BEFORE we update the name, otherwise
491
+ # name will be updated to mlp.experts[0].gate_up_proj, which
492
+ # will then be updated below in expert_params_mapping
493
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
494
+ if ("mlp.experts." in name) and name not in params_dict:
495
+ continue
496
+ name = name.replace(weight_name, param_name)
497
+ # Skip loading extra bias for GPTQ models.
498
+ if name.endswith(".bias") and name not in params_dict:
499
+ continue
500
+ param = params_dict[name]
501
+ weight_loader = param.weight_loader
502
+ weight_loader(param, loaded_weight, shard_id)
503
+ break
504
+ else:
505
+ for mapping in expert_params_mapping:
506
+ param_name, weight_name, expert_id, shard_id = mapping
507
+ if weight_name not in name:
508
+ continue
509
+ name = name.replace(weight_name, param_name)
510
+ param = params_dict[name]
511
+ weight_loader = param.weight_loader
512
+ weight_loader(
513
+ param,
514
+ loaded_weight,
515
+ weight_name,
516
+ shard_id=shard_id,
517
+ expert_id=expert_id,
518
+ )
519
+ break
520
+ else:
521
+ # Skip loading extra bias for GPTQ models.
522
+ if name.endswith(".bias") and name not in params_dict:
523
+ continue
524
+
525
+ param = params_dict[name]
526
+ weight_loader = getattr(
527
+ param, "weight_loader", default_weight_loader
528
+ )
529
+ weight_loader(param, loaded_weight)
530
+
531
+
532
+ EntryClass = DeepseekV2ForCausalLM
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from:
2
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
3
18
  """Inference-only Gemma model compatible with HuggingFace weights."""
@@ -22,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
37
 
23
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
24
39
  from sglang.srt.layers.radix_attention import RadixAttention
25
- from sglang.srt.managers.controller.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.model_runner import InputMetadata
26
41
 
27
42
 
28
43
  class GemmaMLP(nn.Module):
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from:
2
17
  # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
3
18
  from typing import Iterable, Optional, Set, Tuple, Union
@@ -27,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
27
42
 
28
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
44
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.managers.controller.model_runner import InputMetadata
45
+ from sglang.srt.model_executor.model_runner import InputMetadata
31
46
 
32
47
 
33
48
  class GemmaRMSNorm(CustomOp):
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from:
2
17
  # https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
3
18
  """Inference-only GPTBigCode model compatible with HuggingFace weights."""
@@ -20,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
35
 
21
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
22
37
  from sglang.srt.layers.radix_attention import RadixAttention
23
- from sglang.srt.managers.controller.infer_batch import InputMetadata
38
+ from sglang.srt.managers.schedule_batch import InputMetadata
24
39
 
25
40
 
26
41
  class GPTBigCodeAttention(nn.Module):
sglang/srt/models/grok.py CHANGED
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Adapted from
2
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
18
  """Inference-only Grok1 model."""
@@ -37,7 +52,7 @@ from vllm.utils import print_warning_once
37
52
  from sglang.srt.layers.fused_moe import fused_moe
38
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
54
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.managers.controller.model_runner import InputMetadata
55
+ from sglang.srt.model_executor.model_runner import InputMetadata
41
56
 
42
57
  use_fused = True
43
58
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # -*- coding: utf-8 -*-
2
17
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
3
18
 
@@ -25,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
40
 
26
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
42
  from sglang.srt.layers.radix_attention import RadixAttention
28
- from sglang.srt.managers.controller.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.model_runner import InputMetadata
29
44
 
30
45
 
31
46
  class InternLM2MLP(nn.Module):