sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Callable, Dict, List, Optional
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.utils import is_cuda_available
5
+ from sglang.srt.utils import is_cuda_available, set_weight_attrs
6
6
 
7
7
  is_cuda = is_cuda_available()
8
8
  if is_cuda:
@@ -10,6 +10,7 @@ if is_cuda:
10
10
 
11
11
  from torch.nn.parameter import Parameter
12
12
 
13
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
13
14
  from sglang.srt.layers.linear import LinearMethodBase
14
15
  from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
15
16
  from sglang.srt.layers.quantization.base_config import (
@@ -55,9 +56,12 @@ class W8A8Int8Config(QuantizationConfig):
55
56
  prefix: str,
56
57
  ) -> Optional["QuantizeMethodBase"]:
57
58
  from sglang.srt.layers.linear import LinearBase
59
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
58
60
 
59
61
  if isinstance(layer, LinearBase):
60
62
  return W8A8Int8LinearMethod(self)
63
+ elif isinstance(layer, FusedMoE):
64
+ return W8A8Int8MoEMethod(self)
61
65
  return None
62
66
 
63
67
  def get_scaled_act_names(self) -> List[str]:
@@ -81,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
81
85
  input_size: int,
82
86
  output_size: int,
83
87
  params_dtype: torch.dtype,
84
- **extra_weight_attrs
88
+ **extra_weight_attrs,
85
89
  ):
86
90
 
87
91
  weight_loader = extra_weight_attrs.get("weight_loader")
@@ -115,3 +119,148 @@ class W8A8Int8LinearMethod(LinearMethodBase):
115
119
  return int8_scaled_mm(
116
120
  x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
117
121
  )
122
+
123
+
124
+ class W8A8Int8MoEMethod:
125
+ """MoE method for INT8.
126
+ Supports loading INT8 checkpoints with static weight scale and
127
+ dynamic/static activation scale.
128
+ Also supports loading quantized FP16/BF16 model checkpoints with dynamic
129
+ activation scaling. The weight scaling factor will be initialized after
130
+ the model weights are loaded.
131
+ Args:
132
+ quant_config: The quantization config.
133
+ """
134
+
135
+ def __new__(cls, *args, **kwargs):
136
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
137
+
138
+ if not hasattr(cls, "_initialized"):
139
+ original_init = cls.__init__
140
+ new_cls = type(
141
+ cls.__name__,
142
+ (FusedMoEMethodBase,),
143
+ {
144
+ "__init__": original_init,
145
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
146
+ },
147
+ )
148
+ obj = super(new_cls, new_cls).__new__(new_cls)
149
+ obj.__init__(*args, **kwargs)
150
+ return obj
151
+ return super().__new__(cls)
152
+
153
+ def __init__(self, quant_config):
154
+ self.quant_config = quant_config
155
+
156
+ def create_weights(
157
+ self,
158
+ layer: torch.nn.Module,
159
+ num_experts: int,
160
+ hidden_size: int,
161
+ intermediate_size: int,
162
+ params_dtype: torch.dtype,
163
+ **extra_weight_attrs,
164
+ ):
165
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
166
+
167
+ tp_size = get_tensor_model_parallel_world_size()
168
+
169
+ # WEIGHTS
170
+ w13_weight = torch.nn.Parameter(
171
+ torch.empty(
172
+ num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
173
+ ),
174
+ requires_grad=False,
175
+ )
176
+ layer.register_parameter("w13_weight", w13_weight)
177
+ set_weight_attrs(w13_weight, extra_weight_attrs)
178
+
179
+ w2_weight = torch.nn.Parameter(
180
+ torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
181
+ requires_grad=False,
182
+ )
183
+ layer.register_parameter("w2_weight", w2_weight)
184
+ set_weight_attrs(w2_weight, extra_weight_attrs)
185
+
186
+ w13_weight_scale = torch.nn.Parameter(
187
+ torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
188
+ requires_grad=False,
189
+ )
190
+ w2_weight_scale = torch.nn.Parameter(
191
+ torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
192
+ requires_grad=False,
193
+ )
194
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
195
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
196
+
197
+ extra_weight_attrs.update(
198
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
199
+ )
200
+
201
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
202
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
203
+
204
+ w13_input_scale = None
205
+ layer.register_parameter("w13_input_scale", w13_input_scale)
206
+
207
+ w2_input_scale = None
208
+ layer.register_parameter("w2_input_scale", w2_input_scale)
209
+
210
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
211
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
212
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
213
+ layer.w13_weight_scale = Parameter(
214
+ layer.w13_weight_scale.data, requires_grad=False
215
+ )
216
+ layer.w2_weight_scale = Parameter(
217
+ layer.w2_weight_scale.data, requires_grad=False
218
+ )
219
+
220
+ def apply(
221
+ self,
222
+ layer: torch.nn.Module,
223
+ x: torch.Tensor,
224
+ router_logits: torch.Tensor,
225
+ top_k: int,
226
+ renormalize: bool,
227
+ use_grouped_topk: bool,
228
+ topk_group: Optional[int] = None,
229
+ num_expert_group: Optional[int] = None,
230
+ custom_routing_function: Optional[Callable] = None,
231
+ correction_bias: Optional[torch.Tensor] = None,
232
+ activation: str = "silu",
233
+ inplace: bool = True,
234
+ no_combine: bool = False,
235
+ ) -> torch.Tensor:
236
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
237
+ from sglang.srt.layers.moe.topk import select_experts
238
+
239
+ # Expert selection
240
+ topk_weights, topk_ids = select_experts(
241
+ hidden_states=x,
242
+ router_logits=router_logits,
243
+ use_grouped_topk=use_grouped_topk,
244
+ top_k=top_k,
245
+ renormalize=renormalize,
246
+ topk_group=topk_group,
247
+ num_expert_group=num_expert_group,
248
+ custom_routing_function=custom_routing_function,
249
+ correction_bias=correction_bias,
250
+ )
251
+
252
+ return fused_experts(
253
+ x,
254
+ layer.w13_weight,
255
+ layer.w2_weight,
256
+ topk_weights=topk_weights,
257
+ topk_ids=topk_ids,
258
+ inplace=inplace,
259
+ activation=activation,
260
+ use_int8_w8a8=True,
261
+ w1_scale=(layer.w13_weight_scale),
262
+ w2_scale=(layer.w2_weight_scale),
263
+ a1_scale=layer.w13_input_scale,
264
+ a2_scale=layer.w2_input_scale,
265
+ no_combine=no_combine,
266
+ )
@@ -403,12 +403,12 @@ def _yarn_find_correction_range(
403
403
 
404
404
 
405
405
  def _yarn_linear_ramp_mask(
406
- low: float, high: float, dim: int, dtype: torch.dtype
406
+ low: float, high: float, dim: int, dtype: torch.dtype, device: torch.device = None
407
407
  ) -> torch.Tensor:
408
408
  if low == high:
409
409
  high += 0.001 # Prevent singularity
410
410
 
411
- linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
411
+ linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low)
412
412
  ramp_func = torch.clamp(linear_func, 0, 1)
