sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt.distributed import divide
5
6
  from sglang.srt.hf_transformers_utils import AutoConfig
7
+ from sglang.srt.lora.layers import BaseLayerWithLoRA
6
8
  from sglang.srt.lora.lora import LoRAAdapter
7
9
  from sglang.srt.lora.utils import (
10
+ ROW_PARALLELISM_LINEAR_LORA_NAMES,
8
11
  LoRAType,
9
12
  get_hidden_dim,
10
13
  get_stacked_multiply,
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
21
24
  max_loras_per_batch: int,
22
25
  max_lora_dim: int,
23
26
  dtype: torch.dtype,
27
+ tp_size: int,
28
+ tp_rank: int,
29
+ lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
24
30
  ):
25
31
 
26
32
  self.base_hf_config: AutoConfig = base_hf_config
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
28
34
  self.max_loras_per_batch: int = max_loras_per_batch
29
35
  self.max_lora_dim: int = max_lora_dim
30
36
  self.dtype: torch.dtype = dtype
37
+ self.tp_size: int = tp_size
38
+ self.tp_rank: int = tp_rank
39
+ self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
31
40
 
32
41
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
33
42
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
45
54
  # Here we don't initalize to None since None is a valid uid
46
55
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
47
56
 
57
+ def get_lora_A_shape(
58
+ self, module_name: str, base_model: torch.nn.Module
59
+ ) -> Tuple[int]:
60
+ """
61
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
62
+ """
63
+ input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
+ c = get_stacked_multiply(module_name)
65
+ if self.tp_size > 1:
66
+ if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
67
+ input_dim = divide(input_dim, self.tp_size)
68
+ return (
69
+ self.max_loras_per_batch,
70
+ self.max_lora_dim * c,
71
+ input_dim,
72
+ )
73
+
74
+ def get_lora_B_shape(
75
+ self, module_name: str, base_model: torch.nn.Module
76
+ ) -> Tuple[int]:
77
+ """
78
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
79
+ """
80
+ _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
+ c = get_stacked_multiply(module_name)
82
+ if self.tp_size > 1:
83
+ if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
84
+ output_dim = divide(output_dim, self.tp_size)
85
+ return (
86
+ c,
87
+ self.max_loras_per_batch,
88
+ output_dim,
89
+ self.max_lora_dim,
90
+ )
91
+
48
92
  def init_buffers(
49
93
  self,
50
94
  lora_weight_names: Set[Tuple[str]],
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
54
98
  # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
55
99
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
56
100
  self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
57
-
58
- for module_A, module_B in lora_weight_names:
59
- # Init A tensor, column_major=False
60
- input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
61
- c = get_stacked_multiply(module_A)
62
- if module_A not in self.A_buffer:
63
- self.A_buffer[module_A] = [
64
- torch.empty(
65
- (
66
- self.max_loras_per_batch,
67
- self.max_lora_dim * c,
68
- input_dim,
69
- ),
70
- dtype=self.dtype,
71
- device="cuda",
72
- )
73
- for i in range(self.num_layer)
74
- ]
75
-
76
- # Init B tensor, column_major=True
77
- _, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
78
- c = get_stacked_multiply(module_B)
79
- if module_B not in self.B_buffer:
80
- self.B_buffer[module_B] = [
81
- torch.empty(
82
- (
83
- c, # stacked lora_b modules might need separation
84
- self.max_loras_per_batch,
85
- output_dim,
86
- self.max_lora_dim,
87
- ),
88
- dtype=self.dtype,
89
- device="cuda",
90
- )
91
- for i in range(self.num_layer)
92
- ]
101
+ device = next(base_model.parameters()).device
102
+ lora_module_A_names = set([name[0] for name in lora_weight_names])
103
+ lora_module_B_names = set([name[1] for name in lora_weight_names])
104
+ # Init A tensor, column_major=False
105
+ for module_A in lora_module_A_names:
106
+ lora_A_shape = self.get_lora_A_shape(module_A, base_model)
107
+ self.A_buffer[module_A] = [
108
+ torch.empty(
109
+ lora_A_shape,
110
+ dtype=self.dtype,
111
+ device=device,
112
+ )
113
+ for i in range(self.num_layer)
114
+ ]
115
+ # Init B tensor, column_major=True
116
+ for module_B in lora_module_B_names:
117
+ lora_B_shape = self.get_lora_B_shape(module_B, base_model)
118
+ self.B_buffer[module_B] = [
119
+ torch.empty(
120
+ lora_B_shape,
121
+ dtype=self.dtype,
122
+ device=device,
123
+ )
124
+ for _ in range(self.num_layer)
125
+ ]
93
126
 
94
127
  def prepare_lora_batch(
95
128
  self,
@@ -130,36 +163,68 @@ class LoRAMemoryPool:
130
163
  if uid is None:
131
164
  for i in range(self.num_layer):
132
165
  for k in self.A_buffer.keys():
133
- self.A_buffer[k][i][buffer_id] *= 0
166
+ self.A_buffer[k][i][buffer_id] = 0
134
167
  return
135
168
 
136
169
  assert lora_adapter is not None
170
+ lora_rank = lora_adapter.config.hf_config["r"]
137
171
  for layer_id in range(self.num_layer):
138
172
  layer_weights = lora_adapter.layers[layer_id].weights
173
+ temp_A_buffer: Dict[str, torch.Tensor] = {}
174
+ temp_B_buffer: Dict[str, torch.Tensor] = {}
139
175
  for name, weights in layer_weights.items():
140
176
  if "lora_A" in name:
141
177
  lora_weight_name = get_weight_name(
142
178
  name, self.lora_weight_names, LoRAType.LORA_A
143
179
  )
144
- if lora_weight_name:
145
- self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
146
- weights
147
- )
180
+ temp_A_buffer[lora_weight_name] = weights
148
181
  else:
149
182
  lora_weight_name = get_weight_name(
150
183
  name, self.lora_weight_names, LoRAType.LORA_B
151
184
  )
152
- if lora_weight_name:
153
- c = get_stacked_multiply(lora_weight_name)
154
- if c > 1:
155
- for stacked_id in range(c):
156
- self.B_buffer[lora_weight_name][layer_id][stacked_id][
157
- buffer_id
158
- ].copy_(weights[stacked_id])
159
- else:
160
- self.B_buffer[lora_weight_name][layer_id][0][
161
- buffer_id
162
- ].copy_(weights)
185
+ temp_B_buffer[lora_weight_name] = weights
186
+
187
+ if self.tp_size > 1:
188
+ cur_layer_modules = self.lora_modules[layer_id]
189
+ for module_name, module in cur_layer_modules:
190
+ if "qkv_proj" in module_name:
191
+ temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
192
+ temp_A_buffer["qkv_proj"], self.tp_rank
193
+ )
194
+ temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
195
+ module.slice_lora_b_weights(
196
+ [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
197
+ self.tp_rank,
198
+ )
199
+ )
200
+ else:
201
+ weight_name = get_weight_name(
202
+ module_name, self.lora_weight_names, LoRAType.LORA_A
203
+ )
204
+ temp_A_buffer[weight_name] = module.slice_lora_a_weights(
205
+ temp_A_buffer[weight_name], self.tp_rank
206
+ )
207
+ temp_B_buffer[weight_name] = module.slice_lora_b_weights(
208
+ temp_B_buffer[weight_name], self.tp_rank
209
+ )
210
+
211
+ for name, weights in temp_A_buffer.items():
212
+ c = get_stacked_multiply(name)
213
+ self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
214
+ weights
215
+ )
216
+
217
+ for name, weights in temp_B_buffer.items():
218
+ c = get_stacked_multiply(name)
219
+ if c > 1:
220
+ for stacked_id in range(c):
221
+ self.B_buffer[name][layer_id][stacked_id][buffer_id][
222
+ :, :lora_rank
223
+ ].copy_(weights[stacked_id])
224
+ else:
225
+ self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
226
+ weights
227
+ )
163
228
 
