sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/global_config.py +2 -0
  3. sglang/srt/constrained/outlines_backend.py +4 -1
  4. sglang/srt/entrypoints/engine.py +2 -2
  5. sglang/srt/layers/attention/flashinfer_backend.py +265 -147
  6. sglang/srt/layers/attention/triton_backend.py +358 -72
  7. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  8. sglang/srt/layers/linear.py +12 -5
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  20. sglang/srt/layers/moe/topk.py +1 -1
  21. sglang/srt/layers/quantization/__init__.py +51 -5
  22. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  32. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  35. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  37. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  39. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  41. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  49. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  51. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  53. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  54. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  55. sglang/srt/lora/backend/__init__.py +25 -5
  56. sglang/srt/lora/backend/base_backend.py +31 -9
  57. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  58. sglang/srt/lora/backend/triton_backend.py +34 -4
  59. sglang/srt/lora/layers.py +293 -0
  60. sglang/srt/lora/lora.py +101 -326
  61. sglang/srt/lora/lora_manager.py +101 -269
  62. sglang/srt/lora/mem_pool.py +174 -0
  63. sglang/srt/lora/triton_ops/__init__.py +7 -1
  64. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  65. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  66. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  67. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  68. sglang/srt/lora/utils.py +141 -0
  69. sglang/srt/managers/detokenizer_manager.py +1 -0
  70. sglang/srt/managers/io_struct.py +4 -0
  71. sglang/srt/managers/schedule_batch.py +16 -3
  72. sglang/srt/managers/scheduler.py +29 -0
  73. sglang/srt/managers/tokenizer_manager.py +6 -0
  74. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  75. sglang/srt/model_executor/cuda_graph_runner.py +16 -1
  76. sglang/srt/model_executor/model_runner.py +12 -2
  77. sglang/srt/models/deepseek_v2.py +17 -7
  78. sglang/srt/server_args.py +20 -1
  79. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  80. sglang/srt/speculative/eagle_utils.py +64 -21
  81. sglang/srt/speculative/eagle_worker.py +29 -8
  82. sglang/srt/utils.py +7 -0
  83. sglang/version.py +1 -1
  84. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
  85. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
  86. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
  87. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
  88. {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,170 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
+
7
+
8
+ @triton.jit
9
+ def _gate_up_lora_b_kernel(
10
+ # Pointers to matrices
11
+ x,
12
+ weights,
13
+ output,
14
+ # Parameters of size
15
+ K, # K = R
16
+ output_dim,
17
+ # Strides
18
+ x_stride_0,
19
+ x_stride_1,
20
+ w_stride_0,
21
+ w_stride_1,
22
+ w_stride_2,
23
+ output_stride_0,
24
+ output_stride_1,
25
+ # Information on sequence lengths and weight id
26
+ seg_lens,
27
+ seg_indptr,
28
+ weight_indices,
29
+ # Meta parameters
30
+ BLOCK_S: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ BLOCK_K: tl.constexpr,
33
+ # For fused output scaling and adding
34
+ fuse_scaling_add,
35
+ scaling,
36
+ ):
37
+ # This kernel packs 2 sgemms (gate/up) into a single kernel.
38
+
39
+ # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
40
+ # weights: (num_lora, 2 * output_dim, K)
41
+ # output: (s, 2 * output_dim)
42
+ # output_dim >> K
43
+
44
+ # Current block computes sequence with batch_id,
45
+ # which starts from row seg_start of x with length seg_len.
46
+ # gate_up_id decides which of gate or up (0: gate, 1: up)
47
+ batch_id = tl.program_id(axis=2)
48
+ gate_up_id = tl.program_id(axis=1)
49
+ pid = tl.program_id(axis=0)
50
+ seg_len = tl.load(seg_lens + batch_id)
51
+ w_index = tl.load(weight_indices + batch_id)
52
+ seg_start = tl.load(seg_indptr + batch_id)
53
+ n_start = gate_up_id * output_dim # offset on output dim
54
+
55
+ # The tile in output matrix will have (pid_s, pid_n) as id
56
+ num_pid_n = tl.cdiv(output_dim, BLOCK_N)
57
+ pid_s = pid // num_pid_n
58
+ pid_n = pid % num_pid_n
59
+
60
+ # Create pointers for the first block of x and weights
61
+ # The pointers will be advanced as we move in the K direction
62
+ # and accumulate
63
+ s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
64
+ n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
65
+ k_offset = tl.arange(0, BLOCK_K)
66
+
67
+ x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
68
+ s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
69
+ )
70
+ w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
71
+ k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
72
+ )
73
+
74
+ # Iteate to compute the block in output matrix
75
+ partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
76
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
77
+ x_tile = tl.load(
78
+ x_ptrs,
79
+ mask=(s_offset[:, None] < seg_len)
80
+ and (k_offset[None, :] < K - k * BLOCK_K),
81
+ other=0.0,
82
+ )
83
+ w_tile = tl.load(
84
+ w_ptrs,
85
+ mask=(k_offset[:, None] < K - k * BLOCK_K)
86
+ and (n_offset[None, :] < output_dim),
87
+ other=0.0,
88
+ )
89
+ partial_sum += tl.dot(x_tile, w_tile)
90
+
91
+ x_ptrs += BLOCK_K * x_stride_1
92
+ w_ptrs += BLOCK_K * w_stride_2
93
+
94
+ # Store result to output matrix
95
+ partial_sum *= scaling
96
+ partial_sum = partial_sum.to(x.dtype.element_ty)
97
+ output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
98
+ s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
99
+ )
100
+ output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
101
+ if fuse_scaling_add:
102
+ partial_sum += tl.load(output_ptr, mask=output_mask)
103
+ tl.store(output_ptr, partial_sum, mask=output_mask)
104
+
105
+
106
+ def gate_up_lora_b_fwd(
107
+ x: torch.Tensor,
108
+ gate_up_lora_b: torch.Tensor,
109
+ batch_info: LoRABatchInfo,
110
+ output_dim: int,
111
+ base_output: torch.Tensor = None,
112
+ scaling: float = 1.0,
113
+ ) -> torch.Tensor:
114
+
115
+ # x: (s, 2 * r)
116
+ # gate_up_lora_b: (num_lora, 2 * output_dim, r)
117
+ # output: (s, 2 * output_dim)
118
+
119
+ # Compute lora_output with shape (s, output_dim) as follows:
120
+ # lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
121
+ # lora_output[:, output_dim:]
122
+ # = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
123
+
124
+ # Get dims
125
+ s = x.shape[0]
126
+ input_dim = x.shape[1]
127
+ r = gate_up_lora_b.shape[-1]
128
+ assert input_dim == 2 * r
129
+
130
+ BLOCK_S = 16
131
+ BLOCK_R = 16
132
+ BLOCK_OUT = 64
133
+
134
+ grid_b = (
135
+ triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
136
+ 2, # this dimension decides current block computes on gate or up proj
137
+ batch_info.bs,
138
+ )
139
+
140
+ if base_output is None:
141
+ output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
142
+ fuse_scaling_add = False
143
+ else:
144
+ output = base_output
145
+ fuse_scaling_add = True
146
+
147
+ _gate_up_lora_b_kernel[grid_b](
148
+ x,
149
+ gate_up_lora_b,
150
+ output,
151
+ r,
152
+ output_dim,
153
+ x.stride(0),
154
+ x.stride(1),
155
+ gate_up_lora_b.stride(0),
156
+ gate_up_lora_b.stride(1),
157
+ gate_up_lora_b.stride(2),
158
+ output.stride(0),
159
+ output.stride(1),
160
+ batch_info.seg_lens,
161
+ batch_info.seg_indptr,
162
+ batch_info.weight_indices,
163
+ BLOCK_S,
164
+ BLOCK_OUT,
165
+ BLOCK_R,
166
+ fuse_scaling_add,
167
+ scaling,
168
+ )
169
+
170
+ return output
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
108
108
  def qkv_lora_b_fwd(
109
109
  x: torch.Tensor,
110
110
  qkv_lora_b: torch.Tensor,
111
- batch_info: LoraBatchInfo,
111
+ batch_info: LoRABatchInfo,
112
112
  output_offset: torch.Tensor,
113
113
  max_qkv_out_dim: int,
114
114
  base_output: torch.Tensor = None,
@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
123
123
  # output: (s, output_dim_q + 2 * output_dim_kv)
124
124
 
125
125
  # Compute lora_output with shape (s, output_dim) as follows:
126
- # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
126
+ # lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
127
127
  # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
128
- # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
128
+ # = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
129
129
  # lora_output[:, output_dim_q + output_dim_kv: ]
130
- # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
130
+ # = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
131
131
 
132
132
  # Get dims
133
133
  s = x.shape[0]
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
91
91
 
92
92
 
93
93
  def sgemm_lora_a_fwd(
94
- x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
94
+ x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
95
95
  ) -> torch.Tensor:
96
96
  # x: (s, input_dim)
97
97
  # weights: (num_lora, r, input_dim)
@@ -2,7 +2,7 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
 
5
- from sglang.srt.lora.lora import LoraBatchInfo
5
+ from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
8
  @triton.jit
@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
98
98
  def sgemm_lora_b_fwd(
99
99
  x: torch.Tensor,
100
100
  weights: torch.Tensor,
101
- batch_info: LoraBatchInfo,
101
+ batch_info: LoRABatchInfo,
102
102
  base_output: torch.Tensor = None,
103
103
  scaling: float = 1.0,
104
104
  ) -> torch.Tensor:
@@ -0,0 +1,141 @@
1
+ import re
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Optional, Set, Tuple
5
+
6
+ import torch
7
+
8
+ from sglang.srt.hf_transformers_utils import AutoConfig
9
+
10
+
11
+ @dataclass
12
+ class LoRABatchInfo:
13
+ # Batch size
14
+ bs: int
15
+
16
+ # Lengths of each sequence in shape (bs,)
17
+ seg_lens: torch.Tensor
18
+
19
+ # Indice pointers of each sequence in shape (bs + 1, )
20
+ seg_indptr: torch.Tensor
21
+
22
+ # Maximum sequence length of current batch
23
+ max_len: int
24
+
25
+ # The index of lora adapter used by each sequence, in shape (bs,)
26
+ weight_indices: torch.Tensor
27
+
28
+
29
+ class LoRAType(Enum):
30
+ LORA_A = 0
31
+ LORA_B = 1
32
+
33
+
34
+ def get_layer_id(name: str) -> int:
35
+ """
36
+ Extract integer id of layer from its name in string.
37
+ """
38
+ match = re.search(r"layers\.(\d+)\.", name)
39
+ if match is None:
40
+ return None
41
+ return int(match.group(1))
42
+
43
+
44
+ def get_customized_names_from_hf_names(
45
+ hf_module_names: Set[str], base_model: torch.nn.Module
46
+ ) -> Set[str]:
47
+ """
48
+ This function takes in a set of huggingface style module names:
49
+ e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
50
+ and outputs a set of module names of customized sglang layers:
51
+ e.g., {"qkv_proj", "o_proj"}
52
+ """
53
+ if hasattr(base_model, "get_module_name"):
54
+ return {base_model.get_module_name(name) for name in hf_module_names}
55
+ else:
56
+ """
57
+ Fallback solution of mapping from config module name to module name in model class.
58
+ Please check if it aligns with your base model.
59
+ Please implement the function in the model class if it is not.
60
+ You can reference this function in llama.py.
61
+ """
62
+ params_mapping = {
63
+ "q_proj": "qkv_proj",
64
+ "k_proj": "qkv_proj",
65
+ "v_proj": "qkv_proj",
66
+ "gate_proj": "gate_up_proj",
67
+ "up_proj": "gate_up_proj",
68
+ }
69
+ return {params_mapping.get(name, name) for name in hf_module_names}
70
+
71
+
72
+ def get_hidden_dim(
73
+ module_name: str, config: AutoConfig, base_model: torch.nn.Module
74
+ ) -> Tuple[int]:
75
+ """
76
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
77
+ """
78
+
79
+ if hasattr(base_model, "get_hidden_dim"):
80
+ return base_model.get_hidden_dim(module_name)
81
+ else:
82
+ """
83
+ WARNING: get_hidden_dim() is not defined,
84
+ which is used to get the hidden dim for different lora modules
85
+ Use the default one, but please check if it is correct for your model.
86
+ Please implement the function in the model class if it is not.
87
+ You can reference this function in llama.py.
88
+ """
89
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
90
+ return config.hidden_size, config.hidden_size
91
+ elif module_name in ["kv_proj"]:
92
+ return config.hidden_size, config.hidden_size // (
93
+ config.num_attention_heads // config.num_key_value_heads
94
+ )
95
+ elif module_name == "gate_up_proj":
96
+ return config.hidden_size, config.intermediate_size
97
+ elif module_name == "down_proj":
98
+ return config.intermediate_size, config.hidden_size
99
+ else:
100
+ raise NotImplementedError()
101
+
102
+
103
+ def get_stacked_name(name: str) -> Tuple[str]:
104
+ """
105
+ Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
106
+ """
107
+ params_mapping = {
108
+ "q_proj": ("qkv_proj", "q_proj"),
109
+ "k_proj": ("qkv_proj", "kv_proj"),
110
+ "v_proj": ("qkv_proj", "kv_proj"),
111
+ "gate_proj": ("gate_up_proj", "gate_up_proj"),
112
+ "up_proj": ("gate_up_proj", "gate_up_proj"),
113
+ }
114
+ return params_mapping.get(name, (name, name))
115
+
116
+
117
+ def get_stacked_multiply(module_name: str) -> int:
118
+ """
119
+ Mapping a lora module name to its magnification at output dimension
120
+ """
121
+ stacked_rank = {
122
+ "qkv_proj": 3,
123
+ "kv_proj": 2,
124
+ "gate_up_proj": 2,
125
+ }
126
+ return stacked_rank[module_name] if module_name in stacked_rank else 1
127
+
128
+
129
+ def get_weight_name(
130
+ target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
131
+ ) -> Optional[str]:
132
+ """
133
+ target_name is name of a given module,
134
+ lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
135
+ If there is a weight name in lora_weight_names that can match target_name, return this name
136
+ Else return None
137
+ """
138
+ idx = 0 if lora_type == LoRAType.LORA_A else 1
139
+ for weight_name_pair in lora_weight_names:
140
+ if weight_name_pair[idx] in target_name:
141
+ return weight_name_pair[idx]
@@ -210,6 +210,7 @@ class DetokenizerManager:
210
210
  input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
