sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from sglang.srt.utils import (
24
24
  )
25
25
 
26
26
  if TYPE_CHECKING:
27
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
27
28
  from sglang.srt.layers.moe.topk import TopKOutput
28
29
 
29
30
  has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
@@ -129,6 +130,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
129
130
  super().__init__()
130
131
  self.use_triton_kernels = use_triton_kernels
131
132
 
133
+ self.triton_kernel_moe_forward = None
134
+ if torch.cuda.is_available() and has_triton_kernels:
135
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
136
+ triton_kernel_moe_forward as _tk_forward,
137
+ )
138
+
139
+ self.triton_kernel_moe_forward = _tk_forward
140
+
132
141
  def create_weights(
133
142
  self,
134
143
  layer: torch.nn.Module,
@@ -194,6 +203,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
194
203
  no_combine: bool = False,
195
204
  routed_scaling_factor: Optional[float] = None,
196
205
  ) -> torch.Tensor:
206
+
207
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
208
+
209
+ if isinstance(layer, EPMoE):
210
+ return layer.run_moe(
211
+ hidden_states=x,
212
+ topk_output=topk_output,
213
+ )
214
+
197
215
  return self.forward(
198
216
  x=x,
199
217
  layer=layer,
@@ -219,16 +237,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
219
237
  ) -> torch.Tensor:
220
238
 
221
239
  if self.use_triton_kernels:
222
- # TODO(ch-wan): re-enable the Triton kernel
223
- raise NotImplementedError("The Triton kernel is temporarily disabled.")
224
- # return triton_kernel_moe_forward(
225
- # hidden_states=x,
226
- # w1=layer.w13_weight,
227
- # w2=layer.w2_weight,
228
- # gating_output=router_logits,
229
- # topk=top_k,
230
- # renormalize=renormalize,
231
- # )
240
+ return self.triton_kernel_moe_forward(
241
+ hidden_states=x,
242
+ w1=layer.w13_weight,
243
+ w2=layer.w2_weight,
244
+ topk_output=topk_output,
245
+ )
232
246
  else:
233
247
  if _use_aiter:
234
248
  assert not no_combine, "unsupported"
@@ -354,69 +368,3 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
354
368
  raise NotImplementedError("The TPU backend currently does not support MoE.")
355
369
 
356
370
  forward_native = forward_cpu
357
-
358
-
359
- class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
360
-
361
- def create_weights(
362
- self,
363
- layer: torch.nn.Module,
364
- num_experts_per_partition: int,
365
- hidden_size: int,
366
- intermediate_size: int,
367
- params_dtype: torch.dtype,
368
- **extra_weight_attrs,
369
- ):
370
- # Fused gate_up_proj (column parallel)
371
- w13_weight = torch.nn.Parameter(
372
- torch.empty(
373
- num_experts_per_partition,
374
- 2 * intermediate_size,
375
- hidden_size,
376
- dtype=params_dtype,
377
- ),
378
- requires_grad=False,
379
- )
380
- layer.register_parameter("w13_weight", w13_weight)
381
- set_weight_attrs(w13_weight, extra_weight_attrs)
382
-
383
- # down_proj (row parallel)
384
- w2_weight = torch.nn.Parameter(
385
- torch.empty(
386
- num_experts_per_partition,
387
- hidden_size,
388
- intermediate_size,
389
- dtype=params_dtype,
390
- ),
391
- requires_grad=False,
392
- )
393
- layer.register_parameter("w2_weight", w2_weight)
394
- set_weight_attrs(w2_weight, extra_weight_attrs)
395
-
396
- # scale
397
- layer.register_parameter("w13_input_scale", None)
398
- layer.register_parameter("w13_weight_scale", None)
399
-
400
- ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
401
-
402
- w2_input_scale = torch.nn.Parameter(
403
- ones_tensor,
404
- requires_grad=False,
405
- )
406
- layer.register_parameter("w2_input_scale", w2_input_scale)
407
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
408
-
409
- w2_weight_scale = torch.nn.Parameter(
410
- ones_tensor,
411
- requires_grad=False,
412
- )
413
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
414
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
415
-
416
- def apply(
417
- self,
418
- layer: torch.nn.Module,
419
- hidden_states: torch.Tensor,
420
- topk_output: TopKOutput,
421
- ) -> torch.Tensor:
422
- raise NotImplementedError
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
@@ -17,6 +17,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
17
17
  from sglang.srt.layers.quantization.utils import is_layer_skipped
