sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,7 @@ import torch
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
24
  from sglang.srt.hf_transformers_utils import AutoConfig
25
25
  from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
26
- from sglang.srt.lora.layers import get_lora_layer
26
+ from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
29
29
  from sglang.srt.lora.mem_pool import LoRAMemoryPool
@@ -51,6 +51,8 @@ class LoRAManager:
51
51
  load_config: LoadConfig,
52
52
  dtype: torch.dtype,
53
53
  lora_backend: str = "triton",
54
+ tp_size: int = 1,
55
+ tp_rank: int = 0,
54
56
  ):
55
57
  self.base_model: torch.nn.Module = base_model
56
58
  self.lora_paths: Dict[str, str] = lora_paths
@@ -58,6 +60,9 @@ class LoRAManager:
58
60
  self.max_loras_per_batch: int = max_loras_per_batch
59
61
  self.load_config: LoadConfig = load_config
60
62
  self.dtype: torch.dtype = dtype
63
+ self.device: torch.device = next(self.base_model.parameters()).device
64
+ self.tp_size: int = tp_size
65
+ self.tp_rank: int = tp_rank
61
66
 
62
67
  # LoRA backend for running sgemm kernels
63
68
  logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -110,7 +115,13 @@ class LoRAManager:
110
115
  def init_lora_memory_pool(self):
111
116
  # Initialize memory pool
112
117
  self.memory_pool = LoRAMemoryPool(
113
- self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
118
+ self.base_hf_config,
119
+ self.max_loras_per_batch,
120
+ self.max_lora_dim,
121
+ self.dtype,
122
+ self.tp_size,
123
+ self.tp_rank,
124
+ self.lora_modules,
114
125
  )
115
126
 
116
127
  # Initialize target lora modules in memory pool
@@ -131,12 +142,12 @@ class LoRAManager:
131
142
  seg_lens = (
132
143
  forward_batch.extend_seq_lens
133
144
  if forward_batch.forward_mode.is_extend()
134
- else torch.ones(bs, device="cuda")
145
+ else torch.ones(bs, device=self.device)
135
146
  )
136
- seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
147
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
137
148
  seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
138
149
  max_len = int(torch.max(seg_lens))
139
- weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
150
+ weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
140
151
  for i, lora_path in enumerate(forward_batch.lora_paths):
141
152
  weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
142
153
 
@@ -150,22 +161,32 @@ class LoRAManager:
150
161
  self.lora_backend.set_batch_info(batch_info)
151
162
 
152
163
  # call set_lora_info for each lora modules
153
- for module_name, module in self.lora_modules:
154
- layer_id = get_layer_id(module_name)
155
- if "qkv_proj" not in module_name:
156
- weight_name = get_weight_name(
157
- module_name, self.lora_weight_names, LoRAType.LORA_A
158
- )
159
- module.set_lora_info(
160
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
161
- self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
162
- )
163
- else:
164
- module.set_lora_info(
165
- self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
166
- self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
167
- self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
168
- )
164
+ for layer_id, modules in self.lora_modules.items():
165
+ for module_name, module in modules:
166
+ if "qkv_proj" in module_name:
167
+ module.set_lora_info(
168
+ self.memory_pool.get_tensor(
169
+ "qkv_proj", layer_id, LoRAType.LORA_A
170
+ ),
171
+ self.memory_pool.get_tensor(
172
+ "q_proj", layer_id, LoRAType.LORA_B
173
+ ),
174
+ self.memory_pool.get_tensor(
175
+ "kv_proj", layer_id, LoRAType.LORA_B
176
+ ),
177
+ )
178
+ else:
179
+ weight_name = get_weight_name(
180
+ module_name, self.lora_weight_names, LoRAType.LORA_A
181
+ )
182
+ module.set_lora_info(
183
+ self.memory_pool.get_tensor(
184
+ weight_name, layer_id, LoRAType.LORA_A
185
+ ),
186
+ self.memory_pool.get_tensor(
187
+ weight_name, layer_id, LoRAType.LORA_B
188
+ ),
189
+ )
169
190
 
