sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -10,24 +10,14 @@ import numpy as np
10
10
  import torch
11
11
 
12
12
  from sglang.srt.distributed import get_tensor_model_parallel_rank
13
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
13
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
14
14
 
15
15
  DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
16
- DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
16
+ DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
17
17
 
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
- def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
- prefix_str = ""
23
- if prior_hash:
24
- prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
25
- current_token_ids_bytes = np.array(token_ids).tobytes()
26
- current_hash_object = hashlib.sha256(current_token_ids_bytes)
27
- current_hash_hex = current_hash_object.hexdigest()
28
- return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
29
-
30
-
31
21
  @dataclass
32
22
  class MooncakeStoreConfig:
33
23
  local_hostname: str
@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
54
44
  global_segment_size=config.get(
55
45
  "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
56
46
  ),
57
- local_buffer_size=config.get(
58
- "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
59
- ),
47
+ # Zero copy interface does not need local buffer
48
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
60
49
  protocol=config.get("protocol", "tcp"),
61
50
  device_name=config.get("device_name", "auto"),
62
51
  master_server_address=config.get("master_server_address"),
@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
79
68
  global_segment_size=int(
80
69
  os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
81
70
  ),
82
- local_buffer_size=int(
83
- os.getenv("MOONCAKE_LOCAL_BUFFER_SIZE", DEFAULT_LOCAL_BUFFER_SIZE)
84
- ),
71
+ # Zero copy interface does not need local buffer
72
+ local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
85
73
  protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
86
74
  device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
87
75
  master_server_address=os.getenv("MOONCAKE_MASTER"),
@@ -96,7 +84,7 @@ class MooncakeStoreConfig:
96
84
 
97
85
 
98
86
  class MooncakeStore(HiCacheStorage):
99
- def __init__(self, is_mla: bool = False):
87
+ def __init__(self, storage_config: HiCacheStorageConfig = None):
100
88
  try:
101
89
  from mooncake.store import MooncakeDistributedStore
102
90
  except ImportError as e:
@@ -126,7 +114,13 @@ class MooncakeStore(HiCacheStorage):
126
114
  logger.info("Connect to Mooncake store successfully.")
127
115
  self.warmup()
128
116
  logger.info("Mooncake store warmup successfully.")
129
- self.is_mla = is_mla
117
+
118
+ if storage_config is not None:
119
+ self.is_mla_backend = storage_config.is_mla_model
120
+ self.local_rank = storage_config.tp_rank
121
+ else:
122
+ self.is_mla_backend = False
123
+ self.local_rank = 0
130
124
 
131
125
  except ValueError as e:
132
126
  logger.error("Configuration loading failed: %s", e)
@@ -137,12 +131,10 @@ class MooncakeStore(HiCacheStorage):
137
131
 
138
132
  def warmup(self):
139
133
  warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
140
- # 10 MB
141
- warmup_value = bytes(10 * 1024 * 1024)
142
- self.store.put(warmup_key, warmup_value)
134
+ warmup_value = bytes(4 * 1024) # 4 KB
135
+ assert self.store.put(warmup_key, warmup_value) == 0
143
136
  assert self.store.is_exist(warmup_key) == 1
144
- self.store.get(warmup_key)
145
- self.store.remove(warmup_key)
137
+ assert self.store.get(warmup_key) == warmup_value
146
138
 
147
139
  def register_buffer(self, buffer: torch.Tensor) -> None:
148
140
  try:
@@ -162,78 +154,95 @@ class MooncakeStore(HiCacheStorage):
162
154
  target_location: Optional[List[int]] = None,
163
155
  target_sizes: Optional[List[int]] = None,
164
156
  ) -> bool:
165
- assert len(key) == len(target_location) == len(target_sizes)
166
- if len(key) == 0:
167
- return
168
-
169
- for i in range(len(key)):
170
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
171
- return
172
-
173
- self._put_batch_zero_copy_impl(key, target_location, target_sizes)
157
+ return self.batch_set([key], [value], [target_location], [target_sizes])
174
158
 
