sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,161 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from sglang.srt.layers.attention import AttentionBackend
9
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
10
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
+
12
+ if TYPE_CHECKING:
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+
16
+ class TritonAttnBackend(AttentionBackend):
17
+ def __init__(self, model_runner: ModelRunner):
18
+ # Lazy import to avoid the initialization of cuda context
19
+ from sglang.srt.layers.attention.triton_ops.decode_attention import (
20
+ decode_attention_fwd,
21
+ )
22
+ from sglang.srt.layers.attention.triton_ops.extend_attention import (
23
+ extend_attention_fwd,
24
+ )
25
+
26
+ super().__init__()
27
+
28
+ self.decode_attention_fwd = decode_attention_fwd
29
+ self.extend_attention_fwd = extend_attention_fwd
30
+ self.num_head = (
31
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
32
+ )
33
+
34
+ if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
35
+ self.reduce_dtype = torch.float32
36
+ else:
37
+ self.reduce_dtype = torch.float16
38
+
39
+ self.forward_metadata = None
40
+
41
+ self.cuda_graph_max_seq_len = model_runner.model_config.context_len
42
+
43
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
44
+ """Init auxiliary variables for triton attention backend."""
45
+
46
+ if forward_batch.forward_mode.is_decode():
47
+ start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
48
+ start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
49
+
50
+ total_num_tokens = torch.sum(forward_batch.seq_lens).item()
51
+ attn_logits = torch.empty(
52
+ (self.num_head, total_num_tokens),
53
+ dtype=self.reduce_dtype,
54
+ device="cuda",
55
+ )
56
+
57
+ max_seq_len = torch.max(forward_batch.seq_lens).item()
58
+ max_extend_len = None
59
+ else:
60
+ start_loc = attn_logits = max_seq_len = None
61
+ prefix_lens = forward_batch.extend_prefix_lens
62
+ max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
63
+
64
+ self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
65
+
66
+ def init_cuda_graph_state(self, max_bs: int):
67
+ self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
68
+
69
+ self.cuda_graph_start_loc = torch.zeros(
70
+ (max_bs,), dtype=torch.int32, device="cuda"
71
+ )
72
+ self.cuda_graph_attn_logits = torch.empty(
73
+ (
74
+ self.num_head,
75
+ self.cuda_graph_max_total_num_tokens,
76
+ ),
77
+ dtype=self.reduce_dtype,
78
+ device="cuda",
79
+ )
80
+
81
+ def init_forward_metadata_capture_cuda_graph(
82
+ self, bs: int, req_pool_indices, seq_lens
83
+ ):
84
+ self.forward_metadata = (
85
+ self.cuda_graph_start_loc,
86
+ self.cuda_graph_attn_logits,
87
+ self.cuda_graph_max_seq_len,
88
+ None,
89
+ )
90
+
91
+ def init_forward_metadata_replay_cuda_graph(
92
+ self, bs: int, req_pool_indices, seq_lens
93
+ ):
94
+ self.cuda_graph_start_loc.zero_()
95
+ self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
96
+
97
+ def get_cuda_graph_seq_len_fill_value(self):
98
+ return 1
99
+
100
+ def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
101
+ # TODO: reuse the buffer across layers
102
+ if layer.qk_head_dim != layer.v_head_dim:
103
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
104
+ else:
105
+ o = torch.empty_like(q)
106
+
107
+ forward_batch.token_to_kv_pool.set_kv_buffer(
108
+ layer.layer_id, forward_batch.out_cache_loc, k, v
109
+ )
110
+
111
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
112
+ self.extend_attention_fwd(
113
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
114
+ k.contiguous(),
115
+ v.contiguous(),
116
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
117
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
118
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
119
+ forward_batch.req_to_token_pool.req_to_token,
120
+ forward_batch.req_pool_indices,
121
+ forward_batch.seq_lens,
122
+ forward_batch.extend_seq_lens,
123
+ forward_batch.extend_start_loc,
124
+ max_extend_len,
125
+ layer.scaling,
126
+ layer.logit_cap,
127
+ )
128
+ return o
129
+
130
+ def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
131
+ # During torch.compile, there is a bug in rotary_emb that causes the
132
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
133
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
134
+
135
+ # TODO: reuse the buffer across layers
136
+ if layer.qk_head_dim != layer.v_head_dim:
137
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
138
+ else:
139
+ o = torch.empty_like(q)
140
+
141
+ start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
142
+
143
+ forward_batch.token_to_kv_pool.set_kv_buffer(
144
+ layer.layer_id, forward_batch.out_cache_loc, k, v
145
+ )
146
+
147
+ self.decode_attention_fwd(
148
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
149
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
150
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
151
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
152
+ forward_batch.req_to_token_pool.req_to_token,
153
+ forward_batch.req_pool_indices,
154
+ start_loc,
155
+ forward_batch.seq_lens,
156
+ attn_logits,
157
+ max_seq_len,
158
+ layer.scaling,
159
+ layer.logit_cap,
160
+ )
161
+ return o
@@ -22,7 +22,9 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
25
+ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
26
+ context_attention_fwd,
27
+ )
26
28
 
