sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +14 -0
- sglang/backend/anthropic.py +18 -12
- sglang/backend/base_backend.py +6 -0
- sglang/backend/openai.py +41 -12
- sglang/backend/runtime_endpoint.py +57 -6
- sglang/lang/chat_template.py +47 -26
- sglang/lang/interpreter.py +15 -2
- sglang/lang/ir.py +1 -1
- sglang/srt/constrained/__init__.py +23 -1
- sglang/srt/constrained/fsm_cache.py +14 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -1
- sglang/srt/layers/extend_attention.py +7 -6
- sglang/srt/layers/radix_attention.py +2 -10
- sglang/srt/layers/token_attention.py +12 -4
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/router/infer_batch.py +6 -2
- sglang/srt/managers/router/model_rpc.py +45 -32
- sglang/srt/managers/router/model_runner.py +40 -25
- sglang/srt/managers/tokenizer_manager.py +2 -0
- sglang/srt/model_config.py +12 -5
- sglang/srt/models/gemma.py +340 -0
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llava.py +2 -4
- sglang/srt/models/mixtral.py +5 -5
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +5 -5
- sglang/srt/models/stablelm.py +293 -0
- sglang/srt/server.py +111 -47
- sglang/srt/server_args.py +44 -9
- sglang/srt/utils.py +1 -0
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +15 -12
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
- sglang-0.1.14.dist-info/RECORD +64 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
- sglang/srt/models/gpt_neox.py +0 -274
- sglang-0.1.12.dist-info/RECORD +0 -63
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
- {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
@@ -4,8 +4,16 @@
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
|
+
from sglang.srt.managers.router.model_runner import global_server_args_dict
|
7
8
|
from sglang.srt.utils import wrap_kernel_launcher
|
8
9
|
|
10
|
+
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
11
|
+
REDUCE_TRITON_TYPE = tl.float32
|
12
|
+
REDUCE_TORCH_TYPE = torch.float32
|
13
|
+
else:
|
14
|
+
REDUCE_TRITON_TYPE = tl.float16
|
15
|
+
REDUCE_TORCH_TYPE = torch.float16
|
16
|
+
|
9
17
|
|
10
18
|
@triton.jit
|
11
19
|
def _fwd_kernel_stage1(
|
@@ -49,7 +57,7 @@ def _fwd_kernel_stage1(
|
|
49
57
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
50
58
|
|
51
59
|
for start_mark in range(0, block_mask, 1):
|
52
|
-
q = tl.load(Q + off_q + start_mark)
|
60
|
+
q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
|
53
61
|
offs_n_new = cur_batch_start_index + offs_n
|
54
62
|
k_loc = tl.load(
|
55
63
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
@@ -65,7 +73,7 @@ def _fwd_kernel_stage1(
|
|
65
73
|
K_Buffer + offs_buf_k,
|
66
74
|
mask=offs_n_new[:, None] < cur_batch_end_index,
|
67
75
|
other=0.0,
|
68
|
-
)
|
76
|
+
).to(REDUCE_TRITON_TYPE)
|
69
77
|
att_value = tl.sum(q[None, :] * k, 1)
|
70
78
|
att_value *= sm_scale
|
71
79
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
@@ -161,7 +169,7 @@ def _token_att_m_fwd(
|
|
161
169
|
# shape constraints
|
162
170
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
163
171
|
assert Lq == Lk
|
164
|
-
assert Lk in {16, 32, 64, 128}
|
172
|
+
assert Lk in {16, 32, 64, 128, 256}
|
165
173
|
sm_scale = 1.0 / (Lk**0.5)
|
166
174
|
|
167
175
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
@@ -299,7 +307,7 @@ def token_attention_fwd(
|
|
299
307
|
):
|
300
308
|
if att_m is None:
|
301
309
|
att_m = torch.empty(
|
302
|
-
(q.shape[-2], total_num_tokens), dtype=
|
310
|
+
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
303
311
|
)
|
304
312
|
|
305
313
|
_token_att_m_fwd(
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -15,10 +15,12 @@ class GenerateReqInput:
|
|
15
15
|
sampling_params: Union[List[Dict], Dict] = None
|
16
16
|
# The request id
|
17
17
|
rid: Optional[Union[List[str], str]] = None
|
18
|
-
# Whether return logprobs
|
18
|
+
# Whether to return logprobs
|
19
19
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
20
20
|
# The start location of the prompt for return_logprob
|
21
21
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
22
|
+
# Whether to detokenize tokens in logprobs
|
23
|
+
return_text_in_logprobs: bool = False
|
22
24
|
# Whether to stream output
|
23
25
|
stream: bool = False
|
24
26
|
|
@@ -27,8 +27,12 @@ class Req:
|
|
27
27
|
self.input_ids = input_ids
|
28
28
|
self.output_ids = []
|
29
29
|
|
30
|
-
#
|
31
|
-
|
30
|
+
# Since jump forward may retokenize the prompt with partial outputs,
|
31
|
+
# we maintain the original prompt length to report the correct usage.
|
32
|
+
self.prompt_tokens = len(input_ids)
|
33
|
+
# The number of decoded tokens for token usage report. Note that
|
34
|
+
# this does not include the jump forward tokens.
|
35
|
+
self.completion_tokens_wo_jump_forward = 0
|
32
36
|
|
33
37
|
# For vision input
|
34
38
|
self.pixel_values = None
|
@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
|
|
46
46
|
server_args, port_args = [obtain(x) for x in [server_args, port_args]]
|
47
47
|
|
48
48
|
# Copy arguments
|
49
|
-
self.model_mode = server_args.model_mode
|
50
49
|
self.tp_rank = tp_rank
|
51
50
|
self.tp_size = server_args.tp_size
|
52
51
|
self.schedule_heuristic = server_args.schedule_heuristic
|
@@ -57,17 +56,26 @@ class ModelRpcServer(rpyc.Service):
|
|
57
56
|
|
58
57
|
# Init model and tokenizer
|
59
58
|
self.model_config = ModelConfig(
|
60
|
-
server_args.model_path,
|
59
|
+
server_args.model_path,
|
60
|
+
server_args.trust_remote_code,
|
61
|
+
context_length=server_args.context_length,
|
61
62
|
)
|
63
|
+
|
64
|
+
# for model end global settings
|
65
|
+
server_args_dict = {
|
66
|
+
"enable_flashinfer": server_args.enable_flashinfer,
|
67
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
68
|
+
}
|
69
|
+
|
62
70
|
self.model_runner = ModelRunner(
|
63
|
-
self.model_config,
|
64
|
-
server_args.mem_fraction_static,
|
65
|
-
tp_rank,
|
66
|
-
server_args.tp_size,
|
67
|
-
port_args.nccl_port,
|
68
|
-
server_args.load_format,
|
69
|
-
server_args.trust_remote_code,
|
70
|
-
|
71
|
+
model_config=self.model_config,
|
72
|
+
mem_fraction_static=server_args.mem_fraction_static,
|
73
|
+
tp_rank=tp_rank,
|
74
|
+
tp_size=server_args.tp_size,
|
75
|
+
nccl_port=port_args.nccl_port,
|
76
|
+
load_format=server_args.load_format,
|
77
|
+
trust_remote_code=server_args.trust_remote_code,
|
78
|
+
server_args_dict=server_args_dict,
|
71
79
|
)
|
72
80
|
if is_multimodal_model(server_args.model_path):
|
73
81
|
self.processor = get_processor(
|
@@ -102,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
|
|
102
110
|
f"max_total_num_token={self.max_total_num_token}, "
|
103
111
|
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
104
112
|
f"context_len={self.model_config.context_len}, "
|
105
|
-
f"model_mode={self.model_mode}"
|
106
113
|
)
|
114
|
+
logger.info(server_args.get_optional_modes_logging())
|
107
115
|
|
108
116
|
# Init cache
|
109
|
-
self.tree_cache = RadixCache(
|
117
|
+
self.tree_cache = RadixCache(server_args.disable_radix_cache)
|
110
118
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
111
119
|
self.scheduler = Scheduler(
|
112
120
|
self.schedule_heuristic,
|
@@ -208,6 +216,19 @@ class ModelRpcServer(rpyc.Service):
|
|
208
216
|
|
209
217
|
if self.out_pyobjs and self.running_batch.reqs[0].stream:
|
210
218
|
break
|
219
|
+
|
220
|
+
if self.running_batch is not None and self.tp_rank == 0:
|
221
|
+
if self.decode_forward_ct % 40 == 0:
|
222
|
+
num_used = self.max_total_num_token - (
|
223
|
+
self.token_to_kv_pool.available_size()
|
224
|
+
+ self.tree_cache.evictable_size()
|
225
|
+
)
|
226
|
+
logger.info(
|
227
|
+
f"#running-req: {len(self.running_batch.reqs)}, "
|
228
|
+
f"#token: {num_used}, "
|
229
|
+
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
230
|
+
f"#queue-req: {len(self.forward_queue)}"
|
231
|
+
)
|
211
232
|
else:
|
212
233
|
# check the available size
|
213
234
|
available_size = (
|
@@ -221,19 +242,6 @@ class ModelRpcServer(rpyc.Service):
|
|
221
242
|
"KV cache pool leak detected!"
|
222
243
|
)
|
223
244
|
|
224
|
-
if self.running_batch is not None and self.tp_rank == 0:
|
225
|
-
if self.decode_forward_ct % 20 == 0:
|
226
|
-
num_used = self.max_total_num_token - (
|
227
|
-
self.token_to_kv_pool.available_size()
|
228
|
-
+ self.tree_cache.evictable_size()
|
229
|
-
)
|
230
|
-
logger.info(
|
231
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
232
|
-
f"#token: {num_used}, "
|
233
|
-
f"token usage: {num_used / self.max_total_num_token:.2f}, "
|
234
|
-
f"#queue-req: {len(self.forward_queue)}"
|
235
|
-
)
|
236
|
-
|
237
245
|
def handle_generate_request(
|
238
246
|
self,
|
239
247
|
recv_req: TokenizedGenerateReqInput,
|
@@ -424,6 +432,7 @@ class ModelRpcServer(rpyc.Service):
|
|
424
432
|
# Check finish condition
|
425
433
|
pt = 0
|
426
434
|
for i, req in enumerate(reqs):
|
435
|
+
req.completion_tokens_wo_jump_forward += 1
|
427
436
|
req.output_ids = [next_token_ids[i]]
|
428
437
|
req.check_finished()
|
429
438
|
|
@@ -431,9 +440,14 @@ class ModelRpcServer(rpyc.Service):
|
|
431
440
|
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
432
441
|
req.normalized_logprob = normalized_logprobs[i]
|
433
442
|
|
434
|
-
|
435
|
-
|
443
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
|
444
|
+
# will be ignored.
|
445
|
+
prompt_token_len = len(req.logprob)
|
446
|
+
token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
|
447
|
+
token_logprobs = req.logprob + [last_logprobs[i]]
|
436
448
|
req.token_logprob = list(zip(token_ids, token_logprobs))
|
449
|
+
if req.logprob_start_len == 0:
|
450
|
+
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
|
437
451
|
pt += req.extend_input_len
|
438
452
|
|
439
453
|
self.handle_finished_requests(batch)
|
@@ -500,6 +514,7 @@ class ModelRpcServer(rpyc.Service):
|
|
500
514
|
|
501
515
|
# Check finish condition
|
502
516
|
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
517
|
+
req.completion_tokens_wo_jump_forward += 1
|
503
518
|
req.output_ids.append(next_tok_id)
|
504
519
|
req.check_finished()
|
505
520
|
|
@@ -541,15 +556,13 @@ class ModelRpcServer(rpyc.Service):
|
|
541
556
|
req.sampling_params.skip_special_tokens
|
542
557
|
)
|
543
558
|
|
544
|
-
# For the length of input_ids, which will be accumulated during jump-forward.
|
545
|
-
# Use the original length of input_ids to calculate the token usage info.
|
546
559
|
meta_info = {
|
547
|
-
"prompt_tokens": req.
|
560
|
+
"prompt_tokens": req.prompt_tokens,
|
548
561
|
"completion_tokens": len(req.input_ids)
|
549
562
|
+ len(req.output_ids)
|
550
|
-
- req.
|
563
|
+
- req.prompt_tokens,
|
564
|
+
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
551
565
|
}
|
552
|
-
|
553
566
|
if req.return_logprob:
|
554
567
|
meta_info["prompt_logprob"] = req.logprob
|
555
568
|
meta_info["token_logprob"] = req.token_logprob
|
@@ -1,9 +1,10 @@
|
|
1
1
|
import importlib
|
2
2
|
import logging
|
3
|
+
import inspect
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from functools import lru_cache
|
5
6
|
from pathlib import Path
|
6
|
-
|
7
|
+
import importlib.resources
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import torch
|
@@ -13,27 +14,34 @@ from sglang.srt.utils import is_multimodal_model
|
|
13
14
|
from sglang.utils import get_available_gpu_memory
|
14
15
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
15
16
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
17
|
+
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
16
18
|
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
17
19
|
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
18
20
|
|
21
|
+
import importlib
|
22
|
+
import pkgutil
|
23
|
+
|
19
24
|
import sglang
|
20
25
|
|
21
|
-
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig}
|
26
|
+
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
22
27
|
|
23
28
|
logger = logging.getLogger("model_runner")
|
24
29
|
|
25
30
|
|
26
|
-
# for
|
27
|
-
|
31
|
+
# for server args in model endpoints
|
32
|
+
global_server_args_dict: dict = None
|
28
33
|
|
29
34
|
|
30
35
|
@lru_cache()
|
31
36
|
def import_model_classes():
|
32
37
|
model_arch_name_to_cls = {}
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
38
|
+
package_name = "sglang.srt.models"
|
39
|
+
package = importlib.import_module(package_name)
|
40
|
+
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
|
41
|
+
if not ispkg:
|
42
|
+
module = importlib.import_module(name)
|
43
|
+
if hasattr(module, "EntryClass"):
|
44
|
+
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
37
45
|
return model_arch_name_to_cls
|
38
46
|
|
39
47
|
|
@@ -81,7 +89,6 @@ class InputMetadata:
|
|
81
89
|
return_logprob: bool = False
|
82
90
|
|
83
91
|
# for flashinfer
|
84
|
-
use_flashinfer: bool = False
|
85
92
|
qo_indptr: torch.Tensor = None
|
86
93
|
kv_indptr: torch.Tensor = None
|
87
94
|
kv_indices: torch.Tensor = None
|
@@ -126,14 +133,21 @@ class InputMetadata:
|
|
126
133
|
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
127
134
|
workspace_buffer, "NHD"
|
128
135
|
)
|
129
|
-
|
136
|
+
args = [
|
130
137
|
self.qo_indptr,
|
131
138
|
self.kv_indptr,
|
132
139
|
self.kv_indices,
|
133
140
|
self.kv_last_page_len,
|
134
141
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
135
142
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
136
|
-
|
143
|
+
]
|
144
|
+
|
145
|
+
# flashinfer >= 0.0.3
|
146
|
+
# FIXME: Drop this when flashinfer updates to 0.0.4
|
147
|
+
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
148
|
+
args.append(self.model_runner.model_config.head_dim)
|
149
|
+
|
150
|
+
self.prefill_wrapper.begin_forward(*args)
|
137
151
|
else:
|
138
152
|
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
139
153
|
workspace_buffer, "NHD"
|
@@ -224,8 +238,7 @@ class InputMetadata:
|
|
224
238
|
if forward_mode == ForwardMode.EXTEND:
|
225
239
|
ret.init_extend_args()
|
226
240
|
|
227
|
-
|
228
|
-
if ret.use_flashinfer:
|
241
|
+
if global_server_args_dict.get("enable_flashinfer", False):
|
229
242
|
ret.init_flashinfer_args(tp_size)
|
230
243
|
|
231
244
|
return ret
|
@@ -241,7 +254,7 @@ class ModelRunner:
|
|
241
254
|
nccl_port,
|
242
255
|
load_format="auto",
|
243
256
|
trust_remote_code=True,
|
244
|
-
|
257
|
+
server_args_dict: dict = {},
|
245
258
|
):
|
246
259
|
self.model_config = model_config
|
247
260
|
self.mem_fraction_static = mem_fraction_static
|
@@ -250,10 +263,9 @@ class ModelRunner:
|
|
250
263
|
self.nccl_port = nccl_port
|
251
264
|
self.load_format = load_format
|
252
265
|
self.trust_remote_code = trust_remote_code
|
253
|
-
self.model_mode = model_mode
|
254
266
|
|
255
|
-
global
|
256
|
-
|
267
|
+
global global_server_args_dict
|
268
|
+
global_server_args_dict = server_args_dict
|
257
269
|
|
258
270
|
# Init torch distributed
|
259
271
|
torch.cuda.set_device(self.tp_rank)
|
@@ -292,9 +304,15 @@ class ModelRunner:
|
|
292
304
|
self.model_config.hf_config, "quantization_config", None
|
293
305
|
)
|
294
306
|
if hf_quant_config is not None:
|
295
|
-
|
296
|
-
|
297
|
-
|
307
|
+
hf_quant_method = hf_quant_config["quant_method"]
|
308
|
+
|
309
|
+
# compat: autogptq uses is_marlin_format within quant config
|
310
|
+
if (hf_quant_method == "gptq"
|
311
|
+
and "is_marlin_format" in hf_quant_config
|
312
|
+
and hf_quant_config["is_marlin_format"]):
|
313
|
+
hf_quant_method = "marlin"
|
314
|
+
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
315
|
+
|
298
316
|
if quant_config_class is None:
|
299
317
|
raise ValueError(
|
300
318
|
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
@@ -319,9 +337,7 @@ class ModelRunner:
|
|
319
337
|
available_gpu_memory = get_available_gpu_memory(
|
320
338
|
self.tp_rank, distributed=self.tp_size > 1
|
321
339
|
) * (1 << 30)
|
322
|
-
head_dim =
|
323
|
-
self.model_config.hidden_size // self.model_config.num_attention_heads
|
324
|
-
)
|
340
|
+
head_dim = self.model_config.head_dim
|
325
341
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
326
342
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
327
343
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
@@ -346,8 +362,7 @@ class ModelRunner:
|
|
346
362
|
self.max_total_num_token,
|
347
363
|
dtype=torch.float16,
|
348
364
|
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
349
|
-
head_dim=self.model_config.
|
350
|
-
// self.model_config.num_attention_heads,
|
365
|
+
head_dim=self.model_config.head_dim,
|
351
366
|
layer_num=self.model_config.num_hidden_layers,
|
352
367
|
)
|
353
368
|
|
@@ -82,6 +82,8 @@ class TokenizerManager:
|
|
82
82
|
server_args: ServerArgs,
|
83
83
|
port_args: PortArgs,
|
84
84
|
):
|
85
|
+
self.server_args = server_args
|
86
|
+
|
85
87
|
context = zmq.asyncio.Context(2)
|
86
88
|
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
87
89
|
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
sglang/srt/model_config.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1
|
-
import
|
2
|
-
from typing import Optional, Union
|
1
|
+
from typing import Optional
|
3
2
|
|
4
|
-
import torch
|
5
3
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
6
4
|
|
7
5
|
|
@@ -11,15 +9,24 @@ class ModelConfig:
|
|
11
9
|
path: str,
|
12
10
|
trust_remote_code: bool = True,
|
13
11
|
revision: Optional[str] = None,
|
12
|
+
context_length: Optional[int] = None,
|
14
13
|
) -> None:
|
15
14
|
self.path = path
|
16
15
|
self.trust_remote_code = trust_remote_code
|
17
16
|
self.revision = revision
|
18
17
|
self.hf_config = get_config(self.path, trust_remote_code, revision)
|
19
18
|
|
19
|
+
if context_length is not None:
|
20
|
+
self.context_len = context_length
|
21
|
+
else:
|
22
|
+
self.context_len = get_context_length(self.hf_config)
|
23
|
+
|
20
24
|
# Unify the config keys for hf_config
|
21
|
-
self.
|
22
|
-
|
25
|
+
self.head_dim = getattr(
|
26
|
+
self.hf_config,
|
27
|
+
"head_dim",
|
28
|
+
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
|
29
|
+
)
|
23
30
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
24
31
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
25
32
|
if self.num_key_value_heads is None:
|