170
191
  def set_lora_module(self, module_name, module):
171
192
  lora_module = get_lora_layer(
@@ -182,10 +203,13 @@ class LoRAManager:
182
203
  )
183
204
 
184
205
  # Monkey patch to use the LoRA version layers
185
- self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
206
+ self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
207
+ i: [] for i in range(self.base_hf_config.num_hidden_layers)
208
+ }
186
209
  for module_name, module in self.base_model.named_modules():
187
210
  # The module should be converted if it is included in target_names
188
211
  if module_name.split(".")[-1] in customized_target_names:
189
- self.lora_modules.append(
212
+ layer_id = get_layer_id(module_name)
213
+ self.lora_modules[layer_id].append(
190
214
  (module_name, self.set_lora_module(module_name, module))
191
215
  )
@@ -2,9 +2,12 @@ from typing import Dict, List, Optional, Set, Tuple
2
2
 
3
3
  import torch
4
4
 
5
+ from sglang.srt.distributed import divide
5
6
  from sglang.srt.hf_transformers_utils import AutoConfig
7
+ from sglang.srt.lora.layers import BaseLayerWithLoRA
6
8
  from sglang.srt.lora.lora import LoRAAdapter
7
9
  from sglang.srt.lora.utils import (
10
+ ROW_PARALLELISM_LINEAR_LORA_NAMES,
8
11
  LoRAType,
9
12
  get_hidden_dim,
10
13
  get_stacked_multiply,
@@ -21,6 +24,9 @@ class LoRAMemoryPool:
21
24
  max_loras_per_batch: int,
22
25
  max_lora_dim: int,
23
26
  dtype: torch.dtype,
27
+ tp_size: int,
28
+ tp_rank: int,
29
+ lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
24
30
  ):
25
31
 
26
32
  self.base_hf_config: AutoConfig = base_hf_config
@@ -28,6 +34,9 @@ class LoRAMemoryPool:
28
34
  self.max_loras_per_batch: int = max_loras_per_batch
29
35
  self.max_lora_dim: int = max_lora_dim
30
36
  self.dtype: torch.dtype = dtype
37
+ self.tp_size: int = tp_size
38
+ self.tp_rank: int = tp_rank
39
+ self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
31
40
 
32
41
  # Both A_buffer and B_buffer maps lora weight names to its buffer space.
33
42
  # A_buffer contains num_layer number of row-major tensors with shape
@@ -45,6 +54,41 @@ class LoRAMemoryPool:
45
54
  # Here we don't initalize to None since None is a valid uid
46
55
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
47
56
 
57
+ def get_lora_A_shape(
58
+ self, module_name: str, base_model: torch.nn.Module
59
+ ) -> Tuple[int]:
60
+ """
61
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
62
+ """
63
+ input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
+ c = get_stacked_multiply(module_name)
65
+ if self.tp_size > 1:
66
+ if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
67
+ input_dim = divide(input_dim, self.tp_size)
68
+ return (
69
+ self.max_loras_per_batch,
70
+ self.max_lora_dim * c,
71
+ input_dim,
72
+ )
73
+
74
+ def get_lora_B_shape(
75
+ self, module_name: str, base_model: torch.nn.Module
76
+ ) -> Tuple[int]:
77
+ """
78
+ Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
79
+ """
80
+ _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
+ c = get_stacked_multiply(module_name)
82
+ if self.tp_size > 1:
83
+ if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
84
+ output_dim = divide(output_dim, self.tp_size)
85
+ return (
86
+ c,
87
+ self.max_loras_per_batch,
88
+ output_dim,
89
+ self.max_lora_dim,
90
+ )
91
+
48
92
  def init_buffers(
49
93
  self,
50
94
  lora_weight_names: Set[Tuple[str]],
@@ -54,42 +98,31 @@ class LoRAMemoryPool:
54
98
  # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
55
99
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
56
100
  self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
57
-
58
- for module_A, module_B in lora_weight_names:
59
- # Init A tensor, column_major=False
60
- input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
61
- c = get_stacked_multiply(module_A)
62
- if module_A not in self.A_buffer:
63
- self.A_buffer[module_A] = [
64
- torch.empty(
65
- (
66
- self.max_loras_per_batch,
67
- self.max_lora_dim * c,
68
- input_dim,
69
- ),
70
- dtype=self.dtype,
71
- device="cuda",
72
- )
73
- for i in range(self.num_layer)
74
- ]
75
-
76
- # Init B tensor, column_major=True
77
- _, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
78
- c = get_stacked_multiply(module_B)
79
- if module_B not in self.B_buffer:
80
- self.B_buffer[module_B] = [
81
- torch.empty(
82
- (
83
- c, # stacked lora_b modules might need separation
84
- self.max_loras_per_batch,
85
- output_dim,
86
- self.max_lora_dim,
87
- ),
88
- dtype=self.dtype,
89
- device="cuda",
90
- )
91
- for i in range(self.num_layer)
92
- ]
101
+ device = next(base_model.parameters()).device
102
+ lora_module_A_names = set([name[0] for name in lora_weight_names])
103
+ lora_module_B_names = set([name[1] for name in lora_weight_names])
104
+ # Init A tensor, column_major=False
105
+ for module_A in lora_module_A_names:
106
+ lora_A_shape = self.get_lora_A_shape(module_A, base_model)
107
+ self.A_buffer[module_A] = [
108
+ torch.empty(
109
+ lora_A_shape,
110
+ dtype=self.dtype,
111
+ device=device,
112
+ )
113
+ for i in range(self.num_layer)
114
+ ]
115
+ # Init B tensor, column_major=True
116
+ for module_B in lora_module_B_names:
117
+ lora_B_shape = self.get_lora_B_shape(module_B, base_model)
118
+ self.B_buffer[module_B] = [
119
+ torch.empty(
120
+ lora_B_shape,
121
+ dtype=self.dtype,
122
+ device=device,
123
+ )
124
+ for _ in range(self.num_layer)
125
+ ]
93
126
 
94
127
  def prepare_lora_batch(
95
128
  self,
@@ -136,30 +169,56 @@ class LoRAMemoryPool:
136
169
  assert lora_adapter is not None
137
170
  for layer_id in range(self.num_layer):
138
171
  layer_weights = lora_adapter.layers[layer_id].weights
172
+ temp_A_buffer: Dict[str, torch.Tensor] = {}
173
+ temp_B_buffer: Dict[str, torch.Tensor] = {}
139
174
  for name, weights in layer_weights.items():
140
175
  if "lora_A" in name:
141
176
  lora_weight_name = get_weight_name(
142
177
  name, self.lora_weight_names, LoRAType.LORA_A
143
178
  )
144
- if lora_weight_name:
145
- self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
146
- weights
147
- )
179
+ temp_A_buffer[lora_weight_name] = weights
148
180
  else:
149
181
  lora_weight_name = get_weight_name(
150
182
  name, self.lora_weight_names, LoRAType.LORA_B
151
183
  )
152
- if lora_weight_name:
153
- c = get_stacked_multiply(lora_weight_name)
154
- if c > 1:
155
- for stacked_id in range(c):
156
- self.B_buffer[lora_weight_name][layer_id][stacked_id][
157
- buffer_id
158
- ].copy_(weights[stacked_id])
159
- else:
160
- self.B_buffer[lora_weight_name][layer_id][0][
161
- buffer_id
162
- ].copy_(weights)
184
+ temp_B_buffer[lora_weight_name] = weights
185
+
186
+ if self.tp_size > 1:
187
+ cur_layer_modules = self.lora_modules[layer_id]
188
+ for module_name, module in cur_layer_modules:
189
+ if "qkv_proj" in module_name:
190
+ temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
191
+ temp_A_buffer["qkv_proj"], self.tp_rank
192
+ )
193
+ temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
194
+ module.slice_lora_b_weights(
195
+ [temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
196
+ self.tp_rank,
197
+ )
198
+ )
199
+ else:
200
+ weight_name = get_weight_name(
201
+ module_name, self.lora_weight_names, LoRAType.LORA_A
202
+ )
203
+ temp_A_buffer[weight_name] = module.slice_lora_a_weights(
204
+ temp_A_buffer[weight_name], self.tp_rank
205
+ )
206
+ temp_B_buffer[weight_name] = module.slice_lora_b_weights(
207
+ temp_B_buffer[weight_name], self.tp_rank
208
+ )
209
+
210
+ for name, weights in temp_A_buffer.items():
211
+ self.A_buffer[name][layer_id][buffer_id].copy_(weights)
212
+
213
+ for name, weights in temp_B_buffer.items():
214
+ c = get_stacked_multiply(name)
215
+ if c > 1:
216
+ for stacked_id in range(c):
217
+ self.B_buffer[name][layer_id][stacked_id][buffer_id].copy_(
218
+ weights[stacked_id]
219
+ )
220
+ else:
221
+ self.B_buffer[name][layer_id][0][buffer_id].copy_(weights)
163
222
 
164
223
  def get_tensor(
165
224
  self, weight_name: str, layer_id: int, lora_type: LoRAType
sglang/srt/lora/utils.py CHANGED
@@ -133,9 +133,20 @@ def get_weight_name(
133
133
  target_name is name of a given module,
134
134
  lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
135
135
  If there is a weight name in lora_weight_names that can match target_name, return this name
136
- Else return None
136
+ Else raise ValueError.
137
137
  """
138
138
  idx = 0 if lora_type == LoRAType.LORA_A else 1
139
139
  for weight_name_pair in lora_weight_names:
140
140
  if weight_name_pair[idx] in target_name:
141
141
  return weight_name_pair[idx]
142
+ raise ValueError(
143
+ f"Cannot find weight name for {target_name} in {lora_weight_names}"
144
+ )
145
+
146
+
147
+ # TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
148
+ VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
149
+ COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
150
+ MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
151
+ QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
152
+ ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
@@ -22,10 +22,7 @@ from typing import List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import (
26
- MHATokenToKVPoolHost,
27
- TokenToKVPoolAllocator,
28
- )
25
+ from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
29
26
 
30
27
  logger = logging.getLogger(__name__)
31
28
 
@@ -151,7 +148,7 @@ class HiCacheController:
151
148
  def __init__(
152
149
  self,
153
150
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
154
- mem_pool_host: MHATokenToKVPoolHost,
151
+ mem_pool_host: HostKVCache,
155
152
  load_cache_event: threading.Event = None,
156
153
  write_policy: str = "write_through_selective",
157
154
  ):
@@ -82,10 +82,12 @@ class DataParallelController:
82
82
  self.scheduler_procs = []
83
83
  self.workers = [None] * server_args.dp_size
84
84
 
85
- if not server_args.enable_dp_attention:
86
- dp_port_args = self.launch_dp_schedulers(server_args, port_args)
87
- else:
85
+ if server_args.enable_dp_attention:
88
86
  dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
87
+ self.control_message_step = server_args.tp_size
88
+ else:
89
+ dp_port_args = self.launch_dp_schedulers(server_args, port_args)
90
+ self.control_message_step = 1
89
91
 
90
92
  # Only node rank 0 runs the real data parallel controller that dispatches the requests.
91
93
  if server_args.node_rank == 0:
@@ -105,6 +107,7 @@ class DataParallelController:
105
107
  threads = []
106
108
  sockets = []
107
109
  dp_port_args = []
110
+ ready_events = []
108
111
  for dp_rank in range(server_args.dp_size):
109
112
  tmp_port_args = PortArgs.init_new(server_args)
110
113
  tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
@@ -115,10 +118,13 @@ class DataParallelController:
115
118
  # We hold it first so that the next dp worker gets a different port
116
119
  sockets.append(bind_port(tmp_port_args.nccl_port))
117
120
 
121
+ ready_event = threading.Event()
122
+ ready_events.append(ready_event)
123
+
118
124
  # Create a thread for each worker
119
125
  thread = threading.Thread(
120
- target=self.launch_tensor_parallel_group,
121
- args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
126
+ target=self.launch_tensor_parallel_group_thread,
127
+ args=(server_args, tmp_port_args, base_gpu_id, dp_rank, ready_event),
122
128
  )
123
129
  threads.append(thread)
124
130
  base_gpu_id += server_args.tp_size * server_args.gpu_id_step
@@ -130,11 +136,27 @@ class DataParallelController:
130
136
  # Start all threads
131
137
  for thread in threads:
132
138
  thread.start()
133
- for thread in threads:
134
- thread.join()
139
+ for event in ready_events:
140
+ event.wait()
135
141
 
136
142
  return dp_port_args
137
143
 
144
+ def launch_tensor_parallel_group_thread(
145
+ self,
146
+ server_args: ServerArgs,
147
+ port_args: PortArgs,
148
+ base_gpu_id: int,
149
+ dp_rank: int,
150
+ ready_event: threading.Event,
151
+ ):
152
+ self.launch_tensor_parallel_group(server_args, port_args, base_gpu_id, dp_rank)
153
+ ready_event.set()
154
+
155
+ # This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
156
+ # function in scheduler.py will kill the scheduler.
157
+ while True:
158
+ pass
159
+
138
160
  def launch_dp_attention_schedulers(self, server_args, port_args):
139
161
  self.launch_tensor_parallel_group(server_args, port_args, 0, None)
140
162
  dp_port_args = []
@@ -223,7 +245,7 @@ class DataParallelController:
223
245
  self.dispatching(recv_req)
224
246
  else:
225
247
  # Send other control messages to first worker of tp group
226
- for worker in self.workers[:: self.server_args.tp_size]:
248
+ for worker in self.workers[:: self.control_message_step]:
227
249
  worker.send_pyobj(recv_req)
228
250
 
229
251
 
@@ -0,0 +1,81 @@
1
+ import json
2
+ import logging
3
+ import time
4
+ from collections import defaultdict
5
+ from typing import Dict, List, Tuple
6
+
7
+ import torch
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ # global expert distribution recording
13
+ class ExpertDistributionRecorder:
14
+ # This class is a singleton class
15
+ def __new__(cls):
16
+ if not hasattr(cls, "instance"):
17
+ cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
18
+ return cls.instance
19
+
20
+ def __init__(self):
21
+ # the length of the dictionary is the number of layers
22
+ # the length of the list is the number of tokens
23
+ # the length of the tuple is topk's k value
24
+ self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
25
+ list
26
+ )
27
+ self._record = False
28
+ self._current_layer_id = "UNKNOWN"
29
+
30
+ def set_current_layer(self, layer_idx):
31
+ self._current_layer_id = layer_idx
32
+
33
+ def record_new_token(self, topk_ids):
34
+ if not self._record:
35
+ return
36
+ topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
37
+ torch.cuda.synchronize()
38
+ for i in topk_ids_list:
39
+ self._expert_distribution_record[self._current_layer_id].append(tuple(i))
40
+
41
+ def reset(self):
42
+ """Reset the expert distribution recorder."""
43
+ logger.info("Resetting expert distribution record...")
44
+ self._record = False
45
+ self._expert_distribution_record.clear()
46
+ self._current_layer_id = "UNKNOWN"
47
+
48
+ def start_record(self):
49
+ """Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
50
+ if self._record == True:
51
+ logger.warning(
52
+ "SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
53
+ )
54
+ self.reset()
55
+ self._record = True
56
+
57
+ def stop_record(self):
58
+ """Stop recording the expert distribution. Set the recording flag to False."""
59
+ if self._record == False:
60
+ logger.warning(
61
+ "SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
62
+ )
63
+ self._record = False
64
+
65
+ def dump_record(self):
66
+ """Dump the expert distribution record to a file. Reset the recorder after dumping."""
67
+ results = {}
68
+ for layer_idx, layer_record in self._expert_distribution_record.items():
69
+ results[layer_idx] = defaultdict(int)
70
+ for token_record in layer_record:
71
+ for expert_idx in token_record:
72
+ results[layer_idx][expert_idx] += 1
73
+ with open(
74
+ f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
75
+ "w",
76
+ ) as fd:
77
+ fd.write("layer_id,expert_id,count\n")
78
+ for layer_idx, layer_results in results.items():
79
+ for expert_idx, count in layer_results.items():
80
+ fd.write(f"{layer_idx},{expert_idx},{count}\n")
81
+ self.reset()
@@ -45,6 +45,8 @@ class GenerateReqInput:
45
45
  # The image input. It can be a file name, a url, or base64 encoded string.
46
46
  # See also python/sglang/srt/utils.py:load_image.
47
47
  image_data: Optional[Union[List[str], str]] = None
48
+ # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
49
+ audio_data: Optional[Union[List[str], str]] = None
48
50
  # The sampling_params. See descriptions below.
49
51
  sampling_params: Optional[Union[List[Dict], Dict]] = None
50
52
  # The request id.
@@ -103,6 +105,8 @@ class GenerateReqInput:
103
105
  self.batch_size = len(self.text)
104
106
  self.input_embeds = None
105
107
  elif self.input_ids is not None:
108
+ if len(self.input_ids) == 0:
109
+ raise ValueError("input_ids cannot be empty.")
106
110
  if isinstance(self.input_ids[0], int):
107
111
  self.is_single = True
108
112
  self.batch_size = 1
@@ -165,6 +169,13 @@ class GenerateReqInput:
165
169
  elif isinstance(self.image_data, list):
166
170
  pass
167
171
 
172
+ if self.audio_data is None:
173
+ self.audio_data = [None] * num
174
+ elif not isinstance(self.audio_data, list):
175
+ self.audio_data = [self.audio_data] * num
176
+ elif isinstance(self.audio_data, list):
177
+ pass
178
+
168
179
  if self.sampling_params is None:
169
180
  self.sampling_params = [{}] * num
170
181
  elif not isinstance(self.sampling_params, list):
@@ -229,6 +240,7 @@ class GenerateReqInput:
229
240
  text=self.text[i] if self.text is not None else None,
230
241
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
231
242
  image_data=self.image_data[i],
243
+ audio_data=self.audio_data[i],
232
244
  sampling_params=self.sampling_params[i],
233
245
  rid=self.rid[i],
234
246
  return_logprob=self.return_logprob[i],
@@ -257,8 +269,8 @@ class TokenizedGenerateReqInput:
257
269
  input_text: str
258
270
  # The input token ids
259
271
  input_ids: List[int]
260
- # The image inputs
261
- image_inputs: dict
272
+ # The multimodal inputs
273
+ mm_inputs: dict
262
274
  # The sampling parameters
263
275
  sampling_params: SamplingParams
264
276
  # Whether to return the logprobs
@@ -538,7 +550,8 @@ class UpdateWeightsFromDistributedReqOutput:
538
550
 
539
551
  @dataclass
540
552
  class UpdateWeightsFromTensorReqInput:
541
- serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
553
+ # List containing one serialized Dict[str, torch.Tensor] per TP worker
554
+ serialized_named_tensors: List[bytes]
542
555
  load_format: Optional[str]
543
556
  flush_cache: bool
544
557
 
@@ -645,6 +658,17 @@ class ProfileReqType(Enum):
645
658
  STOP_PROFILE = 2
646
659
 
647
660
 
661
+ class ExpertDistributionReq(Enum):
662
+ START_RECORD = 1
663
+ STOP_RECORD = 2
664
+ DUMP_RECORD = 3
665
+
666
+
667
+ @dataclass
668
+ class ExpertDistributionReqOutput:
669
+ pass
670
+
671
+
648
672
  @dataclass
649
673
  class ProfileReq:
650
674
  type: ProfileReqType
@@ -723,3 +747,15 @@ class SeparateReasoningReqInput:
723
747
  class VertexGenerateReqInput:
724
748
  instances: List[dict]
725
749
  parameters: Optional[dict] = None
750
+
751
+
752
+ @dataclass
753
+ class RpcReqInput:
754
+ method: str
755
+ parameters: Optional[Dict] = None
756
+
757
+
758
+ @dataclass
759
+ class RpcReqOutput:
760
+ success: bool
761
+ message: str