sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1172,6 +1172,202 @@ class MRotaryEmbedding(RotaryEmbedding):
1172
1172
  )
1173
1173
 
1174
1174
 
1175
+ class DualChunkRotaryEmbedding(CustomOp):
1176
+ """Rotary positional embedding for Dual Chunk Attention."""
1177
+
1178
+ def __init__(
1179
+ self,
1180
+ head_size: int,
1181
+ rotary_dim: int,
1182
+ max_position_embeddings: int,
1183
+ base: int,
1184
+ is_neox_style: bool,
1185
+ dtype: torch.dtype,
1186
+ chunk_size: int,
1187
+ local_size: int,
1188
+ ) -> None:
1189
+ super().__init__()
1190
+ self.head_size = head_size
1191
+ self.rotary_dim = rotary_dim
1192
+ self.max_position_embeddings = max_position_embeddings
1193
+ self.base = base
1194
+ self.is_neox_style = is_neox_style
1195
+ self.chunk_size = chunk_size
1196
+ self.local_size = local_size
1197
+ self.dtype = dtype
1198
+ self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
1199
+ (q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache) = (
1200
+ self._compute_cos_sin_cache()
1201
+ )
1202
+
1203
+ self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
1204
+ self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
1205
+ self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
1206
+ self.register_buffer(
1207
+ "cos_sin_qc_no_clamp_cache", qc_no_clamp_cache, persistent=False
1208
+ )
1209
+ self.register_buffer("cos_sin_q_inter_cache", q_inter_cache, persistent=False)
1210
+
1211
+ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
1212
+ """Compute the inverse frequency."""
1213
+ # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
1214
+ # However, we use `torch.arange(..., dtype=torch.float)` instead to
1215
+ # avoid numerical issues with large base values (e.g., 10000000).
1216
+ # This may cause a slight numerical difference between the HF
1217
+ # implementation and ours.
1218
+ # NOTE(woosuk): To exactly match the HF implementation, we need to
1219
+ # use CPU to compute the cache and then move it to GPU. However, we
1220
+ # create the cache on GPU for faster initialization. This may cause
1221
+ # a slight numerical difference between the HF implementation and ours.
1222
+ inv_freq = 1.0 / (
1223
+ base
1224
+ ** (
1225
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
1226
+ )
1227
+ )
1228
+ return inv_freq
1229
+
1230
+ def _compute_cos_sin_cache(self) -> torch.Tensor:
1231
+ """Compute the cos and sin cache."""
1232
+ inv_freq = self._compute_inv_freq(self.base)
1233
+ chunk_len = self.chunk_size - self.local_size
1234
+ q_t = torch.arange(chunk_len, dtype=torch.float)
1235
+ qc_t = (torch.arange(chunk_len, dtype=torch.float) + chunk_len).clamp(
1236
+ max=self.chunk_size
1237
+ )
1238
+ k_t = torch.arange(self.max_position_embeddings, dtype=torch.float) % chunk_len
1239
+
1240
+ # count from chunk_len, no clamp(self.chunk_size) restriction
1241
+ qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
1242
+ # count from self.chunk_size for q_inter's rope
1243
+ q_inter_t = torch.arange(chunk_len, dtype=torch.float) + self.chunk_size
1244
+
1245
+ q_freqs = torch.outer(q_t, inv_freq)
1246
+ qc_freqs = torch.outer(qc_t, inv_freq)
1247
+ k_freqs = torch.outer(k_t, inv_freq)
1248
+ qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
1249
+ q_inter_freqs = torch.outer(q_inter_t, inv_freq)
1250
+
1251
+ q_cos = q_freqs.cos()
1252
+ q_sin = q_freqs.sin()
1253
+ qc_cos = qc_freqs.cos()
1254
+ qc_sin = qc_freqs.sin()
1255
+ k_cos = k_freqs.cos()
1256
+ k_sin = k_freqs.sin()
1257
+
1258
+ qc_no_clamp_cos = qc_no_clamp_freqs.cos()
1259
+ qc_no_clamp_sin = qc_no_clamp_freqs.sin()
1260
+ q_inter_cos = q_inter_freqs.cos()
1261
+ q_inter_sin = q_inter_freqs.sin()
1262
+
1263
+ q_cache = torch.cat((q_cos, q_sin), dim=-1).to(
1264
+ dtype=self.dtype, device=self.device
1265
+ )
1266
+ qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(
1267
+ dtype=self.dtype, device=self.device
1268
+ )
1269
+ k_cache = torch.cat((k_cos, k_sin), dim=-1).to(
1270
+ dtype=self.dtype, device=self.device
1271
+ )
1272
+ qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin), dim=-1).to(
1273
+ dtype=self.dtype, device=self.device
1274
+ )
1275
+ q_inter_cache = torch.cat((q_inter_cos, q_inter_sin), dim=-1).to(
1276
+ dtype=self.dtype, device=self.device
1277
+ )
1278
+ return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
1279
+
1280
+ def forward(
1281
+ self,
1282
+ positions: torch.Tensor,
1283
+ query: torch.Tensor,
1284
+ key: torch.Tensor,
1285
+ offsets: Optional[torch.Tensor] = None,
1286
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1287
+ query = query.view(*query.shape[:-1], -1, self.head_size)
1288
+ key = key.view(*key.shape[:-1], -1, self.head_size)
1289
+ query_rot = query[..., : self.rotary_dim]
1290
+ key_rot = key[..., : self.rotary_dim]
1291
+ if self.rotary_dim < self.head_size:
1292
+ query_pass = query[..., self.rotary_dim :]
1293
+ key_pass = key[..., self.rotary_dim :]
1294
+ else:
1295
+ query_pass = None
1296
+ key_pass = None
1297
+
1298
+ positions_with_offsets = (
1299
+ torch.add(positions, offsets) if offsets is not None else positions
1300
+ )
1301
+ key = self._apply_rotary_embedding(
1302
+ self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass
1303
+ )
1304
+ chunk_len = self.chunk_size - self.local_size
1305
+ query = self._apply_rotary_embedding(
1306
+ self.cos_sin_q_cache[positions_with_offsets % chunk_len],
1307
+ query_rot,
1308
+ query_pass,
1309
+ )
1310
+ query_succ = self._apply_rotary_embedding(
1311
+ self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
1312
+ query_rot,
1313
+ query_pass,
1314
+ )
1315
+ query_inter = self._apply_rotary_embedding(
1316
+ self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
1317
+ query_rot,
1318
+ query_pass,
1319
+ )
1320
+ query_succ_critical = self._apply_rotary_embedding(
1321
+ self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
1322
+ query_rot,
1323
+ query_pass,
1324
+ )
1325
+ query_inter_critical = self._apply_rotary_embedding(
1326
+ self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
1327
+ query_rot,
1328
+ query_pass,
1329
+ )
1330
+
1331
+ # merge query into one tensor to simplify the interfaces
1332
+ query = torch.cat(
1333
+ (
1334
+ query,
1335
+ query_succ,
1336
+ query_inter,
1337
+ query_succ_critical,
1338
+ query_inter_critical,
1339
+ ),
1340
+ dim=-1,
1341
+ )
1342
+ return query, key
1343
+
1344
+ def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
1345
+ cos, sin = cos_sin.chunk(2, dim=-1)
1346
+ if self.is_neox_style:
1347
+ # NOTE(woosuk): Here we assume that the positions tensor has the
1348
+ # shape [batch_size, seq_len].
1349
+ cos = cos.repeat(1, 1, 2).unsqueeze(-2)
1350
+ sin = sin.repeat(1, 1, 2).unsqueeze(-2)
1351
+ else:
1352
+ cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
1353
+ sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
1354
+ rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
1355
+ hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
1356
+
1357
+ if self.rotary_dim < self.head_size:
1358
+ hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
1359
+ else:
1360
+ hidden = hidden_rot
1361
+ return hidden.flatten(-2).squeeze(0)
1362
+
1363
+ def extra_repr(self) -> str:
1364
+ s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
1365
+ s += f", max_position_embeddings={self.max_position_embeddings}"
1366
+ s += f", base={self.base}, is_neox_style={self.is_neox_style}"
1367
+ s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
1368
+ return s
1369
+
1370
+
1175
1371
  _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
1176
1372
 
1177
1373
 
@@ -1184,6 +1380,7 @@ def get_rope(
1184
1380
  rope_scaling: Optional[Dict[str, Any]] = None,
1185
1381
  dtype: Optional[torch.dtype] = None,
1186
1382
  partial_rotary_factor: float = 1.0,
1383
+ dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
1187
1384
  ) -> RotaryEmbedding:
1188
1385
  if dtype is None:
1189
1386
  dtype = torch.get_default_dtype()
@@ -1195,6 +1392,17 @@ def get_rope(
1195
1392
  rope_scaling_args = tuple(rope_scaling_tuple.items())
1196
1393
  else:
1197
1394
  rope_scaling_args = None
1395
+
1396
+ if dual_chunk_attention_config is not None:
1397
+ dual_chunk_attention_tuple = {
1398
+ k: tuple(v) if isinstance(v, list) else v
1399
+ for k, v in dual_chunk_attention_config.items()
1400
+ if k != "sparse_attention_config"
1401
+ }
1402
+ dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
1403
+ else:
1404
+ dual_chunk_attention_args = None
1405
+
1198
1406
  if partial_rotary_factor < 1.0:
1199
1407
  rotary_dim = int(rotary_dim * partial_rotary_factor)
1200
1408
  key = (
@@ -1204,12 +1412,28 @@ def get_rope(
1204
1412
  base,
1205
1413
  is_neox_style,
1206
1414
  rope_scaling_args,
1415
+ dual_chunk_attention_args,
1207
1416
  dtype,
1208
1417
  )
1209
1418
  if key in _ROPE_DICT:
1210
1419
  return _ROPE_DICT[key]
1211
1420
 
1212
- if rope_scaling is None:
1421
+ if dual_chunk_attention_config is not None:
1422
+ extra_kwargs = {
1423
+ k: v
1424
+ for k, v in dual_chunk_attention_config.items()
1425
+ if k in ("chunk_size", "local_size")
1426
+ }
1427
+ rotary_emb = DualChunkRotaryEmbedding(
1428
+ head_size,
1429
+ rotary_dim,
1430
+ max_position,
1431
+ base,
1432
+ is_neox_style,
1433
+ dtype,
1434
+ **extra_kwargs,
1435
+ )
1436
+ elif rope_scaling is None:
1213
1437
  rotary_emb = RotaryEmbedding(
1214
1438
  head_size, rotary_dim, max_position, base, is_neox_style, dtype
1215
1439
  )
@@ -1,5 +1,6 @@
1
1
  import logging
2
2
  import re
3
+ from functools import lru_cache
3
4
 
4
5
  import torch
5
6
 
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
35
36
  return (input,) if self.return_tuple else input
36
37
 
37
38
 
39
+ @lru_cache(maxsize=1)
38
40
  def is_sm100_supported(device=None) -> bool:
39
41
  return (torch.cuda.get_device_capability(device)[0] == 10) and (
40
42
  torch.version.cuda >= "12.8"
41
43
  )
44
+
45
+
46
+ @lru_cache(maxsize=1)
47
+ def is_sm90_supported(device=None) -> bool:
48
+ return (torch.cuda.get_device_capability(device)[0] == 9) and (
49
+ torch.version.cuda >= "12.3"
50
+ )
@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
26
26
  method_has_implemented_embedding,
27
27
  )
28
28
  from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
29
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
29
+ from sglang.srt.utils import (
30
+ cpu_has_amx_support,
31
+ get_compiler_backend,
32
+ is_cpu,
33
+ set_weight_attrs,
34
+ )
30
35
 
31
36
  DEFAULT_VOCAB_PADDING_SIZE = 64
32
37
 
@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
117
122
  assert self.num_added_elements <= self.num_added_elements_padded
118
123
 
119
124
 
120
- @torch.jit.script
125
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
121
126
  def get_masked_input_and_mask(
122
127
  input_: torch.Tensor,
123
128
  org_vocab_start_index: int,
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
126
131
  added_vocab_start_index: int,
127
132
  added_vocab_end_index: int,
128
133
  ) -> Tuple[torch.Tensor, torch.Tensor]:
129
- # torch.jit.script will fuse all of the pointwise ops below
134
+ # torch.compile will fuse all of the pointwise ops below
130
135
  # into a single kernel, making it very fast
131
136
  org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
132
137
  added_vocab_mask = (input_ >= added_vocab_start_index) & (
@@ -144,6 +144,7 @@ class LoRAManager:
144
144
 
145
145
  # keep metadata for displayed messages
146
146
  self.lora_refs[lora_ref.lora_id] = lora_ref
147
+ self.num_pinned_loras += int(lora_ref.pinned)
147
148
  except Exception as e:
148
149
  return self.create_lora_update_result(
149
150
  success=False,
@@ -157,13 +158,22 @@ class LoRAManager:
157
158
  Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
158
159
  """
159
160
 
161
+ # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
160
162
  memory_pool = getattr(self, "memory_pool", None)
161
163
  incompatible = memory_pool and not memory_pool.can_support(lora_config)
162
164
  if incompatible:
163
165
  raise ValueError(
164
- f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
165
- "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
166
- "included in `--enable_lora_modules`."
166
+ f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
167
+ "LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
168
+ "`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
169
+ )
170
+
171
+ # Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
172
+ if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
173
+ raise ValueError(
174
+ f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
175
+ "in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
176
+ "`--max-loras-per-batch` or load it as unpinned LoRA adapters."
167
177
  )
168
178
 
169
179
  def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
@@ -172,15 +182,17 @@ class LoRAManager:
172
182
  delete the corresponding LoRA modules.
173
183
  """
174
184
 
175
- adapter = self.configs.get(lora_ref.lora_id, None)
185
+ adapter = self.configs.get(lora_ref.lora_id)
186
+ lora_ref = self.lora_refs.get(lora_ref.lora_id)
176
187
  assert (
177
- adapter is not None
188
+ adapter is not None and lora_ref is not None
178
189
  ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
179
190
 
180
191
  try:
181
192
  del self.configs[lora_ref.lora_id]
182
193
  del self.loras[lora_ref.lora_id]
183
194
  del self.lora_refs[lora_ref.lora_id]
195
+ self.num_pinned_loras -= int(lora_ref.pinned)
184
196
  except Exception as e:
185
197
  return self.create_lora_update_result(
186
198
  success=False,
@@ -189,15 +201,49 @@ class LoRAManager:
189
201
 
190
202
  return self.create_lora_update_result(success=True)
191
203
 
204
+ def validate_lora_batch(self, lora_ids: set[str]) -> bool:
205
+ """
206
+ Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
207
+ """
208
+ if len(lora_ids) > self.max_loras_per_batch:
209
+ return False
210
+
211
+ # skip pinned LoRA check if no pinned LoRA adapters are loaded.
212
+ if self.num_pinned_loras == 0:
213
+ return True
214
+
215
+ # counting the number of pinned LoRA adapters in the batch.
216
+ pinned_loras_in_batch = 0
217
+ for lora_id in lora_ids:
218
+ if lora_id is not None:
219
+ lora_ref = self.lora_refs.get(lora_id)
220
+ assert (
221
+ lora_ref is not None
222
+ ), f"LoRA ID {lora_id} not found in lora_refs."
223
+ pinned_loras_in_batch += int(lora_ref.pinned)
224
+
225
+ assert pinned_loras_in_batch <= self.num_pinned_loras, (
226
+ f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
227
+ f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
228
+ )
229
+
230
+ required_slots = len(lora_ids) - pinned_loras_in_batch
231
+ mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
232
+
233
+ return required_slots <= mem_pool_vacancy
234
+
192
235
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
236
+
193
237
  # Load active loras into lora memory pool
194
- # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
195
- # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
196
- # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
197
- # the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
198
- cur_uids = set(forward_batch.lora_paths)
238
+ cur_uids = set(forward_batch.lora_ids)
239
+
199
240
  assert len(cur_uids) <= self.max_loras_per_batch
200
- self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
241
+ self.memory_pool.prepare_lora_batch(
242
+ cur_uids=cur_uids,
243
+ lora_adapters=self.loras,
244
+ lora_modules=self.lora_modules,
245
+ lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
246
+ )
201
247
 
202
248
  # set up batch info shared by all lora modules
203
249
  bs = forward_batch.batch_size
@@ -211,10 +257,10 @@ class LoRAManager:
211
257
  Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
212
258
  to device (CUDA) asynchronously.
213
259
  """
214
- weight_indices = [0] * len(forward_batch.lora_paths)
260
+ weight_indices = [0] * len(forward_batch.lora_ids)
215
261
  lora_ranks = [0] * self.max_loras_per_batch
216
262
  scalings = [0] * self.max_loras_per_batch
217
- for i, uid in enumerate(forward_batch.lora_paths):
263
+ for i, uid in enumerate(forward_batch.lora_ids):
218
264
  weight_indices[i] = self.memory_pool.get_buffer_id(uid)
219
265
  if uid is not None:
220
266
  lora = self.loras[uid]
@@ -370,6 +416,9 @@ class LoRAManager:
370
416
  # Mapping from LoRA ID to LoRARef object.
371
417
  self.lora_refs: Dict[str, LoRARef] = {}
372
418
 
419
+ # Count of pinned LoRA adapters.
420
+ self.num_pinned_loras: int = 0
421
+
373
422
  if lora_paths:
374
423
  for lora_ref in lora_paths.values():
375
424
  result = self.load_lora_adapter(lora_ref)
@@ -390,13 +439,20 @@ class LoRAManager:
390
439
  else:
391
440
  self.target_modules = set()
392
441
  for config in self.configs.values():
442
+ if not isinstance(config.target_modules, list):
443
+ raise ValueError(
444
+ f"SGLang currently only supports inferring LoRA target modules when a list of "
445
+ "suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
446
+ "specify `--lora-target-modules` during server startup. You can specify `all` to "
447
+ "enable all support modules types. "
448
+ )
393
449
  self.target_modules.update(config.target_modules)
394
450
 
395
451
  if max_lora_rank is not None:
396
452
  self.max_lora_rank = max_lora_rank
397
453
  else:
398
454
  self.max_lora_rank = max(
399
- [x.hf_config["r"] for x in self.configs.values()],
455
+ [x.r for x in self.configs.values()],
400
456
  default=0,
401
457
  )
402
458
 
@@ -28,14 +28,15 @@ class LoRARef:
28
28
  """
29
29
  Reference record for a LoRA model.
30
30
 
31
- This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
32
- eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
31
+ This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
32
+ The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
33
33
  keys (e.g., radix cache).
34
34
  """
35
35
 
36
36
  lora_id: str = field(default_factory=lambda: uuid4().hex)
37
37
  lora_name: Optional[str] = None
38
38
  lora_path: Optional[str] = None
39
+ pinned: Optional[bool] = None
39
40
 
40
41
  def __post_init__(self):
41
42
  if self.lora_id is None:
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
2
3
 
3
4
  import torch
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
7
8
  from sglang.srt.lora.layers import BaseLayerWithLoRA
8
9
  from sglang.srt.lora.lora import LoRAAdapter
9
10
  from sglang.srt.lora.lora_config import LoRAConfig
11
+ from sglang.srt.lora.lora_registry import LoRARef
10
12
  from sglang.srt.lora.utils import (
11
13
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
12
14
  LoRAType,
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
16
18
  get_weight_name,
17
19
  )
18
20
 
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EmptySlot:
25
+ """
26
+ Singleton class to represent an empty slot in the memory pool.
27
+ This is used to improve readability by not using special str as a placeholder.
28
+ """
29
+
30
+ __slots__ = ()
31
+
32
+ def __repr__(self):
33
+ return "|EMPTY|"
34
+
35
+ def __new__(cls):
36
+ if not hasattr(cls, "_instance"):
37
+ cls._instance = super().__new__(cls)
38
+ return cls._instance
39
+
40
+
41
+ EMPTY_SLOT = EmptySlot()
42
+
19
43
 
20
44
  class LoRAMemoryPool:
21
45
  """Class for memory pool management of lora modules"""
@@ -54,9 +78,11 @@ class LoRAMemoryPool:
54
78
  self.uid_to_buffer_id: Dict[Optional[str], int] = {}
55
79
 
56
80
  # Buffer idx -> lora uid in memory pool
57
- # All uids are initialized as empty strings for empty buffer slots
81
+ # All uids are initialized as `EmptySlot` for empty buffer slots
58
82
  # Here we don't initialize to None since None is a valid uid
59
- self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
83
+ self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
84
+ EMPTY_SLOT
85
+ ] * self.max_loras_per_batch
60
86
 
61
87
  self.init_buffers(base_model)
62
88
 
@@ -154,17 +180,29 @@ class LoRAMemoryPool:
154
180
  cur_uids: Set[Optional[str]],
155
181
  lora_adapters: Dict[str, LoRAAdapter],
156
182
  lora_modules: List[Dict[str, BaseLayerWithLoRA]],
183
+ lora_refs: Dict[str, LoRARef],
157
184
  ):
158
185
  def get_available_buffer_slot():
159
186
  for buffer_id in range(self.max_loras_per_batch):
160
187
  # Prioritize empty slots
161
- if self.buffer_id_to_uid[buffer_id] == "":
188
+ if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
162
189
  return buffer_id
163
190
 
164
191
  for buffer_id in range(self.max_loras_per_batch):
192
+ uid = self.buffer_id_to_uid[buffer_id]
193
+
165
194
  # Evict unneeded lora
166
- if self.buffer_id_to_uid[buffer_id] not in cur_uids:
167
- self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
195
+ if uid not in cur_uids:
196
+ # Skip pinned LoRAs
197
+ # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
198
+ if uid is not None:
199
+ lora_ref = lora_refs.get(uid)
200
+ if lora_ref is not None and lora_ref.pinned:
201
+ continue
202
+
203
+ self.uid_to_buffer_id.pop(uid)
204
+ logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
205
+ self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
168
206
  return buffer_id
169
207
 
170
208
  raise ValueError(