sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1172
1172
|
)
|
1173
1173
|
|
1174
1174
|
|
1175
|
+
class DualChunkRotaryEmbedding(CustomOp):
|
1176
|
+
"""Rotary positional embedding for Dual Chunk Attention."""
|
1177
|
+
|
1178
|
+
def __init__(
|
1179
|
+
self,
|
1180
|
+
head_size: int,
|
1181
|
+
rotary_dim: int,
|
1182
|
+
max_position_embeddings: int,
|
1183
|
+
base: int,
|
1184
|
+
is_neox_style: bool,
|
1185
|
+
dtype: torch.dtype,
|
1186
|
+
chunk_size: int,
|
1187
|
+
local_size: int,
|
1188
|
+
) -> None:
|
1189
|
+
super().__init__()
|
1190
|
+
self.head_size = head_size
|
1191
|
+
self.rotary_dim = rotary_dim
|
1192
|
+
self.max_position_embeddings = max_position_embeddings
|
1193
|
+
self.base = base
|
1194
|
+
self.is_neox_style = is_neox_style
|
1195
|
+
self.chunk_size = chunk_size
|
1196
|
+
self.local_size = local_size
|
1197
|
+
self.dtype = dtype
|
1198
|
+
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
1199
|
+
(q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
|
1200
|
+
self._compute_cos_sin_cache()
|
1201
|
+
)
|
1202
|
+
|
1203
|
+
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
|
1204
|
+
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
|
1205
|
+
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
|
1206
|
+
self.register_buffer(
|
1207
|
+
"cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
|
1208
|
+
)
|
1209
|
+
self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
|
1210
|
+
|
1211
|
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
1212
|
+
"""Compute the inverse frequency."""
|
1213
|
+
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
1214
|
+
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
1215
|
+
# avoid numerical issues with large base values (e.g., 10000000).
|
1216
|
+
# This may cause a slight numerical difference between the HF
|
1217
|
+
# implementation and ours.
|
1218
|
+
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
1219
|
+
# use CPU to compute the cache and then move it to GPU. However, we
|
1220
|
+
# create the cache on GPU for faster initialization. This may cause
|
1221
|
+
# a slight numerical difference between the HF implementation and ours.
|
1222
|
+
inv_freq = 1.0 / (
|
1223
|
+
base
|
1224
|
+
** (
|
1225
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
|
1226
|
+
)
|
1227
|
+
)
|
1228
|
+
return inv_freq
|
1229
|
+
|
1230
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
1231
|
+
"""Compute the cos and sin cache."""
|
1232
|
+
inv_freq = self._compute_inv_freq(self.base)
|
1233
|
+
chunk_len = self.chunk_size - self.local_size
|
1234
|
+
q_t = torch.arange(chunk_len, dtype=torch.float)
|
1235
|
+
qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
|
1236
|
+
max=self.chunk_size
|
1237
|
+
)
|
1238
|
+
k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
|
1239
|
+
|
1240
|
+
# count from chunk_len, no clamp(self.chunk_size) restriction
|
1241
|
+
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
|
1242
|
+
# count from self.chunk_size for q_inter's rope
|
1243
|
+
q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
|
1244
|
+
|
1245
|
+
q_freqs = torch.outer(q_t, inv_freq)
|
1246
|
+
qc_freqs = torch.outer(qc_t, inv_freq)
|
1247
|
+
k_freqs = torch.outer(k_t, inv_freq)
|
1248
|
+
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
|
1249
|
+
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
|
1250
|
+
|
1251
|
+
q_cos = q_freqs.cos()
|
1252
|
+
q_sin = q_freqs.sin()
|
1253
|
+
qc_cos = qc_freqs.cos()
|
1254
|
+
qc_sin = qc_freqs.sin()
|
1255
|
+
k_cos = k_freqs.cos()
|
1256
|
+
k_sin = k_freqs.sin()
|
1257
|
+
|
1258
|
+
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
|
1259
|
+
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
|
1260
|
+
q_inter_cos = q_inter_freqs.cos()
|
1261
|
+
q_inter_sin = q_inter_freqs.sin()
|
1262
|
+
|
1263
|
+
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
|
1264
|
+
dtype=self.dtype, device=self.device
|
1265
|
+
)
|
1266
|
+
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
|
1267
|
+
dtype=self.dtype, device=self.device
|
1268
|
+
)
|
1269
|
+
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
|
1270
|
+
dtype=self.dtype, device=self.device
|
1271
|
+
)
|
1272
|
+
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
|
1273
|
+
dtype=self.dtype, device=self.device
|
1274
|
+
)
|
1275
|
+
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
|
1276
|
+
dtype=self.dtype, device=self.device
|
1277
|
+
)
|
1278
|
+
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
1279
|
+
|
1280
|
+
def forward(
|
1281
|
+
self,
|
1282
|
+
positions: torch.Tensor,
|
1283
|
+
query: torch.Tensor,
|
1284
|
+
key: torch.Tensor,
|
1285
|
+
offsets: Optional[torch.Tensor] = None,
|
1286
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1287
|
+
query = query.view(*query.shape[:-1], -1, self.head_size)
|
1288
|
+
key = key.view(*key.shape[:-1], -1, self.head_size)
|
1289
|
+
query_rot = query[..., : self.rotary_dim]
|
1290
|
+
key_rot = key[..., : self.rotary_dim]
|
1291
|
+
if self.rotary_dim < self.head_size:
|
1292
|
+
query_pass = query[..., self.rotary_dim :]
|
1293
|
+
key_pass = key[..., self.rotary_dim :]
|
1294
|
+
else:
|
1295
|
+
query_pass = None
|
1296
|
+
key_pass = None
|
1297
|
+
|
1298
|
+
positions_with_offsets = (
|
1299
|
+
torch.add(positions, offsets) if offsets is not None else positions
|
1300
|
+
)
|
1301
|
+
key = self._apply_rotary_embedding(
|
1302
|
+
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
|
1303
|
+
)
|
1304
|
+
chunk_len = self.chunk_size - self.local_size
|
1305
|
+
query = self._apply_rotary_embedding(
|
1306
|
+
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
|
1307
|
+
query_rot,
|
1308
|
+
query_pass,
|
1309
|
+
)
|
1310
|
+
query_succ = self._apply_rotary_embedding(
|
1311
|
+
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
|
1312
|
+
query_rot,
|
1313
|
+
query_pass,
|
1314
|
+
)
|
1315
|
+
query_inter = self._apply_rotary_embedding(
|
1316
|
+
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
|
1317
|
+
query_rot,
|
1318
|
+
query_pass,
|
1319
|
+
)
|
1320
|
+
query_succ_critical = self._apply_rotary_embedding(
|
1321
|
+
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
|
1322
|
+
query_rot,
|
1323
|
+
query_pass,
|
1324
|
+
)
|
1325
|
+
query_inter_critical = self._apply_rotary_embedding(
|
1326
|
+
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
|
1327
|
+
query_rot,
|
1328
|
+
query_pass,
|
1329
|
+
)
|
1330
|
+
|
1331
|
+
# merge query into one tensor to simplify the interfaces
|
1332
|
+
query = torch.cat(
|
1333
|
+
(
|
1334
|
+
query,
|
1335
|
+
query_succ,
|
1336
|
+
query_inter,
|
1337
|
+
query_succ_critical,
|
1338
|
+
query_inter_critical,
|
1339
|
+
),
|
1340
|
+
dim=-1,
|
1341
|
+
)
|
1342
|
+
return query, key
|
1343
|
+
|
1344
|
+
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
1345
|
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
1346
|
+
if self.is_neox_style:
|
1347
|
+
# NOTE(woosuk): Here we assume that the positions tensor has the
|
1348
|
+
# shape [batch_size, seq_len].
|
1349
|
+
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
1350
|
+
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
1351
|
+
else:
|
1352
|
+
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
1353
|
+
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
1354
|
+
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
1355
|
+
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
|
1356
|
+
|
1357
|
+
if self.rotary_dim < self.head_size:
|
1358
|
+
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
|
1359
|
+
else:
|
1360
|
+
hidden = hidden_rot
|
1361
|
+
return hidden.flatten(-2).squeeze(0)
|
1362
|
+
|
1363
|
+
def extra_repr(self) -> str:
|
1364
|
+
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
1365
|
+
s += f", max_position_embeddings={self.max_position_embeddings}"
|
1366
|
+
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
1367
|
+
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
|
1368
|
+
return s
|
1369
|
+
|
1370
|
+
|
1175
1371
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
1176
1372
|
|
1177
1373
|
|
@@ -1184,6 +1380,7 @@ def get_rope(
|
|
1184
1380
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
1185
1381
|
dtype: Optional[torch.dtype] = None,
|
1186
1382
|
partial_rotary_factor: float = 1.0,
|
1383
|
+
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
1187
1384
|
) -> RotaryEmbedding:
|
1188
1385
|
if dtype is None:
|
1189
1386
|
dtype = torch.get_default_dtype()
|
@@ -1195,6 +1392,17 @@ def get_rope(
|
|
1195
1392
|
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
1196
1393
|
else:
|
1197
1394
|
rope_scaling_args = None
|
1395
|
+
|
1396
|
+
if dual_chunk_attention_config is not None:
|
1397
|
+
dual_chunk_attention_tuple = {
|
1398
|
+
k: tuple(v) if isinstance(v, list) else v
|
1399
|
+
for k, v in dual_chunk_attention_config.items()
|
1400
|
+
if k != "sparse_attention_config"
|
1401
|
+
}
|
1402
|
+
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
1403
|
+
else:
|
1404
|
+
dual_chunk_attention_args = None
|
1405
|
+
|
1198
1406
|
if partial_rotary_factor < 1.0:
|
1199
1407
|
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
1200
1408
|
key = (
|
@@ -1204,12 +1412,28 @@ def get_rope(
|
|
1204
1412
|
base,
|
1205
1413
|
is_neox_style,
|
1206
1414
|
rope_scaling_args,
|
1415
|
+
dual_chunk_attention_args,
|
1207
1416
|
dtype,
|
1208
1417
|
)
|
1209
1418
|
if key in _ROPE_DICT:
|
1210
1419
|
return _ROPE_DICT[key]
|
1211
1420
|
|
1212
|
-
if
|
1421
|
+
if dual_chunk_attention_config is not None:
|
1422
|
+
extra_kwargs = {
|
1423
|
+
k: v
|
1424
|
+
for k, v in dual_chunk_attention_config.items()
|
1425
|
+
if k in ("chunk_size", "local_size")
|
1426
|
+
}
|
1427
|
+
rotary_emb = DualChunkRotaryEmbedding(
|
1428
|
+
head_size,
|
1429
|
+
rotary_dim,
|
1430
|
+
max_position,
|
1431
|
+
base,
|
1432
|
+
is_neox_style,
|
1433
|
+
dtype,
|
1434
|
+
**extra_kwargs,
|
1435
|
+
)
|
1436
|
+
elif rope_scaling is None:
|
1213
1437
|
rotary_emb = RotaryEmbedding(
|
1214
1438
|
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
1215
1439
|
)
|
sglang/srt/layers/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
import re
|
3
|
+
from functools import lru_cache
|
3
4
|
|
4
5
|
import torch
|
5
6
|
|
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
|
|
35
36
|
return (input,) if self.return_tuple else input
|
36
37
|
|
37
38
|
|
39
|
+
@lru_cache(maxsize=1)
|
38
40
|
def is_sm100_supported(device=None) -> bool:
|
39
41
|
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
40
42
|
torch.version.cuda >= "12.8"
|
41
43
|
)
|
44
|
+
|
45
|
+
|
46
|
+
@lru_cache(maxsize=1)
|
47
|
+
def is_sm90_supported(device=None) -> bool:
|
48
|
+
return (torch.cuda.get_device_capability(device)[0] == 9) and (
|
49
|
+
torch.version.cuda >= "12.3"
|
50
|
+
)
|
@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
26
26
|
method_has_implemented_embedding,
|
27
27
|
)
|
28
28
|
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
29
|
-
from sglang.srt.utils import
|
29
|
+
from sglang.srt.utils import (
|
30
|
+
cpu_has_amx_support,
|
31
|
+
get_compiler_backend,
|
32
|
+
is_cpu,
|
33
|
+
set_weight_attrs,
|
34
|
+
)
|
30
35
|
|
31
36
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
32
37
|
|
@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
|
|
117
122
|
assert self.num_added_elements <= self.num_added_elements_padded
|
118
123
|
|
119
124
|
|
120
|
-
@torch.
|
125
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
121
126
|
def get_masked_input_and_mask(
|
122
127
|
input_: torch.Tensor,
|
123
128
|
org_vocab_start_index: int,
|
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
|
|
126
131
|
added_vocab_start_index: int,
|
127
132
|
added_vocab_end_index: int,
|
128
133
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
129
|
-
# torch.
|
134
|
+
# torch.compile will fuse all of the pointwise ops below
|
130
135
|
# into a single kernel, making it very fast
|
131
136
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
132
137
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -144,6 +144,7 @@ class LoRAManager:
|
|
144
144
|
|
145
145
|
# keep metadata for displayed messages
|
146
146
|
self.lora_refs[lora_ref.lora_id] = lora_ref
|
147
|
+
self.num_pinned_loras += int(lora_ref.pinned)
|
147
148
|
except Exception as e:
|
148
149
|
return self.create_lora_update_result(
|
149
150
|
success=False,
|
@@ -157,13 +158,22 @@ class LoRAManager:
|
|
157
158
|
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
158
159
|
"""
|
159
160
|
|
161
|
+
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
|
160
162
|
memory_pool = getattr(self, "memory_pool", None)
|
161
163
|
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
162
164
|
if incompatible:
|
163
165
|
raise ValueError(
|
164
|
-
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current
|
165
|
-
"Please ensure that the LoRA adapter's rank is within the configured
|
166
|
-
"included in `--
|
166
|
+
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
|
167
|
+
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
|
168
|
+
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
|
169
|
+
)
|
170
|
+
|
171
|
+
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
|
172
|
+
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
|
173
|
+
raise ValueError(
|
174
|
+
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
|
175
|
+
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
|
176
|
+
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
167
177
|
)
|
168
178
|
|
169
179
|
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
@@ -172,15 +182,17 @@ class LoRAManager:
|
|
172
182
|
delete the corresponding LoRA modules.
|
173
183
|
"""
|
174
184
|
|
175
|
-
adapter = self.configs.get(lora_ref.lora_id
|
185
|
+
adapter = self.configs.get(lora_ref.lora_id)
|
186
|
+
lora_ref = self.lora_refs.get(lora_ref.lora_id)
|
176
187
|
assert (
|
177
|
-
adapter is not None
|
188
|
+
adapter is not None and lora_ref is not None
|
178
189
|
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
179
190
|
|
180
191
|
try:
|
181
192
|
del self.configs[lora_ref.lora_id]
|
182
193
|
del self.loras[lora_ref.lora_id]
|
183
194
|
del self.lora_refs[lora_ref.lora_id]
|
195
|
+
self.num_pinned_loras -= int(lora_ref.pinned)
|
184
196
|
except Exception as e:
|
185
197
|
return self.create_lora_update_result(
|
186
198
|
success=False,
|
@@ -189,15 +201,49 @@ class LoRAManager:
|
|
189
201
|
|
190
202
|
return self.create_lora_update_result(success=True)
|
191
203
|
|
204
|
+
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
|
205
|
+
"""
|
206
|
+
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
|
207
|
+
"""
|
208
|
+
if len(lora_ids) > self.max_loras_per_batch:
|
209
|
+
return False
|
210
|
+
|
211
|
+
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
|
212
|
+
if self.num_pinned_loras == 0:
|
213
|
+
return True
|
214
|
+
|
215
|
+
# counting the number of pinned LoRA adapters in the batch.
|
216
|
+
pinned_loras_in_batch = 0
|
217
|
+
for lora_id in lora_ids:
|
218
|
+
if lora_id is not None:
|
219
|
+
lora_ref = self.lora_refs.get(lora_id)
|
220
|
+
assert (
|
221
|
+
lora_ref is not None
|
222
|
+
), f"LoRA ID {lora_id} not found in lora_refs."
|
223
|
+
pinned_loras_in_batch += int(lora_ref.pinned)
|
224
|
+
|
225
|
+
assert pinned_loras_in_batch <= self.num_pinned_loras, (
|
226
|
+
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
|
227
|
+
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
|
228
|
+
)
|
229
|
+
|
230
|
+
required_slots = len(lora_ids) - pinned_loras_in_batch
|
231
|
+
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
|
232
|
+
|
233
|
+
return required_slots <= mem_pool_vacancy
|
234
|
+
|
192
235
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
236
|
+
|
193
237
|
# Load active loras into lora memory pool
|
194
|
-
|
195
|
-
|
196
|
-
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
197
|
-
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
198
|
-
cur_uids = set(forward_batch.lora_paths)
|
238
|
+
cur_uids = set(forward_batch.lora_ids)
|
239
|
+
|
199
240
|
assert len(cur_uids) <= self.max_loras_per_batch
|
200
|
-
self.memory_pool.prepare_lora_batch(
|
241
|
+
self.memory_pool.prepare_lora_batch(
|
242
|
+
cur_uids=cur_uids,
|
243
|
+
lora_adapters=self.loras,
|
244
|
+
lora_modules=self.lora_modules,
|
245
|
+
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
|
246
|
+
)
|
201
247
|
|
202
248
|
# set up batch info shared by all lora modules
|
203
249
|
bs = forward_batch.batch_size
|
@@ -211,10 +257,10 @@ class LoRAManager:
|
|
211
257
|
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
212
258
|
to device (CUDA) asynchronously.
|
213
259
|
"""
|
214
|
-
weight_indices = [0] * len(forward_batch.
|
260
|
+
weight_indices = [0] * len(forward_batch.lora_ids)
|
215
261
|
lora_ranks = [0] * self.max_loras_per_batch
|
216
262
|
scalings = [0] * self.max_loras_per_batch
|
217
|
-
for i, uid in enumerate(forward_batch.
|
263
|
+
for i, uid in enumerate(forward_batch.lora_ids):
|
218
264
|
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
219
265
|
if uid is not None:
|
220
266
|
lora = self.loras[uid]
|
@@ -370,6 +416,9 @@ class LoRAManager:
|
|
370
416
|
# Mapping from LoRA ID to LoRARef object.
|
371
417
|
self.lora_refs: Dict[str, LoRARef] = {}
|
372
418
|
|
419
|
+
# Count of pinned LoRA adapters.
|
420
|
+
self.num_pinned_loras: int = 0
|
421
|
+
|
373
422
|
if lora_paths:
|
374
423
|
for lora_ref in lora_paths.values():
|
375
424
|
result = self.load_lora_adapter(lora_ref)
|
@@ -390,13 +439,20 @@ class LoRAManager:
|
|
390
439
|
else:
|
391
440
|
self.target_modules = set()
|
392
441
|
for config in self.configs.values():
|
442
|
+
if not isinstance(config.target_modules, list):
|
443
|
+
raise ValueError(
|
444
|
+
f"SGLang currently only supports inferring LoRA target modules when a list of "
|
445
|
+
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
|
446
|
+
"specify `--lora-target-modules` during server startup. You can specify `all` to "
|
447
|
+
"enable all support modules types. "
|
448
|
+
)
|
393
449
|
self.target_modules.update(config.target_modules)
|
394
450
|
|
395
451
|
if max_lora_rank is not None:
|
396
452
|
self.max_lora_rank = max_lora_rank
|
397
453
|
else:
|
398
454
|
self.max_lora_rank = max(
|
399
|
-
[x.
|
455
|
+
[x.r for x in self.configs.values()],
|
400
456
|
default=0,
|
401
457
|
)
|
402
458
|
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -28,14 +28,15 @@ class LoRARef:
|
|
28
28
|
"""
|
29
29
|
Reference record for a LoRA model.
|
30
30
|
|
31
|
-
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``
|
32
|
-
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
31
|
+
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
|
32
|
+
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
33
33
|
keys (e.g., radix cache).
|
34
34
|
"""
|
35
35
|
|
36
36
|
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
37
37
|
lora_name: Optional[str] = None
|
38
38
|
lora_path: Optional[str] = None
|
39
|
+
pinned: Optional[bool] = None
|
39
40
|
|
40
41
|
def __post_init__(self):
|
41
42
|
if self.lora_id is None:
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
2
3
|
|
3
4
|
import torch
|
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
|
7
8
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
8
9
|
from sglang.srt.lora.lora import LoRAAdapter
|
9
10
|
from sglang.srt.lora.lora_config import LoRAConfig
|
11
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
10
12
|
from sglang.srt.lora.utils import (
|
11
13
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
12
14
|
LoRAType,
|
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
|
|
16
18
|
get_weight_name,
|
17
19
|
)
|
18
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class EmptySlot:
|
25
|
+
"""
|
26
|
+
Singleton class to represent an empty slot in the memory pool.
|
27
|
+
This is used to improve readability by not using special str as a placeholder.
|
28
|
+
"""
|
29
|
+
|
30
|
+
__slots__ = ()
|
31
|
+
|
32
|
+
def __repr__(self):
|
33
|
+
return "|EMPTY|"
|
34
|
+
|
35
|
+
def __new__(cls):
|
36
|
+
if not hasattr(cls, "_instance"):
|
37
|
+
cls._instance = super().__new__(cls)
|
38
|
+
return cls._instance
|
39
|
+
|
40
|
+
|
41
|
+
EMPTY_SLOT = EmptySlot()
|
42
|
+
|
19
43
|
|
20
44
|
class LoRAMemoryPool:
|
21
45
|
"""Class for memory pool management of lora modules"""
|
@@ -54,9 +78,11 @@ class LoRAMemoryPool:
|
|
54
78
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
55
79
|
|
56
80
|
# Buffer idx -> lora uid in memory pool
|
57
|
-
# All uids are initialized as
|
81
|
+
# All uids are initialized as `EmptySlot` for empty buffer slots
|
58
82
|
# Here we don't initialize to None since None is a valid uid
|
59
|
-
self.buffer_id_to_uid: List[
|
83
|
+
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
|
84
|
+
EMPTY_SLOT
|
85
|
+
] * self.max_loras_per_batch
|
60
86
|
|
61
87
|
self.init_buffers(base_model)
|
62
88
|
|
@@ -154,17 +180,29 @@ class LoRAMemoryPool:
|
|
154
180
|
cur_uids: Set[Optional[str]],
|
155
181
|
lora_adapters: Dict[str, LoRAAdapter],
|
156
182
|
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
183
|
+
lora_refs: Dict[str, LoRARef],
|
157
184
|
):
|
158
185
|
def get_available_buffer_slot():
|
159
186
|
for buffer_id in range(self.max_loras_per_batch):
|
160
187
|
# Prioritize empty slots
|
161
|
-
if self.buffer_id_to_uid[buffer_id] ==
|
188
|
+
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
|
162
189
|
return buffer_id
|
163
190
|
|
164
191
|
for buffer_id in range(self.max_loras_per_batch):
|
192
|
+
uid = self.buffer_id_to_uid[buffer_id]
|
193
|
+
|
165
194
|
# Evict unneeded lora
|
166
|
-
if
|
167
|
-
|
195
|
+
if uid not in cur_uids:
|
196
|
+
# Skip pinned LoRAs
|
197
|
+
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
|
198
|
+
if uid is not None:
|
199
|
+
lora_ref = lora_refs.get(uid)
|
200
|
+
if lora_ref is not None and lora_ref.pinned:
|
201
|
+
continue
|
202
|
+
|
203
|
+
self.uid_to_buffer_id.pop(uid)
|
204
|
+
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
|
205
|
+
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
|
168
206
|
return buffer_id
|
169
207
|
|
170
208
|
raise ValueError(
|