175
159
  def batch_set(
176
160
  self,
177
161
  keys: List[str],
178
- value: Optional[Any] = None,
179
162
  target_location: Optional[List[int]] = None,
180
163
  target_sizes: Optional[List[int]] = None,
181
164
  ) -> bool:
182
165
  assert len(keys) == len(target_location) == len(target_sizes)
183
166
  if len(keys) == 0:
184
- return
167
+ return False
185
168
 
186
169
  for i in range(len(keys)):
187
170
  if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
188
- return
171
+ return False
189
172
 
190
- self._put_batch_zero_copy_impl(keys, target_location, target_sizes)
173
+ exist_result = self._batch_exist(keys)
174
+ set_keys = []
175
+ set_target_locations = []
176
+ set_target_sizes = []
177
+ set_indices = []
178
+ for i in range(len(keys)):
179
+ if exist_result[i] != 1:
180
+ set_keys.append(keys[i])
181
+ set_target_locations.append(target_location[i])
182
+ set_target_sizes.append(target_sizes[i])
183
+ set_indices.append(i)
184
+ # Only set non-existing keys to storage
185
+ put_result = self._put_batch_zero_copy_impl(
186
+ set_keys, set_target_locations, set_target_sizes
187
+ )
188
+ for i in range(len(set_indices)):
189
+ if put_result[i] == 0:
190
+ exist_result[set_indices[i]] = 1
191
+
192
+ success_count = 0
193
+ for i in range(len(keys)):
194
+ if exist_result[i] == 0:
195
+ break
196
+ success_count += 1
197
+ # TODO: return the number of consecutive successful operations from the start.
198
+ return success_count == len(keys)
191
199
 
192
200
  def get(
193
201
  self,
194
202
  key,
195
203
  target_location: Optional[Any] = None,
196
204
  target_sizes: Optional[Any] = None,
197
- ) -> torch.Tensor | None:
198
- assert len(key) == len(target_location) == len(target_sizes)
199
- if len(key) == 0:
200
- return
201
-
202
- for i in range(len(key)):
203
- if key[i] is None or target_location[i] is None or target_sizes[i] is None:
204
- return
205
-
206
- return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
205
+ ) -> bool:
206
+ return self.batch_get([key], [target_location], [target_sizes]) == 1
207
207
 
208
208
  def batch_get(
209
209
  self,
210
210
  keys: List[str],
211
211
  target_location: Optional[Any] = None,
212
212
  target_sizes: Optional[Any] = None,
213
- ) -> torch.Tensor | None:
213
+ ) -> int:
214
214
  assert len(keys) == len(target_location) == len(target_sizes)
215
215
  if len(keys) == 0:
216
- return
217
-
216
+ return 0
217
+ get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
218
+ if self.is_mla_backend:
219
+ key_multiplier = 1
220
+ else:
221
+ key_multiplier = 2
218
222
  for i in range(len(keys)):
219
- if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
220
- return
221
-
222
- return self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
223
-
224
- def exists(self, keys) -> bool | dict:
225
- _keys = []
226
- local_rank = get_tensor_model_parallel_rank()
227
- for key in keys:
228
- if key is None:
229
- return None
230
-
231
- if self.is_mla:
232
- _keys.append(f"{key}_k")
233
- else:
234
- _keys.append(f"{key}_{local_rank}_k")
235
- result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
236
- return result
223
+ if get_result[i] < 0:
224
+ return i // key_multiplier
225
+ return len(keys) // key_multiplier
226
+
227
+ def exists(self, key) -> bool:
228
+ return self.batch_exists([key]) > 0
229
+
230
+ def batch_exists(self, keys) -> int:
231
+ if self.is_mla_backend:
232
+ query_keys = [f"{key}_k" for key in keys]
233
+ key_multiplier = 1
234
+ else:
235
+ query_keys = []
236
+ for key in keys:
237
+ query_keys.append(f"{key}_{self.local_rank}_k")
238
+ query_keys.append(f"{key}_{self.local_rank}_v")
239
+ key_multiplier = 2
240
+
241
+ exist_result = self._batch_exist(query_keys)
242
+ for i in range(len(query_keys)):
243
+ if exist_result[i] != 1:
244
+ return i // key_multiplier
245
+ return len(query_keys) // key_multiplier
237
246
 
