sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
"""ModelRunner runs the forward passes of the models."""
|
2
|
+
|
1
3
|
import importlib
|
2
4
|
import importlib.resources
|
3
5
|
import logging
|
@@ -11,15 +13,19 @@ import torch
|
|
11
13
|
import torch.nn as nn
|
12
14
|
from vllm.config import DeviceConfig, LoadConfig
|
13
15
|
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import
|
16
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
15
17
|
from vllm.model_executor.model_loader import get_model
|
16
18
|
from vllm.model_executor.models import ModelRegistry
|
17
19
|
|
18
20
|
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
19
21
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
22
|
from sglang.srt.server_args import ServerArgs
|
21
|
-
from sglang.srt.utils import
|
22
|
-
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
get_available_gpu_memory,
|
25
|
+
is_multimodal_model,
|
26
|
+
monkey_patch_vllm_dummy_weight_loader,
|
27
|
+
monkey_patch_vllm_p2p_access_check,
|
28
|
+
)
|
23
29
|
|
24
30
|
logger = logging.getLogger("srt.model_runner")
|
25
31
|
|
@@ -29,7 +35,6 @@ global_server_args_dict = {}
|
|
29
35
|
|
30
36
|
@dataclass
|
31
37
|
class InputMetadata:
|
32
|
-
model_runner: "ModelRunner"
|
33
38
|
forward_mode: ForwardMode
|
34
39
|
batch_size: int
|
35
40
|
total_num_tokens: int
|
@@ -60,73 +65,82 @@ class InputMetadata:
|
|
60
65
|
kv_indptr: torch.Tensor = None
|
61
66
|
kv_indices: torch.Tensor = None
|
62
67
|
kv_last_page_len: torch.Tensor = None
|
63
|
-
|
64
|
-
|
68
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
65
71
|
|
66
|
-
def init_flashinfer_args(self,
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
)
|
72
|
+
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
+
if (
|
74
|
+
self.forward_mode == ForwardMode.PREFILL
|
75
|
+
or self.forward_mode == ForwardMode.EXTEND
|
76
|
+
):
|
77
|
+
paged_kernel_lens = self.prefix_lens
|
78
|
+
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
+
else:
|
80
|
+
paged_kernel_lens = self.seq_lens
|
71
81
|
|
72
82
|
self.kv_indptr = torch.zeros(
|
73
83
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
74
84
|
)
|
75
|
-
self.kv_indptr[1:] = torch.cumsum(
|
85
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
76
86
|
self.kv_last_page_len = torch.ones(
|
77
87
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
78
88
|
)
|
79
89
|
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
80
|
-
|
90
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
81
91
|
self.kv_indices = torch.cat(
|
82
92
|
[
|
83
93
|
self.req_to_token_pool.req_to_token[
|
84
|
-
req_pool_indices_cpu[i], :
|
94
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
85
95
|
]
|
86
96
|
for i in range(self.batch_size)
|
87
97
|
],
|
88
98
|
dim=0,
|
89
99
|
).contiguous()
|
90
100
|
|
91
|
-
workspace_buffer = torch.empty(
|
92
|
-
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
93
|
-
)
|
94
101
|
if (
|
95
102
|
self.forward_mode == ForwardMode.PREFILL
|
96
103
|
or self.forward_mode == ForwardMode.EXTEND
|
97
104
|
):
|
105
|
+
# extend part
|
98
106
|
self.qo_indptr = torch.zeros(
|
99
107
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
100
108
|
)
|
101
109
|
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
102
|
-
|
103
|
-
|
110
|
+
|
111
|
+
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
+
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
+
self.qo_indptr,
|
114
|
+
self.qo_indptr.clone(),
|
115
|
+
num_qo_heads,
|
116
|
+
num_kv_heads,
|
117
|
+
head_dim,
|
104
118
|
)
|
105
|
-
|
119
|
+
|
120
|
+
# cached part
|
121
|
+
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
+
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
106
123
|
self.qo_indptr,
|
107
124
|
self.kv_indptr,
|
108
125
|
self.kv_indices,
|
109
126
|
self.kv_last_page_len,
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
self.prefill_wrapper.begin_forward(*args)
|
116
|
-
else:
|
117
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
118
|
-
workspace_buffer, "NHD"
|
127
|
+
num_qo_heads,
|
128
|
+
num_kv_heads,
|
129
|
+
head_dim,
|
130
|
+
1,
|
119
131
|
)
|
120
|
-
|
132
|
+
else:
|
133
|
+
self.flashinfer_decode_wrapper.end_forward()
|
134
|
+
self.flashinfer_decode_wrapper.begin_forward(
|
121
135
|
self.kv_indptr,
|
122
136
|
self.kv_indices,
|
123
137
|
self.kv_last_page_len,
|
124
|
-
|
125
|
-
|
126
|
-
|
138
|
+
num_qo_heads,
|
139
|
+
num_kv_heads,
|
140
|
+
head_dim,
|
127
141
|
1,
|
128
|
-
"NONE",
|
129
|
-
|
142
|
+
pos_encoding_mode="NONE",
|
143
|
+
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
130
144
|
)
|
131
145
|
|
132
146
|
def init_extend_args(self):
|
@@ -150,6 +164,9 @@ class InputMetadata:
|
|
150
164
|
out_cache_cont_end=None,
|
151
165
|
top_logprobs_nums=None,
|
152
166
|
return_logprob=False,
|
167
|
+
flashinfer_prefill_wrapper_ragged=None,
|
168
|
+
flashinfer_prefill_wrapper_paged=None,
|
169
|
+
flashinfer_decode_wrapper=None,
|
153
170
|
):
|
154
171
|
batch_size = len(req_pool_indices)
|
155
172
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
@@ -182,7 +199,6 @@ class InputMetadata:
|
|
182
199
|
other_kv_index = None
|
183
200
|
|
184
201
|
ret = cls(
|
185
|
-
model_runner=model_runner,
|
186
202
|
forward_mode=forward_mode,
|
187
203
|
batch_size=batch_size,
|
188
204
|
total_num_tokens=total_num_tokens,
|
@@ -200,13 +216,20 @@ class InputMetadata:
|
|
200
216
|
other_kv_index=other_kv_index,
|
201
217
|
return_logprob=return_logprob,
|
202
218
|
top_logprobs_nums=top_logprobs_nums,
|
219
|
+
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
+
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
+
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
203
222
|
)
|
204
223
|
|
205
224
|
if forward_mode == ForwardMode.EXTEND:
|
206
225
|
ret.init_extend_args()
|
207
226
|
|
208
|
-
if global_server_args_dict.get("
|
209
|
-
ret.init_flashinfer_args(
|
227
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
+
ret.init_flashinfer_args(
|
229
|
+
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
+
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
+
model_runner.model_config.head_dim,
|
232
|
+
)
|
210
233
|
|
211
234
|
return ret
|
212
235
|
|
@@ -229,24 +252,27 @@ class ModelRunner:
|
|
229
252
|
self.tp_size = tp_size
|
230
253
|
self.nccl_port = nccl_port
|
231
254
|
self.server_args = server_args
|
232
|
-
|
233
|
-
|
234
|
-
global_server_args_dict = {
|
235
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
236
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
237
|
-
}
|
255
|
+
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
256
|
+
monkey_patch_vllm_dummy_weight_loader()
|
238
257
|
|
239
258
|
# Init torch distributed
|
240
259
|
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
241
260
|
torch.cuda.set_device(self.gpu_id)
|
242
261
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
243
|
-
|
262
|
+
|
263
|
+
if not server_args.enable_p2p_check:
|
264
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
265
|
+
|
266
|
+
if server_args.nccl_init_addr:
|
267
|
+
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
268
|
+
else:
|
269
|
+
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
244
270
|
init_distributed_environment(
|
245
271
|
backend="nccl",
|
246
272
|
world_size=self.tp_size,
|
247
273
|
rank=self.tp_rank,
|
248
274
|
local_rank=self.gpu_id,
|
249
|
-
distributed_init_method=
|
275
|
+
distributed_init_method=nccl_init_method,
|
250
276
|
)
|
251
277
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
252
278
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -260,9 +286,18 @@ class ModelRunner:
|
|
260
286
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
261
287
|
)
|
262
288
|
|
289
|
+
# Set some global args
|
290
|
+
global global_server_args_dict
|
291
|
+
global_server_args_dict = {
|
292
|
+
"disable_flashinfer": server_args.disable_flashinfer,
|
293
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
294
|
+
}
|
295
|
+
|
296
|
+
# Load the model and create memory pool
|
263
297
|
self.load_model()
|
264
298
|
self.init_memory_pool(total_gpu_memory)
|
265
|
-
self.
|
299
|
+
self.init_cublas()
|
300
|
+
self.init_flash_infer()
|
266
301
|
|
267
302
|
def load_model(self):
|
268
303
|
logger.info(
|
@@ -278,10 +313,11 @@ class ModelRunner:
|
|
278
313
|
tokenizer=None,
|
279
314
|
tokenizer_mode=None,
|
280
315
|
trust_remote_code=self.server_args.trust_remote_code,
|
281
|
-
dtype=
|
316
|
+
dtype=self.server_args.dtype,
|
282
317
|
seed=42,
|
283
318
|
skip_tokenizer_init=True,
|
284
319
|
)
|
320
|
+
self.dtype = vllm_model_config.dtype
|
285
321
|
if self.model_config.model_overide_args is not None:
|
286
322
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
287
323
|
|
@@ -290,7 +326,7 @@ class ModelRunner:
|
|
290
326
|
device_config=device_config,
|
291
327
|
load_config=load_config,
|
292
328
|
lora_config=None,
|
293
|
-
|
329
|
+
multimodal_config=None,
|
294
330
|
parallel_config=None,
|
295
331
|
scheduler_config=None,
|
296
332
|
cache_config=None,
|
@@ -298,6 +334,7 @@ class ModelRunner:
|
|
298
334
|
logger.info(
|
299
335
|
f"[gpu_id={self.gpu_id}] Load weight end. "
|
300
336
|
f"type={type(self.model).__name__}, "
|
337
|
+
f"dtype={self.dtype}, "
|
301
338
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
302
339
|
)
|
303
340
|
|
@@ -306,8 +343,14 @@ class ModelRunner:
|
|
306
343
|
self.gpu_id, distributed=self.tp_size > 1
|
307
344
|
)
|
308
345
|
head_dim = self.model_config.head_dim
|
309
|
-
head_num = self.model_config.
|
310
|
-
cell_size =
|
346
|
+
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
347
|
+
cell_size = (
|
348
|
+
head_num
|
349
|
+
* head_dim
|
350
|
+
* self.model_config.num_hidden_layers
|
351
|
+
* 2
|
352
|
+
* torch._utils._element_size(self.dtype)
|
353
|
+
)
|
311
354
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
312
355
|
1 - self.mem_fraction_static
|
313
356
|
)
|
@@ -319,7 +362,7 @@ class ModelRunner:
|
|
319
362
|
|
320
363
|
if self.max_total_num_tokens <= 0:
|
321
364
|
raise RuntimeError(
|
322
|
-
"Not
|
365
|
+
"Not enough memory. Please try to increase --mem-fraction-static."
|
323
366
|
)
|
324
367
|
|
325
368
|
self.req_to_token_pool = ReqToTokenPool(
|
@@ -328,8 +371,8 @@ class ModelRunner:
|
|
328
371
|
)
|
329
372
|
self.token_to_kv_pool = TokenToKVPool(
|
330
373
|
self.max_total_num_tokens,
|
331
|
-
dtype=
|
332
|
-
head_num=self.model_config.
|
374
|
+
dtype=self.dtype,
|
375
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
333
376
|
head_dim=self.model_config.head_dim,
|
334
377
|
layer_num=self.model_config.num_hidden_layers,
|
335
378
|
)
|
@@ -338,6 +381,50 @@ class ModelRunner:
|
|
338
381
|
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
339
382
|
)
|
340
383
|
|
384
|
+
def init_cublas(self):
|
385
|
+
"""We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
|
386
|
+
dtype = torch.float16
|
387
|
+
device = "cuda"
|
388
|
+
a = torch.ones((16, 16), dtype=dtype, device=device)
|
389
|
+
b = torch.ones((16, 16), dtype=dtype, device=device)
|
390
|
+
c = a @ b
|
391
|
+
return c
|
392
|
+
|
393
|
+
def init_flash_infer(self):
|
394
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
395
|
+
from flashinfer import (
|
396
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
397
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
398
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
399
|
+
)
|
400
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
401
|
+
|
402
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
403
|
+
self.model_config.num_attention_heads // self.tp_size,
|
404
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
405
|
+
):
|
406
|
+
use_tensor_cores = True
|
407
|
+
else:
|
408
|
+
use_tensor_cores = False
|
409
|
+
|
410
|
+
workspace_buffers = torch.empty(
|
411
|
+
2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
412
|
+
)
|
413
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
414
|
+
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
|
415
|
+
)
|
416
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
417
|
+
workspace_buffers[1], "NHD"
|
418
|
+
)
|
419
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
420
|
+
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
421
|
+
)
|
422
|
+
else:
|
423
|
+
self.flashinfer_prefill_wrapper_ragged = (
|
424
|
+
self.flashinfer_prefill_wrapper_paged
|
425
|
+
) = None
|
426
|
+
self.flashinfer_decode_wrapper = None
|
427
|
+
|
341
428
|
@torch.inference_mode()
|
342
429
|
def forward_prefill(self, batch: Batch):
|
343
430
|
input_metadata = InputMetadata.create(
|
@@ -351,6 +438,9 @@ class ModelRunner:
|
|
351
438
|
out_cache_loc=batch.out_cache_loc,
|
352
439
|
top_logprobs_nums=batch.top_logprobs_nums,
|
353
440
|
return_logprob=batch.return_logprob,
|
441
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
442
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
443
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
354
444
|
)
|
355
445
|
return self.model.forward(
|
356
446
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -369,6 +459,9 @@ class ModelRunner:
|
|
369
459
|
out_cache_loc=batch.out_cache_loc,
|
370
460
|
top_logprobs_nums=batch.top_logprobs_nums,
|
371
461
|
return_logprob=batch.return_logprob,
|
462
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
463
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
464
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
372
465
|
)
|
373
466
|
return self.model.forward(
|
374
467
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -389,6 +482,9 @@ class ModelRunner:
|
|
389
482
|
out_cache_cont_end=batch.out_cache_cont_end,
|
390
483
|
top_logprobs_nums=batch.top_logprobs_nums,
|
391
484
|
return_logprob=batch.return_logprob,
|
485
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
486
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
487
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
392
488
|
)
|
393
489
|
return self.model.forward(
|
394
490
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -407,6 +503,9 @@ class ModelRunner:
|
|
407
503
|
out_cache_loc=batch.out_cache_loc,
|
408
504
|
top_logprobs_nums=batch.top_logprobs_nums,
|
409
505
|
return_logprob=batch.return_logprob,
|
506
|
+
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
507
|
+
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
508
|
+
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
410
509
|
)
|
411
510
|
return self.model.forward(
|
412
511
|
batch.input_ids,
|
@@ -440,16 +539,29 @@ def import_model_classes():
|
|
440
539
|
module = importlib.import_module(name)
|
441
540
|
if hasattr(module, "EntryClass"):
|
442
541
|
entry = module.EntryClass
|
443
|
-
if isinstance(
|
542
|
+
if isinstance(
|
543
|
+
entry, list
|
544
|
+
): # To support multiple model classes in one module
|
444
545
|
for tmp in entry:
|
445
546
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
446
547
|
else:
|
447
548
|
model_arch_name_to_cls[entry.__name__] = entry
|
549
|
+
|
550
|
+
# compat: some models such as chatglm has incorrect class set in config.json
|
551
|
+
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
552
|
+
if hasattr(module, "EntryClassRemapping") and isinstance(
|
553
|
+
module.EntryClassRemapping, list
|
554
|
+
):
|
555
|
+
for remap in module.EntryClassRemapping:
|
556
|
+
if isinstance(remap, tuple) and len(remap) == 2:
|
557
|
+
model_arch_name_to_cls[remap[0]] = remap[1]
|
558
|
+
|
448
559
|
return model_arch_name_to_cls
|
449
560
|
|
450
561
|
|
451
562
|
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
452
563
|
model_arch_name_to_cls = import_model_classes()
|
564
|
+
|
453
565
|
if model_arch not in model_arch_name_to_cls:
|
454
566
|
raise ValueError(
|
455
567
|
f"Unsupported architectures: {model_arch}. "
|