413
413
  return ramp_func
414
414
 
@@ -688,7 +688,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
688
688
  # Get n-d rotational scaling corrected for extrapolation
689
689
  inv_freq_mask = (
690
690
  1
691
- - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
691
+ - _yarn_linear_ramp_mask(
692
+ low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device
693
+ )
692
694
  ) * self.extrapolation_factor
693
695
  inv_freq = (
694
696
  inv_freq_interpolation * (1 - inv_freq_mask)
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List, Optional
2
+ from typing import List
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -42,7 +42,6 @@ class Sampler(nn.Module):
42
42
  return_logprob: bool,
43
43
  top_logprobs_nums: List[int],
44
44
  token_ids_logprobs: List[List[int]],
45
- batch_next_token_ids: Optional[torch.Tensor] = None,
46
45
  ):
47
46
  """Run a sampler & compute logprobs and update logits_output accordingly.
48
47
 
@@ -72,8 +71,7 @@ class Sampler(nn.Module):
72
71
 
73
72
  if sampling_info.is_all_greedy:
74
73
  # Use torch.argmax if all requests use greedy sampling
75
- if batch_next_token_ids is None:
76
- batch_next_token_ids = torch.argmax(logits, -1)
74
+ batch_next_token_ids = torch.argmax(logits, -1)
77
75
  if return_logprob:
78
76
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
79
77
  else:
@@ -94,43 +92,39 @@ class Sampler(nn.Module):
94
92
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
95
93
  ).clamp(min=torch.finfo(probs.dtype).min)
96
94
 
97
- if batch_next_token_ids is None:
98
- max_top_k_round, batch_size = 32, probs.shape[0]
99
- uniform_samples = torch.rand(
100
- (max_top_k_round, batch_size), device=probs.device
95
+ max_top_k_round, batch_size = 32, probs.shape[0]
96
+ uniform_samples = torch.rand(
97
+ (max_top_k_round, batch_size), device=probs.device
98
+ )
99
+ if sampling_info.need_min_p_sampling:
100
+ probs = top_k_renorm_prob(probs, sampling_info.top_ks)
101
+ probs = top_p_renorm_prob(probs, sampling_info.top_ps)
102
+ batch_next_token_ids = min_p_sampling_from_probs(
103
+ probs, uniform_samples, sampling_info.min_ps
101
104
  )
102
- if sampling_info.need_min_p_sampling:
103
- probs = top_k_renorm_prob(probs, sampling_info.top_ks)
104
- probs = top_p_renorm_prob(probs, sampling_info.top_ps)
105
- batch_next_token_ids = min_p_sampling_from_probs(
106
- probs, uniform_samples, sampling_info.min_ps
107
- )
108
- else:
109
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
110
- probs,
111
- uniform_samples,
112
- sampling_info.top_ks,
113
- sampling_info.top_ps,
114
- filter_apply_order="joint",
115
- )
116
-
117
- if self.use_nan_detection and not torch.all(success):
118
- logger.warning("Detected errors during sampling!")
119
- batch_next_token_ids = torch.zeros_like(
120
- batch_next_token_ids
121
- )
122
-
123
- elif global_server_args_dict["sampling_backend"] == "pytorch":
124
- if batch_next_token_ids is None:
125
- # A slower fallback implementation with torch native operations.
126
- batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
105
+ else:
106
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
127
107
  probs,
108
+ uniform_samples,
128
109
  sampling_info.top_ks,
129
110
  sampling_info.top_ps,
130
- sampling_info.min_ps,
131
- sampling_info.need_min_p_sampling,
111
+ filter_apply_order="joint",
132
112
  )
133
113
 
114
+ if self.use_nan_detection and not torch.all(success):
115
+ logger.warning("Detected errors during sampling!")
116
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
117
+
118
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
119
+ # A slower fallback implementation with torch native operations.
120
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
121
+ probs,
122
+ sampling_info.top_ks,
123
+ sampling_info.top_ps,
124
+ sampling_info.min_ps,
125
+ sampling_info.need_min_p_sampling,
126
+ )
127
+
134
128
  if return_logprob:
135
129
  # clamp to avoid -inf
136
130
  logprobs = torch.log(
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
264
264
  quant_method = None
265
265
  if quant_config is not None:
266
266
  quant_method = quant_config.get_quant_method(self, prefix=prefix)
267
- print("quant_method", quant_method)
268
267
  if quant_method is None:
269
268
  quant_method = UnquantizedEmbeddingMethod()
270
269
 
@@ -1,23 +1,20 @@
1
- from .base_backend import BaseLoRABackend
2
- from .flashinfer_backend import FlashInferLoRABackend
3
- from .triton_backend import TritonLoRABackend
1
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
4
2
 
5
3
 
6
4
  def get_backend_from_name(name: str) -> BaseLoRABackend:
7
5
  """
