sglang 0.3.1.post2__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_latency.py +12 -11
  2. sglang/bench_server_latency.py +0 -6
  3. sglang/srt/hf_transformers_utils.py +1 -0
  4. sglang/srt/layers/activation.py +3 -2
  5. sglang/srt/layers/attention_backend.py +6 -12
  6. sglang/srt/layers/fused_moe/patch.py +117 -0
  7. sglang/srt/layers/linear.py +1133 -0
  8. sglang/srt/layers/quantization/__init__.py +76 -0
  9. sglang/srt/layers/quantization/base_config.py +122 -0
  10. sglang/srt/managers/schedule_batch.py +3 -5
  11. sglang/srt/managers/tokenizer_manager.py +1 -0
  12. sglang/srt/managers/tp_worker.py +1 -1
  13. sglang/srt/mem_cache/radix_cache.py +5 -5
  14. sglang/srt/model_executor/cuda_graph_runner.py +10 -6
  15. sglang/srt/model_executor/forward_batch_info.py +2 -4
  16. sglang/srt/model_executor/model_runner.py +0 -3
  17. sglang/srt/models/baichuan.py +1 -1
  18. sglang/srt/models/chatglm.py +6 -6
  19. sglang/srt/models/commandr.py +7 -7
  20. sglang/srt/models/dbrx.py +7 -7
  21. sglang/srt/models/deepseek.py +7 -7
  22. sglang/srt/models/deepseek_v2.py +7 -7
  23. sglang/srt/models/exaone.py +6 -6
  24. sglang/srt/models/gemma.py +6 -6
  25. sglang/srt/models/gemma2.py +6 -6
  26. sglang/srt/models/gpt_bigcode.py +6 -6
  27. sglang/srt/models/grok.py +6 -6
  28. sglang/srt/models/internlm2.py +6 -6
  29. sglang/srt/models/llama.py +14 -6
  30. sglang/srt/models/llama_classification.py +1 -1
  31. sglang/srt/models/llava.py +1 -1
  32. sglang/srt/models/llavavid.py +1 -1
  33. sglang/srt/models/minicpm.py +6 -6
  34. sglang/srt/models/minicpm3.py +1 -1
  35. sglang/srt/models/mixtral.py +6 -6
  36. sglang/srt/models/mixtral_quant.py +6 -6
  37. sglang/srt/models/olmoe.py +1 -1
  38. sglang/srt/models/qwen.py +6 -6
  39. sglang/srt/models/qwen2.py +6 -6
  40. sglang/srt/models/qwen2_moe.py +7 -7
  41. sglang/srt/models/stablelm.py +6 -6
  42. sglang/srt/models/xverse.py +1 -1
  43. sglang/srt/models/xverse_moe.py +1 -1
  44. sglang/srt/models/yivl.py +1 -1
  45. sglang/srt/openai_api/adapter.py +7 -0
  46. sglang/srt/utils.py +21 -1
  47. sglang/test/runners.py +7 -9
  48. sglang/test/test_utils.py +39 -2
  49. sglang/version.py +1 -1
  50. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/METADATA +8 -6
  51. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/RECORD +54 -50
  52. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/LICENSE +0 -0
  53. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/WHEEL +0 -0
  54. {sglang-0.3.1.post2.dist-info → sglang-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,76 @@
1
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
+
3
+ from typing import Dict, Type
4
+
5
+ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
6
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
7
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
8
+ from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
9
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
10
+ CompressedTensorsConfig,
11
+ )
12
+ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
13
+ from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
14
+ from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
15
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
16
+ from vllm.model_executor.layers.quantization.gguf import GGUFConfig
17
+ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
18
+ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
19
+ from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
20
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
21
+ from vllm.model_executor.layers.quantization.qqq import QQQConfig
22
+ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
23
+ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
+
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+
27
+ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
+ "aqlm": AQLMConfig,
29
+ "awq": AWQConfig,
30
+ "deepspeedfp": DeepSpeedFPConfig,
31
+ "tpu_int8": Int8TpuConfig,
32
+ "fp8": Fp8Config,
33
+ "fbgemm_fp8": FBGEMMFp8Config,
34
+ # The order of gptq methods is important for config.py iteration over
35
+ # override_quantization_method(..)
36
+ "marlin": MarlinConfig,
37
+ "gguf": GGUFConfig,
38
+ "gptq_marlin_24": GPTQMarlin24Config,
39
+ "gptq_marlin": GPTQMarlinConfig,
40
+ "awq_marlin": AWQMarlinConfig,
41
+ "gptq": GPTQConfig,
42
+ "squeezellm": SqueezeLLMConfig,
43
+ "compressed-tensors": CompressedTensorsConfig,
44
+ "bitsandbytes": BitsAndBytesConfig,
45
+ "qqq": QQQConfig,
46
+ "experts_int8": ExpertsInt8Config,
47
+ }
48
+
49
+
50
+ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
51
+ if quantization not in QUANTIZATION_METHODS:
52
+ raise ValueError(f"Invalid quantization method: {quantization}")
53
+ return QUANTIZATION_METHODS[quantization]
54
+
55
+
56
+ __all__ = [
57
+ "QuantizationConfig",
58
+ "get_quantization_config",
59
+ "QUANTIZATION_METHODS",
60
+ ]
61
+
62
+ """
63
+ def fp8_get_quant_method(
64
+ self, layer: torch.nn.Module, prefix: str
65
+ ) -> Optional["QuantizeMethodBase"]:
66
+ if isinstance(layer, LinearBase):
67
+ if is_layer_skipped(prefix, self.ignored_layers):
68
+ return UnquantizedLinearMethod()
69
+ return Fp8LinearMethod(self)
70
+ elif isinstance(layer, FusedMoE):
71
+ return Fp8MoEMethod(self)
72
+ return None
73
+
74
+
75
+ setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
76
+ """
@@ -0,0 +1,122 @@
1
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class QuantizeMethodBase(ABC):
11
+ """Base class for different quantized methods."""
12
+
13
+ @abstractmethod
14
+ def create_weights(
15
+ self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs
16
+ ):
17
+ """Create weights for a layer.
18
+
19
+ The weights will be set as attributes of the layer."""
20
+ raise NotImplementedError
21
+
22
+ @abstractmethod
23
+ def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
24
+ """Apply the weights in layer to the input tensor.
25
+
26
+ Expects create_weights to have been called before on the layer."""
27
+ raise NotImplementedError
28
+
29
+ def process_weights_after_loading(self, layer: nn.Module) -> None:
30
+ """Process the weight after loading.
31
+
32
+ This can be used for example, to transpose weights for computation.
33
+ """
34
+ return
35
+
36
+
37
+ class QuantizationConfig(ABC):
38
+ """Base class for quantization configs."""
39
+
40
+ @abstractmethod
41
+ def get_name(self) -> str:
42
+ """Name of the quantization method."""
43
+ raise NotImplementedError
44
+
45
+ @abstractmethod
46
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
47
+ """List of supported activation dtypes."""
48
+ raise NotImplementedError
49
+
50
+ @classmethod
51
+ @abstractmethod
52
+ def get_min_capability(cls) -> int:
53
+ """Minimum GPU capability to support the quantization method.
54
+
55
+ E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
56
+ This requirement is due to the custom CUDA kernels used by the
57
+ quantization method.
58
+ """
59
+ raise NotImplementedError
60
+
61
+ @staticmethod
62
+ @abstractmethod
63
+ def get_config_filenames() -> List[str]:
64
+ """List of filenames to search for in the model directory."""
65
+ raise NotImplementedError
66
+
67
+ @classmethod
68
+ @abstractmethod
69
+ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
70
+ """Create a config class from the model's quantization config."""
71
+ raise NotImplementedError
72
+
73
+ @classmethod
74
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
75
+ """
76
+ Detects if this quantization method can support a given checkpoint
77
+ format by overriding the user specified quantization method --
78
+ this method should only be overwritten by subclasses in exceptional
79
+ circumstances
80
+ """
81
+ return None
82
+
83
+ @staticmethod
84
+ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
85
+ """Get a value from the model's quantization config."""
86
+ for key in keys:
87
+ if key in config:
88
+ return config[key]
89
+ raise ValueError(
90
+ f"Cannot find any of {keys} in the model's " "quantization config."
91
+ )
92
+
93
+ @staticmethod
94
+ def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any:
95
+ """Get a optional value from the model's quantization config."""
96
+ try:
97
+ return QuantizationConfig.get_from_keys(config, keys)
98
+ except ValueError:
99
+ return default
100
+
101
+ @abstractmethod
102
+ def get_quant_method(
103
+ self, layer: torch.nn.Module, prefix: str
104
+ ) -> Optional[QuantizeMethodBase]:
105
+ """Get the quantize method to use for the quantized layer.
106
+
107
+ Args:
108
+ layer: The layer for the quant method.
109
+ prefix: The full name of the layer in the state dict
110
+ Returns:
111
+ The quantize method. None if the given layer doesn't support quant
112
+ method.
113
+ """
114
+ raise NotImplementedError
115
+
116
+ @abstractmethod
117
+ def get_scaled_act_names(self) -> List[str]:
118
+ """Returns the activation function names that should be post-scaled.
119
+
120
+ For now, this is only used by AWQ.
121
+ """
122
+ raise NotImplementedError
@@ -429,7 +429,7 @@ class ScheduleBatch:
429
429
  def prepare_for_extend(self, vocab_size: int):