238
247
  def delete(self, key) -> None:
239
248
  raise (NotImplementedError)
@@ -248,18 +257,13 @@ class MooncakeStore(HiCacheStorage):
248
257
 
249
258
  def _put_batch_zero_copy_impl(
250
259
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
251
- ) -> None:
252
- try:
253
- self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
254
- except TypeError as err:
255
- logger.error("Failed to put value to Mooncake Store: %s", err)
256
- raise TypeError("Mooncake Store Put Type Error.") from err
260
+ ) -> List[int]:
261
+ return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
257
262
 
258
263
  def _get_batch_zero_copy_impl(
259
264
  self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
260
- ) -> None:
261
- try:
262
- self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
263
- except TypeError as err:
264
- logger.error("Failed to get value from Mooncake Store: %s", err)
265
- raise TypeError("Mooncake Store Get Type Error.") from err
265
+ ) -> List[int]:
266
+ return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
267
+
268
+ def _batch_exist(self, key_strs: List[str]) -> List[int]:
269
+ return self.store.batch_is_exist(key_strs)
@@ -142,7 +142,7 @@ class SchedulerStats:
142
142
  spec_accept_length: float = 0.0
143
143
  avg_request_queue_latency: float = 0.0
144
144
  num_prefill_prealloc_queue_reqs: int = 0
145
- num_prefill_infight_queue_reqs: int = 0
145
+ num_prefill_inflight_queue_reqs: int = 0
146
146
  num_decode_prealloc_queue_reqs: int = 0
147
147
  num_decode_transfer_queue_reqs: int = 0
148
148
  total_retracted_reqs: int = 0
@@ -235,9 +235,9 @@ class SchedulerMetricsCollector:
235
235
  multiprocess_mode="mostrecent",
236
236
  )
237
237
 
238
- self.num_prefill_infight_queue_reqs = Gauge(
239
- name="sglang:num_prefill_infight_queue_reqs",
240
- documentation="The number of requests in the prefill infight queue.",
238
+ self.num_prefill_inflight_queue_reqs = Gauge(
239
+ name="sglang:num_prefill_inflight_queue_reqs",
240
+ documentation="The number of requests in the prefill inflight queue.",
241
241
  labelnames=labels.keys(),
242
242
  multiprocess_mode="mostrecent",
243
243
  )
@@ -294,7 +294,7 @@ class SchedulerMetricsCollector:
294
294
  self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
295
295
  )
296
296
  self._log_gauge(
297
- self.num_prefill_infight_queue_reqs, stats.num_prefill_infight_queue_reqs
297
+ self.num_prefill_inflight_queue_reqs, stats.num_prefill_inflight_queue_reqs
298
298
  )