164
229
  def get_tensor(
165
230
  self, weight_name: str, layer_id: int, lora_type: LoRAType
@@ -22,17 +22,18 @@ def _gate_up_lora_b_kernel(
22
22
  w_stride_2,
23
23
  output_stride_0,
24
24
  output_stride_1,
25
- # Information on sequence lengths and weight id
25
+ # Information on sequence lengths,ranks and weight id
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # This kernel packs 2 sgemms (gate/up) into a single kernel.
38
39
 
@@ -51,6 +52,11 @@ def _gate_up_lora_b_kernel(
51
52
  w_index = tl.load(weight_indices + batch_id)
52
53
  seg_start = tl.load(seg_indptr + batch_id)
53
54
  n_start = gate_up_id * output_dim # offset on output dim
55
+ rank = tl.load(lora_ranks + w_index)
56
+ scaling = tl.load(scalings + w_index)
57
+
58
+ # Adjust K (rank) according to the specific LoRA adapter
59
+ K = tl.minimum(K, rank)
54
60
 
55
61
  # The tile in output matrix will have (pid_s, pid_n) as id
56
62
  num_pid_n = tl.cdiv(output_dim, BLOCK_N)
@@ -109,7 +115,6 @@ def gate_up_lora_b_fwd(
109
115
  batch_info: LoRABatchInfo,
110
116
  output_dim: int,
111
117
  base_output: torch.Tensor = None,
112
- scaling: float = 1.0,
113
118
  ) -> torch.Tensor:
114
119
 
115
120
  # x: (s, 2 * r)
@@ -160,11 +165,12 @@ def gate_up_lora_b_fwd(
160
165
  batch_info.seg_lens,
161
166
  batch_info.seg_indptr,
162
167
  batch_info.weight_indices,
168
+ batch_info.lora_ranks,
163
169
  BLOCK_S,
164
170
  BLOCK_OUT,
165
171
  BLOCK_R,
166
172
  fuse_scaling_add,
167
- scaling,
173
+ batch_info.scalings,
168
174
  )
169
175
 
170
176
  return output
@@ -26,6 +26,7 @@ def _qkv_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Offsets of q/k/v slice on output dimension
30
31
  n_offs,
31
32
  # Meta parameters
@@ -34,7 +35,7 @@ def _qkv_lora_b_kernel(
34
35
  BLOCK_K: tl.constexpr,
35
36
  # For fused output scaling and adding
36
37
  fuse_scaling_add,
37
- scaling,
38
+ scalings,
38
39
  ):
39
40
  # This kernel packs 3 sgemms (q/k/v) into a single kernel.
40
41
 
@@ -54,6 +55,10 @@ def _qkv_lora_b_kernel(
54
55
  seg_start = tl.load(seg_indptr + batch_id)
55
56
  n_start = tl.load(n_offs + qkv_id)
56
57
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
+ rank = tl.load(lora_ranks + w_index)
59
+ scaling = tl.load(scalings + w_index)
60
+ # Adjust K (rank) according to the specific LoRA adapter
61
+ K = tl.minimum(K, rank)
57
62
 
58
63
  # The tile in output matrix will have (pid_s, pid_n) as id
59
64
  num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
@@ -112,7 +117,6 @@ def qkv_lora_b_fwd(
112
117
  output_offset: torch.Tensor,
113
118
  max_qkv_out_dim: int,
114
119
  base_output: torch.Tensor = None,
115
- scaling: float = 1.0,
116
120
  ) -> torch.Tensor:
117
121
 
118
122
  # x: (s, 3 * r)
@@ -171,12 +175,13 @@ def qkv_lora_b_fwd(
171
175
  batch_info.seg_lens,
172
176
  batch_info.seg_indptr,
173
177
  batch_info.weight_indices,
178
+ batch_info.lora_ranks,
174
179
  output_offset,
175
180
  BLOCK_S,
176
181
  BLOCK_OUT,
177
182
  BLOCK_R,
178
183
  fuse_scaling_add,
179
- scaling,
184
+ batch_info.scalings,
180
185
  )
181
186
 
182
187
  return output
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
12
12
  weights,
13
13
  output,
14
14
  # Matrix dimensions
15
- N, # r
15
+ N, # stack_num * r
16
16
  K, # input_dim
17
+ stack_num,
17
18
  # Strides
18
19
  x_stride_0,
19
20
  x_stride_1,
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
22
23
  w_stride_2,
23
24
  output_stride_0,
24
25
  output_stride_1,
25
- # Information on sequence lengths and weight id
26
+ # Information on sequence lengths,ranks and weight id
26
27
  seg_lens,
27
28
  seg_indptr,
28
29
  weight_indices,
30
+ lora_ranks,
29
31
  # Meta parameters
30
32
  BLOCK_S: tl.constexpr,
31
33
  BLOCK_N: tl.constexpr,
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
43
45
  seg_len = tl.load(seg_lens + batch_id)
44
46
  w_index = tl.load(weight_indices + batch_id)
45
47
  seg_start = tl.load(seg_indptr + batch_id)
48
+ rank = tl.load(lora_ranks + w_index)
49
+ # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
+ N = tl.minimum(N, rank * stack_num)
46
51
 
47
52
  # The tile in output matrix will have (pid_s, pid_n) as id
48
53
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
91
96
 
92
97
 
93
98
  def sgemm_lora_a_fwd(
94
- x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
99
+ x: torch.Tensor,
100
+ weights: torch.Tensor,
101
+ batch_info: LoRABatchInfo,
102
+ stack_num: int = 1,
95
103
  ) -> torch.Tensor:
96
104
  # x: (s, input_dim)
97
- # weights: (num_lora, r, input_dim)
98
- # output: (s, r)
105
+ # weights: (num_lora, stack_num * r, input_dim)
106
+ # output: (s, stack_num * r)
107
+ # stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
99
108
  # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
100
109
  # input_dim is much larger than r
101
110
 
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
126
135
  output,
127
136
  R,
128
137
  K,
138
+ stack_num,
129
139
  x.stride(0),
130
140
  x.stride(1),
131
141
  weights.stride(0),
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
136
146
  batch_info.seg_lens,
137
147
  batch_info.seg_indptr,
138
148
  batch_info.weight_indices,
149
+ batch_info.lora_ranks,
139
150
  BLOCK_S,
140
151
  BLOCK_R,
141
152
  BLOCK_K,
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # x: (s, K), s is the sum of sequence lengths
38
39
  # weights: (num_lora, N, K)
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
45
46
  seg_len = tl.load(seg_lens + batch_id)
46
47
  w_index = tl.load(weight_indices + batch_id)
47
48
  seg_start = tl.load(seg_indptr + batch_id)
49
+ rank = tl.load(lora_ranks + w_index)
50
+ scaling = tl.load(scalings + w_index)
51
+ # Adjust K (rank) according to the specific LoRA adapter
52
+ K = tl.minimum(K, rank)
48
53
 
49
54
  # The tile in output matrix will have (pid_s, pid_n) as id
50
55
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
100
105
  weights: torch.Tensor,
101
106
  batch_info: LoRABatchInfo,
102
107
  base_output: torch.Tensor = None,
103
- scaling: float = 1.0,
104
108
  ) -> torch.Tensor:
105
- # x: (s, r)
106
- # weights: (num_lora, output_dim, r)
109
+ # x: (s, max_r)
110
+ # weights: (num_lora, output_dim, max_r)
107
111
  # output: (s, output_dim)
108
- # output_dim is much larger than r
112
+ # output_dim is much larger than max_r
109
113
 
110
114
  assert x.is_contiguous()
111
115
  assert weights.is_contiguous()
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
150
154
  batch_info.seg_lens,
151
155
  batch_info.seg_indptr,
152
156
  batch_info.weight_indices,
157
+ batch_info.lora_ranks,
153
158
  BLOCK_S,
154
159
  BLOCK_N,
155
160
  BLOCK_R,
156
161
  fuse_scaling_add,
157
- scaling,
162
+ batch_info.scalings,
158
163
  )
159
164
  return output
sglang/srt/lora/utils.py CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
25
25
  # The index of lora adapter used by each sequence, in shape (bs,)
26
26
  weight_indices: torch.Tensor
27
27
 
28
+ # ranks of each lora adapter, in shape (lora_num,)
29
+ lora_ranks: torch.Tensor
30
+
31
+ # scaling of each lora adapter, in shape (lora_num,)
32
+ scalings: torch.Tensor
33
+
28
34
 
29
35
  class LoRAType(Enum):
30
36
  LORA_A = 0
@@ -133,9 +139,20 @@ def get_weight_name(
133
139
  target_name is name of a given module,
134
140
  lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
135
141
  If there is a weight name in lora_weight_names that can match target_name, return this name
136
- Else return None
142
+ Else raise ValueError.
137
143
  """
138
144
  idx = 0 if lora_type == LoRAType.LORA_A else 1
139
145
  for weight_name_pair in lora_weight_names:
140
146
  if weight_name_pair[idx] in target_name:
141
147
  return weight_name_pair[idx]
148
+ raise ValueError(
149
+ f"Cannot find weight name for {target_name} in {lora_weight_names}"
150
+ )
151
+
152
+
153
+ # TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
154
+ VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
155
+ COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
156
+ MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
157
+ QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
158
+ ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
@@ -22,10 +22,7 @@ from typing import List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import (
26
- MHATokenToKVPoolHost,
27
- TokenToKVPoolAllocator,
28
- )
25
+ from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
29
26
 
30
27
  logger = logging.getLogger(__name__)
31
28
 
@@ -151,7 +148,7 @@ class HiCacheController:
151
148
  def __init__(
152
149
  self,
153
150
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
154
- mem_pool_host: MHATokenToKVPoolHost,
151
+ mem_pool_host: HostKVCache,
155
152
  load_cache_event: threading.Event = None,
156
153
  write_policy: str = "write_through_selective",
157
154
  ):
@@ -82,10 +82,12 @@ class DataParallelController:
82
82
  self.scheduler_procs = []
83
83
  self.workers = [None] * server_args.dp_size
84
84
 
85
- if not server_args.enable_dp_attention:
86
- dp_port_args = self.launch_dp_schedulers(server_args, port_args)
87
- else:
85
+ if server_args.enable_dp_attention:
88
86
  dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
87
+ self.control_message_step = server_args.tp_size
88
+ else:
89
+ dp_port_args = self.launch_dp_schedulers(server_args, port_args)
90
+ self.control_message_step = 1
89
91
 
90
92
  # Only node rank 0 runs the real data parallel controller that dispatches the requests.
91
93
  if server_args.node_rank == 0:
@@ -105,6 +107,7 @@ class DataParallelController:
105
107
  threads = []
106
108
  sockets = []
107
109
  dp_port_args = []
110
+ ready_events = []
108
111
  for dp_rank in range(server_args.dp_size):
109
112
  tmp_port_args = PortArgs.init_new(server_args)
110
113
  tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
@@ -115,10 +118,13 @@ class DataParallelController:
115
118
  # We hold it first so that the next dp worker gets a different port
116
119
  sockets.append(bind_port(tmp_port_args.nccl_port))
117
120
 
121
+ ready_event = threading.Event()
122
+ ready_events.append(ready_event)
123
+
118
124
  # Create a thread for each worker
119
125
  thread = threading.Thread(
120
- target=self.launch_tensor_parallel_group,
121
- args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
126
+ target=self.launch_tensor_parallel_group_thread,
127
+ args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
122
128
  )
123
129
  threads.append(thread)
124
130
  base_gpu_id += server_args.tp_size * server_args.gpu_id_step
@@ -130,11 +136,27 @@ class DataParallelController:
130
136
  # Start all threads
131
137
  for thread in threads:
132
138
  thread.start()
133
- for thread in threads:
134
- thread.join()
139
+ for event in ready_events:
140
+ event.wait()
135
141
 
136
142
  return dp_port_args
137
143
 
144
+ def launch_tensor_parallel_group_thread(
145
+ self,
146
+ server_args: ServerArgs,
147
+ port_args: PortArgs,
148
+ base_gpu_id: int,
149
+ dp_rank: int,
150
+ ready_event: threading.Event,
151
+ ):
152
+ self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
153
+ ready_event.set()
154
+
155
+ # This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
156
+ # function in scheduler.py will kill the scheduler.
157
+ while True:
158
+ pass
159
+
138
160
  def launch_dp_attention_schedulers(self, server_args, port_args):
139
161
  self.launch_tensor_parallel_group(server_args, port_args, 0, None)
140
162
  dp_port_args = []
@@ -223,7 +245,7 @@ class DataParallelController:
223
245
  self.dispatching(recv_req)
224
246
  else:
225
247
  # Send other control messages to first worker of tp group
226
- for worker in self.workers[:: self.server_args.tp_size]:
248
+ for worker in self.workers[:: self.control_message_step]:
227
249
  worker.send_pyobj(recv_req)
228
250
 
229
251
 
@@ -0,0 +1,81 @@
1
+ import json
2
+ import logging
3
+ import time
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Tuple
6
+
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ # global expert distribution recording
13
+ class ExpertDistributionRecorder:
14
+ # This class is a singleton class
15
+ def __new__(cls):
16
+ if not hasattr(cls, "instance"):
17
+ cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
18
+ return cls.instance
19
+
20
+ def __init__(self):
21
+ # the length of the dictionary is the number of layers
22
+ # the length of the list is the number of tokens
23
+ # the length of the tuple is topk's k value
24
+ self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
25
+ list
26
+ )
27
+ self._record = False
28
+ self._current_layer_id = "UNKNOWN"
29
+
30
+ def set_current_layer(self, layer_idx):
31
+ self._current_layer_id = layer_idx
32
+
33
+ def record_new_token(self, topk_ids):
34
+ if not self._record:
35
+ return
36
+ topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
37
+ torch.cuda.synchronize()
38
+ for i in topk_ids_list:
39
+ self._expert_distribution_record[self._current_layer_id].append(tuple(i))
40
+
41
+ def reset(self):
42
+ """Reset the expert distribution recorder."""
43
+ logger.info("Resetting expert distribution record...")
44
+ self._record = False
45
+ self._expert_distribution_record.clear()
46
+ self._current_layer_id = "UNKNOWN"
47
+
48
+ def start_record(self):
49
+ """Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
50
+ if self._record == True:
51
+ logger.warning(
52
+ "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
53
+ )
54
+ self.reset()
55
+ self._record = True
56
+
57
+ def stop_record(self):
58
+ """Stop recording the expert distribution. Set the recording flag to False."""
59
+ if self._record == False:
60
+ logger.warning(
61
+ "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
62
+ )
63
+ self._record = False
64
+
65
+ def dump_record(self):
66
+ """Dump the expert distribution record to a file. Reset the recorder after dumping."""
67
+ results = {}
68
+ for layer_idx, layer_record in self._expert_distribution_record.items():
69
+ results[layer_idx] = defaultdict(int)
70
+ for token_record in layer_record:
71
+ for expert_idx in token_record:
72
+ results[layer_idx][expert_idx] += 1
73
+ with open(
74
+ f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
75
+ "w",
76
+ ) as fd:
77
+ fd.write("layer_id,expert_id,count\n")
78
+ for layer_idx, layer_results in results.items():
79
+ for expert_idx, count in layer_results.items():
80
+ fd.write(f"{layer_idx},{expert_idx},{count}\n")
81
+ self.reset()