211
211
  output_top_logprobs_val=recv_obj.output_top_logprobs_val,
212
212
  output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
213
+ output_hidden_states=recv_obj.output_hidden_states,
213
214
  )
214
215
  )
215
216
 
@@ -371,6 +371,8 @@ class BatchTokenIDOut:
371
371
  output_top_logprobs_val: List[List]
372
372
  output_top_logprobs_idx: List[List]
373
373
 
374
+ output_hidden_states: List[List[float]]
375
+
374
376
 
375
377
  @dataclass
376
378
  class BatchStrOut:
@@ -397,6 +399,8 @@ class BatchStrOut:
397
399
  output_top_logprobs_val: List[List]
398
400
  output_top_logprobs_idx: List[List]
399
401
 
402
+ output_hidden_states: List[List[float]]
403
+
400
404
 
401
405
  @dataclass
402
406
  class BatchEmbeddingOut:
@@ -65,6 +65,7 @@ global_server_args_dict = {
65
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
66
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
67
  "device": ServerArgs.device,
68
+ "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
68
69
  }
69
70
 
70
71
  logger = logging.getLogger(__name__)
@@ -315,6 +316,7 @@ class Req:
315
316
  self.output_token_logprobs_val = self.output_token_logprobs_idx = (
316
317
  self.output_top_logprobs_val
317
318
  ) = self.output_top_logprobs_idx = None