8
6
  Get corresponding backend class from backend's name
9
7
  """
10
- backend_mapping = {
11
- "triton": TritonLoRABackend,
12
- "flashinfer": FlashInferLoRABackend,
13
- }
8
+ if name == "triton":
9
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
14
10
 
15
- if name in backend_mapping:
16
- return backend_mapping[name]
11
+ return TritonLoRABackend
12
+ elif name == "flashinfer":
13
+ from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
17
14
 
18
- raise Exception(
19
- f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
20
- )
15
+ return FlashInferLoRABackend
16
+ else:
17
+ raise ValueError(f"Invalid backend: {name}")
21
18
 
22
19
 
23
20
  __all__ = [
@@ -22,11 +22,34 @@ from typing import List, Optional
22
22
 
23
23
  import torch
24
24
 
25
- from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
25
+ from sglang.srt.mem_cache.memory_pool import (
26
+ MHATokenToKVPoolHost,
27
+ TokenToKVPoolAllocator,
28
+ )
26
29
 
27
30
  logger = logging.getLogger(__name__)
28
31
 
29
32
 
33
+ class LayerDoneCounter:
34
+ def __init__(self, num_layers):
35
+ self.counter = num_layers
36
+ self.condition = threading.Condition()
37
+
38
+ def increment(self):
39
+ with self.condition:
40
+ self.counter += 1
41
+ self.condition.notify_all()
42
+
43
+ def wait_until(self, threshold):
44
+ with self.condition:
45
+ while self.counter <= threshold:
46
+ self.condition.wait()
47
+
48
+ def reset(self):
49
+ with self.condition:
50
+ self.counter = 0
51
+
52
+
30
53
  class CacheOperation:
31
54
 
32
55
  counter = 0
@@ -127,15 +150,20 @@ class HiCacheController:
127
150
 
128
151
  def __init__(
129
152
  self,
130
- mem_pool_device: MHATokenToKVPool,
153
+ token_to_kv_pool_allocator: TokenToKVPoolAllocator,
131
154
  mem_pool_host: MHATokenToKVPoolHost,
155
+ load_cache_event: threading.Event = None,
132
156
  write_policy: str = "write_through_selective",
133
157
  ):
134
-
135
- self.mem_pool_device = mem_pool_device
158
+ self.mem_pool_device_allocator = token_to_kv_pool_allocator
159
+ self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
136
160
  self.mem_pool_host = mem_pool_host
137
161
  self.write_policy = write_policy
138
162
 
163
+ self.load_cache_event = load_cache_event
164
+ self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
165
+ self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
166
+
139
167
  if write_policy not in [
140
168
  "write_through",
141
169
  "write_through_selective",
@@ -162,7 +190,7 @@ class HiCacheController:
162
190
  target=self.write_thread_func_buffer, daemon=True
163
191
  )
164
192
  self.load_thread = threading.Thread(
165
- target=self.load_thread_func_buffer, daemon=True
193
+ target=self.load_thread_func_layer_by_layer, daemon=True
166
194
  )
167
195
  self.write_thread.start()
168
196
  self.load_thread.start()
@@ -183,7 +211,7 @@ class HiCacheController:
183
211
  target=self.write_thread_func_buffer, daemon=True
184
212
  )
185
213
  self.load_thread = threading.Thread(
186
- target=self.load_thread_func_buffer, daemon=True
214
+ target=self.load_thread_func_layer_by_layer, daemon=True
187
215
  )
188
216
  self.stop_event.clear()
189
217
  self.write_thread.start()
@@ -216,10 +244,12 @@ class HiCacheController:
216
244
  """
