sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -31,28 +31,44 @@ def _gate_up_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # This kernel packs 2 sgemms (gate/up) into a single kernel.
39
-
40
- # x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
41
- # weights: (num_lora, 2 * output_dim, K)
42
- # output: (s, 2 * output_dim)
37
+ """
38
+ This kernel packs 2 sgemms (gate/up) into a single kernel. The multiplication
39
+ results are accumulated into the output tensor.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
47
+ Shape: (s, 2 * K), where s is the sum of all sequence lengths in the
48
+ batch and K is the maximum LoRA rank.
49
+ weights (Tensor): The LoRA B weights for all adapters.
50
+ Shape: (num_lora, 2 * output_dim, K).
51
+ output (Tensor): The output tensor where the result is stored.
52
+ Shape: (s, 2 * output_dim).
53
+ """
43
54
  # output_dim >> K
44
55
 
45
56
  # Current block computes sequence with batch_id,
46
57
  # which starts from row seg_start of x with length seg_len.
47
58
  # gate_up_id decides which of gate or up (0: gate, 1: up)
48
59
  batch_id = tl.program_id(axis=2)
60
+ w_index = tl.load(weight_indices + batch_id)
61
+ rank = tl.load(lora_ranks + w_index)
62
+
63
+ # If rank is 0, this kernel is a no-op.
64
+ if rank == 0:
65
+ return
66
+
49
67
  gate_up_id = tl.program_id(axis=1)
50
68
  pid = tl.program_id(axis=0)
51
69
  seg_len = tl.load(seg_lens + batch_id)
52
- w_index = tl.load(weight_indices + batch_id)
53
70
  seg_start = tl.load(seg_indptr + batch_id)
54
71
  n_start = gate_up_id * output_dim # offset on output dim
55
- rank = tl.load(lora_ranks + w_index)
56
72
  scaling = tl.load(scalings + w_index)
57
73
 
58
74
  # Adjust K (rank) according to the specific LoRA adapter
@@ -82,14 +98,13 @@ def _gate_up_lora_b_kernel(
82
98
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
99
  x_tile = tl.load(
84
100
  x_ptrs,
85
- mask=(s_offset[:, None] < seg_len)
86
- and (k_offset[None, :] < K - k * BLOCK_K),
101
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
87
102
  other=0.0,
88
103
  )
89
104
  w_tile = tl.load(
90
105
  w_ptrs,
91
106
  mask=(k_offset[:, None] < K - k * BLOCK_K)
92
- and (n_offset[None, :] < output_dim),
107
+ & (n_offset[None, :] < output_dim),
93
108
  other=0.0,
94
109
  )
95
110
  partial_sum += tl.dot(x_tile, w_tile)
@@ -103,9 +118,8 @@ def _gate_up_lora_b_kernel(
103
118
  output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
104
119
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
105
120
  )
106
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
107
- if fuse_scaling_add:
108
- partial_sum += tl.load(output_ptr, mask=output_mask)
121
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < output_dim)
122
+ partial_sum += tl.load(output_ptr, mask=output_mask)
109
123
  tl.store(output_ptr, partial_sum, mask=output_mask)
110
124
 
111
125
 
@@ -143,11 +157,9 @@ def gate_up_lora_b_fwd(
143
157
  )
144
158
 
145
159
  if base_output is None:
146
- output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
147
- fuse_scaling_add = False
160
+ output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype)
148
161
  else:
149
162
  output = base_output
150
- fuse_scaling_add = True
151
163
 