319
+ self.hidden_states = []
318
320
 
319
321
  # Logprobs (internal values)
320
322
  # The tokens is prefilled but need to be considered as decode tokens
@@ -604,6 +606,9 @@ class ScheduleBatch:
604
606
  # Enable custom logit processor
605
607
  enable_custom_logit_processor: bool = False
606
608
 
609
+ # Return hidden states
610
+ return_hidden_states: bool = False
611
+
607
612
  @classmethod
608
613
  def init_new(
609
614
  cls,
@@ -615,6 +620,7 @@ class ScheduleBatch:
615
620
  enable_overlap: bool,
616
621
  spec_algorithm: SpeculativeAlgorithm,
617
622
  enable_custom_logit_processor: bool,
623
+ return_hidden_states: bool = False,
618
624
  ):
619
625
  return cls(
620
626
  reqs=reqs,
@@ -629,6 +635,7 @@ class ScheduleBatch:
629
635
  device=req_to_token_pool.device,
630
636
  spec_algorithm=spec_algorithm,
631
637
  enable_custom_logit_processor=enable_custom_logit_processor,
638
+ return_hidden_states=return_hidden_states,
632
639
  )
633
640
 
634
641
  def batch_size(self):
@@ -1196,9 +1203,15 @@ class ScheduleBatch:
1196
1203
  spec_algorithm=self.spec_algorithm,