27
29
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
30
 
@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
21
21
  import torch
22
22
  import torch.nn as nn
23
23
 
24
- from sglang.srt.utils import is_hip
24
+ from sglang.srt.utils import is_flashinfer_available
25
25
 
26
- if not is_hip():
26
+ if is_flashinfer_available():
27
27
  from flashinfer.norm import (
28
28
  fused_add_rmsnorm,
29
29
  gemma_fused_add_rmsnorm,
@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
119
119
  return out
120
120
 
121
121
 
122
- if is_hip():
122
+ if not is_flashinfer_available():
123
123
  logger.info(
124
- "FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
124
+ "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
125
125
  )
126
126
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -25,7 +25,7 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_gather,
26
26
  )
27
27
 
28
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
@@ -61,26 +61,30 @@ class LogitsMetadata:
61
61
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
62
62
 
63
63
  @classmethod
64
- def from_input_metadata(cls, input_metadata: InputMetadata):
65
- return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
66
- if input_metadata.forward_mode.is_extend():
64
+ def from_forward_batch(cls, forward_batch: ForwardBatch):
65
+ if forward_batch.return_logprob:
66
+ return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
67
+ else:
68
+ return_top_logprob = False
69
+
70
+ if forward_batch.forward_mode.is_extend():
67
71
  extend_logprob_pruned_lens_cpu = [
68
72
  extend_len - start_len
69
73
  for extend_len, start_len in zip(
70
- input_metadata.extend_seq_lens,
71
- input_metadata.extend_logprob_start_lens_cpu,
74
+ forward_batch.extend_seq_lens,
75
+ forward_batch.extend_logprob_start_lens_cpu,
72
76
  )
73
77
  ]
74
78
  else:
75
79
  extend_logprob_pruned_lens_cpu = None
76
80
  return cls(
77
- forward_mode=input_metadata.forward_mode,
78
- top_logprobs_nums=input_metadata.top_logprobs_nums,
79
- return_logprob=input_metadata.return_logprob,
81
+ forward_mode=forward_batch.forward_mode,
82
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
83
+ return_logprob=forward_batch.return_logprob,
80
84
  return_top_logprob=return_top_logprob,
81
- extend_seq_lens=input_metadata.extend_seq_lens,
82
- extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
83
- extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
85
+ extend_seq_lens=forward_batch.extend_seq_lens,
86
+ extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
87
+ extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
84
88
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
85
89
  )
86
90
 
@@ -162,10 +166,10 @@ class LogitsProcessor(nn.Module):
162
166
  input_ids,
163
167
  hidden_states,
164
168
  weight,
165
- logits_metadata: Union[LogitsMetadata, InputMetadata],
169
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
166
170
  ):
167
- if isinstance(logits_metadata, InputMetadata):
168
- logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
171
+ if isinstance(logits_metadata, ForwardBatch):
172
+ logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
169
173
  assert isinstance(logits_metadata, LogitsMetadata)
170
174
 
171
175
  # Get the last hidden states and last logits for the next token prediction
@@ -7,7 +7,7 @@ from enum import IntEnum
7
7
  import torch
8
8
  import torch.nn as nn
9
9
 
10
- from sglang.srt.model_executor.model_runner import InputMetadata
10
+ from sglang.srt.model_executor.model_runner import ForwardBatch
11
11
 
12
12
 
13
13
  class PoolingType(IntEnum):
@@ -36,10 +36,10 @@ class Pooler(nn.Module):
36
36
  self.normalize = normalize
37
37
 
38
38
  def forward(
39
- self, hidden_states: torch.Tensor, input_metadata: InputMetadata
39
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
40
40
  ) -> EmbeddingPoolerOutput:
41
41
  if self.pooling_type == PoolingType.LAST:
42
- last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
42
+ last_token_indices = torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1
43
43
  pooled_data = hidden_states[last_token_indices]
44
44
  else:
45
45
  raise ValueError(f"Invalid pooling type: {self.pooling_type}")
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
19
19
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
20
20
  from vllm.model_executor.layers.quantization.marlin import MarlinConfig
21
21
  from vllm.model_executor.layers.quantization.qqq import QQQConfig
22
- from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
23
22
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -39,7 +38,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
39
38
  "gptq_marlin": GPTQMarlinConfig,
40
39
  "awq_marlin": AWQMarlinConfig,
41
40
  "gptq": GPTQConfig,
42
- "squeezellm": SqueezeLLMConfig,
43
41
  "compressed-tensors": CompressedTensorsConfig,
