sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +20 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +69 -65
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +15 -1
- sglang/srt/model_executor/model_runner.py +11 -4
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/grok.py +0 -5
- sglang/srt/models/llama.py +0 -5
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +3 -3
- sglang/srt/server_args.py +43 -4
- sglang/srt/utils.py +50 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,24 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Dict, Set
|
6
|
-
|
7
5
|
import torch
|
8
6
|
|
9
7
|
|
10
|
-
def
|
11
|
-
|
8
|
+
def apply_torchao_config_to_model(
|
9
|
+
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
10
|
+
):
|
11
|
+
"""Quantize a modelwith torchao quantization specified by torchao_config
|
12
12
|
|
13
13
|
Args:
|
14
|
-
`
|
15
|
-
`torchao_config
|
16
|
-
quantize the
|
14
|
+
`model`: a model to be quantized based on torchao_config
|
15
|
+
`torchao_config` (str): type of quantization and their arguments we want to use to
|
16
|
+
quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
|
17
17
|
128
|
18
18
|
"""
|
19
19
|
# Lazy import to suppress some warnings
|
20
20
|
from torchao.quantization import (
|
21
21
|
float8_dynamic_activation_float8_weight,
|
22
|
+
float8_weight_only,
|
22
23
|
int4_weight_only,
|
23
24
|
int8_dynamic_activation_int8_weight,
|
24
25
|
int8_weight_only,
|
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
26
27
|
)
|
27
28
|
from torchao.quantization.observer import PerRow, PerTensor
|
28
29
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
30
|
+
if filter_fn is None:
|
31
|
+
|
32
|
+
def filter_fn(module, fqn):
|
33
|
+
return "proj" in fqn
|
34
|
+
|
35
|
+
if torchao_config == "" or torchao_config is None:
|
36
|
+
return model
|
37
|
+
elif "int8wo" in torchao_config:
|
38
|
+
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
|
33
39
|
elif "int8dq" in torchao_config:
|
34
|
-
quantize_(
|
40
|
+
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
|
35
41
|
elif "int4wo" in torchao_config:
|
36
42
|
group_size = int(torchao_config.split("-")[-1])
|
37
43
|
assert group_size in [
|
@@ -40,13 +46,11 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
40
46
|
128,
|
41
47
|
256,
|
42
48
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
43
|
-
quantize_(
|
49
|
+
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
44
50
|
elif "fp8wo" in torchao_config:
|
45
|
-
from torchao.quantization import float8_weight_only
|
46
|
-
|
47
51
|
# this requires newer hardware
|
48
52
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
49
|
-
quantize_(
|
53
|
+
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
|
50
54
|
elif "fp8dq" in torchao_config:
|
51
55
|
granularity = torchao_config.split("-")[-1]
|
52
56
|
GRANULARITY_MAP = {
|
@@ -57,39 +61,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
57
61
|
granularity in GRANULARITY_MAP
|
58
62
|
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
59
63
|
quantize_(
|
60
|
-
|
64
|
+
model,
|
61
65
|
float8_dynamic_activation_float8_weight(
|
62
66
|
granularity=GRANULARITY_MAP[granularity]
|
63
67
|
),
|
68
|
+
filter_fn=filter_fn,
|
64
69
|
)
|
65
70
|
else:
|
66
71
|
raise ValueError(f"Unexpected config: {torchao_config}")
|
67
72
|
|
68
|
-
return
|
69
|
-
|
70
|
-
|
71
|
-
def apply_torchao_config_(
|
72
|
-
self: torch.nn.Module,
|
73
|
-
params_dict: Dict[str, torch.Tensor],
|
74
|
-
param_suffixes: Set[str],
|
75
|
-
) -> None:
|
76
|
-
"""A util function used for quantizing the weight parameters after they are loaded if
|
77
|
-
self.torchao_config is specified
|
78
|
-
|
79
|
-
Args:
|
80
|
-
`self`: the model we want to quantize
|
81
|
-
`params_dict`: dictionary mapping from param_name to the parameter Tensor
|
82
|
-
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
|
83
|
-
|
84
|
-
Returns:
|
85
|
-
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
|
86
|
-
"""
|
87
|
-
if self.torchao_config:
|
88
|
-
for param_suffix in param_suffixes:
|
89
|
-
for name in params_dict:
|
90
|
-
param = params_dict[name]
|
91
|
-
if param_suffix in name and param.ndim == 2:
|
92
|
-
params_dict[name] = torchao_quantize_param_data(
|
93
|
-
param, self.torchao_config
|
94
|
-
)
|
95
|
-
self.load_state_dict(params_dict, assign=True)
|
73
|
+
return model
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -114,9 +114,6 @@ class Scheduler:
|
|
114
114
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
115
115
|
self.enable_metrics = server_args.enable_metrics
|
116
116
|
|
117
|
-
# Session info
|
118
|
-
self.sessions = {}
|
119
|
-
|
120
117
|
# Init inter-process communication
|
121
118
|
context = zmq.Context(2)
|
122
119
|
|
@@ -259,6 +256,10 @@ class Scheduler:
|
|
259
256
|
self.num_generated_tokens = 0
|
260
257
|
self.last_decode_stats_tic = time.time()
|
261
258
|
self.stream_interval = server_args.stream_interval
|
259
|
+
self.current_stream = torch.get_device_module(self.device).current_stream()
|
260
|
+
|
261
|
+
# Session info
|
262
|
+
self.sessions = {}
|
262
263
|
|
263
264
|
# Init chunked prefill
|
264
265
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
@@ -356,6 +357,7 @@ class Scheduler:
|
|
356
357
|
)
|
357
358
|
|
358
359
|
def watchdog_thread(self):
|
360
|
+
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
359
361
|
self.watchdog_last_forward_ct = 0
|
360
362
|
self.watchdog_last_time = time.time()
|
361
363
|
|
@@ -433,61 +435,6 @@ class Scheduler:
|
|
433
435
|
|
434
436
|
self.last_batch = batch
|
435
437
|
|
436
|
-
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
437
|
-
# Check if other DP workers have running batches
|
438
|
-
if local_batch is None:
|
439
|
-
num_tokens = 0
|
440
|
-
elif local_batch.forward_mode.is_decode():
|
441
|
-
num_tokens = local_batch.batch_size()
|
442
|
-
else:
|
443
|
-
num_tokens = local_batch.extend_num_tokens
|
444
|
-
|
445
|
-
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
446
|
-
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
447
|
-
torch.distributed.all_gather_into_tensor(
|
448
|
-
global_num_tokens,
|
449
|
-
local_num_tokens,
|
450
|
-
group=self.tp_cpu_group,
|
451
|
-
)
|
452
|
-
|
453
|
-
if local_batch is None and global_num_tokens.max().item() > 0:
|
454
|
-
local_batch = self.get_idle_batch()
|
455
|
-
|
456
|
-
if local_batch is not None:
|
457
|
-
local_batch.global_num_tokens = global_num_tokens.tolist()
|
458
|
-
|
459
|
-
# Check forward mode for cuda graph
|
460
|
-
if not self.server_args.disable_cuda_graph:
|
461
|
-
forward_mode_state = torch.tensor(
|
462
|
-
(
|
463
|
-
1
|
464
|
-
if local_batch.forward_mode.is_decode()
|
465
|
-
or local_batch.forward_mode.is_idle()
|
466
|
-
else 0
|
467
|
-
),
|
468
|
-
dtype=torch.int32,
|
469
|
-
)
|
470
|
-
torch.distributed.all_reduce(
|
471
|
-
forward_mode_state,
|
472
|
-
op=torch.distributed.ReduceOp.MIN,
|
473
|
-
group=self.tp_cpu_group,
|
474
|
-
)
|
475
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
476
|
-
|
477
|
-
return local_batch
|
478
|
-
|
479
|
-
def get_idle_batch(self):
|
480
|
-
idle_batch = ScheduleBatch.init_new(
|
481
|
-
[],
|
482
|
-
self.req_to_token_pool,
|
483
|
-
self.token_to_kv_pool,
|
484
|
-
self.tree_cache,
|
485
|
-
self.model_config,
|
486
|
-
self.enable_overlap,
|
487
|
-
)
|
488
|
-
idle_batch.prepare_for_idle()
|
489
|
-
return idle_batch
|
490
|
-
|
491
438
|
def recv_requests(self):
|
492
439
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
493
440
|
recv_reqs = []
|
@@ -993,7 +940,7 @@ class Scheduler:
|
|
993
940
|
self.process_batch_result_prefill(batch, result)
|
994
941
|
elif batch.forward_mode.is_dummy_first():
|
995
942
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
996
|
-
|
943
|
+
self.current_stream.synchronize()
|
997
944
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
998
945
|
|
999
946
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
@@ -1049,13 +996,14 @@ class Scheduler:
|
|
1049
996
|
|
1050
997
|
if req.grammar is not None:
|
1051
998
|
req.grammar.accept_token(next_token_id)
|
999
|
+
req.grammar.finished = req.finished()
|
1052
1000
|
else:
|
1053
1001
|
# being chunked reqs' prefill is not finished
|
1054
1002
|
req.is_being_chunked -= 1
|
1055
1003
|
|
1056
1004
|
if batch.next_batch_sampling_info:
|
1057
1005
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1058
|
-
|
1006
|
+
self.current_stream.synchronize()
|
1059
1007
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1060
1008
|
|
1061
1009
|
else: # embedding or reward model
|
@@ -1127,10 +1075,11 @@ class Scheduler:
|
|
1127
1075
|
|
1128
1076
|
if req.grammar is not None:
|
1129
1077
|
req.grammar.accept_token(next_token_id)
|
1078
|
+
req.grammar.finished = req.finished()
|
1130
1079
|
|
1131
1080
|
if batch.next_batch_sampling_info:
|
1132
1081
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1133
|
-
|
1082
|
+
self.current_stream.synchronize()
|
1134
1083
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1135
1084
|
|
1136
1085
|
self.stream_output(batch.reqs)
|
@@ -1328,6 +1277,61 @@ class Scheduler:
|
|
1328
1277
|
)
|
1329
1278
|
)
|
1330
1279
|
|
1280
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1281
|
+
# Check if other DP workers have running batches
|
1282
|
+
if local_batch is None:
|
1283
|
+
num_tokens = 0
|
1284
|
+
elif local_batch.forward_mode.is_decode():
|
1285
|
+
num_tokens = local_batch.batch_size()
|
1286
|
+
else:
|
1287
|
+
num_tokens = local_batch.extend_num_tokens
|
1288
|
+
|
1289
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
1290
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
1291
|
+
torch.distributed.all_gather_into_tensor(
|
1292
|
+
global_num_tokens,
|
1293
|
+
local_num_tokens,
|
1294
|
+
group=self.tp_cpu_group,
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
1298
|
+
local_batch = self.get_idle_batch()
|
1299
|
+
|
1300
|
+
if local_batch is not None:
|
1301
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
1302
|
+
|
1303
|
+
# Check forward mode for cuda graph
|
1304
|
+
if not self.server_args.disable_cuda_graph:
|
1305
|
+
forward_mode_state = torch.tensor(
|
1306
|
+
(
|
1307
|
+
1
|
1308
|
+
if local_batch.forward_mode.is_decode()
|
1309
|
+
or local_batch.forward_mode.is_idle()
|
1310
|
+
else 0
|
1311
|
+
),
|
1312
|
+
dtype=torch.int32,
|
1313
|
+
)
|
1314
|
+
torch.distributed.all_reduce(
|
1315
|
+
forward_mode_state,
|
1316
|
+
op=torch.distributed.ReduceOp.MIN,
|
1317
|
+
group=self.tp_cpu_group,
|
1318
|
+
)
|
1319
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1320
|
+
|
1321
|
+
return local_batch
|
1322
|
+
|
1323
|
+
def get_idle_batch(self):
|
1324
|
+
idle_batch = ScheduleBatch.init_new(
|
1325
|
+
[],
|
1326
|
+
self.req_to_token_pool,
|
1327
|
+
self.token_to_kv_pool,
|
1328
|
+
self.tree_cache,
|
1329
|
+
self.model_config,
|
1330
|
+
self.enable_overlap,
|
1331
|
+
)
|
1332
|
+
idle_batch.prepare_for_idle()
|
1333
|
+
return idle_batch
|
1334
|
+
|
1331
1335
|
def move_ready_grammar_requests(self):
|
1332
1336
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1333
1337
|
num_ready_reqs = 0
|
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
|
|
1469
1473
|
dp_rank: Optional[int],
|
1470
1474
|
pipe_writer,
|
1471
1475
|
):
|
1472
|
-
# set cpu affinity to this gpu process
|
1473
|
-
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1474
|
-
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1475
|
-
|
1476
1476
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1477
1477
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1478
1478
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
|
|
1482
1482
|
else:
|
1483
1483
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1484
1484
|
|
1485
|
+
# set cpu affinity to this gpu process
|
1486
|
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1487
|
+
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1488
|
+
|
1485
1489
|
suppress_other_loggers()
|
1486
1490
|
parent_process = psutil.Process().parent()
|
1487
1491
|
|
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
|
|
32
32
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
33
33
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
34
34
|
from sglang.srt.server_args import ServerArgs
|
35
|
+
from sglang.srt.utils import get_compiler_backend
|
35
36
|
from sglang.utils import get_exception_traceback
|
36
37
|
|
37
38
|
logger = logging.getLogger(__name__)
|
38
39
|
|
39
40
|
|
40
|
-
@torch.compile(dynamic=True)
|
41
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
41
42
|
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
42
43
|
input_ids[:] = torch.where(
|
43
44
|
input_ids < 0,
|
@@ -73,12 +74,13 @@ class TpModelWorkerClient:
|
|
73
74
|
# Launch threads
|
74
75
|
self.input_queue = Queue()
|
75
76
|
self.output_queue = Queue()
|
76
|
-
self.forward_stream = torch.
|
77
|
+
self.forward_stream = torch.get_device_module(self.device).Stream()
|
77
78
|
self.forward_thread = threading.Thread(
|
78
79
|
target=self.forward_thread_func,
|
79
80
|
)
|
80
81
|
self.forward_thread.start()
|
81
82
|
self.parent_process = psutil.Process().parent()
|
83
|
+
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
82
84
|
|
83
85
|
def get_worker_info(self):
|
84
86
|
return self.worker.get_worker_info()
|
@@ -97,7 +99,7 @@ class TpModelWorkerClient:
|
|
97
99
|
|
98
100
|
def forward_thread_func(self):
|
99
101
|
try:
|
100
|
-
with torch.
|
102
|
+
with torch.get_device_module(self.device).stream(self.forward_stream):
|
101
103
|
self.forward_thread_func_()
|
102
104
|
except Exception:
|
103
105
|
traceback = get_exception_traceback()
|
@@ -122,7 +124,7 @@ class TpModelWorkerClient:
|
|
122
124
|
|
123
125
|
# Create event
|
124
126
|
self.launch_done = threading.Event()
|
125
|
-
copy_done = torch.
|
127
|
+
copy_done = torch.get_device_module(self.device).Event()
|
126
128
|
|
127
129
|
# Resolve future tokens in the input
|
128
130
|
input_ids = model_worker_batch.input_ids
|
@@ -190,7 +192,7 @@ class TpModelWorkerClient:
|
|
190
192
|
)
|
191
193
|
|
192
194
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
193
|
-
|
195
|
+
self.scheduler_stream.synchronize()
|
194
196
|
|
195
197
|
# Push a new batch to the queue
|
196
198
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
|
|
27
27
|
import torch
|
28
28
|
|
29
29
|
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.utils import get_compiler_backend
|
30
31
|
|
31
32
|
logger = logging.getLogger(__name__)
|
32
33
|
|
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
|
|
129
130
|
return select_index.to(self.device, non_blocking=True)
|
130
131
|
|
131
132
|
def free(self, free_index: torch.Tensor):
|
133
|
+
if free_index.numel() == 0:
|
134
|
+
return
|
135
|
+
|
132
136
|
if self.is_not_in_free_group:
|
133
137
|
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
134
138
|
else:
|
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
234
238
|
|
235
239
|
# This compiled version is slower in the unit test
|
236
240
|
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
237
|
-
@torch.compile(dynamic=True)
|
241
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
238
242
|
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
239
243
|
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
240
244
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|
47
47
|
if "FusedMoE" in sub.__class__.__name__:
|
48
48
|
if batch_size == 1:
|
49
49
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
50
|
-
# so we decide to
|
50
|
+
# so we decide to only use torch.compile when bs =1
|
51
51
|
sub._forward_method = fused_moe_forward_native
|
52
52
|
else:
|
53
53
|
sub._forward_method = sub.forward_native
|
@@ -130,6 +130,20 @@ class CudaGraphRunner:
|
|
130
130
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
131
131
|
else:
|
132
132
|
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
133
|
+
|
134
|
+
if max(self.capture_bs) > model_runner.req_to_token_pool.size:
|
135
|
+
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
136
|
+
# is very samll. We add more values here to make sure we capture the maximum bs.
|
137
|
+
self.capture_bs = list(
|
138
|
+
sorted(
|
139
|
+
set(
|
140
|
+
self.capture_bs
|
141
|
+
+ [model_runner.req_to_token_pool.size - 1]
|
142
|
+
+ [model_runner.req_to_token_pool.size]
|
143
|
+
)
|
144
|
+
)
|
145
|
+
)
|
146
|
+
|
133
147
|
self.capture_bs = [
|
134
148
|
bs
|
135
149
|
for bs in self.capture_bs
|
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
initialize_model_parallel,
|
28
28
|
set_custom_all_reduce,
|
29
29
|
)
|
30
|
-
from vllm.distributed.parallel_state import in_the_same_node_as
|
31
30
|
|
32
31
|
from sglang.srt.configs.device_config import DeviceConfig
|
33
32
|
from sglang.srt.configs.load_config import LoadConfig
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
|
38
37
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
38
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
40
39
|
from sglang.srt.layers.sampler import Sampler
|
40
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
41
41
|
from sglang.srt.lora.lora_manager import LoRAManager
|
42
42
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
43
43
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -111,11 +111,13 @@ class ModelRunner:
|
|
111
111
|
)
|
112
112
|
|
113
113
|
if self.is_multimodal:
|
114
|
-
logger.info(
|
115
|
-
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
116
|
-
)
|
117
114
|
server_args.chunked_prefill_size = -1
|
118
115
|
self.mem_fraction_static *= 0.95
|
116
|
+
logger.info(
|
117
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
|
118
|
+
f"and turn off chunked prefill "
|
119
|
+
f"because this is a multimodal model."
|
120
|
+
)
|
119
121
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
120
122
|
if self.model_config.hf_config.architectures == [
|
121
123
|
"Qwen2VLForConditionalGeneration"
|
@@ -139,6 +141,7 @@ class ModelRunner:
|
|
139
141
|
"torchao_config": server_args.torchao_config,
|
140
142
|
"enable_nan_detection": server_args.enable_nan_detection,
|
141
143
|
"enable_dp_attention": server_args.enable_dp_attention,
|
144
|
+
"enable_ep_moe": server_args.enable_ep_moe,
|
142
145
|
}
|
143
146
|
)
|
144
147
|
|
@@ -159,6 +162,10 @@ class ModelRunner:
|
|
159
162
|
else:
|
160
163
|
self.torch_tp_applied = False
|
161
164
|
|
165
|
+
apply_torchao_config_to_model(
|
166
|
+
self.model, global_server_args_dict["torchao_config"]
|
167
|
+
)
|
168
|
+
|
162
169
|
# Init memory pool and attention backends
|
163
170
|
if server_args.lora_paths is not None:
|
164
171
|
self.init_lora_manager()
|
sglang/srt/model_parallel.py
CHANGED
@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|
54
54
|
)._prepare_output_fn(
|
55
55
|
output_layouts, use_local_output, mod, outputs, device_mesh
|
56
56
|
)
|
57
|
-
|
58
|
-
if isinstance(outputs, AsyncCollectiveTensor):
|
59
|
-
return outputs.wait()
|
60
|
-
else:
|
61
|
-
return outputs
|
57
|
+
return torch.distributed._functional_collectives.wait_tensor(outputs)
|
62
58
|
|
63
59
|
|
64
60
|
def tensor_parallel(
|
sglang/srt/models/commandr.py
CHANGED
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
64
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
65
|
-
from sglang.srt.utils import set_weight_attrs
|
65
|
+
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
66
66
|
|
67
67
|
|
68
|
-
@torch.compile
|
68
|
+
@torch.compile(backend=get_compiler_backend())
|
69
69
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
70
70
|
input_dtype = hidden_states.dtype
|
71
71
|
hidden_states = hidden_states.to(torch.float32)
|