299
299
  self._log_gauge(
300
300
  self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
@@ -54,7 +54,7 @@ from sglang.srt.utils import (
54
54
  empty_context,
55
55
  get_available_gpu_memory,
56
56
  get_device_memory_capacity,
57
- rank0_log,
57
+ log_info_on_rank0,
58
58
  require_attn_tp_gather,
59
59
  require_gathered_buffer,
60
60
  require_mlp_sync,
@@ -267,7 +267,7 @@ class CudaGraphRunner:
267
267
 
268
268
  # Batch sizes to capture
269
269
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
270
- rank0_log(f"Capture cuda graph bs {self.capture_bs}")
270
+ log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
271
271
  self.capture_forward_mode = ForwardMode.DECODE
272
272
  self.capture_hidden_mode = CaptureHiddenMode.NULL
273
273
  self.num_tokens_per_bs = 1
@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
66
66
  )
67
67
  from sglang.srt.layers.sampler import Sampler
68
68
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
69
- from sglang.srt.layers.utils import is_sm100_supported
70
69
  from sglang.srt.lora.lora_manager import LoRAManager
71
70
  from sglang.srt.lora.lora_registry import LoRARef
72
71
  from sglang.srt.managers.schedule_batch import (
@@ -121,6 +120,7 @@ from sglang.srt.utils import (
121
120
  is_hopper_with_cuda_12_3,
122
121
  is_no_spec_infer_or_topk_one,
123
122
  is_npu,
123
+ is_sm100_supported,
124
124
  monkey_patch_p2p_access_check,
125
125
  monkey_patch_vllm_gguf_config,
126
126
  set_cuda_arch,
@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
87
87
  block_dequant as int8_block_dequant,
88
88
  )
89
89
  from sglang.srt.layers.radix_attention import RadixAttention
90
- from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
91
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
90
+ from sglang.srt.layers.rotary_embedding import get_rope_wrapper
91
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
92
92
  from sglang.srt.layers.vocab_parallel_embedding import (
93
93
  ParallelLMHead,
94
94
  VocabParallelEmbedding,
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
114
114
  is_flashinfer_available,
115
115
  is_hip,
116
116
  is_non_idle_and_non_empty,
117
+ is_sm100_supported,
117
118
  log_info_on_rank0,
118
119
  make_layers,
119
120
  use_intel_amx_backend,
@@ -994,7 +995,14 @@ class DeepseekV2AttentionMLA(nn.Module):
994
995
  self.current_attention_backend = attention_backend
995
996
 
996
997
  if attention_backend == "ascend":
997
- return AttnForwardMethod.MLA
998
+ if (
999
+ forward_batch.forward_mode.is_extend()
1000
+ and not forward_batch.forward_mode.is_target_verify()
1001
+ and not forward_batch.forward_mode.is_draft_extend()
1002
+ ):
1003
+ return AttnForwardMethod.MHA
1004
+ else:
1005
+ return AttnForwardMethod.MLA
998
1006
  elif (
999
1007
  attention_backend == "flashinfer"
1000
1008
  or attention_backend == "fa3"
@@ -1292,6 +1300,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1292
1300
  or self.current_attention_backend == "flashinfer"
1293
1301
  or self.current_attention_backend == "cutlass_mla"
1294
1302
  or self.current_attention_backend == "trtllm_mla"
1303
+ or self.current_attention_backend == "ascend"
1295
1304
  ):
1296
1305
  extra_args = {}
1297
1306
  if self._fuse_rope_for_trtllm_mla(forward_batch):
@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
58
  from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
59
59
  from sglang.srt.layers.radix_attention import RadixAttention
60
60
  from sglang.srt.layers.rotary_embedding import get_rope
61
- from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
61
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
62
62
  from sglang.srt.layers.vocab_parallel_embedding import (
63
63
  ParallelLMHead,
64
64
  VocabParallelEmbedding,
@@ -71,6 +71,7 @@ from sglang.srt.utils import (
71
71
  add_prefix,
72
72
  is_cuda,
73
73
  is_flashinfer_available,
74
+ is_sm100_supported,
74
75
  make_layers,
75
76
  )
76
77
 
@@ -526,6 +526,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
526
526
  def get_input_embeddings(self):
527
527
  return self.model.embed_tokens
528
528
 
529
+ @torch.no_grad()
529
530
  def forward(
530
531
  self,
531
532
  input_ids: torch.Tensor,
sglang/srt/offloader.py CHANGED
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
321
321
  @staticmethod
322
322
  def create(mode: str, **kwargs) -> "_BaseParamOffloader":
323
323
  return {
324
+ "meta": _MetaParamOffloader,
324
325
  "cpu": _CpuParamOffloader,
325
326
  "shm_cpu": _ShmCpuParamOffloader,
326
327
  "sharded_gpu": _ShardedGpuParamOffloader,
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
341
342
  raise NotImplementedError
342
343
 
343
344
 
345
+ class _MetaParamOffloader(_BaseParamOffloader):
346
+ """Usually used for debugging."""
347
+
348
+ def __init__(self, module, param_name):
349
+ super().__init__(module, param_name)
350
+ _move_param_to_meta(module, param_name)
351
+
352
+ def create_device_tensor(self):
353
+ return torch.empty_like(self._param.data, device="cuda")
354
+
355
+
344
356
  class _CpuParamOffloader(_BaseParamOffloader):
345
357
  def __init__(self, module, param_name):
346
358
  super().__init__(module, param_name)
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
431
443
  device=device,
432
444
  pin_memory=pin_memory,
433
445
  )
446
+
447
+
448
+ # ----------------------------------------- ShardedGpu ------------------------------------------------------
449
+
450
+
451
+ # TODO unify with ShmCpu mode
452
+ class _ShardedGpuParamOffloader(_BaseParamOffloader):
453
+ def __init__(self, module, param_name):
454
+ super().__init__(module, param_name)
455
+ self._rank = get_naive_distributed().get_rank()
456
+ self._world_size = get_naive_distributed().get_world_size()
457
+
458
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
459
+
460
+ assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
461
+ assert (
462
+ self._param.data.is_contiguous()
463
+ ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
464
+
465
+ if self._rank == 0:
466
+ _move_param_to_cpu(self._param, pin_memory=True)
467
+ else:
468
+ _move_param_to_meta(self._module, self._param_name)
469
+
470
+ self.sharded_param_handles = None
471
+
472
+ def post_init(self):
473
+ # check again since it may be changed
474
+ assert (
475
+ self._param.data.is_contiguous()
476
+ ), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
477
+
478
+ scatter_src = self._param.data
479
+
480
+ logger.info(
481
+ f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}"
482
+ )
483
+
484
+ if self._rank == 0:
485
+ scatter_src = scatter_src.to("cuda")
486
+ scatter_list = _even_chunk(scatter_src, self._world_size)
487
+
488
+ sharded_param = torch.empty(
489
+ scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda"
490
+ )
491
+ self.sharded_param_handles = _create_shared_buffer_tensors(
492
+ local_tensor=sharded_param
493
+ )
494
+
495
+ get_naive_distributed().scatter(
496
+ sharded_param, scatter_list if self._rank == 0 else None
497
+ )
498
+
499
+ _move_param_to_meta(self._module, self._param_name)
500
+
501
+ def create_device_tensor(self):
502
+ output = _empty_strided_like(self._param, device="cuda")
503
+ output_chunks = output.chunk(self._world_size)
504
+
505
+ for index in range(self._world_size):
506
+ src_rank = (self._rank + index) % self._world_size
507
+ src_buf = self.sharded_param_handles[src_rank]
508
+ output_chunks[src_rank].copy_(src_buf)
509
+
510
+ return output
511
+
512
+
513
+ def _even_chunk(x: torch.Tensor, chunks: int):
514
+ assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}"
515
+ return list(x.chunk(chunks))
516
+
517
+
518
+ def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]:
519
+ self_rank = get_naive_distributed().get_rank()
520
+ world_size = get_naive_distributed().get_world_size()
521
+
522
+ object_list = get_naive_distributed().all_gather_object(
523
+ dict(
524
+ dup_serialized_local_tensor=[
525
+ (
526
+ None
527
+ if interesting_rank == self_rank
528
+ else MultiprocessingSerializer.serialize(local_tensor)
529
+ )
530
+ for interesting_rank in range(world_size)
531
+ ]
532
+ )
533
+ )
534
+
535
+ output_tensors = []
536
+ for output_rank in range(world_size):
537
+ remote_serialized_tensor = object_list[output_rank][
538
+ "dup_serialized_local_tensor"
539
+ ][self_rank]
540
+ if output_rank == self_rank:
541
+ assert remote_serialized_tensor is None
542
+ output_tensors.append(local_tensor)
543
+ else:
544
+ output_tensors.append(
545
+ MultiprocessingSerializer.deserialize(remote_serialized_tensor)
546
+ )
547
+
548
+ return output_tensors