217
245
  Load KV caches from host memory to device memory.
218
246
  """
219
- device_indices = self.mem_pool_device.alloc(len(host_indices))
247
+ device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
220
248
  if device_indices is None:
221
249
  return None
222
250
  self.mem_pool_host.protect_load(host_indices)
251
+ # to ensure the device indices are ready before accessed by another CUDA stream
252
+ torch.cuda.current_stream().synchronize()
223
253
  self.load_queue.put(
224
254
  CacheOperation(host_indices, device_indices, node_id, priority)
225
255
  )
@@ -270,6 +300,42 @@ class HiCacheController:
270
300
  except Exception as e:
271
301
  logger.error(e)
272
302
 
303
+ def load_thread_func_layer_by_layer(self):
304
+ """
305
+ Load KV caches from host memory to device memory layer by layer.
306
+ """
307
+ with torch.cuda.stream(self.load_stream):
308
+ while not self.stop_event.is_set():
309
+ self.load_cache_event.wait(timeout=1)
310
+ if not self.load_cache_event.is_set():
311
+ continue
312
+ self.load_cache_event.clear()
313
+
314
+ batch_operation = None
315
+ while self.load_queue.qsize() > 0:
316
+ op = self.load_queue.get(block=True)
317
+ if batch_operation is None:
318
+ batch_operation = op
319
+ else:
320
+ batch_operation.merge(op)
321
+ if batch_operation is None:
322
+ continue
323
+
324
+ self.layer_done_counter.reset()
325
+ for i in range(self.mem_pool_host.layer_num):
326
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
327
+ batch_operation.host_indices, i
328
+ )
329
+ self.mem_pool_device.transfer_per_layer(
330
+ batch_operation.device_indices, flat_data, i
331
+ )
332
+ self.layer_done_counter.increment()
333
+
334
+ self.mem_pool_host.complete_io(batch_operation.host_indices)
335
+ for node_id in batch_operation.node_ids:
336
+ if node_id != 0:
337
+ self.ack_load_queue.put(node_id)
338
+
273
339
  def write_aux_func(self, no_wait=False):
274
340
  """
275
341
  Auxiliary function to prepare the buffer for write operations.
@@ -417,7 +483,7 @@ class HiCacheController:
417
483
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
418
484
  ) -> int:
419
485
  if self.mem_pool_host.is_synced(host_indices):
420
- self.mem_pool_device.free(device_indices)
486
+ self.mem_pool_device_allocator.free(device_indices)
421
487
  self.mem_pool_host.update_backup(host_indices)
422
488
  return len(device_indices)
423
489
  else:
@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
54
54
  class DataParallelController:
55
55
  """A controller that dispatches requests to multiple data parallel workers."""
56
56
 
57
- def __init__(self, server_args, port_args) -> None:
57
+ def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
58
58
  # Parse args
59
59
  self.max_total_num_tokens = None
60
60
  self.server_args = server_args