sglang 0.1.22__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/bench_serving.py +243 -25
  3. sglang/global_config.py +3 -2
  4. sglang/lang/interpreter.py +1 -0
  5. sglang/srt/hf_transformers_utils.py +13 -1
  6. sglang/srt/layers/logits_processor.py +4 -5
  7. sglang/srt/layers/radix_attention.py +38 -49
  8. sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
  9. sglang/srt/managers/controller/infer_batch.py +51 -22
  10. sglang/srt/managers/controller/model_runner.py +58 -4
  11. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  12. sglang/srt/managers/controller/tp_worker.py +9 -11
  13. sglang/srt/memory_pool.py +13 -5
  14. sglang/srt/models/deepseek.py +430 -0
  15. sglang/srt/models/gpt_bigcode.py +282 -0
  16. sglang/srt/models/llama2.py +19 -10
  17. sglang/srt/server.py +26 -1
  18. sglang/srt/server_args.py +12 -6
  19. sglang/srt/utils.py +93 -1
  20. sglang/version.py +1 -0
  21. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/METADATA +10 -6
  22. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/RECORD +25 -36
  23. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/WHEEL +1 -1
  24. sglang/backend/__init__.py +0 -0
  25. sglang/backend/anthropic.py +0 -77
  26. sglang/backend/base_backend.py +0 -80
  27. sglang/backend/litellm.py +0 -90
  28. sglang/backend/openai.py +0 -438
  29. sglang/backend/runtime_endpoint.py +0 -283
  30. sglang/backend/vertexai.py +0 -149
  31. sglang/bench.py +0 -627
  32. sglang/srt/managers/controller/dp_worker.py +0 -113
  33. sglang/srt/openai_api/api_adapter.py +0 -432
  34. sglang/srt/openai_api/openai_api_adapter.py +0 -431
  35. sglang/srt/openai_api/openai_protocol.py +0 -207
  36. sglang/srt/openai_api_adapter.py +0 -411
  37. sglang/srt/openai_protocol.py +0 -207
  38. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/LICENSE +0 -0
  39. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,430 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/14f91fe67c2342f2fe859dc6a5c40810df0e1c61/vllm/model_executor/models/deepseek.py
