sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,6 @@
14
14
 
15
15
 
16
16
  import asyncio
17
- from collections import defaultdict
18
17
  from dataclasses import dataclass, field, fields
19
18
  from typing import Dict, List, Optional, Union
20
19
  from uuid import uuid4
@@ -28,14 +27,15 @@ class LoRARef:
28
27
  """
29
28
  Reference record for a LoRA model.
30
29
 
31
- This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
32
- eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
30
+ This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
31
+ The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
33
32
  keys (e.g., radix cache).
34
33
  """
35
34
 
36
35
  lora_id: str = field(default_factory=lambda: uuid4().hex)
37
36
  lora_name: Optional[str] = None
38
37
  lora_path: Optional[str] = None
38
+ pinned: Optional[bool] = None
39
39
 
40
40
  def __post_init__(self):
41
41
  if self.lora_id is None:
@@ -105,7 +105,6 @@ class LoRARegistry:
105
105
  f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
106
106
  )
107
107
  del self._registry[lora_name]
108
- del self._counters[lora_ref.lora_id]
109
108
 
110
109
  return lora_ref.lora_id
111
110
 
@@ -116,6 +115,9 @@ class LoRARegistry:
116
115
  """
117
116
 
118
117
  def _lookup(name: str) -> str:
118
+ if name is None:
119
+ return None
120
+
119
121
  lora_ref = self._registry.get(name, None)
120
122
  if lora_ref is None:
121
123
  raise ValueError(
@@ -134,7 +136,11 @@ class LoRARegistry:
134
136
 
135
137
  # Increment the counters only after all IDs are looked up.
136
138
  await asyncio.gather(
137
- *[self._counters[id].increment(notify_all=False) for id in lora_ids]
139
+ *[
140
+ self._counters[id].increment(notify_all=False)
141
+ for id in lora_ids
142
+ if id is not None
143
+ ]
138
144
  )
139
145
  return lora_ids
140
146
  else:
@@ -152,7 +158,11 @@ class LoRARegistry:
152
158
  await self._counters[lora_id].decrement()
153
159
  elif isinstance(lora_id, list):
154
160
  await asyncio.gather(
155
- *[self._counters[id].decrement() for id in lora_id]
161
+ *[
162
+ self._counters[id].decrement()
163
+ for id in lora_id
164
+ if id is not None
165
+ ]
156
166
  )
157
167
  else:
158
168
  raise TypeError("lora_id must be either a string or a list of strings.")
@@ -168,11 +178,13 @@ class LoRARegistry:
168
178
  assert (
169
179
  lora_id not in self._registry
170
180
  ), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
171
- counter = self._counters.get(lora_id)
172
- if counter:
173
- # Wait until no requests are using this LoRA adapter.
174
- await counter.wait_for_zero()
175
- del self._counters[lora_id]
181
+ assert (
182
+ lora_id in self._counters
183
+ ), "The LoRA ID should still have a counter if it has been registered before."
184
+
185
+ # Wait until no requests are using this LoRA adapter.
186
+ await self._counters[lora_id].wait_for_zero()
187
+ del self._counters[lora_id]
176
188
 
177
189
  def _register_adapter(self, lora_ref: LoRARef):
178
190
  """
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
2
3
 
3
4
  import torch
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
7
8
  from sglang.srt.lora.layers import BaseLayerWithLoRA
8
9
  from sglang.srt.lora.lora import LoRAAdapter
9
10
  from sglang.srt.lora.lora_config import LoRAConfig
11
+ from sglang.srt.lora.lora_registry import LoRARef
10
12
  from sglang.srt.lora.utils import (
11
13
  ROW_PARALLELISM_LINEAR_LORA_NAMES,
12
14
  LoRAType,
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
16
18
  get_weight_name,
17
19
  )
18
20
 
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class EmptySlot:
25
+ """
26
+ Singleton class to represent an empty slot in the memory pool.
27
+ This is used to improve readability by not using special str as a placeholder.
28
+ """
29
+
30
+ __slots__ = ()
31
+
32
+ def __repr__(self):
33
+ return "|EMPTY|"
34
+
35
+ def __new__(cls):
36
+ if not hasattr(cls, "_instance"):
37
+ cls._instance = super().__new__(cls)
38
+ return cls._instance
39
+
40
+
41
+ EMPTY_SLOT = EmptySlot()
42
+
19
43
 
20
44
  class LoRAMemoryPool:
21
45
  """Class for memory pool management of lora modules"""
@@ -28,7 +52,7 @@ class LoRAMemoryPool:
28
52
  tp_size: int,
29
53
  tp_rank: int,
30
54
  max_lora_rank: int,
