sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/_custom_ops.py +29 -1
  2. sglang/srt/configs/model_config.py +1 -1
  3. sglang/srt/conversation.py +1 -1
  4. sglang/srt/disaggregation/common/conn.py +34 -6
  5. sglang/srt/disaggregation/mini_lb.py +3 -2
  6. sglang/srt/disaggregation/mooncake/conn.py +49 -20
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  8. sglang/srt/disaggregation/nixl/conn.py +17 -13
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  10. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  11. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  12. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  13. sglang/srt/distributed/parallel_state.py +70 -15
  14. sglang/srt/entrypoints/engine.py +2 -8
  15. sglang/srt/entrypoints/http_server.py +20 -32
  16. sglang/srt/entrypoints/openai/protocol.py +3 -3
  17. sglang/srt/entrypoints/openai/serving_chat.py +27 -4
  18. sglang/srt/function_call/base_format_detector.py +74 -12
  19. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  20. sglang/srt/function_call/ebnf_composer.py +95 -63
  21. sglang/srt/function_call/function_call_parser.py +4 -4
  22. sglang/srt/function_call/kimik2_detector.py +41 -16
  23. sglang/srt/function_call/llama32_detector.py +6 -3
  24. sglang/srt/function_call/mistral_detector.py +11 -3
  25. sglang/srt/function_call/pythonic_detector.py +16 -14
  26. sglang/srt/function_call/qwen25_detector.py +12 -3
  27. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
  28. sglang/srt/layers/activation.py +11 -3
  29. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  30. sglang/srt/layers/communicator.py +12 -12
  31. sglang/srt/layers/dp_attention.py +72 -24
  32. sglang/srt/layers/logits_processor.py +34 -24
  33. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  35. sglang/srt/layers/moe/topk.py +5 -13
  36. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  37. sglang/srt/layers/quantization/modelopt_quant.py +8 -4
  38. sglang/srt/layers/quantization/utils.py +0 -9
  39. sglang/srt/layers/radix_attention.py +5 -3
  40. sglang/srt/lora/lora_manager.py +133 -169
  41. sglang/srt/lora/lora_registry.py +124 -0
  42. sglang/srt/lora/mem_pool.py +2 -2
  43. sglang/srt/managers/cache_controller.py +53 -6
  44. sglang/srt/managers/io_struct.py +19 -1
  45. sglang/srt/managers/schedule_batch.py +13 -3
  46. sglang/srt/managers/scheduler.py +13 -25
  47. sglang/srt/managers/tokenizer_manager.py +28 -25
  48. sglang/srt/managers/tp_worker.py +2 -4
  49. sglang/srt/mem_cache/allocator.py +67 -7
  50. sglang/srt/mem_cache/hicache_storage.py +17 -1
  51. sglang/srt/mem_cache/hiradix_cache.py +30 -16
  52. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  53. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  54. sglang/srt/model_executor/forward_batch_info.py +201 -29
  55. sglang/srt/model_executor/model_runner.py +41 -23
  56. sglang/srt/models/deepseek_v2.py +1 -2
  57. sglang/srt/models/mllama4.py +10 -3
  58. sglang/srt/models/qwen2_moe.py +0 -4
  59. sglang/srt/models/qwen3_moe.py +1 -6
  60. sglang/srt/reasoning_parser.py +46 -4
  61. sglang/srt/sampling/sampling_batch_info.py +6 -5
  62. sglang/srt/server_args.py +76 -55
  63. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  65. sglang/srt/speculative/eagle_utils.py +51 -23
  66. sglang/srt/speculative/eagle_worker.py +59 -44
  67. sglang/srt/two_batch_overlap.py +9 -5
  68. sglang/srt/utils.py +17 -68
  69. sglang/test/test_activation.py +50 -1
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
  72. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
  73. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, Iterable, Optional, Set, Tuple
19
+ from typing import Dict, Iterable, List, Optional, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
26
26
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
29
+ from sglang.srt.lora.lora_registry import LoRARef
29
30
  from sglang.srt.lora.mem_pool import LoRAMemoryPool
