sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Dict,
|
19
|
+
from typing import Dict, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -45,7 +45,6 @@ class LoRAManager:
|
|
45
45
|
def __init__(
|
46
46
|
self,
|
47
47
|
base_model: torch.nn.Module,
|
48
|
-
lora_paths: Dict[str, str],
|
49
48
|
base_hf_config: AutoConfig,
|
50
49
|
max_loras_per_batch: int,
|
51
50
|
load_config: LoadConfig,
|
@@ -55,7 +54,6 @@ class LoRAManager:
|
|
55
54
|
tp_rank: int = 0,
|
56
55
|
):
|
57
56
|
self.base_model: torch.nn.Module = base_model
|
58
|
-
self.lora_paths: Dict[str, str] = lora_paths
|
59
57
|
self.base_hf_config: AutoConfig = base_hf_config
|
60
58
|
self.max_loras_per_batch: int = max_loras_per_batch
|
61
59
|
self.load_config: LoadConfig = load_config
|
@@ -69,8 +67,8 @@ class LoRAManager:
|
|
69
67
|
backend_type = get_backend_from_name(lora_backend)
|
70
68
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
71
69
|
|
72
|
-
|
73
|
-
self.
|
70
|
+
# Initialize mutable internal state of the LoRAManager.
|
71
|
+
self.init_state()
|
74
72
|
|
75
73
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
76
74
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
@@ -100,72 +98,49 @@ class LoRAManager:
|
|
100
98
|
],
|
101
99
|
)
|
102
100
|
|
103
|
-
def
|
104
|
-
|
105
|
-
|
101
|
+
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
102
|
+
"""
|
103
|
+
Load LoRA adapters from the specified paths.
|
104
|
+
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
108
|
+
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
109
|
+
"""
|
110
|
+
|
111
|
+
for lora_name, lora_path in lora_paths.items():
|
112
|
+
if lora_name in self.loras:
|
113
|
+
logger.warning(
|
114
|
+
f"LoRA adapter {lora_name} is already loaded."
|
115
|
+
"If you want to reload it, please unload it first."
|
116
|
+
)
|
117
|
+
continue
|
106
118
|
|
107
|
-
|
108
|
-
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
109
|
-
self.hf_target_names: Set[str] = set()
|
110
|
-
for name, path in self.lora_paths.items():
|
111
|
-
self.configs[name] = LoRAConfig(path)
|
112
|
-
self.hf_target_names.update(self.configs[name].target_modules)
|
119
|
+
self.configs[lora_name] = LoRAConfig(lora_path)
|
113
120
|
|
114
|
-
|
115
|
-
weights_A: List[str] = []
|
116
|
-
weights_B: List[str] = []
|
117
|
-
for module in self.hf_target_names:
|
118
|
-
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
119
|
-
weights_A += lora_A
|
120
|
-
weights_B += lora_B
|
121
|
-
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
121
|
+
self.update_state_from_configs()
|
122
122
|
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
name,
|
128
|
-
self.configs[name],
|
129
|
-
self.base_hf_config,
|
130
|
-
self.load_config,
|
131
|
-
self.lora_backend,
|
132
|
-
)
|
133
|
-
lora_adapter.initialize_weights()
|
134
|
-
self.loras[name] = lora_adapter
|
135
|
-
|
136
|
-
# misc lora configs
|
137
|
-
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
123
|
+
def unload_lora_adapters(self, lora_names: Set[str]):
|
124
|
+
"""
|
125
|
+
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
126
|
+
delete the corresponding LoRA modules.
|
138
127
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
128
|
+
Args:
|
129
|
+
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
130
|
+
"""
|
131
|
+
for lora_name in lora_names:
|
132
|
+
if lora_name in self.loras:
|
133
|
+
del self.configs[lora_name]
|
134
|
+
else:
|
135
|
+
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
145
136
|
|
146
|
-
|
147
|
-
self.convert_to_lora_layers()
|
148
|
-
|
149
|
-
def init_lora_memory_pool(self):
|
150
|
-
# Initialize memory pool
|
151
|
-
self.memory_pool = LoRAMemoryPool(
|
152
|
-
self.base_hf_config,
|
153
|
-
self.max_loras_per_batch,
|
154
|
-
self.max_lora_dim,
|
155
|
-
self.dtype,
|
156
|
-
self.tp_size,
|
157
|
-
self.tp_rank,
|
158
|
-
self.lora_modules,
|
159
|
-
)
|
160
|
-
|
161
|
-
# Initialize target lora modules in memory pool
|
162
|
-
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
|
137
|
+
self.update_state_from_configs()
|
163
138
|
|
164
139
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
165
140
|
# load active loras into lora memory pool
|
166
141
|
cur_uids = set(forward_batch.lora_paths)
|
167
142
|
assert len(cur_uids) <= self.max_loras_per_batch
|
168
|
-
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
143
|
+
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
169
144
|
|
170
145
|
# set up batch info shared by all lora modules
|
171
146
|
bs = forward_batch.batch_size
|
@@ -267,9 +242,16 @@ class LoRAManager:
|
|
267
242
|
)
|
268
243
|
self.lora_backend.set_batch_info(batch_info)
|
269
244
|
|
270
|
-
#
|
271
|
-
|
272
|
-
|
245
|
+
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
246
|
+
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
247
|
+
self.update_lora_info()
|
248
|
+
|
249
|
+
def update_lora_info(self):
|
250
|
+
"""
|
251
|
+
Update all LoRA modules to associate them with the latest memory buffer.
|
252
|
+
"""
|
253
|
+
for layer_id, layer_modules in self.lora_modules.items():
|
254
|
+
for module_name, module in layer_modules.items():
|
273
255
|
if "qkv_proj" in module_name:
|
274
256
|
module.set_lora_info(
|
275
257
|
self.memory_pool.get_tensor(
|
@@ -295,23 +277,139 @@ class LoRAManager:
|
|
295
277
|
),
|
296
278
|
)
|
297
279
|
|
280
|
+
def init_state(self):
|
281
|
+
"""
|
282
|
+
Initialize the internal (mutable) state of the LoRAManager.
|
283
|
+
|
284
|
+
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
285
|
+
"""
|
286
|
+
|
287
|
+
# Configs of all active LoRA adapters.
|
288
|
+
self.configs: Dict[str, LoRAConfig] = {}
|
289
|
+
|
290
|
+
# LoRA adapter weights cached in CPU memory.
|
291
|
+
self.loras: Dict[str, LoRAAdapter] = {}
|
292
|
+
|
293
|
+
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
294
|
+
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
295
|
+
|
296
|
+
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
297
|
+
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
298
|
+
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
299
|
+
}
|
300
|
+
|
301
|
+
# Initialize memory pool
|
302
|
+
self.memory_pool = LoRAMemoryPool(
|
303
|
+
self.base_hf_config,
|
304
|
+
self.max_loras_per_batch,
|
305
|
+
self.dtype,
|
306
|
+
self.tp_size,
|
307
|
+
self.tp_rank,
|
308
|
+
)
|
309
|
+
|
310
|
+
def update_state_from_configs(self):
|
311
|
+
"""
|
312
|
+
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
313
|
+
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
314
|
+
|
315
|
+
This includes:
|
316
|
+
- Initializing LoRA adapters if they are not already loaded.
|
317
|
+
- Collect all LoRA weight names based on the current loaded adapters.
|
318
|
+
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
319
|
+
- Preparing the GPU buffer pool for active LoRA weights.
|
320
|
+
"""
|
321
|
+
|
322
|
+
# Target module names in huggingface lora configs.
|
323
|
+
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
324
|
+
hf_target_module_names: Set[str] = set()
|
325
|
+
for config in self.configs.values():
|
326
|
+
hf_target_module_names.update(config.target_modules)
|
327
|
+
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
328
|
+
|
329
|
+
# Loads / unloads LoRA adapters based on the latest configs.
|
330
|
+
self.update_lora_adapters()
|
331
|
+
|
332
|
+
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
333
|
+
#
|
334
|
+
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
335
|
+
# multiple places to support the new weight names when the first adapter targeting such weight names
|
336
|
+
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
337
|
+
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
338
|
+
# list of LoRA weight names is expected to be extremely finite and stable.
|
339
|
+
self.update_lora_weight_names(hf_target_module_names)
|
340
|
+
self.update_lora_modules(hf_target_module_names)
|
341
|
+
self.update_memory_buffers(max_lora_dim)
|
342
|
+
|
343
|
+
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
344
|
+
"""
|
345
|
+
Add new LoRA weight names if needed based on the current `self.configs`.
|
346
|
+
"""
|
347
|
+
|
348
|
+
# Target lora weight names for lora_a and lora_b modules respectively.
|
349
|
+
for module in hf_target_names:
|
350
|
+
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
351
|
+
self.lora_weight_names[0].update(lora_A)
|
352
|
+
self.lora_weight_names[1].update(lora_B)
|
353
|
+
|
354
|
+
def update_lora_adapters(self):
|
355
|
+
"""
|
356
|
+
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
357
|
+
It loads any new adapters that are not already loaded, and unloads any adapters
|
358
|
+
that are no longer in `self.configs` (e.g., unloaded).
|
359
|
+
"""
|
360
|
+
|
361
|
+
# Load new adapter weights to cpu
|
362
|
+
for name, config in self.configs.items():
|
363
|
+
if name not in self.loras:
|
364
|
+
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
365
|
+
lora_adapter = LoRAAdapter(
|
366
|
+
name,
|
367
|
+
config,
|
368
|
+
self.base_hf_config,
|
369
|
+
self.load_config,
|
370
|
+
self.lora_backend,
|
371
|
+
)
|
372
|
+
lora_adapter.initialize_weights()
|
373
|
+
self.loras[name] = lora_adapter
|
374
|
+
|
375
|
+
# Clean up unused LoRA adapters
|
376
|
+
for name in self.loras:
|
377
|
+
if name not in self.configs:
|
378
|
+
logger.info(f"Unloading LoRA adapter {name}")
|
379
|
+
del self.loras[name]
|
380
|
+
|
381
|
+
# Additional checks for flashinfer backend
|
382
|
+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
383
|
+
if self.lora_backend == "flashinfer":
|
384
|
+
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
385
|
+
scalings = set(x.scaling for x in self.loras.values())
|
386
|
+
assert (
|
387
|
+
len(lora_dims) == 1 and len(scalings) == 1
|
388
|
+
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
389
|
+
|
390
|
+
def update_memory_buffers(self, max_lora_dim: int):
|
391
|
+
"""
|
392
|
+
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
393
|
+
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
394
|
+
are set or updated.
|
395
|
+
"""
|
396
|
+
|
397
|
+
self.memory_pool.init_buffers(
|
398
|
+
self.lora_weight_names, self.base_model, max_lora_dim
|
399
|
+
)
|
400
|
+
|
298
401
|
def set_lora_module(self, module_name, module):
|
299
402
|
lora_module = get_lora_layer(module, self.lora_backend)
|
300
403
|
replace_submodule(self.base_model, module_name, lora_module)
|
301
404
|
return lora_module
|
302
405
|
|
303
|
-
def
|
406
|
+
def update_lora_modules(self, hf_target_names: Set[str]):
|
304
407
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
305
408
|
# e.g., {"qkv_proj", "o_proj"}
|
306
409
|
customized_target_names = get_customized_names_from_hf_names(
|
307
|
-
|
410
|
+
hf_target_names, self.base_model
|
308
411
|
)
|
309
412
|
|
310
|
-
# Monkey patch to use the LoRA version layers
|
311
|
-
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
312
|
-
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
313
|
-
}
|
314
|
-
|
315
413
|
for module_name, module in self.base_model.named_modules():
|
316
414
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
317
415
|
# should_apply_lora function to support mapping by full module name instead
|
@@ -326,6 +424,7 @@ class LoRAManager:
|
|
326
424
|
# The module should be converted if it is included in target_names
|
327
425
|
if module_name.split(".")[-1] in customized_target_names:
|
328
426
|
layer_id = get_layer_id(module_name)
|
329
|
-
self.lora_modules[layer_id]
|
330
|
-
|
331
|
-
|
427
|
+
if module_name not in self.lora_modules[layer_id]:
|
428
|
+
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
429
|
+
module_name, module
|
430
|
+
)
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, List, Optional, Set, Tuple
|
1
|
+
from typing import Callable, Dict, List, Optional, Set, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
|
|
22
22
|
self,
|
23
23
|
base_hf_config: AutoConfig,
|
24
24
|
max_loras_per_batch: int,
|
25
|
-
max_lora_dim: int,
|
26
25
|
dtype: torch.dtype,
|
27
26
|
tp_size: int,
|
28
27
|
tp_rank: int,
|
29
|
-
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
30
28
|
):
|
31
|
-
|
32
29
|
self.base_hf_config: AutoConfig = base_hf_config
|
33
30
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
34
31
|
self.max_loras_per_batch: int = max_loras_per_batch
|
35
|
-
self.max_lora_dim: int = max_lora_dim
|
36
32
|
self.dtype: torch.dtype = dtype
|
37
33
|
self.tp_size: int = tp_size
|
38
34
|
self.tp_rank: int = tp_rank
|
39
|
-
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
40
35
|
|
41
36
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
42
37
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -55,79 +50,84 @@ class LoRAMemoryPool:
|
|
55
50
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
56
51
|
|
57
52
|
def get_lora_A_shape(
|
58
|
-
self, module_name: str, base_model: torch.nn.Module
|
53
|
+
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
59
54
|
) -> Tuple[int]:
|
60
55
|
"""
|
61
56
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
62
57
|
"""
|
63
58
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
59
|
c = get_stacked_multiply(module_name)
|
65
|
-
if self.tp_size > 1:
|
66
|
-
|
67
|
-
input_dim = divide(input_dim, self.tp_size)
|
60
|
+
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
61
|
+
input_dim = divide(input_dim, self.tp_size)
|
68
62
|
return (
|
69
63
|
self.max_loras_per_batch,
|
70
|
-
|
64
|
+
max_lora_dim * c,
|
71
65
|
input_dim,
|
72
66
|
)
|
73
67
|
|
74
68
|
def get_lora_B_shape(
|
75
|
-
self, module_name: str, base_model: torch.nn.Module
|
69
|
+
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
76
70
|
) -> Tuple[int]:
|
77
71
|
"""
|
78
72
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
79
73
|
"""
|
80
74
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
75
|
c = get_stacked_multiply(module_name)
|
82
|
-
if self.tp_size > 1:
|
83
|
-
|
84
|
-
output_dim = divide(output_dim, self.tp_size)
|
76
|
+
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
77
|
+
output_dim = divide(output_dim, self.tp_size)
|
85
78
|
return (
|
86
79
|
c,
|
87
80
|
self.max_loras_per_batch,
|
88
81
|
output_dim,
|
89
|
-
|
82
|
+
max_lora_dim,
|
90
83
|
)
|
91
84
|
|
92
85
|
def init_buffers(
|
93
86
|
self,
|
94
87
|
lora_weight_names: Tuple[Set[str]],
|
95
88
|
base_model: torch.nn.Module,
|
89
|
+
max_lora_dim: int,
|
96
90
|
):
|
97
|
-
|
98
91
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
99
92
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
100
93
|
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
101
94
|
device = next(base_model.parameters()).device
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
)
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
95
|
+
|
96
|
+
def update_buffer(
|
97
|
+
buffer: Dict[str, List[torch.Tensor]],
|
98
|
+
lora_weight_names: Set[str],
|
99
|
+
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
100
|
+
):
|
101
|
+
new_weight_names = lora_weight_names - buffer.keys()
|
102
|
+
for module_name in new_weight_names:
|
103
|
+
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
104
|
+
buffer[module_name] = [
|
105
|
+
torch.empty(
|
106
|
+
lora_shape,
|
107
|
+
dtype=self.dtype,
|
108
|
+
device=device,
|
109
|
+
)
|
110
|
+
for _ in range(self.num_layer)
|
111
|
+
]
|
112
|
+
|
113
|
+
update_buffer(
|
114
|
+
self.A_buffer,
|
115
|
+
lora_weight_names[0],
|
116
|
+
self.get_lora_A_shape,
|
117
|
+
)
|
118
|
+
|
119
|
+
update_buffer(
|
120
|
+
self.B_buffer,
|
121
|
+
lora_weight_names[1],
|
122
|
+
self.get_lora_B_shape,
|
123
|
+
)
|
124
124
|
|
125
125
|
def prepare_lora_batch(
|
126
126
|
self,
|
127
127
|
cur_uids: Set[Optional[str]],
|
128
128
|
lora_adapters: Dict[str, LoRAAdapter],
|
129
|
+
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
129
130
|
):
|
130
|
-
|
131
131
|
def get_available_buffer_slot():
|
132
132
|
for buffer_id in range(self.max_loras_per_batch):
|
133
133
|
# Prioritize empty slots
|
@@ -147,14 +147,19 @@ class LoRAMemoryPool:
|
|
147
147
|
for uid in cur_uids:
|
148
148
|
if uid not in self.uid_to_buffer_id:
|
149
149
|
buffer_id = get_available_buffer_slot()
|
150
|
+
lora_adapter = lora_adapters.get(uid, None)
|
150
151
|
self.load_lora_weight_to_buffer(
|
151
|
-
uid, buffer_id,
|
152
|
+
uid, buffer_id, lora_adapter, lora_modules
|
152
153
|
)
|
153
154
|
self.uid_to_buffer_id[uid] = buffer_id
|
154
155
|
self.buffer_id_to_uid[buffer_id] = uid
|
155
156
|
|
156
157
|
def load_lora_weight_to_buffer(
|
157
|
-
self,
|
158
|
+
self,
|
159
|
+
uid: str,
|
160
|
+
buffer_id: int,
|
161
|
+
lora_adapter: LoRAAdapter,
|
162
|
+
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
158
163
|
):
|
159
164
|
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
160
165
|
assert (
|
@@ -186,8 +191,8 @@ class LoRAMemoryPool:
|
|
186
191
|
temp_B_buffer[lora_weight_name] = weights
|
187
192
|
|
188
193
|
if self.tp_size > 1:
|
189
|
-
cur_layer_modules =
|
190
|
-
for module_name, module in cur_layer_modules:
|
194
|
+
cur_layer_modules = lora_modules[layer_id]
|
195
|
+
for module_name, module in cur_layer_modules.items():
|
191
196
|
if "qkv_proj" in module_name:
|
192
197
|
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
193
198
|
temp_A_buffer["qkv_proj"], self.tp_rank
|
@@ -236,7 +241,6 @@ class LoRAMemoryPool:
|
|
236
241
|
def get_tensor(
|
237
242
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
238
243
|
) -> torch.Tensor:
|
239
|
-
|
240
244
|
if lora_type == LoRAType.LORA_A:
|
241
245
|
return self.A_buffer[weight_name][layer_id]
|
242
246
|
|
sglang/srt/lora/utils.py
CHANGED
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
|
108
108
|
|
109
109
|
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
110
110
|
"""
|
111
|
-
Mapping a target module name to names of the
|
111
|
+
Mapping a target module name to names of the normalized LoRA weights.
|
112
112
|
Returned tuple contains (name for Lora A, name for Lora B)
|
113
113
|
"""
|
114
114
|
params_mapping = {
|
@@ -18,34 +18,50 @@ import logging
|
|
18
18
|
import math
|
19
19
|
import threading
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, Optional
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
|
26
|
-
from sglang.srt.mem_cache.
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
28
|
|
28
29
|
logger = logging.getLogger(__name__)
|
29
30
|
|
30
31
|
|
31
32
|
class LayerDoneCounter:
|
32
33
|
def __init__(self, num_layers):
|
33
|
-
self.
|
34
|
-
|
34
|
+
self.num_layers = num_layers
|
35
|
+
# extra producer and consumer counters for overlap mode
|
36
|
+
self.num_counters = 3
|
37
|
+
self.counters = [num_layers] * self.num_counters
|
38
|
+
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
39
|
+
self.producer_index = 0
|
40
|
+
self.consumer_index = 0
|
41
|
+
|
42
|
+
def next_producer(self):
|
43
|
+
return (self.producer_index + 1) % self.num_counters
|
44
|
+
|
45
|
+
def update_producer(self):
|
46
|
+
self.producer_index = self.next_producer()
|
47
|
+
return self.producer_index
|
48
|
+
|
49
|
+
def set_consumer(self, index):
|
50
|
+
self.consumer_index = index
|
35
51
|
|
36
52
|
def increment(self):
|
37
|
-
with self.
|
38
|
-
self.
|
39
|
-
self.
|
53
|
+
with self.conditions[self.producer_index]:
|
54
|
+
self.counters[self.producer_index] += 1
|
55
|
+
self.conditions[self.producer_index].notify_all()
|
40
56
|
|
41
57
|
def wait_until(self, threshold):
|
42
|
-
with self.
|
43
|
-
while self.
|
44
|
-
self.
|
58
|
+
with self.conditions[self.consumer_index]:
|
59
|
+
while self.counters[self.consumer_index] <= threshold:
|
60
|
+
self.conditions[self.consumer_index].wait()
|
45
61
|
|
46
62
|
def reset(self):
|
47
|
-
with self.
|
48
|
-
self.
|
63
|
+
with self.conditions[self.producer_index]:
|
64
|
+
self.counters[self.producer_index] = 0
|
49
65
|
|
50
66
|
|
51
67
|
class CacheOperation:
|
@@ -148,7 +164,7 @@ class HiCacheController:
|
|
148
164
|
|
149
165
|
def __init__(
|
150
166
|
self,
|
151
|
-
token_to_kv_pool_allocator:
|
167
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
152
168
|
mem_pool_host: HostKVCache,
|
153
169
|
page_size: int,
|
154
170
|
load_cache_event: threading.Event = None,
|
@@ -296,7 +312,6 @@ class HiCacheController:
|
|
296
312
|
while not self.stop_event.is_set():
|
297
313
|
try:
|
298
314
|
operation = self.load_queue.get(block=True, timeout=1)
|
299
|
-
# time.sleep(18e-6 * len(operation.host_indices))
|
300
315
|
operation.data = self.mem_pool_host.get_flat_data(
|
301
316
|
operation.host_indices
|
302
317
|
)
|
@@ -320,6 +335,7 @@ class HiCacheController:
|
|
320
335
|
if not self.load_cache_event.is_set():
|
321
336
|
continue
|
322
337
|
self.load_cache_event.clear()
|
338
|
+
self.layer_done_counter.update_producer()
|
323
339
|
|
324
340
|
batch_operation = None
|
325
341
|
while self.load_queue.qsize() > 0:
|
@@ -331,6 +347,7 @@ class HiCacheController:
|
|
331
347
|
if batch_operation is None:
|
332
348
|
continue
|
333
349
|
|
350
|
+
# start layer-wise KV cache transfer from CPU to GPU
|
334
351
|
self.layer_done_counter.reset()
|
335
352
|
for i in range(self.mem_pool_host.layer_num):
|
336
353
|
if self.page_size == 1:
|
@@ -466,6 +483,7 @@ class HiCacheController:
|
|
466
483
|
except Exception as e:
|
467
484
|
logger.error(e)
|
468
485
|
|
486
|
+
# todo (zhiqiang): double buffering to be deprecated
|
469
487
|
def write_thread_func_buffer(self):
|
470
488
|
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
471
489
|
aux_thread.start()
|