31
- lora_weight_names: Tuple[Set[str], Set[str]],
55
+ lora_weight_names: Set[str],
32
56
  base_model: torch.nn.Module,
33
57
  ):
34
58
  self.base_hf_config: AutoConfig = base_hf_config
@@ -38,9 +62,7 @@ class LoRAMemoryPool:
38
62
  self.tp_size: int = tp_size
39
63
  self.tp_rank: int = tp_rank
40
64
  self.max_lora_rank: int = max_lora_rank
41
-
42
- # lora weight names for LoRA A and B respectively.
43
- self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
65
+ self.lora_weight_names: Set[str] = lora_weight_names
44
66
 
45
67
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
46
68
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -54,9 +76,11 @@ class LoRAMemoryPool:
54
76
  self.uid_to_buffer_id: Dict[Optional[str], int] = {}
55
77
 
56
78
  # Buffer idx -> lora uid in memory pool
57
- # All uids are initialized as empty strings for empty buffer slots
79
+ # All uids are initialized as `EmptySlot` for empty buffer slots
58
80
  # Here we don't initialize to None since None is a valid uid
59
- self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
81
+ self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
82
+ EMPTY_SLOT
83
+ ] * self.max_loras_per_batch
60
84
 
61
85
  self.init_buffers(base_model)
62
86
 
@@ -71,12 +95,8 @@ class LoRAMemoryPool:
71
95
  """
72
96
  if config.r > self.max_lora_rank:
73
97
  return False
74
- weights_a, weights_b = get_normalized_lora_weight_names(
75
- config.target_modules
76
- )
77
- return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
78
- self.lora_weight_names[1]
79
- )
98
+ weights = get_normalized_lora_weight_names(config.target_modules)
99
+ return weights.issubset(self.lora_weight_names)
80
100
 
81
101
  if isinstance(config, LoRAConfig):
82
102
  return _can_support(config)
@@ -106,11 +126,9 @@ class LoRAMemoryPool:
106
126
  Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
107
127
  """
108
128
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
109
- c = get_stacked_multiply(module_name)
110
129
  if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
111
130
  output_dim = divide(output_dim, self.tp_size)
112
131
  return (
113
- c,
114
132
  self.max_loras_per_batch,
115
133
  output_dim,
116
134
  max_lora_dim,
@@ -139,13 +157,13 @@ class LoRAMemoryPool:
139
157
 
140
158
  init_buffer(
141
159
  self.A_buffer,
142
- self.lora_weight_names[0],
160
+ self.lora_weight_names,
143
161
  self.get_lora_A_shape,
144
162
  )
145
163
 
146
164
  init_buffer(
147
165
  self.B_buffer,
148
- self.lora_weight_names[1],
166
+ self.lora_weight_names,
149
167
  self.get_lora_B_shape,
150
168
  )
151
169
 
@@ -154,17 +172,29 @@ class LoRAMemoryPool:
154
172
  cur_uids: Set[Optional[str]],
155
173
  lora_adapters: Dict[str, LoRAAdapter],
156
174
  lora_modules: List[Dict[str, BaseLayerWithLoRA]],
175
+ lora_refs: Dict[str, LoRARef],
157
176
  ):
158
177
  def get_available_buffer_slot():
159
178
  for buffer_id in range(self.max_loras_per_batch):
160
179
  # Prioritize empty slots
161
- if self.buffer_id_to_uid[buffer_id] == "":
180
+ if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
162
181
  return buffer_id
163
182
 
164
183
  for buffer_id in range(self.max_loras_per_batch):
184
+ uid = self.buffer_id_to_uid[buffer_id]
185
+
165
186
  # Evict unneeded lora
166
- if self.buffer_id_to_uid[buffer_id] not in cur_uids:
167
- self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
187
+ if uid not in cur_uids:
188
+ # Skip pinned LoRAs
189
+ # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
190
+ if uid is not None:
191
+ lora_ref = lora_refs.get(uid)
192
+ if lora_ref is not None and lora_ref.pinned:
193
+ continue
194
+
195
+ self.uid_to_buffer_id.pop(uid)
196
+ logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
197
+ self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
168
198
  return buffer_id
169
199
 
170
200
  raise ValueError(
@@ -208,7 +238,7 @@ class LoRAMemoryPool:
208
238
  return
209
239
 
210
240
  assert lora_adapter is not None
211
- lora_rank = lora_adapter.config.hf_config["r"]
241
+ lora_rank = lora_adapter.config.r
212
242
  for layer_id in range(self.num_layer):
213
243
  layer_weights = lora_adapter.layers[layer_id].weights
214
244
  temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
@@ -218,73 +248,38 @@ class LoRAMemoryPool:
218
248
  weight_name: None for weight_name in self.B_buffer
219
249
  }
