sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,11 +6,14 @@ from typing import List, Mapping, Tuple, Union
6
6
  import torch
7
7
 
8
8
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
9
- from sglang.srt.utils import is_cuda
9
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
10
10
 
11
11
  _is_cuda = is_cuda()
12
+ _is_npu = is_npu()
13
+ _is_cpu_amx_available = cpu_has_amx_support()
14
+ _is_cpu = is_cpu()
12
15
 
13
- if not _is_cuda:
16
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
14
17
  from vllm._custom_ops import scaled_fp8_quant
15
18
 
16
19
 
@@ -18,7 +18,6 @@ from typing import Optional
18
18
 
19
19
  from torch import nn
20
20
 
21
- from sglang.srt.layers.linear import UnquantizedLinearMethod
22
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
23
 
@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
52
51
  sliding_window_size: int = -1,
53
52
  is_cross_attention: bool = False,
54
53
  quant_config: Optional[QuantizationConfig] = None,
55
- attn_type=AttentionType.DECODER,
56
- prefix: str = "",
54
+ attn_type: AttentionType = AttentionType.DECODER,
57
55
  use_irope: bool = False,
56
+ prefix: str = "",
58
57
  ):
59
58
  super().__init__()
60
59
  self.tp_q_head_num = num_heads
@@ -8,10 +8,13 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import is_cuda, is_hip
11
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
12
12
 
13
13
  _is_cuda = is_cuda()
14
14
  _is_hip = is_hip()
15
+ _is_npu = is_npu()
16
+ _is_cpu_amx_available = cpu_has_amx_support()
17
+ _is_cpu = is_cpu()
15
18
 
16
19
  if _is_cuda:
17
20
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
@@ -84,7 +87,9 @@ class RotaryEmbedding(CustomOp):
84
87
  if not _is_cuda:
85
88
  cache = cache.to(dtype)
86
89
 
87
- if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
90
+ if (
91
+ not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]
92
+ ) and not (_is_cpu and _is_cpu_amx_available):
88
93
  from vllm._custom_ops import rotary_embedding
89
94
 
90
95
  self.vllm_rotary_embedding = rotary_embedding
@@ -147,6 +152,26 @@ class RotaryEmbedding(CustomOp):
147
152
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
148
153
  return query, key
149
154
 