1197
1204
  spec_info=self.spec_info,
1198
1205
  capture_hidden_mode=(
1199
- getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL)
1200
- if self.spec_info
1201
- else CaptureHiddenMode.NULL
1206
+ CaptureHiddenMode.FULL
1207
+ if self.return_hidden_states
1208
+ else (
1209
+ getattr(
1210
+ self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
1211
+ )
1212
+ if self.spec_info
1213
+ else CaptureHiddenMode.NULL
1214
+ )
1202
1215
  ),
1203
1216
  )
1204
1217
 
@@ -997,6 +997,7 @@ class Scheduler:
997
997
  self.enable_overlap,
998
998
  self.spec_algorithm,
999
999
  self.server_args.enable_custom_logit_processor,
1000
+ self.server_args.return_hidden_states,
1000
1001
  )
1001
1002
  new_batch.prepare_for_extend()
1002
1003
 
@@ -1156,6 +1157,8 @@ class Scheduler:
1156
1157
  logits_output.input_token_logprobs.tolist()
1157
1158
  )
1158
1159
 
1160
+ hidden_state_offset = 0
1161
+
1159
1162
  # Check finish conditions
1160
1163
  logprob_pt = 0
1161
1164
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
@@ -1182,6 +1185,21 @@ class Scheduler:
1182
1185
  i, req, logprob_pt, next_token_ids, logits_output
