sglang 0.4.0__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/srt/constrained/outlines_backend.py +5 -0
  3. sglang/srt/constrained/xgrammar_backend.py +5 -5
  4. sglang/srt/layers/attention/__init__.py +5 -2
  5. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  6. sglang/srt/layers/attention/flashinfer_backend.py +20 -5
  7. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  8. sglang/srt/layers/attention/triton_backend.py +22 -8
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  10. sglang/srt/layers/ep_moe/__init__.py +0 -0
  11. sglang/srt/layers/ep_moe/kernels.py +349 -0
  12. sglang/srt/layers/ep_moe/layer.py +661 -0
  13. sglang/srt/layers/quantization/__init__.py +2 -2
  14. sglang/srt/layers/quantization/fp8.py +559 -0
  15. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  16. sglang/srt/layers/radix_attention.py +4 -2
  17. sglang/srt/layers/sampler.py +2 -0
  18. sglang/srt/layers/torchao_utils.py +23 -45
  19. sglang/srt/managers/schedule_batch.py +1 -0
  20. sglang/srt/managers/scheduler.py +69 -65
  21. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  22. sglang/srt/mem_cache/memory_pool.py +5 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +15 -1
  24. sglang/srt/model_executor/model_runner.py +11 -4
  25. sglang/srt/model_parallel.py +1 -5
  26. sglang/srt/models/commandr.py +2 -2
  27. sglang/srt/models/deepseek_v2.py +87 -7
  28. sglang/srt/models/grok.py +0 -5
  29. sglang/srt/models/llama.py +0 -5
  30. sglang/srt/models/mixtral.py +12 -9
  31. sglang/srt/models/phi3_small.py +0 -5
  32. sglang/srt/models/qwen2_moe.py +0 -5
  33. sglang/srt/models/torch_native_llama.py +0 -5
  34. sglang/srt/sampling/sampling_batch_info.py +9 -8
  35. sglang/srt/server.py +3 -3
  36. sglang/srt/server_args.py +43 -4
  37. sglang/srt/utils.py +50 -0
  38. sglang/version.py +1 -1
  39. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  40. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/RECORD +43 -38
  41. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  42. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  43. {sglang-0.4.0.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,24 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
- from typing import Dict, Set
6
-
7
5
  import torch
8
6
 
9
7
 
10
- def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11
- """Quantize a Tensor with torchao quantization specified by torchao_config
8
+ def apply_torchao_config_to_model(
9
+ model: torch.nn.Module, torchao_config: str, filter_fn=None
10
+ ):
11
+ """Quantize a modelwith torchao quantization specified by torchao_config
12
12
 
13
13
  Args:
14
- `param`: weight parameter of the linear module
15
- `torchao_config`: type of quantization and their arguments we want to use to
16
- quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
14
+ `model`: a model to be quantized based on torchao_config
15
+ `torchao_config` (str): type of quantization and their arguments we want to use to
16
+ quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
17
17
  128
18
18
  """
19
19
  # Lazy import to suppress some warnings
20
20
  from torchao.quantization import (
21
21
  float8_dynamic_activation_float8_weight,
22
+ float8_weight_only,
22
23
  int4_weight_only,
23
24
  int8_dynamic_activation_int8_weight,
24
25
  int8_weight_only,
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
26
27
  )
27
28
  from torchao.quantization.observer import PerRow, PerTensor
28
29
 
29
- dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
30
- dummy_linear.weight = param
31
- if "int8wo" in torchao_config:
32
- quantize_(dummy_linear, int8_weight_only())
30
+ if filter_fn is None:
31
+
32
+ def filter_fn(module, fqn):
33
+ return "proj" in fqn
34
+
35
+ if torchao_config == "" or torchao_config is None:
36
+ return model
37
+ elif "int8wo" in torchao_config:
38
+ quantize_(model, int8_weight_only(), filter_fn=filter_fn)
33
39
  elif "int8dq" in torchao_config:
34
- quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
40
+ quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
35
41
  elif "int4wo" in torchao_config:
36
42
  group_size = int(torchao_config.split("-")[-1])
37
43
  assert group_size in [
@@ -40,13 +46,11 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
40
46
  128,
41
47
  256,
42
48
  ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
43
- quantize_(dummy_linear, int4_weight_only(group_size=group_size))
49
+ quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
44
50
  elif "fp8wo" in torchao_config:
45
- from torchao.quantization import float8_weight_only
46
-
47
51
  # this requires newer hardware
48
52
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
49
- quantize_(dummy_linear, float8_weight_only())
53
+ quantize_(model, float8_weight_only(), filter_fn=filter_fn)
50
54
  elif "fp8dq" in torchao_config:
51
55
  granularity = torchao_config.split("-")[-1]
52
56
  GRANULARITY_MAP = {
@@ -57,39 +61,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
57
61
  granularity in GRANULARITY_MAP
58
62
  ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
59
63
  quantize_(
60
- dummy_linear,
64
+ model,
61
65
  float8_dynamic_activation_float8_weight(
62
66
  granularity=GRANULARITY_MAP[granularity]
63
67
  ),
68
+ filter_fn=filter_fn,
64
69
  )
65
70
  else:
66
71
  raise ValueError(f"Unexpected config: {torchao_config}")
67
72
 
68
- return dummy_linear.weight
69
-
70
-
71
- def apply_torchao_config_(
72
- self: torch.nn.Module,
73
- params_dict: Dict[str, torch.Tensor],
74
- param_suffixes: Set[str],
75
- ) -> None:
76
- """A util function used for quantizing the weight parameters after they are loaded if
77
- self.torchao_config is specified
78
-
79
- Args:
80
- `self`: the model we want to quantize
81
- `params_dict`: dictionary mapping from param_name to the parameter Tensor
82
- `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
83
-
84
- Returns:
85
- None, the `params_dict` is modified inplace and the weights of `self` model are quantized
86
- """
87
- if self.torchao_config:
88
- for param_suffix in param_suffixes:
89
- for name in params_dict:
90
- param = params_dict[name]
91
- if param_suffix in name and param.ndim == 2:
92
- params_dict[name] = torchao_quantize_param_data(
93
- param, self.torchao_config
94
- )
95
- self.load_state_dict(params_dict, assign=True)
73
+ return model
@@ -58,6 +58,7 @@ global_server_args_dict = {
58
58
  "torchao_config": ServerArgs.torchao_config,
59
59
  "enable_nan_detection": ServerArgs.enable_nan_detection,
60
60
  "enable_dp_attention": ServerArgs.enable_dp_attention,
61
+ "enable_ep_moe": ServerArgs.enable_ep_moe,
61
62
  }
62
63
 
63
64
 
@@ -114,9 +114,6 @@ class Scheduler:
114
114
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
115
115
  self.enable_metrics = server_args.enable_metrics
116
116
 
117
- # Session info
118
- self.sessions = {}
119
-
120
117
  # Init inter-process communication
121
118
  context = zmq.Context(2)
122
119
 
@@ -259,6 +256,10 @@ class Scheduler:
259
256
  self.num_generated_tokens = 0
260
257
  self.last_decode_stats_tic = time.time()
261
258
  self.stream_interval = server_args.stream_interval
259
+ self.current_stream = torch.get_device_module(self.device).current_stream()
260
+
261
+ # Session info
262
+ self.sessions = {}
262
263
 
263
264
  # Init chunked prefill
264
265
  self.chunked_prefill_size = server_args.chunked_prefill_size
@@ -356,6 +357,7 @@ class Scheduler:
356
357
  )
357
358
 
358
359
  def watchdog_thread(self):
360
+ """A watch dog thread that will try to kill the server itself if one batch takes too long."""
359
361
  self.watchdog_last_forward_ct = 0
360
362
  self.watchdog_last_time = time.time()
361
363
 
@@ -433,61 +435,6 @@ class Scheduler:
433
435
 
434
436
  self.last_batch = batch
435
437
 
436
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
437
- # Check if other DP workers have running batches
438
- if local_batch is None:
439
- num_tokens = 0
440
- elif local_batch.forward_mode.is_decode():
441
- num_tokens = local_batch.batch_size()
442
- else:
443
- num_tokens = local_batch.extend_num_tokens
444
-
445
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
446
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
447
- torch.distributed.all_gather_into_tensor(
448
- global_num_tokens,
449
- local_num_tokens,
450
- group=self.tp_cpu_group,
451
- )
452
-
453
- if local_batch is None and global_num_tokens.max().item() > 0:
454
- local_batch = self.get_idle_batch()
455
-
456
- if local_batch is not None:
457
- local_batch.global_num_tokens = global_num_tokens.tolist()
458
-
459
- # Check forward mode for cuda graph
460
- if not self.server_args.disable_cuda_graph:
461
- forward_mode_state = torch.tensor(
462
- (
463
- 1
464
- if local_batch.forward_mode.is_decode()
465
- or local_batch.forward_mode.is_idle()
466
- else 0
467
- ),
468
- dtype=torch.int32,
469
- )
470
- torch.distributed.all_reduce(
471
- forward_mode_state,
472
- op=torch.distributed.ReduceOp.MIN,
473
- group=self.tp_cpu_group,
474
- )
475
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
476
-
477
- return local_batch
478
-
479
- def get_idle_batch(self):
480
- idle_batch = ScheduleBatch.init_new(
481
- [],
482
- self.req_to_token_pool,
483
- self.token_to_kv_pool,
484
- self.tree_cache,
485
- self.model_config,
486
- self.enable_overlap,
487
- )
488
- idle_batch.prepare_for_idle()
489
- return idle_batch
490
-
491
438
  def recv_requests(self):
492
439
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
493
440
  recv_reqs = []
@@ -993,7 +940,7 @@ class Scheduler:
993
940
  self.process_batch_result_prefill(batch, result)
994
941
  elif batch.forward_mode.is_dummy_first():
995
942
  batch.next_batch_sampling_info.update_regex_vocab_mask()
996
- torch.cuda.current_stream().synchronize()
943
+ self.current_stream.synchronize()
997
944
  batch.next_batch_sampling_info.sampling_info_done.set()
998
945
 
999
946
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
@@ -1049,13 +996,14 @@ class Scheduler:
1049
996
 
1050
997
  if req.grammar is not None:
1051
998
  req.grammar.accept_token(next_token_id)
999
+ req.grammar.finished = req.finished()
1052
1000
  else:
1053
1001
  # being chunked reqs' prefill is not finished
1054
1002
  req.is_being_chunked -= 1
1055
1003
 
1056
1004
  if batch.next_batch_sampling_info:
1057
1005
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1058
- torch.cuda.current_stream().synchronize()
1006
+ self.current_stream.synchronize()
1059
1007
  batch.next_batch_sampling_info.sampling_info_done.set()
1060
1008
 
1061
1009
  else: # embedding or reward model
@@ -1127,10 +1075,11 @@ class Scheduler:
1127
1075
 
1128
1076
  if req.grammar is not None:
1129
1077
  req.grammar.accept_token(next_token_id)
1078
+ req.grammar.finished = req.finished()
1130
1079
 
1131
1080
  if batch.next_batch_sampling_info:
1132
1081
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1133
- torch.cuda.current_stream().synchronize()
1082
+ self.current_stream.synchronize()
1134
1083
  batch.next_batch_sampling_info.sampling_info_done.set()
1135
1084
 
1136
1085
  self.stream_output(batch.reqs)
@@ -1328,6 +1277,61 @@ class Scheduler:
1328
1277
  )
1329
1278
  )
1330
1279
 
1280
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1281
+ # Check if other DP workers have running batches
1282
+ if local_batch is None:
1283
+ num_tokens = 0
1284
+ elif local_batch.forward_mode.is_decode():
1285
+ num_tokens = local_batch.batch_size()
1286
+ else:
1287
+ num_tokens = local_batch.extend_num_tokens
1288
+
1289
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1290
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1291
+ torch.distributed.all_gather_into_tensor(
1292
+ global_num_tokens,
1293
+ local_num_tokens,
1294
+ group=self.tp_cpu_group,
1295
+ )
1296
+
1297
+ if local_batch is None and global_num_tokens.max().item() > 0:
1298
+ local_batch = self.get_idle_batch()
1299
+
1300
+ if local_batch is not None:
1301
+ local_batch.global_num_tokens = global_num_tokens.tolist()
1302
+
1303
+ # Check forward mode for cuda graph
1304
+ if not self.server_args.disable_cuda_graph:
1305
+ forward_mode_state = torch.tensor(
1306
+ (
1307
+ 1
1308
+ if local_batch.forward_mode.is_decode()
1309
+ or local_batch.forward_mode.is_idle()
1310
+ else 0
1311
+ ),
1312
+ dtype=torch.int32,
1313
+ )
1314
+ torch.distributed.all_reduce(
1315
+ forward_mode_state,
1316
+ op=torch.distributed.ReduceOp.MIN,
1317
+ group=self.tp_cpu_group,
1318
+ )
1319
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1320
+
1321
+ return local_batch
1322
+
1323
+ def get_idle_batch(self):
1324
+ idle_batch = ScheduleBatch.init_new(
1325
+ [],
1326
+ self.req_to_token_pool,
1327
+ self.token_to_kv_pool,
1328
+ self.tree_cache,
1329
+ self.model_config,
1330
+ self.enable_overlap,
1331
+ )
1332
+ idle_batch.prepare_for_idle()
1333
+ return idle_batch
1334
+
1331
1335
  def move_ready_grammar_requests(self):
1332
1336
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1333
1337
  num_ready_reqs = 0
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
1469
1473
  dp_rank: Optional[int],
1470
1474
  pipe_writer,
1471
1475
  ):
1472
- # set cpu affinity to this gpu process
1473
- if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1474
- set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1475
-
1476
1476
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1477
1477
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1478
1478
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
1482
1482
  else:
1483
1483
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1484
1484
 
1485
+ # set cpu affinity to this gpu process
1486
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1487
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1488
+
1485
1489
  suppress_other_loggers()
1486
1490
  parent_process = psutil.Process().parent()
1487
1491
 
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
32
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
33
33
  from sglang.srt.managers.tp_worker import TpModelWorker
34
34
  from sglang.srt.server_args import ServerArgs
35
+ from sglang.srt.utils import get_compiler_backend
35
36
  from sglang.utils import get_exception_traceback
36
37
 
37
38
  logger = logging.getLogger(__name__)
38
39
 
39
40
 
40
- @torch.compile(dynamic=True)
41
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
41
42
  def resolve_future_token_ids(input_ids, future_token_ids_map):
42
43
  input_ids[:] = torch.where(
43
44
  input_ids < 0,
@@ -73,12 +74,13 @@ class TpModelWorkerClient:
73
74
  # Launch threads
74
75
  self.input_queue = Queue()
75
76
  self.output_queue = Queue()
76
- self.forward_stream = torch.cuda.Stream()
77
+ self.forward_stream = torch.get_device_module(self.device).Stream()
77
78
  self.forward_thread = threading.Thread(
78
79
  target=self.forward_thread_func,
79
80
  )
80
81
  self.forward_thread.start()
81
82
  self.parent_process = psutil.Process().parent()
83
+ self.scheduler_stream = torch.get_device_module(self.device).current_stream()
82
84
 
83
85
  def get_worker_info(self):
84
86
  return self.worker.get_worker_info()
@@ -97,7 +99,7 @@ class TpModelWorkerClient:
97
99
 
98
100
  def forward_thread_func(self):
99
101
  try:
100
- with torch.cuda.stream(self.forward_stream):
102
+ with torch.get_device_module(self.device).stream(self.forward_stream):
101
103
  self.forward_thread_func_()
102
104
  except Exception:
103
105
  traceback = get_exception_traceback()
@@ -122,7 +124,7 @@ class TpModelWorkerClient:
122
124
 
123
125
  # Create event
124
126
  self.launch_done = threading.Event()
125
- copy_done = torch.cuda.Event()
127
+ copy_done = torch.get_device_module(self.device).Event()
126
128
 
127
129
  # Resolve future tokens in the input
128
130
  input_ids = model_worker_batch.input_ids
@@ -190,7 +192,7 @@ class TpModelWorkerClient:
190
192
  )
191
193
 
192
194
  # A cuda stream sync here to avoid the cuda illegal memory access error.
193
- torch.cuda.current_stream().synchronize()
195
+ self.scheduler_stream.synchronize()
194
196
 
195
197
  # Push a new batch to the queue
196
198
  self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
27
27
  import torch
28
28
 
29
29
  from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.utils import get_compiler_backend
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
129
130
  return select_index.to(self.device, non_blocking=True)
130
131
 
131
132
  def free(self, free_index: torch.Tensor):
133
+ if free_index.numel() == 0:
134
+ return
135
+
132
136
  if self.is_not_in_free_group:
133
137
  self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
134
138
  else:
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
234
238
 
235
239
  # This compiled version is slower in the unit test
236
240
  # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
237
- @torch.compile(dynamic=True)
241
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
238
242
  def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
239
243
  dst_1[loc] = src_1.to(dtype).view(store_dtype)
240
244
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
47
47
  if "FusedMoE" in sub.__class__.__name__:
48
48
  if batch_size == 1:
49
49
  # The performance of torch.compile on this layer is not always good when bs > 1,
50
- # so we decide to skip it for now.
50
+ # so we decide to only use torch.compile when bs =1
51
51
  sub._forward_method = fused_moe_forward_native
52
52
  else:
53
53
  sub._forward_method = sub.forward_native
@@ -130,6 +130,20 @@ class CudaGraphRunner:
130
130
  self.capture_bs = list(range(1, 32)) + [64, 128]
131
131
  else:
132
132
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
133
+
134
+ if max(self.capture_bs) > model_runner.req_to_token_pool.size:
135
+ # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
136
+ # is very samll. We add more values here to make sure we capture the maximum bs.
137
+ self.capture_bs = list(
138
+ sorted(
139
+ set(
140
+ self.capture_bs
141
+ + [model_runner.req_to_token_pool.size - 1]
142
+ + [model_runner.req_to_token_pool.size]
143
+ )
144
+ )
145
+ )
146
+
133
147
  self.capture_bs = [
134
148
  bs
135
149
  for bs in self.capture_bs
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  initialize_model_parallel,
28
28
  set_custom_all_reduce,
29
29
  )
30
- from vllm.distributed.parallel_state import in_the_same_node_as
31
30
 
32
31
  from sglang.srt.configs.device_config import DeviceConfig
33
32
  from sglang.srt.configs.load_config import LoadConfig
@@ -38,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
38
37
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
38
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
40
39
  from sglang.srt.layers.sampler import Sampler
40
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
41
41
  from sglang.srt.lora.lora_manager import LoRAManager
42
42
  from sglang.srt.managers.schedule_batch import global_server_args_dict
43
43
  from sglang.srt.mem_cache.memory_pool import (
@@ -111,11 +111,13 @@ class ModelRunner:
111
111
  )
112
112
 
113
113
  if self.is_multimodal:
114
- logger.info(
115
- "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
116
- )
117
114
  server_args.chunked_prefill_size = -1
118
115
  self.mem_fraction_static *= 0.95
116
+ logger.info(
117
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
118
+ f"and turn off chunked prefill "
119
+ f"because this is a multimodal model."
120
+ )
119
121
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
120
122
  if self.model_config.hf_config.architectures == [
121
123
  "Qwen2VLForConditionalGeneration"
@@ -139,6 +141,7 @@ class ModelRunner:
139
141
  "torchao_config": server_args.torchao_config,
140
142
  "enable_nan_detection": server_args.enable_nan_detection,
141
143
  "enable_dp_attention": server_args.enable_dp_attention,
144
+ "enable_ep_moe": server_args.enable_ep_moe,
142
145
  }
143
146
  )
144
147
 
@@ -159,6 +162,10 @@ class ModelRunner:
159
162
  else:
160
163
  self.torch_tp_applied = False
161
164
 
165
+ apply_torchao_config_to_model(
166
+ self.model, global_server_args_dict["torchao_config"]
167
+ )
168
+
162
169
  # Init memory pool and attention backends
163
170
  if server_args.lora_paths is not None:
164
171
  self.init_lora_manager()
@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
54
54
  )._prepare_output_fn(
55
55
  output_layouts, use_local_output, mod, outputs, device_mesh
56
56
  )
57
- # wait for the output to be ready
58
- if isinstance(outputs, AsyncCollectiveTensor):
59
- return outputs.wait()
60
- else:
61
- return outputs
57
+ return torch.distributed._functional_collectives.wait_tensor(outputs)
62
58
 
63
59
 
64
60
  def tensor_parallel(
@@ -62,10 +62,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
64
  from sglang.srt.model_loader.weight_utils import default_weight_loader
65
- from sglang.srt.utils import set_weight_attrs
65
+ from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
66
 
67
67
 
68
- @torch.compile
68
+ @torch.compile(backend=get_compiler_backend())
69
69
  def layer_norm_func(hidden_states, weight, variance_epsilon):
70
70
  input_dtype = hidden_states.dtype
71
71
  hidden_states = hidden_states.to(torch.float32)