sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, List, Set, Tuple
19
+ from typing import Dict, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
@@ -45,7 +45,6 @@ class LoRAManager:
45
45
  def __init__(
46
46
  self,
47
47
  base_model: torch.nn.Module,
48
- lora_paths: Dict[str, str],
49
48
  base_hf_config: AutoConfig,
50
49
  max_loras_per_batch: int,
51
50
  load_config: LoadConfig,
@@ -55,7 +54,6 @@ class LoRAManager:
55
54
  tp_rank: int = 0,
56
55
  ):
57
56
  self.base_model: torch.nn.Module = base_model
58
- self.lora_paths: Dict[str, str] = lora_paths
59
57
  self.base_hf_config: AutoConfig = base_hf_config
60
58
  self.max_loras_per_batch: int = max_loras_per_batch
61
59
  self.load_config: LoadConfig = load_config
@@ -69,8 +67,8 @@ class LoRAManager:
69
67
  backend_type = get_backend_from_name(lora_backend)
70
68
  self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
71
69
 
72
- self.init_loras()
73
- self.init_lora_memory_pool()
70
+ # Initialize mutable internal state of the LoRAManager.
71
+ self.init_state()
74
72
 
75
73
  def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
76
74
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -100,72 +98,49 @@ class LoRAManager:
100
98
  ],
101
99
  )
102
100
 
103
- def init_loras(self):
104
- # Config of each LoRA adapter
105
- self.configs: Dict[str, LoRAConfig] = {}
101
+ def load_lora_adapters(self, lora_paths: Dict[str, str]):
102
+ """
103
+ Load LoRA adapters from the specified paths.
104
+ TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
105
+
106
+ Args:
107
+ lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
+ If a LoRA adapter is already loaded, it will be skipped with a warning.
109
+ """
110
+
111
+ for lora_name, lora_path in lora_paths.items():
112
+ if lora_name in self.loras:
113
+ logger.warning(
114
+ f"LoRA adapter {lora_name} is already loaded."
115
+ "If you want to reload it, please unload it first."
116
+ )
117
+ continue
106
118
 
107
- # Target module names in huggingface lora configs.
108
- # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
109
- self.hf_target_names: Set[str] = set()
110
- for name, path in self.lora_paths.items():
111
- self.configs[name] = LoRAConfig(path)
112
- self.hf_target_names.update(self.configs[name].target_modules)
119
+ self.configs[lora_name] = LoRAConfig(lora_path)
113
120
 
114
- # Target lora weight names for lora_a and lora_b modules respectively.
115
- weights_A: List[str] = []
116
- weights_B: List[str] = []
117
- for module in self.hf_target_names:
118
- lora_A, lora_B = get_normalized_lora_weight_names(module)
119
- weights_A += lora_A
120
- weights_B += lora_B
121
- self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
121
+ self.update_state_from_configs()
122
122
 
123
- # load all weights to cpu
124
- self.loras: Dict[str, LoRAAdapter] = {}
125
- for name in self.lora_paths.keys():
126
- lora_adapter = LoRAAdapter(
127
- name,
128
- self.configs[name],
129
- self.base_hf_config,
130
- self.load_config,
131
- self.lora_backend,
132
- )
133
- lora_adapter.initialize_weights()
134
- self.loras[name] = lora_adapter
135
-
136
- # misc lora configs
137
- self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
123
+ def unload_lora_adapters(self, lora_names: Set[str]):
124
+ """
125
+ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
+ delete the corresponding LoRA modules.
138
127
 
139
- if self.lora_backend == "flashinfer":
140
- # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
141
- max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
142
- scaling = list(self.loras.values())[0].scaling
143
- assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
144
- assert all(x.scaling == scaling for x in self.loras.values())
128
+ Args:
129
+ lora_names (Set[str]): A set of LoRA adapter names to unload.
130
+ """
131
+ for lora_name in lora_names:
132
+ if lora_name in self.loras:
133
+ del self.configs[lora_name]
134
+ else:
135
+ logger.warning(f"LoRA adapter {lora_name} is not loaded.")
145
136
 
146
- # Convert original model layers to layers with LoRA
147
- self.convert_to_lora_layers()
148
-
149
- def init_lora_memory_pool(self):
150
- # Initialize memory pool
151
- self.memory_pool = LoRAMemoryPool(
152
- self.base_hf_config,
153
- self.max_loras_per_batch,
154
- self.max_lora_dim,
155
- self.dtype,
156
- self.tp_size,
157
- self.tp_rank,
158
- self.lora_modules,
159
- )
160
-
161
- # Initialize target lora modules in memory pool
162
- self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
137
+ self.update_state_from_configs()
163
138
 