18
18
  from sglang.srt.utils import set_weight_attrs
19
19
 
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
22
+
20
23
  ACTIVATION_SCHEMES = ["static", "dynamic"]
21
24
 
22
25
  logger = logging.getLogger(__name__)
@@ -84,13 +87,14 @@ class W4AFp8Config(QuantizationConfig):
84
87
  self, layer: torch.nn.Module, prefix: str
85
88
  ) -> Optional[QuantizeMethodBase]:
86
89
  from sglang.srt.layers.linear import LinearBase
90
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
87
91
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
88
92
 
89
93
  if isinstance(layer, LinearBase):
90
94
  if is_layer_skipped(prefix, self.ignored_layers):
91
95
  return UnquantizedLinearMethod()
92
96
  return Fp8LinearMethod(self)
93
- elif isinstance(layer, FusedMoE):
97
+ elif isinstance(layer, EPMoE):
94
98
  return W4AFp8MoEMethod(self)
95
99
  return None
96
100
 
@@ -105,8 +109,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
105
109
 
106
110
  def create_weights(
107
111
  self,
108
- layer: Module,
109
- num_experts_per_partition: int,
112
+ layer: EPMoE,
113
+ num_experts: int,
110
114
  hidden_size: int,
111
115
  intermediate_size: int,
112
116
  params_dtype: torch.dtype,
@@ -117,7 +121,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
117
121
  # Fused gate_up_proj (column parallel)
118
122
  w13_weight = torch.nn.Parameter(
119
123
  torch.empty(
120
- num_experts_per_partition,
124
+ num_experts,
121
125
  intermediate_size * 2,
122
126
  hidden_size // 2,
123
127
  dtype=torch.int8,
@@ -130,7 +134,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
130
134
  # down_proj (row parallel)
131
135
  w2_weight = torch.nn.Parameter(
132
136
  torch.empty(
133
- num_experts_per_partition,
137
+ num_experts,
134
138
  hidden_size,
135
139
  intermediate_size // 2,
136
140
  dtype=torch.int8,
@@ -142,7 +146,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
142
146
 
143
147
  w13_weight_scale = torch.nn.Parameter(
144
148
  torch.zeros(
145
- num_experts_per_partition,
149
+ num_experts,
146
150
  2 * intermediate_size,
147
151
  hidden_size // self.quant_config.group_size,
148
152
  dtype=torch.float32,
@@ -154,7 +158,7 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
154
158
 
155
159
  w2_weight_scale = torch.nn.Parameter(
156
160
  torch.zeros(
157
- num_experts_per_partition,
161
+ num_experts,
158
162
  hidden_size,
159
163
  intermediate_size // self.quant_config.group_size,
160
164
  dtype=torch.float32,
@@ -166,14 +170,14 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
166
170
 
167
171
  # Input scales
168
172
  w13_input_scale = torch.nn.Parameter(
169
- torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
173
+ torch.ones((num_experts, 2), dtype=torch.bfloat16),
170
174
  requires_grad=False,
171
175
  )
172
176
  layer.register_parameter("w13_input_scale", w13_input_scale)
173
177
  set_weight_attrs(w13_input_scale, extra_weight_attrs)
174
178
 
175
179
  w2_input_scale = torch.nn.Parameter(
176
- torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
180
+ torch.ones(num_experts, dtype=torch.bfloat16),
177
181
  requires_grad=False,
178
182
  )
179
183
  layer.register_parameter("w2_input_scale", w2_input_scale)
@@ -183,25 +187,25 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
183
187
  device = layer.w13_weight.device
184
188
 
185
189
  self.a_strides1 = torch.full(
186
- (num_experts_per_partition, 3),
190
+ (num_experts, 3),
187
191
  hidden_size,
188
192
  device=device,
189
193
  dtype=torch.int64,
190
194
  )
191
195
  self.c_strides1 = torch.full(
192
- (num_experts_per_partition, 3),
196
+ (num_experts, 3),
193
197
  2 * intermediate_size,
194
198
  device=device,
195
199
  dtype=torch.int64,
196
200
  )
197
201
  self.a_strides2 = torch.full(
198
- (num_experts_per_partition, 3),
202
+ (num_experts, 3),
199
203
  intermediate_size,
200
204
  device=device,
201
205
  dtype=torch.int64,
202
206
  )
203
207
  self.c_strides2 = torch.full(
204
- (num_experts_per_partition, 3),
208
+ (num_experts, 3),
205
209
  hidden_size,
206
210
  device=device,
207
211
  dtype=torch.int64,
@@ -212,13 +216,13 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
212
216
  self.s_strides2 = self.c_strides2
213
217
 
214
218
  self.expert_offsets = torch.empty(
215
- (num_experts_per_partition + 1), dtype=torch.int32, device=device
219
+ (num_experts + 1), dtype=torch.int32, device=device
216
220
  )
217
221
  self.problem_sizes1 = torch.empty(
218
- (num_experts_per_partition, 3), dtype=torch.int32, device=device
222
+ (num_experts, 3), dtype=torch.int32, device=device
219
223
  )
220
224
  self.problem_sizes2 = torch.empty(
221
- (num_experts_per_partition, 3), dtype=torch.int32, device=device
225
+ (num_experts, 3), dtype=torch.int32, device=device
222
226
  )
223
227
 
224
228
  return
@@ -266,3 +270,50 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
266
270
  [w2_input_scale_max], dtype=dtype, device=device
267
271
  )
268
272
  layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
273
+
274
+ def apply(
275
+ self,
276
+ layer: EPMoE,
277
+ hidden_states: torch.Tensor,
278
+ topk_output: TopKOutput,
279
+ ) -> torch.Tensor:
280
+
281
+ # TODO(ch-wan): move it out of this class
282
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
283
+
284
+ topk_ids, topk_weights, _ = topk_output
285
+ local_topk_ids = topk_ids
286
+ if layer.expert_map is not None:
287
+ "Translate info from expert_map to topk_ids"
288
+ local_topk_ids = torch.where(
289
+ layer.expert_map[topk_ids] != layer.num_experts,
290
+ layer.expert_map[topk_ids],
291
+ layer.num_experts,
292
+ )
293
+
294
+ return cutlass_w4a8_moe(
295
+ layer.start_expert_id,
296
+ layer.end_expert_id,
297
+ layer.num_experts,
298
+ hidden_states,
299
+ layer.w13_weight,
300
+ layer.w2_weight,
301
+ layer.w13_weight_scale_inv,
302
+ layer.w2_weight_scale_inv,
303
+ topk_weights,
304
+ topk_ids,
305
+ local_topk_ids,
306
+ self.a_strides1,
307
+ self.b_strides1,
308
+ self.c_strides1,
309
+ self.a_strides2,
310
+ self.b_strides2,
311
+ self.c_strides2,
312
+ self.s_strides13,
313
+ self.s_strides2,
314
+ self.expert_offsets,
315
+ self.problem_sizes1,
316
+ self.problem_sizes2,
317
+ layer.w13_input_scale,
318
+ layer.w2_input_scale,
319
+ )
@@ -14,12 +14,16 @@
14
14
 
15
15
 
16
16
  import asyncio
17
+ from collections import defaultdict
17
18
  from dataclasses import dataclass, field, fields
18
19
  from typing import Dict, List, Optional, Union
19
20
  from uuid import uuid4
20
21
 
22
+ from sglang.srt.aio_rwlock import RWLock
23
+ from sglang.srt.utils import ConcurrentCounter
21
24
 
22
- @dataclass(frozen=True, slots=True)
25
+
26
+ @dataclass(frozen=True)
23
27
  class LoRARef:
24
28
  """
25
29
  Reference record for a LoRA model.
@@ -48,10 +52,11 @@ class LoRARef:
48
52
 
49
53
  class LoRARegistry:
50
54
  """
51
- The central registry to keep track of available LoRA adapters.
55
+ The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
52
56
 
53
- TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
54
- to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
57
+ The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
58
+ available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
59
+ update / eventual consistency model between the tokenizer manager process and the scheduler processes.
55
60
  """
56
61
 
57
62
  def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
@@ -62,8 +67,19 @@ class LoRARegistry:
62
67
  "Please file an issue if you see this error."
63
68
  )
64
69
 
70
+ # A read-write lock to ensure adapters loading / unloading operations are exclusive.
71
+ # Please note that the counter increment/decrement operations are not synchronized through this
72
+ # lock, as they are designed to be non-blocking and can be performed concurrently.
73
+ self._registry_lock = RWLock()
65
74
  # A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
66
- self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
75
+ self._registry: Dict[str, LoRARef] = {}
76
+ # Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
77
+ self._counters: Dict[str, ConcurrentCounter] = {}
78
+
79
+ # Initialize the registry with provided LoRA paths, if present.
80
+ if lora_paths:
81
+ for lora_ref in lora_paths.values():
82
+ self._register_adapter(lora_ref)
67
83
 
68
84
  async def register(self, lora_ref: LoRARef):
69
85
  """
@@ -72,11 +88,8 @@ class LoRARegistry:
72
88
  Args:
73
89
  lora_ref (LoRARef): The LoRARef object to register.
74
90
  """
75
- if lora_ref.lora_name in self._registry:
76
- raise ValueError(
77
- f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
78
- )
79
- self._registry[lora_ref.lora_name] = lora_ref
91
+ async with self._registry_lock.writer_lock:
92
+ self._register_adapter(lora_ref)
80
93
 
81
94
  async def unregister(self, lora_name: str) -> str:
82
95
  """
@@ -85,12 +98,14 @@ class LoRARegistry:
85
98
  Args:
86
99
  lora_name (str): The name of the LoRA model to unregister.
87
100
  """
88
- lora_ref = self._registry.get(lora_name, None)
89
- if lora_ref is None:
90
- raise ValueError(
91
- f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
92
- )
93
- del self._registry[lora_name]
101
+ async with self._registry_lock.writer_lock:
102
+ lora_ref = self._registry.get(lora_name, None)
103
+ if lora_ref is None:
104
+ raise ValueError(
105
+ f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
106
+ )
107
+ del self._registry[lora_name]
108
+ del self._counters[lora_ref.lora_id]
94
109
 
95
110
  return lora_ref.lora_id
96
111
 
@@ -98,27 +113,76 @@ class LoRARegistry:
98
113
  """
99
114
  Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
100
115
  by incrementing its counter.
101
-
102
- TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
103
116
  """
104
117
 
105
- async def _acquire_single(name: str) -> str:
118
+ def _lookup(name: str) -> str:
106
119
  lora_ref = self._registry.get(name, None)
107
120
  if lora_ref is None:
108
121
  raise ValueError(
109
122
  f"The following requested LoRA adapters are not loaded: {name}\n"
110
123
  f"Loaded adapters: {self._registry.keys()}."
111
124
  )
112
- # await self._counters[lora_ref.lora_id].increment()
113
125
  return lora_ref.lora_id
114
126
 
115
- if isinstance(lora_name, str):
116
- lora_id = await _acquire_single(lora_name)
117
- return lora_id
118
- elif isinstance(lora_name, list):
119
- lora_ids = await asyncio.gather(
120
- *[_acquire_single(name) for name in lora_name]
127
+ async with self._registry_lock.reader_lock:
128
+ if isinstance(lora_name, str):
129
+ lora_id = _lookup(lora_name)
130
+ await self._counters[lora_id].increment(notify_all=False)
131
+ return lora_id
132
+ elif isinstance(lora_name, list):
133
+ lora_ids = [_lookup(name) for name in lora_name]
134
+
135
+ # Increment the counters only after all IDs are looked up.
136
+ await asyncio.gather(
137
+ *[self._counters[id].increment(notify_all=False) for id in lora_ids]
138
+ )
139
+ return lora_ids
140
+ else:
141
+ raise TypeError(
142
+ "lora_name must be either a string or a list of strings."
143
+ )
144
+
145
+ async def release(self, lora_id: Union[str, List[str]]):
146
+ """
147
+ Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
148
+ """
149
+
150
+ async with self._registry_lock.reader_lock:
151
+ if isinstance(lora_id, str):
152
+ await self._counters[lora_id].decrement()
153
+ elif isinstance(lora_id, list):
154
+ await asyncio.gather(
155
+ *[self._counters[id].decrement() for id in lora_id]
156
+ )
157
+ else:
158
+ raise TypeError("lora_id must be either a string or a list of strings.")
159
+
160
+ async def wait_for_unload(self, lora_id: str):
161
+ """
162
+ Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
163
+ This is useful for ensuring that a LoRA adapter can be safely unloaded.
164
+
165
+ This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
166
+ which itself is guaranteed to be sequential.
167
+ """
168
+ assert (
169
+ lora_id not in self._registry
170
+ ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
171
+ counter = self._counters.get(lora_id)
172
+ if counter:
173
+ # Wait until no requests are using this LoRA adapter.
174
+ await counter.wait_for_zero()
175
+ del self._counters[lora_id]
176
+
177
+ def _register_adapter(self, lora_ref: LoRARef):
178
+ """
179
+ Internal helper method to register a LoRA adapter.
180
+ """
181
+
182
+ if lora_ref.lora_name in self._registry:
183
+ raise ValueError(
184
+ f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
121
185
  )
122
- return lora_ids
123
- else:
124
- raise TypeError("lora_name must be either a string or a list of strings.")
186
+ self._registry[lora_ref.lora_name] = lora_ref
187
+ self._counters[lora_ref.lora_id] = ConcurrentCounter()
188
+ return lora_ref
@@ -201,8 +201,9 @@ class PrefetchOperation(StorageOperation):
201
201
  def increment(self, num_tokens: int):
202
202
  with self._lock:
203
203
  if self._done_flag:
204
- return
204
+ return False
205
205
  self.completed_tokens += num_tokens
206
+ return True
206
207
 
207
208
  def mark_done(self):
208
209
  with self._lock:
@@ -528,12 +529,12 @@ class HiCacheController:
528
529
  f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
529
530
  )
530
531
  break
531
- self.mem_pool_host.set_from_flat_data_page(
532
- operation.host_indices[operation.completed_tokens],
533
- page_data,
534
- )
535
- operation.increment(self.page_size)
536
- if operation.is_done():
532
+ if operation.increment(self.page_size):
533
+ self.mem_pool_host.set_from_flat_data_page(
534
+ operation.host_indices[operation.completed_tokens],
535
+ page_data,
536
+ )
537
+ else:
537
538
  # operation terminated by controller, release pre-allocated memory
538
539
  self.mem_pool_host.free(
539
540
  operation.host_indices[operation.completed_tokens :]
@@ -589,6 +590,7 @@ class HiCacheController:
589
590
  if storage_hit_count < self.prefetch_threshold:
590
591
  # not to prefetch if not enough benefits
591
592
  self.prefetch_revoke_queue.put(operation.request_id)
593
+ self.mem_pool_host.free(operation.host_indices)
592
594
  logger.debug(
593
595
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
594
596
  )