sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +57 -2
- sglang/api.py +8 -5
- sglang/backend/anthropic.py +18 -4
- sglang/backend/openai.py +2 -1
- sglang/backend/runtime_endpoint.py +18 -5
- sglang/backend/vertexai.py +1 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +83 -2
- sglang/lang/interpreter.py +92 -35
- sglang/lang/ir.py +12 -9
- sglang/lang/tracer.py +6 -4
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +1 -0
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +2 -2
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +10 -2
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +1 -0
- sglang/srt/layers/logits_processor.py +114 -54
- sglang/srt/layers/radix_attention.py +2 -1
- sglang/srt/layers/token_attention.py +1 -0
- sglang/srt/managers/detokenizer_manager.py +5 -1
- sglang/srt/managers/io_struct.py +27 -3
- sglang/srt/managers/router/infer_batch.py +97 -48
- sglang/srt/managers/router/manager.py +11 -8
- sglang/srt/managers/router/model_rpc.py +169 -90
- sglang/srt/managers/router/model_runner.py +110 -166
- sglang/srt/managers/router/radix_cache.py +89 -51
- sglang/srt/managers/router/scheduler.py +17 -28
- sglang/srt/managers/tokenizer_manager.py +110 -33
- sglang/srt/memory_pool.py +5 -14
- sglang/srt/model_config.py +11 -0
- sglang/srt/models/commandr.py +372 -0
- sglang/srt/models/dbrx.py +412 -0
- sglang/srt/models/dbrx_config.py +281 -0
- sglang/srt/models/gemma.py +24 -25
- sglang/srt/models/llama2.py +25 -26
- sglang/srt/models/llava.py +8 -10
- sglang/srt/models/llavavid.py +307 -0
- sglang/srt/models/mixtral.py +29 -33
- sglang/srt/models/qwen.py +34 -25
- sglang/srt/models/qwen2.py +25 -26
- sglang/srt/models/stablelm.py +26 -26
- sglang/srt/models/yivl.py +3 -5
- sglang/srt/openai_api_adapter.py +356 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +91 -456
- sglang/srt/server_args.py +79 -49
- sglang/srt/utils.py +212 -47
- sglang/srt/weight_utils.py +417 -0
- sglang/test/test_programs.py +8 -7
- sglang/test/test_utils.py +195 -7
- sglang/utils.py +77 -26
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
- sglang-0.1.16.dist-info/RECORD +72 -0
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,412 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py
|
3
|
+
# coding=utf-8
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import torch.nn as nn
|
8
|
+
from vllm.distributed import (
|
9
|
+
get_tensor_model_parallel_rank,
|
10
|
+
get_tensor_model_parallel_world_size,
|
11
|
+
tensor_model_parallel_all_reduce,
|
12
|
+
)
|
13
|
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
14
|
+
from vllm.model_executor.layers.linear import (
|
15
|
+
QKVParallelLinear,
|
16
|
+
ReplicatedLinear,
|
17
|
+
RowParallelLinear,
|
18
|
+
)
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
20
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
21
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
22
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
23
|
+
ParallelLMHead,
|
24
|
+
VocabParallelEmbedding,
|
25
|
+
)
|
26
|
+
from vllm.model_executor.utils import set_weight_attrs
|
27
|
+
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.managers.router.model_runner import InputMetadata
|
31
|
+
from sglang.srt.models.dbrx_config import DbrxConfig
|
32
|
+
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
|
33
|
+
|
34
|
+
|
35
|
+
class DbrxRouter(nn.Module):
|
36
|
+
"""A Router implementation for DBRX that returns logits for each expert
|
37
|
+
per token.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
config: DbrxConfig,
|
43
|
+
params_dtype: Optional[torch.dtype] = None,
|
44
|
+
):
|
45
|
+
super().__init__()
|
46
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
47
|
+
self.num_total_experts = config.ffn_config.moe_num_experts
|
48
|
+
self.d_model = config.d_model
|
49
|
+
self.layer = ReplicatedLinear(
|
50
|
+
self.d_model,
|
51
|
+
self.num_total_experts,
|
52
|
+
bias=False,
|
53
|
+
params_dtype=params_dtype,
|
54
|
+
quant_config=None,
|
55
|
+
)
|
56
|
+
|
57
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
58
|
+
router_logits, _ = self.layer(hidden_states)
|
59
|
+
return router_logits
|
60
|
+
|
61
|
+
|
62
|
+
class DbrxExperts(nn.Module):
|
63
|
+
"""A tensor-parallel MoE implementation for DBRX.
|
64
|
+
|
65
|
+
Each expert's weights are sharded across all ranks and a fused MoE
|
66
|
+
kernel is used for the forward pass, and finally we reduce the outputs
|
67
|
+
across ranks.
|
68
|
+
"""
|
69
|
+
|
70
|
+
def __init__(
|
71
|
+
self,
|
72
|
+
config: DbrxConfig,
|
73
|
+
quant_config: Optional[QuantizationConfig] = None,
|
74
|
+
params_dtype: Optional[torch.dtype] = None,
|
75
|
+
):
|
76
|
+
super().__init__()
|
77
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
78
|
+
self.num_total_experts = config.ffn_config.moe_num_experts
|
79
|
+
self.top_k = config.ffn_config.moe_top_k
|
80
|
+
self.d_model = config.d_model
|
81
|
+
self.intermediate_size = config.ffn_config.ffn_hidden_size // self.tp_size
|
82
|
+
|
83
|
+
if params_dtype is None:
|
84
|
+
params_dtype = torch.get_default_dtype()
|
85
|
+
self.params_dtype = params_dtype
|
86
|
+
|
87
|
+
self.router = DbrxRouter(config, self.params_dtype)
|
88
|
+
self.ws = nn.Parameter(
|
89
|
+
torch.empty(
|
90
|
+
self.num_total_experts,
|
91
|
+
2 * self.intermediate_size,
|
92
|
+
self.d_model,
|
93
|
+
device="cuda",
|
94
|
+
dtype=self.params_dtype,
|
95
|
+
)
|
96
|
+
)
|
97
|
+
self.w2s = nn.Parameter(
|
98
|
+
torch.empty(
|
99
|
+
self.num_total_experts,
|
100
|
+
self.d_model,
|
101
|
+
self.intermediate_size,
|
102
|
+
device="cuda",
|
103
|
+
dtype=self.params_dtype,
|
104
|
+
)
|
105
|
+
)
|
106
|
+
|
107
|
+
set_weight_attrs(
|
108
|
+
self.ws,
|
109
|
+
{
|
110
|
+
"weight_loader": self.weight_loader,
|
111
|
+
},
|
112
|
+
)
|
113
|
+
set_weight_attrs(
|
114
|
+
self.w2s,
|
115
|
+
{
|
116
|
+
"weight_loader": self.weight_loader,
|
117
|
+
},
|
118
|
+
)
|
119
|
+
|
120
|
+
def weight_loader(
|
121
|
+
self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str
|
122
|
+
):
|
123
|
+
tp_rank = get_tensor_model_parallel_rank()
|
124
|
+
param_data = param.data
|
125
|
+
shard_size = self.intermediate_size
|
126
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
127
|
+
# DBRX uses GLU for each experts.
|
128
|
+
# GLU has 3 linear layers: w1, v1 and w2.
|
129
|
+
if weight_name.endswith("w1"):
|
130
|
+
loaded_weight = torch.reshape(
|
131
|
+
loaded_weight,
|
132
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
133
|
+
)
|
134
|
+
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
135
|
+
if weight_name.endswith("v1"):
|
136
|
+
loaded_weight = torch.reshape(
|
137
|
+
loaded_weight,
|
138
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
139
|
+
)
|
140
|
+
param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[:, shard, :]
|
141
|
+
if weight_name.endswith("w2"):
|
142
|
+
loaded_weight = torch.reshape(
|
143
|
+
loaded_weight,
|
144
|
+
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
145
|
+
).transpose(1, 2)
|
146
|
+
param_data[:] = loaded_weight[:, :, shard]
|
147
|
+
|
148
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
149
|
+
num_tokens, hidden_size = hidden_states.shape
|
150
|
+
hidden_states = hidden_states.view(-1, self.d_model)
|
151
|
+
# router_logits: (num_tokens, n_experts)
|
152
|
+
router_logits = self.router(hidden_states)
|
153
|
+
final_hidden_states = fused_moe(
|
154
|
+
hidden_states,
|
155
|
+
self.ws,
|
156
|
+
self.w2s,
|
157
|
+
router_logits,
|
158
|
+
self.top_k,
|
159
|
+
renormalize=True,
|
160
|
+
inplace=True,
|
161
|
+
)
|
162
|
+
|
163
|
+
if self.tp_size > 1:
|
164
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
165
|
+
|
166
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
167
|
+
|
168
|
+
|
169
|
+
class DbrxAttention(nn.Module):
|
170
|
+
def __init__(
|
171
|
+
self,
|
172
|
+
config: DbrxConfig,
|
173
|
+
layer_id: int = 0,
|
174
|
+
quant_config: Optional[QuantizationConfig] = None,
|
175
|
+
):
|
176
|
+
super().__init__()
|
177
|
+
self.d_model = config.d_model
|
178
|
+
self.total_num_heads = config.n_heads
|
179
|
+
self.head_dim = self.d_model // self.total_num_heads
|
180
|
+
self.total_num_kv_heads = config.attn_config.kv_n_heads
|
181
|
+
self.clip_qkv = config.attn_config.clip_qkv
|
182
|
+
self.rope_theta = config.attn_config.rope_theta
|
183
|
+
self.max_position = config.max_seq_len
|
184
|
+
|
185
|
+
# pylint: disable=invalid-name
|
186
|
+
self.Wqkv = QKVParallelLinear(
|
187
|
+
self.d_model,
|
188
|
+
self.head_dim,
|
189
|
+
self.total_num_heads,
|
190
|
+
self.total_num_kv_heads,
|
191
|
+
bias=False,
|
192
|
+
quant_config=quant_config,
|
193
|
+
)
|
194
|
+
self.out_proj = RowParallelLinear(
|
195
|
+
self.d_model,
|
196
|
+
self.d_model,
|
197
|
+
bias=False,
|
198
|
+
quant_config=quant_config,
|
199
|
+
)
|
200
|
+
self.rotary_emb = get_rope(
|
201
|
+
self.head_dim,
|
202
|
+
rotary_dim=self.head_dim,
|
203
|
+
max_position=self.max_position,
|
204
|
+
base=int(self.rope_theta),
|
205
|
+
is_neox_style=True,
|
206
|
+
)
|
207
|
+
|
208
|
+
tp_world_size = get_tensor_model_parallel_world_size()
|
209
|
+
self.tp_size = tp_world_size
|
210
|
+
assert self.total_num_heads % tp_world_size == 0
|
211
|
+
self.num_heads = self.total_num_heads // tp_world_size
|
212
|
+
if self.total_num_kv_heads >= tp_world_size:
|
213
|
+
# Number of KV heads is greater than TP size, so we partition
|
214
|
+
# the KV heads across multiple tensor parallel GPUs.
|
215
|
+
assert self.total_num_kv_heads % tp_world_size == 0
|
216
|
+
else:
|
217
|
+
# Number of KV heads is less than TP size, so we replicate
|
218
|
+
# the KV heads across multiple tensor parallel GPUs.
|
219
|
+
assert tp_world_size % self.total_num_kv_heads == 0
|
220
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
221
|
+
self.q_size = self.num_heads * self.head_dim
|
222
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
223
|
+
self.scaling = self.head_dim**-0.5
|
224
|
+
self.attn = RadixAttention(
|
225
|
+
self.num_heads,
|
226
|
+
self.head_dim,
|
227
|
+
self.scaling,
|
228
|
+
num_kv_heads=self.num_kv_heads,
|
229
|
+
layer_id=layer_id,
|
230
|
+
)
|
231
|
+
|
232
|
+
def forward(
|
233
|
+
self,
|
234
|
+
position_ids: torch.Tensor,
|
235
|
+
hidden_states: torch.Tensor,
|
236
|
+
input_metadata: InputMetadata,
|
237
|
+
) -> torch.Tensor:
|
238
|
+
qkv, _ = self.Wqkv(hidden_states)
|
239
|
+
if self.clip_qkv is not None:
|
240
|
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
241
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
242
|
+
q, k = self.rotary_emb(position_ids, q, k)
|
243
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
244
|
+
hidden_states, _ = self.out_proj(attn_output)
|
245
|
+
return hidden_states
|
246
|
+
|
247
|
+
|
248
|
+
class DbrxFusedNormAttention(nn.Module):
|
249
|
+
def __init__(
|
250
|
+
self,
|
251
|
+
config: DbrxConfig,
|
252
|
+
layer_id: int = 0,
|
253
|
+
quant_config: Optional[QuantizationConfig] = None,
|
254
|
+
):
|
255
|
+
super().__init__()
|
256
|
+
self.d_model = config.d_model
|
257
|
+
self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
|
258
|
+
self.norm_1 = nn.LayerNorm(self.d_model)
|
259
|
+
self.norm_2 = nn.LayerNorm(self.d_model)
|
260
|
+
|
261
|
+
def forward(
|
262
|
+
self,
|
263
|
+
position_ids: torch.Tensor,
|
264
|
+
hidden_states: torch.Tensor,
|
265
|
+
input_metadata: InputMetadata,
|
266
|
+
) -> torch.Tensor:
|
267
|
+
residual = hidden_states
|
268
|
+
hidden_states = self.norm_1(hidden_states)
|
269
|
+
x = self.attn(
|
270
|
+
position_ids=position_ids,
|
271
|
+
hidden_states=hidden_states,
|
272
|
+
input_metadata=input_metadata,
|
273
|
+
)
|
274
|
+
hidden_states = residual + x
|
275
|
+
residual = hidden_states
|
276
|
+
hidden_states = self.norm_2(hidden_states)
|
277
|
+
return hidden_states, residual
|
278
|
+
|
279
|
+
|
280
|
+
class DbrxBlock(nn.Module):
|
281
|
+
def __init__(
|
282
|
+
self,
|
283
|
+
config: DbrxConfig,
|
284
|
+
layer_id: int = 0,
|
285
|
+
quant_config: Optional[QuantizationConfig] = None,
|
286
|
+
):
|
287
|
+
super().__init__()
|
288
|
+
self.norm_attn_norm = DbrxFusedNormAttention(
|
289
|
+
config, layer_id, quant_config=quant_config
|
290
|
+
)
|
291
|
+
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
292
|
+
|
293
|
+
def forward(
|
294
|
+
self,
|
295
|
+
position_ids: torch.Tensor,
|
296
|
+
hidden_states: torch.Tensor,
|
297
|
+
input_metadata: InputMetadata,
|
298
|
+
) -> torch.Tensor:
|
299
|
+
hidden_states, residual = self.norm_attn_norm(
|
300
|
+
position_ids=position_ids,
|
301
|
+
hidden_states=hidden_states,
|
302
|
+
input_metadata=input_metadata,
|
303
|
+
)
|
304
|
+
hidden_states = self.ffn(hidden_states)
|
305
|
+
hidden_states = hidden_states + residual
|
306
|
+
return hidden_states
|
307
|
+
|
308
|
+
|
309
|
+
class DbrxModel(nn.Module):
|
310
|
+
def __init__(
|
311
|
+
self,
|
312
|
+
config: DbrxConfig,
|
313
|
+
quant_config: Optional[QuantizationConfig] = None,
|
314
|
+
):
|
315
|
+
super().__init__()
|
316
|
+
self.wte = VocabParallelEmbedding(
|
317
|
+
config.vocab_size,
|
318
|
+
config.d_model,
|
319
|
+
)
|
320
|
+
self.blocks = nn.ModuleList(
|
321
|
+
[
|
322
|
+
DbrxBlock(config, i, quant_config=quant_config)
|
323
|
+
for i in range(config.n_layers)
|
324
|
+
]
|
325
|
+
)
|
326
|
+
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
327
|
+
for module in self.modules():
|
328
|
+
if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
|
329
|
+
# Remove the bias term in Linear and LayerNorm.
|
330
|
+
module.register_parameter("bias", None)
|
331
|
+
|
332
|
+
def forward(
|
333
|
+
self,
|
334
|
+
input_ids: torch.Tensor,
|
335
|
+
position_ids: torch.Tensor,
|
336
|
+
input_metadata: InputMetadata,
|
337
|
+
input_embeds: torch.Tensor = None,
|
338
|
+
) -> torch.Tensor:
|
339
|
+
if input_embeds is None:
|
340
|
+
hidden_states = self.wte(input_ids)
|
341
|
+
else:
|
342
|
+
hidden_states = input_embeds
|
343
|
+
for i in range(len(self.blocks)):
|
344
|
+
block = self.blocks[i]
|
345
|
+
hidden_states = block(position_ids, hidden_states, input_metadata)
|
346
|
+
hidden_states = self.norm_f(hidden_states)
|
347
|
+
return hidden_states
|
348
|
+
|
349
|
+
|
350
|
+
class DbrxForCausalLM(nn.Module):
|
351
|
+
def __init__(
|
352
|
+
self,
|
353
|
+
config: DbrxConfig,
|
354
|
+
quant_config: Optional[QuantizationConfig] = None,
|
355
|
+
):
|
356
|
+
super().__init__()
|
357
|
+
self.config = config
|
358
|
+
self.quant_config = quant_config
|
359
|
+
self.unpadded_vocab_size = config.vocab_size
|
360
|
+
self.transformer = DbrxModel(config, quant_config=quant_config)
|
361
|
+
self.lm_head = ParallelLMHead(
|
362
|
+
config.vocab_size,
|
363
|
+
config.d_model,
|
364
|
+
org_num_embeddings=config.vocab_size,
|
365
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
366
|
+
)
|
367
|
+
self.logits_processor = LogitsProcessor(config)
|
368
|
+
|
369
|
+
def forward(
|
370
|
+
self,
|
371
|
+
input_ids: torch.Tensor,
|
372
|
+
positions: torch.Tensor,
|
373
|
+
input_metadata: InputMetadata,
|
374
|
+
) -> torch.Tensor:
|
375
|
+
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
376
|
+
return self.logits_processor(
|
377
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
378
|
+
)
|
379
|
+
|
380
|
+
def load_weights(
|
381
|
+
self,
|
382
|
+
model_name_or_path: str,
|
383
|
+
cache_dir: Optional[str] = None,
|
384
|
+
load_format: str = "auto",
|
385
|
+
revision: Optional[str] = None,
|
386
|
+
):
|
387
|
+
expert_params_mapping = [
|
388
|
+
(
|
389
|
+
"ws" if weight_name in ["w1", "v1"] else "w2s",
|
390
|
+
f"experts.mlp.{weight_name}",
|
391
|
+
)
|
392
|
+
for weight_name in ["w1", "v1", "w2"]
|
393
|
+
]
|
394
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
395
|
+
for name, loaded_weight in hf_model_weights_iterator(
|
396
|
+
model_name_or_path, cache_dir, load_format, revision
|
397
|
+
):
|
398
|
+
for param_name, weight_name in expert_params_mapping:
|
399
|
+
if weight_name not in name:
|
400
|
+
continue
|
401
|
+
name = name.replace(weight_name, param_name)
|
402
|
+
param = params_dict[name]
|
403
|
+
weight_loader = param.weight_loader
|
404
|
+
weight_loader(param, loaded_weight, weight_name)
|
405
|
+
break
|
406
|
+
else:
|
407
|
+
param = params_dict[name]
|
408
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
409
|
+
weight_loader(param, loaded_weight)
|
410
|
+
|
411
|
+
|
412
|
+
EntryClass = DbrxForCausalLM
|
@@ -0,0 +1,281 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/transformers_utils/configs/dbrx.py
|
3
|
+
# yapf: disable
|
4
|
+
# ruff: noqa: E501
|
5
|
+
# coding=utf-8
|
6
|
+
# Copied from
|
7
|
+
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
|
8
|
+
"""Dbrx configuration."""
|
9
|
+
|
10
|
+
# FIXME: remove this once vllm releases a new version
|
11
|
+
|
12
|
+
from typing import Any, Optional
|
13
|
+
|
14
|
+
from transformers.configuration_utils import PretrainedConfig
|
15
|
+
from transformers.utils import logging
|
16
|
+
|
17
|
+
logger = logging.get_logger(__name__)
|
18
|
+
|
19
|
+
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
20
|
+
|
21
|
+
|
22
|
+
class DbrxAttentionConfig(PretrainedConfig):
|
23
|
+
"""Configuration class for Dbrx Attention.
|
24
|
+
|
25
|
+
[`DbrxAttention`] class. It is used to instantiate attention layers
|
26
|
+
according to the specified arguments, defining the layers architecture.
|
27
|
+
|
28
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
29
|
+
documentation from [`PretrainedConfig`] for more information.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
33
|
+
The dropout probability for the attention layers.
|
34
|
+
clip_qkv (`float`, *optional*, defaults to None):
|
35
|
+
If not `None`, clip the queries, keys, and values in the attention layer to this value.
|
36
|
+
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
37
|
+
rope_theta (float): The base frequency for rope.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
attn_pdrop: float = 0,
|
43
|
+
clip_qkv: Optional[float] = None,
|
44
|
+
kv_n_heads: int = 1,
|
45
|
+
rope_theta: float = 10000.0,
|
46
|
+
**kwargs: Any,
|
47
|
+
):
|
48
|
+
super().__init__(**kwargs)
|
49
|
+
self.attn_pdrop = attn_pdrop
|
50
|
+
self.clip_qkv = clip_qkv
|
51
|
+
self.kv_n_heads = kv_n_heads
|
52
|
+
self.rope_theta = rope_theta
|
53
|
+
|
54
|
+
for k in ["model_type"]:
|
55
|
+
if k in kwargs:
|
56
|
+
kwargs.pop(k)
|
57
|
+
if len(kwargs) != 0:
|
58
|
+
raise ValueError(f"Found unknown {kwargs=}")
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def from_pretrained(
|
62
|
+
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
63
|
+
) -> "PretrainedConfig":
|
64
|
+
cls._set_token_in_kwargs(kwargs)
|
65
|
+
|
66
|
+
config_dict, kwargs = cls.get_config_dict(
|
67
|
+
pretrained_model_name_or_path, **kwargs
|
68
|
+
)
|
69
|
+
|
70
|
+
if config_dict.get("model_type") == "dbrx":
|
71
|
+
config_dict = config_dict["attn_config"]
|
72
|
+
|
73
|
+
if (
|
74
|
+
"model_type" in config_dict
|
75
|
+
and hasattr(cls, "model_type")
|
76
|
+
and config_dict["model_type"] != cls.model_type
|
77
|
+
):
|
78
|
+
logger.warning(
|
79
|
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
80
|
+
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
81
|
+
)
|
82
|
+
|
83
|
+
return cls.from_dict(config_dict, **kwargs)
|
84
|
+
|
85
|
+
|
86
|
+
class DbrxFFNConfig(PretrainedConfig):
|
87
|
+
"""Configuration class for Dbrx FFN.
|
88
|
+
|
89
|
+
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
90
|
+
the specified arguments, defining the layers architecture.
|
91
|
+
|
92
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
93
|
+
documentation from [`PretrainedConfig`] for more information.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
|
97
|
+
The dict should have a key 'name' with the value being the name of
|
98
|
+
the activation function along with any additional keyword arguments.
|
99
|
+
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
|
100
|
+
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
|
101
|
+
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
|
102
|
+
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
|
103
|
+
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
|
104
|
+
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
|
105
|
+
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
|
106
|
+
This should only be used for benchmarking purposes.
|
107
|
+
"""
|
108
|
+
|
109
|
+
def __init__(
|
110
|
+
self,
|
111
|
+
ffn_act_fn: Optional[dict] = None,
|
112
|
+
ffn_hidden_size: int = 3584,
|
113
|
+
moe_num_experts: int = 4,
|
114
|
+
moe_top_k: int = 1,
|
115
|
+
moe_jitter_eps: Optional[float] = None,
|
116
|
+
moe_loss_weight: float = 0.01,
|
117
|
+
moe_normalize_expert_weights: Optional[float] = 1,
|
118
|
+
uniform_expert_assignment: bool = False,
|
119
|
+
**kwargs: Any,
|
120
|
+
):
|
121
|
+
super().__init__()
|
122
|
+
if ffn_act_fn is None:
|
123
|
+
ffn_act_fn = {"name": "silu"}
|
124
|
+
self.ffn_act_fn = ffn_act_fn
|
125
|
+
self.ffn_hidden_size = ffn_hidden_size
|
126
|
+
self.moe_num_experts = moe_num_experts
|
127
|
+
self.moe_top_k = moe_top_k
|
128
|
+
self.moe_jitter_eps = moe_jitter_eps
|
129
|
+
self.moe_loss_weight = moe_loss_weight
|
130
|
+
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
131
|
+
self.uniform_expert_assignment = uniform_expert_assignment
|
132
|
+
|
133
|
+
for k in ["model_type"]:
|
134
|
+
if k in kwargs:
|
135
|
+
kwargs.pop(k)
|
136
|
+
if len(kwargs) != 0:
|
137
|
+
raise ValueError(f"Found unknown {kwargs=}")
|
138
|
+
|
139
|
+
@classmethod
|
140
|
+
def from_pretrained(
|
141
|
+
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
142
|
+
) -> "PretrainedConfig":
|
143
|
+
cls._set_token_in_kwargs(kwargs)
|
144
|
+
|
145
|
+
config_dict, kwargs = cls.get_config_dict(
|
146
|
+
pretrained_model_name_or_path, **kwargs
|
147
|
+
)
|
148
|
+
|
149
|
+
if config_dict.get("model_type") == "dbrx":
|
150
|
+
config_dict = config_dict["ffn_config"]
|
151
|
+
|
152
|
+
if (
|
153
|
+
"model_type" in config_dict
|
154
|
+
and hasattr(cls, "model_type")
|
155
|
+
and config_dict["model_type"] != cls.model_type
|
156
|
+
):
|
157
|
+
logger.warning(
|
158
|
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
159
|
+
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
160
|
+
)
|
161
|
+
|
162
|
+
return cls.from_dict(config_dict, **kwargs)
|
163
|
+
|
164
|
+
|
165
|
+
class DbrxConfig(PretrainedConfig):
|
166
|
+
"""Configuration class for Dbrx.
|
167
|
+
|
168
|
+
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
169
|
+
specified arguments, defining the model architecture.
|
170
|
+
|
171
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
172
|
+
documentation from [`PretrainedConfig`] for more information.
|
173
|
+
|
174
|
+
|
175
|
+
Args:
|
176
|
+
d_model (`int`, *optional*, defaults to 6144):
|
177
|
+
Dimensionality of the embeddings and hidden states.
|
178
|
+
n_heads (`int`, *optional*, defaults to 48):
|
179
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
180
|
+
n_layers (`int`, *optional*, defaults to 40):
|
181
|
+
Number of hidden layers in the Transformer encoder.
|
182
|
+
max_seq_len (`int`, *optional*, defaults to 32768):
|
183
|
+
The maximum sequence length of the model.
|
184
|
+
vocab_size (`int`, *optional*, defaults to 100352):
|
185
|
+
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
186
|
+
the `inputs_ids` passed when calling [`DbrxModel`].
|
187
|
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
188
|
+
The dropout probability applied to the attention output before combining with residual.
|
189
|
+
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
190
|
+
The dropout probability for the embedding layer.
|
191
|
+
attn_config (`dict`, *optional*):
|
192
|
+
A dictionary used to configure the model's attention module.
|
193
|
+
ffn_config (`dict`, *optional*):
|
194
|
+
A dictionary used to configure the model's FFN module.
|
195
|
+
use_cache (`bool`, *optional*, defaults to `False`):
|
196
|
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
197
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
198
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
199
|
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
200
|
+
Whether or not the router logits should be returned by the model. Enabling this will also
|
201
|
+
allow the model to output the auxiliary loss. See [here]() for more details
|
202
|
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
203
|
+
The aux loss factor for the total loss.
|
204
|
+
|
205
|
+
|
206
|
+
Example:
|
207
|
+
```python
|
208
|
+
>>> from transformers import DbrxConfig, DbrxModel
|
209
|
+
|
210
|
+
>>> # Initializing a Dbrx configuration
|
211
|
+
>>> configuration = DbrxConfig()
|
212
|
+
|
213
|
+
>>> # Initializing a model (with random weights) from the configuration
|
214
|
+
>>> model = DbrxModel(configuration)
|
215
|
+
|
216
|
+
>>> # Accessing the model configuration
|
217
|
+
>>> configuration = model.config
|
218
|
+
```
|
219
|
+
"""
|
220
|
+
|
221
|
+
model_type = "dbrx"
|
222
|
+
attribute_map = {
|
223
|
+
"num_attention_heads": "n_heads",
|
224
|
+
"hidden_size": "d_model",
|
225
|
+
"num_hidden_layers": "n_layers",
|
226
|
+
"max_position_embeddings": "max_seq_len",
|
227
|
+
}
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
d_model: int = 2048,
|
232
|
+
n_heads: int = 16,
|
233
|
+
n_layers: int = 24,
|
234
|
+
max_seq_len: int = 2048,
|
235
|
+
vocab_size: int = 32000,
|
236
|
+
resid_pdrop: float = 0.0,
|
237
|
+
emb_pdrop: float = 0.0,
|
238
|
+
attn_config: Optional[DbrxAttentionConfig] = None,
|
239
|
+
ffn_config: Optional[DbrxFFNConfig] = None,
|
240
|
+
use_cache: bool = True,
|
241
|
+
initializer_range: float = 0.02,
|
242
|
+
output_router_logits: bool = False,
|
243
|
+
router_aux_loss_coef: float = 0.05,
|
244
|
+
**kwargs: Any,
|
245
|
+
):
|
246
|
+
if attn_config is None:
|
247
|
+
self.attn_config = DbrxAttentionConfig()
|
248
|
+
elif isinstance(attn_config, dict):
|
249
|
+
self.attn_config = DbrxAttentionConfig(**attn_config)
|
250
|
+
else:
|
251
|
+
self.attn_config = attn_config
|
252
|
+
|
253
|
+
if ffn_config is None:
|
254
|
+
self.ffn_config = DbrxFFNConfig()
|
255
|
+
elif isinstance(ffn_config, dict):
|
256
|
+
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
257
|
+
else:
|
258
|
+
self.ffn_config = ffn_config
|
259
|
+
|
260
|
+
self.d_model = d_model
|
261
|
+
self.n_heads = n_heads
|
262
|
+
self.n_layers = n_layers
|
263
|
+
self.max_seq_len = max_seq_len
|
264
|
+
self.vocab_size = vocab_size
|
265
|
+
self.resid_pdrop = resid_pdrop
|
266
|
+
self.emb_pdrop = emb_pdrop
|
267
|
+
self.use_cache = use_cache
|
268
|
+
self.initializer_range = initializer_range
|
269
|
+
self.output_router_logits = output_router_logits
|
270
|
+
self.router_aux_loss_coef = router_aux_loss_coef
|
271
|
+
|
272
|
+
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
273
|
+
if tie_word_embeddings:
|
274
|
+
raise ValueError(
|
275
|
+
"tie_word_embeddings is not supported for Dbrx models."
|
276
|
+
)
|
277
|
+
|
278
|
+
super().__init__(
|
279
|
+
tie_word_embeddings=tie_word_embeddings,
|
280
|
+
**kwargs,
|
281
|
+
)
|