30
31
  from sglang.srt.lora.utils import (
31
32
  LoRABatchInfo,
@@ -55,6 +56,7 @@ class LoRAManager:
55
56
  tp_rank: int = 0,
56
57
  max_lora_rank: Optional[int] = None,
57
58
  target_modules: Optional[Iterable[str]] = None,
59
+ lora_paths: Optional[Dict[str, LoRARef]] = None,
58
60
  ):
59
61
  self.base_model: torch.nn.Module = base_model
60
62
  self.base_hf_config: AutoConfig = base_hf_config
@@ -64,10 +66,6 @@ class LoRAManager:
64
66
  self.device: torch.device = next(self.base_model.parameters()).device
65
67
  self.tp_size: int = tp_size
66
68
  self.tp_rank: int = tp_rank
67
- self.max_lora_rank: Optional[int] = max_lora_rank
68
- self.target_modules: Optional[Set[str]] = (
69
- set(target_modules) if target_modules else None
70
- )
71
69
 
72
70
  # LoRA backend for running sgemm kernels
73
71
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -75,7 +73,11 @@ class LoRAManager:
75
73
  self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
76
74
 
77
75
  # Initialize mutable internal state of the LoRAManager.
78
- self.init_state()
76
+ self.init_state(
77
+ max_lora_rank=max_lora_rank,
78
+ target_modules=target_modules,
79
+ lora_paths=lora_paths,
80
+ )
79
81
 
80
82
  def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
81
83
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -112,108 +114,87 @@ class LoRAManager:
112
114
  success=success,
113
115
  error_message=error_message,
114
116
  loaded_adapters={
115
- name: config.path for name, config in self.configs.items()
117
+ lora_ref.lora_name: lora_ref.lora_path
118
+ for lora_ref in self.lora_refs.values()
116
119
  },
117
120
  )
118
121
 
119
- def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
120
- """
121
- Load LoRA adapters from the specified paths.
122
-
123
- Args:
124
- lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
125
- If a LoRA adapter is already loaded, it will be skipped with a warning.
126
- """
127
-
128
- results = []
129
- for lora_name, lora_path in lora_paths.items():
130
- result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
131
- results.append(result)
132
-
133
- self.update_state_from_configs()
134
-
135
- return self.create_lora_update_result(
136
- success=all(result.success for result in results),
137
- error_message="\n".join(
138
- result.error_message for result in results if not result.success
139
- ),
140
- )
141
-
142
- def load_lora_adapter(
143
- self, lora_name: str, lora_path: str, update_state: bool = True
144
- ) -> LoRAUpdateResult:
122
+ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
145
123
  """
146
124
  Load a single LoRA adapter from the specified path.
147
125
 
148
126
  Args:
149
- lora_name (str): The name of the LoRA adapter.
150
- lora_path (str): The file path to the LoRA adapter.
151
- update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
127
+ lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
152
128
  """
