sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/decode.py +4 -0
- sglang/srt/disaggregation/prefill.py +4 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/openai/protocol.py +27 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/entrypoints/tool.py +7 -7
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +16 -7
- sglang/srt/layers/attention/ascend_backend.py +218 -111
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
- sglang/srt/layers/attention/utils.py +15 -94
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/moe/cutlass_moe.py +0 -15
- sglang/srt/layers/moe/ep_moe/layer.py +1 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/mxfp4.py +16 -23
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/lora_manager.py +29 -12
- sglang/srt/managers/cache_controller.py +223 -156
- sglang/srt/managers/detokenizer_manager.py +5 -0
- sglang/srt/managers/io_struct.py +30 -0
- sglang/srt/managers/scheduler.py +58 -7
- sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
- sglang/srt/managers/tokenizer_manager.py +36 -3
- sglang/srt/mem_cache/hicache_storage.py +31 -20
- sglang/srt/mem_cache/hiradix_cache.py +12 -3
- sglang/srt/mem_cache/memory_pool.py +73 -14
- sglang/srt/mem_cache/memory_pool_host.py +3 -2
- sglang/srt/mem_cache/radix_cache.py +1 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
- sglang/srt/metrics/collector.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +12 -3
- sglang/srt/models/gpt_oss.py +2 -1
- sglang/srt/models/qwen2_5_vl.py +1 -0
- sglang/srt/offloader.py +115 -0
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/server_args.py +10 -5
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +59 -12
- sglang/test/test_cutlass_moe.py +33 -28
- sglang/version.py +1 -1
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -10,24 +10,14 @@ import numpy as np
|
|
10
10
|
import torch
|
11
11
|
|
12
12
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
|
13
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
|
14
14
|
|
15
15
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
|
16
|
-
DEFAULT_LOCAL_BUFFER_SIZE =
|
16
|
+
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
20
20
|
|
21
|
-
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
|
22
|
-
prefix_str = ""
|
23
|
-
if prior_hash:
|
24
|
-
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
|
25
|
-
current_token_ids_bytes = np.array(token_ids).tobytes()
|
26
|
-
current_hash_object = hashlib.sha256(current_token_ids_bytes)
|
27
|
-
current_hash_hex = current_hash_object.hexdigest()
|
28
|
-
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}"
|
29
|
-
|
30
|
-
|
31
21
|
@dataclass
|
32
22
|
class MooncakeStoreConfig:
|
33
23
|
local_hostname: str
|
@@ -54,9 +44,8 @@ class MooncakeStoreConfig:
|
|
54
44
|
global_segment_size=config.get(
|
55
45
|
"global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
|
56
46
|
),
|
57
|
-
|
58
|
-
|
59
|
-
),
|
47
|
+
# Zero copy interface does not need local buffer
|
48
|
+
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
60
49
|
protocol=config.get("protocol", "tcp"),
|
61
50
|
device_name=config.get("device_name", "auto"),
|
62
51
|
master_server_address=config.get("master_server_address"),
|
@@ -79,9 +68,8 @@ class MooncakeStoreConfig:
|
|
79
68
|
global_segment_size=int(
|
80
69
|
os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE)
|
81
70
|
),
|
82
|
-
|
83
|
-
|
84
|
-
),
|
71
|
+
# Zero copy interface does not need local buffer
|
72
|
+
local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE,
|
85
73
|
protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"),
|
86
74
|
device_name=os.getenv("MOONCAKE_DEVICE", "auto"),
|
87
75
|
master_server_address=os.getenv("MOONCAKE_MASTER"),
|
@@ -96,7 +84,7 @@ class MooncakeStoreConfig:
|
|
96
84
|
|
97
85
|
|
98
86
|
class MooncakeStore(HiCacheStorage):
|
99
|
-
def __init__(self,
|
87
|
+
def __init__(self, storage_config: HiCacheStorageConfig = None):
|
100
88
|
try:
|
101
89
|
from mooncake.store import MooncakeDistributedStore
|
102
90
|
except ImportError as e:
|
@@ -126,7 +114,13 @@ class MooncakeStore(HiCacheStorage):
|
|
126
114
|
logger.info("Connect to Mooncake store successfully.")
|
127
115
|
self.warmup()
|
128
116
|
logger.info("Mooncake store warmup successfully.")
|
129
|
-
|
117
|
+
|
118
|
+
if storage_config is not None:
|
119
|
+
self.is_mla_backend = storage_config.is_mla_model
|
120
|
+
self.local_rank = storage_config.tp_rank
|
121
|
+
else:
|
122
|
+
self.is_mla_backend = False
|
123
|
+
self.local_rank = 0
|
130
124
|
|
131
125
|
except ValueError as e:
|
132
126
|
logger.error("Configuration loading failed: %s", e)
|
@@ -137,12 +131,10 @@ class MooncakeStore(HiCacheStorage):
|
|
137
131
|
|
138
132
|
def warmup(self):
|
139
133
|
warmup_key = "sglang_mooncake_store_warmup_key" + uuid.uuid4().hex
|
140
|
-
#
|
141
|
-
|
142
|
-
self.store.put(warmup_key, warmup_value)
|
134
|
+
warmup_value = bytes(4 * 1024) # 4 KB
|
135
|
+
assert self.store.put(warmup_key, warmup_value) == 0
|
143
136
|
assert self.store.is_exist(warmup_key) == 1
|
144
|
-
self.store.get(warmup_key)
|
145
|
-
self.store.remove(warmup_key)
|
137
|
+
assert self.store.get(warmup_key) == warmup_value
|
146
138
|
|
147
139
|
def register_buffer(self, buffer: torch.Tensor) -> None:
|
148
140
|
try:
|
@@ -162,78 +154,95 @@ class MooncakeStore(HiCacheStorage):
|
|
162
154
|
target_location: Optional[List[int]] = None,
|
163
155
|
target_sizes: Optional[List[int]] = None,
|
164
156
|
) -> bool:
|
165
|
-
|
166
|
-
if len(key) == 0:
|
167
|
-
return
|
168
|
-
|
169
|
-
for i in range(len(key)):
|
170
|
-
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
171
|
-
return
|
172
|
-
|
173
|
-
self._put_batch_zero_copy_impl(key, target_location, target_sizes)
|
157
|
+
return self.batch_set([key], [value], [target_location], [target_sizes])
|
174
158
|
|
175
159
|
def batch_set(
|
176
160
|
self,
|
177
161
|
keys: List[str],
|
178
|
-
value: Optional[Any] = None,
|
179
162
|
target_location: Optional[List[int]] = None,
|
180
163
|
target_sizes: Optional[List[int]] = None,
|
181
164
|
) -> bool:
|
182
165
|
assert len(keys) == len(target_location) == len(target_sizes)
|
183
166
|
if len(keys) == 0:
|
184
|
-
return
|
167
|
+
return False
|
185
168
|
|
186
169
|
for i in range(len(keys)):
|
187
170
|
if keys[i] is None or target_location[i] is None or target_sizes[i] is None:
|
188
|
-
return
|
171
|
+
return False
|
189
172
|
|
190
|
-
self.
|
173
|
+
exist_result = self._batch_exist(keys)
|
174
|
+
set_keys = []
|
175
|
+
set_target_locations = []
|
176
|
+
set_target_sizes = []
|
177
|
+
set_indices = []
|
178
|
+
for i in range(len(keys)):
|
179
|
+
if exist_result[i] != 1:
|
180
|
+
set_keys.append(keys[i])
|
181
|
+
set_target_locations.append(target_location[i])
|
182
|
+
set_target_sizes.append(target_sizes[i])
|
183
|
+
set_indices.append(i)
|
184
|
+
# Only set non-existing keys to storage
|
185
|
+
put_result = self._put_batch_zero_copy_impl(
|
186
|
+
set_keys, set_target_locations, set_target_sizes
|
187
|
+
)
|
188
|
+
for i in range(len(set_indices)):
|
189
|
+
if put_result[i] == 0:
|
190
|
+
exist_result[set_indices[i]] = 1
|
191
|
+
|
192
|
+
success_count = 0
|
193
|
+
for i in range(len(keys)):
|
194
|
+
if exist_result[i] == 0:
|
195
|
+
break
|
196
|
+
success_count += 1
|
197
|
+
# TODO: return the number of consecutive successful operations from the start.
|
198
|
+
return success_count == len(keys)
|
191
199
|
|
192
200
|
def get(
|
193
201
|
self,
|
194
202
|
key,
|
195
203
|
target_location: Optional[Any] = None,
|
196
204
|
target_sizes: Optional[Any] = None,
|
197
|
-
) ->
|
198
|
-
|
199
|
-
if len(key) == 0:
|
200
|
-
return
|
201
|
-
|
202
|
-
for i in range(len(key)):
|
203
|
-
if key[i] is None or target_location[i] is None or target_sizes[i] is None:
|
204
|
-
return
|
205
|
-
|
206
|
-
return self._get_batch_zero_copy_impl(key, target_location, target_sizes)
|
205
|
+
) -> bool:
|
206
|
+
return self.batch_get([key], [target_location], [target_sizes]) == 1
|
207
207
|
|
208
208
|
def batch_get(
|
209
209
|
self,
|
210
210
|
keys: List[str],
|
211
211
|
target_location: Optional[Any] = None,
|
212
212
|
target_sizes: Optional[Any] = None,
|
213
|
-
) ->
|
213
|
+
) -> int:
|
214
214
|
assert len(keys) == len(target_location) == len(target_sizes)
|
215
215
|
if len(keys) == 0:
|
216
|
-
return
|
217
|
-
|
216
|
+
return 0
|
217
|
+
get_result = self._get_batch_zero_copy_impl(keys, target_location, target_sizes)
|
218
|
+
if self.is_mla_backend:
|
219
|
+
key_multiplier = 1
|
220
|
+
else:
|
221
|
+
key_multiplier = 2
|
218
222
|
for i in range(len(keys)):
|
219
|
-
if
|
220
|
-
return
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
223
|
+
if get_result[i] < 0:
|
224
|
+
return i // key_multiplier
|
225
|
+
return len(keys) // key_multiplier
|
226
|
+
|
227
|
+
def exists(self, key) -> bool:
|
228
|
+
return self.batch_exists([key]) > 0
|
229
|
+
|
230
|
+
def batch_exists(self, keys) -> int:
|
231
|
+
if self.is_mla_backend:
|
232
|
+
query_keys = [f"{key}_k" for key in keys]
|
233
|
+
key_multiplier = 1
|
234
|
+
else:
|
235
|
+
query_keys = []
|
236
|
+
for key in keys:
|
237
|
+
query_keys.append(f"{key}_{self.local_rank}_k")
|
238
|
+
query_keys.append(f"{key}_{self.local_rank}_v")
|
239
|
+
key_multiplier = 2
|
240
|
+
|
241
|
+
exist_result = self._batch_exist(query_keys)
|
242
|
+
for i in range(len(query_keys)):
|
243
|
+
if exist_result[i] != 1:
|
244
|
+
return i // key_multiplier
|
245
|
+
return len(query_keys) // key_multiplier
|
237
246
|
|
238
247
|
def delete(self, key) -> None:
|
239
248
|
raise (NotImplementedError)
|
@@ -248,18 +257,13 @@ class MooncakeStore(HiCacheStorage):
|
|
248
257
|
|
249
258
|
def _put_batch_zero_copy_impl(
|
250
259
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
251
|
-
) ->
|
252
|
-
|
253
|
-
self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
254
|
-
except TypeError as err:
|
255
|
-
logger.error("Failed to put value to Mooncake Store: %s", err)
|
256
|
-
raise TypeError("Mooncake Store Put Type Error.") from err
|
260
|
+
) -> List[int]:
|
261
|
+
return self.store.batch_put_from(key_strs, buffer_ptrs, buffer_sizes)
|
257
262
|
|
258
263
|
def _get_batch_zero_copy_impl(
|
259
264
|
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
|
260
|
-
) ->
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
raise TypeError("Mooncake Store Get Type Error.") from err
|
265
|
+
) -> List[int]:
|
266
|
+
return self.store.batch_get_into(key_strs, buffer_ptrs, buffer_sizes)
|
267
|
+
|
268
|
+
def _batch_exist(self, key_strs: List[str]) -> List[int]:
|
269
|
+
return self.store.batch_is_exist(key_strs)
|
sglang/srt/metrics/collector.py
CHANGED
@@ -142,7 +142,7 @@ class SchedulerStats:
|
|
142
142
|
spec_accept_length: float = 0.0
|
143
143
|
avg_request_queue_latency: float = 0.0
|
144
144
|
num_prefill_prealloc_queue_reqs: int = 0
|
145
|
-
|
145
|
+
num_prefill_inflight_queue_reqs: int = 0
|
146
146
|
num_decode_prealloc_queue_reqs: int = 0
|
147
147
|
num_decode_transfer_queue_reqs: int = 0
|
148
148
|
total_retracted_reqs: int = 0
|
@@ -235,9 +235,9 @@ class SchedulerMetricsCollector:
|
|
235
235
|
multiprocess_mode="mostrecent",
|
236
236
|
)
|
237
237
|
|
238
|
-
self.
|
239
|
-
name="sglang:
|
240
|
-
documentation="The number of requests in the prefill
|
238
|
+
self.num_prefill_inflight_queue_reqs = Gauge(
|
239
|
+
name="sglang:num_prefill_inflight_queue_reqs",
|
240
|
+
documentation="The number of requests in the prefill inflight queue.",
|
241
241
|
labelnames=labels.keys(),
|
242
242
|
multiprocess_mode="mostrecent",
|
243
243
|
)
|
@@ -294,7 +294,7 @@ class SchedulerMetricsCollector:
|
|
294
294
|
self.num_prefill_prealloc_queue_reqs, stats.num_prefill_prealloc_queue_reqs
|
295
295
|
)
|
296
296
|
self._log_gauge(
|
297
|
-
self.
|
297
|
+
self.num_prefill_inflight_queue_reqs, stats.num_prefill_inflight_queue_reqs
|
298
298
|
)
|
299
299
|
self._log_gauge(
|
300
300
|
self.num_decode_prealloc_queue_reqs, stats.num_decode_prealloc_queue_reqs
|
@@ -54,7 +54,7 @@ from sglang.srt.utils import (
|
|
54
54
|
empty_context,
|
55
55
|
get_available_gpu_memory,
|
56
56
|
get_device_memory_capacity,
|
57
|
-
|
57
|
+
log_info_on_rank0,
|
58
58
|
require_attn_tp_gather,
|
59
59
|
require_gathered_buffer,
|
60
60
|
require_mlp_sync,
|
@@ -267,7 +267,7 @@ class CudaGraphRunner:
|
|
267
267
|
|
268
268
|
# Batch sizes to capture
|
269
269
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
270
|
-
|
270
|
+
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
271
271
|
self.capture_forward_mode = ForwardMode.DECODE
|
272
272
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
273
273
|
self.num_tokens_per_bs = 1
|
@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
|
|
66
66
|
)
|
67
67
|
from sglang.srt.layers.sampler import Sampler
|
68
68
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
69
|
-
from sglang.srt.layers.utils import is_sm100_supported
|
70
69
|
from sglang.srt.lora.lora_manager import LoRAManager
|
71
70
|
from sglang.srt.lora.lora_registry import LoRARef
|
72
71
|
from sglang.srt.managers.schedule_batch import (
|
@@ -121,6 +120,7 @@ from sglang.srt.utils import (
|
|
121
120
|
is_hopper_with_cuda_12_3,
|
122
121
|
is_no_spec_infer_or_topk_one,
|
123
122
|
is_npu,
|
123
|
+
is_sm100_supported,
|
124
124
|
monkey_patch_p2p_access_check,
|
125
125
|
monkey_patch_vllm_gguf_config,
|
126
126
|
set_cuda_arch,
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
87
87
|
block_dequant as int8_block_dequant,
|
88
88
|
)
|
89
89
|
from sglang.srt.layers.radix_attention import RadixAttention
|
90
|
-
from sglang.srt.layers.rotary_embedding import
|
91
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
90
|
+
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
91
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
92
92
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
93
93
|
ParallelLMHead,
|
94
94
|
VocabParallelEmbedding,
|
@@ -114,6 +114,7 @@ from sglang.srt.utils import (
|
|
114
114
|
is_flashinfer_available,
|
115
115
|
is_hip,
|
116
116
|
is_non_idle_and_non_empty,
|
117
|
+
is_sm100_supported,
|
117
118
|
log_info_on_rank0,
|
118
119
|
make_layers,
|
119
120
|
use_intel_amx_backend,
|
@@ -994,7 +995,14 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
994
995
|
self.current_attention_backend = attention_backend
|
995
996
|
|
996
997
|
if attention_backend == "ascend":
|
997
|
-
|
998
|
+
if (
|
999
|
+
forward_batch.forward_mode.is_extend()
|
1000
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1001
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1002
|
+
):
|
1003
|
+
return AttnForwardMethod.MHA
|
1004
|
+
else:
|
1005
|
+
return AttnForwardMethod.MLA
|
998
1006
|
elif (
|
999
1007
|
attention_backend == "flashinfer"
|
1000
1008
|
or attention_backend == "fa3"
|
@@ -1292,6 +1300,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1292
1300
|
or self.current_attention_backend == "flashinfer"
|
1293
1301
|
or self.current_attention_backend == "cutlass_mla"
|
1294
1302
|
or self.current_attention_backend == "trtllm_mla"
|
1303
|
+
or self.current_attention_backend == "ascend"
|
1295
1304
|
):
|
1296
1305
|
extra_args = {}
|
1297
1306
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
sglang/srt/models/gpt_oss.py
CHANGED
@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
58
58
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
59
59
|
from sglang.srt.layers.radix_attention import RadixAttention
|
60
60
|
from sglang.srt.layers.rotary_embedding import get_rope
|
61
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
61
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
63
63
|
ParallelLMHead,
|
64
64
|
VocabParallelEmbedding,
|
@@ -71,6 +71,7 @@ from sglang.srt.utils import (
|
|
71
71
|
add_prefix,
|
72
72
|
is_cuda,
|
73
73
|
is_flashinfer_available,
|
74
|
+
is_sm100_supported,
|
74
75
|
make_layers,
|
75
76
|
)
|
76
77
|
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
sglang/srt/offloader.py
CHANGED
@@ -321,6 +321,7 @@ class _BaseParamOffloader(ABC):
|
|
321
321
|
@staticmethod
|
322
322
|
def create(mode: str, **kwargs) -> "_BaseParamOffloader":
|
323
323
|
return {
|
324
|
+
"meta": _MetaParamOffloader,
|
324
325
|
"cpu": _CpuParamOffloader,
|
325
326
|
"shm_cpu": _ShmCpuParamOffloader,
|
326
327
|
"sharded_gpu": _ShardedGpuParamOffloader,
|
@@ -341,6 +342,17 @@ class _BaseParamOffloader(ABC):
|
|
341
342
|
raise NotImplementedError
|
342
343
|
|
343
344
|
|
345
|
+
class _MetaParamOffloader(_BaseParamOffloader):
|
346
|
+
"""Usually used for debugging."""
|
347
|
+
|
348
|
+
def __init__(self, module, param_name):
|
349
|
+
super().__init__(module, param_name)
|
350
|
+
_move_param_to_meta(module, param_name)
|
351
|
+
|
352
|
+
def create_device_tensor(self):
|
353
|
+
return torch.empty_like(self._param.data, device="cuda")
|
354
|
+
|
355
|
+
|
344
356
|
class _CpuParamOffloader(_BaseParamOffloader):
|
345
357
|
def __init__(self, module, param_name):
|
346
358
|
super().__init__(module, param_name)
|
@@ -431,3 +443,106 @@ def _empty_strided_like(x: torch.Tensor, device, pin_memory=False):
|
|
431
443
|
device=device,
|
432
444
|
pin_memory=pin_memory,
|
433
445
|
)
|
446
|
+
|
447
|
+
|
448
|
+
# ----------------------------------------- ShardedGpu ------------------------------------------------------
|
449
|
+
|
450
|
+
|
451
|
+
# TODO unify with ShmCpu mode
|
452
|
+
class _ShardedGpuParamOffloader(_BaseParamOffloader):
|
453
|
+
def __init__(self, module, param_name):
|
454
|
+
super().__init__(module, param_name)
|
455
|
+
self._rank = get_naive_distributed().get_rank()
|
456
|
+
self._world_size = get_naive_distributed().get_world_size()
|
457
|
+
|
458
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
459
|
+
|
460
|
+
assert get_tensor_model_parallel_world_size() == 1, "not yet support tp_size!=1"
|
461
|
+
assert (
|
462
|
+
self._param.data.is_contiguous()
|
463
|
+
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
464
|
+
|
465
|
+
if self._rank == 0:
|
466
|
+
_move_param_to_cpu(self._param, pin_memory=True)
|
467
|
+
else:
|
468
|
+
_move_param_to_meta(self._module, self._param_name)
|
469
|
+
|
470
|
+
self.sharded_param_handles = None
|
471
|
+
|
472
|
+
def post_init(self):
|
473
|
+
# check again since it may be changed
|
474
|
+
assert (
|
475
|
+
self._param.data.is_contiguous()
|
476
|
+
), f"not yet support non-contiguous tensor {self._param.shape=} {self._param.stride()=}"
|
477
|
+
|
478
|
+
scatter_src = self._param.data
|
479
|
+
|
480
|
+
logger.info(
|
481
|
+
f"[offloader] post_init {scatter_src.nbytes=} {scatter_src.dtype=} {scatter_src.shape=} {torch.cuda.memory_allocated()=}"
|
482
|
+
)
|
483
|
+
|
484
|
+
if self._rank == 0:
|
485
|
+
scatter_src = scatter_src.to("cuda")
|
486
|
+
scatter_list = _even_chunk(scatter_src, self._world_size)
|
487
|
+
|
488
|
+
sharded_param = torch.empty(
|
489
|
+
scatter_list[0].shape, dtype=scatter_list[0].dtype, device="cuda"
|
490
|
+
)
|
491
|
+
self.sharded_param_handles = _create_shared_buffer_tensors(
|
492
|
+
local_tensor=sharded_param
|
493
|
+
)
|
494
|
+
|
495
|
+
get_naive_distributed().scatter(
|
496
|
+
sharded_param, scatter_list if self._rank == 0 else None
|
497
|
+
)
|
498
|
+
|
499
|
+
_move_param_to_meta(self._module, self._param_name)
|
500
|
+
|
501
|
+
def create_device_tensor(self):
|
502
|
+
output = _empty_strided_like(self._param, device="cuda")
|
503
|
+
output_chunks = output.chunk(self._world_size)
|
504
|
+
|
505
|
+
for index in range(self._world_size):
|
506
|
+
src_rank = (self._rank + index) % self._world_size
|
507
|
+
src_buf = self.sharded_param_handles[src_rank]
|
508
|
+
output_chunks[src_rank].copy_(src_buf)
|
509
|
+
|
510
|
+
return output
|
511
|
+
|
512
|
+
|
513
|
+
def _even_chunk(x: torch.Tensor, chunks: int):
|
514
|
+
assert x.shape[0] % chunks == 0, f"{x.shape=} {chunks=}"
|
515
|
+
return list(x.chunk(chunks))
|
516
|
+
|
517
|
+
|
518
|
+
def _create_shared_buffer_tensors(local_tensor: torch.Tensor) -> List[torch.Tensor]:
|
519
|
+
self_rank = get_naive_distributed().get_rank()
|
520
|
+
world_size = get_naive_distributed().get_world_size()
|
521
|
+
|
522
|
+
object_list = get_naive_distributed().all_gather_object(
|
523
|
+
dict(
|
524
|
+
dup_serialized_local_tensor=[
|
525
|
+
(
|
526
|
+
None
|
527
|
+
if interesting_rank == self_rank
|
528
|
+
else MultiprocessingSerializer.serialize(local_tensor)
|
529
|
+
)
|
530
|
+
for interesting_rank in range(world_size)
|
531
|
+
]
|
532
|
+
)
|
533
|
+
)
|
534
|
+
|
535
|
+
output_tensors = []
|
536
|
+
for output_rank in range(world_size):
|
537
|
+
remote_serialized_tensor = object_list[output_rank][
|
538
|
+
"dup_serialized_local_tensor"
|
539
|
+
][self_rank]
|
540
|
+
if output_rank == self_rank:
|
541
|
+
assert remote_serialized_tensor is None
|
542
|
+
output_tensors.append(local_tensor)
|
543
|
+
else:
|
544
|
+
output_tensors.append(
|
545
|
+
MultiprocessingSerializer.deserialize(remote_serialized_tensor)
|
546
|
+
)
|
547
|
+
|
548
|
+
return output_tensors
|