sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
14
14
 
15
15
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
16
16
  from sglang.srt.layers.quantization.fp8_kernel import (
17
+ fp8_dtype,
18
+ fp8_max,
19
+ is_fp8_fnuz,
17
20
  per_token_group_quant_fp8,
18
21
  scaled_fp8_quant,
19
22
  sglang_per_token_quant_fp8,
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
30
33
 
31
34
  _is_hip = is_hip()
32
35
  _is_cuda = is_cuda()
36
+ _is_fp8_fnuz = is_fp8_fnuz()
33
37
 
34
- if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
38
+ use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
39
+
40
+ if _is_hip and use_aiter_moe:
35
41
  from aiter import gemm_a8w8_blockscale
36
42
 
37
43
  if _is_cuda:
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
43
49
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
44
50
  TORCH_DEVICE_IDENTITY = None
45
51
 
46
- _TORCH_VERSION = torch.__version__.split("+")[0]
47
- try:
48
- _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
49
- except ValueError:
50
- _TORCH_VERSION_TUPLE = (0, 0, 0)
51
-
52
- # The condition to determine if it is on a platform that supports
53
- # torch._scaled_mm rowwise feature.
54
- # The condition is determined once as the operations
55
- # are time consuming.
56
- USE_ROWWISE_TORCH_SCALED_MM = (
57
- _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
58
- )
52
+
53
+ def use_rowwise_torch_scaled_mm():
54
+ _TORCH_VERSION = torch.__version__.split("+")[0]
55
+ try:
56
+ _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
57
+ except ValueError:
58
+ _TORCH_VERSION_TUPLE = (0, 0, 0)
59
+ if _is_hip:
60
+ # The condition to determine if it is on a platform that supports
61
+ # torch._scaled_mm rowwise feature.
62
+ # The condition is determined once as the operations
63
+ # are time consuming.
64
+ return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
65
+ return False
66
+
67
+
68
+ USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
59
69
 
60
70
 
61
71
  def cutlass_fp8_supported():
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
132
142
  output = fp8_blockwise_scaled_mm(
133
143
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
144
  )
135
- elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
145
+ elif _is_hip and use_aiter_moe:
136
146
  q_input, x_scale = per_token_group_quant_fp8(
137
147
  input_2d, block_size[1], column_major_scales=False
138
148
  )
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
164
174
 
165
175
 
166
176
  def input_to_float8(
167
- x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
177
+ x: torch.Tensor, dtype: torch.dtype = fp8_dtype
168
178
  ) -> Tuple[torch.Tensor, torch.Tensor]:
169
179
  """This function quantizes input values to float8 values with tensor-wise quantization."""
170
- finfo = torch.finfo(dtype)
171
180
  min_val, max_val = x.aminmax()
172
181
  amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
173
- fp8_max = finfo.max
174
- if _is_hip:
175
- dtype = torch.float8_e4m3fnuz
176
- fp8_max = 224.0
177
- scale = fp8_max / amax
178
- x_scl_sat = (x.float() * scale).clamp(min=-fp8_max, max=fp8_max)
182
+
183
+ if _is_fp8_fnuz:
184
+ dtype = fp8_dtype
185
+ fp_max = fp8_max
186
+ else:
187
+ finfo = torch.finfo(dtype)
188
+ fp_max = finfo.max
189
+
190
+ scale = fp_max / amax
191
+ x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
179
192
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
180
193
 
181
194
 
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
222
235
  return x_q_tensor, scale
223
236
 
224
237
 
238
+ def block_quant_dequant(
239
+ x_q_block: torch.Tensor,
240
+ x_s: torch.Tensor,
241
+ block_size: List[int],
242
+ dtype: torch.dtype,
243
+ ) -> torch.Tensor:
244
+ """This function converts block-wise quantization to unquantized.
245
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
246
+ and the block size.
247
+ The output is an unquantized tensor with dtype.
248
+ """
249
+ block_n, block_k = block_size[0], block_size[1]
250
+ n, k = x_q_block.shape
251
+ n_tiles = (n + block_n - 1) // block_n
252
+ k_tiles = (k + block_k - 1) // block_k
253
+ assert n_tiles == x_s.shape[0]
254
+ assert k_tiles == x_s.shape[1]
255
+
256
+ x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
257
+
258
+ for j in range(n_tiles):
259
+ for i in range(k_tiles):
260
+ x_q_block_tile = x_q_block[
261
+ j * block_n : min((j + 1) * block_n, n),
262
+ i * block_k : min((i + 1) * block_k, k),
263
+ ]
264
+ x_dq_block_tile = x_dq_block[
265
+ j * block_n : min((j + 1) * block_n, n),
266
+ i * block_k : min((i + 1) * block_k, k),
267
+ ]
268
+ x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
269
+
270
+ return x_dq_block
271
+
272
+
225
273
  def channel_quant_to_tensor_quant(
226
274
  x_q_channel: torch.Tensor,
227
275
  x_s: torch.Tensor,
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
8
8
  QuantizationConfig,
9
9
  QuantizeMethodBase,
10
10
  )
11
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
11
12
  from sglang.srt.layers.radix_attention import RadixAttention
12
- from sglang.srt.utils import is_hip
13
-
14
- _is_hip = is_hip()
15
13
 
16
14
  logger = logging.getLogger(__name__)
17
15
 
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
44
42
  torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
45
43
  )