164
139
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
165
140
  # load active loras into lora memory pool
166
141
  cur_uids = set(forward_batch.lora_paths)
167
142
  assert len(cur_uids) <= self.max_loras_per_batch
168
- self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
143
+ self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
169
144
 
170
145
  # set up batch info shared by all lora modules
171
146
  bs = forward_batch.batch_size
@@ -267,9 +242,16 @@ class LoRAManager:
267
242
  )
268
243
  self.lora_backend.set_batch_info(batch_info)
269
244
 
270
- # call set_lora_info for each lora modules
271
- for layer_id, modules in self.lora_modules.items():
272
- for module_name, module in modules:
245
+ # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
246
+ # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
247
+ self.update_lora_info()
248
+
249
+ def update_lora_info(self):
250
+ """
251
+ Update all LoRA modules to associate them with the latest memory buffer.
252
+ """
253
+ for layer_id, layer_modules in self.lora_modules.items():
254
+ for module_name, module in layer_modules.items():
273
255
  if "qkv_proj" in module_name:
274
256
  module.set_lora_info(
275
257
  self.memory_pool.get_tensor(
@@ -295,23 +277,139 @@ class LoRAManager:
295
277
  ),
296
278
  )
297
279
 
280
+ def init_state(self):
281
+ """
282
+ Initialize the internal (mutable) state of the LoRAManager.
283
+
284
+ These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
285
+ """
286
+
287
+ # Configs of all active LoRA adapters.
288
+ self.configs: Dict[str, LoRAConfig] = {}
289
+
290
+ # LoRA adapter weights cached in CPU memory.
291
+ self.loras: Dict[str, LoRAAdapter] = {}
292
+
293
+ # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
294
+ self.lora_weight_names: Tuple[Set[str]] = (set(), set())
295
+
296
+ # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
297
+ self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
298
+ i: {} for i in range(self.base_hf_config.num_hidden_layers)
299
+ }
300
+
301
+ # Initialize memory pool
302
+ self.memory_pool = LoRAMemoryPool(
303
+ self.base_hf_config,
304
+ self.max_loras_per_batch,
305
+ self.dtype,
306
+ self.tp_size,
307
+ self.tp_rank,
308
+ )
309
+
310
+ def update_state_from_configs(self):
311
+ """
312
+ Update the internal state of the LoRAManager based on the current `self.configs`. This method
313
+ should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
314
+
315
+ This includes:
316
+ - Initializing LoRA adapters if they are not already loaded.
317
+ - Collect all LoRA weight names based on the current loaded adapters.
318
+ - Lazily monkey-patching the base model to use LoRA layers where applicable.
319
+ - Preparing the GPU buffer pool for active LoRA weights.
320
+ """
321
+
322
+ # Target module names in huggingface lora configs.
323
+ # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
324
+ hf_target_module_names: Set[str] = set()
325
+ for config in self.configs.values():
326
+ hf_target_module_names.update(config.target_modules)
327
+ max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
328
+
329
+ # Loads / unloads LoRA adapters based on the latest configs.
330
+ self.update_lora_adapters()
331
+
332
+ # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
333
+ #
334
+ # Please note that the following update operations are "monotonic" by design, meaning that we update
335
+ # multiple places to support the new weight names when the first adapter targeting such weight names
336
+ # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
337
+ # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
338
+ # list of LoRA weight names is expected to be extremely finite and stable.
339
+ self.update_lora_weight_names(hf_target_module_names)
340
+ self.update_lora_modules(hf_target_module_names)
341
+ self.update_memory_buffers(max_lora_dim)
342
+
343
+ def update_lora_weight_names(self, hf_target_names: Set[str]):
344
+ """
345
+ Add new LoRA weight names if needed based on the current `self.configs`.
346
+ """
347
+
348
+ # Target lora weight names for lora_a and lora_b modules respectively.
349
+ for module in hf_target_names:
350
+ lora_A, lora_B = get_normalized_lora_weight_names(module)
351
+ self.lora_weight_names[0].update(lora_A)
352
+ self.lora_weight_names[1].update(lora_B)
353
+
354
+ def update_lora_adapters(self):
355
+ """
356
+ Update the LoRA adapters in CPU memory based on the current `self.configs`.
357
+ It loads any new adapters that are not already loaded, and unloads any adapters
358
+ that are no longer in `self.configs` (e.g., unloaded).
359
+ """
360
+
361
+ # Load new adapter weights to cpu
362
+ for name, config in self.configs.items():
363
+ if name not in self.loras:
364
+ logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
365
+ lora_adapter = LoRAAdapter(
366
+ name,
367
+ config,
368
+ self.base_hf_config,
369
+ self.load_config,
370
+ self.lora_backend,
371
+ )
372
+ lora_adapter.initialize_weights()
373
+ self.loras[name] = lora_adapter
374
+
375
+ # Clean up unused LoRA adapters
376
+ for name in self.loras:
377
+ if name not in self.configs:
378
+ logger.info(f"Unloading LoRA adapter {name}")
379
+ del self.loras[name]
380
+
381
+ # Additional checks for flashinfer backend
382
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
383
+ if self.lora_backend == "flashinfer":
384
+ lora_dims = set(x.hf_config["r"] for x in self.configs.values())
385
+ scalings = set(x.scaling for x in self.loras.values())
386
+ assert (
387
+ len(lora_dims) == 1 and len(scalings) == 1
388
+ ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
389
+
390
+ def update_memory_buffers(self, max_lora_dim: int):
391
+ """
392
+ Update the LoRA memory pool buffers based on the current LoRA configurations and update
393
+ LoRA modules to use the new buffers. This method should be called after the LoRA configurations
394
+ are set or updated.
395
+ """
396
+
397
+ self.memory_pool.init_buffers(
398
+ self.lora_weight_names, self.base_model, max_lora_dim
399
+ )
400
+
298
401
  def set_lora_module(self, module_name, module):
299
402
  lora_module = get_lora_layer(module, self.lora_backend)
300
403
  replace_submodule(self.base_model, module_name, lora_module)
301
404
  return lora_module
302
405
 
303
- def convert_to_lora_layers(self):
406
+ def update_lora_modules(self, hf_target_names: Set[str]):
304
407
  # Target module names of customized layers defined in python/sglang/srt/layers
305
408
  # e.g., {"qkv_proj", "o_proj"}
306
409
  customized_target_names = get_customized_names_from_hf_names(
307
- self.hf_target_names, self.base_model
410
+ hf_target_names, self.base_model
308
411
  )
309
412
 
310
- # Monkey patch to use the LoRA version layers
311
- self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
312
- i: [] for i in range(self.base_hf_config.num_hidden_layers)
313
- }
314
-
315
413
  for module_name, module in self.base_model.named_modules():