129
+ assert (
130
+ lora_ref.lora_name is not None and lora_ref.lora_path is not None
131
+ ), "LoRARef must have both lora_name and lora_path set for loading."
132
+ assert (
133
+ lora_ref.lora_id not in self.loras
134
+ ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
153
135
 
154
- success = True
155
- error_message = ""
136
+ try:
137
+ # load configs
138
+ new_adapter = LoRAConfig(lora_ref.lora_path)
139
+ self.validate_new_adapter(new_adapter, lora_ref)
140
+ self.configs[lora_ref.lora_id] = new_adapter
156
141
 
157
- if lora_name in self.loras:
158
- success = False
159
- error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
142
+ # load weights
143
+ self.load_lora_weights(lora_ref)
160
144
 
161
- try:
162
- new_adapter = LoRAConfig(lora_path)
163
- self.validate_new_adapter(lora_name, new_adapter)
164
- self.configs[lora_name] = new_adapter
145
+ # keep metadata for displayed messages
146
+ self.lora_refs[lora_ref.lora_id] = lora_ref
165
147
  except Exception as e:
166
- success = False
167
- error_message = (
168
- f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
148
+ return self.create_lora_update_result(
149
+ success=False,
150
+ error_message=str(e),
169
151
  )
170
152
 
171
- if update_state:
172
- self.update_state_from_configs()
153
+ return self.create_lora_update_result(success=True)
173
154
 
174
- return self.create_lora_update_result(
175
- success=success,
176
- error_message=error_message,
177
- )
178
-
179
- def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
155
+ def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
180
156
  """
181
157
  Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
182
158
  """
183
159
 
184
- incompatible = self.memory_pool and not self.memory_pool.can_support(
185
- lora_config
186
- )
160
+ memory_pool = getattr(self, "memory_pool", None)
161
+ incompatible = memory_pool and not memory_pool.can_support(lora_config)
187
162
  if incompatible:
188
163
  raise ValueError(
189
- f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
164
+ f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
190
165
  "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
191
166
  "included in `--enable_lora_modules`."
192
167
  )
193
168
 
194
- def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
169
+ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
195
170
  """
196
171
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
197
172
  delete the corresponding LoRA modules.
198
173
  """
199
174
 
200
- success = True
201
- error_message = ""
202
- if lora_name in self.loras:
203
- del self.configs[lora_name]
204
- else:
205
- error_message = f"LoRA adapter {lora_name} is not loaded."
206
- success = False
175
+ adapter = self.configs.get(lora_ref.lora_id, None)
176
+ assert (
177
+ adapter is not None
178
+ ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
207
179
 
208
- self.update_state_from_configs()
180
+ try:
181
+ del self.configs[lora_ref.lora_id]
182
+ del self.loras[lora_ref.lora_id]
183
+ del self.lora_refs[lora_ref.lora_id]
184
+ except Exception as e:
185
+ return self.create_lora_update_result(
186
+ success=False,
187
+ error_message=str(e),
188
+ )
209
189
 
210
- return self.create_lora_update_result(
211
- success=success,
212
- error_message=error_message,
213
- )
190
+ return self.create_lora_update_result(success=True)
214
191
 
215
192
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
216
- # load active loras into lora memory pool
193
+ # Load active loras into lora memory pool
194
+ # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
195
+ # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
196
+ # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
197
+ # the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
217
198
  cur_uids = set(forward_batch.lora_paths)
218
199
  assert len(cur_uids) <= self.max_loras_per_batch
219
200
  self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
@@ -233,10 +214,10 @@ class LoRAManager:
233
214
  weight_indices = [0] * len(forward_batch.lora_paths)
234
215
  lora_ranks = [0] * self.max_loras_per_batch
235
216
  scalings = [0] * self.max_loras_per_batch
236
- for i, lora_path in enumerate(forward_batch.lora_paths):
237
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
238
- if lora_path is not None:
239
- lora = self.loras[lora_path]
217
+ for i, uid in enumerate(forward_batch.lora_paths):
218
+ weight_indices[i] = self.memory_pool.get_buffer_id(uid)
219
+ if uid is not None:
220
+ lora = self.loras[uid]
240
221
  lora_ranks[weight_indices[i]] = lora.config.r
241
222
  scalings[weight_indices[i]] = lora.scaling
242
223
 
@@ -326,7 +307,7 @@ class LoRAManager:
326
307
  """
327
308
  Update all LoRA modules to associate them with the latest memory buffer.
328
309
  """
329
- for layer_id, layer_modules in self.lora_modules.items():
310
+ for layer_id, layer_modules in enumerate(self.lora_modules):
330
311
  for module_name, module in layer_modules.items():
331
312
  if "qkv_proj" in module_name:
332
313
  module.set_lora_info(
@@ -353,115 +334,94 @@ class LoRAManager:
353
334
  ),
354
335
  )
355
336
 
356
- def init_state(self):
337
+ def init_state(
338
+ self,
339
+ max_lora_rank: Optional[int] = None,
340
+ target_modules: Optional[Iterable[str]] = None,
341
+ lora_paths: Optional[Dict[str, LoRARef]] = None,
342
+ ):
357
343
  """
358
344
  Initialize the internal (mutable) state of the LoRAManager.
359
345
 
360
- These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
346
+ When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
347
+ the target modules and max_lora_rank.
361
348
  """
362
349
 
363
- # Configs of all active LoRA adapters.
364
- self.configs: Dict[str, LoRAConfig] = {}
365
-
366
- # LoRA adapter weights cached in CPU memory.
367
- self.loras: Dict[str, LoRAAdapter] = {}
350
+ assert lora_paths or (
351
+ max_lora_rank is not None and target_modules is not None
352
+ ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
368
353
 
369
- # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
370
- self.lora_weight_names: Tuple[Set[str]] = (set(), set())
371
-
372
- # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
373
- self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
374
- i: {} for i in range(self.base_hf_config.num_hidden_layers)
375
- }
354
+ self.init_lora_adapters(lora_paths)
355
+ self.init_lora_shapes(
356
+ max_lora_rank=max_lora_rank,
357
+ target_modules=target_modules,
358
+ )
359
+ self.init_lora_weight_names()
360
+ self.init_lora_modules()
361
+ self.init_memory_pool()
376
362
 
377
- # The LoRA memory pool that manages the GPU buffers for active LoRA weights.
378
- # It is initialized lazily when the first LoRA adapter is loaded.
379
- self.memory_pool: Optional[LoRAMemoryPool] = None
363
+ def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
364
+ # Configs of all active LoRA adapters, indexed by LoRA ID.
365
+ self.configs: Dict[str, LoRAConfig] = {}
380
366
 
381
- def update_state_from_configs(self):
382
- """
383
- Update the internal state of the LoRAManager based on the current `self.configs`. This method
384
- should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
385
- """
367
+ # LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
368
+ self.loras: Dict[str, LoRAAdapter] = {}
386
369
 
387
- # Loads / unloads LoRA adapters based on the latest configs.
388
- self.update_lora_adapters()
389
- # Apply the latest LoRA configurations to the internal state for inferencing.
390
- self.apply_lora_configs()
370
+ # Mapping from LoRA ID to LoRARef object.
371
+ self.lora_refs: Dict[str, LoRARef] = {}
391
372
 
392
- def apply_lora_configs(self):
393
- """
394
- Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
373
+ if lora_paths:
374
+ for lora_ref in lora_paths.values():
375
+ result = self.load_lora_adapter(lora_ref)
376
+ if not result.success:
377
+ raise RuntimeError(
378
+ f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
379
+ )
395
380
 
396
- Notes:
397
- - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
398
- we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
399
- LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
400
- early CY25H2.
401
- """
381
+ def init_lora_shapes(
382
+ self,
383
+ max_lora_rank: Optional[int] = None,
384
+ target_modules: Optional[Iterable[str]] = None,
385
+ ):
386
+ """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
402
387
 
403
- if self.memory_pool is None:
404
- # Infer max_lora_rank and target_modules if not explicitly specified in server args.
405
- if self.target_modules is None:
406
- self.target_modules = set()
407
- for config in self.configs.values():
408
- self.target_modules.update(config.target_modules)
409
-
410
- if self.max_lora_rank is None:
411
- self.max_lora_rank = max(
412
- [x.hf_config["r"] for x in self.configs.values()],
413
- default=0,
414
- )
388
+ if target_modules is not None:
389
+ self.target_modules = set(target_modules)
390
+ else:
391
+ self.target_modules = set()
392
+ for config in self.configs.values():
393
+ self.target_modules.update(config.target_modules)
415
394
 
416
- self.update_lora_weight_names()
417
- self.update_lora_modules()
418
- self.update_memory_buffers()
395
+ if max_lora_rank is not None:
396
+ self.max_lora_rank = max_lora_rank
419
397
  else:
420
- # No-op if the memory pool can support the current LoRA configurations.
421
- # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
422
- # module is changed once FlashInfer backend is deprecated.
423
- assert self.memory_pool.can_support(self.configs.values()), (
424
- "LoRA memory pool cannot support the current LoRA configuration. "
425
- "This should never happen as we should have validated adapter compatibility. "
426
- "Please create a Github issue to report.",
398
+ self.max_lora_rank = max(
399
+ [x.hf_config["r"] for x in self.configs.values()],
400
+ default=0,
427
401
  )
428
402
 
429
- def update_lora_weight_names(self):
403
+ def init_lora_weight_names(self):
430
404
  """
431
405
  Add new LoRA weight names if needed based on the current `self.configs`.
432
406
  """
433
407
 
434
408
  # Target lora weight names for lora_a and lora_b modules respectively.
435
409
  lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
436
- self.lora_weight_names[0].update(lora_A)
437
- self.lora_weight_names[1].update(lora_B)
410
+ self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
438
411
 
439
- def update_lora_adapters(self):
412
+ def load_lora_weights(self, lora_ref: LoRARef):
440
413
  """
441
- Update the LoRA adapters in CPU memory based on the current `self.configs`.
442
- It loads any new adapters that are not already loaded, and unloads any adapters
443
- that are no longer in `self.configs` (e.g., unloaded).
414
+ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
444
415
  """
445
-
446
- # Load new adapter weights to cpu
447
- for name, config in self.configs.items():
448
- if name not in self.loras:
449
- logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
450
- lora_adapter = LoRAAdapter(
451
- name,
452
- config,
453
- self.base_hf_config,
454
- self.load_config,
455
- self.lora_backend,
456
- )
457
- lora_adapter.initialize_weights()
458
- self.loras[name] = lora_adapter
459
-
460
- # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
461
- for name in list(self.loras):
462
- if name not in self.configs:
463
- logger.info(f"Unloading LoRA adapter {name}")
464
- del self.loras[name]
416
+ lora_adapter = LoRAAdapter(
417
+ lora_ref.lora_id,
418
+ self.configs[lora_ref.lora_id],
419
+ self.base_hf_config,
420
+ self.load_config,
421
+ self.lora_backend,
422
+ )
423
+ lora_adapter.initialize_weights()
424
+ self.loras[lora_ref.lora_id] = lora_adapter
465
425
 
466
426
  # Additional checks for flashinfer backend
467
427
  # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
@@ -472,7 +432,7 @@ class LoRAManager:
472
432
  len(lora_dims) == 1 and len(scalings) == 1
473
433
  ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
474
434
 
475
- def update_memory_buffers(self):
435
+ def init_memory_pool(self):
476
436
  """(Re)initialize the LoRA memory pool based on the current configurations."""
477
437
  self.memory_pool = LoRAMemoryPool(
478
438
  base_hf_config=self.base_hf_config,
@@ -490,7 +450,12 @@ class LoRAManager:
490
450
  replace_submodule(self.base_model, module_name, lora_module)
491
451
  return lora_module
492
452
 
493
- def update_lora_modules(self):
453
+ def init_lora_modules(self):
454
+ # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
455
+ self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
456
+ {} for _ in range(self.base_hf_config.num_hidden_layers)
457
+ ]
458
+
494
459
  # Target module names of customized layers defined in python/sglang/srt/layers
495
460
  # e.g., {"qkv_proj", "o_proj"}
496
461
  customized_target_names = get_customized_names_from_hf_names(
@@ -511,7 +476,6 @@ class LoRAManager:
511
476
  # The module should be converted if it is included in target_names
512
477
  if module_name.split(".")[-1] in customized_target_names:
513
478
  layer_id = get_layer_id(module_name)
514
- if module_name not in self.lora_modules[layer_id]:
515
- self.lora_modules[layer_id][module_name] = self.set_lora_module(
516
- module_name, module
517
- )
479
+ self.lora_modules[layer_id][module_name] = self.set_lora_module(
480
+ module_name, module
481
+ )
@@ -0,0 +1,124 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+
16
+ import asyncio
17
+ from dataclasses import dataclass, field, fields
18
+ from typing import Dict, List, Optional, Union
19
+ from uuid import uuid4
20
+
21
+
22
+ @dataclass(frozen=True, slots=True)
23
+ class LoRARef:
24
+ """
25
+ Reference record for a LoRA model.
26
+
27
+ This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
28
+ eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
29
+ keys (e.g., radix cache).
30
+ """
31
+
32
+ lora_id: str = field(default_factory=lambda: uuid4().hex)
33
+ lora_name: Optional[str] = None
34
+ lora_path: Optional[str] = None
35
+
36
+ def __post_init__(self):
37
+ if self.lora_id is None:
38
+ raise ValueError("lora_id cannot be None")
39
+
40
+ def __str__(self) -> str:
41
+ parts = [
42
+ f"{f.name}={value}"
43
+ for f in fields(self)
44
+ if (value := getattr(self, f.name)) is not None
45
+ ]
46
+ return f"{self.__class__.__name__}({', '.join(parts)})"
47
+
48
+
49
+ class LoRARegistry:
50
+ """
51
+ The central registry to keep track of available LoRA adapters.
52
+
53
+ TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
54
+ to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
55
+ """
56
+
57
+ def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
58
+ assert lora_paths is None or all(
59
+ isinstance(lora, LoRARef) for lora in lora_paths.values()
60
+ ), (
61
+ "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
62
+ "Please file an issue if you see this error."
63
+ )
64
+
65
+ # A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
66
+ self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
67
+
68
+ async def register(self, lora_ref: LoRARef):
69
+ """
70
+ Register a new LoRARef object in the registry.
71
+
72
+ Args:
73
+ lora_ref (LoRARef): The LoRARef object to register.
74
+ """
75
+ if lora_ref.lora_name in self._registry:
76
+ raise ValueError(
77
+ f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
78
+ )
79
+ self._registry[lora_ref.lora_name] = lora_ref
80
+
81
+ async def unregister(self, lora_name: str) -> str:
82
+ """
83
+ Unregister a LoRARef object from the registry and returns the removed LoRA ID.
84
+
85
+ Args:
86
+ lora_name (str): The name of the LoRA model to unregister.
87
+ """
88
+ lora_ref = self._registry.get(lora_name, None)
89
+ if lora_ref is None:
90
+ raise ValueError(
91
+ f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
92
+ )
93
+ del self._registry[lora_name]
94
+
95
+ return lora_ref.lora_id
96
+
97
+ async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
98
+ """
99
+ Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
100
+ by incrementing its counter.
101
+
102
+ TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
103
+ """
104
+
105
+ async def _acquire_single(name: str) -> str:
106
+ lora_ref = self._registry.get(name, None)
107
+ if lora_ref is None:
108
+ raise ValueError(
109
+ f"The following requested LoRA adapters are not loaded: {name}\n"
110
+ f"Loaded adapters: {self._registry.keys()}."
111
+ )
112
+ # await self._counters[lora_ref.lora_id].increment()
113
+ return lora_ref.lora_id
114
+
115
+ if isinstance(lora_name, str):
116
+ lora_id = await _acquire_single(lora_name)
117
+ return lora_id
118
+ elif isinstance(lora_name, list):
119
+ lora_ids = await asyncio.gather(
120
+ *[_acquire_single(name) for name in lora_name]
121
+ )
122
+ return lora_ids
123
+ else:
124
+ raise TypeError("lora_name must be either a string or a list of strings.")
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
153
153
  self,
154
154
  cur_uids: Set[Optional[str]],
155
155
  lora_adapters: Dict[str, LoRAAdapter],
156
- lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
156
+ lora_modules: List[Dict[str, BaseLayerWithLoRA]],
157
157
  ):
158
158
  def get_available_buffer_slot():
159
159
  for buffer_id in range(self.max_loras_per_batch):
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
186
186
  uid: str,
187
187
  buffer_id: int,
188
188
  lora_adapter: LoRAAdapter,
189
- lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
189
+ lora_modules: List[Dict[str, BaseLayerWithLoRA]],
190
190
  ):
191
191
  def load_lora_weight_tensor(
192
192
  buffer_view: torch.Tensor, weight: Optional[torch.Tensor]