sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
"""ModelRunner runs the forward passes of the models."""
|
2
|
+
|
1
3
|
import importlib
|
2
4
|
import importlib.resources
|
3
5
|
import logging
|
@@ -11,15 +13,19 @@ import torch
|
|
11
13
|
import torch.nn as nn
|
12
14
|
from vllm.config import DeviceConfig, LoadConfig
|
13
15
|
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import
|
16
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
15
17
|
from vllm.model_executor.model_loader import get_model
|
16
18
|
from vllm.model_executor.models import ModelRegistry
|
17
19
|
|
18
20
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
19
21
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
22
|
from sglang.srt.server_args import ServerArgs
|
21
|
-
from sglang.srt.utils import
|
22
|
-
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
get_available_gpu_memory,
|
25
|
+
is_multimodal_model,
|
26
|
+
monkey_patch_vllm_dummy_weight_loader,
|
27
|
+
monkey_patch_vllm_p2p_access_check,
|
28
|
+
)
|
23
29
|
|
24
30
|
logger = logging.getLogger("srt.model_runner")
|
25
31
|
|
@@ -29,7 +35,6 @@ global_server_args_dict = {}
|
|
29
35
|
|
30
36
|
@dataclass
|
31
37
|
class InputMetadata:
|
32
|
-
model_runner: "ModelRunner"
|
33
38
|
forward_mode: ForwardMode
|
34
39
|
batch_size: int
|
35
40
|
total_num_tokens: int
|
@@ -60,73 +65,82 @@ class InputMetadata:
|
|
60
65
|
kv_indptr: torch.Tensor = None
|
61
66
|
kv_indices: torch.Tensor = None
|
62
67
|
kv_last_page_len: torch.Tensor = None
|
63
|
-
|
64
|
-
|
68
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
65
71
|
|
66
|
-
def init_flashinfer_args(self,
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
)
|
72
|
+
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
+
if (
|
74
|
+
self.forward_mode == ForwardMode.PREFILL
|
75
|
+
or self.forward_mode == ForwardMode.EXTEND
|
76
|
+
):
|
77
|
+
paged_kernel_lens = self.prefix_lens
|
78
|
+
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
+
else:
|
80
|
+
paged_kernel_lens = self.seq_lens
|
71
81
|
|
72
82
|
self.kv_indptr = torch.zeros(
|
73
83
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
74
84
|
)
|
75
|
-
self.kv_indptr[1:] = torch.cumsum(
|
85
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
76
86
|
self.kv_last_page_len = torch.ones(
|
77
87
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
78
88
|
)
|
79
89
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
80
|
-
|
90
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
81
91
|
self.kv_indices = torch.cat(
|
82
92
|
[
|
83
93
|
self.req_to_token_pool.req_to_token[
|
84
|
-
req_pool_indices_cpu[i], :
|
94
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
85
95
|
]
|
86
96
|
for i in range(self.batch_size)
|
87
97
|
],
|
88
98
|
dim=0,
|
89
99
|
).contiguous()
|
90
100
|
|
91
|
-
workspace_buffer = torch.empty(
|
92
|
-
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
93
|
-
)
|
94
101
|
if (
|
95
102
|
self.forward_mode == ForwardMode.PREFILL
|
96
103
|
or self.forward_mode == ForwardMode.EXTEND
|
97
104
|
):
|
105
|
+
# extend part
|
98
106
|
self.qo_indptr = torch.zeros(
|
99
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
100
108
|
)
|
101
109
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
102
|
-
|
103
|
-
|
110
|
+
|
111
|
+
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
+
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
+
self.qo_indptr,
|
114
|
+
self.qo_indptr.clone(),
|
115
|
+
num_qo_heads,
|
116
|
+
num_kv_heads,
|
117
|
+
head_dim,
|
104
118
|
)
|
105
|
-
|
119
|
+
|
120
|
+
# cached part
|
121
|
+
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
+
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
106
123
|
self.qo_indptr,
|
107
124
|
self.kv_indptr,
|
108
125
|
self.kv_indices,
|
109
126
|
self.kv_last_page_len,
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
self.prefill_wrapper.begin_forward(*args)
|
116
|
-
else:
|
117
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
118
|
-
workspace_buffer, "NHD"
|
127
|
+
num_qo_heads,
|
128
|
+
num_kv_heads,
|
129
|
+
head_dim,
|
130
|
+
1
|
119
131
|
)
|
120
|
-
|
132
|
+
else:
|
133
|
+
self.flashinfer_decode_wrapper.end_forward()
|
134
|
+
self.flashinfer_decode_wrapper.begin_forward(
|
121
135
|
self.kv_indptr,
|
122
136
|
self.kv_indices,
|
123
137
|
self.kv_last_page_len,
|
124
|
-
|
125
|
-
|
126
|
-
|
138
|
+
num_qo_heads,
|
139
|
+
num_kv_heads,
|
140
|
+
head_dim,
|
127
141
|
1,
|
128
|
-
"NONE",
|
129
|
-
|
142
|
+
pos_encoding_mode="NONE",
|
143
|
+
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
130
144
|
)
|
131
145
|
|
132
146
|
def init_extend_args(self):
|
@@ -150,6 +164,9 @@ class InputMetadata:
|
|
150
164
|
out_cache_cont_end=None,
|
151
165
|
top_logprobs_nums=None,
|
152
166
|
return_logprob=False,
|
167
|
+
flashinfer_prefill_wrapper_ragged=None,
|
168
|
+
flashinfer_prefill_wrapper_paged=None,
|
169
|
+
flashinfer_decode_wrapper=None,
|
153
170
|
):
|
154
171
|
batch_size = len(req_pool_indices)
|
155
172
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
@@ -182,7 +199,6 @@ class InputMetadata:
|
|
182
199
|
other_kv_index = None
|
183
200
|
|
184
201
|
ret = cls(
|
185
|
-
model_runner=model_runner,
|
186
202
|
forward_mode=forward_mode,
|
187
203
|
batch_size=batch_size,
|
188
204
|
total_num_tokens=total_num_tokens,
|
@@ -200,13 +216,20 @@ class InputMetadata:
|
|
200
216
|
other_kv_index=other_kv_index,
|
201
217
|
return_logprob=return_logprob,
|
202
218
|
top_logprobs_nums=top_logprobs_nums,
|
219
|
+
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
+
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
+
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
203
222
|
)
|
204
223
|
|
205
224
|
if forward_mode == ForwardMode.EXTEND:
|
206
225
|
ret.init_extend_args()
|
207
226
|
|
208
|
-
if global_server_args_dict.get("
|
209
|
-
ret.init_flashinfer_args(
|
227
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
+
ret.init_flashinfer_args(
|
229
|
+
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
+
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
+
model_runner.model_config.head_dim
|
232
|
+
)
|
210
233
|
|
211
234
|
return ret
|
212
235
|
|
@@ -229,24 +252,24 @@ class ModelRunner:
|
|
229
252
|
self.tp_size = tp_size
|
230
253
|
self.nccl_port = nccl_port
|
231
254
|
self.server_args = server_args
|
232
|
-
|
233
|
-
|
234
|
-
global_server_args_dict = {
|
235
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
236
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
237
|
-
}
|
255
|
+
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
256
|
+
monkey_patch_vllm_dummy_weight_loader()
|
238
257
|
|
239
258
|
# Init torch distributed
|
240
259
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
241
260
|
torch.cuda.set_device(self.gpu_id)
|
242
261
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
243
|
-
monkey_patch_vllm_p2p_access_check()
|
262
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
263
|
+
if server_args.nccl_init_addr:
|
264
|
+
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
265
|
+
else:
|
266
|
+
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
244
267
|
init_distributed_environment(
|
245
268
|
backend="nccl",
|
246
269
|
world_size=self.tp_size,
|
247
270
|
rank=self.tp_rank,
|
248
271
|
local_rank=self.gpu_id,
|
249
|
-
distributed_init_method=
|
272
|
+
distributed_init_method=nccl_init_method
|
250
273
|
)
|
251
274
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
252
275
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -260,9 +283,18 @@ class ModelRunner:
|
|
260
283
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
261
284
|
)
|
262
285
|
|
286
|
+
# Set some global args
|
287
|
+
global global_server_args_dict
|
288
|
+
global_server_args_dict = {
|
289
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
290
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
291
|
+
}
|
292
|
+
|
293
|
+
# Load the model and create memory pool
|
263
294
|
self.load_model()
|
264
295
|
self.init_memory_pool(total_gpu_memory)
|
265
|
-
self.
|
296
|
+
self.init_cublas()
|
297
|
+
self.init_flash_infer()
|
266
298
|
|
267
299
|
def load_model(self):
|
268
300
|
logger.info(
|
@@ -278,10 +310,11 @@ class ModelRunner:
|
|
278
310
|
tokenizer=None,
|
279
311
|
tokenizer_mode=None,
|
280
312
|
trust_remote_code=self.server_args.trust_remote_code,
|
281
|
-
dtype=
|
313
|
+
dtype=self.server_args.dtype,
|
282
314
|
seed=42,
|
283
315
|
skip_tokenizer_init=True,
|
284
316
|
)
|
317
|
+
self.dtype = vllm_model_config.dtype
|
285
318
|
if self.model_config.model_overide_args is not None:
|
286
319
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
287
320
|
|
@@ -298,6 +331,7 @@ class ModelRunner:
|
|
298
331
|
logger.info(
|
299
332
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
300
333
|
f"type={type(self.model).__name__}, "
|
334
|
+
f"dtype={self.dtype}, "
|
301
335
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
302
336
|
)
|
303
337
|
|
@@ -306,8 +340,8 @@ class ModelRunner:
|
|
306
340
|
self.gpu_id, distributed=self.tp_size > 1
|
307
341
|
)
|
308
342
|
head_dim = self.model_config.head_dim
|
309
|
-
head_num = self.model_config.
|
310
|
-
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 *
|
343
|
+
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
344
|
+
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
|
311
345
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
312
346
|
1 - self.mem_fraction_static
|
313
347
|
)
|
@@ -319,7 +353,7 @@ class ModelRunner:
|
|
319
353
|
|
320
354
|
if self.max_total_num_tokens <= 0:
|
321
355
|
raise RuntimeError(
|
322
|
-
"Not
|
356
|
+
"Not enough memory. Please try to increase --mem-fraction-static."
|
323
357
|
)
|
324
358
|
|
325
359
|
self.req_to_token_pool = ReqToTokenPool(
|
@@ -328,8 +362,8 @@ class ModelRunner:
|
|
328
362
|
)
|
329
363
|
self.token_to_kv_pool = TokenToKVPool(
|
330
364
|
self.max_total_num_tokens,
|
331
|
-
dtype=
|
332
|
-
head_num=self.model_config.
|
365
|
+
dtype=self.dtype,
|
366
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
333
367
|
head_dim=self.model_config.head_dim,
|
334
368
|
layer_num=self.model_config.num_hidden_layers,
|
335
369
|
)
|
@@ -338,6 +372,47 @@ class ModelRunner:
|
|
338
372
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
339
373
|
)
|
340
374
|
|
375
|
+
def init_cublas(self):
|
376
|
+
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
377
|
+
dtype = torch.float16
|
378
|
+
device = "cuda"
|
379
|
+
a = torch.ones((16, 16), dtype=dtype, device=device)
|
380
|
+
b = torch.ones((16, 16), dtype=dtype, device=device)
|
381
|
+
c = a @ b
|
382
|
+
return c
|
383
|
+
|
384
|
+
def init_flash_infer(self):
|
385
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
386
|
+
from flashinfer import (
|
387
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
388
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
389
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
390
|
+
)
|
391
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
392
|
+
|
393
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
394
|
+
self.model_config.num_attention_heads // self.tp_size,
|
395
|
+
self.model_config.get_num_kv_heads(self.tp_size)):
|
396
|
+
use_tensor_cores = True
|
397
|
+
else:
|
398
|
+
use_tensor_cores = False
|
399
|
+
|
400
|
+
workspace_buffers = torch.empty(
|
401
|
+
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
402
|
+
)
|
403
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
404
|
+
workspace_buffers[0], "NHD"
|
405
|
+
)
|
406
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
407
|
+
workspace_buffers[1], "NHD"
|
408
|
+
)
|
409
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
410
|
+
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
411
|
+
)
|
412
|
+
else:
|
413
|
+
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
|
414
|
+
self.flashinfer_decode_wrapper = None
|
415
|
+
|
341
416
|
@torch.inference_mode()
|
342
417
|
def forward_prefill(self, batch: Batch):
|
343
418
|
input_metadata = InputMetadata.create(
|
@@ -351,6 +426,9 @@ class ModelRunner:
|
|
351
426
|
out_cache_loc=batch.out_cache_loc,
|
352
427
|
top_logprobs_nums=batch.top_logprobs_nums,
|
353
428
|
return_logprob=batch.return_logprob,
|
429
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
430
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
431
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
354
432
|
)
|
355
433
|
return self.model.forward(
|
356
434
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -369,6 +447,9 @@ class ModelRunner:
|
|
369
447
|
out_cache_loc=batch.out_cache_loc,
|
370
448
|
top_logprobs_nums=batch.top_logprobs_nums,
|
371
449
|
return_logprob=batch.return_logprob,
|
450
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
451
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
452
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
372
453
|
)
|
373
454
|
return self.model.forward(
|
374
455
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -389,6 +470,9 @@ class ModelRunner:
|
|
389
470
|
out_cache_cont_end=batch.out_cache_cont_end,
|
390
471
|
top_logprobs_nums=batch.top_logprobs_nums,
|
391
472
|
return_logprob=batch.return_logprob,
|
473
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
474
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
475
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
392
476
|
)
|
393
477
|
return self.model.forward(
|
394
478
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -407,6 +491,9 @@ class ModelRunner:
|
|
407
491
|
out_cache_loc=batch.out_cache_loc,
|
408
492
|
top_logprobs_nums=batch.top_logprobs_nums,
|
409
493
|
return_logprob=batch.return_logprob,
|
494
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
495
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
496
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
410
497
|
)
|
411
498
|
return self.model.forward(
|
412
499
|
batch.input_ids,
|
@@ -440,16 +527,29 @@ def import_model_classes():
|
|
440
527
|
module = importlib.import_module(name)
|
441
528
|
if hasattr(module, "EntryClass"):
|
442
529
|
entry = module.EntryClass
|
443
|
-
if isinstance(
|
530
|
+
if isinstance(
|
531
|
+
entry, list
|
532
|
+
): # To support multiple model classes in one module
|
444
533
|
for tmp in entry:
|
445
534
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
446
535
|
else:
|
447
536
|
model_arch_name_to_cls[entry.__name__] = entry
|
537
|
+
|
538
|
+
# compat: some models such as chatglm has incorrect class set in config.json
|
539
|
+
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
540
|
+
if hasattr(module, "EntryClassRemapping") and isinstance(
|
541
|
+
module.EntryClassRemapping, list
|
542
|
+
):
|
543
|
+
for remap in module.EntryClassRemapping:
|
544
|
+
if isinstance(remap, tuple) and len(remap) == 2:
|
545
|
+
model_arch_name_to_cls[remap[0]] = remap[1]
|
546
|
+
|
448
547
|
return model_arch_name_to_cls
|
449
548
|
|
450
549
|
|
451
550
|
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
452
551
|
model_arch_name_to_cls = import_model_classes()
|
552
|
+
|
453
553
|
if model_arch not in model_arch_name_to_cls:
|
454
554
|
raise ValueError(
|
455
555
|
f"Unsupported architectures: {model_arch}. "
|