1183
1186
  )
1184
1187
 
1188
+ if (
1189
+ self.server_args.return_hidden_states
1190
+ and logits_output.hidden_states is not None
1191
+ ):
1192
+ req.hidden_states.append(
1193
+ logits_output.hidden_states[
1194
+ hidden_state_offset : (
1195
+ hidden_state_offset := hidden_state_offset
1196
+ + len(req.origin_input_ids)
1197
+ )
1198
+ ]
1199
+ .cpu()
1200
+ .clone()
1201
+ )
1202
+
1185
1203
  if req.grammar is not None:
1186
1204
  req.grammar.accept_token(next_token_id)
1187
1205
  req.grammar.finished = req.finished()
@@ -1275,6 +1293,12 @@ class Scheduler:
1275
1293
  logits_output.next_token_top_logprobs_idx[i]
1276
1294
  )
1277
1295
 
1296
+ if (
1297
+ self.server_args.return_hidden_states
1298
+ and logits_output.hidden_states is not None
1299
+ ):
1300
+ req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
1301
+
1278
1302
  if req.grammar is not None:
1279
1303
  req.grammar.accept_token(next_token_id)
1280
1304
  req.grammar.finished = req.finished()
@@ -1398,6 +1422,7 @@ class Scheduler:
1398
1422
  completion_tokens = []
1399
1423
  cached_tokens = []
1400
1424
  spec_verify_ct = []
1425
+ hidden_states = []
1401
1426
 
1402
1427
  if return_logprob:
1403
1428
  input_token_logprobs_val = []
@@ -1464,6 +1489,8 @@ class Scheduler:
1464
1489
  output_top_logprobs_val.append(req.output_top_logprobs_val)
1465
1490
  output_top_logprobs_idx.append(req.output_top_logprobs_idx)
1466
1491
 
1492
+ hidden_states.append(req.hidden_states)
1493
+
1467
1494
  # Send to detokenizer
1468
1495
  if rids:
1469
1496
  self.send_to_detokenizer.send_pyobj(
@@ -1490,6 +1517,7 @@ class Scheduler:
1490
1517
  input_top_logprobs_idx,
1491
1518
  output_top_logprobs_val,
1492
1519
  output_top_logprobs_idx,
1520
+ hidden_states,
1493
1521
  )
1494
1522
  )
1495
1523
  else: # embedding or reward model
@@ -1553,6 +1581,7 @@ class Scheduler:
1553
1581
  self.enable_overlap,
1554
1582
  self.spec_algorithm,
1555
1583
  self.server_args.enable_custom_logit_processor,
1584
+ self.server_args.return_hidden_states,
1556
1585
  )
1557
1586
  idle_batch.prepare_for_idle()
1558
1587
  return idle_batch
@@ -796,6 +796,12 @@ class TokenizerManager:
796
796
  }
797
797
  )
798
798
 
799
+ if (
800
+ hasattr(recv_obj, "output_hidden_states")
801
+ and len(recv_obj.output_hidden_states[i]) > 0
802
+ ):
803
+ meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
804
+
799
805
  if isinstance(recv_obj, BatchStrOut):