46
44
 
47
- @classmethod
48
- def is_fp8_fnuz(cls) -> bool:
49
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
50
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
51
-
52
45
  def apply(self, layer: torch.nn.Module) -> torch.Tensor:
53
46
  raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
54
47
 
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
57
50
  # We prefer to use separate k_scale and v_scale if present
58
51
  k_scale = layer.k_scale.to("cpu").tolist()
59
52
  v_scale = layer.v_scale.to("cpu").tolist()
60
- if _is_hip and self.is_fp8_fnuz():
53
+ if is_fp8_fnuz():
61
54
  k_scale *= 2
62
55
  v_scale *= 2
63
56
  elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
73
66
  scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
67
  k_scale = scale_to_duplicate.to("cpu").tolist()
75
68
  v_scale = scale_to_duplicate.to("cpu").tolist()
76
- if _is_hip and self.is_fp8_fnuz():
69
+ if is_fp8_fnuz():
77
70
  k_scale *= 2
78
71
  v_scale *= 2
79
72
 
@@ -14,11 +14,6 @@ if not _is_cuda:
14
14
  from vllm._custom_ops import scaled_fp8_quant
15
15
 
16
16
 
17
- def is_fp8_fnuz() -> bool:
18
- # only device 0 is checked, this assumes MI300 platforms are homogeneous
19
- return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
20
-
21
-
22
17
  def is_layer_skipped(
23
18
  prefix: str,
24
19
  ignored_layers: List[str],
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
9
9
  QuantizationConfig,
10
10
  QuantizeMethodBase,
11
11
  )
12
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
12
+ from sglang.srt.layers.quantization.fp8_kernel import (
13
+ fp8_dtype,
14
+ is_fp8_fnuz,
15
+ per_token_group_quant_fp8,
16
+ )
13
17
  from sglang.srt.layers.quantization.fp8_utils import (
14
18
  apply_fp8_linear,
15
19
  cutlass_fp8_supported,
16
20
  input_to_float8,
17
21
  normalize_e4m3fn_to_e4m3fnuz,
18
22
  )
19
- from sglang.srt.utils import is_hip, set_weight_attrs
23
+ from sglang.srt.utils import set_weight_attrs
20
24
 
21
- _is_hip = is_hip()
25
+ _is_fp8_fnuz = is_fp8_fnuz()
22
26
 
23
27
 
24
28
  class W8A8Fp8Config(QuantizationConfig):
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
97
101
  if self.quantization_config.is_checkpoint_fp8_serialized:
98
102
  weight_scale = layer.weight_scale.detach()
99
103
  # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
100
- if _is_hip:
104
+ if _is_fp8_fnuz:
101
105
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
102
106
  weight=weight, weight_scale=weight_scale
103
107
  )
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
113
117
  layer.weight, layer.weight.shape[-1]
114
118
  )
115
119
  weight_scale = weight_scale.t().contiguous()
116
- if _is_hip:
117
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
118
- weight=weight, weight_scale=weight_scale
119
- )
120
120
  else:
121
121
  # if cutlass not supported, we fall back to use torch._scaled_mm
122
122
  # which requires per tensor quantization on weight
123
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
124
123
  qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
125
124
 
126
125
  # Update the layer with the new values.
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
227
226
  ):
228
227
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
229
228
 