220
250
  for name, weights in layer_weights.items():
251
+ lora_weight_name = get_weight_name(name, self.lora_weight_names)
221
252
  if "lora_A" in name:
222
- lora_weight_name = get_weight_name(
223
- name, self.lora_weight_names, LoRAType.LORA_A
224
- )
225
253
  temp_A_buffer[lora_weight_name] = weights
226
254
  else:
227
- lora_weight_name = get_weight_name(
228
- name, self.lora_weight_names, LoRAType.LORA_B
229
- )
230
255
  temp_B_buffer[lora_weight_name] = weights
231
256
 
232
257
  if self.tp_size > 1:
233
258
  cur_layer_modules = lora_modules[layer_id]
234
259
  for module_name, module in cur_layer_modules.items():
235
- weight_name = get_weight_name(
236
- module_name, self.lora_weight_names, LoRAType.LORA_A
237
- )
260
+ weight_name = get_weight_name(module_name, self.lora_weight_names)
238
261
 
239
262
  if temp_A_buffer[weight_name] is None:
240
263
  # Skip weight slicing if the weight is not present in the adapter
241
264
  continue
242
265
 
243
- if "qkv_proj" in module_name:
244
- temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
245
- temp_A_buffer["qkv_proj"], self.tp_rank
246
- )
247
- temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
248
- module.slice_lora_b_weights(
249
- [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
250
- self.tp_rank,
251
- )
252
- )
253
- else:
254
- # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
255
- # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
256
- # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
257
- # FlashInfer LoRA backend.
258
- temp_A_buffer[weight_name] = module.slice_lora_a_weights(
259
- temp_A_buffer[weight_name], self.tp_rank
260
- )
261
- temp_B_buffer[weight_name] = module.slice_lora_b_weights(
262
- temp_B_buffer[weight_name], self.tp_rank
263
- )
266
+ temp_A_buffer[weight_name] = module.slice_lora_a_weights(
267
+ temp_A_buffer[weight_name], self.tp_rank
268
+ )
269
+ temp_B_buffer[weight_name] = module.slice_lora_b_weights(
270
+ temp_B_buffer[weight_name], self.tp_rank
271
+ )
264
272
 
265
273
  for name, weights in temp_A_buffer.items():
266
274
  c = get_stacked_multiply(name)
267
- buffer_view = self.A_buffer[name][layer_id][buffer_id][
268
- : lora_rank * c, :
269
- ]
275
+ target_buffer = self.A_buffer[name][layer_id]
276
+ buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
270
277
  load_lora_weight_tensor(buffer_view, weights)
271
278
 
272
279
  for name, weights in temp_B_buffer.items():
273
- c = get_stacked_multiply(name)
274
- if c > 1:
275
- for stacked_id in range(c):
276
- buffer_view = self.B_buffer[name][layer_id][stacked_id][
277
- buffer_id
278
- ][:, :lora_rank]
279
- weight_slice = (
280
- weights[stacked_id] if weights is not None else None
281
- )
282
- load_lora_weight_tensor(buffer_view, weight_slice)
283
- else:
284
- buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
285
- :, :lora_rank
286
- ]
287
- load_lora_weight_tensor(buffer_view, weights)
280
+ target_buffer = self.B_buffer[name][layer_id]
281
+ buffer_view = target_buffer[buffer_id, :, :lora_rank]
282
+ load_lora_weight_tensor(buffer_view, weights)
288
283
 
289
284
  def get_tensor(
290
285
  self, weight_name: str, layer_id: int, lora_type: LoRAType
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
119
119
  output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
120
120
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
121
121
  )
122
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
122
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
123
123
  partial_sum += tl.load(output_ptr, mask=output_mask)
124
124
  tl.store(output_ptr, partial_sum, mask=output_mask)
125
125
 
sglang/srt/lora/utils.py CHANGED
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
47
47
  return int(match.group(1))
48
48
 
49
49
 