800
806
  out_dict = {
801
807
  "text": recv_obj.output_strs[i],
@@ -156,6 +156,10 @@ class TpModelWorkerClient:
156
156
  logits_output.input_token_logprobs = (
157
157
  logits_output.input_token_logprobs.to("cpu", non_blocking=True)
158
158
  )
159
+ if logits_output.hidden_states is not None:
160
+ logits_output.hidden_states = logits_output.hidden_states.to(
161
+ "cpu", non_blocking=True
162
+ )
159
163
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
160
164
  copy_done.record()
161
165
 
@@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
+ from sglang.srt.utils import is_hip
37
+
38
+ is_hip_ = is_hip()
36
39
 
37
40
  if TYPE_CHECKING:
38
41
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
129
132
  if bs <= model_runner.req_to_token_pool.size
130
133
  and bs <= server_args.cuda_graph_max_bs
131
134
  ]
135
+ if is_hip_:
136
+ capture_bs += [i * 8 for i in range(21, 33)]
132
137
  compile_bs = (
133
138
  [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
134
139
  if server_args.enable_torch_compile
@@ -237,6 +242,7 @@ class CudaGraphRunner:
237
242
  "1. disable cuda graph by --disable-cuda-graph\n"
238
243
  "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
239
244
  "3. disable torch compile by not using --enable-torch-compile\n"
245
+ "4. set --cuda-graph-max-bs to a smaller value (e.g., 32)\n"
240
246
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
241
247
  )
242
248
 
@@ -348,7 +354,13 @@ class CudaGraphRunner:
348
354
  spec_algorithm=self.model_runner.spec_algorithm,
349
355
  spec_info=spec_info,
350
356
  capture_hidden_mode=(
351
- spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
357
+ CaptureHiddenMode.FULL
358
+ if self.model_runner.server_args.return_hidden_states
359
+ else (
360
+ spec_info.capture_hidden_mode
361
+ if spec_info
362
+ else CaptureHiddenMode.NULL
363
+ )
352
364
  ),
353
365
  )
354
366
 
@@ -462,8 +474,11 @@ class CudaGraphRunner:
462
474
  ),
463
475
  positions=None,
464
476
  retrive_index=None,
477
+ retrive_next_token=None,
478
+ retrive_next_sibling=None,
465
479
  retrive_cum_len=None,
466
480
  draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
481
+ spec_steps=self.model_runner.server_args.speculative_num_steps,
467
482
  capture_hidden_mode=CaptureHiddenMode.FULL,
468
483
  )
469
484
 
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
67
67
  monkey_patch_p2p_access_check,
68
68
  monkey_patch_vllm_gguf_config,
69
69
  set_cpu_offload_max_bytes,
70
+ set_cuda_arch,
70
71
  )
71
72
 
72
73
  logger = logging.getLogger(__name__)
@@ -110,8 +111,14 @@ class ModelRunner:
110
111
  ):
111
112
  # TODO: add MLA optimization on CPU
112
113
  if self.server_args.device != "cpu":
113
- logger.info("MLA optimization is turned on. Use triton backend.")
114
- self.server_args.attention_backend = "triton"
114
+ if server_args.enable_flashinfer_mla:
115
+ logger.info(
116
+ "FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
117
+ )
118
+ self.server_args.attention_backend = "flashinfer"
119
+ else:
120
+ logger.info("MLA optimization is turned on. Use triton backend.")
121
+ self.server_args.attention_backend = "triton"
115
122
 
116
123
  if self.server_args.enable_double_sparsity:
117
124
  logger.info(
@@ -169,6 +176,7 @@ class ModelRunner:
169
176
  "enable_dp_attention": server_args.enable_dp_attention,
170
177
  "enable_ep_moe": server_args.enable_ep_moe,
171
178
  "device": server_args.device,
179
+ "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
172
180
  }
173
181
  )
174
182
 
@@ -292,6 +300,8 @@ class ModelRunner:
292
300
  if torch.cuda.get_device_capability()[1] < 5:
293
301
  raise RuntimeError("SGLang only supports sm75 and above.")
294
302
 
303
+ set_cuda_arch()
304
+
295
305
  # Prepare the model config
296
306
  self.load_config = LoadConfig(
297
307
  load_format=self.server_args.load_format,