44
42
  "bitsandbytes": BitsAndBytesConfig,
45
43
  "qqq": QQQConfig,
@@ -17,7 +17,7 @@ limitations under the License.
17
17
 
18
18
  from torch import nn
19
19
 
20
- from sglang.srt.model_executor.forward_batch_info import InputMetadata
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
21
 
22
22
 
23
23
  class RadixAttention(nn.Module):
@@ -32,9 +32,10 @@ class RadixAttention(nn.Module):
32
32
  scaling: float,
33
33
  num_kv_heads: int,
34
34
  layer_id: int,
35
- sliding_window_size: int = -1,
36
35
  logit_cap: float = 0.0,
37
36
  v_head_dim: int = -1,
37
+ sliding_window_size: int = -1,
38
+ is_cross_attention: bool = False,
38
39
  ):
39
40
  super().__init__()
40
41
  self.tp_q_head_num = num_heads
@@ -47,12 +48,13 @@ class RadixAttention(nn.Module):
47
48
  self.layer_id = layer_id
48
49
  self.logit_cap = logit_cap
49
50
  self.sliding_window_size = sliding_window_size or -1
51
+ self.is_cross_attention = is_cross_attention
50
52
 
51
- def forward(self, q, k, v, input_metadata: InputMetadata):
53
+ def forward(self, q, k, v, forward_batch: ForwardBatch):
52
54
  if k is not None:
53
55
  # For cross-layer sharing, kv can be None
54
56
  assert v is not None
55
57
  k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
56
58
  v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
57
59
 
58
- return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
60
+ return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
@@ -7,10 +7,9 @@ from torch import nn
7
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
- from sglang.srt.utils import is_hip
10
+ from sglang.srt.utils import is_flashinfer_available
11
11
 
12
- # ROCm: flashinfer available later
13
- if not is_hip():
12
+ if is_flashinfer_available():
14
13
  from flashinfer.sampling import (
15
14
  min_p_sampling_from_probs,
16
15
  top_k_renorm_prob,
@@ -43,7 +42,10 @@ class Sampler(nn.Module):
43
42
  torch.isnan(probs), torch.full_like(probs, 1e-10), probs
44
43
  )
45
44
 
46
- if global_server_args_dict["sampling_backend"] == "flashinfer":
45
+ if sampling_info.top_ks.max().item() <= 1:
46
+ # Use torch.argmax if all requests use greedy sampling
47
+ batch_next_token_ids = torch.argmax(probs, -1)
48
+ elif global_server_args_dict["sampling_backend"] == "flashinfer":
47
49
  max_top_k_round, batch_size = 32, probs.shape[0]
