sglang 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -5
- sglang/lang/interpreter.py +2 -1
- sglang/lang/ir.py +0 -1
- sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +2 -2
- sglang/srt/constrained/fsm_cache.py +2 -2
- sglang/srt/constrained/jump_forward.py +2 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +29 -9
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/managers/tp_worker.py +29 -6
- sglang/srt/mem_cache/base_cache.py +43 -0
- sglang/srt/mem_cache/chunk_cache.py +60 -0
- sglang/srt/mem_cache/radix_cache.py +5 -2
- sglang/srt/model_executor/model_runner.py +17 -2
- sglang/srt/models/llama2.py +5 -21
- sglang/srt/openai_api/adapter.py +76 -22
- sglang/srt/openai_api/protocol.py +20 -2
- sglang/srt/server.py +9 -14
- sglang/srt/server_args.py +18 -4
- sglang/srt/utils.py +20 -0
- sglang/test/run_eval.py +104 -0
- sglang/test/simple_eval_common.py +467 -0
- sglang/test/simple_eval_humaneval.py +139 -0
- sglang/test/simple_eval_mmlu.py +120 -0
- sglang/test/test_programs.py +12 -9
- sglang/test/test_utils.py +32 -0
- sglang/version.py +1 -1
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/METADATA +4 -4
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/RECORD +32 -28
- sglang/test/test_conversation.py +0 -46
- sglang/test/test_openai_protocol.py +0 -51
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/LICENSE +0 -0
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/WHEEL +0 -0
- {sglang-0.2.7.dist-info → sglang-0.2.9.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -21,7 +21,7 @@ import sys
|
|
21
21
|
import time
|
22
22
|
import traceback
|
23
23
|
import warnings
|
24
|
-
from argparse import ArgumentParser
|
24
|
+
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
27
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
|
|
868
868
|
|
869
869
|
|
870
870
|
if __name__ == "__main__":
|
871
|
-
parser =
|
872
|
-
description="Benchmark the online serving throughput."
|
873
|
-
)
|
871
|
+
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
874
872
|
parser.add_argument(
|
875
873
|
"--backend",
|
876
874
|
type=str,
|
877
|
-
required=True,
|
878
875
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
876
|
+
default="sglang",
|
879
877
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
880
878
|
)
|
881
879
|
parser.add_argument(
|
sglang/lang/interpreter.py
CHANGED
@@ -553,7 +553,8 @@ class StreamExecutor:
|
|
553
553
|
"output_token_logprobs": output_token_logprobs,
|
554
554
|
}
|
555
555
|
self.variable_event[name].set()
|
556
|
-
self.stream_var_event
|
556
|
+
if self.stream_var_event:
|
557
|
+
self.stream_var_event[name].set()
|
557
558
|
self.text_ += decision
|
558
559
|
|
559
560
|
def _execute_variable(self, expr: SglVariable):
|
sglang/lang/ir.py
CHANGED
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""Base cache
|
16
|
+
"""Base tool cache for constrained decoding tools."""
|
17
17
|
|
18
18
|
import time
|
19
19
|
|
20
20
|
|
21
|
-
class
|
21
|
+
class BaseToolCache:
|
22
22
|
def __init__(self, enable=True):
|
23
23
|
self.enable = enable
|
24
24
|
self.reset()
|
@@ -16,10 +16,10 @@ limitations under the License.
|
|
16
16
|
"""Cache for the compressed finite state machine."""
|
17
17
|
|
18
18
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
19
|
-
from sglang.srt.constrained.
|
19
|
+
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
20
20
|
|
21
21
|
|
22
|
-
class FSMCache(
|
22
|
+
class FSMCache(BaseToolCache):
|
23
23
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
24
24
|
super().__init__(enable=enable)
|
25
25
|
|
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
|
|
30
30
|
make_byte_level_fsm,
|
31
31
|
make_deterministic_fsm,
|
32
32
|
)
|
33
|
-
from sglang.srt.constrained.
|
33
|
+
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
34
34
|
|
35
35
|
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
36
36
|
|
@@ -151,7 +151,7 @@ class JumpForwardMap:
|
|
151
151
|
)
|
152
152
|
|
153
153
|
|
154
|
-
class JumpForwardCache(
|
154
|
+
class JumpForwardCache(BaseToolCache):
|
155
155
|
def __init__(self):
|
156
156
|
super().__init__()
|
157
157
|
|
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
211
|
all_logprobs = all_logits
|
212
|
-
del all_logits
|
212
|
+
del all_logits, hidden_states
|
213
213
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
214
214
|
|
215
215
|
# Get the logprob of top-k tokens
|
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
|
|
28
28
|
from sglang.global_config import global_config
|
29
29
|
from sglang.srt.constrained import RegexGuide
|
30
30
|
from sglang.srt.constrained.jump_forward import JumpForwardMap
|
31
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
31
32
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
|
32
33
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
33
34
|
|
@@ -486,15 +487,33 @@ class Batch:
|
|
486
487
|
req = self.reqs[idx]
|
487
488
|
retracted_reqs.append(req)
|
488
489
|
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
490
|
+
if isinstance(self.tree_cache, ChunkCache):
|
491
|
+
# ChunkCache does not have eviction
|
492
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
493
|
+
req_pool_indices_cpu[idx]
|
494
|
+
][: seq_lens_cpu[idx]]
|
495
|
+
self.token_to_kv_pool.free(token_indices)
|
496
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
497
|
+
del self.tree_cache.entries[req.rid]
|
498
|
+
else:
|
499
|
+
# TODO: apply more fine-grained retraction
|
500
|
+
last_uncached_pos = len(req.prefix_indices)
|
501
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
502
|
+
req_pool_indices_cpu[idx]
|
503
|
+
][last_uncached_pos : seq_lens_cpu[idx]]
|
504
|
+
self.token_to_kv_pool.free(token_indices)
|
505
|
+
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
|
506
|
+
|
507
|
+
# release the last node
|
508
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
509
|
+
|
510
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
511
|
+
residual_size = (
|
512
|
+
len(sorted_indices) * global_config.retract_decode_steps
|
513
|
+
- self.token_to_kv_pool.available_size()
|
514
|
+
)
|
515
|
+
residual_size = max(0, residual_size)
|
516
|
+
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
498
517
|
|
499
518
|
req.prefix_indices = None
|
500
519
|
req.last_node = None
|
@@ -575,6 +594,7 @@ class Batch:
|
|
575
594
|
if req_pool_indices_cpu is None:
|
576
595
|
req_pool_indices_cpu = self.req_pool_indices.tolist()
|
577
596
|
self.tree_cache.cache_req(
|
597
|
+
rid=req.rid,
|
578
598
|
token_ids=cur_all_ids,
|
579
599
|
last_uncached_pos=len(req.prefix_indices),
|
580
600
|
req_pool_idx=req_pool_indices_cpu[i],
|
@@ -79,6 +79,7 @@ class TokenizerManager:
|
|
79
79
|
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
80
80
|
|
81
81
|
self.model_path = server_args.model_path
|
82
|
+
self.served_model_name = server_args.served_model_name
|
82
83
|
self.hf_config = get_config(
|
83
84
|
self.model_path,
|
84
85
|
trust_remote_code=server_args.trust_remote_code,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
|
|
43
43
|
ForwardMode,
|
44
44
|
Req,
|
45
45
|
)
|
46
|
+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
46
47
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
47
48
|
from sglang.srt.model_config import ModelConfig
|
48
49
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -144,11 +145,20 @@ class ModelTpServer:
|
|
144
145
|
)
|
145
146
|
|
146
147
|
# Init cache
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
148
|
+
if (
|
149
|
+
server_args.chunked_prefill_size is not None
|
150
|
+
and server_args.disable_radix_cache
|
151
|
+
):
|
152
|
+
self.tree_cache = ChunkCache(
|
153
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
154
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
self.tree_cache = RadixCache(
|
158
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
159
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
160
|
+
disable=server_args.disable_radix_cache,
|
161
|
+
)
|
152
162
|
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
153
163
|
self.scheduler = PolicyScheduler(
|
154
164
|
self.schedule_policy,
|
@@ -280,6 +290,14 @@ class ModelTpServer:
|
|
280
290
|
"KV cache pool leak detected!"
|
281
291
|
)
|
282
292
|
|
293
|
+
if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
|
294
|
+
warnings.warn(
|
295
|
+
"Warning: "
|
296
|
+
f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
|
297
|
+
f"total slots={self.req_to_token_pool.size}\n"
|
298
|
+
"Memory pool leak detected!"
|
299
|
+
)
|
300
|
+
|
283
301
|
def handle_generate_request(
|
284
302
|
self,
|
285
303
|
recv_req: TokenizedGenerateReqInput,
|
@@ -346,7 +364,10 @@ class ModelTpServer:
|
|
346
364
|
# Compute matched prefix length
|
347
365
|
for req in self.waiting_queue:
|
348
366
|
req.input_ids = req.origin_input_ids + req.output_ids
|
349
|
-
prefix_indices, last_node = self.tree_cache.match_prefix(
|
367
|
+
prefix_indices, last_node = self.tree_cache.match_prefix(
|
368
|
+
rid=req.rid,
|
369
|
+
key=req.input_ids,
|
370
|
+
)
|
350
371
|
if req.return_logprob:
|
351
372
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
352
373
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
@@ -606,6 +627,7 @@ class ModelTpServer:
|
|
606
627
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
607
628
|
for i, req in enumerate(batch.reqs):
|
608
629
|
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
|
630
|
+
rid=req.rid,
|
609
631
|
token_ids=tuple(req.input_ids),
|
610
632
|
last_uncached_pos=len(req.prefix_indices),
|
611
633
|
req_pool_idx=req_pool_indices_cpu[i],
|
@@ -763,6 +785,7 @@ class ModelTpServer:
|
|
763
785
|
for i in finished_indices:
|
764
786
|
req = batch.reqs[i]
|
765
787
|
self.tree_cache.cache_req(
|
788
|
+
rid=req.rid,
|
766
789
|
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
|
767
790
|
last_uncached_pos=len(req.prefix_indices),
|
768
791
|
req_pool_idx=req_pool_indices_cpu[i],
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
|
4
|
+
class BasePrefixCache(ABC):
|
5
|
+
"""Cache can be indexed by either rid or key."""
|
6
|
+
|
7
|
+
@abstractmethod
|
8
|
+
def reset(self):
|
9
|
+
pass
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def match_prefix(self, **kwargs):
|
13
|
+
pass
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def insert(self, **kwargs):
|
17
|
+
pass
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def cache_req(self, **kwargs):
|
21
|
+
pass
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
def evict(self, num_tokens, evict_callback):
|
25
|
+
pass
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def inc_lock_ref(self, node):
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
def dec_lock_ref(self, node):
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def evictable_size(self):
|
37
|
+
pass
|
38
|
+
|
39
|
+
def total_size(self):
|
40
|
+
raise NotImplementedError
|
41
|
+
|
42
|
+
def pretty_print(self):
|
43
|
+
raise NotImplementedError
|
@@ -0,0 +1,60 @@
|
|
1
|
+
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
2
|
+
|
3
|
+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
4
|
+
|
5
|
+
|
6
|
+
class ChunkCacheEntry:
|
7
|
+
def __init__(self, rid, value):
|
8
|
+
self.rid = rid
|
9
|
+
self.value = value
|
10
|
+
|
11
|
+
|
12
|
+
class ChunkCache(BasePrefixCache):
|
13
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool):
|
14
|
+
self.disable = True
|
15
|
+
self.req_to_token_pool = req_to_token_pool
|
16
|
+
self.token_to_kv_pool = token_to_kv_pool
|
17
|
+
|
18
|
+
self.reset()
|
19
|
+
|
20
|
+
def reset(self):
|
21
|
+
self.entries = {}
|
22
|
+
|
23
|
+
def match_prefix(self, rid, **kwargs):
|
24
|
+
if rid not in self.entries:
|
25
|
+
return [], None
|
26
|
+
|
27
|
+
entry = self.entries[rid]
|
28
|
+
return entry.value, entry
|
29
|
+
|
30
|
+
def cache_req(
|
31
|
+
self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
|
32
|
+
):
|
33
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
34
|
+
if del_in_memory_pool:
|
35
|
+
assert rid in self.entries
|
36
|
+
self.req_to_token_pool.free(req_pool_idx)
|
37
|
+
self.token_to_kv_pool.free(indices)
|
38
|
+
return
|
39
|
+
|
40
|
+
if rid not in self.entries:
|
41
|
+
self.entries[rid] = ChunkCacheEntry(rid, indices)
|
42
|
+
|
43
|
+
entry = self.entries[rid]
|
44
|
+
entry.value = indices
|
45
|
+
return indices, entry
|
46
|
+
|
47
|
+
def insert(self):
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
def evict(self, num_tokens, evict_callback):
|
51
|
+
pass
|
52
|
+
|
53
|
+
def inc_lock_ref(self, node):
|
54
|
+
return 0
|
55
|
+
|
56
|
+
def dec_lock_ref(self, node):
|
57
|
+
return 0
|
58
|
+
|
59
|
+
def evictable_size(self):
|
60
|
+
return 0
|
@@ -23,6 +23,8 @@ from collections import defaultdict
|
|
23
23
|
|
24
24
|
import torch
|
25
25
|
|
26
|
+
from sglang.srt.mem_cache.base_cache import BasePrefixCache
|
27
|
+
|
26
28
|
|
27
29
|
class TreeNode:
|
28
30
|
def __init__(self):
|
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
|
|
46
48
|
return i
|
47
49
|
|
48
50
|
|
49
|
-
class RadixCache:
|
51
|
+
class RadixCache(BasePrefixCache):
|
50
52
|
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
51
53
|
self.req_to_token_pool = req_to_token_pool
|
52
54
|
self.token_to_kv_pool = token_to_kv_pool
|
@@ -62,7 +64,7 @@ class RadixCache:
|
|
62
64
|
self.root_node.lock_ref = 1
|
63
65
|
self.evictable_size_ = 0
|
64
66
|
|
65
|
-
def match_prefix(self, key):
|
67
|
+
def match_prefix(self, key, **kwargs):
|
66
68
|
if self.disable:
|
67
69
|
return [], self.root_node
|
68
70
|
|
@@ -90,6 +92,7 @@ class RadixCache:
|
|
90
92
|
req_pool_idx,
|
91
93
|
del_in_memory_pool=True,
|
92
94
|
old_last_node=None,
|
95
|
+
**kwargs,
|
93
96
|
):
|
94
97
|
# Insert the request into radix cache
|
95
98
|
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
@@ -19,6 +19,7 @@ import importlib
|
|
19
19
|
import importlib.resources
|
20
20
|
import logging
|
21
21
|
import pkgutil
|
22
|
+
import warnings
|
22
23
|
from functools import lru_cache
|
23
24
|
from typing import Optional, Type
|
24
25
|
|
@@ -121,7 +122,11 @@ class ModelRunner:
|
|
121
122
|
|
122
123
|
# Load the model and create memory pool
|
123
124
|
self.load_model()
|
124
|
-
self.init_memory_pool(
|
125
|
+
self.init_memory_pool(
|
126
|
+
total_gpu_memory,
|
127
|
+
server_args.max_num_reqs,
|
128
|
+
server_args.max_total_tokens,
|
129
|
+
)
|
125
130
|
self.init_cublas()
|
126
131
|
self.init_flash_infer()
|
127
132
|
|
@@ -203,8 +208,18 @@ class ModelRunner:
|
|
203
208
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
204
209
|
return max_num_token
|
205
210
|
|
206
|
-
def init_memory_pool(
|
211
|
+
def init_memory_pool(
|
212
|
+
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
|
213
|
+
):
|
207
214
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
215
|
+
if max_total_tokens is not None:
|
216
|
+
if max_total_tokens > self.max_total_num_tokens:
|
217
|
+
warnings.warn(
|
218
|
+
f"max_total_tokens={max_total_tokens} is larger than the profiled value "
|
219
|
+
f"{self.max_total_num_tokens}. "
|
220
|
+
f"Use the profiled value instead."
|
221
|
+
)
|
222
|
+
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
208
223
|
|
209
224
|
if self.max_total_num_tokens <= 0:
|
210
225
|
raise RuntimeError(
|
sglang/srt/models/llama2.py
CHANGED
@@ -26,6 +26,11 @@ from vllm.config import CacheConfig
|
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
27
|
from vllm.model_executor.layers.activation import SiluAndMul
|
28
28
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
|
+
from vllm.model_executor.layers.linear import (
|
30
|
+
MergedColumnParallelLinear,
|
31
|
+
QKVParallelLinear,
|
32
|
+
RowParallelLinear,
|
33
|
+
)
|
29
34
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
30
35
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
36
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -38,10 +43,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
38
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
44
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
40
45
|
|
41
|
-
MergedColumnParallelLinear = None
|
42
|
-
QKVParallelLinear = None
|
43
|
-
RowParallelLinear = None
|
44
|
-
|
45
46
|
|
46
47
|
class LlamaMLP(nn.Module):
|
47
48
|
def __init__(
|
@@ -295,23 +296,6 @@ class LlamaForCausalLM(nn.Module):
|
|
295
296
|
cache_config: Optional[CacheConfig] = None,
|
296
297
|
efficient_weight_load=False,
|
297
298
|
) -> None:
|
298
|
-
global MergedColumnParallelLinear
|
299
|
-
global QKVParallelLinear
|
300
|
-
global RowParallelLinear
|
301
|
-
|
302
|
-
if efficient_weight_load:
|
303
|
-
from sglang.srt.layers.linear import (
|
304
|
-
MergedColumnParallelLinear,
|
305
|
-
QKVParallelLinear,
|
306
|
-
RowParallelLinear,
|
307
|
-
)
|
308
|
-
else:
|
309
|
-
from vllm.model_executor.layers.linear import (
|
310
|
-
MergedColumnParallelLinear,
|
311
|
-
QKVParallelLinear,
|
312
|
-
RowParallelLinear,
|
313
|
-
)
|
314
|
-
|
315
299
|
super().__init__()
|
316
300
|
self.config = config
|
317
301
|
self.quant_config = quant_config
|