230
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
231
229
  # WEIGHTS
232
230
  w13_weight = torch.nn.Parameter(
233
231
  torch.empty(
@@ -156,18 +156,15 @@ class LoRAManager:
156
156
  # set up batch info shared by all lora modules
157
157
  bs = forward_batch.batch_size
158
158
 
159
- if hasattr(self, "max_bs_in_cuda_graph") and bs <= self.max_bs_in_cuda_graph:
160
- # Do in-place updates when CUDA graph is enabled. Note that
161
- # if CUDA graph is enabled, the batch whose bs <= max_bs_in_cuda_graph
162
- # will also use these preallocated buffers, no matter whether
163
- # the batch can use CUDA graph or not.
159
+ if (
160
+ hasattr(self, "max_bs_in_cuda_graph")
161
+ and bs <= self.max_bs_in_cuda_graph
162
+ and forward_batch.forward_mode.is_cuda_graph()
163
+ ):
164
+ # Do in-place updates when CUDA graph is enabled and the batch forward mode
165
+ # could use CUDA graph.
164
166
  self.cuda_graph_batch_info.bs = bs
165
- if forward_batch.forward_mode.is_extend():
166
- self.cuda_graph_batch_info.seg_lens[:bs].copy_(
167
- forward_batch.extend_seq_lens
168
- )
169
- else:
170
- self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
167
+ self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
171
168
  torch.cumsum(
172
169
  self.cuda_graph_batch_info.seg_lens[:bs],
173
170
  dim=0,
@@ -201,10 +198,10 @@ class LoRAManager:
201
198
  max_len = int(torch.max(seg_lens))
202
199
  weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
203
200
 
204
- lora_ranks = torch.empty(
201
+ lora_ranks = torch.zeros(
205
202
  (self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
206
203
  )
207
- scalings = torch.empty(
204
+ scalings = torch.zeros(
208
205
  (self.max_loras_per_batch,), dtype=torch.float, device="cuda"
209
206
  )
210
207
  for i, lora_path in enumerate(forward_batch.lora_paths):
@@ -268,98 +268,97 @@ class HiCacheController:
268
268
  """
269
269
  Directly write through KV caches to host memory without buffering.
270
270
  """
271
- with torch.cuda.stream(self.write_stream):
272
- while not self.stop_event.is_set():
273
- try:
274
- operation = self.write_queue.get(block=True, timeout=1)
275
- self.mem_pool_host.write_page_all_layers(
276
- operation.host_indices,
277
- operation.device_indices,
278
- self.mem_pool_device,
279
- )
280
- self.write_stream.synchronize()
281
- self.mem_pool_host.complete_io(operation.host_indices)
282
- for node_id in operation.node_ids:
283
- if node_id != 0:
284
- self.ack_write_queue.put(node_id)
285
- except Empty:
286
- continue
287
- except Exception as e:
288
- logger.error(e)
271
+ torch.cuda.set_stream(self.write_stream)
272
+ while not self.stop_event.is_set():
273
+ try:
274
+ operation = self.write_queue.get(block=True, timeout=1)
275
+ self.mem_pool_host.write_page_all_layers(
276
+ operation.host_indices,
277
+ operation.device_indices,
278
+ self.mem_pool_device,
279
+ )
280
+ self.write_stream.synchronize()
281
+ self.mem_pool_host.complete_io(operation.host_indices)
282
+ for node_id in operation.node_ids:
283
+ if node_id != 0:
284
+ self.ack_write_queue.put(node_id)
285
+ except Empty:
286
+ continue
287
+ except Exception as e:
288
+ logger.error(e)
289
289
 
290
290
  def load_thread_func_direct(self):
291
291
  """
292
292
  Directly load KV caches from host memory to device memory without buffering.
293
293
  """
294
- with torch.cuda.stream(self.load_stream):
295
- while not self.stop_event.is_set():
296
- try:
297
- operation = self.load_queue.get(block=True, timeout=1)
298
- # time.sleep(18e-6 * len(operation.host_indices))
299
- operation.data = self.mem_pool_host.get_flat_data(
300
- operation.host_indices
301
- )
302
- self.mem_pool_device.transfer(
303
- operation.device_indices, operation.data
304
- )
305
- self.mem_pool_host.complete_io(operation.host_indices)
306
- for node_id in operation.node_ids:
307
- if node_id != 0:
308
- self.ack_load_queue.put(node_id)
309
- except Empty:
310
- continue
311
- except Exception as e:
312
- logger.error(e)
294
+ torch.cuda.set_stream(self.load_stream)
295
+ while not self.stop_event.is_set():
296
+ try:
297
+ operation = self.load_queue.get(block=True, timeout=1)
298
+ # time.sleep(18e-6 * len(operation.host_indices))
299
+ operation.data = self.mem_pool_host.get_flat_data(
300
+ operation.host_indices
301
+ )
302
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
303
+ self.mem_pool_host.complete_io(operation.host_indices)
304
+ for node_id in operation.node_ids:
305
+ if node_id != 0:
306
+ self.ack_load_queue.put(node_id)
307
+ except Empty:
308
+ continue
309
+ except Exception as e:
310
+ logger.error(e)
313
311
 
314
312
  def load_thread_func_layer_by_layer(self):
315
313
  """
316
314
  Load KV caches from host memory to device memory layer by layer.
317
315
  """
318
- with torch.cuda.stream(self.load_stream):
319
- while not self.stop_event.is_set():
320
- self.load_cache_event.wait(timeout=1)
321
- if not self.load_cache_event.is_set():
322
- continue
323
- self.load_cache_event.clear()
316
+ torch.cuda.set_stream(self.load_stream)
317
+ while not self.stop_event.is_set():
318
+ self.load_cache_event.wait(timeout=1)
319
+ if not self.load_cache_event.is_set():
320
+ continue
321
+ self.load_cache_event.clear()
324
322
 
325
- batch_operation = None
326
- while self.load_queue.qsize() > 0:
327
- op = self.load_queue.get(block=True)
328
- if batch_operation is None:
329
- batch_operation = op
330
- else:
331
- batch_operation.merge(op)
323
+ batch_operation = None
324
+ while self.load_queue.qsize() > 0:
325
+ op = self.load_queue.get(block=True)
332
326
  if batch_operation is None:
333
- continue
327
+ batch_operation = op
328
+ else:
329
+ batch_operation.merge(op)
330
+ if batch_operation is None:
331
+ continue
334
332
 
335
- self.layer_done_counter.reset()
336
- for i in range(self.mem_pool_host.layer_num):
337
- if self.page_size == 1:
338
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
339
- batch_operation.host_indices, i
340
- )
341
- self.mem_pool_device.transfer_per_layer(
342
- batch_operation.device_indices, flat_data, i
343
- )
344
- else:
345
- self.mem_pool_host.load_page_per_layer(
346
- batch_operation.host_indices,
347
- batch_operation.device_indices,
348
- self.mem_pool_device,
349
- i,
350
- )
351
- self.load_stream.synchronize()
352
- self.layer_done_counter.increment()
353
-
354
- self.mem_pool_host.complete_io(batch_operation.host_indices)
355
- for node_id in batch_operation.node_ids:
356
- if node_id != 0:
357
- self.ack_load_queue.put(node_id)
333
+ self.layer_done_counter.reset()
334
+ for i in range(self.mem_pool_host.layer_num):
335
+ if self.page_size == 1:
336
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
337
+ batch_operation.host_indices, i
338
+ )
339
+ self.mem_pool_device.transfer_per_layer(
340
+ batch_operation.device_indices, flat_data, i
341
+ )
342
+ else:
343
+ self.mem_pool_host.load_page_per_layer(
344
+ batch_operation.host_indices,
345
+ batch_operation.device_indices,
346
+ self.mem_pool_device,
347
+ i,
348
+ )
349
+ self.load_stream.synchronize()
350
+ self.layer_done_counter.increment()
351
+
352
+ self.mem_pool_host.complete_io(batch_operation.host_indices)
353
+ for node_id in batch_operation.node_ids:
354
+ if node_id != 0:
355
+ self.ack_load_queue.put(node_id)
358
356
 
359
357
  def write_aux_func(self, no_wait=False):
360
358
  """
361
359
  Auxiliary function to prepare the buffer for write operations.
362
360
  """
361
+ torch.cuda.set_stream(self.write_stream)
363
362
 
364
363
  def _to_op(op_):
365
364
  assert op_.device_indices.is_cuda, "Device indices should be on GPU"
@@ -370,44 +369,42 @@ class HiCacheController:
370
369
  return op_
371
370
 
372
371
  buffer = None
373
- with torch.cuda.stream(self.write_stream):
374
- while not self.stop_event.is_set():
375
- try:
376
- operation = self.write_queue.get(block=True, timeout=1)
377
- factor = (
378
- len(operation.device_indices)
379
- // self.write_buffer.max_buffer_size
380
- )
372
+ while not self.stop_event.is_set():
373
+ try:
374
+ operation = self.write_queue.get(block=True, timeout=1)
375
+ factor = (
376
+ len(operation.device_indices) // self.write_buffer.max_buffer_size
377
+ )
381
378
 
382
- if factor >= 1:
383
- if buffer is not None:
384
- _to_op(buffer)
385
- buffer = None
386
-
387
- if factor < 2:
388
- _to_op(operation)
389
- else:
390
- split_ops = operation.split(factor)
391
- for op_ in split_ops:
392
- _to_op(op_)
393
- continue
394
-
395
- if buffer is None:
396
- buffer = operation
397
- else:
398
- buffer.merge(operation)
399
- if (
400
- no_wait
401
- or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
402
- or self.write_queue.empty()
403
- or self.write_buffer.empty()
404
- ):
379
+ if factor >= 1:
380
+ if buffer is not None:
405
381
  _to_op(buffer)
406
382
  buffer = None
407
- except Empty:
383
+
384
+ if factor < 2:
385
+ _to_op(operation)
386
+ else:
387
+ split_ops = operation.split(factor)
388
+ for op_ in split_ops:
389
+ _to_op(op_)
408
390
  continue
409
- except Exception as e:
410
- logger.error(e)
391
+
392
+ if buffer is None:
393
+ buffer = operation
394
+ else:
395
+ buffer.merge(operation)
396
+ if (
397
+ no_wait
398
+ or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
399
+ or self.write_queue.empty()
400
+ or self.write_buffer.empty()
401
+ ):
402
+ _to_op(buffer)
403
+ buffer = None
404
+ except Empty:
405
+ continue
406
+ except Exception as e:
407
+ logger.error(e)
411
408
 
412
409
  def load_aux_func(self):
413
410
  """
@@ -484,19 +481,18 @@ class HiCacheController:
484
481
  aux_thread.join()
485
482
 
486
483
  def load_thread_func_buffer(self):
484
+ torch.cuda.set_stream(self.load_stream)
487
485
  aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
488
486
  aux_thread.start()
489
-
490
- with torch.cuda.stream(self.load_stream):
491
- while not self.stop_event.is_set():
492
- operation = self.load_buffer.get()
493
- if operation is None:
494
- continue
495
- self.mem_pool_device.transfer(operation.device_indices, operation.data)
496
- self.mem_pool_host.complete_io(operation.host_indices)
497
- for node_id in operation.node_ids:
498
- if node_id != 0:
499
- self.ack_load_queue.put(node_id)
487
+ while not self.stop_event.is_set():
488
+ operation = self.load_buffer.get()
489
+ if operation is None:
490
+ continue
491
+ self.mem_pool_device.transfer(operation.device_indices, operation.data)
492
+ self.mem_pool_host.complete_io(operation.host_indices)
493
+ for node_id in operation.node_ids:
494
+ if node_id != 0:
495
+ self.ack_load_queue.put(node_id)
500
496
  aux_thread.join()
501
497
 
502
498
  def evict_device(
@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
790
790
  pass
791
791
 
792
792
 
793
+ @dataclass
794
+ class SlowDownReqInput:
795
+ forward_sleep_time: Optional[float]
796
+
797
+
798
+ @dataclass
799
+ class SlowDownReqOutput:
800
+ pass
801
+
802
+
793
803
  @dataclass
794
804
  class AbortReq:
795
805
  # The request id
@@ -8,6 +8,7 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
+ import torch
11
12
  from PIL import Image
12
13
  from transformers import BaseImageProcessorFast
13
14
 
@@ -89,6 +90,10 @@ class BaseMultimodalProcessor(ABC):
89
90
  return_tensors="pt",
90
91
  **kwargs,
91
92
  )
93
+ if "pixel_values" in result and isinstance(
94
+ result["pixel_values"], torch.Tensor
95
+ ):
96
+ result["pixel_values"] = result["pixel_values"].to("cpu")
92
97
  return result
93
98
 
94
99
  @abstractmethod