50
- def get_customized_names_from_hf_names(
51
- hf_module_names: Set[str], base_model: torch.nn.Module
52
- ) -> Set[str]:
53
- """
54
- This function takes in a set of huggingface style module names:
55
- e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
56
- and outputs a set of module names of customized sglang layers:
57
- e.g., {"qkv_proj", "o_proj"}
58
- """
59
- if hasattr(base_model, "get_module_name"):
60
- return {base_model.get_module_name(name) for name in hf_module_names}
61
- else:
62
- """
63
- Fallback solution of mapping from config module name to module name in model class.
64
- Please check if it aligns with your base model.
65
- Please implement the function in the model class if it is not.
66
- You can reference this function in llama.py.
67
- """
68
- params_mapping = {
69
- "q_proj": "qkv_proj",
70
- "k_proj": "qkv_proj",
71
- "v_proj": "qkv_proj",
72
- "gate_proj": "gate_up_proj",
73
- "up_proj": "gate_up_proj",
74
- }
75
- return {params_mapping.get(name, name) for name in hf_module_names}
76
-
77
-
78
50
  def get_hidden_dim(
79
51
  module_name: str, config: AutoConfig, base_model: torch.nn.Module
80
52
  ) -> Tuple[int]:
@@ -92,14 +64,20 @@ def get_hidden_dim(
92
64
  Please implement the function in the model class if it is not.
93
65
  You can reference this function in llama.py.
94
66
  """
95
- if module_name in ["q_proj", "o_proj", "qkv_proj"]:
96
- return config.hidden_size, config.hidden_size
97
- elif module_name in ["kv_proj"]:
98
- return config.hidden_size, config.hidden_size // (
99
- config.num_attention_heads // config.num_key_value_heads
67
+ head_dim = getattr(
68
+ config, "head_dim", config.hidden_size // config.num_attention_heads
69
+ )
70
+ if module_name == "qkv_proj":
71
+ return config.hidden_size, head_dim * (
72
+ config.num_attention_heads + config.num_key_value_heads * 2
73
+ )
74
+ elif module_name == "o_proj":
75
+ return (
76
+ head_dim * config.num_attention_heads,
77
+ config.hidden_size,
100
78
  )
101
79
  elif module_name == "gate_up_proj":
102
- return config.hidden_size, config.intermediate_size
80
+ return config.hidden_size, config.intermediate_size * 2
103
81
  elif module_name == "down_proj":
104
82
  return config.intermediate_size, config.hidden_size
105
83
  else:
@@ -108,26 +86,22 @@ def get_hidden_dim(
108
86
 
109
87
  def get_normalized_lora_weight_names(
110
88
  target_modules: Iterable[str],
111
- ) -> Tuple[set[str], set[str]]:
89
+ ) -> set[str]:
112
90
  """
113
91
  Mapping a list of target module name to names of the normalized LoRA weights.
114
- Returned tuple contains (name for Lora A, name for Lora B)
115
92
  """
116
93
  params_mapping = {
117
- "q_proj": (["qkv_proj"], ["q_proj"]),
118
- "k_proj": (["qkv_proj"], ["kv_proj"]),
119
- "v_proj": (["qkv_proj"], ["kv_proj"]),
120
- "gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
121
- "up_proj": (["gate_up_proj"], ["gate_up_proj"]),
122
- "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
123
- "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
94
+ "q_proj": "qkv_proj",
95
+ "k_proj": "qkv_proj",
96
+ "v_proj": "qkv_proj",
97
+ "gate_proj": "gate_up_proj",
98
+ "up_proj": "gate_up_proj",
124
99
  }
125
100
 
126
- result = (set(), set())
101
+ result = set()
127
102
  for name in target_modules:
128
- lora_a, lora_b = params_mapping.get(name, ([name], [name]))
129
- result[0].update(lora_a)
130
- result[1].update(lora_b)
103
+ weight_name = params_mapping.get(name, name)
104
+ result.add(weight_name)
131
105
  return result
132
106
 
133
107
 
@@ -137,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
137
111
  """
138
112
  stacked_rank = {
139
113
  "qkv_proj": 3,
140
- "kv_proj": 2,
141
114
  "gate_up_proj": 2,
142
115
  }
143
116
  return stacked_rank[module_name] if module_name in stacked_rank else 1
144
117
 
145
118
 
146
119
  def get_weight_name(
147
- target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
120
+ target_name: str, lora_weight_names: Tuple[Set[str]]
148
121
  ) -> Optional[str]:
149
122
  """
150
- target_name is name of a given module,
151
- lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
123
+ Get the weight name in lora_weight_names that can match target_name.
124
+
152
125
  If there is a weight name in lora_weight_names that can match target_name, return this name
153
126
  Else raise ValueError.
154
127
  """
155
- idx = 0 if lora_type == LoRAType.LORA_A else 1
156
- for weight_name in lora_weight_names[idx]:
128
+ for weight_name in lora_weight_names:
157
129
  if weight_name in target_name:
158
130
  return weight_name
159
131
  raise ValueError(
@@ -161,9 +133,4 @@ def get_weight_name(
161
133
  )
162
134
 
163
135
 
164
- # TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
165
- VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
166
- COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
167
- MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
168
- QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
169
136
  ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]