48
50
  uniform_samples = torch.rand(
49
51
  (max_top_k_round, batch_size), device=probs.device
@@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
18
18
  """
19
19
  # Lazy import to suppress some warnings
20
20
  from torchao.quantization import (
21
+ float8_dynamic_activation_float8_weight,
21
22
  int4_weight_only,
22
23
  int8_dynamic_activation_int8_weight,
23
24
  int8_weight_only,
24
25
  quantize_,
25
26
  )
27
+ from torchao.quantization.observer import PerRow, PerTensor
26
28
 
27
29
  dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
28
30
  dummy_linear.weight = param
@@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
45
47
  # this requires newer hardware
46
48
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
47
49
  quantize_(dummy_linear, float8_weight_only())
50
+ elif "fp8dq" in torchao_config:
51
+ granularity = torchao_config.split("-")[-1]
52
+ GRANULARITY_MAP = {
53
+ "per_row": PerRow(),
54
+ "per_tensor": PerTensor(),
55
+ }
56
+ assert (
57
+ granularity in GRANULARITY_MAP
58
+ ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
59
+ quantize_(
60
+ dummy_linear,
61
+ float8_dynamic_activation_float8_weight(
62
+ granularity=GRANULARITY_MAP[granularity]
63
+ ),
64
+ )
65
+
48
66
  return dummy_linear.weight
49
67
 
50
68
 
sglang/srt/lora/lora.py CHANGED
@@ -28,19 +28,19 @@ from typing import Any, Dict, List, Optional, Tuple
28
28
  import safetensors.torch
29
29
  import torch
30
30
  from torch import nn
31
- from vllm.model_executor.layers.linear import (
32
- ColumnParallelLinear,
33
- MergedColumnParallelLinear,
34
- QKVParallelLinear,
35
- RowParallelLinear,
36
- )
37
31
  from vllm.model_executor.layers.vocab_parallel_embedding import (
38
32
  ParallelLMHead,
39
33
  VocabParallelEmbedding,
40
34
  )
41
35
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
42
36
 
43
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
37
+ from sglang.srt.layers.linear import (
38
+ ColumnParallelLinear,
39
+ MergedColumnParallelLinear,
40
+ QKVParallelLinear,
41
+ RowParallelLinear,
42
+ )
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
44
44
 
45
45
 
46
46
  class BaseLayerWithLoRA(nn.Module):
@@ -101,12 +101,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
101
101
  ) -> None:
102
102
  super().__init__(base_layer, segment_gemm, lora_rank, scaling)
103
103
 
104
- def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
104
+ def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
105
105
  self.set_lora = True
106
106
  self.A_buffer = A_buffer
107
107
  self.B_buffer = B_buffer
108
108
  self.bs = bs
109
- self.seq_lens = seq_lens
109
+ self.seg_indptr = seg_indptr
110
110
  self.weight_indices = weight_indices
111
111
 
112
112
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -115,11 +115,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
115
115
  weights=self.A_buffer,
116
116
  batch_size=self.bs,
117
117
  weight_column_major=True,
118
- seg_lens=self.seq_lens,
118
+ seg_indptr=self.seg_indptr,
119
119
  weight_indices=self.weight_indices,
120
120
  )
121
121
  # FIXME
122
- assert lora_a_output.shape[-1] == self.lora_rank * 2
123
122
  lora_output = torch.empty_like(base_output)
124
123
  output_dim = lora_output.shape[-1] // 2
125
124
  for i in range(2):
@@ -132,7 +131,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
132
131
  weights=self.B_buffer[:, left:right, :].contiguous(),
133
132
  batch_size=self.bs,
134
133
  weight_column_major=True,
135
- seg_lens=self.seq_lens,
134
+ seg_indptr=self.seg_indptr,
136
135
  weight_indices=self.weight_indices,
137
136
  )
138
137
  return base_output + lora_output * self.scaling
@@ -145,14 +144,14 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
145
144
  super().__init__(base_layer, segment_gemm, lora_rank, scaling)
146
145
 
147
146
  def set_lora_info(
148
- self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seq_lens, weight_indices
147
+ self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
149
148
  ):
150
149
  self.set_lora = True
151
150
  self.A_buffer_qkv = A_buffer_qkv
152
151
  self.B_buffer_q = B_buffer_q
153
152
  self.B_buffer_kv = B_buffer_kv
154
153
  self.bs = bs
155
- self.seq_lens = seq_lens
154
+ self.seg_indptr = seg_indptr
156
155
  self.weight_indices = weight_indices
157
156
 
158
157
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -161,7 +160,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
161
160
  weights=self.A_buffer_qkv,
162
161
  batch_size=self.bs,
163
162
  weight_column_major=True,
164
- seg_lens=self.seq_lens,
163
+ seg_indptr=self.seg_indptr,
165
164
  weight_indices=self.weight_indices,
166
165
  )
167
166
  # FIXME parallelize qkv
@@ -173,7 +172,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
173
172
  weights=self.B_buffer_q,
174
173
  batch_size=self.bs,
175
174
  weight_column_major=True,
176
- seg_lens=self.seq_lens,
175
+ seg_indptr=self.seg_indptr,
177
176
  weight_indices=self.weight_indices,
178
177
  )
179
178
  # kv
@@ -189,7 +188,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
189
188
  weights=self.B_buffer_kv[:, left:right, :].contiguous(),
190
189
  batch_size=self.bs,
191
190
  weight_column_major=True,
192
- seg_lens=self.seq_lens,
191
+ seg_indptr=self.seg_indptr,
193
192
  weight_indices=self.weight_indices,
194
193
  )
195
194
  )
@@ -202,12 +201,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
202
201
  ) -> None:
203
202
  super().__init__(base_layer, segment_gemm, lora_rank, scaling)
204
203
 
205
- def set_lora_info(self, A_buffer, B_buffer, bs, seq_lens, weight_indices):
204
+ def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
206
205
  self.set_lora = True
207
206
  self.A_buffer = A_buffer
208
207
  self.B_buffer = B_buffer
209
208
  self.bs = bs
210
- self.seq_lens = seq_lens
209
+ self.seg_indptr = seg_indptr
211
210
  self.weight_indices = weight_indices
212
211
 
213
212
  def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
@@ -216,7 +215,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
216
215
  weights=self.A_buffer,
217
216
  batch_size=self.bs,
218
217
  weight_column_major=True,
219
- seg_lens=self.seq_lens,
218
+ seg_indptr=self.seg_indptr,
220
219
  weight_indices=self.weight_indices,
221
220
  )
222
221
  lora_output = self.segment_gemm.run(
@@ -224,7 +223,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
224
223
  weights=self.B_buffer,
225
224
  batch_size=self.bs,
226
225
  weight_column_major=True,
227
- seg_lens=self.seq_lens,
226
+ seg_indptr=self.seg_indptr,
228
227
  weight_indices=self.weight_indices,
229
228
  )
230
229
  return base_output + lora_output * self.scaling