430
430
  self.forward_mode = ForwardMode.EXTEND
431
431
 
432
- bs = self.batch_size()
432
+ bs = len(self.reqs)
433
433
  reqs = self.reqs
434
434
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
435
435
  extend_num_tokens = sum(len(ids) for ids in input_ids)
@@ -509,7 +509,7 @@ class ScheduleBatch:
509
509
  self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
510
510
 
511
511
  def check_decode_mem(self):
512
- bs = self.batch_size()
512
+ bs = len(self.reqs)
513
513
  if self.token_to_kv_pool.available_size() >= bs:
514
514
  return True
515
515
 
@@ -680,14 +680,12 @@ class ScheduleBatch:
680
680
  r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
681
681
  for r in self.reqs
682
682
  ]
683
- else:
684
- self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids)
685
683
 
686
684
  self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
687
685
  self.seq_lens.add_(1)
688
686
 
689
687
  # Alloc mem
690
- bs = self.batch_size()
688
+ bs = len(self.reqs)
691
689
  self.out_cache_loc = self.alloc_token_slots(bs)
692
690
 
693
691
  self.req_to_token_pool.req_to_token[
@@ -123,6 +123,7 @@ class TokenizerManager:
123
123
  initializer=init_global_processor,
124
124
  mp_context=mp.get_context("fork"),
125
125
  initargs=(server_args,),
126
+ max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
126
127
  )