316
414
  # TODO (lifuhuang): in the future, we should consider generalizing the
317
415
  # should_apply_lora function to support mapping by full module name instead
@@ -326,6 +424,7 @@ class LoRAManager:
326
424
  # The module should be converted if it is included in target_names
327
425
  if module_name.split(".")[-1] in customized_target_names:
328
426
  layer_id = get_layer_id(module_name)
329
- self.lora_modules[layer_id].append(
330
- (module_name, self.set_lora_module(module_name, module))
331
- )
427
+ if module_name not in self.lora_modules[layer_id]:
428
+ self.lora_modules[layer_id][module_name] = self.set_lora_module(
429
+ module_name, module
430
+ )
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Set, Tuple
1
+ from typing import Callable, Dict, List, Optional, Set, Tuple
2
2
 
3
3
  import torch
4
4
 
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
22
22
  self,
23
23
  base_hf_config: AutoConfig,
24
24
  max_loras_per_batch: int,
25
- max_lora_dim: int,
26
25
  dtype: torch.dtype,
27
26
  tp_size: int,
28
27
  tp_rank: int,
29
- lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
30
28
  ):
31
-
32
29
  self.base_hf_config: AutoConfig = base_hf_config
33
30
  self.num_layer: int = base_hf_config.num_hidden_layers
34
31
  self.max_loras_per_batch: int = max_loras_per_batch
35
- self.max_lora_dim: int = max_lora_dim
36
32
  self.dtype: torch.dtype = dtype
37
33
  self.tp_size: int = tp_size
38
34
  self.tp_rank: int = tp_rank
39
- self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
40
35
 
41
36
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
42
37
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -55,79 +50,84 @@ class LoRAMemoryPool:
55
50
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
56
51
 
57
52
  def get_lora_A_shape(
58
- self, module_name: str, base_model: torch.nn.Module
53
+ self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
59
54
  ) -> Tuple[int]:
60
55
  """
61
56
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
62
57
  """
63
58
  input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
59
  c = get_stacked_multiply(module_name)
65
- if self.tp_size > 1:
66
- if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
67
- input_dim = divide(input_dim, self.tp_size)
60
+ if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
61
+ input_dim = divide(input_dim, self.tp_size)
68
62
  return (
69
63
  self.max_loras_per_batch,
70
- self.max_lora_dim * c,
64
+ max_lora_dim * c,
71
65
  input_dim,
72
66
  )
73
67
 
