sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
14
14
 
15
15
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
16
16
  from sglang.srt.layers.quantization.fp8_kernel import (
17
+ fp8_dtype,
18
+ fp8_max,
19
+ is_fp8_fnuz,
17
20
  per_token_group_quant_fp8,
18
21
  scaled_fp8_quant,
19
22
  sglang_per_token_quant_fp8,
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
30
33
 
31
34
  _is_hip = is_hip()
32
35
  _is_cuda = is_cuda()
36
+ _is_fp8_fnuz = is_fp8_fnuz()
33
37
 
34
- if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
38
+ use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
39
+
40
+ if _is_hip and use_aiter_moe:
35
41
  from aiter import gemm_a8w8_blockscale
36
42
 
37
43
  if _is_cuda:
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
43
49
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
44
50
  TORCH_DEVICE_IDENTITY = None
45
51
 
46
- _TORCH_VERSION = torch.__version__.split("+")[0]
47
- try:
48
- _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
49
- except ValueError:
50
- _TORCH_VERSION_TUPLE = (0, 0, 0)
51
-
52
- # The condition to determine if it is on a platform that supports
53
- # torch._scaled_mm rowwise feature.
54
- # The condition is determined once as the operations
55
- # are time consuming.
56
- USE_ROWWISE_TORCH_SCALED_MM = (
57
- _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
58
- )
52
+
53
+ def use_rowwise_torch_scaled_mm():
54
+ _TORCH_VERSION = torch.__version__.split("+")[0]
55
+ try:
56
+ _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
57
+ except ValueError:
58
+ _TORCH_VERSION_TUPLE = (0, 0, 0)
59
+ if _is_hip:
60
+ # The condition to determine if it is on a platform that supports
61
+ # torch._scaled_mm rowwise feature.
62
+ # The condition is determined once as the operations
63
+ # are time consuming.
64
+ return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
65
+ return False
66
+
67
+
68
+ USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
59
69
 
60
70
 
61
71
  def cutlass_fp8_supported():
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
132
142
  output = fp8_blockwise_scaled_mm(
133
143
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
144
  )
135
- elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
145
+ elif _is_hip and use_aiter_moe:
136
146
  q_input, x_scale = per_token_group_quant_fp8(
137
147
  input_2d, block_size[1], column_major_scales=False
138
148
  )
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
164
174
 
165
175
 
166
176
  def input_to_float8(
167
- x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
177
+ x: torch.Tensor, dtype: torch.dtype = fp8_dtype
168
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
169
179
  """This function quantizes input values to float8 values with tensor-wise quantization."""
170
- finfo = torch.finfo(dtype)
171
180
  min_val, max_val = x.aminmax()
172
181
  amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
173
- fp8_max = finfo.max
174
- if _is_hip:
175
- dtype = torch.float8_e4m3fnuz
176
- fp8_max = 224.0
177
- scale = fp8_max / amax
178
- x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
182
+
183
+ if _is_fp8_fnuz:
184
+ dtype = fp8_dtype
185
+ fp_max = fp8_max
186
+ else:
187
+ finfo = torch.finfo(dtype)
188
+ fp_max = finfo.max
189
+
190
+ scale = fp_max / amax
191
+ x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
179
192
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
180
193
 
181
194
 
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
222
235
  return x_q_tensor, scale
223
236
 
224
237
 
238
+ def block_quant_dequant(
239
+ x_q_block: torch.Tensor,
240
+ x_s: torch.Tensor,
241
+ block_size: List[int],
242
+ dtype: torch.dtype,
243
+ ) -> torch.Tensor:
244
+ """This function converts block-wise quantization to unquantized.
245
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
246
+ and the block size.
247
+ The output is an unquantized tensor with dtype.
248
+ """
249
+ block_n, block_k = block_size[0], block_size[1]
250
+ n, k = x_q_block.shape
251
+ n_tiles = (n + block_n - 1) // block_n
252
+ k_tiles = (k + block_k - 1) // block_k
253
+ assert n_tiles == x_s.shape[0]
254
+ assert k_tiles == x_s.shape[1]
255
+
256
+ x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
257
+
258
+ for j in range(n_tiles):
259
+ for i in range(k_tiles):
260
+ x_q_block_tile = x_q_block[
261
+ j * block_n : min((j + 1) * block_n, n),
262
+ i * block_k : min((i + 1) * block_k, k),
263
+ ]
264
+ x_dq_block_tile = x_dq_block[
265
+ j * block_n : min((j + 1) * block_n, n),
266
+ i * block_k : min((i + 1) * block_k, k),
267
+ ]
268
+ x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
269
+
270
+ return x_dq_block
271
+
272
+
225
273
  def channel_quant_to_tensor_quant(
226
274
  x_q_channel: torch.Tensor,
227
275
  x_s: torch.Tensor,
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
8
8
  QuantizationConfig,
9
9
  QuantizeMethodBase,
10
10
  )
11
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
11
12
  from sglang.srt.layers.radix_attention import RadixAttention
12
- from sglang.srt.utils import is_hip
13
-
14
- _is_hip = is_hip()
15
13
 
16
14
  logger = logging.getLogger(__name__)
17
15
 
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
44
42
  torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45
43
  )
46
44
 
47
- @classmethod
48
- def is_fp8_fnuz(cls) -> bool:
49
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
50
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
51
-
52
45
  def apply(self, layer: torch.nn.Module) -> torch.Tensor:
53
46
  raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
54
47
 
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
57
50
  # We prefer to use separate k_scale and v_scale if present
58
51
  k_scale = layer.k_scale.to("cpu").tolist()
59
52
  v_scale = layer.v_scale.to("cpu").tolist()
60
- if _is_hip and self.is_fp8_fnuz():
53
+ if is_fp8_fnuz():
61
54
  k_scale *= 2
62
55
  v_scale *= 2
63
56
  elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
73
66
  scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
67
  k_scale = scale_to_duplicate.to("cpu").tolist()
75
68
  v_scale = scale_to_duplicate.to("cpu").tolist()
76
- if _is_hip and self.is_fp8_fnuz():
69
+ if is_fp8_fnuz():
77
70
  k_scale *= 2
78
71
  v_scale *= 2
79
72
 
@@ -14,11 +14,6 @@ if not _is_cuda:
14
14
  from vllm._custom_ops import scaled_fp8_quant
15
15
 
16
16
 
17
- def is_fp8_fnuz() -> bool:
18
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
19
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
20
-
21
-
22
17
  def is_layer_skipped(
23
18
  prefix: str,
24
19
  ignored_layers: List[str],
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
9
9
  QuantizationConfig,
10
10
  QuantizeMethodBase,
11
11
  )
12
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
12
+ from sglang.srt.layers.quantization.fp8_kernel import (
13
+ fp8_dtype,
14
+ is_fp8_fnuz,
15
+ per_token_group_quant_fp8,
16
+ )
13
17
  from sglang.srt.layers.quantization.fp8_utils import (
14
18
  apply_fp8_linear,
15
19
  cutlass_fp8_supported,
16
20
  input_to_float8,
17
21
  normalize_e4m3fn_to_e4m3fnuz,
18
22
  )
19
- from sglang.srt.utils import is_hip, set_weight_attrs
23
+ from sglang.srt.utils import set_weight_attrs
20
24
 
21
- _is_hip = is_hip()
25
+ _is_fp8_fnuz = is_fp8_fnuz()
22
26
 
23
27
 
24
28
  class W8A8Fp8Config(QuantizationConfig):
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
97
101
  if self.quantization_config.is_checkpoint_fp8_serialized:
98
102
  weight_scale = layer.weight_scale.detach()
99
103
  # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
100
- if _is_hip:
104
+ if _is_fp8_fnuz:
101
105
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
102
106
  weight=weight, weight_scale=weight_scale
103
107
  )
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
113
117
  layer.weight, layer.weight.shape[-1]