127
128
  else:
128
129
  self.tokenizer = get_tokenizer(
@@ -215,6 +215,7 @@ class ModelTpServer:
215
215
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
216
216
  self.do_not_get_new_batch = False
217
217
 
218
+ @torch.inference_mode()
218
219
  def exposed_step(self, recv_reqs: List):
219
220
  try:
220
221
  # Recv requests
@@ -246,7 +247,6 @@ class ModelTpServer:
246
247
  self.out_pyobjs = []
247
248
  return ret
248
249
 
249
- @torch.inference_mode()
250
250
  def forward_step(self):
251
251
  if self.do_not_get_new_batch and self.current_inflight_req is None:
252
252
  new_batch = None
@@ -291,15 +291,15 @@ class RadixCache(BasePrefixCache):
291
291
 
292
292
  def _collect_leaves(self):
293
293
  ret_list = []
294
+ stack = [self.root_node]
294
295
 
295
- def dfs_(cur_node):
296
+ while stack:
297
+ cur_node = stack.pop()
296
298
  if len(cur_node.children) == 0:
297
299
  ret_list.append(cur_node)
300
+ else:
301
+ stack.extend(cur_node.children.values())
298
302
 
299
- for x in cur_node.children.values():
300
- dfs_(x)
301
-
302
- dfs_(self.root_node)
303
303
  return ret_list
304
304
 
305
305
 
@@ -25,6 +25,7 @@ import torch
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
+ from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
28
29
  from sglang.srt.layers.logits_processor import (
29
30
  LogitsMetadata,
30
31
  LogitsProcessor,
@@ -41,14 +42,15 @@ if TYPE_CHECKING:
41
42
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
42
43
  for sub in model._modules.values():
43
44
  if isinstance(sub, CustomOp):
44
- # NOTE: FusedMoE torch native implementaiton is not efficient
45
- if "FusedMoE" in sub.__class__.__name__:
46
- continue
47
45
  if reverse:
48
46
  sub._forward_method = sub.forward_cuda
49
47
  setattr(sub, "is_torch_compile", False)
50
48
  else:
51
- sub._forward_method = sub.forward_native
49
+ # NOTE: Temporarily workaround MoE
50
+ if "FusedMoE" in sub.__class__.__name__:
51
+ sub._forward_method = fused_moe_forward_native
52
+ else:
53
+ sub._forward_method = sub.forward_native
52
54
  setattr(sub, "is_torch_compile", True)
53
55
  if isinstance(sub, torch.nn.Module):
54
56
  _to_torch(sub, reverse)
@@ -67,7 +69,9 @@ def patch_model(
67
69
  monkey_patch_vllm_all_gather()
68
70
  backup_ca_comm = tp_group.ca_comm
69
71
  tp_group.ca_comm = None
70
- yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
72
+ yield torch.compile(
73
+ torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
74
+ )
71
75
  else:
72
76
  yield model.forward
73
77
  finally:
@@ -150,7 +154,7 @@ class CudaGraphRunner:
150
154
  f"Capture cuda graph failed: {e}\n"
151
155
  "Possible solutions:\n"
152
156
  "1. disable cuda graph by --disable-cuda-graph\n"
153
- "2. set --mem-fraction-static to a smaller value\n"
157
+ "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
154
158
  "3. disable torch compile by not using --enable-torch-compile\n"
155
159
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
156
160
  )
@@ -97,14 +97,12 @@ class InputMetadata:
97
97
  self.modalities = [r.modalities for r in reqs]
98
98
 
99
99
  def compute_positions(self, batch: ScheduleBatch):
100
- position_ids_offsets = batch.position_ids_offsets
101
-
102
100
  if self.forward_mode.is_decode():
103
101
  if True:
104
102
  self.positions = self.seq_lens - 1
105
103
  else:
106
104
  # Deprecated
107
- self.positions = (self.seq_lens - 1) + position_ids_offsets
105
+ self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
108
106
  else:
109
107
  if True:
110
108
  self.positions = torch.tensor(
@@ -119,7 +117,7 @@ class InputMetadata:
119
117
  )
120
118
  else:
121
119
  # Deprecated
122
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
120
+ position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
123
121
  self.positions = torch.tensor(
124
122
  np.concatenate(
125
123
  [
@@ -467,7 +467,6 @@ class ModelRunner:
467
467
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
468
468
  self.cuda_graph_runner = CudaGraphRunner(self)
469
469
 
470
- @torch.inference_mode()
471
470
  def forward_decode(self, batch: ScheduleBatch):
472
471
  if self.server_args.lora_paths is not None:
473
472
  self.lora_manager.prepare_lora_batch(batch)
@@ -481,7 +480,6 @@ class ModelRunner:
481
480
  batch.input_ids, input_metadata.positions, input_metadata
482
481
  )
483
482
 
484
- @torch.inference_mode()
485
483
  def forward_extend(self, batch: ScheduleBatch):
486
484
  input_metadata = InputMetadata.from_schedule_batch(self, batch)
487
485
  if self.server_args.lora_paths is not None:
@@ -500,7 +498,6 @@ class ModelRunner:
500
498
  get_embedding=True,
501
499
  )
502
500
 
503
- @torch.inference_mode()
504
501
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
505
502
  input_metadata = InputMetadata.from_schedule_batch(self, batch)
506
503
  return self.model.forward(
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  QKVParallelLinear,
35
35
  RowParallelLinear,
36
36
  )
37
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
37
  from vllm.model_executor.layers.rotary_embedding import get_rope
39
38
  from vllm.model_executor.layers.vocab_parallel_embedding import (
40
39
  ParallelLMHead,
@@ -45,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
44
  from sglang.srt.layers.activation import SiluAndMul
46
45
  from sglang.srt.layers.layernorm import RMSNorm
47
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
@@ -24,12 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.linear import (
28
- MergedColumnParallelLinear,
29
- QKVParallelLinear,
30
- RowParallelLinear,
31
- )
32
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
28
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
29
  ParallelLMHead,
@@ -40,7 +34,13 @@ from vllm.transformers_utils.configs import ChatGLMConfig
40
34
 
41
35
  from sglang.srt.layers.activation import SiluAndMul
42
36
  from sglang.srt.layers.layernorm import RMSNorm
37
+ from sglang.srt.layers.linear import (
38
+ MergedColumnParallelLinear,
39
+ QKVParallelLinear,
40
+ RowParallelLinear,
41
+ )
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
46
 
@@ -50,21 +50,21 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_rank,
51
51
  get_tensor_model_parallel_world_size,
52
52
  )
53
- from vllm.model_executor.layers.linear import (
54
- MergedColumnParallelLinear,
55
- QKVParallelLinear,
56
- RowParallelLinear,
57
- )
58
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
59
53
  from vllm.model_executor.layers.rotary_embedding import get_rope
60
54
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
61
55
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
62
- from vllm.model_executor.utils import set_weight_attrs
63
56
 
64
57
  from sglang.srt.layers.activation import SiluAndMul
58
+ from sglang.srt.layers.linear import (
59
+ MergedColumnParallelLinear,
60
+ QKVParallelLinear,
61
+ RowParallelLinear,
62
+ )
65
63
  from sglang.srt.layers.logits_processor import LogitsProcessor
64
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
66
65
  from sglang.srt.layers.radix_attention import RadixAttention
67
66
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
67
+ from sglang.srt.utils import set_weight_attrs
68
68
 
69
69
 
70
70
  @torch.compile
sglang/srt/models/dbrx.py CHANGED
@@ -27,12 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.fused_moe import fused_moe
30
- from vllm.model_executor.layers.linear import (
31
- QKVParallelLinear,
32
- ReplicatedLinear,
33
- RowParallelLinear,
34
- )
35
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
36
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
31
  from vllm.model_executor.layers.vocab_parallel_embedding import (
38
32
  DEFAULT_VOCAB_PADDING_SIZE,
@@ -40,12 +34,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
34
  VocabParallelEmbedding,
41
35
  )
42
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
- from vllm.model_executor.utils import set_weight_attrs
44
37
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
38
 
39
+ from sglang.srt.layers.linear import (
40
+ QKVParallelLinear,
41
+ ReplicatedLinear,
42
+ RowParallelLinear,
43
+ )
46
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
46
  from sglang.srt.layers.radix_attention import RadixAttention
48
47
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
48
+ from sglang.srt.utils import set_weight_attrs
49
49
 
50
50
 
51
51
  class DbrxRouter(nn.Module):
@@ -28,13 +28,6 @@ from vllm.distributed import (
28
28
  tensor_model_parallel_all_reduce,
29
29
  )
30
30
  from vllm.model_executor.layers.fused_moe import fused_moe
31
- from vllm.model_executor.layers.linear import (
32
- MergedColumnParallelLinear,
33
- QKVParallelLinear,
34
- ReplicatedLinear,
35
- RowParallelLinear,
36
- )
37
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
39
32
  from vllm.model_executor.layers.vocab_parallel_embedding import (
40
33
  ParallelLMHead,
@@ -44,7 +37,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
37
 
45
38
  from sglang.srt.layers.activation import SiluAndMul
46
39
  from sglang.srt.layers.layernorm import RMSNorm
40
+ from sglang.srt.layers.linear import (
41
+ MergedColumnParallelLinear,
42
+ QKVParallelLinear,
43
+ ReplicatedLinear,
44
+ RowParallelLinear,
45
+ )
47
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
@@ -27,13 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.fused_moe import FusedMoE
30
- from vllm.model_executor.layers.linear import (
31
- ColumnParallelLinear,
32
- MergedColumnParallelLinear,
33
- ReplicatedLinear,
34
- RowParallelLinear,
35
- )
36
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
37
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
38
31
  from vllm.model_executor.layers.vocab_parallel_embedding import (
39
32
  ParallelLMHead,
@@ -43,7 +36,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
36
 
44
37
  from sglang.srt.layers.activation import SiluAndMul
45
38
  from sglang.srt.layers.layernorm import RMSNorm
39
+ from sglang.srt.layers.linear import (
40
+ ColumnParallelLinear,
41
+ MergedColumnParallelLinear,
42
+ ReplicatedLinear,
43
+ RowParallelLinear,
44
+ )
46
45
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -23,12 +23,6 @@ import torch
23
23
  from torch import nn
24
24
  from vllm.config import CacheConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.linear import (
27
- MergedColumnParallelLinear,
28
- QKVParallelLinear,
29
- RowParallelLinear,
30
- )
31
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
33
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
34
28
  ParallelLMHead,
@@ -38,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
32
 
39
33
  from sglang.srt.layers.activation import SiluAndMul
40
34
  from sglang.srt.layers.layernorm import RMSNorm
35
+ from sglang.srt.layers.linear import (
36
+ MergedColumnParallelLinear,
37
+ QKVParallelLinear,
38
+ RowParallelLinear,
39
+ )
41
40
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
41
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
@@ -23,19 +23,19 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.linear import (
27
- MergedColumnParallelLinear,
28
- QKVParallelLinear,
29
- RowParallelLinear,
30
- )
31
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
33
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
28
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
29
 
36
30
  from sglang.srt.layers.activation import GeluAndMul
37
31
  from sglang.srt.layers.layernorm import RMSNorm
32
+ from sglang.srt.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
38
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
41
 
@@ -22,12 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig, LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.linear import (
26
- MergedColumnParallelLinear,
27
- QKVParallelLinear,
28
- RowParallelLinear,
29
- )
30
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
31
25
 
32
26
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
33
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
@@ -35,7 +29,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
29
 
36
30
  from sglang.srt.layers.activation import GeluAndMul
37
31
  from sglang.srt.layers.layernorm import GemmaRMSNorm
32
+ from sglang.srt.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
38
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
41
 
@@ -23,17 +23,17 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.linear import (
27
- ColumnParallelLinear,
28
- QKVParallelLinear,
29
- RowParallelLinear,
30
- )
31
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
32
26
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
33
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
28
 
35
29
  from sglang.srt.layers.activation import get_act_fn
30
+ from sglang.srt.layers.linear import (
31
+ ColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear,
34
+ )
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
39