155
+ def forward_cpu(
156
+ self,
157
+ positions: torch.Tensor,
158
+ query: torch.Tensor,
159
+ key: torch.Tensor,
160
+ offsets: Optional[torch.Tensor] = None,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ positions = torch.add(positions, offsets) if offsets is not None else positions
163
+ if _is_cpu_amx_available:
164
+ return torch.ops.sgl_kernel.rotary_embedding_cpu(
165
+ positions,
166
+ query,
167
+ key,
168
+ self.head_size,
169
+ self.cos_sin_cache,
170
+ self.is_neox_style,
171
+ )
172
+ else:
173
+ return self.forward_native(positions, query, key, offsets)
174
+
150
175
  def forward_cuda(
151
176
  self,
152
177
  positions: torch.Tensor,
@@ -696,6 +721,21 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
696
721
  key = key_rot
697
722
  return query.to(dtype), key.to(dtype)
698
723
 
724
+ def forward_cpu(
725
+ self,
726
+ positions: torch.Tensor,
727
+ query: torch.Tensor,
728
+ key: torch.Tensor,
729
+ offsets: Optional[torch.Tensor] = None,
730
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
731
+ positions = torch.add(positions, offsets) if offsets is not None else positions
732
+ if _is_cpu_amx_available:
733
+ return torch.ops.sgl_kernel.rotary_embedding_cpu(
734
+ positions, query, key, self.head_size, self.cos_sin_cache, False
735
+ )
736
+ else:
737
+ return self.forward_native(positions, query, key, offsets)
738
+
699
739
 
700
740
  class Llama3RotaryEmbedding(RotaryEmbedding):
701
741
 
@@ -91,7 +91,7 @@ class Sampler(nn.Module):
91
91
  )
92
92
  else:
93
93
  batch_next_token_ids = top_k_top_p_sampling_from_probs(
94
- probs,
94
+ probs.contiguous(),
95
95
  sampling_info.top_ks,
96
96
  sampling_info.top_ps,
97
97
  filter_apply_order="joint",
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, List, Set, Tuple
19
+ from typing import Dict, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
@@ -45,7 +45,6 @@ class LoRAManager:
45
45
  def __init__(
46
46
  self,
47
47
  base_model: torch.nn.Module,
48
- lora_paths: Dict[str, str],
49
48
  base_hf_config: AutoConfig,
50
49
  max_loras_per_batch: int,
51
50
  load_config: LoadConfig,
@@ -55,7 +54,6 @@ class LoRAManager:
55
54
  tp_rank: int = 0,
56
55
  ):
57
56
  self.base_model: torch.nn.Module = base_model
58
- self.lora_paths: Dict[str, str] = lora_paths
59
57
  self.base_hf_config: AutoConfig = base_hf_config
60
58
  self.max_loras_per_batch: int = max_loras_per_batch
61
59
  self.load_config: LoadConfig = load_config
@@ -69,8 +67,8 @@ class LoRAManager:
69
67
  backend_type = get_backend_from_name(lora_backend)
70
68
  self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
71
69
 
72
- self.init_loras()
73
- self.init_lora_memory_pool()
70
+ # Initialize mutable internal state of the LoRAManager.
71
+ self.init_state()
74
72
 
75
73
  def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
76
74
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -81,7 +79,7 @@ class LoRAManager:
81
79
  seg_indptr=torch.zeros(
82
80
  self.max_bs_in_cuda_graph + 1, dtype=torch.int32
83
81
  ),
84
- max_len=0,
82
+ max_len=1,
85
83
  weight_indices=torch.zeros(
86
84
  self.max_bs_in_cuda_graph, dtype=torch.int32
87
85
  ),
@@ -89,76 +87,103 @@ class LoRAManager:
89
87
  scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
90
88
  )
91
89
 
92
- def init_loras(self):
93
- # Config of each LoRA adapter
94
- self.configs: Dict[str, LoRAConfig] = {}
90
+ # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
91
+ # across batches.
92
+ self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
93
+ torch.cumsum(
94
+ self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
95
+ dim=0,
96
+ out=self.cuda_graph_batch_info.seg_indptr[
97
+ 1 : self.max_bs_in_cuda_graph + 1
98
+ ],
99
+ )
95
100
 
96
- # Target module names in huggingface lora configs.
97
- # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
98
- self.hf_target_names: Set[str] = set()
99
- for name, path in self.lora_paths.items():
100
- self.configs[name] = LoRAConfig(path)
101
- self.hf_target_names.update(self.configs[name].target_modules)
101
+ def load_lora_adapters(self, lora_paths: Dict[str, str]):
102
+ """
103
+ Load LoRA adapters from the specified paths.
104
+ TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
102
105
 
103
- # Target lora weight names for lora_a and lora_b modules respectively.
104
- weights_A: List[str] = []
105
- weights_B: List[str] = []
106
- for module in self.hf_target_names:
107
- lora_A, lora_B = get_normalized_lora_weight_names(module)
108
- weights_A += lora_A
109
- weights_B += lora_B
110
- self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
106
+ Args:
107
+ lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
108
+ If a LoRA adapter is already loaded, it will be skipped with a warning.
109
+ """
111
110
 
112
- # load all weights to cpu
113
- self.loras: Dict[str, LoRAAdapter] = {}
114
- for name in self.lora_paths.keys():
115
- lora_adapter = LoRAAdapter(
116
- name,
117
- self.configs[name],
118
- self.base_hf_config,
119
- self.load_config,
120
- self.lora_backend,
121
- )
122
- lora_adapter.initialize_weights()
123
- self.loras[name] = lora_adapter
111
+ for lora_name, lora_path in lora_paths.items():
112
+ if lora_name in self.loras:
113
+ logger.warning(
114
+ f"LoRA adapter {lora_name} is already loaded."
115
+ "If you want to reload it, please unload it first."
116
+ )
117
+ continue
124
118
 
125
- # misc lora configs
126
- self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
119
+ self.configs[lora_name] = LoRAConfig(lora_path)
127
120
 
128
- if self.lora_backend == "flashinfer":
129
- # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
130
- max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
131
- scaling = list(self.loras.values())[0].scaling
132
- assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
133
- assert all(x.scaling == scaling for x in self.loras.values())
121
+ self.update_state_from_configs()
134
122
 
135
- # Convert original model layers to layers with LoRA
136
- self.convert_to_lora_layers()
123
+ def unload_lora_adapters(self, lora_names: Set[str]):
124
+ """
125
+ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
126
+ delete the corresponding LoRA modules.
137
127
 
138
- def init_lora_memory_pool(self):
139
- # Initialize memory pool
140
- self.memory_pool = LoRAMemoryPool(
141
- self.base_hf_config,
142
- self.max_loras_per_batch,
143
- self.max_lora_dim,
144
- self.dtype,
145
- self.tp_size,
146
- self.tp_rank,
147
- self.lora_modules,
148
- )
128
+ Args:
129
+ lora_names (Set[str]): A set of LoRA adapter names to unload.
130
+ """
131
+ for lora_name in lora_names:
132
+ if lora_name in self.loras:
133
+ del self.configs[lora_name]
134
+ else:
135
+ logger.warning(f"LoRA adapter {lora_name} is not loaded.")
149
136
 
150
- # Initialize target lora modules in memory pool
151
- self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
137
+ self.update_state_from_configs()
152
138
 
153
139
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
154
140
  # load active loras into lora memory pool
155
141
  cur_uids = set(forward_batch.lora_paths)
156
142
  assert len(cur_uids) <= self.max_loras_per_batch
157
- self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
143
+ self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
158
144
 
159
145
  # set up batch info shared by all lora modules
160
146
  bs = forward_batch.batch_size
161
147
 
148
+ def transfer_adapter_info(
149
+ weight_indices_out: torch.Tensor,
150
+ lora_ranks_out: torch.Tensor,
151
+ scalings_out: torch.Tensor,
152
+ ):
153
+ """
154
+ Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
155
+ to device (CUDA) asynchronously.
156
+ """
157
+ weight_indices = [0] * len(forward_batch.lora_paths)
158
+ lora_ranks = [0] * self.max_loras_per_batch
159
+ scalings = [0] * self.max_loras_per_batch
160
+ for i, lora_path in enumerate(forward_batch.lora_paths):
161
+ weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
162
+ if lora_path is not None:
163
+ lora = self.loras[lora_path]
164
+ lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
165
+ scalings[weight_indices[i]] = lora.scaling
166
+
167
+ # Use pinned memory to avoid synchronizations during host-to-device transfer
168
+ weight_indices_tensor = torch.tensor(
169
+ weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
170
+ )
171
+ lora_ranks_tensor = torch.tensor(
172
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
173
+ )
174
+ scalings_tensor = torch.tensor(
175
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
176
+ )
177
+
178
+ # Copy to device tensors asynchronously
179
+ weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
180
+ lora_ranks_out[: self.max_loras_per_batch].copy_(
181
+ lora_ranks_tensor, non_blocking=True
182
+ )
183
+ scalings_out[: self.max_loras_per_batch].copy_(
184
+ scalings_tensor, non_blocking=True
185
+ )
186
+
162
187
  if (
163
188
  hasattr(self, "max_bs_in_cuda_graph")
164
189
  and bs <= self.max_bs_in_cuda_graph
@@ -166,51 +191,46 @@ class LoRAManager:
166
191
  ):
167
192
  # Do in-place updates when CUDA graph is enabled and the batch forward mode
168
193
  # could use CUDA graph.
169
- self.cuda_graph_batch_info.bs = bs
170
- self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171
- torch.cumsum(
172
- self.cuda_graph_batch_info.seg_lens[:bs],
173
- dim=0,
174
- out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
194
+
195
+ transfer_adapter_info(
196
+ self.cuda_graph_batch_info.weight_indices,
197
+ self.cuda_graph_batch_info.lora_ranks,
198
+ self.cuda_graph_batch_info.scalings,
175
199
  )
176
- self.cuda_graph_batch_info.max_len = 1
177
200
 
178
- for i, lora_path in enumerate(forward_batch.lora_paths):
179
- self.cuda_graph_batch_info.weight_indices[i] = (
180
- self.memory_pool.get_buffer_id(lora_path)
181
- )
182
- if lora_path is not None:
183
- lora = self.loras[lora_path]
184
- self.cuda_graph_batch_info.lora_ranks[
185
- self.cuda_graph_batch_info.weight_indices[i]
186
- ] = lora.config.hf_config["r"]
187
- self.cuda_graph_batch_info.scalings[
188
- self.cuda_graph_batch_info.weight_indices[i]
189
- ] = lora.scaling
201
+ self.cuda_graph_batch_info.bs = bs
202
+ self.cuda_graph_batch_info.max_len = 1
190
203
  batch_info = self.cuda_graph_batch_info
191
204
  else:
205
+ weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
206
+ lora_ranks = torch.zeros(
207
+ (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
208
+ )
209
+ scalings = torch.zeros(
210
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
211
+ )
212
+ transfer_adapter_info(
213
+ weight_indices,
214
+ lora_ranks,
215
+ scalings,
216
+ )
217
+
192
218
  seg_lens = (
193
219
  forward_batch.extend_seq_lens
194
220
  if forward_batch.forward_mode.is_extend()
195
221
  else torch.ones(bs, device=self.device)
196
222
  )
223
+
224
+ max_len = (
225
+ # Calculate max_len from the CPU copy to avoid D2H transfer.
226
+ max(forward_batch.extend_seq_lens_cpu)
227
+ if forward_batch.forward_mode.is_extend()
228
+ else 1
229
+ )
230
+
197
231
  seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
198
232
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
199
- max_len = int(torch.max(seg_lens))
200
- weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
201
233
 
202
- lora_ranks = torch.zeros(
203
- (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
204
- )
205
- scalings = torch.zeros(
206
- (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
207
- )
208
- for i, lora_path in enumerate(forward_batch.lora_paths):
209
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
210
- if lora_path is not None:
211
- lora = self.loras[lora_path]
212
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
213
- scalings[weight_indices[i]] = lora.scaling
214
234
  batch_info = LoRABatchInfo(
215
235
  bs=bs,
216
236
  seg_lens=seg_lens,
@@ -222,9 +242,16 @@ class LoRAManager:
222
242
  )
223
243
  self.lora_backend.set_batch_info(batch_info)
224
244
 
225
- # call set_lora_info for each lora modules
226
- for layer_id, modules in self.lora_modules.items():
227
- for module_name, module in modules:
245
+ # TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
246
+ # this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
247
+ self.update_lora_info()
248
+
249
+ def update_lora_info(self):
250
+ """
251
+ Update all LoRA modules to associate them with the latest memory buffer.
252
+ """
253
+ for layer_id, layer_modules in self.lora_modules.items():
254
+ for module_name, module in layer_modules.items():
228
255
  if "qkv_proj" in module_name:
229
256
  module.set_lora_info(
230
257
  self.memory_pool.get_tensor(
@@ -250,23 +277,139 @@ class LoRAManager:
250
277
  ),
251
278
  )
252
279
 
280
+ def init_state(self):
281
+ """
282
+ Initialize the internal (mutable) state of the LoRAManager.
283
+
284
+ These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
285
+ """
286
+
287
+ # Configs of all active LoRA adapters.
288
+ self.configs: Dict[str, LoRAConfig] = {}
289
+
290
+ # LoRA adapter weights cached in CPU memory.
291
+ self.loras: Dict[str, LoRAAdapter] = {}
292
+
293
+ # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
294
+ self.lora_weight_names: Tuple[Set[str]] = (set(), set())
295
+
296
+ # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
297
+ self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
298
+ i: {} for i in range(self.base_hf_config.num_hidden_layers)
299
+ }
300
+
301
+ # Initialize memory pool
302
+ self.memory_pool = LoRAMemoryPool(
303
+ self.base_hf_config,
304
+ self.max_loras_per_batch,
305
+ self.dtype,
306
+ self.tp_size,
307
+ self.tp_rank,
308
+ )
309
+
310
+ def update_state_from_configs(self):
311
+ """
312
+ Update the internal state of the LoRAManager based on the current `self.configs`. This method
313
+ should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
314
+
315
+ This includes:
316
+ - Initializing LoRA adapters if they are not already loaded.
317
+ - Collect all LoRA weight names based on the current loaded adapters.
318
+ - Lazily monkey-patching the base model to use LoRA layers where applicable.
319
+ - Preparing the GPU buffer pool for active LoRA weights.
320
+ """
321
+
322
+ # Target module names in huggingface lora configs.
323
+ # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
324
+ hf_target_module_names: Set[str] = set()
325
+ for config in self.configs.values():
326
+ hf_target_module_names.update(config.target_modules)
327
+ max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
328
+
329
+ # Loads / unloads LoRA adapters based on the latest configs.
330
+ self.update_lora_adapters()
331
+
332
+ # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
333
+ #
334
+ # Please note that the following update operations are "monotonic" by design, meaning that we update
335
+ # multiple places to support the new weight names when the first adapter targeting such weight names
336
+ # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
337
+ # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
338
+ # list of LoRA weight names is expected to be extremely finite and stable.
339
+ self.update_lora_weight_names(hf_target_module_names)
340
+ self.update_lora_modules(hf_target_module_names)
341
+ self.update_memory_buffers(max_lora_dim)
342
+
343
+ def update_lora_weight_names(self, hf_target_names: Set[str]):
344
+ """
345
+ Add new LoRA weight names if needed based on the current `self.configs`.
346
+ """
347
+
348
+ # Target lora weight names for lora_a and lora_b modules respectively.
349
+ for module in hf_target_names:
350
+ lora_A, lora_B = get_normalized_lora_weight_names(module)
351
+ self.lora_weight_names[0].update(lora_A)
352
+ self.lora_weight_names[1].update(lora_B)
353
+
354
+ def update_lora_adapters(self):
355
+ """
356
+ Update the LoRA adapters in CPU memory based on the current `self.configs`.
357
+ It loads any new adapters that are not already loaded, and unloads any adapters
358
+ that are no longer in `self.configs` (e.g., unloaded).
359
+ """
360
+
361
+ # Load new adapter weights to cpu
362
+ for name, config in self.configs.items():
363
+ if name not in self.loras:
364
+ logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
365
+ lora_adapter = LoRAAdapter(
366
+ name,
367
+ config,
368
+ self.base_hf_config,
369
+ self.load_config,
370
+ self.lora_backend,
371
+ )
372
+ lora_adapter.initialize_weights()
373
+ self.loras[name] = lora_adapter
374
+
375
+ # Clean up unused LoRA adapters
376
+ for name in self.loras:
377
+ if name not in self.configs:
378
+ logger.info(f"Unloading LoRA adapter {name}")
379
+ del self.loras[name]
380
+
381
+ # Additional checks for flashinfer backend
382
+ # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
383
+ if self.lora_backend == "flashinfer":
384
+ lora_dims = set(x.hf_config["r"] for x in self.configs.values())
385
+ scalings = set(x.scaling for x in self.loras.values())
386
+ assert (
387
+ len(lora_dims) == 1 and len(scalings) == 1
388
+ ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
389
+
390
+ def update_memory_buffers(self, max_lora_dim: int):
391
+ """
392
+ Update the LoRA memory pool buffers based on the current LoRA configurations and update
393
+ LoRA modules to use the new buffers. This method should be called after the LoRA configurations
394
+ are set or updated.
395
+ """
396
+
397
+ self.memory_pool.init_buffers(
398
+ self.lora_weight_names, self.base_model, max_lora_dim
399
+ )
400
+
253
401
  def set_lora_module(self, module_name, module):
254
402
  lora_module = get_lora_layer(module, self.lora_backend)
255
403
  replace_submodule(self.base_model, module_name, lora_module)
256
404
  return lora_module
257
405
 
258
- def convert_to_lora_layers(self):
406
+ def update_lora_modules(self, hf_target_names: Set[str]):
259
407
  # Target module names of customized layers defined in python/sglang/srt/layers
260
408
  # e.g., {"qkv_proj", "o_proj"}
261
409
  customized_target_names = get_customized_names_from_hf_names(
262
- self.hf_target_names, self.base_model
410
+ hf_target_names, self.base_model
263
411
  )
264
412
 
265
- # Monkey patch to use the LoRA version layers
266
- self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
267
- i: [] for i in range(self.base_hf_config.num_hidden_layers)
268
- }
269
-
270
413
  for module_name, module in self.base_model.named_modules():
271
414
  # TODO (lifuhuang): in the future, we should consider generalizing the
272
415
  # should_apply_lora function to support mapping by full module name instead
@@ -281,6 +424,7 @@ class LoRAManager:
281
424
  # The module should be converted if it is included in target_names
282
425
  if module_name.split(".")[-1] in customized_target_names:
283
426
  layer_id = get_layer_id(module_name)
284
- self.lora_modules[layer_id].append(
285
- (module_name, self.set_lora_module(module_name, module))
286
- )
427
+ if module_name not in self.lora_modules[layer_id]:
428
+ self.lora_modules[layer_id][module_name] = self.set_lora_module(
429
+ module_name, module
430
+ )