3
+ """Inference-only Deepseek model."""
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+ from vllm.config import CacheConfig
10
+ from vllm.distributed import (
11
+ get_tensor_model_parallel_rank,
12
+ get_tensor_model_parallel_world_size,
13
+ tensor_model_parallel_all_reduce,
14
+ )
15
+ from vllm.model_executor.layers.activation import SiluAndMul
16
+ from vllm.model_executor.layers.fused_moe import fused_moe
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (
19
+ MergedColumnParallelLinear,
20
+ QKVParallelLinear,
21
+ ReplicatedLinear,
22
+ RowParallelLinear,
23
+ )
24
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
+ from vllm.model_executor.layers.rotary_embedding import get_rope
26
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
27
+ ParallelLMHead,
28
+ VocabParallelEmbedding,
29
+ )
30
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
+
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.managers.controller.infer_batch import InputMetadata
35
+
36
+
37
+ class DeepseekMLP(nn.Module):
38
+
39
+ def __init__(
40
+ self,
41
+ hidden_size: int,
42
+ intermediate_size: int,
43
+ hidden_act: str,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ reduce_results: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.gate_up_proj = MergedColumnParallelLinear(
49
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
50
+ )
51
+ self.down_proj = RowParallelLinear(
52
+ intermediate_size,
53
+ hidden_size,
54
+ bias=False,
55
+ quant_config=quant_config,
56
+ reduce_results=reduce_results,
57
+ )
58
+ if hidden_act != "silu":
59
+ raise ValueError(
60
+ f"Unsupported activation: {hidden_act}. "
61
+ "Only silu is supported for now."
62
+ )
63
+ self.act_fn = SiluAndMul()
64
+
65
+ def forward(self, x):
66
+ gate_up, _ = self.gate_up_proj(x)
67
+ x = self.act_fn(gate_up)
68
+ x, _ = self.down_proj(x)
69
+ return x
70
+
71
+
72
+ class DeepseekMoE(nn.Module):
73
+
74
+ def __init__(
75
+ self,
76
+ config: PretrainedConfig,
77
+ quant_config: Optional[QuantizationConfig] = None,
78
+ ):
79
+ super().__init__()
80
+ self.config = config
81
+ self.rank = get_tensor_model_parallel_rank()
82
+ self.tp_size = get_tensor_model_parallel_world_size()
83
+ self.n_routed_experts = config.n_routed_experts
84
+ self.top_k = config.num_experts_per_tok
85
+ if self.tp_size > self.n_routed_experts:
86
+ raise ValueError(
87
+ f"Tensor parallel size {self.tp_size} is greater than "
88
+ f"the number of experts {self.n_routed_experts}."
89
+ )
90
+
91
+ self.experts = nn.ModuleList(
92
+ [
93
+ DeepseekMLP(
94
+ hidden_size=config.hidden_size,
95
+ intermediate_size=config.moe_intermediate_size,
96
+ hidden_act=config.hidden_act,
97
+ quant_config=quant_config,
98
+ reduce_results=False,
99
+ )
100
+ for idx in range(self.n_routed_experts)
101
+ ]
102
+ )
103
+ self.pack_params()
104
+
105
+ self.gate = ReplicatedLinear(
106
+ config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
107
+ )
108
+
109
+ if config.n_shared_experts is not None:
110
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
111
+ self.shared_experts = DeepseekMLP(
112
+ hidden_size=config.hidden_size,
113
+ intermediate_size=intermediate_size,
114
+ hidden_act=config.hidden_act,
115
+ quant_config=quant_config,
116
+ reduce_results=False,
117
+ )
118
+
119
+ def pack_params(self):
120
+ w1 = []
121
+ w2 = []
122
+ for expert in self.experts:
123
+ w1.append(expert.gate_up_proj.weight)
124
+ w2.append(expert.down_proj.weight)
125
+ self.w1 = torch._utils._flatten_dense_tensors(w1)
126
+ w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
127
+ for data, param in zip(w1s, w1):
128
+ param.data = data
129
+ self.w1 = self.w1.view(len(w1), *w1s[0].shape)
130
+
131
+ self.w2 = torch._utils._flatten_dense_tensors(w2)
132
+ w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
133
+ for data, param in zip(w2s, w2):
134
+ param.data = data
135
+
136
+ self.w2 = self.w2.view(len(w2), *w2s[0].shape)
137
+
138
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
139
+ num_tokens, hidden_dim = hidden_states.shape
140
+ hidden_states = hidden_states.view(-1, hidden_dim)
141
+ if self.config.n_shared_experts is not None:
142
+ shared_output = self.shared_experts(hidden_states)
143
+ # router_logits: (num_tokens, n_experts)
144
+ router_logits, _ = self.gate(hidden_states)
145
+ final_hidden_states = fused_moe(
146
+ hidden_states,
147
+ self.w1,
148
+ self.w2,
149
+ router_logits,
150
+ self.top_k,
151
+ renormalize=self.config.norm_topk_prob,
152
+ inplace=True,
153
+ )
154
+
155
+ if self.config.n_shared_experts is not None:
156
+ final_hidden_states = final_hidden_states + shared_output
157
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
158
+
159
+ return final_hidden_states.view(num_tokens, hidden_dim)
160
+
161
+
162
+ class DeepseekAttention(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ hidden_size: int,
167
+ num_heads: int,
168
+ num_kv_heads: int,
169
+ layer_id: int = 0,
170
+ rope_theta: float = 10000,
171
+ rope_scaling: Optional[Dict[str, Any]] = None,
172
+ max_position_embeddings: int = 8192,
173
+ cache_config: Optional[CacheConfig] = None,
174
+ quant_config: Optional[QuantizationConfig] = None,
175
+ ) -> None:
176
+ super().__init__()
177
+ self.hidden_size = hidden_size
178
+ tp_size = get_tensor_model_parallel_world_size()
179
+ self.total_num_heads = num_heads
180
+ assert self.total_num_heads % tp_size == 0
181
+ self.num_heads = self.total_num_heads // tp_size
182
+ self.total_num_kv_heads = num_kv_heads
183
+ if self.total_num_kv_heads >= tp_size:
184
+ # Number of KV heads is greater than TP size, so we partition
185
+ # the KV heads across multiple tensor parallel GPUs.
186
+ assert self.total_num_kv_heads % tp_size == 0
187
+ else:
188
+ # Number of KV heads is less than TP size, so we replicate
189
+ # the KV heads across multiple tensor parallel GPUs.
190
+ assert tp_size % self.total_num_kv_heads == 0
191
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
192
+ self.head_dim = hidden_size // self.total_num_heads
193
+ self.q_size = self.num_heads * self.head_dim
194
+ self.kv_size = self.num_kv_heads * self.head_dim
195
+ self.scaling = self.head_dim**-0.5
196
+ self.rope_theta = rope_theta
197
+ self.max_position_embeddings = max_position_embeddings
198
+
199
+ self.qkv_proj = QKVParallelLinear(
200
+ hidden_size,
201
+ self.head_dim,
202
+ self.total_num_heads,
203
+ self.total_num_kv_heads,
204
+ bias=False,
205
+ quant_config=quant_config,
206
+ )
207
+
208
+ self.o_proj = RowParallelLinear(
209
+ self.total_num_heads * self.head_dim,
210
+ hidden_size,
211
+ bias=False,
212
+ quant_config=quant_config,
213
+ )
214
+
215
+ self.rotary_emb = get_rope(
216
+ self.head_dim,
217
+ rotary_dim=self.head_dim,
218
+ max_position=max_position_embeddings,
219
+ base=rope_theta,
220
+ rope_scaling=rope_scaling,
221
+ )
222
+ self.attn = RadixAttention(
223
+ self.num_heads,
224
+ self.head_dim,
225
+ self.scaling,
226
+ num_kv_heads=self.num_kv_heads,
227
+ layer_id=layer_id,
228
+ )
229
+
230
+ def forward(
231
+ self,
232
+ positions: torch.Tensor,
233
+ hidden_states: torch.Tensor,
234
+ input_metadata: InputMetadata,
235
+ ) -> torch.Tensor:
236
+ qkv, _ = self.qkv_proj(hidden_states)
237
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
238
+ q, k = self.rotary_emb(positions, q, k)
239
+ attn_output = self.attn(q, k, v, input_metadata)
240
+ output, _ = self.o_proj(attn_output)
241
+ return output
242
+
243
+
244
+ class DeepseekDecoderLayer(nn.Module):
245
+
246
+ def __init__(
247
+ self,
248
+ config: PretrainedConfig,
249
+ layer_id: int,
250
+ cache_config: Optional[CacheConfig] = None,
251
+ quant_config: Optional[QuantizationConfig] = None,
252
+ ) -> None:
253
+ super().__init__()
254
+ self.hidden_size = config.hidden_size
255
+ rope_theta = getattr(config, "rope_theta", 10000)
256
+ rope_scaling = getattr(config, "rope_scaling", None)
257
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
258
+ self.self_attn = DeepseekAttention(
259
+ hidden_size=self.hidden_size,
260
+ num_heads=config.num_attention_heads,
261
+ num_kv_heads=config.num_key_value_heads,
262
+ layer_id=layer_id,
263
+ rope_theta=rope_theta,
264
+ rope_scaling=rope_scaling,
265
+ max_position_embeddings=max_position_embeddings,
266
+ cache_config=cache_config,
267
+ quant_config=quant_config,
268
+ )
269
+ if (
270
+ config.n_routed_experts is not None
271
+ and layer_id >= config.first_k_dense_replace
272
+ and layer_id % config.moe_layer_freq == 0
273
+ ):
274
+ self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
275
+ else:
276
+ self.mlp = DeepseekMLP(
277
+ hidden_size=config.hidden_size,
278
+ intermediate_size=config.intermediate_size,
279
+ hidden_act=config.hidden_act,
280
+ quant_config=quant_config,
281
+ )
282
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
283
+ self.post_attention_layernorm = RMSNorm(
284
+ config.hidden_size, eps=config.rms_norm_eps
285
+ )
286
+
287
+ def forward(
288
+ self,
289
+ positions: torch.Tensor,
290
+ hidden_states: torch.Tensor,
291
+ input_metadata: InputMetadata,
292
+ residual: Optional[torch.Tensor],
293
+ ) -> torch.Tensor:
294
+ # Self Attention
295
+ if residual is None:
296
+ residual = hidden_states
297
+ hidden_states = self.input_layernorm(hidden_states)
298
+ else:
299
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
300
+ hidden_states = self.self_attn(
301
+ positions=positions,
302
+ hidden_states=hidden_states,
303
+ input_metadata=input_metadata,
304
+ )
305
+
306
+ # Fully Connected
307
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
308
+ hidden_states = self.mlp(hidden_states)
309
+ return hidden_states, residual
310
+
311
+
312
+ class DeepseekModel(nn.Module):
313
+
314
+ fall_back_to_pt_during_load = False
315
+
316
+ def __init__(
317
+ self,
318
+ config: PretrainedConfig,
319
+ cache_config: Optional[CacheConfig] = None,
320
+ quant_config: Optional[QuantizationConfig] = None,
321
+ ) -> None:
322
+ super().__init__()
323
+ self.padding_idx = config.pad_token_id
324
+ self.vocab_size = config.vocab_size
325
+
326
+ self.embed_tokens = VocabParallelEmbedding(
327
+ config.vocab_size,
328
+ config.hidden_size,
329
+ )
330
+ self.layers = nn.ModuleList(
331
+ [
332
+ DeepseekDecoderLayer(
333
+ config, layer_id, cache_config, quant_config=quant_config
334
+ )
335
+ for layer_id in range(config.num_hidden_layers)
336
+ ]
337
+ )
338
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
339
+
340
+ def forward(
341
+ self,
342
+ input_ids: torch.Tensor,
343
+ positions: torch.Tensor,
344
+ input_metadata: InputMetadata,
345
+ ) -> torch.Tensor:
346
+ hidden_states = self.embed_tokens(input_ids)
347
+ residual = None
348
+ for i in range(len(self.layers)):
349
+ layer = self.layers[i]
350
+ hidden_states, residual = layer(
351
+ positions, hidden_states, input_metadata, residual
352
+ )
353
+ hidden_states, _ = self.norm(hidden_states, residual)
354
+ return hidden_states
355
+
356
+
357
+ class DeepseekForCausalLM(nn.Module):
358
+
359
+ def __init__(
360
+ self,
361
+ config: PretrainedConfig,
362
+ cache_config: Optional[CacheConfig] = None,
363
+ quant_config: Optional[QuantizationConfig] = None,
364
+ ) -> None:
365
+ super().__init__()
366
+ self.config = config
367
+ self.quant_config = quant_config
368
+ self.model = DeepseekModel(config, cache_config, quant_config)
369
+ self.lm_head = ParallelLMHead(
370
+ config.vocab_size, config.hidden_size, quant_config=quant_config
371
+ )
372
+ self.logits_processor = LogitsProcessor(config)
373
+
374
+ @torch.no_grad()
375
+ def forward(
376
+ self,
377
+ input_ids: torch.Tensor,
378
+ positions: torch.Tensor,
379
+ input_metadata: InputMetadata,
380
+ ) -> torch.Tensor:
381
+ hidden_states = self.model(input_ids, positions, input_metadata)
382
+ return self.logits_processor(
383
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
384
+ )
385
+
386
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
387
+ stacked_params_mapping = [
388
+ # (param_name, shard_name, shard_id)
389
+ ("qkv_proj", "q_proj", "q"),
390
+ ("qkv_proj", "k_proj", "k"),
391
+ ("qkv_proj", "v_proj", "v"),
392
+ ("gate_up_proj", "gate_proj", 0),
393
+ ("gate_up_proj", "up_proj", 1),
394
+ ]
395
+
396
+ params_dict = dict(self.named_parameters())
397
+ for name, loaded_weight in weights:
398
+ if "rotary_emb.inv_freq" in name:
399
+ continue
400
+ for param_name, weight_name, shard_id in stacked_params_mapping:
401
+ if weight_name not in name:
402
+ continue
403
+ name = name.replace(weight_name, param_name)
404
+ # Skip loading extra bias for GPTQ models.
405
+ if name.endswith(".bias") and name not in params_dict:
406
+ continue
407
+ # Skip experts that are not assigned to this worker.
408
+ if (
409
+ "mlp.experts." in name or "mlp.shared_experts." in name
410
+ ) and name not in params_dict:
411
+ continue
412
+ param = params_dict[name]
413
+ weight_loader = param.weight_loader
414
+ weight_loader(param, loaded_weight, shard_id)
415
+ break
416
+ else:
417
+ # Skip loading extra bias for GPTQ models.
418
+ if name.endswith(".bias") and name not in params_dict:
419
+ continue
420
+ # Skip experts that are not assigned to this worker.
421
+ if (
422
+ "mlp.experts." in name or "mlp.shared_experts." in name
423
+ ) and name not in params_dict:
424
+ continue
425
+ param = params_dict[name]
426
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
427
+ weight_loader(param, loaded_weight)
428
+
429
+
430
+ EntryClass = DeepseekForCausalLM
@@ -0,0 +1,282 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
3
+ """Inference-only GPTBigCode model compatible with HuggingFace weights."""
4
+ from typing import Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import GPTBigCodeConfig
9
+ from vllm.config import CacheConfig, LoRAConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
11
+ from vllm.model_executor.layers.activation import get_act_fn
12
+ from vllm.model_executor.layers.linear import (
13
+ ColumnParallelLinear,
14
+ QKVParallelLinear,
15
+ RowParallelLinear,
16
+ )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
19
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
20
+
21
+ from sglang.srt.layers.logits_processor import LogitsProcessor
22
+ from sglang.srt.layers.radix_attention import RadixAttention
23
+ from sglang.srt.managers.controller.infer_batch import InputMetadata
24
+
25
+
26
+ class GPTBigCodeAttention(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ layer_id: int,
31
+ config: GPTBigCodeConfig,
32
+ cache_config: Optional[CacheConfig] = None,
33
+ quant_config: Optional[QuantizationConfig] = None,
34
+ ):
35
+ super().__init__()
36
+ self.hidden_size = config.hidden_size
37
+ total_num_heads = config.num_attention_heads
38
+ self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
39
+ assert total_num_heads % self.tensor_model_parallel_world_size == 0
40
+ self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
41
+ self.head_dim = self.hidden_size // total_num_heads
42
+ self.scale = self.head_dim**-0.5
43
+
44
+ self.multi_query = config.multi_query
45
+ if self.multi_query:
46
+ total_num_kv_heads = 1
47
+ self.num_kv_heads = 1
48
+ else:
49
+ total_num_kv_heads = total_num_heads
50
+ self.num_kv_heads = self.num_heads
51
+ self.kv_dim = self.head_dim * self.num_kv_heads
52
+ self.c_attn = QKVParallelLinear(
53
+ self.hidden_size,
54
+ self.head_dim,
55
+ total_num_heads,
56
+ total_num_kv_heads,
57
+ bias=True,
58
+ quant_config=quant_config,
59
+ )
60
+
61
+ self.c_proj = RowParallelLinear(
62
+ self.hidden_size,
63
+ self.hidden_size,
64
+ bias=True,
65
+ quant_config=quant_config,
66
+ )
67
+ self.attn = RadixAttention(
68
+ self.num_heads,
69
+ self.head_dim,
70
+ scaling=self.scale,
71
+ num_kv_heads=self.num_kv_heads,
72
+ layer_id=layer_id,
73
+ )
74
+
75
+ def forward(
76
+ self,
77
+ hidden_states: torch.Tensor,
78
+ input_metadata: InputMetadata,
79
+ ) -> torch.Tensor:
80
+ qkv, _ = self.c_attn(hidden_states)
81
+ q, k, v = qkv.split(
82
+ [
83
+ self.hidden_size // self.tensor_model_parallel_world_size,
84
+ self.kv_dim,
85
+ self.kv_dim,
86
+ ],
87
+ dim=-1,
88
+ )
89
+ attn_output = self.attn(q, k, v, input_metadata)
90
+ attn_output, _ = self.c_proj(attn_output)
91
+ return attn_output
92
+
93
+
94
+ class GPTBigMLP(nn.Module):
95
+
96
+ def __init__(
97
+ self,
98
+ intermediate_size: int,
99
+ config: GPTBigCodeConfig,
100
+ quant_config: Optional[QuantizationConfig] = None,
101
+ ):
102
+ super().__init__()
103
+ hidden_size = config.hidden_size
104
+ self.c_fc = ColumnParallelLinear(
105
+ hidden_size,
106
+ intermediate_size,
107
+ bias=True,
108
+ quant_config=quant_config,
109
+ )
110
+ self.c_proj = RowParallelLinear(
111
+ intermediate_size,
112
+ hidden_size,
113
+ bias=True,
114
+ quant_config=quant_config,
115
+ )
116
+ self.act = get_act_fn(
117
+ config.activation_function, quant_config, intermediate_size
118
+ )
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
+ hidden_states, _ = self.c_fc(hidden_states)
122
+ hidden_states = self.act(hidden_states)
123
+ hidden_states, _ = self.c_proj(hidden_states)
124
+ return hidden_states
125
+
126
+
127
+ class GPTBigCodeBlock(nn.Module):
128
+
129
+ def __init__(
130
+ self,
131
+ layer_id: int,
132
+ config: GPTBigCodeConfig,
133
+ cache_config: Optional[CacheConfig] = None,
134
+ quant_config: Optional[QuantizationConfig] = None,
135
+ ):
136
+ super().__init__()
137
+ hidden_size = config.hidden_size
138
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
139
+
140
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
141
+ self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
142
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
143
+ self.mlp = GPTBigMLP(inner_dim, config, quant_config)
144
+
145
+ def forward(
146
+ self,
147
+ hidden_states: torch.Tensor,
148
+ input_metadata: InputMetadata,
149
+ ) -> torch.Tensor:
150
+ residual = hidden_states
151
+ hidden_states = self.ln_1(hidden_states)
152
+ attn_output = self.attn(
153
+ hidden_states=hidden_states, input_metadata=input_metadata
154
+ )
155
+ # residual connection
156
+ hidden_states = attn_output + residual
157
+
158
+ residual = hidden_states
159
+ hidden_states = self.ln_2(hidden_states)
160
+ feed_forward_hidden_states = self.mlp(hidden_states)
161
+ # residual connection
162
+ hidden_states = residual + feed_forward_hidden_states
163
+ return hidden_states
164
+
165
+
166
+ class GPTBigCodeModel(nn.Module):
167
+
168
+ def __init__(
169
+ self,
170
+ config: GPTBigCodeConfig,
171
+ cache_config: Optional[CacheConfig] = None,
172
+ quant_config: Optional[QuantizationConfig] = None,
173
+ lora_config: Optional[LoRAConfig] = None,
174
+ ):
175
+ super().__init__()
176
+ self.config = config
177
+ assert not config.add_cross_attention
178
+
179
+ self.embed_dim = config.hidden_size
180
+ lora_vocab = (
181
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
182
+ if lora_config
183
+ else 0
184
+ )
185
+ self.vocab_size = config.vocab_size + lora_vocab
186
+ self.wte = VocabParallelEmbedding(
187
+ self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
188
+ )
189
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
190
+ self.h = nn.ModuleList(
191
+ [
192
+ GPTBigCodeBlock(i, config, cache_config, quant_config)
193
+ for i in range(config.num_hidden_layers)
194
+ ]
195
+ )
196
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
197
+
198
+ def forward(
199
+ self,
200
+ input_ids: torch.Tensor,
201
+ position_ids: torch.Tensor,
202
+ input_metadata: InputMetadata,
203
+ ) -> torch.Tensor:
204
+ inputs_embeds = self.wte(input_ids)
205
+ position_embeds = self.wpe(position_ids)
206
+ hidden_states = inputs_embeds + position_embeds
207
+
208
+ for i in range(len(self.h)):
209
+ layer = self.h[i]
210
+ hidden_states = layer(hidden_states, input_metadata)
211
+
212
+ hidden_states = self.ln_f(hidden_states)
213
+ return hidden_states
214
+
215
+
216
+ class GPTBigCodeForCausalLM(nn.Module):
217
+ packed_modules_mapping = {"c_attn": ["c_attn"]}
218
+
219
+ supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
220
+
221
+ embedding_modules = {
222
+ "wte": "input_embeddings",
223
+ "lm_head": "output_embeddings",
224
+ }
225
+
226
+ embedding_padding_modules = []
227
+
228
+ def __init__(
229
+ self,
230
+ config: GPTBigCodeConfig,
231
+ cache_config: Optional[CacheConfig] = None,
232
+ quant_config: Optional[QuantizationConfig] = None,
233
+ lora_config: Optional[LoRAConfig] = None,
234
+ ):
235
+ super().__init__()
236
+
237
+ self.config = config
238
+ self.lora_config = lora_config
239
+
240
+ self.quant_config = quant_config
241
+ self.transformer = GPTBigCodeModel(
242
+ config, cache_config, quant_config, lora_config
243
+ )
244
+ self.lm_head = self.transformer.wte
245
+ self.unpadded_vocab_size = config.vocab_size
246
+ if lora_config:
247
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
248
+ self.logits_processor = LogitsProcessor(config)
249
+
250
+ @torch.no_grad()
251
+ def forward(
252
+ self,
253
+ input_ids: torch.Tensor,
254
+ positions: torch.Tensor,
255
+ input_metadata: InputMetadata,
256
+ ) -> torch.Tensor:
257
+ hidden_states = self.transformer(input_ids, positions, input_metadata)
258
+ return self.logits_processor(
259
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
260
+ )
261
+
262
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
263
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
264
+ for name, loaded_weight in weights:
265
+ if "lm_head.weight" in name:
266
+ continue
267
+ if ".attn.bias" in name:
268
+ # Skip attention mask.
269
+ # NOTE: "c_attn.bias" should not be skipped.
270
+ continue
271
+ param = params_dict[name]
272
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
273
+ # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
274
+ if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
275
+ weight_loader(param, loaded_weight, "q")
276
+ weight_loader(param, loaded_weight, "k")
277
+ weight_loader(param, loaded_weight, "v")
278
+ else:
279
+ weight_loader(param, loaded_weight)
280
+
281
+
282
+ EntryClass = GPTBigCodeForCausalLM