sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from flashinfer import (
|
29
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
30
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
31
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
32
|
-
)
|
33
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
28
|
from vllm.config import DeviceConfig, LoadConfig
|
35
29
|
from vllm.config import ModelConfig as VllmModelConfig
|
36
30
|
from vllm.distributed import (
|
@@ -43,32 +37,34 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
|
43
37
|
from vllm.model_executor.model_loader import get_model
|
44
38
|
from vllm.model_executor.models import ModelRegistry
|
45
39
|
|
46
|
-
from sglang.
|
40
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
|
+
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
47
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.sampler import
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
44
|
+
from sglang.srt.lora.lora_manager import LoRAManager
|
49
45
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
46
|
from sglang.srt.mem_cache.memory_pool import (
|
51
47
|
MHATokenToKVPool,
|
52
48
|
MLATokenToKVPool,
|
53
49
|
ReqToTokenPool,
|
54
50
|
)
|
55
|
-
from sglang.srt.
|
56
|
-
from sglang.srt.
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
52
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
53
|
from sglang.srt.server_args import ServerArgs
|
58
54
|
from sglang.srt.utils import (
|
59
55
|
get_available_gpu_memory,
|
60
56
|
is_generation_model,
|
61
|
-
is_llama3_405b_fp8_head_16,
|
62
57
|
is_multimodal_model,
|
63
58
|
monkey_patch_vllm_dummy_weight_loader,
|
64
59
|
monkey_patch_vllm_p2p_access_check,
|
65
|
-
monkey_patch_vllm_qvk_linear_loader,
|
66
60
|
)
|
67
61
|
|
68
62
|
logger = logging.getLogger(__name__)
|
69
63
|
|
70
64
|
|
71
65
|
class ModelRunner:
|
66
|
+
"""ModelRunner runs the forward passes of the models."""
|
67
|
+
|
72
68
|
def __init__(
|
73
69
|
self,
|
74
70
|
model_config: ModelConfig,
|
@@ -92,13 +88,15 @@ class ModelRunner:
|
|
92
88
|
)
|
93
89
|
global_server_args_dict.update(
|
94
90
|
{
|
95
|
-
"
|
96
|
-
"
|
91
|
+
"attention_backend": server_args.attention_backend,
|
92
|
+
"sampling_backend": server_args.sampling_backend,
|
97
93
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
98
94
|
"enable_mla": server_args.enable_mla,
|
95
|
+
"torchao_config": server_args.torchao_config,
|
99
96
|
}
|
100
97
|
)
|
101
98
|
|
99
|
+
# Model-specific adjustment
|
102
100
|
if self.is_multimodal_model:
|
103
101
|
logger.info(
|
104
102
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
@@ -106,15 +104,19 @@ class ModelRunner:
|
|
106
104
|
server_args.chunked_prefill_size = None
|
107
105
|
server_args.mem_fraction_static *= 0.95
|
108
106
|
|
107
|
+
# Init componnets
|
109
108
|
min_per_gpu_memory = self.init_torch_distributed()
|
109
|
+
self.sampler = Sampler()
|
110
110
|
self.load_model()
|
111
|
+
if server_args.lora_paths is not None:
|
112
|
+
self.init_lora_manager()
|
111
113
|
self.init_memory_pool(
|
112
114
|
min_per_gpu_memory,
|
113
|
-
server_args.
|
115
|
+
server_args.max_running_requests,
|
114
116
|
server_args.max_total_tokens,
|
115
117
|
)
|
116
118
|
self.init_cublas()
|
117
|
-
self.
|
119
|
+
self.init_attention_backend()
|
118
120
|
self.init_cuda_graphs()
|
119
121
|
|
120
122
|
def init_torch_distributed(self):
|
@@ -162,10 +164,13 @@ class ModelRunner:
|
|
162
164
|
return min_per_gpu_memory
|
163
165
|
|
164
166
|
def load_model(self):
|
165
|
-
torch.set_num_threads(1)
|
166
167
|
logger.info(
|
167
168
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
168
169
|
)
|
170
|
+
|
171
|
+
# This can reduce thread conflicts and speed up weight loading.
|
172
|
+
torch.set_num_threads(1)
|
173
|
+
|
169
174
|
if torch.cuda.get_device_capability()[0] < 8:
|
170
175
|
logger.info(
|
171
176
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
@@ -174,6 +179,7 @@ class ModelRunner:
|
|
174
179
|
if torch.cuda.get_device_capability()[1] < 5:
|
175
180
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
176
181
|
|
182
|
+
# Prepare the vllm model config
|
177
183
|
monkey_patch_vllm_dummy_weight_loader()
|
178
184
|
self.device_config = DeviceConfig()
|
179
185
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
@@ -184,23 +190,16 @@ class ModelRunner:
|
|
184
190
|
tokenizer_mode=None,
|
185
191
|
trust_remote_code=self.server_args.trust_remote_code,
|
186
192
|
dtype=self.server_args.dtype,
|
187
|
-
seed=
|
193
|
+
seed=self.server_args.random_seed,
|
188
194
|
skip_tokenizer_init=True,
|
189
195
|
)
|
190
|
-
|
191
|
-
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
192
|
-
# Drop this after Sept, 2024.
|
193
|
-
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
194
|
-
self.model_config.hf_config.num_key_value_heads = 8
|
195
|
-
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
196
|
-
monkey_patch_vllm_qvk_linear_loader()
|
197
|
-
|
198
|
-
self.dtype = self.vllm_model_config.dtype
|
199
196
|
if self.model_config.model_override_args is not None:
|
200
197
|
self.vllm_model_config.hf_config.update(
|
201
198
|
self.model_config.model_override_args
|
202
199
|
)
|
200
|
+
self.dtype = self.vllm_model_config.dtype
|
203
201
|
|
202
|
+
# Load the model
|
204
203
|
self.model = get_model(
|
205
204
|
model_config=self.vllm_model_config,
|
206
205
|
load_config=self.load_config,
|
@@ -251,20 +250,20 @@ class ModelRunner:
|
|
251
250
|
tokenizer_mode=None,
|
252
251
|
trust_remote_code=self.server_args.trust_remote_code,
|
253
252
|
dtype=self.server_args.dtype,
|
254
|
-
seed=
|
253
|
+
seed=self.server_args.random_seed,
|
255
254
|
skip_tokenizer_init=True,
|
256
255
|
)
|
257
256
|
except Exception as e:
|
258
|
-
|
259
|
-
return False,
|
257
|
+
message = f"Failed to load model config: {e}."
|
258
|
+
return False, message
|
260
259
|
|
261
260
|
load_config = LoadConfig(load_format=load_format)
|
262
261
|
|
263
262
|
# Only support vllm DefaultModelLoader for now
|
264
263
|
loader = get_model_loader(load_config)
|
265
264
|
if not isinstance(loader, DefaultModelLoader):
|
266
|
-
|
267
|
-
return False,
|
265
|
+
message = f"Failed to get model loader: {loader}."
|
266
|
+
return False, message
|
268
267
|
|
269
268
|
def get_weight_iter(config):
|
270
269
|
iter = loader._get_weights_iterator(
|
@@ -289,14 +288,14 @@ class ModelRunner:
|
|
289
288
|
try:
|
290
289
|
iter = get_weight_iter(vllm_model_config)
|
291
290
|
except Exception as e:
|
292
|
-
message = f"Failed to get weights iterator: {e}"
|
293
|
-
logger.error(message)
|
291
|
+
message = f"Failed to get weights iterator: {e}."
|
294
292
|
return False, message
|
295
293
|
try:
|
296
294
|
model = model_load_weights(self.model, iter)
|
297
295
|
except Exception as e:
|
298
|
-
message =
|
299
|
-
|
296
|
+
message = (
|
297
|
+
f"Failed to update weights: {e}.\nRolling back to original weights."
|
298
|
+
)
|
300
299
|
del iter
|
301
300
|
gc.collect()
|
302
301
|
iter = get_weight_iter(self.vllm_model_config)
|
@@ -311,7 +310,18 @@ class ModelRunner:
|
|
311
310
|
self.model_config.path = model_path
|
312
311
|
|
313
312
|
logger.info("Update weights end.")
|
314
|
-
return True, "Succeeded to update model weights"
|
313
|
+
return True, "Succeeded to update model weights."
|
314
|
+
|
315
|
+
def init_lora_manager(self):
|
316
|
+
self.lora_manager = LoRAManager(
|
317
|
+
base_model=self.model,
|
318
|
+
lora_paths=self.server_args.lora_paths,
|
319
|
+
base_hf_config=self.model_config.hf_config,
|
320
|
+
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
321
|
+
load_config=self.load_config,
|
322
|
+
dtype=self.dtype,
|
323
|
+
)
|
324
|
+
logger.info("LoRA manager ready.")
|
315
325
|
|
316
326
|
def profile_max_num_token(self, total_gpu_memory: int):
|
317
327
|
available_gpu_memory = get_available_gpu_memory(
|
@@ -343,8 +353,8 @@ class ModelRunner:
|
|
343
353
|
def init_memory_pool(
|
344
354
|
self,
|
345
355
|
total_gpu_memory: int,
|
346
|
-
max_num_reqs: int = None,
|
347
|
-
max_total_tokens: int = None,
|
356
|
+
max_num_reqs: Optional[int] = None,
|
357
|
+
max_total_tokens: Optional[int] = None,
|
348
358
|
):
|
349
359
|
if self.server_args.kv_cache_dtype == "auto":
|
350
360
|
self.kv_cache_dtype = self.dtype
|
@@ -378,7 +388,7 @@ class ModelRunner:
|
|
378
388
|
),
|
379
389
|
2048,
|
380
390
|
),
|
381
|
-
|
391
|
+
4096,
|
382
392
|
)
|
383
393
|
|
384
394
|
self.req_to_token_pool = ReqToTokenPool(
|
@@ -396,9 +406,6 @@ class ModelRunner:
|
|
396
406
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
397
407
|
layer_num=self.model_config.num_hidden_layers,
|
398
408
|
)
|
399
|
-
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
400
|
-
# FIXME: temporarily only Triton MLA is supported
|
401
|
-
self.server_args.disable_flashinfer = True
|
402
409
|
else:
|
403
410
|
self.token_to_kv_pool = MHATokenToKVPool(
|
404
411
|
self.max_total_num_tokens,
|
@@ -421,118 +428,46 @@ class ModelRunner:
|
|
421
428
|
c = a @ b
|
422
429
|
return c
|
423
430
|
|
424
|
-
def
|
425
|
-
"""Init
|
426
|
-
if self.server_args.
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
self.flashinfer_decode_wrapper = None
|
433
|
-
return
|
434
|
-
|
435
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
436
|
-
self.model_config.num_attention_heads // self.tp_size,
|
437
|
-
self.model_config.get_num_kv_heads(self.tp_size),
|
438
|
-
):
|
439
|
-
use_tensor_cores = True
|
440
|
-
else:
|
441
|
-
use_tensor_cores = False
|
442
|
-
|
443
|
-
if self.sliding_window_size is None:
|
444
|
-
self.flashinfer_workspace_buffer = torch.empty(
|
445
|
-
global_config.flashinfer_workspace_size,
|
446
|
-
dtype=torch.uint8,
|
447
|
-
device="cuda",
|
448
|
-
)
|
449
|
-
self.flashinfer_prefill_wrapper_ragged = (
|
450
|
-
BatchPrefillWithRaggedKVCacheWrapper(
|
451
|
-
self.flashinfer_workspace_buffer, "NHD"
|
452
|
-
)
|
453
|
-
)
|
454
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
455
|
-
self.flashinfer_workspace_buffer, "NHD"
|
456
|
-
)
|
457
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
458
|
-
self.flashinfer_workspace_buffer,
|
459
|
-
"NHD",
|
460
|
-
use_tensor_cores=use_tensor_cores,
|
431
|
+
def init_attention_backend(self):
|
432
|
+
"""Init attention kernel backend."""
|
433
|
+
if self.server_args.attention_backend == "flashinfer":
|
434
|
+
self.attn_backend = FlashInferAttnBackend(self)
|
435
|
+
elif self.server_args.attention_backend == "triton":
|
436
|
+
assert self.sliding_window_size is None, (
|
437
|
+
"Window attention is not supported in the triton attention backend. "
|
438
|
+
"Please use `--attention-backend flashinfer`."
|
461
439
|
)
|
440
|
+
self.attn_backend = TritonAttnBackend(self)
|
462
441
|
else:
|
463
|
-
|
464
|
-
|
465
|
-
dtype=torch.uint8,
|
466
|
-
device="cuda",
|
442
|
+
raise ValueError(
|
443
|
+
f"Invalid attention backend: {self.server_args.attention_backend}"
|
467
444
|
)
|
468
|
-
self.flashinfer_prefill_wrapper_ragged = None
|
469
|
-
self.flashinfer_prefill_wrapper_paged = []
|
470
|
-
self.flashinfer_decode_wrapper = []
|
471
|
-
for i in range(2):
|
472
|
-
self.flashinfer_prefill_wrapper_paged.append(
|
473
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
474
|
-
self.flashinfer_workspace_buffer, "NHD"
|
475
|
-
)
|
476
|
-
)
|
477
|
-
self.flashinfer_decode_wrapper.append(
|
478
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
479
|
-
self.flashinfer_workspace_buffer,
|
480
|
-
"NHD",
|
481
|
-
use_tensor_cores=use_tensor_cores,
|
482
|
-
)
|
483
|
-
)
|
484
445
|
|
485
446
|
def init_cuda_graphs(self):
|
486
447
|
"""Capture cuda graphs."""
|
448
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
449
|
+
|
450
|
+
self.cuda_graph_runner = None
|
451
|
+
|
487
452
|
if not self.is_generation:
|
488
453
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
489
454
|
return
|
490
455
|
|
491
|
-
|
492
|
-
|
493
|
-
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
494
|
-
self.cuda_graph_runner = None
|
456
|
+
if self.server_args.disable_cuda_graph:
|
495
457
|
return
|
496
458
|
|
497
459
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
498
|
-
|
499
|
-
if self.server_args.disable_cuda_graph_padding:
|
500
|
-
batch_size_list = list(range(1, 32)) + [64, 128]
|
501
|
-
else:
|
502
|
-
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
503
|
-
|
504
|
-
self.cuda_graph_runner = CudaGraphRunner(
|
505
|
-
self,
|
506
|
-
max_batch_size_to_capture=max(batch_size_list),
|
507
|
-
use_torch_compile=self.server_args.enable_torch_compile,
|
508
|
-
disable_padding=self.server_args.disable_cuda_graph_padding,
|
509
|
-
)
|
510
|
-
try:
|
511
|
-
self.cuda_graph_runner.capture(batch_size_list)
|
512
|
-
except RuntimeError as e:
|
513
|
-
raise Exception(
|
514
|
-
f"Capture cuda graph failed: {e}\n"
|
515
|
-
"Possible solutions:\n"
|
516
|
-
"1. disable cuda graph by --disable-cuda-graph\n"
|
517
|
-
"2. set --mem-fraction-static to a smaller value\n"
|
518
|
-
"3. disable torch compile by not using --enable-torch-compile\n"
|
519
|
-
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
520
|
-
)
|
460
|
+
self.cuda_graph_runner = CudaGraphRunner(self)
|
521
461
|
|
522
462
|
@torch.inference_mode()
|
523
463
|
def forward_decode(self, batch: ScheduleBatch):
|
524
|
-
if
|
525
|
-
self.
|
526
|
-
|
527
|
-
|
528
|
-
):
|
464
|
+
if self.server_args.lora_paths is not None:
|
465
|
+
self.lora_manager.prepare_lora_batch(batch)
|
466
|
+
|
467
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
529
468
|
return self.cuda_graph_runner.replay(batch)
|
530
469
|
|
531
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
532
|
-
self,
|
533
|
-
batch,
|
534
|
-
ForwardMode.DECODE,
|
535
|
-
)
|
470
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
536
471
|
|
537
472
|
return self.model.forward(
|
538
473
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -540,11 +475,10 @@ class ModelRunner:
|
|
540
475
|
|
541
476
|
@torch.inference_mode()
|
542
477
|
def forward_extend(self, batch: ScheduleBatch):
|
543
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
544
|
-
|
545
|
-
batch,
|
546
|
-
|
547
|
-
)
|
478
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
479
|
+
if self.server_args.lora_paths is not None:
|
480
|
+
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
481
|
+
|
548
482
|
if self.is_generation:
|
549
483
|
return self.model.forward(
|
550
484
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -560,11 +494,7 @@ class ModelRunner:
|
|
560
494
|
|
561
495
|
@torch.inference_mode()
|
562
496
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
563
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
564
|
-
self,
|
565
|
-
batch,
|
566
|
-
forward_mode=ForwardMode.EXTEND,
|
567
|
-
)
|
497
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
568
498
|
return self.model.forward(
|
569
499
|
batch.input_ids,
|
570
500
|
input_metadata.positions,
|
@@ -574,17 +504,56 @@ class ModelRunner:
|
|
574
504
|
input_metadata.image_offsets,
|
575
505
|
)
|
576
506
|
|
577
|
-
def forward(
|
578
|
-
|
579
|
-
|
580
|
-
if self.is_multimodal_model and forward_mode
|
507
|
+
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
508
|
+
assert batch.forward_mode is not None
|
509
|
+
|
510
|
+
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
581
511
|
return self.forward_extend_multi_modal(batch)
|
582
|
-
elif forward_mode
|
512
|
+
elif batch.forward_mode.is_decode():
|
583
513
|
return self.forward_decode(batch)
|
584
|
-
elif forward_mode
|
514
|
+
elif batch.forward_mode.is_extend():
|
585
515
|
return self.forward_extend(batch)
|
586
516
|
else:
|
587
|
-
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
517
|
+
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
518
|
+
|
519
|
+
def _apply_logits_bias(
|
520
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
521
|
+
):
|
522
|
+
# Apply logit_bias
|
523
|
+
if sampling_info.logit_bias is not None:
|
524
|
+
logits.add_(sampling_info.logit_bias)
|
525
|
+
|
526
|
+
# min-token, presence, frequency
|
527
|
+
if sampling_info.linear_penalties is not None:
|
528
|
+
logits += sampling_info.linear_penalties
|
529
|
+
|
530
|
+
# repetition
|
531
|
+
if sampling_info.scaling_penalties is not None:
|
532
|
+
logits = torch.where(
|
533
|
+
logits > 0,
|
534
|
+
logits / sampling_info.scaling_penalties,
|
535
|
+
logits * sampling_info.scaling_penalties,
|
536
|
+
)
|
537
|
+
|
538
|
+
# Apply regex vocab_mask
|
539
|
+
if sampling_info.vocab_mask is not None:
|
540
|
+
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
541
|
+
|
542
|
+
return logits
|
543
|
+
|
544
|
+
def sample(
|
545
|
+
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
546
|
+
) -> torch.Tensor:
|
547
|
+
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
548
|
+
batch.sampling_info.update_regex_vocab_mask(batch)
|
549
|
+
batch.sampling_info.update_penalties()
|
550
|
+
logits = self._apply_logits_bias(
|
551
|
+
logits_output.next_token_logits, batch.sampling_info
|
552
|
+
)
|
553
|
+
|
554
|
+
# Sample the next tokens.
|
555
|
+
next_token_ids = self.sampler(logits, batch.sampling_info)
|
556
|
+
return next_token_ids
|
588
557
|
|
589
558
|
|
590
559
|
@lru_cache()
|