74
68
  def get_lora_B_shape(
75
- self, module_name: str, base_model: torch.nn.Module
69
+ self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
76
70
  ) -> Tuple[int]:
77
71
  """
78
72
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
79
73
  """
80
74
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
75
  c = get_stacked_multiply(module_name)
82
- if self.tp_size > 1:
83
- if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
84
- output_dim = divide(output_dim, self.tp_size)
76
+ if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
77
+ output_dim = divide(output_dim, self.tp_size)
85
78
  return (
86
79
  c,
87
80
  self.max_loras_per_batch,
88
81
  output_dim,
89
- self.max_lora_dim,
82
+ max_lora_dim,
90
83
  )
91
84
 
92
85
  def init_buffers(
93
86
  self,
94
87
  lora_weight_names: Tuple[Set[str]],
95
88
  base_model: torch.nn.Module,
89
+ max_lora_dim: int,
96
90
  ):
97
-
98
91
  # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
99
92
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
100
93
  self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
101
94
  device = next(base_model.parameters()).device
102
- # Init A tensor, column_major=False
103
- for module_A in lora_weight_names[0]:
104
- lora_A_shape = self.get_lora_A_shape(module_A, base_model)
105
- self.A_buffer[module_A] = [
106
- torch.empty(
107
- lora_A_shape,
108
- dtype=self.dtype,
109
- device=device,
110
- )
111
- for _ in range(self.num_layer)
112
- ]
113
- # Init B tensor, column_major=True
114
- for module_B in lora_weight_names[1]:
115
- lora_B_shape = self.get_lora_B_shape(module_B, base_model)
116
- self.B_buffer[module_B] = [
117
- torch.empty(
118
- lora_B_shape,
119
- dtype=self.dtype,
120
- device=device,
121
- )
122
- for _ in range(self.num_layer)
123
- ]
95
+
96
+ def update_buffer(
97
+ buffer: Dict[str, List[torch.Tensor]],
98
+ lora_weight_names: Set[str],
99
+ get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
100
+ ):
101
+ new_weight_names = lora_weight_names - buffer.keys()
102
+ for module_name in new_weight_names:
103
+ lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
104
+ buffer[module_name] = [
105
+ torch.empty(
106
+ lora_shape,
107
+ dtype=self.dtype,
108
+ device=device,
109
+ )
110
+ for _ in range(self.num_layer)
111
+ ]
112
+
113
+ update_buffer(
114
+ self.A_buffer,
115
+ lora_weight_names[0],
116
+ self.get_lora_A_shape,
117
+ )
118
+
119
+ update_buffer(
120
+ self.B_buffer,
121
+ lora_weight_names[1],
122
+ self.get_lora_B_shape,
123
+ )
124
124
 
125
125
  def prepare_lora_batch(
126
126
  self,
127
127
  cur_uids: Set[Optional[str]],
128
128
  lora_adapters: Dict[str, LoRAAdapter],
129
+ lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
129
130
  ):
130
-
131
131
  def get_available_buffer_slot():
132
132
  for buffer_id in range(self.max_loras_per_batch):
133
133
  # Prioritize empty slots
@@ -147,14 +147,19 @@ class LoRAMemoryPool:
147
147
  for uid in cur_uids:
148
148
  if uid not in self.uid_to_buffer_id:
149
149
  buffer_id = get_available_buffer_slot()
150
+ lora_adapter = lora_adapters.get(uid, None)
150
151
  self.load_lora_weight_to_buffer(
151
- uid, buffer_id, lora_adapters.get(uid, None)
152
+ uid, buffer_id, lora_adapter, lora_modules
152
153
  )
153
154
  self.uid_to_buffer_id[uid] = buffer_id
154
155
  self.buffer_id_to_uid[buffer_id] = uid
155
156
 
156
157
  def load_lora_weight_to_buffer(
157
- self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
158
+ self,
159
+ uid: str,
160
+ buffer_id: int,
161
+ lora_adapter: LoRAAdapter,
162
+ lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
158
163
  ):
159
164
  def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
