sglang 0.3.1__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/bench_latency.py CHANGED
@@ -63,7 +63,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63
63
  from sglang.srt.model_executor.model_runner import ModelRunner
64
64
  from sglang.srt.sampling.sampling_params import SamplingParams
65
65
  from sglang.srt.server_args import ServerArgs
66
- from sglang.srt.utils import suppress_other_loggers
66
+ from sglang.srt.utils import kill_child_process, suppress_other_loggers
67
67
 
68
68
 
69
69
  @dataclasses.dataclass
@@ -502,4 +502,9 @@ if __name__ == "__main__":
502
502
  format="%(message)s",
503
503
  )
504
504
 
505
- main(server_args, bench_args)
505
+ try:
506
+ main(server_args, bench_args)
507
+ except Exception as e:
508
+ raise e
509
+ finally:
510
+ kill_child_process(os.getpid(), including_parent=False)
sglang/global_config.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Global configurations"""
2
2
 
3
+ import os
4
+
3
5
 
4
6
  class GlobalConfig:
5
7
  def __init__(self):
@@ -16,30 +18,20 @@ class GlobalConfig:
16
18
  self.base_min_new_token_ratio = 0.1
17
19
  self.new_token_ratio_decay = 0.001
18
20
 
19
- # Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
20
- # This can improve the speed for large batch sizes during prefill.
21
- self.layer_sync_threshold = 8192
22
-
23
21
  # Runtime constants: others
24
22
  self.num_continue_decode_steps = 10
25
23
  self.retract_decode_steps = 20
26
- self.flashinfer_workspace_size = 384 * 1024 * 1024
24
+ self.flashinfer_workspace_size = os.environ.get(
25
+ "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
26
+ )
27
27
 
28
28
  # Output tokenization configs
29
29
  self.skip_special_tokens_in_output = True
30
30
  self.spaces_between_special_tokens_in_out = True
31
31
 
32
32
  # Interpreter optimization configs
33
- self.eager_fill_image = False
34
33
  self.enable_precache_with_tracing = True
35
34
  self.enable_parallel_encoding = True
36
- self.enable_parallel_decoding = True
37
-
38
- # Deprecated
39
- # Choices: ["no_adjust", "adjust_cache"]
40
- # no_adjust: Do not adjust the position embedding of KV cache.
41
- # adjust_cache: Adjust the position embedding of KV cache.
42
- self.concate_and_append_mode = "no_adjust"
43
35
 
44
36
 
45
37
  global_config = GlobalConfig()
@@ -434,9 +434,6 @@ class StreamExecutor:
434
434
  self.cur_images.append((path, base64_data))
435
435
  self.text_ += self.chat_template.image_token
436
436
 
437
- # if global_config.eager_fill_image:
438
- # self.backend.fill_image(self)
439
-
440
437
  def _spec_gen(self, sampling_params):
441
438
  stop = sampling_params.stop
442
439
  max_new_tokens = sampling_params.max_new_tokens
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
29
29
  tokenizer_args_dict,
30
30
  enable=True,
31
31
  skip_tokenizer_init=False,
32
+ constrained_json_whitespace_pattern=None,
32
33
  ):
33
34
  super().__init__(enable=enable)
34
35
 
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
63
64
  self.outlines_tokenizer.vocabulary = (
64
65
  self.outlines_tokenizer.tokenizer.get_vocab()
65
66
  )
67
+ self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
66
68
 
67
69
  def init_value(self, key):
68
70
  key_type, key_string = key
69
71
  if key_type == "json":
70
- regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
72
+ regex = build_regex_from_schema(
73
+ key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
74
+ )
71
75
  elif key_type == "regex":
72
76
  regex = key_string
73
77
  else:
@@ -13,6 +13,7 @@ limitations under the License.
13
13
 
14
14
  """Fused operators for activation layers."""
15
15
 
16
+ import logging
16
17
  from typing import Optional
17
18
 
18
19
  import torch
@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
28
29
  from vllm.model_executor.layers.quantization import QuantizationConfig
29
30
  from vllm.model_executor.utils import set_weight_attrs
30
31
 
32
+ from sglang.srt.utils import is_hip
33
+
34
+ logger = logging.getLogger(__name__)
35
+
31
36
 
32
37
  class SiluAndMul(CustomOp):
33
38
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -135,3 +140,10 @@ def get_act_fn(
135
140
  act_fn, intermediate_size, input_is_parallel, params_dtype
136
141
  )
137
142
  return act_fn
143
+
144
+
145
+ if is_hip():
146
+ logger.info(
147
+ "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
148
+ )
149
+ from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
12
12
 
13
13
  import torch
14
14
  import torch.nn as nn
15
- from flashinfer import (
16
- BatchDecodeWithPagedKVCacheWrapper,
17
- BatchPrefillWithPagedKVCacheWrapper,
18
- BatchPrefillWithRaggedKVCacheWrapper,
19
- )
20
- from flashinfer.cascade import merge_state
21
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
22
15
 
23
16
  from sglang.global_config import global_config
24
17
  from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
25
18
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
26
19
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
20
+ from sglang.srt.utils import is_hip
27
21
 
28
22
  if TYPE_CHECKING:
29
23
  from sglang.srt.model_executor.model_runner import ModelRunner
30
24
 
25
+ # ROCm: flashinfer available later
26
+ if not is_hip():
27
+ from flashinfer import (
28
+ BatchDecodeWithPagedKVCacheWrapper,
29
+ BatchPrefillWithPagedKVCacheWrapper,
30
+ BatchPrefillWithRaggedKVCacheWrapper,
31
+ )
32
+ from flashinfer.cascade import merge_state
33
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
+
31
35
 
32
36
  class AttentionBackend(ABC):
33
37
  """The base class of attention backends"""
@@ -150,7 +154,7 @@ class FlashInferAttnBackend(AttentionBackend):
150
154
  # Some heuristics to check whether to use ragged forward
151
155
  use_ragged = False
152
156
  if (
153
- int(torch.sum(input_metadata.seq_lens)) > 4096
157
+ torch.sum(input_metadata.seq_lens).item() >= 4096
154
158
  and self.model_runner.sliding_window_size is None
155
159
  ):
156
160
  use_ragged = True
@@ -301,10 +305,6 @@ class FlashInferAttnBackend(AttentionBackend):
301
305
  layer.layer_id, input_metadata.out_cache_loc, k, v
302
306
  )
303
307
 
304
- if total_num_tokens >= global_config.layer_sync_threshold:
305
- # TODO: Revisit this. Why is this synchronize needed?
306
- torch.cuda.synchronize()
307
-
308
308
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
309
309
 
310
310
  def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata):
@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
18
18
  from vllm.model_executor.layers.quantization.fp8 import Fp8Config
19
19
  from vllm.model_executor.utils import set_weight_attrs
20
20
 
21
+ from sglang.srt.utils import is_hip
22
+
21
23
  logger = init_logger(__name__)
22
24
 
23
25
 
@@ -381,6 +383,7 @@ from torch.nn import Module
381
383
  from vllm import _custom_ops as ops
382
384
  from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
383
385
  all_close_1d,
386
+ normalize_e4m3fn_to_e4m3fnuz,
384
387
  per_tensor_dequantize,
385
388
  )
386
389
  from vllm.utils import print_warning_once
@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
479
482
 
480
483
  def process_weights_after_loading(self, layer: Module) -> None:
481
484
 
482
- # If checkpoint is fp16, quantize in place.
485
+ # If checkpoint is fp16 or bfloat16, quantize in place.
483
486
  if not self.quant_config.is_checkpoint_fp8_serialized:
484
- w13_weight = torch.empty_like(
485
- layer.w13_weight.data, dtype=torch.float8_e4m3fn
486
- )
487
- w2_weight = torch.empty_like(
488
- layer.w2_weight.data, dtype=torch.float8_e4m3fn
489
- )
487
+ # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
488
+ fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
489
+ w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
490
+ w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
490
491
 
491
492
  # Re-initialize w13_scale because we directly quantize
492
493
  # merged w13 weights and generate a single scaling factor.
@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
534
535
  layer.a2_scale.max(), requires_grad=False
535
536
  )
536
537
 
538
+ # If ROCm, normalize the weights and scales to e4m3fnuz
539
+ if is_hip():
540
+ # Normalize the weights and scales
541
+ w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
542
+ layer.w13_weight, layer.w13_scale, layer.a13_scale
543
+ )
544
+ w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
545
+ layer.w2_weight, layer.w2_scale, layer.a2_scale
546
+ )
547
+ # Reset the parameters
548
+ layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
549
+ layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
550
+ if a13_scale is not None:
551
+ layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
552
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
553
+ layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
554
+ if a2_scale is not None:
555
+ layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
556
+
537
557
  # Fp8 moe kernel needs single weight scale for w13 per expert.
538
558
  # We take the max then dequant and requant each expert.
539
559
  assert layer.w13_scale is not None
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  """Fused operators for normalization layers."""
17
17
 
18
+ import logging
18
19
  from typing import Optional, Tuple, Union
19
20
 
20
21
  import torch
@@ -27,6 +28,10 @@ from flashinfer.norm import (
27
28
  )
28
29
  from vllm.model_executor.custom_op import CustomOp
29
30
 
31
+ from sglang.srt.utils import is_hip
32
+
33
+ logger = logging.getLogger(__name__)
34
+
30
35
 
31
36
  class RMSNorm(CustomOp):
32
37
  def __init__(
@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
109
114
  return x, residual
110
115
  out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
111
116
  return out
117
+
118
+
119
+ if is_hip():
120
+ logger.info(
121
+ "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
122
+ )
123
+ from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -1,51 +1,28 @@
1
- import dataclasses
2
1
  import logging
3
- from typing import Tuple, Union
2
+ from typing import Union
4
3
 
5
4
  import torch
6
- from flashinfer.sampling import (
7
- min_p_sampling_from_probs,
8
- top_k_renorm_prob,
9
- top_k_top_p_sampling_from_probs,
10
- top_p_renorm_prob,
11
- )
12
- from torch.library import custom_op as torch_custom_op
13
- from vllm.model_executor.custom_op import CustomOp
5
+ from torch import nn
14
6
 
15
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
-
17
- # TODO: move this dict to another place
18
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
19
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
+ from sglang.srt.utils import is_hip
11
+
12
+ # ROCm: flashinfer available later
13
+ if not is_hip():
14
+ from flashinfer.sampling import (
15
+ min_p_sampling_from_probs,
16
+ top_k_renorm_prob,
17
+ top_k_top_p_sampling_from_probs,
18
+ top_p_renorm_prob,
19
+ )
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- @dataclasses.dataclass
25
- class SampleOutput:
26
- success: torch.Tensor
27
- probs: torch.Tensor
28
- batch_next_token_ids: torch.Tensor
29
-
30
-
31
- class Sampler(CustomOp):
32
- def __init__(self):
33
- super().__init__()
34
- # FIXME: torch.multinomial has too many bugs
35
- self.forward_native = self.forward_cuda
36
- self.is_torch_compile = False
37
-
38
- def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
39
- # Post process logits
40
- logits = logits.contiguous()
41
- logits.div_(sampling_info.temperatures)
42
- if self.is_torch_compile:
43
- # FIXME: Temporary workaround for unknown bugs in torch.compile
44
- logits.add_(0)
45
-
46
- return torch.softmax(logits, dim=-1)
47
-
48
- def forward_cuda(
24
+ class Sampler(nn.Module):
25
+ def forward(
49
26
  self,
50
27
  logits: Union[torch.Tensor, LogitsProcessorOutput],
51
28
  sampling_info: SamplingBatchInfo,
@@ -53,7 +30,15 @@ class Sampler(CustomOp):
53
30
  if isinstance(logits, LogitsProcessorOutput):
54
31
  logits = logits.next_token_logits
55
32
 
56
- probs = self._get_probs(logits, sampling_info)
33
+ # Post process logits
34
+ logits.div_(sampling_info.temperatures)
35
+ probs = logits[:] = torch.softmax(logits, dim=-1)
36
+
37
+ if torch.any(torch.isnan(probs)):
38
+ logger.warning("Detected errors during sampling! NaN in the probability.")
39
+ probs = torch.where(
40
+ torch.isnan(probs), torch.full_like(probs, 1e-10), probs
41
+ )
57
42
 
58
43
  if global_server_args_dict["sampling_backend"] == "flashinfer":
59
44
  max_top_k_round, batch_size = 32, probs.shape[0]
@@ -67,12 +52,16 @@ class Sampler(CustomOp):
67
52
  probs, uniform_samples, sampling_info.min_ps
68
53
  )
69
54
  else:
70
- batch_next_token_ids, success = flashinfer_top_k_top_p(
55
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
71
56
  probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
72
57
  )
58
+
59
+ if not torch.all(success):
60
+ logger.warning("Detected errors during sampling!")
61
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
73
62
  elif global_server_args_dict["sampling_backend"] == "pytorch":
74
63
  # Here we provide a slower fallback implementation.
75
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
64
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
76
65
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
77
66
  )
78
67
  else:
@@ -80,48 +69,7 @@ class Sampler(CustomOp):
80
69
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
81
70
  )
82
71
 
83
- return SampleOutput(success, probs, batch_next_token_ids)
84
-
85
- def forward_native(
86
- self,
87
- logits: Union[torch.Tensor, LogitsProcessorOutput],
88
- sampling_info: SamplingBatchInfo,
89
- ):
90
- if isinstance(logits, LogitsProcessorOutput):
91
- logits = logits.next_token_logits
92
-
93
- probs = self._get_probs(logits, sampling_info)
94
-
95
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
96
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
97
- )
98
-
99
- return SampleOutput(success, probs, batch_next_token_ids)
100
-
101
-
102
- @torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
103
- def flashinfer_top_k_top_p(
104
- probs: torch.Tensor,
105
- uniform_samples: torch.Tensor,
106
- top_ks: torch.Tensor,
107
- top_ps: torch.Tensor,
108
- ) -> Tuple[torch.Tensor, torch.Tensor]:
109
- # NOTE: we do not use min_p neither in CUDA nor in torch.compile
110
- return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
111
-
112
-
113
- @flashinfer_top_k_top_p.register_fake
114
- def _(
115
- probs: torch.Tensor,
116
- uniform_samples: torch.Tensor,
117
- top_ks: torch.Tensor,
118
- top_ps: torch.Tensor,
119
- ) -> Tuple[torch.Tensor, torch.Tensor]:
120
- bs = probs.shape[0]
121
- return (
122
- torch.ones(bs, dtype=torch.bool, device=probs.device),
123
- torch.zeros(bs, dtype=torch.int32, device=probs.device),
124
- )
72
+ return batch_next_token_ids
125
73
 
126
74
 
127
75
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -141,19 +89,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
141
89
  ] = 0.0
142
90
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
143
91
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
144
- try:
145
- # FIXME: torch.multiomial does not support num_samples = 1
146
- sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
147
- :, :1
148
- ]
149
- except RuntimeError as e:
150
- logger.warning(f"Sampling error: {e}")
151
- batch_next_token_ids = torch.zeros(
152
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
153
- )
154
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
155
- return batch_next_token_ids, success
156
-
92
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
157
93
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
158
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
159
- return batch_next_token_ids, success
94
+ return batch_next_token_ids
@@ -21,12 +21,15 @@ import re
21
21
  from dataclasses import dataclass
22
22
 
23
23
  import torch
24
- from flashinfer import SegmentGEMMWrapper
25
24
 
26
25
  from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
27
26
  from sglang.srt.lora.lora_config import LoRAConfig
28
27
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
29
- from sglang.srt.utils import replace_submodule
28
+ from sglang.srt.utils import is_hip, replace_submodule
29
+
30
+ # ROCm: flashinfer available later
31
+ if not is_hip():
32
+ from flashinfer import SegmentGEMMWrapper
30
33
 
31
34
 
32
35
  def get_stacked_name(name):
@@ -96,10 +99,10 @@ class LoRAManager:
96
99
  # get configs and target modules
97
100
  self.configs = {}
98
101
  self.origin_target_modules = set()
99
- for path in self.lora_paths:
100
- self.configs[path] = LoRAConfig(path)
102
+ for name, path in self.lora_paths.items():
103
+ self.configs[name] = LoRAConfig(path)
101
104
  self.origin_target_modules = set(self.origin_target_modules) | set(
102
- self.configs[path].target_modules
105
+ self.configs[name].target_modules
103
106
  )
104
107
  self.target_modules = set(
105
108
  [
@@ -114,11 +117,11 @@ class LoRAManager:
114
117
  # load all weights to cpu
115
118
  self.loras = []
116
119
  self.lora_id = {}
117
- for path in self.lora_paths:
118
- self.lora_id[path] = len(self.loras)
120
+ for name in self.lora_paths.keys():
121
+ self.lora_id[name] = len(self.loras)
119
122
  self.loras.append(
120
123
  LoRAAdapter(
121
- path, self.configs[path], self.base_hf_config, self.load_config
124
+ name, self.configs[name], self.base_hf_config, self.load_config
122
125
  )
123
126
  )
124
127
  self.loras[-1].initialize_weights()
@@ -360,6 +360,7 @@ class ScheduleBatch:
360
360
  tree_cache: BasePrefixCache
361
361
 
362
362
  forward_mode: ForwardMode = None
363
+ sampling_info: SamplingBatchInfo = None
363
364
 
364
365
  # Batched arguments to model runner
365
366
  input_ids: torch.Tensor = None
@@ -198,6 +198,7 @@ class ModelTpServer:
198
198
  "trust_remote_code": server_args.trust_remote_code,
199
199
  },
200
200
  skip_tokenizer_init=server_args.skip_tokenizer_init,
201
+ constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
201
202
  )
202
203
  self.jump_forward_cache = JumpForwardCache()
203
204
 
@@ -414,7 +415,7 @@ class ModelTpServer:
414
415
 
415
416
  # Truncate prompts that are too long
416
417
  if len(req.origin_input_ids) >= self.max_req_input_len:
417
- logger.warn(
418
+ logger.warning(
418
419
  "Request length is longer than the KV cache pool size or "
419
420
  "the max context length. Truncated!!!"
420
421
  )
@@ -807,12 +808,10 @@ class ModelTpServer:
807
808
  unfinished_indices.append(i)
808
809
 
809
810
  if req.finished() or (
810
- (
811
- req.stream
812
- and (
813
- self.decode_forward_ct % self.stream_interval == 0
814
- or len(req.output_ids) == 1
815
- )
811
+ req.stream
812
+ and (
813
+ self.decode_forward_ct % self.stream_interval == 0
814
+ or len(req.output_ids) == 1
816
815
  )
817
816
  ):
818
817
  output_rids.append(req.rid)
@@ -937,6 +936,8 @@ class ModelTpServer:
937
936
  if success:
938
937
  flash_cache_success = self.flush_cache()
939
938
  assert flash_cache_success, "Cache flush failed after updating weights"
939
+ else:
940
+ logger.error(message)
940
941
  return success, message
941
942
 
942
943
 
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
41
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
42
42
  for sub in model._modules.values():
43
43
  if isinstance(sub, CustomOp):
44
+ # NOTE: FusedMoE torch native implementaiton is not efficient
45
+ if "FusedMoE" in sub.__class__.__name__:
46
+ continue
44
47
  if reverse:
45
48
  sub._forward_method = sub.forward_cuda
46
49
  setattr(sub, "is_torch_compile", False)
@@ -105,7 +108,15 @@ class CudaGraphRunner:
105
108
  self.capture_bs = list(range(1, 32)) + [64, 128]
106
109
  else:
107
110
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
108
- self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
111
+ self.compile_bs = (
112
+ [
113
+ bs
114
+ for bs in self.capture_bs
115
+ if bs <= self.model_runner.server_args.max_torch_compile_bs
116
+ ]
117
+ if self.use_torch_compile
118
+ else []
119
+ )
109
120
 
110
121
  # Common inputs
111
122
  self.max_bs = max(self.capture_bs)