sglang 0.3.1__py3-none-any.whl → 0.3.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -3
- sglang/bench_server_latency.py +187 -0
- sglang/bench_serving.py +1 -1
- sglang/global_config.py +5 -13
- sglang/lang/interpreter.py +0 -3
- sglang/srt/constrained/fsm_cache.py +5 -1
- sglang/srt/layers/activation.py +16 -1
- sglang/srt/layers/attention_backend.py +12 -12
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +21 -6
- sglang/srt/layers/sampler.py +40 -98
- sglang/srt/lora/lora_manager.py +11 -8
- sglang/srt/managers/io_struct.py +3 -0
- sglang/srt/managers/policy_scheduler.py +49 -93
- sglang/srt/managers/schedule_batch.py +2 -1
- sglang/srt/managers/tp_worker.py +19 -13
- sglang/srt/model_executor/cuda_graph_runner.py +25 -13
- sglang/srt/model_executor/model_runner.py +37 -46
- sglang/srt/models/deepseek_v2.py +8 -3
- sglang/srt/models/llama.py +1 -3
- sglang/srt/models/llama_classification.py +2 -3
- sglang/srt/models/minicpm3.py +7 -3
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/xverse.py +1 -3
- sglang/srt/models/xverse_moe.py +1 -4
- sglang/srt/sampling/sampling_batch_info.py +3 -50
- sglang/srt/server.py +6 -1
- sglang/srt/server_args.py +39 -10
- sglang/srt/utils.py +7 -51
- sglang/test/few_shot_gsm8k.py +8 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/METADATA +4 -5
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/RECORD +37 -35
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/WHEEL +1 -1
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -21,12 +21,15 @@ import re
|
|
21
21
|
from dataclasses import dataclass
|
22
22
|
|
23
23
|
import torch
|
24
|
-
from flashinfer import SegmentGEMMWrapper
|
25
24
|
|
26
25
|
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
|
27
26
|
from sglang.srt.lora.lora_config import LoRAConfig
|
28
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
29
|
-
from sglang.srt.utils import replace_submodule
|
28
|
+
from sglang.srt.utils import is_hip, replace_submodule
|
29
|
+
|
30
|
+
# ROCm: flashinfer available later
|
31
|
+
if not is_hip():
|
32
|
+
from flashinfer import SegmentGEMMWrapper
|
30
33
|
|
31
34
|
|
32
35
|
def get_stacked_name(name):
|
@@ -96,10 +99,10 @@ class LoRAManager:
|
|
96
99
|
# get configs and target modules
|
97
100
|
self.configs = {}
|
98
101
|
self.origin_target_modules = set()
|
99
|
-
for path in self.lora_paths:
|
100
|
-
self.configs[
|
102
|
+
for name, path in self.lora_paths.items():
|
103
|
+
self.configs[name] = LoRAConfig(path)
|
101
104
|
self.origin_target_modules = set(self.origin_target_modules) | set(
|
102
|
-
self.configs[
|
105
|
+
self.configs[name].target_modules
|
103
106
|
)
|
104
107
|
self.target_modules = set(
|
105
108
|
[
|
@@ -114,11 +117,11 @@ class LoRAManager:
|
|
114
117
|
# load all weights to cpu
|
115
118
|
self.loras = []
|
116
119
|
self.lora_id = {}
|
117
|
-
for
|
118
|
-
self.lora_id[
|
120
|
+
for name in self.lora_paths.keys():
|
121
|
+
self.lora_id[name] = len(self.loras)
|
119
122
|
self.loras.append(
|
120
123
|
LoRAAdapter(
|
121
|
-
|
124
|
+
name, self.configs[name], self.base_hf_config, self.load_config
|
122
125
|
)
|
123
126
|
)
|
124
127
|
self.loras[-1].initialize_weights()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -133,6 +133,9 @@ class GenerateReqInput:
|
|
133
133
|
self.image_data = [None] * num
|
134
134
|
elif not isinstance(self.image_data, list):
|
135
135
|
self.image_data = [self.image_data] * num
|
136
|
+
elif isinstance(self.image_data, list):
|
137
|
+
# multi-image with n > 1
|
138
|
+
self.image_data = self.image_data * num
|
136
139
|
|
137
140
|
if self.sampling_params is None:
|
138
141
|
self.sampling_params = [{}] * num
|
@@ -119,19 +119,32 @@ class PrefillAdder:
|
|
119
119
|
self.running_batch = running_batch
|
120
120
|
self.new_token_ratio = new_token_ratio
|
121
121
|
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
|
122
|
-
self.rem_total_tokens_ = self.rem_total_tokens
|
123
|
-
self.total_tokens = rem_total_tokens
|
124
122
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
125
123
|
self.rem_chunk_tokens = rem_chunk_tokens
|
126
124
|
if self.rem_chunk_tokens is not None:
|
127
125
|
self.rem_chunk_tokens -= mixed_with_decode_tokens
|
128
126
|
|
127
|
+
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
|
128
|
+
|
129
129
|
self.req_states = None
|
130
130
|
self.can_run_list = []
|
131
131
|
self.new_inflight_req = None
|
132
132
|
self.log_hit_tokens = 0
|
133
133
|
self.log_input_tokens = 0
|
134
134
|
|
135
|
+
if running_batch is not None:
|
136
|
+
# Pre-remove the tokens which will be occupied by the running requests
|
137
|
+
self.rem_total_tokens -= sum(
|
138
|
+
[
|
139
|
+
min(
|
140
|
+
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
141
|
+
CLIP_MAX_NEW_TOKENS,
|
142
|
+
)
|
143
|
+
* self.new_token_ratio
|
144
|
+
for r in running_batch.reqs
|
145
|
+
]
|
146
|
+
)
|
147
|
+
|
135
148
|
def no_remaining_tokens(self):
|
136
149
|
return (
|
137
150
|
self.rem_total_tokens <= 0
|
@@ -141,31 +154,14 @@ class PrefillAdder:
|
|
141
154
|
if self.rem_chunk_tokens is not None
|
142
155
|
else False
|
143
156
|
)
|
144
|
-
|
145
|
-
|
146
|
-
def remove_running_tokens(self, running_batch: ScheduleBatch):
|
147
|
-
self.rem_total_tokens -= sum(
|
148
|
-
[
|
149
|
-
min(
|
150
|
-
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
151
|
-
CLIP_MAX_NEW_TOKENS,
|
152
|
-
)
|
153
|
-
* self.new_token_ratio
|
154
|
-
for r in running_batch.reqs
|
155
|
-
]
|
156
|
-
)
|
157
|
-
self.rem_total_tokens_ -= sum(
|
158
|
-
[
|
159
|
-
r.sampling_params.max_new_tokens - len(r.output_ids)
|
160
|
-
for r in running_batch.reqs
|
161
|
-
]
|
157
|
+
or self.cur_rem_tokens <= 0
|
162
158
|
)
|
163
159
|
|
164
160
|
def _prefill_one_req(
|
165
161
|
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
|
166
162
|
):
|
167
163
|
self.rem_total_tokens -= extend_input_len + max_new_tokens
|
168
|
-
self.
|
164
|
+
self.cur_rem_tokens -= extend_input_len
|
169
165
|
self.rem_input_tokens -= extend_input_len
|
170
166
|
if self.rem_chunk_tokens is not None:
|
171
167
|
self.rem_chunk_tokens -= extend_input_len
|
@@ -173,29 +169,7 @@ class PrefillAdder:
|
|
173
169
|
self.log_hit_tokens += prefix_len
|
174
170
|
self.log_input_tokens += extend_input_len
|
175
171
|
|
176
|
-
def add_inflight_req_ignore_eos(self, req: Req):
|
177
|
-
truncated = req.extend_input_len > self.rem_chunk_tokens
|
178
|
-
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
179
|
-
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
180
|
-
self.can_run_list.append(req)
|
181
|
-
|
182
|
-
self._prefill_one_req(
|
183
|
-
0,
|
184
|
-
req.extend_input_len,
|
185
|
-
(
|
186
|
-
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
187
|
-
if not truncated
|
188
|
-
else 0
|
189
|
-
),
|
190
|
-
)
|
191
|
-
|
192
|
-
# Return if chunked prefill not finished
|
193
|
-
return req if truncated else None
|
194
|
-
|
195
172
|
def add_inflight_req(self, req: Req):
|
196
|
-
if req.sampling_params.ignore_eos:
|
197
|
-
return self.add_inflight_req_ignore_eos(req)
|
198
|
-
|
199
173
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
200
174
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
201
175
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -225,7 +199,7 @@ class PrefillAdder:
|
|
225
199
|
self.rem_total_tokens += delta
|
226
200
|
|
227
201
|
def add_one_req_ignore_eos(self, req: Req):
|
228
|
-
def
|
202
|
+
def add_req_state(r, insert_sort=False):
|
229
203
|
new_token_ratio = (
|
230
204
|
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
|
231
205
|
)
|
@@ -235,56 +209,38 @@ class PrefillAdder:
|
|
235
209
|
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
|
236
210
|
|
237
211
|
if tokens_left > 0:
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
can_run = False
|
244
|
-
if (
|
245
|
-
req.extend_input_len + req.sampling_params.max_new_tokens
|
246
|
-
<= self.rem_total_tokens
|
247
|
-
):
|
248
|
-
can_run = True
|
249
|
-
|
250
|
-
if not can_run:
|
251
|
-
if self.req_states is None:
|
252
|
-
self.req_states = []
|
253
|
-
if self.running_batch is not None:
|
254
|
-
for r in self.running_batch.reqs:
|
255
|
-
state = get_req_state(r)
|
256
|
-
if state is not None:
|
257
|
-
self.req_states.append(state)
|
258
|
-
for r in self.can_run_list:
|
259
|
-
state = get_req_state(r)
|
260
|
-
if state is not None:
|
261
|
-
self.req_states.append(state)
|
262
|
-
state = get_req_state(req)
|
263
|
-
if state is not None:
|
264
|
-
self.req_states.append(state)
|
265
|
-
|
266
|
-
self.req_states.sort(key=lambda x: x[0])
|
267
|
-
else:
|
268
|
-
state = get_req_state(req)
|
269
|
-
if state is not None:
|
270
|
-
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
271
|
-
if tokens_left >= state[0]:
|
272
|
-
self.req_states.insert(i, state)
|
212
|
+
if not insert_sort:
|
213
|
+
self.req_states.append((tokens_left, tokens_occupied))
|
214
|
+
else:
|
215
|
+
for i in range(len(self.req_states)):
|
216
|
+
if tokens_left <= self.req_states[i][0]:
|
273
217
|
break
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
)
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
218
|
+
self.req_states.insert(i, (tokens_left, tokens_occupied))
|
219
|
+
|
220
|
+
if self.req_states is None:
|
221
|
+
self.req_states = []
|
222
|
+
add_req_state(req)
|
223
|
+
if self.running_batch is not None:
|
224
|
+
for r in self.running_batch.reqs:
|
225
|
+
add_req_state(r)
|
226
|
+
for r in self.can_run_list:
|
227
|
+
add_req_state(r)
|
228
|
+
self.req_states.sort(key=lambda x: x[0])
|
229
|
+
else:
|
230
|
+
add_req_state(req, insert_sort=True)
|
231
|
+
|
232
|
+
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
|
233
|
+
tokens_freed = 0
|
234
|
+
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
235
|
+
decode_steps = (
|
236
|
+
self.req_states[i + 1][0]
|
237
|
+
if i + 1 < len(self.req_states)
|
238
|
+
else tokens_left
|
239
|
+
)
|
240
|
+
bs = len(self.req_states) - i
|
241
|
+
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
|
242
|
+
return False
|
243
|
+
tokens_freed += tokens_occupied
|
288
244
|
|
289
245
|
if req.extend_input_len <= self.rem_chunk_tokens:
|
290
246
|
self.can_run_list.append(req)
|
@@ -40,7 +40,7 @@ global_server_args_dict = {
|
|
40
40
|
"attention_backend": ServerArgs.attention_backend,
|
41
41
|
"sampling_backend": ServerArgs.sampling_backend,
|
42
42
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
43
|
-
"
|
43
|
+
"disable_mla": ServerArgs.disable_mla,
|
44
44
|
"torchao_config": ServerArgs.torchao_config,
|
45
45
|
}
|
46
46
|
|
@@ -360,6 +360,7 @@ class ScheduleBatch:
|
|
360
360
|
tree_cache: BasePrefixCache
|
361
361
|
|
362
362
|
forward_mode: ForwardMode = None
|
363
|
+
sampling_info: SamplingBatchInfo = None
|
363
364
|
|
364
365
|
# Batched arguments to model runner
|
365
366
|
input_ids: torch.Tensor = None
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -198,6 +198,7 @@ class ModelTpServer:
|
|
198
198
|
"trust_remote_code": server_args.trust_remote_code,
|
199
199
|
},
|
200
200
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
201
|
+
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
201
202
|
)
|
202
203
|
self.jump_forward_cache = JumpForwardCache()
|
203
204
|
|
@@ -414,7 +415,7 @@ class ModelTpServer:
|
|
414
415
|
|
415
416
|
# Truncate prompts that are too long
|
416
417
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
417
|
-
logger.
|
418
|
+
logger.warning(
|
418
419
|
"Request length is longer than the KV cache pool size or "
|
419
420
|
"the max context length. Truncated!!!"
|
420
421
|
)
|
@@ -444,9 +445,6 @@ class ModelTpServer:
|
|
444
445
|
num_mixed_running,
|
445
446
|
)
|
446
447
|
|
447
|
-
if self.running_batch is not None:
|
448
|
-
adder.remove_running_tokens(self.running_batch)
|
449
|
-
|
450
448
|
has_inflight = self.current_inflight_req is not None
|
451
449
|
if self.current_inflight_req is not None:
|
452
450
|
self.current_inflight_req.init_next_round_input(
|
@@ -464,9 +462,6 @@ class ModelTpServer:
|
|
464
462
|
)
|
465
463
|
|
466
464
|
for req in self.waiting_queue:
|
467
|
-
if adder.no_remaining_tokens():
|
468
|
-
break
|
469
|
-
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
470
465
|
if (
|
471
466
|
self.lora_paths is not None
|
472
467
|
and len(
|
@@ -477,6 +472,10 @@ class ModelTpServer:
|
|
477
472
|
> self.max_loras_per_batch
|
478
473
|
):
|
479
474
|
break
|
475
|
+
|
476
|
+
if adder.no_remaining_tokens():
|
477
|
+
break
|
478
|
+
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
480
479
|
res = adder.add_one_req(req)
|
481
480
|
if (
|
482
481
|
not res
|
@@ -506,6 +505,11 @@ class ModelTpServer:
|
|
506
505
|
else:
|
507
506
|
tree_cache_hit_rate = 0.0
|
508
507
|
|
508
|
+
num_used = self.max_total_num_tokens - (
|
509
|
+
self.token_to_kv_pool.available_size()
|
510
|
+
+ self.tree_cache.evictable_size()
|
511
|
+
)
|
512
|
+
|
509
513
|
if num_mixed_running > 0:
|
510
514
|
logger.info(
|
511
515
|
f"Prefill batch"
|
@@ -514,6 +518,7 @@ class ModelTpServer:
|
|
514
518
|
f"#new-token: {adder.log_input_tokens}, "
|
515
519
|
f"#cached-token: {adder.log_hit_tokens}, "
|
516
520
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
521
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
517
522
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
518
523
|
)
|
519
524
|
else:
|
@@ -523,6 +528,7 @@ class ModelTpServer:
|
|
523
528
|
f"#new-token: {adder.log_input_tokens}, "
|
524
529
|
f"#cached-token: {adder.log_hit_tokens}, "
|
525
530
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
531
|
+
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
526
532
|
f"#running-req: {running_bs}, "
|
527
533
|
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
528
534
|
)
|
@@ -807,12 +813,10 @@ class ModelTpServer:
|
|
807
813
|
unfinished_indices.append(i)
|
808
814
|
|
809
815
|
if req.finished() or (
|
810
|
-
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
-
or len(req.output_ids) == 1
|
815
|
-
)
|
816
|
+
req.stream
|
817
|
+
and (
|
818
|
+
self.decode_forward_ct % self.stream_interval == 0
|
819
|
+
or len(req.output_ids) == 1
|
816
820
|
)
|
817
821
|
):
|
818
822
|
output_rids.append(req.rid)
|
@@ -937,6 +941,8 @@ class ModelTpServer:
|
|
937
941
|
if success:
|
938
942
|
flash_cache_success = self.flush_cache()
|
939
943
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
944
|
+
else:
|
945
|
+
logger.error(message)
|
940
946
|
return success, message
|
941
947
|
|
942
948
|
|
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
|
|
41
41
|
def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
42
42
|
for sub in model._modules.values():
|
43
43
|
if isinstance(sub, CustomOp):
|
44
|
+
# NOTE: FusedMoE torch native implementaiton is not efficient
|
45
|
+
if "FusedMoE" in sub.__class__.__name__:
|
46
|
+
continue
|
44
47
|
if reverse:
|
45
48
|
sub._forward_method = sub.forward_cuda
|
46
49
|
setattr(sub, "is_torch_compile", False)
|
@@ -105,23 +108,22 @@ class CudaGraphRunner:
|
|
105
108
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
106
109
|
else:
|
107
110
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
108
|
-
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
|
109
111
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
self.
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
(self.max_bs,), dtype=torch.int32, device="cuda"
|
112
|
+
self.capture_bs = [
|
113
|
+
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
114
|
+
]
|
115
|
+
self.compile_bs = (
|
116
|
+
[
|
117
|
+
bs
|
118
|
+
for bs in self.capture_bs
|
119
|
+
if bs <= self.model_runner.server_args.max_torch_compile_bs
|
120
|
+
]
|
121
|
+
if self.use_torch_compile
|
122
|
+
else []
|
122
123
|
)
|
123
124
|
|
124
125
|
# Attention backend
|
126
|
+
self.max_bs = max(self.capture_bs)
|
125
127
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
126
128
|
self.seq_len_fill_value = (
|
127
129
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
@@ -130,6 +132,16 @@ class CudaGraphRunner:
|
|
130
132
|
if self.use_torch_compile:
|
131
133
|
set_torch_compile_config()
|
132
134
|
|
135
|
+
# Common inputs
|
136
|
+
with torch.device("cuda"):
|
137
|
+
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
|
138
|
+
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
139
|
+
self.seq_lens = torch.full(
|
140
|
+
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
141
|
+
)
|
142
|
+
self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
|
143
|
+
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
144
|
+
|
133
145
|
# Capture
|
134
146
|
try:
|
135
147
|
self.capture()
|
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
|
40
40
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
41
|
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
43
|
-
from sglang.srt.layers.sampler import
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
44
44
|
from sglang.srt.lora.lora_manager import LoRAManager
|
45
45
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
46
46
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
|
|
54
54
|
from sglang.srt.utils import (
|
55
55
|
get_available_gpu_memory,
|
56
56
|
is_generation_model,
|
57
|
-
is_llama3_405b_fp8_head_16,
|
58
57
|
is_multimodal_model,
|
59
58
|
monkey_patch_vllm_dummy_weight_loader,
|
60
59
|
monkey_patch_vllm_p2p_access_check,
|
61
|
-
monkey_patch_vllm_qvk_linear_loader,
|
62
60
|
)
|
63
61
|
|
64
62
|
logger = logging.getLogger(__name__)
|
@@ -88,12 +86,20 @@ class ModelRunner:
|
|
88
86
|
self.is_multimodal_model = is_multimodal_model(
|
89
87
|
self.model_config.hf_config.architectures
|
90
88
|
)
|
89
|
+
|
90
|
+
if (
|
91
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
92
|
+
and not self.server_args.disable_mla
|
93
|
+
):
|
94
|
+
logger.info("MLA optimization is tunred on. Use triton backend.")
|
95
|
+
self.server_args.attention_backend = "triton"
|
96
|
+
|
91
97
|
global_server_args_dict.update(
|
92
98
|
{
|
93
99
|
"attention_backend": server_args.attention_backend,
|
94
100
|
"sampling_backend": server_args.sampling_backend,
|
95
101
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
96
|
-
"
|
102
|
+
"disable_mla": server_args.disable_mla,
|
97
103
|
"torchao_config": server_args.torchao_config,
|
98
104
|
}
|
99
105
|
)
|
@@ -166,10 +172,13 @@ class ModelRunner:
|
|
166
172
|
return min_per_gpu_memory
|
167
173
|
|
168
174
|
def load_model(self):
|
169
|
-
torch.set_num_threads(1)
|
170
175
|
logger.info(
|
171
176
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
172
177
|
)
|
178
|
+
|
179
|
+
# This can reduce thread conflicts and speed up weight loading.
|
180
|
+
torch.set_num_threads(1)
|
181
|
+
|
173
182
|
if torch.cuda.get_device_capability()[0] < 8:
|
174
183
|
logger.info(
|
175
184
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
@@ -178,6 +187,7 @@ class ModelRunner:
|
|
178
187
|
if torch.cuda.get_device_capability()[1] < 5:
|
179
188
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
180
189
|
|
190
|
+
# Prepare the vllm model config
|
181
191
|
monkey_patch_vllm_dummy_weight_loader()
|
182
192
|
self.device_config = DeviceConfig()
|
183
193
|
self.load_config = LoadConfig(load_format=self.server_args.load_format)
|
@@ -188,23 +198,16 @@ class ModelRunner:
|
|
188
198
|
tokenizer_mode=None,
|
189
199
|
trust_remote_code=self.server_args.trust_remote_code,
|
190
200
|
dtype=self.server_args.dtype,
|
191
|
-
seed=
|
201
|
+
seed=self.server_args.random_seed,
|
192
202
|
skip_tokenizer_init=True,
|
193
203
|
)
|
194
|
-
|
195
|
-
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
|
196
|
-
# Drop this after Sept, 2024.
|
197
|
-
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
|
198
|
-
self.model_config.hf_config.num_key_value_heads = 8
|
199
|
-
self.vllm_model_config.hf_config.num_key_value_heads = 8
|
200
|
-
monkey_patch_vllm_qvk_linear_loader()
|
201
|
-
|
202
|
-
self.dtype = self.vllm_model_config.dtype
|
203
204
|
if self.model_config.model_override_args is not None:
|
204
205
|
self.vllm_model_config.hf_config.update(
|
205
206
|
self.model_config.model_override_args
|
206
207
|
)
|
208
|
+
self.dtype = self.vllm_model_config.dtype
|
207
209
|
|
210
|
+
# Load the model
|
208
211
|
self.model = get_model(
|
209
212
|
model_config=self.vllm_model_config,
|
210
213
|
load_config=self.load_config,
|
@@ -255,20 +258,20 @@ class ModelRunner:
|
|
255
258
|
tokenizer_mode=None,
|
256
259
|
trust_remote_code=self.server_args.trust_remote_code,
|
257
260
|
dtype=self.server_args.dtype,
|
258
|
-
seed=
|
261
|
+
seed=self.server_args.random_seed,
|
259
262
|
skip_tokenizer_init=True,
|
260
263
|
)
|
261
264
|
except Exception as e:
|
262
|
-
|
263
|
-
return False,
|
265
|
+
message = f"Failed to load model config: {e}."
|
266
|
+
return False, message
|
264
267
|
|
265
268
|
load_config = LoadConfig(load_format=load_format)
|
266
269
|
|
267
270
|
# Only support vllm DefaultModelLoader for now
|
268
271
|
loader = get_model_loader(load_config)
|
269
272
|
if not isinstance(loader, DefaultModelLoader):
|
270
|
-
|
271
|
-
return False,
|
273
|
+
message = f"Failed to get model loader: {loader}."
|
274
|
+
return False, message
|
272
275
|
|
273
276
|
def get_weight_iter(config):
|
274
277
|
iter = loader._get_weights_iterator(
|
@@ -293,14 +296,14 @@ class ModelRunner:
|
|
293
296
|
try:
|
294
297
|
iter = get_weight_iter(vllm_model_config)
|
295
298
|
except Exception as e:
|
296
|
-
message = f"Failed to get weights iterator: {e}"
|
297
|
-
logger.error(message)
|
299
|
+
message = f"Failed to get weights iterator: {e}."
|
298
300
|
return False, message
|
299
301
|
try:
|
300
302
|
model = model_load_weights(self.model, iter)
|
301
303
|
except Exception as e:
|
302
|
-
message =
|
303
|
-
|
304
|
+
message = (
|
305
|
+
f"Failed to update weights: {e}.\nRolling back to original weights."
|
306
|
+
)
|
304
307
|
del iter
|
305
308
|
gc.collect()
|
306
309
|
iter = get_weight_iter(self.vllm_model_config)
|
@@ -315,7 +318,7 @@ class ModelRunner:
|
|
315
318
|
self.model_config.path = model_path
|
316
319
|
|
317
320
|
logger.info("Update weights end.")
|
318
|
-
return True, "Succeeded to update model weights"
|
321
|
+
return True, "Succeeded to update model weights."
|
319
322
|
|
320
323
|
def init_lora_manager(self):
|
321
324
|
self.lora_manager = LoRAManager(
|
@@ -334,7 +337,7 @@ class ModelRunner:
|
|
334
337
|
)
|
335
338
|
if (
|
336
339
|
self.model_config.attention_arch == AttentionArch.MLA
|
337
|
-
and self.server_args.
|
340
|
+
and not self.server_args.disable_mla
|
338
341
|
):
|
339
342
|
cell_size = (
|
340
343
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
@@ -397,12 +400,12 @@ class ModelRunner:
|
|
397
400
|
)
|
398
401
|
|
399
402
|
self.req_to_token_pool = ReqToTokenPool(
|
400
|
-
max_num_reqs,
|
401
|
-
self.model_config.context_len +
|
403
|
+
max_num_reqs + 1,
|
404
|
+
self.model_config.context_len + 4,
|
402
405
|
)
|
403
406
|
if (
|
404
407
|
self.model_config.attention_arch == AttentionArch.MLA
|
405
|
-
and self.server_args.
|
408
|
+
and not self.server_args.disable_mla
|
406
409
|
):
|
407
410
|
self.token_to_kv_pool = MLATokenToKVPool(
|
408
411
|
self.max_total_num_tokens,
|
@@ -521,21 +524,6 @@ class ModelRunner:
|
|
521
524
|
else:
|
522
525
|
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
523
526
|
|
524
|
-
def _check_sample_results(self, sample_output: SampleOutput):
|
525
|
-
if not torch.all(sample_output.success):
|
526
|
-
probs = sample_output.probs
|
527
|
-
batch_next_token_ids = sample_output.batch_next_token_ids
|
528
|
-
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
529
|
-
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
530
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
531
|
-
batch_next_token_ids = torch.where(
|
532
|
-
sample_output.success, batch_next_token_ids, argmax_ids
|
533
|
-
)
|
534
|
-
sample_output.probs = probs
|
535
|
-
sample_output.batch_next_token_ids = batch_next_token_ids
|
536
|
-
|
537
|
-
return sample_output.batch_next_token_ids
|
538
|
-
|
539
527
|
def _apply_logits_bias(
|
540
528
|
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
541
529
|
):
|
@@ -564,13 +552,16 @@ class ModelRunner:
|
|
564
552
|
def sample(
|
565
553
|
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
566
554
|
) -> torch.Tensor:
|
555
|
+
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
567
556
|
batch.sampling_info.update_regex_vocab_mask(batch)
|
568
557
|
batch.sampling_info.update_penalties()
|
569
558
|
logits = self._apply_logits_bias(
|
570
559
|
logits_output.next_token_logits, batch.sampling_info
|
571
560
|
)
|
572
|
-
|
573
|
-
|
561
|
+
|
562
|
+
# Sample the next tokens.
|
563
|
+
next_token_ids = self.sampler(logits, batch.sampling_info)
|
564
|
+
return next_token_ids
|
574
565
|
|
575
566
|
|
576
567
|
@lru_cache()
|