114
118
  )
115
119
  weight_scale = weight_scale.t().contiguous()
116
- if _is_hip:
117
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
118
- weight=weight, weight_scale=weight_scale
119
- )
120
120
  else:
121
121
  # if cutlass not supported, we fall back to use torch._scaled_mm
122
122
  # which requires per tensor quantization on weight
123
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
124
123
  qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
125
124
 
126
125
  # Update the layer with the new values.
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
227
226
  ):
228
227
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
229
228
 
230
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
231
229
  # WEIGHTS
232
230
  w13_weight = torch.nn.Parameter(
233
231
  torch.empty(
@@ -0,0 +1,35 @@
1
+ import logging
2
+ import re
3
+
4
+ import torch
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ def get_layer_id(weight_name):
10
+ # example weight name: model.layers.10.self_attn.qkv_proj.weight
11
+ match = re.search(r"layers\.(\d+)\.", weight_name)
12
+ if match:
13
+ return int(match.group(1))
14
+ return None
15
+
16
+
17
+ class PPMissingLayer(torch.nn.Identity):
18
+ # Adapted from
19
+ # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
20
+ """
21
+ A placeholder layer for missing layers in a pipeline parallel model.
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__()
26
+ self.return_tuple = kwargs.get("return_tuple", False)
27
+
28
+ def forward(self, *args, **kwargs):
29
+ """
30
+ Return the first arg from args or the first value from kwargs.
31
+
32
+ Wraps the input in a tuple if `self.return_tuple` is True.
33
+ """
34
+ input = args[0] if args else next(iter(kwargs.values()))
35
+ return (input,) if self.return_tuple else input
sglang/srt/lora/layers.py CHANGED
@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
136
136
  self.set_lora = True
137
137
  self.A_buffer_gate_up = A_buffer
138
138
  if self.lora_backend.fuse_stacked_lora_b:
139
- # TODO: avoid using contiguous() in GPU.
140
139
  # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
141
- self.B_buffer_gate_up = torch.cat(
142
- (B_buffer[0], B_buffer[1]), dim=-2
143
- ).contiguous()
140
+ if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
141
+ self.B_buffer_gate_up = torch.empty(
142
+ (
143
+ B_buffer[0].shape[0],
144
+ 2 * B_buffer[0].shape[1],
145
+ B_buffer[0].shape[2],
146
+ ),
147
+ dtype=B_buffer[0].dtype,
148
+ device=B_buffer[0].device,
149
+ )
150
+ self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
151
+ self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
144
152
  else:
145
153
  self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
146
154
 
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
171
179
 
172
180
 
173
181
  class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
174
- def init__(
182
+ def __init__(
175
183
  self,
176
184
  base_layer: QKVParallelLinear,
177
185
  lora_backend: BaseLoRABackend,
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
194
202
  output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
195
203
 
196
204
  # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
197
- self.B_buffer_qkv = torch.cat(
198
- (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
199
- ).contiguous()
205
+ if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
206
+ self.B_buffer_qkv = torch.empty(
207
+ (
208
+ B_buffer_q[0].shape[0],
209
+ output_dim_q + 2 * output_dim_kv,
210
+ B_buffer_q[0].shape[2],
211
+ ),
212
+ dtype=B_buffer_q[0].dtype,
213
+ device=B_buffer_q[0].device,
214
+ )
215
+ self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
216
+ self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
217
+ B_buffer_kv[0]
218
+ )
219
+ self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
220
+ B_buffer_kv[1]
221
+ )
200
222
 
201
223
  # Offsets of q/k/v in output dimension
202
- self.output_offset = torch.tensor(
224
+ if not hasattr(self, "output_offset") or self.output_offset is None:
225
+ self.output_offset = torch.empty(
226
+ 4, dtype=torch.int32, device=B_buffer_q.device
227
+ )
228
+ self.output_offset[:4] = torch.tensor(
203
229
  [
204
230
  0,
205
231
  output_dim_q,
@@ -72,6 +72,23 @@ class LoRAManager:
72
72
  self.init_loras()
73
73
  self.init_lora_memory_pool()
74
74
 
75
+ def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
76
+ self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
77
+ with torch.device("cuda"):
78
+ self.cuda_graph_batch_info = LoRABatchInfo(
79
+ bs=self.max_bs_in_cuda_graph,
80
+ seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
81
+ seg_indptr=torch.zeros(
82
+ self.max_bs_in_cuda_graph + 1, dtype=torch.int32
83
+ ),
84
+ max_len=0,
85
+ weight_indices=torch.zeros(
86
+ self.max_bs_in_cuda_graph, dtype=torch.int32
87
+ ),
88
+ lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
89
+ scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
90
+ )
91
+
75
92
  def init_loras(self):
76
93
  # Config of each LoRA adapter
77
94
  self.configs: Dict[str, LoRAConfig] = {}
@@ -136,43 +153,72 @@ class LoRAManager:
136
153
  assert len(cur_uids) <= self.max_loras_per_batch
137
154
  self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
138
155
 
139
- # FIXME: Handle lora uid with None more safely
140
- if cur_uids == set([None]):
141
- return
142
-
143
- # set up batch info shared by all lora moruldes
156
+ # set up batch info shared by all lora modules
144
157
  bs = forward_batch.batch_size
145
- seg_lens = (
146
- forward_batch.extend_seq_lens
147
- if forward_batch.forward_mode.is_extend()
148
- else torch.ones(bs, device=self.device)
149
- )
150
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
151
- seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
152
- max_len = int(torch.max(seg_lens))
153
- weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
154
158
 
155
- lora_ranks = torch.empty(
156
- (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
157
- )
158
- scalings = torch.empty(
159
- (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
160
- )
161
- for i, lora_path in enumerate(forward_batch.lora_paths):
162
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
163
- lora = self.loras[lora_path]
164
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
- scalings[weight_indices[i]] = lora.scaling
166
-
167
- batch_info = LoRABatchInfo(
168
- bs=bs,
169
- seg_lens=seg_lens,
170
- seg_indptr=seg_indptr,
171
- max_len=max_len,
172
- weight_indices=weight_indices,
173
- lora_ranks=lora_ranks,
174
- scalings=scalings,
175
- )
159
+ if (
160
+ hasattr(self, "max_bs_in_cuda_graph")
161
+ and bs <= self.max_bs_in_cuda_graph
162
+ and forward_batch.forward_mode.is_cuda_graph()
163
+ ):
164
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
165
+ # could use CUDA graph.
166
+ self.cuda_graph_batch_info.bs = bs
167
+ self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
168
+ torch.cumsum(
169
+ self.cuda_graph_batch_info.seg_lens[:bs],
170
+ dim=0,
171
+ out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
172
+ )
173
+ self.cuda_graph_batch_info.max_len = int(
174
+ torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
175
+ )
176
+
177
+ for i, lora_path in enumerate(forward_batch.lora_paths):
178
+ self.cuda_graph_batch_info.weight_indices[i] = (
179
+ self.memory_pool.get_buffer_id(lora_path)
180
+ )
181
+ if lora_path is not None:
182
+ lora = self.loras[lora_path]
183
+ self.cuda_graph_batch_info.lora_ranks[
184
+ self.cuda_graph_batch_info.weight_indices[i]
185
+ ] = lora.config.hf_config["r"]
186
+ self.cuda_graph_batch_info.scalings[
187
+ self.cuda_graph_batch_info.weight_indices[i]
188
+ ] = lora.scaling
189
+ batch_info = self.cuda_graph_batch_info
190
+ else:
191
+ seg_lens = (
192
+ forward_batch.extend_seq_lens
193
+ if forward_batch.forward_mode.is_extend()
194
+ else torch.ones(bs, device=self.device)
195
+ )
196
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
197
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
198
+ max_len = int(torch.max(seg_lens))
199
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
200
+
201
+ lora_ranks = torch.zeros(
202
+ (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
203
+ )
204
+ scalings = torch.zeros(
205
+ (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
206
+ )
207
+ for i, lora_path in enumerate(forward_batch.lora_paths):
208
+ weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
209
+ if lora_path is not None:
210
+ lora = self.loras[lora_path]
211
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
212
+ scalings[weight_indices[i]] = lora.scaling
213
+ batch_info = LoRABatchInfo(
214
+ bs=bs,
215
+ seg_lens=seg_lens,
216
+ seg_indptr=seg_indptr,
217
+ max_len=max_len,
218
+ weight_indices=weight_indices,
219
+ lora_ranks=lora_ranks,
220
+ scalings=scalings,
221
+ )
176
222
  self.lora_backend.set_batch_info(batch_info)
177
223
 
178
224
  # call set_lora_info for each lora modules