vllm-npu 0.4.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vllm/__init__.py +23 -0
- vllm/_custom_ops.py +251 -0
- vllm/attention/__init__.py +13 -0
- vllm/attention/backends/__init__.py +0 -0
- vllm/attention/backends/abstract.py +127 -0
- vllm/attention/backends/flash_attn.py +271 -0
- vllm/attention/backends/flashinfer.py +220 -0
- vllm/attention/backends/rocm_flash_attn.py +374 -0
- vllm/attention/backends/torch_sdpa.py +250 -0
- vllm/attention/backends/xformers.py +393 -0
- vllm/attention/layer.py +56 -0
- vllm/attention/ops/__init__.py +0 -0
- vllm/attention/ops/paged_attn.py +216 -0
- vllm/attention/ops/prefix_prefill.py +792 -0
- vllm/attention/ops/triton_flash_attention.py +810 -0
- vllm/attention/selector.py +91 -0
- vllm/block.py +84 -0
- vllm/config.py +1225 -0
- vllm/core/__init__.py +0 -0
- vllm/core/block/__init__.py +0 -0
- vllm/core/block/block_table.py +295 -0
- vllm/core/block/common.py +199 -0
- vllm/core/block/cpu_gpu_block_allocator.py +228 -0
- vllm/core/block/interfaces.py +205 -0
- vllm/core/block/naive_block.py +318 -0
- vllm/core/block/prefix_caching_block.py +606 -0
- vllm/core/block_manager_v1.py +625 -0
- vllm/core/block_manager_v2.py +258 -0
- vllm/core/evictor_v1.py +105 -0
- vllm/core/evictor_v2.py +127 -0
- vllm/core/interfaces.py +113 -0
- vllm/core/policy.py +45 -0
- vllm/core/scheduler.py +1163 -0
- vllm/distributed/__init__.py +3 -0
- vllm/distributed/communication_op.py +237 -0
- vllm/distributed/device_communicators/__init__.py +0 -0
- vllm/distributed/device_communicators/custom_all_reduce.py +274 -0
- vllm/distributed/device_communicators/pynccl.py +287 -0
- vllm/distributed/device_communicators/pynccl_utils.py +66 -0
- vllm/distributed/parallel_state.py +339 -0
- vllm/distributed/utils.py +136 -0
- vllm/engine/__init__.py +0 -0
- vllm/engine/arg_utils.py +649 -0
- vllm/engine/async_llm_engine.py +737 -0
- vllm/engine/llm_engine.py +784 -0
- vllm/engine/metrics.py +368 -0
- vllm/engine/output_processor/__init__.py +0 -0
- vllm/engine/output_processor/interfaces.py +76 -0
- vllm/engine/output_processor/multi_step.py +142 -0
- vllm/engine/output_processor/single_step.py +284 -0
- vllm/engine/output_processor/stop_checker.py +101 -0
- vllm/engine/output_processor/util.py +19 -0
- vllm/entrypoints/__init__.py +0 -0
- vllm/entrypoints/api_server.py +119 -0
- vllm/entrypoints/llm.py +259 -0
- vllm/entrypoints/openai/__init__.py +0 -0
- vllm/entrypoints/openai/api_server.py +186 -0
- vllm/entrypoints/openai/cli_args.py +115 -0
- vllm/entrypoints/openai/protocol.py +460 -0
- vllm/entrypoints/openai/serving_chat.py +392 -0
- vllm/entrypoints/openai/serving_completion.py +347 -0
- vllm/entrypoints/openai/serving_engine.py +234 -0
- vllm/envs.py +217 -0
- vllm/executor/__init__.py +0 -0
- vllm/executor/cpu_executor.py +152 -0
- vllm/executor/distributed_gpu_executor.py +115 -0
- vllm/executor/executor_base.py +115 -0
- vllm/executor/gpu_executor.py +150 -0
- vllm/executor/multiproc_worker_utils.py +263 -0
- vllm/executor/neuron_executor.py +91 -0
- vllm/executor/ray_gpu_executor.py +327 -0
- vllm/executor/ray_utils.py +119 -0
- vllm/logger.py +153 -0
- vllm/logging/__init__.py +5 -0
- vllm/logging/formatter.py +15 -0
- vllm/lora/__init__.py +0 -0
- vllm/lora/fully_sharded_layers.py +262 -0
- vllm/lora/layers.py +1181 -0
- vllm/lora/lora.py +167 -0
- vllm/lora/models.py +645 -0
- vllm/lora/punica.py +213 -0
- vllm/lora/request.py +32 -0
- vllm/lora/utils.py +98 -0
- vllm/lora/worker_manager.py +251 -0
- vllm/model_executor/__init__.py +7 -0
- vllm/model_executor/guided_decoding/__init__.py +25 -0
- vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +70 -0
- vllm/model_executor/guided_decoding/outlines_decoding.py +130 -0
- vllm/model_executor/guided_decoding/outlines_logits_processors.py +184 -0
- vllm/model_executor/layers/__init__.py +0 -0
- vllm/model_executor/layers/activation.py +173 -0
- vllm/model_executor/layers/fused_moe/__init__.py +7 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +140 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
- vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- vllm/model_executor/layers/fused_moe/fused_moe.py +479 -0
- vllm/model_executor/layers/layernorm.py +71 -0
- vllm/model_executor/layers/linear.py +709 -0
- vllm/model_executor/layers/logits_processor.py +115 -0
- vllm/model_executor/layers/ops/__init__.py +0 -0
- vllm/model_executor/layers/ops/rand.py +157 -0
- vllm/model_executor/layers/ops/sample.py +406 -0
- vllm/model_executor/layers/quantization/__init__.py +35 -0
- vllm/model_executor/layers/quantization/aqlm.py +376 -0
- vllm/model_executor/layers/quantization/awq.py +175 -0
- vllm/model_executor/layers/quantization/base_config.py +97 -0
- vllm/model_executor/layers/quantization/fp8.py +265 -0
- vllm/model_executor/layers/quantization/gptq.py +224 -0
- vllm/model_executor/layers/quantization/gptq_marlin.py +438 -0
- vllm/model_executor/layers/quantization/marlin.py +227 -0
- vllm/model_executor/layers/quantization/schema.py +84 -0
- vllm/model_executor/layers/quantization/squeezellm.py +137 -0
- vllm/model_executor/layers/rejection_sampler.py +405 -0
- vllm/model_executor/layers/rotary_embedding.py +525 -0
- vllm/model_executor/layers/sampler.py +1051 -0
- vllm/model_executor/layers/vocab_parallel_embedding.py +155 -0
- vllm/model_executor/model_loader/__init__.py +30 -0
- vllm/model_executor/model_loader/loader.py +362 -0
- vllm/model_executor/model_loader/neuron.py +136 -0
- vllm/model_executor/model_loader/tensorizer.py +368 -0
- vllm/model_executor/model_loader/utils.py +41 -0
- vllm/model_executor/model_loader/weight_utils.py +372 -0
- vllm/model_executor/models/__init__.py +119 -0
- vllm/model_executor/models/baichuan.py +410 -0
- vllm/model_executor/models/bloom.py +327 -0
- vllm/model_executor/models/chatglm.py +386 -0
- vllm/model_executor/models/commandr.py +373 -0
- vllm/model_executor/models/dbrx.py +413 -0
- vllm/model_executor/models/decilm.py +122 -0
- vllm/model_executor/models/deepseek.py +438 -0
- vllm/model_executor/models/falcon.py +444 -0
- vllm/model_executor/models/gemma.py +393 -0
- vllm/model_executor/models/gpt2.py +266 -0
- vllm/model_executor/models/gpt_bigcode.py +274 -0
- vllm/model_executor/models/gpt_j.py +281 -0
- vllm/model_executor/models/gpt_neox.py +295 -0
- vllm/model_executor/models/internlm2.py +323 -0
- vllm/model_executor/models/jais.py +333 -0
- vllm/model_executor/models/llama.py +442 -0
- vllm/model_executor/models/llava.py +239 -0
- vllm/model_executor/models/minicpm.py +531 -0
- vllm/model_executor/models/mixtral.py +583 -0
- vllm/model_executor/models/mixtral_quant.py +404 -0
- vllm/model_executor/models/mpt.py +295 -0
- vllm/model_executor/models/olmo.py +356 -0
- vllm/model_executor/models/opt.py +349 -0
- vllm/model_executor/models/orion.py +319 -0
- vllm/model_executor/models/phi.py +300 -0
- vllm/model_executor/models/qwen.py +284 -0
- vllm/model_executor/models/qwen2.py +367 -0
- vllm/model_executor/models/qwen2_moe.py +447 -0
- vllm/model_executor/models/stablelm.py +301 -0
- vllm/model_executor/models/starcoder2.py +302 -0
- vllm/model_executor/models/xverse.py +366 -0
- vllm/model_executor/sampling_metadata.py +588 -0
- vllm/model_executor/utils.py +35 -0
- vllm/outputs.py +150 -0
- vllm/py.typed +2 -0
- vllm/sampling_params.py +340 -0
- vllm/sequence.py +766 -0
- vllm/spec_decode/__init__.py +0 -0
- vllm/spec_decode/batch_expansion.py +397 -0
- vllm/spec_decode/interfaces.py +73 -0
- vllm/spec_decode/metrics.py +191 -0
- vllm/spec_decode/multi_step_worker.py +203 -0
- vllm/spec_decode/ngram_worker.py +176 -0
- vllm/spec_decode/spec_decode_worker.py +472 -0
- vllm/spec_decode/top1_proposer.py +200 -0
- vllm/spec_decode/util.py +228 -0
- vllm/test_utils.py +41 -0
- vllm/transformers_utils/__init__.py +0 -0
- vllm/transformers_utils/config.py +58 -0
- vllm/transformers_utils/configs/__init__.py +16 -0
- vllm/transformers_utils/configs/chatglm.py +68 -0
- vllm/transformers_utils/configs/dbrx.py +278 -0
- vllm/transformers_utils/configs/falcon.py +87 -0
- vllm/transformers_utils/configs/jais.py +236 -0
- vllm/transformers_utils/configs/mpt.py +178 -0
- vllm/transformers_utils/detokenizer.py +313 -0
- vllm/transformers_utils/tokenizer.py +149 -0
- vllm/transformers_utils/tokenizer_group/__init__.py +33 -0
- vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +55 -0
- vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +169 -0
- vllm/transformers_utils/tokenizer_group/tokenizer_group.py +78 -0
- vllm/transformers_utils/tokenizers/__init__.py +5 -0
- vllm/transformers_utils/tokenizers/baichuan.py +255 -0
- vllm/usage/__init__.py +0 -0
- vllm/usage/usage_lib.py +209 -0
- vllm/utils.py +677 -0
- vllm/worker/__init__.py +0 -0
- vllm/worker/cache_engine.py +105 -0
- vllm/worker/cpu_model_runner.py +346 -0
- vllm/worker/cpu_worker.py +321 -0
- vllm/worker/model_runner.py +1168 -0
- vllm/worker/neuron_model_runner.py +196 -0
- vllm/worker/neuron_worker.py +98 -0
- vllm/worker/worker.py +345 -0
- vllm/worker/worker_base.py +146 -0
- vllm_npu-0.4.2.dist-info/LICENSE +201 -0
- vllm_npu-0.4.2.dist-info/METADATA +173 -0
- vllm_npu-0.4.2.dist-info/RECORD +219 -0
- vllm_npu-0.4.2.dist-info/WHEEL +5 -0
- vllm_npu-0.4.2.dist-info/top_level.txt +1 -0
vllm/spec_decode/util.py
ADDED
@@ -0,0 +1,228 @@
|
|
1
|
+
from contextlib import contextmanager
|
2
|
+
from itertools import chain
|
3
|
+
from typing import Dict, List, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata,
|
8
|
+
SequenceGroupOutput, SequenceOutput)
|
9
|
+
|
10
|
+
SeqId = int
|
11
|
+
|
12
|
+
|
13
|
+
def get_all_seq_ids(
|
14
|
+
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
|
15
|
+
"""Given a list of SequenceGroupMetadata, create a list of all
|
16
|
+
sequence ids.
|
17
|
+
"""
|
18
|
+
return list(
|
19
|
+
chain.from_iterable([
|
20
|
+
seq_group_metadata.seq_data.keys()
|
21
|
+
for seq_group_metadata in seq_group_metadata_list
|
22
|
+
]))
|
23
|
+
|
24
|
+
|
25
|
+
def get_all_num_logprobs(
|
26
|
+
seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]:
|
27
|
+
"""Given a list of SequenceGroupMetadata, create a list of all num_logprobs.
|
28
|
+
|
29
|
+
If the sampling params do not call for any logprobs, return 0 for that
|
30
|
+
sequence.
|
31
|
+
"""
|
32
|
+
|
33
|
+
all_num_logprobs = []
|
34
|
+
for seq_group_metadata in seq_group_metadata_list:
|
35
|
+
num_logprobs = seq_group_metadata.sampling_params.logprobs
|
36
|
+
if seq_group_metadata.sampling_params.logprobs is None:
|
37
|
+
num_logprobs = 0
|
38
|
+
all_num_logprobs.append(num_logprobs)
|
39
|
+
|
40
|
+
return all_num_logprobs
|
41
|
+
|
42
|
+
|
43
|
+
def get_sampled_token_logprobs(
|
44
|
+
# shape [num_steps, batch_size, vocab_size]
|
45
|
+
logprob_tensor: torch.Tensor,
|
46
|
+
sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size]
|
47
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
48
|
+
"""Get the logprobs for the sampled tokens. Returns the ranks and logprobs.
|
49
|
+
"""
|
50
|
+
num_steps, batch_size, vocab_size = logprob_tensor.shape
|
51
|
+
|
52
|
+
selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
|
53
|
+
torch.arange(batch_size),
|
54
|
+
sampled_token_ids, ]
|
55
|
+
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
|
56
|
+
-1, -1, vocab_size)
|
57
|
+
sampled_token_ids_ranks = (logprob_tensor >=
|
58
|
+
expanded_selected_logprobs).sum(-1)
|
59
|
+
|
60
|
+
return sampled_token_ids_ranks, selected_logprobs
|
61
|
+
|
62
|
+
|
63
|
+
def create_sequence_group_output(
|
64
|
+
token_id: int,
|
65
|
+
token_id_logprob_rank: int,
|
66
|
+
token_id_logprob: float,
|
67
|
+
seq_id: SeqId,
|
68
|
+
topk_token_ids: List[int],
|
69
|
+
topk_logprobs: List[float],
|
70
|
+
) -> SequenceGroupOutput:
|
71
|
+
"""Create a SequenceGroupOutput given the sampling results.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
token_id (int): The sampled token for the sequence.
|
75
|
+
token_id_logprob_rank (int): The logprob rank of the sampled token.
|
76
|
+
token_id_logprob (float): The logprob value of the sampled token.
|
77
|
+
seq_id (int): The sequence id.
|
78
|
+
topk_token_ids (List[int]): The list of top-k token ids.
|
79
|
+
topk_logprobs (List[float]): The list of top-k logprobs.
|
80
|
+
"""
|
81
|
+
# vLLM logprobs always include the sampled token. In addition, the user may
|
82
|
+
# request topk-logprobs (where top-k varies per user up to max_logprobs).
|
83
|
+
logprobs: Dict[int, Logprob] = {
|
84
|
+
token_id: Logprob(
|
85
|
+
logprob=token_id_logprob,
|
86
|
+
rank=token_id_logprob_rank,
|
87
|
+
),
|
88
|
+
}
|
89
|
+
logprobs.update({
|
90
|
+
topk_token_ids[topk_logprob_index]: Logprob(
|
91
|
+
logprob=topk_logprobs[topk_logprob_index],
|
92
|
+
rank=topk_logprob_index + 1,
|
93
|
+
)
|
94
|
+
for topk_logprob_index, _ in enumerate(topk_token_ids)
|
95
|
+
})
|
96
|
+
|
97
|
+
return SequenceGroupOutput(
|
98
|
+
samples=[
|
99
|
+
SequenceOutput(parent_seq_id=seq_id,
|
100
|
+
output_token=token_id,
|
101
|
+
logprobs=logprobs)
|
102
|
+
],
|
103
|
+
# TODO add prompt logprobs support.
|
104
|
+
prompt_logprobs=None,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
def split_batch_by_proposal_len(
|
109
|
+
seq_group_metadata_list: List[SequenceGroupMetadata],
|
110
|
+
proposal_lens: List[int], select_proposal_len_zero: bool
|
111
|
+
) -> Tuple[List[SequenceGroupMetadata], List[int]]:
|
112
|
+
"""Utility function that splits a batch based on whether the proposal len is
|
113
|
+
zero or not. We should remove this once vLLM supports per-sequence proposal
|
114
|
+
lens in a batch.
|
115
|
+
"""
|
116
|
+
|
117
|
+
if select_proposal_len_zero:
|
118
|
+
predicate = lambda proposal_len: proposal_len == 0
|
119
|
+
else:
|
120
|
+
predicate = lambda proposal_len: proposal_len != 0
|
121
|
+
|
122
|
+
indices = [
|
123
|
+
i for i, (_, proposal_len
|
124
|
+
) in enumerate(zip(seq_group_metadata_list, proposal_lens))
|
125
|
+
if predicate(proposal_len)
|
126
|
+
]
|
127
|
+
seq_groups = [
|
128
|
+
seq_group for seq_group, proposal_len in zip(
|
129
|
+
seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
|
130
|
+
]
|
131
|
+
|
132
|
+
return seq_groups, indices
|
133
|
+
|
134
|
+
|
135
|
+
def sampler_output_to_torch(
|
136
|
+
sampler_output_list: List[SamplerOutput], sampler_transposed: bool
|
137
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
138
|
+
"""Utility function which converts a list of SamplerOutput to tensors.
|
139
|
+
|
140
|
+
sampler_transposed here is used as the indicator for whether
|
141
|
+
we need do additional tensor transpose logic here.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
sampled_token_ids: torch.Tensor
|
145
|
+
shape: [batch_size, len(sampler_output_list)]
|
146
|
+
|
147
|
+
sampled_token_probs: torch.Tensor
|
148
|
+
shape: [batch_size, len(sampler_output_list), vocab_size]
|
149
|
+
"""
|
150
|
+
|
151
|
+
# shape: [batch_size, num_sampler_output, vocab_size]
|
152
|
+
sampled_token_probs = torch.stack(
|
153
|
+
[
|
154
|
+
sampler_output.sampled_token_probs
|
155
|
+
for sampler_output in sampler_output_list
|
156
|
+
],
|
157
|
+
dim=0,
|
158
|
+
)
|
159
|
+
|
160
|
+
if sampler_transposed:
|
161
|
+
sampled_token_probs = sampled_token_probs.transpose(0, 1)
|
162
|
+
|
163
|
+
# shape: [batch_size, num_sampler_output, vocab_size]
|
164
|
+
sampled_token_logprobs = torch.stack(
|
165
|
+
[sampler_output.logprobs for sampler_output in sampler_output_list],
|
166
|
+
dim=0,
|
167
|
+
)
|
168
|
+
|
169
|
+
if sampler_transposed:
|
170
|
+
sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1)
|
171
|
+
|
172
|
+
# shape: [batch_size, num_sampler_output]
|
173
|
+
sampled_token_ids = torch.stack(
|
174
|
+
[
|
175
|
+
sampler_output.sampled_token_ids.flatten()
|
176
|
+
for sampler_output in sampler_output_list
|
177
|
+
],
|
178
|
+
dim=0,
|
179
|
+
)
|
180
|
+
if sampler_transposed:
|
181
|
+
sampled_token_ids = sampled_token_ids.transpose(0, 1)
|
182
|
+
|
183
|
+
return sampled_token_ids, sampled_token_probs, sampled_token_logprobs
|
184
|
+
|
185
|
+
|
186
|
+
def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,
|
187
|
+
vocab_size: int, device: str) -> None:
|
188
|
+
"""Helper method which mocks out the GPU tensors in SamplerOutput with dummy
|
189
|
+
values. This will be removed in PR 7/9.
|
190
|
+
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer
|
191
|
+
"""
|
192
|
+
values = [
|
193
|
+
sampler_output.sampled_token_probs, sampler_output.sampled_token_ids
|
194
|
+
]
|
195
|
+
assert all(v is None for v in values) or not any(v is None for v in values)
|
196
|
+
if not any(v is None for v in values):
|
197
|
+
# Do nothing if the tensors are already created (usually in unit tests).
|
198
|
+
return
|
199
|
+
|
200
|
+
# Softmax to ensure valid probs.
|
201
|
+
sampler_output.sampled_token_probs = torch.nn.functional.softmax(
|
202
|
+
torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device),
|
203
|
+
dim=-1)
|
204
|
+
|
205
|
+
sampler_output.sampled_token_ids = torch.randint(low=10,
|
206
|
+
high=100,
|
207
|
+
size=(batch_size, ),
|
208
|
+
dtype=torch.long,
|
209
|
+
device=device)
|
210
|
+
|
211
|
+
|
212
|
+
@contextmanager
|
213
|
+
def nvtx_range(msg, *args, **kwargs):
|
214
|
+
"""
|
215
|
+
Context manager / decorator that pushes an NVTX range at the beginning
|
216
|
+
of its scope, and pops it at the end. If extra arguments are given,
|
217
|
+
they are passed as arguments to msg.format().
|
218
|
+
|
219
|
+
If running with cuda graphs, you must enable nsys cuda graph profiling.
|
220
|
+
|
221
|
+
Arguments:
|
222
|
+
msg (string): message to associate with the range
|
223
|
+
"""
|
224
|
+
torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
|
225
|
+
try:
|
226
|
+
yield
|
227
|
+
finally:
|
228
|
+
torch.cuda.nvtx.range_pop()
|
vllm/test_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
1
|
+
import ray
|
2
|
+
|
3
|
+
from vllm.distributed import (ensure_model_parallel_initialized,
|
4
|
+
init_distributed_environment)
|
5
|
+
from vllm.utils import get_open_port
|
6
|
+
|
7
|
+
|
8
|
+
def init_test_distributed_environment(
|
9
|
+
pipeline_parallel_size: int,
|
10
|
+
tensor_parallel_size: int,
|
11
|
+
rank: int,
|
12
|
+
distributed_init_port: str,
|
13
|
+
local_rank: int = -1,
|
14
|
+
) -> None:
|
15
|
+
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
16
|
+
init_distributed_environment(
|
17
|
+
world_size=pipeline_parallel_size * tensor_parallel_size,
|
18
|
+
rank=rank,
|
19
|
+
distributed_init_method=distributed_init_method,
|
20
|
+
local_rank=local_rank)
|
21
|
+
ensure_model_parallel_initialized(tensor_parallel_size,
|
22
|
+
pipeline_parallel_size)
|
23
|
+
|
24
|
+
|
25
|
+
def multi_process_tensor_parallel(
|
26
|
+
tensor_parallel_size: int,
|
27
|
+
test_target,
|
28
|
+
) -> None:
|
29
|
+
# Using ray helps debugging the error when it failed
|
30
|
+
# as compared to multiprocessing.
|
31
|
+
ray.init()
|
32
|
+
|
33
|
+
distributed_init_port = get_open_port()
|
34
|
+
refs = []
|
35
|
+
for rank in range(tensor_parallel_size):
|
36
|
+
refs.append(
|
37
|
+
test_target.remote(tensor_parallel_size, rank,
|
38
|
+
distributed_init_port))
|
39
|
+
ray.get(refs)
|
40
|
+
|
41
|
+
ray.shutdown()
|
File without changes
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import Dict, Optional
|
2
|
+
|
3
|
+
from transformers import AutoConfig, PretrainedConfig
|
4
|
+
|
5
|
+
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
6
|
+
JAISConfig, MPTConfig, RWConfig)
|
7
|
+
|
8
|
+
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
9
|
+
"chatglm": ChatGLMConfig,
|
10
|
+
"dbrx": DbrxConfig,
|
11
|
+
"mpt": MPTConfig,
|
12
|
+
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
13
|
+
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
14
|
+
"jais": JAISConfig,
|
15
|
+
}
|
16
|
+
|
17
|
+
|
18
|
+
def get_config(model: str,
|
19
|
+
trust_remote_code: bool,
|
20
|
+
revision: Optional[str] = None,
|
21
|
+
code_revision: Optional[str] = None) -> PretrainedConfig:
|
22
|
+
try:
|
23
|
+
config = AutoConfig.from_pretrained(
|
24
|
+
model,
|
25
|
+
trust_remote_code=trust_remote_code,
|
26
|
+
revision=revision,
|
27
|
+
code_revision=code_revision)
|
28
|
+
except ValueError as e:
|
29
|
+
if (not trust_remote_code and
|
30
|
+
"requires you to execute the configuration file" in str(e)):
|
31
|
+
err_msg = (
|
32
|
+
"Failed to load the model config. If the model is a custom "
|
33
|
+
"model not yet available in the HuggingFace transformers "
|
34
|
+
"library, consider setting `trust_remote_code=True` in LLM "
|
35
|
+
"or using the `--trust-remote-code` flag in the CLI.")
|
36
|
+
raise RuntimeError(err_msg) from e
|
37
|
+
else:
|
38
|
+
raise e
|
39
|
+
if config.model_type in _CONFIG_REGISTRY:
|
40
|
+
config_class = _CONFIG_REGISTRY[config.model_type]
|
41
|
+
config = config_class.from_pretrained(model,
|
42
|
+
revision=revision,
|
43
|
+
code_revision=code_revision)
|
44
|
+
return config
|
45
|
+
|
46
|
+
|
47
|
+
def get_hf_text_config(config: PretrainedConfig):
|
48
|
+
"""Get the "sub" config relevant to llm for multi modal models.
|
49
|
+
No op for pure text models.
|
50
|
+
"""
|
51
|
+
if hasattr(config, "text_config"):
|
52
|
+
# The code operates under the assumption that text_config should have
|
53
|
+
# `num_attention_heads` (among others). Assert here to fail early
|
54
|
+
# if transformers config doesn't align with this assumption.
|
55
|
+
assert hasattr(config.text_config, "num_attention_heads")
|
56
|
+
return config.text_config
|
57
|
+
else:
|
58
|
+
return config
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
2
|
+
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
3
|
+
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
4
|
+
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
5
|
+
# `FalconConfig` class from the official HuggingFace transformers library.
|
6
|
+
from vllm.transformers_utils.configs.falcon import RWConfig
|
7
|
+
from vllm.transformers_utils.configs.jais import JAISConfig
|
8
|
+
from vllm.transformers_utils.configs.mpt import MPTConfig
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"ChatGLMConfig",
|
12
|
+
"DbrxConfig",
|
13
|
+
"MPTConfig",
|
14
|
+
"RWConfig",
|
15
|
+
"JAISConfig",
|
16
|
+
]
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Adapted from
|
3
|
+
# https://github.com/THUDM/ChatGLM2-6B
|
4
|
+
from transformers import PretrainedConfig
|
5
|
+
|
6
|
+
|
7
|
+
class ChatGLMConfig(PretrainedConfig):
|
8
|
+
model_type = "chatglm"
|
9
|
+
attribute_map = {
|
10
|
+
"num_hidden_layers": "num_layers",
|
11
|
+
"n_head_kv": "multi_query_group_num",
|
12
|
+
}
|
13
|
+
|
14
|
+
def __init__(self,
|
15
|
+
num_layers=28,
|
16
|
+
padded_vocab_size=65024,
|
17
|
+
hidden_size=4096,
|
18
|
+
ffn_hidden_size=13696,
|
19
|
+
kv_channels=128,
|
20
|
+
num_attention_heads=32,
|
21
|
+
seq_length=2048,
|
22
|
+
hidden_dropout=0.0,
|
23
|
+
attention_dropout=0.0,
|
24
|
+
layernorm_epsilon=1e-5,
|
25
|
+
rmsnorm=True,
|
26
|
+
apply_residual_connection_post_layernorm=False,
|
27
|
+
post_layer_norm=True,
|
28
|
+
add_bias_linear=False,
|
29
|
+
add_qkv_bias=False,
|
30
|
+
interleaved_qkv=False,
|
31
|
+
bias_dropout_fusion=True,
|
32
|
+
multi_query_attention=False,
|
33
|
+
multi_query_group_num=1,
|
34
|
+
apply_query_key_layer_scaling=True,
|
35
|
+
attention_softmax_in_fp32=True,
|
36
|
+
fp32_residual_connection=False,
|
37
|
+
quantization_bit=0,
|
38
|
+
pre_seq_len=None,
|
39
|
+
prefix_projection=False,
|
40
|
+
**kwargs):
|
41
|
+
self.num_layers = num_layers
|
42
|
+
self.vocab_size = padded_vocab_size
|
43
|
+
self.padded_vocab_size = padded_vocab_size
|
44
|
+
self.hidden_size = hidden_size
|
45
|
+
self.ffn_hidden_size = ffn_hidden_size
|
46
|
+
self.kv_channels = kv_channels
|
47
|
+
self.num_attention_heads = num_attention_heads
|
48
|
+
self.seq_length = seq_length
|
49
|
+
self.hidden_dropout = hidden_dropout
|
50
|
+
self.attention_dropout = attention_dropout
|
51
|
+
self.layernorm_epsilon = layernorm_epsilon
|
52
|
+
self.rmsnorm = rmsnorm
|
53
|
+
self.apply_residual_connection_post_layernorm = (
|
54
|
+
apply_residual_connection_post_layernorm)
|
55
|
+
self.post_layer_norm = post_layer_norm
|
56
|
+
self.add_bias_linear = add_bias_linear
|
57
|
+
self.add_qkv_bias = add_qkv_bias
|
58
|
+
self.bias_dropout_fusion = bias_dropout_fusion
|
59
|
+
self.multi_query_attention = multi_query_attention
|
60
|
+
self.multi_query_group_num = multi_query_group_num
|
61
|
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
62
|
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
63
|
+
self.fp32_residual_connection = fp32_residual_connection
|
64
|
+
self.quantization_bit = quantization_bit
|
65
|
+
self.pre_seq_len = pre_seq_len
|
66
|
+
self.prefix_projection = prefix_projection
|
67
|
+
self.interleaved_qkv = interleaved_qkv
|
68
|
+
super().__init__(**kwargs)
|
@@ -0,0 +1,278 @@
|
|
1
|
+
# yapf: disable
|
2
|
+
# ruff: noqa: E501
|
3
|
+
# coding=utf-8
|
4
|
+
# Copied from
|
5
|
+
# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py
|
6
|
+
"""Dbrx configuration."""
|
7
|
+
|
8
|
+
from typing import Any, Optional
|
9
|
+
|
10
|
+
from transformers.configuration_utils import PretrainedConfig
|
11
|
+
from transformers.utils import logging
|
12
|
+
|
13
|
+
logger = logging.get_logger(__name__)
|
14
|
+
|
15
|
+
DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore
|
16
|
+
|
17
|
+
|
18
|
+
class DbrxAttentionConfig(PretrainedConfig):
|
19
|
+
"""Configuration class for Dbrx Attention.
|
20
|
+
|
21
|
+
[`DbrxAttention`] class. It is used to instantiate attention layers
|
22
|
+
according to the specified arguments, defining the layers architecture.
|
23
|
+
|
24
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
25
|
+
documentation from [`PretrainedConfig`] for more information.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
attn_pdrop (`float`, *optional*, defaults to 0.0):
|
29
|
+
The dropout probability for the attention layers.
|
30
|
+
clip_qkv (`float`, *optional*, defaults to None):
|
31
|
+
If not `None`, clip the queries, keys, and values in the attention layer to this value.
|
32
|
+
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
|
33
|
+
rope_theta (float): The base frequency for rope.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
attn_pdrop: float = 0,
|
39
|
+
clip_qkv: Optional[float] = None,
|
40
|
+
kv_n_heads: int = 1,
|
41
|
+
rope_theta: float = 10000.0,
|
42
|
+
**kwargs: Any,
|
43
|
+
):
|
44
|
+
super().__init__(**kwargs)
|
45
|
+
self.attn_pdrop = attn_pdrop
|
46
|
+
self.clip_qkv = clip_qkv
|
47
|
+
self.kv_n_heads = kv_n_heads
|
48
|
+
self.rope_theta = rope_theta
|
49
|
+
|
50
|
+
for k in ["model_type"]:
|
51
|
+
if k in kwargs:
|
52
|
+
kwargs.pop(k)
|
53
|
+
if len(kwargs) != 0:
|
54
|
+
raise ValueError(f"Found unknown {kwargs=}")
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def from_pretrained(
|
58
|
+
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
59
|
+
) -> "PretrainedConfig":
|
60
|
+
cls._set_token_in_kwargs(kwargs)
|
61
|
+
|
62
|
+
config_dict, kwargs = cls.get_config_dict(
|
63
|
+
pretrained_model_name_or_path, **kwargs
|
64
|
+
)
|
65
|
+
|
66
|
+
if config_dict.get("model_type") == "dbrx":
|
67
|
+
config_dict = config_dict["attn_config"]
|
68
|
+
|
69
|
+
if (
|
70
|
+
"model_type" in config_dict
|
71
|
+
and hasattr(cls, "model_type")
|
72
|
+
and config_dict["model_type"] != cls.model_type
|
73
|
+
):
|
74
|
+
logger.warning(
|
75
|
+
"You are using a model of type %s to instantiate a model of "
|
76
|
+
"type %s. This is not supported for all configurations of "
|
77
|
+
"models and can yield errors.",
|
78
|
+
config_dict["model_type"], cls.model_type)
|
79
|
+
|
80
|
+
return cls.from_dict(config_dict, **kwargs)
|
81
|
+
|
82
|
+
|
83
|
+
class DbrxFFNConfig(PretrainedConfig):
|
84
|
+
"""Configuration class for Dbrx FFN.
|
85
|
+
|
86
|
+
[`DbrxFFN`] class. It is used to instantiate feedforward layers according to
|
87
|
+
the specified arguments, defining the layers architecture.
|
88
|
+
|
89
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
90
|
+
documentation from [`PretrainedConfig`] for more information.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
ffn_act_fn (dict, optional): A dict specifying activation function for the FFN.
|
94
|
+
The dict should have a key 'name' with the value being the name of
|
95
|
+
the activation function along with any additional keyword arguments.
|
96
|
+
ffn_hidden_size (int, optional): The hidden size of the feedforward network.
|
97
|
+
moe_num_experts (int, optional): The number of experts in the mixture of experts layer.
|
98
|
+
moe_top_k (int, optional): The number of experts to use in the mixture of experts layer.
|
99
|
+
moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer.
|
100
|
+
moe_loss_weight (float, optional): The loss weight for the mixture of experts layer.
|
101
|
+
moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights.
|
102
|
+
uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment.
|
103
|
+
This should only be used for benchmarking purposes.
|
104
|
+
"""
|
105
|
+
|
106
|
+
def __init__(
|
107
|
+
self,
|
108
|
+
ffn_act_fn: Optional[dict] = None,
|
109
|
+
ffn_hidden_size: int = 3584,
|
110
|
+
moe_num_experts: int = 4,
|
111
|
+
moe_top_k: int = 1,
|
112
|
+
moe_jitter_eps: Optional[float] = None,
|
113
|
+
moe_loss_weight: float = 0.01,
|
114
|
+
moe_normalize_expert_weights: Optional[float] = 1,
|
115
|
+
uniform_expert_assignment: bool = False,
|
116
|
+
**kwargs: Any,
|
117
|
+
):
|
118
|
+
super().__init__()
|
119
|
+
if ffn_act_fn is None:
|
120
|
+
ffn_act_fn = {"name": "silu"}
|
121
|
+
self.ffn_act_fn = ffn_act_fn
|
122
|
+
self.ffn_hidden_size = ffn_hidden_size
|
123
|
+
self.moe_num_experts = moe_num_experts
|
124
|
+
self.moe_top_k = moe_top_k
|
125
|
+
self.moe_jitter_eps = moe_jitter_eps
|
126
|
+
self.moe_loss_weight = moe_loss_weight
|
127
|
+
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
128
|
+
self.uniform_expert_assignment = uniform_expert_assignment
|
129
|
+
|
130
|
+
for k in ["model_type"]:
|
131
|
+
if k in kwargs:
|
132
|
+
kwargs.pop(k)
|
133
|
+
if len(kwargs) != 0:
|
134
|
+
raise ValueError(f"Found unknown {kwargs=}")
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def from_pretrained(
|
138
|
+
cls, pretrained_model_name_or_path: str, **kwargs: Any
|
139
|
+
) -> "PretrainedConfig":
|
140
|
+
cls._set_token_in_kwargs(kwargs)
|
141
|
+
|
142
|
+
config_dict, kwargs = cls.get_config_dict(
|
143
|
+
pretrained_model_name_or_path, **kwargs
|
144
|
+
)
|
145
|
+
|
146
|
+
if config_dict.get("model_type") == "dbrx":
|
147
|
+
config_dict = config_dict["ffn_config"]
|
148
|
+
|
149
|
+
if (
|
150
|
+
"model_type" in config_dict
|
151
|
+
and hasattr(cls, "model_type")
|
152
|
+
and config_dict["model_type"] != cls.model_type
|
153
|
+
):
|
154
|
+
logger.warning(
|
155
|
+
"You are using a model of type %s to instantiate a model of "
|
156
|
+
"type %s. This is not supported for all "
|
157
|
+
"configurations of models and can yield errors.", config_dict["model_type"], cls.model_type)
|
158
|
+
|
159
|
+
return cls.from_dict(config_dict, **kwargs)
|
160
|
+
|
161
|
+
|
162
|
+
class DbrxConfig(PretrainedConfig):
|
163
|
+
"""Configuration class for Dbrx.
|
164
|
+
|
165
|
+
[`DbrxModel`]. It is used to instantiate a Dbrx model according to the
|
166
|
+
specified arguments, defining the model architecture.
|
167
|
+
|
168
|
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
169
|
+
documentation from [`PretrainedConfig`] for more information.
|
170
|
+
|
171
|
+
|
172
|
+
Args:
|
173
|
+
d_model (`int`, *optional*, defaults to 6144):
|
174
|
+
Dimensionality of the embeddings and hidden states.
|
175
|
+
n_heads (`int`, *optional*, defaults to 48):
|
176
|
+
Number of attention heads for each attention layer in the Transformer encoder.
|
177
|
+
n_layers (`int`, *optional*, defaults to 40):
|
178
|
+
Number of hidden layers in the Transformer encoder.
|
179
|
+
max_seq_len (`int`, *optional*, defaults to 32768):
|
180
|
+
The maximum sequence length of the model.
|
181
|
+
vocab_size (`int`, *optional*, defaults to 100352):
|
182
|
+
Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by
|
183
|
+
the `inputs_ids` passed when calling [`DbrxModel`].
|
184
|
+
resid_pdrop (`float`, *optional*, defaults to 0.0):
|
185
|
+
The dropout probability applied to the attention output before combining with residual.
|
186
|
+
emb_pdrop (`float`, *optional*, defaults to 0.0):
|
187
|
+
The dropout probability for the embedding layer.
|
188
|
+
attn_config (`dict`, *optional*):
|
189
|
+
A dictionary used to configure the model's attention module.
|
190
|
+
ffn_config (`dict`, *optional*):
|
191
|
+
A dictionary used to configure the model's FFN module.
|
192
|
+
use_cache (`bool`, *optional*, defaults to `False`):
|
193
|
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
194
|
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
195
|
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
196
|
+
output_router_logits (`bool`, *optional*, defaults to `False`):
|
197
|
+
Whether or not the router logits should be returned by the model. Enabling this will also
|
198
|
+
allow the model to output the auxiliary loss. See [here]() for more details
|
199
|
+
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
200
|
+
The aux loss factor for the total loss.
|
201
|
+
|
202
|
+
|
203
|
+
Example:
|
204
|
+
```python
|
205
|
+
>>> from transformers import DbrxConfig, DbrxModel
|
206
|
+
|
207
|
+
>>> # Initializing a Dbrx configuration
|
208
|
+
>>> configuration = DbrxConfig()
|
209
|
+
|
210
|
+
>>> # Initializing a model (with random weights) from the configuration
|
211
|
+
>>> model = DbrxModel(configuration)
|
212
|
+
|
213
|
+
>>> # Accessing the model configuration
|
214
|
+
>>> configuration = model.config
|
215
|
+
```
|
216
|
+
"""
|
217
|
+
|
218
|
+
model_type = "dbrx"
|
219
|
+
attribute_map = {
|
220
|
+
"num_attention_heads": "n_heads",
|
221
|
+
"hidden_size": "d_model",
|
222
|
+
"num_hidden_layers": "n_layers",
|
223
|
+
"max_position_embeddings": "max_seq_len",
|
224
|
+
}
|
225
|
+
|
226
|
+
def __init__(
|
227
|
+
self,
|
228
|
+
d_model: int = 2048,
|
229
|
+
n_heads: int = 16,
|
230
|
+
n_layers: int = 24,
|
231
|
+
max_seq_len: int = 2048,
|
232
|
+
vocab_size: int = 32000,
|
233
|
+
resid_pdrop: float = 0.0,
|
234
|
+
emb_pdrop: float = 0.0,
|
235
|
+
attn_config: Optional[DbrxAttentionConfig] = None,
|
236
|
+
ffn_config: Optional[DbrxFFNConfig] = None,
|
237
|
+
use_cache: bool = True,
|
238
|
+
initializer_range: float = 0.02,
|
239
|
+
output_router_logits: bool = False,
|
240
|
+
router_aux_loss_coef: float = 0.05,
|
241
|
+
**kwargs: Any,
|
242
|
+
):
|
243
|
+
if attn_config is None:
|
244
|
+
self.attn_config = DbrxAttentionConfig()
|
245
|
+
elif isinstance(attn_config, dict):
|
246
|
+
self.attn_config = DbrxAttentionConfig(**attn_config)
|
247
|
+
else:
|
248
|
+
self.attn_config = attn_config
|
249
|
+
|
250
|
+
if ffn_config is None:
|
251
|
+
self.ffn_config = DbrxFFNConfig()
|
252
|
+
elif isinstance(ffn_config, dict):
|
253
|
+
self.ffn_config = DbrxFFNConfig(**ffn_config)
|
254
|
+
else:
|
255
|
+
self.ffn_config = ffn_config
|
256
|
+
|
257
|
+
self.d_model = d_model
|
258
|
+
self.n_heads = n_heads
|
259
|
+
self.n_layers = n_layers
|
260
|
+
self.max_seq_len = max_seq_len
|
261
|
+
self.vocab_size = vocab_size
|
262
|
+
self.resid_pdrop = resid_pdrop
|
263
|
+
self.emb_pdrop = emb_pdrop
|
264
|
+
self.use_cache = use_cache
|
265
|
+
self.initializer_range = initializer_range
|
266
|
+
self.output_router_logits = output_router_logits
|
267
|
+
self.router_aux_loss_coef = router_aux_loss_coef
|
268
|
+
|
269
|
+
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
270
|
+
if tie_word_embeddings:
|
271
|
+
raise ValueError(
|
272
|
+
"tie_word_embeddings is not supported for Dbrx models."
|
273
|
+
)
|
274
|
+
|
275
|
+
super().__init__(
|
276
|
+
tie_word_embeddings=tie_word_embeddings,
|
277
|
+
**kwargs,
|
278
|
+
)
|