160
165
  assert (
@@ -186,8 +191,8 @@ class LoRAMemoryPool:
186
191
  temp_B_buffer[lora_weight_name] = weights
187
192
 
188
193
  if self.tp_size > 1:
189
- cur_layer_modules = self.lora_modules[layer_id]
190
- for module_name, module in cur_layer_modules:
194
+ cur_layer_modules = lora_modules[layer_id]
195
+ for module_name, module in cur_layer_modules.items():
191
196
  if "qkv_proj" in module_name:
192
197
  temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
193
198
  temp_A_buffer["qkv_proj"], self.tp_rank
@@ -236,7 +241,6 @@ class LoRAMemoryPool:
236
241
  def get_tensor(
237
242
  self, weight_name: str, layer_id: int, lora_type: LoRAType
238
243
  ) -> torch.Tensor:
239
-
240
244
  if lora_type == LoRAType.LORA_A:
241
245
  return self.A_buffer[weight_name][layer_id]
242
246
 
sglang/srt/lora/utils.py CHANGED
@@ -108,7 +108,7 @@ def get_hidden_dim(
108
108
 
109
109
  def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
110
110
  """
111
- Mapping a target module name to names of the normized LoRA weights.
111
+ Mapping a target module name to names of the normalized LoRA weights.
112
112
  Returned tuple contains (name for Lora A, name for Lora B)
113
113
  """
114
114
  params_mapping = {
@@ -18,34 +18,50 @@ import logging
18
18
  import math
19
19
  import threading
20
20
  from queue import Empty, Full, PriorityQueue, Queue
21
- from typing import List, Optional
21
+ from typing import TYPE_CHECKING, List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
26
- from sglang.srt.mem_cache.memory_pool_host import HostKVCache
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
+ from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
28
 
28
29
  logger = logging.getLogger(__name__)
29
30
 
30
31
 
31
32
  class LayerDoneCounter:
32
33
  def __init__(self, num_layers):
33
- self.counter = num_layers
34
- self.condition = threading.Condition()
34
+ self.num_layers = num_layers
35
+ # extra producer and consumer counters for overlap mode
36
+ self.num_counters = 3
37
+ self.counters = [num_layers] * self.num_counters
38
+ self.conditions = [threading.Condition() for _ in range(self.num_counters)]
39
+ self.producer_index = 0
40
+ self.consumer_index = 0
41
+
42
+ def next_producer(self):
43
+ return (self.producer_index + 1) % self.num_counters
44
+
45
+ def update_producer(self):
46
+ self.producer_index = self.next_producer()
47
+ return self.producer_index
48
+
49
+ def set_consumer(self, index):
50
+ self.consumer_index = index
35
51
 
36
52
  def increment(self):
37
- with self.condition:
38
- self.counter += 1
39
- self.condition.notify_all()
53
+ with self.conditions[self.producer_index]:
54
+ self.counters[self.producer_index] += 1
55
+ self.conditions[self.producer_index].notify_all()
40
56
 
41
57
  def wait_until(self, threshold):
42
- with self.condition:
43
- while self.counter <= threshold:
44
- self.condition.wait()
58
+ with self.conditions[self.consumer_index]:
59
+ while self.counters[self.consumer_index] <= threshold:
60
+ self.conditions[self.consumer_index].wait()
45
61
 
46
62
  def reset(self):
47
- with self.condition:
48
- self.counter = 0
63
+ with self.conditions[self.producer_index]:
64
+ self.counters[self.producer_index] = 0
49
65
 
50
66
 
51
67
  class CacheOperation:
@@ -148,7 +164,7 @@ class HiCacheController:
148
164
 
149
165
  def __init__(
150
166
  self,
151
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
167
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
152
168
  mem_pool_host: HostKVCache,
153
169
  page_size: int,
154
170
  load_cache_event: threading.Event = None,
@@ -296,7 +312,6 @@ class HiCacheController:
296
312
  while not self.stop_event.is_set():
297
313
  try:
298
314
  operation = self.load_queue.get(block=True, timeout=1)
299
- # time.sleep(18e-6 * len(operation.host_indices))
300
315
  operation.data = self.mem_pool_host.get_flat_data(
301
316
  operation.host_indices
302
317
  )
@@ -320,6 +335,7 @@ class HiCacheController:
320
335
  if not self.load_cache_event.is_set():
321
336
  continue
322
337
  self.load_cache_event.clear()
338
+ self.layer_done_counter.update_producer()
323
339
 
324
340
  batch_operation = None
325
341
  while self.load_queue.qsize() > 0:
@@ -331,6 +347,7 @@ class HiCacheController:
331
347
  if batch_operation is None:
332
348
  continue
333
349
 
350
+ # start layer-wise KV cache transfer from CPU to GPU
334
351
  self.layer_done_counter.reset()
335
352
  for i in range(self.mem_pool_host.layer_num):
336
353
  if self.page_size == 1:
@@ -466,6 +483,7 @@ class HiCacheController:
466
483
  except Exception as e:
467
484
  logger.error(e)
468
485
 
486
+ # todo (zhiqiang): double buffering to be deprecated
469
487
  def write_thread_func_buffer(self):
470
488
  aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
471
489
  aux_thread.start()