152
164
  _gate_up_lora_b_kernel[grid_b](
153
165
  x,
@@ -169,7 +181,6 @@ def gate_up_lora_b_fwd(
169
181
  BLOCK_S,
170
182
  BLOCK_OUT,
171
183
  BLOCK_R,
172
- fuse_scaling_add,
173
184
  batch_info.scalings,
174
185
  )
175
186
 
@@ -33,29 +33,45 @@ def _qkv_lora_b_kernel(
33
33
  BLOCK_S: tl.constexpr,
34
34
  BLOCK_N: tl.constexpr,
35
35
  BLOCK_K: tl.constexpr,
36
- # For fused output scaling and adding
37
- fuse_scaling_add,
36
+ # For fused output scaling
38
37
  scalings,
39
38
  ):
40
- # This kernel packs 3 sgemms (q/k/v) into a single kernel.
41
-
42
- # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
43
- # weights: (num_lora, N_Q + 2 * N_KV, K)
44
- # output: (s, N_Q + 2 * N_KV)
45
- # N_Q >> K, N_KV >> K
39
+ """
40
+ This kernel packs 3 sgemms (q/k/v) into a single kernel. The multiplication
41
+ results are accumulated into the output tensor.
42
+
43
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
44
+ the convention in pytorch where the product of two matrices of shape (m, 0)
45
+ and (0, n) is an all-zero matrix of shape (m, n).
46
+
47
+ Args:
48
+ x (Tensor): The input tensor, which is the result of the LoRA A projection.
49
+ Shape: (s, 3 * K), where s is the sum of all sequence lengths in the
50
+ batch and K is the maximum LoRA rank. The second dimension is partitioned
51
+ for Q, K, and V.
52
+ weights (Tensor): The LoRA B weights for all adapters.
53
+ Shape: (num_lora, N_Q + 2 * N_KV, K).
54
+ output (Tensor): The output tensor where the result is stored.
55
+ Shape: (s, N_Q + 2 * N_KV).
56
+ """
46
57
 
47
58
  # Current block computes sequence with batch_id,
48
59
  # which starts from row seg_start of x with length seg_len.
49
60
  # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
50
61
  batch_id = tl.program_id(axis=2)
62
+ w_index = tl.load(weight_indices + batch_id)
63
+ rank = tl.load(lora_ranks + w_index)
64
+
65
+ # If rank is 0, this kernel is a no-op.
66
+ if rank == 0:
67
+ return
68
+
51
69
  qkv_id = tl.program_id(axis=1)
52
70
  pid = tl.program_id(axis=0)
53
71
  seg_len = tl.load(seg_lens + batch_id)
54
- w_index = tl.load(weight_indices + batch_id)
55
72
  seg_start = tl.load(seg_indptr + batch_id)
56
73
  n_start = tl.load(n_offs + qkv_id)
57
74
  n_size = tl.load(n_offs + qkv_id + 1) - n_start
58
- rank = tl.load(lora_ranks + w_index)
59
75
  scaling = tl.load(scalings + w_index)
60
76
  # Adjust K (rank) according to the specific LoRA adapter
61
77
  K = tl.minimum(K, rank)
@@ -84,13 +100,12 @@ def _qkv_lora_b_kernel(
84
100
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
101
  x_tile = tl.load(
86
102
  x_ptrs,
87
- mask=(s_offset[:, None] < seg_len)
88
- and (k_offset[None, :] < K - k * BLOCK_K),
103
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
89
104
  other=0.0,
90
105
  )
91
106
  w_tile = tl.load(
92
107
  w_ptrs,
93
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
108
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < n_size),
94
109
  other=0.0,
95
110
  )
96
111
  partial_sum += tl.dot(x_tile, w_tile)
@@ -105,8 +120,7 @@ def _qkv_lora_b_kernel(
105
120
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
106
121
  )
107
122
  output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
108
- if fuse_scaling_add:
109
- partial_sum += tl.load(output_ptr, mask=output_mask)
123
+ partial_sum += tl.load(output_ptr, mask=output_mask)
110
124
  tl.store(output_ptr, partial_sum, mask=output_mask)
111
125
 
112
126
 
@@ -153,11 +167,9 @@ def qkv_lora_b_fwd(
153
167
  )
154
168
 
155
169
  if base_output is None:
156
- output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
157
- fuse_scaling_add = False
170
+ output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
158
171
  else:
159
172
  output = base_output
160
- fuse_scaling_add = True
161
173
 
162
174
  _qkv_lora_b_kernel[grid_b](
163
175
  x,
@@ -180,7 +192,6 @@ def qkv_lora_b_fwd(
180
192
  BLOCK_S,
181
193
  BLOCK_OUT,
182
194
  BLOCK_R,
183
- fuse_scaling_add,
184
195
  batch_info.scalings,
185
196
  )
186
197
 
@@ -33,19 +33,36 @@ def _sgemm_lora_a_kernel(
33
33
  BLOCK_N: tl.constexpr,
34
34
  BLOCK_K: tl.constexpr,
35
35
  ):
36
-
37
- # x: (s, K), s is the sum of sequence lengths
38
- # weights: (num_lora, N, K)
39
- # output: (s, N)
36
+ """
37
+ Computes a segmented batched matrix multiplication for the LoRA A matrix.
38
+
39
+ The kernel ensures that output[seg_start:seg_start + seg_len, :rank * stack_num]
40
+ stores the product of the input `x` and the LoRA weights for the corresponding
41
+ sequence. This implies that when rank is 0, the kernel is essentially a no-op,
42
+ as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
43
+
44
+ Args:
45
+ x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
46
+ is the sum of all sequence lengths in the batch.
47
+ weights (torch.Tensor): The LoRA 'A' weights for all available adapters,
48
+ with shape `(num_lora, N, K)`.
49
+ output (torch.Tensor): The output tensor of shape `(s, N)`.
50
+ """
40
51
 
41
52
  # Current block computes sequence with batch_id,
42
53
  # which starts from row seg_start of x with length seg_len
43
54
  batch_id = tl.program_id(axis=1)
44
- pid = tl.program_id(axis=0)
45
- seg_len = tl.load(seg_lens + batch_id)
46
55
  w_index = tl.load(weight_indices + batch_id)
47
- seg_start = tl.load(seg_indptr + batch_id)
48
56
  rank = tl.load(lora_ranks + w_index)
57
+
58
+ # If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
59
+ if rank == 0:
60
+ return
61
+
62
+ pid = tl.program_id(axis=0)
63
+ seg_start = tl.load(seg_indptr + batch_id)
64
+ seg_len = tl.load(seg_lens + batch_id)
65
+
49
66
  # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
67
  N = tl.minimum(N, rank * stack_num)
51
68
 
@@ -72,13 +89,12 @@ def _sgemm_lora_a_kernel(
72
89
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
90
  x_tile = tl.load(
74
91
  x_ptrs,
75
- mask=(s_offset[:, None] < seg_len)
76
- and (k_offset[None, :] < K - k * BLOCK_K),
92
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
77
93
  other=0.0,
78
94
  )
79
95
  w_tile = tl.load(
80
96
  w_ptrs,
81
- mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
97
+ mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < N),
82
98
  other=0.0,
83
99
  )
84
100
  partial_sum += tl.dot(x_tile, w_tile)
@@ -91,7 +107,7 @@ def _sgemm_lora_a_kernel(
91
107
  output_ptr = (output + seg_start * output_stride_0) + (
92
108
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
93
109
  )
94
- output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
110
+ output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < N)
95
111
  tl.store(output_ptr, partial_sum, mask=output_mask)
96
112
 
97
113
 
@@ -31,22 +31,39 @@ def _sgemm_lora_b_kernel(
31
31
  BLOCK_S: tl.constexpr,
32
32
  BLOCK_N: tl.constexpr,
33
33
  BLOCK_K: tl.constexpr,
34
- # For fused output scaling and adding
35
- fuse_scaling_add,
34
+ # For fused output scaling
36
35
  scalings,
37
36
  ):
38
- # x: (s, K), s is the sum of sequence lengths
39
- # weights: (num_lora, N, K)
40
- # output: (s, N)
37
+ """
38
+ Computes a segmented batched matrix multiplication for the LoRA B matrix
39
+ and adds the result to the output in-place.
40
+
41
+ When a sequence's rank is 0, the kernel is essentially a no-op, following
42
+ the convention in pytorch where the product of two matrices of shape (m, 0)
43
+ and (0, n) is an all-zero matrix of shape (m, n).
44
+
45
+ Args:
46
+ x (torch.Tensor): The intermediate tensor from the LoRA 'A' multiplication,
47
+ of shape `(s, K)`, where `s` is the total number of tokens.
48
+ weights (torch.Tensor): The LoRA 'B' weights for all available adapters,
49
+ with shape `(num_lora, N, K)`.
50
+ output (torch.Tensor): The output tensor of shape `(s, N)`. This can be
51
+ the base model's output for a fused add operation.
52
+ """
41
53
 
42
54
  # Current block computes sequence with batch_id,
43
55
  # which starts from row seg_start of x with length seg_len
44
56
  batch_id = tl.program_id(axis=1)
57
+ w_index = tl.load(weight_indices + batch_id)
58
+ rank = tl.load(lora_ranks + w_index)
59
+
60
+ # If rank is 0, this kernel is a no-op.
61
+ if rank == 0:
62
+ return
63
+
45
64
  pid = tl.program_id(axis=0)
46
65
  seg_len = tl.load(seg_lens + batch_id)
47
- w_index = tl.load(weight_indices + batch_id)
48
66
  seg_start = tl.load(seg_indptr + batch_id)
49
- rank = tl.load(lora_ranks + w_index)
50
67
  scaling = tl.load(scalings + w_index)
51
68
  # Adjust K (rank) according to the specific LoRA adapter
52
69
  K = tl.minimum(K, rank)
@@ -74,8 +91,7 @@ def _sgemm_lora_b_kernel(
74
91
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
92
  x_tile = tl.load(
76
93
  x_ptrs,
77
- mask=(s_offset[:, None] < seg_len)
78
- and (k_offset[None, :] < K - k * BLOCK_K),
94
+ mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K),
79
95
  other=0.0,
80
96
  )
81
97
  w_tile = tl.load(
@@ -95,8 +111,7 @@ def _sgemm_lora_b_kernel(
95
111
  s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
96
112
  )
97
113
  output_mask = s_offset[:, None] < seg_len
98
- if fuse_scaling_add:
99
- partial_sum += tl.load(output_ptr, mask=output_mask)
114
+ partial_sum += tl.load(output_ptr, mask=output_mask)
100
115
  tl.store(output_ptr, partial_sum, mask=output_mask)
101
116
 
102
117
 
@@ -132,11 +147,9 @@ def sgemm_lora_b_fwd(
132
147
  )
133
148
 
134
149
  if base_output is None:
135
- output = torch.empty((S, N), device=x.device, dtype=x.dtype)
136
- fuse_scaling_add = False
150
+ output = torch.zeros((S, N), device=x.device, dtype=x.dtype)
137
151
  else:
138
152
  output = base_output
139
- fuse_scaling_add = True
140
153
 
141
154
  _sgemm_lora_b_kernel[grid](
142
155
  x,
@@ -158,7 +171,6 @@ def sgemm_lora_b_fwd(
158
171
  BLOCK_S,
159
172
  BLOCK_N,
160
173
  BLOCK_R,
161
- fuse_scaling_add,
162
174
  batch_info.scalings,
163
175
  )
164
176
  return output
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- import concurrent.futures
17
16
  import logging
18
17
  import math
19
18
  import threading
@@ -169,12 +168,23 @@ class HiCacheController:
169
168
  page_size: int,
170
169
  load_cache_event: threading.Event = None,
171
170
  write_policy: str = "write_through_selective",
171
+ io_backend: str = "",
172
172
  ):
173
173
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
174
174
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
175
175
  self.mem_pool_host = mem_pool_host
176
176
  self.write_policy = write_policy
177
177
  self.page_size = page_size
178
+ # using kernel for small page KV cache transfer and DMA for large pages
179
+ if not io_backend:
180
+ IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
181
+ self.io_backend = (
182
+ "direct"
183
+ if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
184
+ else "kernel"
185
+ )
186
+ else:
187
+ self.io_backend = io_backend
178
188
 
179
189
  self.load_cache_event = load_cache_event
180
190
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -203,12 +213,7 @@ class HiCacheController:
203
213
  self.load_stream = torch.cuda.Stream()
204
214
 
205
215
  self.write_thread = threading.Thread(
206
- target=(
207
- self.write_thread_func_buffer
208
- if self.page_size == 1
209
- else self.write_thread_func_direct
210
- ),
211
- daemon=True,
216
+ target=self.write_thread_func_direct, daemon=True
212
217
  )
213
218
  self.load_thread = threading.Thread(
214
219
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -229,12 +234,7 @@ class HiCacheController:
229
234
  self.ack_load_queue.queue.clear()
230
235
 
231
236
  self.write_thread = threading.Thread(
232
- target=(
233
- self.write_thread_func_buffer
234
- if self.page_size == 1
235
- else self.write_thread_func_direct
236
- ),
237
- daemon=True,
237
+ target=self.write_thread_func_direct, daemon=True
238
238
  )
239
239
  self.load_thread = threading.Thread(
240
240
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -281,6 +281,15 @@ class HiCacheController:
281
281
  )
282
282
  return device_indices
283
283
 
284
+ def move_indices(self, host_indices, device_indices):
285
+ # move indices to GPU if using kernels, to host if using direct indexing
286
+ if self.io_backend == "kernel":
287
+ return host_indices.to(self.mem_pool_device.device), device_indices
288
+ elif self.io_backend == "direct":
289
+ return host_indices, device_indices.cpu()
290
+ else:
291
+ raise ValueError(f"Unsupported io backend")
292
+
284
293
  def write_thread_func_direct(self):
285
294
  """
286
295
  Directly write through KV caches to host memory without buffering.
@@ -289,10 +298,14 @@ class HiCacheController:
289
298
  while not self.stop_event.is_set():
290
299
  try:
291
300
  operation = self.write_queue.get(block=True, timeout=1)
292
- self.mem_pool_host.write_page_all_layers(
293
- operation.host_indices,
294
- operation.device_indices,
295
- self.mem_pool_device,
301
+ host_indices, device_indices = self.move_indices(
302
+ operation.host_indices, operation.device_indices
303
+ )
304
+ self.mem_pool_device.backup_to_host_all_layer(
305
+ self.mem_pool_host,
306
+ host_indices,
307
+ device_indices,
308
+ self.io_backend,
296
309
  )
297
310
  self.write_stream.synchronize()
298
311
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -304,27 +317,6 @@ class HiCacheController:
304
317
  except Exception as e:
305
318
  logger.error(e)
306
319
 
307
- def load_thread_func_direct(self):
308
- """
309
- Directly load KV caches from host memory to device memory without buffering.
310
- """
311
- torch.cuda.set_stream(self.load_stream)
312
- while not self.stop_event.is_set():
313
- try:
314
- operation = self.load_queue.get(block=True, timeout=1)
315
- operation.data = self.mem_pool_host.get_flat_data(
316
- operation.host_indices
317
- )
318
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
319
- self.mem_pool_host.complete_io(operation.host_indices)
320
- for node_id in operation.node_ids:
321
- if node_id != 0:
322
- self.ack_load_queue.put(node_id)
323
- except Empty:
324
- continue
325
- except Exception as e:
326
- logger.error(e)
327
-
328
320
  def load_thread_func_layer_by_layer(self):
329
321
  """
330
322
  Load KV caches from host memory to device memory layer by layer.
@@ -349,22 +341,18 @@ class HiCacheController:
349
341
 
350
342
  # start layer-wise KV cache transfer from CPU to GPU
351
343
  self.layer_done_counter.reset()
344
+ host_indices, device_indices = self.move_indices(
345
+ batch_operation.host_indices, batch_operation.device_indices
346
+ )
352
347
  for i in range(self.mem_pool_host.layer_num):
353
- if self.page_size == 1:
354
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
355
- batch_operation.host_indices, i
356
- )
357
- self.mem_pool_device.transfer_per_layer(
358
- batch_operation.device_indices, flat_data, i
359
- )
360
- else:
361
- self.mem_pool_host.load_page_per_layer(
362
- batch_operation.host_indices,
363
- batch_operation.device_indices,
364
- self.mem_pool_device,
365
- i,
366
- )
367
- self.load_stream.synchronize()
348
+ self.mem_pool_device.load_from_host_per_layer(
349
+ self.mem_pool_host,
350
+ host_indices,
351
+ device_indices,
352
+ i,
353
+ self.io_backend,
354
+ )
355
+ self.load_stream.synchronize()
368
356
  self.layer_done_counter.increment()
369
357
 
370
358
  self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -372,148 +360,6 @@ class HiCacheController:
372
360
  if node_id != 0:
373
361
  self.ack_load_queue.put(node_id)
374
362
 
375
- def write_aux_func(self, no_wait=False):
376
- """
377
- Auxiliary function to prepare the buffer for write operations.
378
- """
379
- torch.cuda.set_stream(self.write_stream)
380
-
381
- def _to_op(op_):
382
- assert op_.device_indices.is_cuda, "Device indices should be on GPU"
383
- op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
384
- self.mem_pool_host.device
385
- )
386
- self.write_buffer.put(op_)
387
- return op_
388
-
389
- buffer = None
390
- while not self.stop_event.is_set():
391
- try:
392
- operation = self.write_queue.get(block=True, timeout=1)
393
- factor = (
394
- len(operation.device_indices) // self.write_buffer.max_buffer_size
395
- )
396
-
397
- if factor >= 1:
398
- if buffer is not None:
399
- _to_op(buffer)
400
- buffer = None
401
-
402
- if factor < 2:
403
- _to_op(operation)
404
- else:
405
- split_ops = operation.split(factor)
406
- for op_ in split_ops:
407
- _to_op(op_)
408
- continue
409
-
410
- if buffer is None:
411
- buffer = operation
412
- else:
413
- buffer.merge(operation)
414
- if (
415
- no_wait
416
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
417
- or self.write_queue.empty()
418
- or self.write_buffer.empty()
419
- ):
420
- _to_op(buffer)
421
- buffer = None
422
- except Empty:
423
- continue
424
- except Exception as e:
425
- logger.error(e)
426
-
427
- def load_aux_func(self):
428
- """
429
- Auxiliary function to prepare the buffer for load operations.
430
- """
431
-
432
- def _pin_op(op_, put=True):
433
- op_.data = (
434
- self.mem_pool_host.get_flat_data(op_.host_indices)
435
- .contiguous()
436
- .pin_memory()
437
- )
438
- if put:
439
- self.load_buffer.put(op_)
440
- return op_
441
-
442
- buffer = None
443
- while not self.stop_event.is_set():
444
- try:
445
- operation = self.load_queue.get(block=True, timeout=1)
446
- factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
447
-
448
- if factor >= 1:
449
- if buffer is not None:
450
- _pin_op(buffer)
451
- buffer = None
452
-
453
- if factor < 2:
454
- _pin_op(operation)
455
- else:
456
- split_ops = operation.split(factor)
457
- split_args = [(op_, True) for op_ in split_ops[:-1]]
458
- split_args.append((split_ops[-1], False))
459
- # Spawn threads to pin each op concurrently
460
- with concurrent.futures.ThreadPoolExecutor() as executor:
461
- pinned_ops = list(
462
- executor.map(
463
- lambda x: _pin_op(x[0], put=x[1]), split_args
464
- )
465
- )
466
- # preserve the order of last op to ensure correct ack
467
- self.load_buffer.put(pinned_ops[-1])
468
- continue
469
-
470
- if buffer is None:
471
- buffer = operation
472
- else:
473
- buffer.merge(operation)
474
- if (
475
- len(buffer.host_indices) >= self.load_buffer.max_buffer_size
476
- or self.load_queue.empty()
477
- or self.load_buffer.empty()
478
- ):
479
- _pin_op(buffer)
480
- buffer = None
481
- except Empty:
482
- continue
483
- except Exception as e:
484
- logger.error(e)
485
-
486
- # todo (zhiqiang): double buffering to be deprecated
487
- def write_thread_func_buffer(self):
488
- aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
489
- aux_thread.start()
490
-
491
- while not self.stop_event.is_set():
492
- operation = self.write_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_write_queue.put(node_id)
500
- aux_thread.join()
501
-
502
- def load_thread_func_buffer(self):
503
- torch.cuda.set_stream(self.load_stream)
504
- aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
505
- aux_thread.start()
506
- while not self.stop_event.is_set():
507
- operation = self.load_buffer.get()
508
- if operation is None:
509
- continue
510
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
511
- self.mem_pool_host.complete_io(operation.host_indices)
512
- for node_id in operation.node_ids:
513
- if node_id != 0:
514
- self.ack_load_queue.put(node_id)
515
- aux_thread.join()
516
-
517
363
  def evict_device(
518
364
  self, device_indices: torch.Tensor, host_indices: torch.Tensor
519
365
  ) -> int:
@@ -28,7 +28,7 @@ if __name__ == "__main__":
28
28
  parser = argparse.ArgumentParser()
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  parser.add_argument("--log-requests", action="store_true")
31
- parser.add_argument("--log-requests-level", type=int, default=2)
31
+ parser.add_argument("--log-requests-level", type=int, default=3)
32
32
  parser.add_argument(
33
33
  "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
34
34
  )