sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
import concurrent.futures
|
17
16
|
import logging
|
18
17
|
import math
|
19
18
|
import threading
|
@@ -169,12 +168,23 @@ class HiCacheController:
|
|
169
168
|
page_size: int,
|
170
169
|
load_cache_event: threading.Event = None,
|
171
170
|
write_policy: str = "write_through_selective",
|
171
|
+
io_backend: str = "",
|
172
172
|
):
|
173
173
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
174
174
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
175
175
|
self.mem_pool_host = mem_pool_host
|
176
176
|
self.write_policy = write_policy
|
177
177
|
self.page_size = page_size
|
178
|
+
# using kernel for small page KV cache transfer and DMA for large pages
|
179
|
+
if not io_backend:
|
180
|
+
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
181
|
+
self.io_backend = (
|
182
|
+
"direct"
|
183
|
+
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
184
|
+
else "kernel"
|
185
|
+
)
|
186
|
+
else:
|
187
|
+
self.io_backend = io_backend
|
178
188
|
|
179
189
|
self.load_cache_event = load_cache_event
|
180
190
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -203,12 +213,7 @@ class HiCacheController:
|
|
203
213
|
self.load_stream = torch.cuda.Stream()
|
204
214
|
|
205
215
|
self.write_thread = threading.Thread(
|
206
|
-
target=
|
207
|
-
self.write_thread_func_buffer
|
208
|
-
if self.page_size == 1
|
209
|
-
else self.write_thread_func_direct
|
210
|
-
),
|
211
|
-
daemon=True,
|
216
|
+
target=self.write_thread_func_direct, daemon=True
|
212
217
|
)
|
213
218
|
self.load_thread = threading.Thread(
|
214
219
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -229,12 +234,7 @@ class HiCacheController:
|
|
229
234
|
self.ack_load_queue.queue.clear()
|
230
235
|
|
231
236
|
self.write_thread = threading.Thread(
|
232
|
-
target=
|
233
|
-
self.write_thread_func_buffer
|
234
|
-
if self.page_size == 1
|
235
|
-
else self.write_thread_func_direct
|
236
|
-
),
|
237
|
-
daemon=True,
|
237
|
+
target=self.write_thread_func_direct, daemon=True
|
238
238
|
)
|
239
239
|
self.load_thread = threading.Thread(
|
240
240
|
target=self.load_thread_func_layer_by_layer, daemon=True
|
@@ -281,6 +281,15 @@ class HiCacheController:
|
|
281
281
|
)
|
282
282
|
return device_indices
|
283
283
|
|
284
|
+
def move_indices(self, host_indices, device_indices):
|
285
|
+
# move indices to GPU if using kernels, to host if using direct indexing
|
286
|
+
if self.io_backend == "kernel":
|
287
|
+
return host_indices.to(self.mem_pool_device.device), device_indices
|
288
|
+
elif self.io_backend == "direct":
|
289
|
+
return host_indices, device_indices.cpu()
|
290
|
+
else:
|
291
|
+
raise ValueError(f"Unsupported io backend")
|
292
|
+
|
284
293
|
def write_thread_func_direct(self):
|
285
294
|
"""
|
286
295
|
Directly write through KV caches to host memory without buffering.
|
@@ -289,10 +298,14 @@ class HiCacheController:
|
|
289
298
|
while not self.stop_event.is_set():
|
290
299
|
try:
|
291
300
|
operation = self.write_queue.get(block=True, timeout=1)
|
292
|
-
self.
|
293
|
-
operation.host_indices,
|
294
|
-
|
295
|
-
|
301
|
+
host_indices, device_indices = self.move_indices(
|
302
|
+
operation.host_indices, operation.device_indices
|
303
|
+
)
|
304
|
+
self.mem_pool_device.backup_to_host_all_layer(
|
305
|
+
self.mem_pool_host,
|
306
|
+
host_indices,
|
307
|
+
device_indices,
|
308
|
+
self.io_backend,
|
296
309
|
)
|
297
310
|
self.write_stream.synchronize()
|
298
311
|
self.mem_pool_host.complete_io(operation.host_indices)
|
@@ -304,27 +317,6 @@ class HiCacheController:
|
|
304
317
|
except Exception as e:
|
305
318
|
logger.error(e)
|
306
319
|
|
307
|
-
def load_thread_func_direct(self):
|
308
|
-
"""
|
309
|
-
Directly load KV caches from host memory to device memory without buffering.
|
310
|
-
"""
|
311
|
-
torch.cuda.set_stream(self.load_stream)
|
312
|
-
while not self.stop_event.is_set():
|
313
|
-
try:
|
314
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
315
|
-
operation.data = self.mem_pool_host.get_flat_data(
|
316
|
-
operation.host_indices
|
317
|
-
)
|
318
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
319
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
320
|
-
for node_id in operation.node_ids:
|
321
|
-
if node_id != 0:
|
322
|
-
self.ack_load_queue.put(node_id)
|
323
|
-
except Empty:
|
324
|
-
continue
|
325
|
-
except Exception as e:
|
326
|
-
logger.error(e)
|
327
|
-
|
328
320
|
def load_thread_func_layer_by_layer(self):
|
329
321
|
"""
|
330
322
|
Load KV caches from host memory to device memory layer by layer.
|
@@ -349,22 +341,18 @@ class HiCacheController:
|
|
349
341
|
|
350
342
|
# start layer-wise KV cache transfer from CPU to GPU
|
351
343
|
self.layer_done_counter.reset()
|
344
|
+
host_indices, device_indices = self.move_indices(
|
345
|
+
batch_operation.host_indices, batch_operation.device_indices
|
346
|
+
)
|
352
347
|
for i in range(self.mem_pool_host.layer_num):
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
self.mem_pool_host.load_page_per_layer(
|
362
|
-
batch_operation.host_indices,
|
363
|
-
batch_operation.device_indices,
|
364
|
-
self.mem_pool_device,
|
365
|
-
i,
|
366
|
-
)
|
367
|
-
self.load_stream.synchronize()
|
348
|
+
self.mem_pool_device.load_from_host_per_layer(
|
349
|
+
self.mem_pool_host,
|
350
|
+
host_indices,
|
351
|
+
device_indices,
|
352
|
+
i,
|
353
|
+
self.io_backend,
|
354
|
+
)
|
355
|
+
self.load_stream.synchronize()
|
368
356
|
self.layer_done_counter.increment()
|
369
357
|
|
370
358
|
self.mem_pool_host.complete_io(batch_operation.host_indices)
|
@@ -372,148 +360,6 @@ class HiCacheController:
|
|
372
360
|
if node_id != 0:
|
373
361
|
self.ack_load_queue.put(node_id)
|
374
362
|
|
375
|
-
def write_aux_func(self, no_wait=False):
|
376
|
-
"""
|
377
|
-
Auxiliary function to prepare the buffer for write operations.
|
378
|
-
"""
|
379
|
-
torch.cuda.set_stream(self.write_stream)
|
380
|
-
|
381
|
-
def _to_op(op_):
|
382
|
-
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
|
383
|
-
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
|
384
|
-
self.mem_pool_host.device
|
385
|
-
)
|
386
|
-
self.write_buffer.put(op_)
|
387
|
-
return op_
|
388
|
-
|
389
|
-
buffer = None
|
390
|
-
while not self.stop_event.is_set():
|
391
|
-
try:
|
392
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
393
|
-
factor = (
|
394
|
-
len(operation.device_indices) // self.write_buffer.max_buffer_size
|
395
|
-
)
|
396
|
-
|
397
|
-
if factor >= 1:
|
398
|
-
if buffer is not None:
|
399
|
-
_to_op(buffer)
|
400
|
-
buffer = None
|
401
|
-
|
402
|
-
if factor < 2:
|
403
|
-
_to_op(operation)
|
404
|
-
else:
|
405
|
-
split_ops = operation.split(factor)
|
406
|
-
for op_ in split_ops:
|
407
|
-
_to_op(op_)
|
408
|
-
continue
|
409
|
-
|
410
|
-
if buffer is None:
|
411
|
-
buffer = operation
|
412
|
-
else:
|
413
|
-
buffer.merge(operation)
|
414
|
-
if (
|
415
|
-
no_wait
|
416
|
-
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
|
417
|
-
or self.write_queue.empty()
|
418
|
-
or self.write_buffer.empty()
|
419
|
-
):
|
420
|
-
_to_op(buffer)
|
421
|
-
buffer = None
|
422
|
-
except Empty:
|
423
|
-
continue
|
424
|
-
except Exception as e:
|
425
|
-
logger.error(e)
|
426
|
-
|
427
|
-
def load_aux_func(self):
|
428
|
-
"""
|
429
|
-
Auxiliary function to prepare the buffer for load operations.
|
430
|
-
"""
|
431
|
-
|
432
|
-
def _pin_op(op_, put=True):
|
433
|
-
op_.data = (
|
434
|
-
self.mem_pool_host.get_flat_data(op_.host_indices)
|
435
|
-
.contiguous()
|
436
|
-
.pin_memory()
|
437
|
-
)
|
438
|
-
if put:
|
439
|
-
self.load_buffer.put(op_)
|
440
|
-
return op_
|
441
|
-
|
442
|
-
buffer = None
|
443
|
-
while not self.stop_event.is_set():
|
444
|
-
try:
|
445
|
-
operation = self.load_queue.get(block=True, timeout=1)
|
446
|
-
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
|
447
|
-
|
448
|
-
if factor >= 1:
|
449
|
-
if buffer is not None:
|
450
|
-
_pin_op(buffer)
|
451
|
-
buffer = None
|
452
|
-
|
453
|
-
if factor < 2:
|
454
|
-
_pin_op(operation)
|
455
|
-
else:
|
456
|
-
split_ops = operation.split(factor)
|
457
|
-
split_args = [(op_, True) for op_ in split_ops[:-1]]
|
458
|
-
split_args.append((split_ops[-1], False))
|
459
|
-
# Spawn threads to pin each op concurrently
|
460
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
461
|
-
pinned_ops = list(
|
462
|
-
executor.map(
|
463
|
-
lambda x: _pin_op(x[0], put=x[1]), split_args
|
464
|
-
)
|
465
|
-
)
|
466
|
-
# preserve the order of last op to ensure correct ack
|
467
|
-
self.load_buffer.put(pinned_ops[-1])
|
468
|
-
continue
|
469
|
-
|
470
|
-
if buffer is None:
|
471
|
-
buffer = operation
|
472
|
-
else:
|
473
|
-
buffer.merge(operation)
|
474
|
-
if (
|
475
|
-
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
|
476
|
-
or self.load_queue.empty()
|
477
|
-
or self.load_buffer.empty()
|
478
|
-
):
|
479
|
-
_pin_op(buffer)
|
480
|
-
buffer = None
|
481
|
-
except Empty:
|
482
|
-
continue
|
483
|
-
except Exception as e:
|
484
|
-
logger.error(e)
|
485
|
-
|
486
|
-
# todo (zhiqiang): double buffering to be deprecated
|
487
|
-
def write_thread_func_buffer(self):
|
488
|
-
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
489
|
-
aux_thread.start()
|
490
|
-
|
491
|
-
while not self.stop_event.is_set():
|
492
|
-
operation = self.write_buffer.get()
|
493
|
-
if operation is None:
|
494
|
-
continue
|
495
|
-
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
|
496
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
497
|
-
for node_id in operation.node_ids:
|
498
|
-
if node_id != 0:
|
499
|
-
self.ack_write_queue.put(node_id)
|
500
|
-
aux_thread.join()
|
501
|
-
|
502
|
-
def load_thread_func_buffer(self):
|
503
|
-
torch.cuda.set_stream(self.load_stream)
|
504
|
-
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
|
505
|
-
aux_thread.start()
|
506
|
-
while not self.stop_event.is_set():
|
507
|
-
operation = self.load_buffer.get()
|
508
|
-
if operation is None:
|
509
|
-
continue
|
510
|
-
self.mem_pool_device.transfer(operation.device_indices, operation.data)
|
511
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
512
|
-
for node_id in operation.node_ids:
|
513
|
-
if node_id != 0:
|
514
|
-
self.ack_load_queue.put(node_id)
|
515
|
-
aux_thread.join()
|
516
|
-
|
517
363
|
def evict_device(
|
518
364
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
519
365
|
) -> int:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -200,6 +200,8 @@ class GenerateReqInput:
|
|
200
200
|
self.text = [self.text]
|
201
201
|
if self.input_ids is not None:
|
202
202
|
self.input_ids = [self.input_ids]
|
203
|
+
if self.input_embeds is not None:
|
204
|
+
self.input_embeds = [self.input_embeds]
|
203
205
|
|
204
206
|
def _normalize_single_inputs(self):
|
205
207
|
"""Normalize inputs for a single example."""
|
@@ -324,7 +326,9 @@ class GenerateReqInput:
|
|
324
326
|
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
325
327
|
self.rid = new_rids
|
326
328
|
elif isinstance(self.rid, list):
|
327
|
-
|
329
|
+
# Note: the length of rid shall be the same as the batch_size,
|
330
|
+
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
331
|
+
if len(self.rid) != self.batch_size:
|
328
332
|
raise ValueError(
|
329
333
|
"The specified rids length mismatch with the batch_size for batch processing."
|
330
334
|
)
|
@@ -400,6 +404,9 @@ class GenerateReqInput:
|
|
400
404
|
return GenerateReqInput(
|
401
405
|
text=self.text[i] if self.text is not None else None,
|
402
406
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
407
|
+
input_embeds=(
|
408
|
+
self.input_embeds[i] if self.input_embeds is not None else None
|
409
|
+
),
|
403
410
|
image_data=self.image_data[i],
|
404
411
|
audio_data=self.audio_data[i],
|
405
412
|
sampling_params=self.sampling_params[i],
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -248,7 +248,9 @@ def _get_chunked_prefill_embedding(
|
|
248
248
|
) -> Optional[torch.Tensor]:
|
249
249
|
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
250
250
|
embedding_list = []
|
251
|
-
for
|
251
|
+
# FIXME(Xinyuan): temporary workaround for eagle3, which may have len(items_size) > len(prefix_length)
|
252
|
+
max_iterations = min(len(items_size) - 1, len(prefix_length))
|
253
|
+
for i in range(max_iterations):
|
252
254
|
if items_size[i] == items_size[i + 1]:
|
253
255
|
continue
|
254
256
|
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
@@ -269,7 +271,7 @@ def _get_chunked_prefill_embedding(
|
|
269
271
|
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
270
272
|
embedding=embedding_per_req,
|
271
273
|
extend_prefix_len=prefix_length[i],
|
272
|
-
extend_seq_len=extend_length[i],
|
274
|
+
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
273
275
|
items_offset=items_offset,
|
274
276
|
)
|
275
277
|
# remove this item from cache if chunk reaches to the end
|
@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
101
101
|
"triton_attention_reduce_in_fp32",
|
102
102
|
"num_reserved_decode_tokens",
|
103
103
|
"weight_loader_disable_mmap",
|
104
|
+
"enable_triton_kernel_moe",
|
104
105
|
]
|
105
106
|
|
106
107
|
# Put some global args for easy access
|
@@ -842,7 +843,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
842
843
|
global_num_tokens_for_logprob: Optional[List[int]] = None
|
843
844
|
is_extend_in_batch: bool = False
|
844
845
|
can_run_dp_cuda_graph: bool = False
|
845
|
-
is_extend_in_batch: bool = False
|
846
846
|
tbo_split_seq_index: Optional[int] = None
|
847
847
|
global_forward_mode: Optional[ForwardMode] = None
|
848
848
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
15
15
|
|
16
|
+
import datetime
|
16
17
|
import faulthandler
|
17
18
|
import logging
|
18
19
|
import os
|
@@ -590,6 +591,12 @@ class Scheduler(
|
|
590
591
|
hicache_ratio=server_args.hicache_ratio,
|
591
592
|
hicache_size=server_args.hicache_size,
|
592
593
|
hicache_write_policy=server_args.hicache_write_policy,
|
594
|
+
hicache_io_backend=(
|
595
|
+
"direct"
|
596
|
+
if server_args.attention_backend
|
597
|
+
== "fa3" # hot fix for incompatibility
|
598
|
+
else server_args.hicache_io_backend
|
599
|
+
),
|
593
600
|
)
|
594
601
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
595
602
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -1313,10 +1320,12 @@ class Scheduler(
|
|
1313
1320
|
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1314
1321
|
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1315
1322
|
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
1316
|
-
f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
|
1323
|
+
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
1317
1324
|
else:
|
1318
1325
|
f += f"#running-req: {running_bs}, "
|
1319
|
-
f += f"#queue-req: {len(self.waiting_queue)}"
|
1326
|
+
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1327
|
+
|
1328
|
+
f += f"timestamp: {datetime.datetime.now().isoformat()}"
|
1320
1329
|
|
1321
1330
|
logger.info(f)
|
1322
1331
|
|
@@ -1378,7 +1387,8 @@ class Scheduler(
|
|
1378
1387
|
msg += (
|
1379
1388
|
f"cuda graph: {can_run_cuda_graph}, "
|
1380
1389
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
1381
|
-
f"#queue-req: {len(self.waiting_queue)}"
|
1390
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
1391
|
+
f"timestamp: {datetime.datetime.now().isoformat()}"
|
1382
1392
|
)
|
1383
1393
|
|
1384
1394
|
logger.info(msg)
|
@@ -2333,9 +2343,8 @@ class Scheduler(
|
|
2333
2343
|
|
2334
2344
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
2335
2345
|
tags = recv_req.tags
|
2336
|
-
import subprocess
|
2337
2346
|
|
2338
|
-
if tags is None:
|
2347
|
+
if tags is None or len(tags) == 0:
|
2339
2348
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2340
2349
|
|
2341
2350
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
@@ -2346,17 +2355,20 @@ class Scheduler(
|
|
2346
2355
|
self.stashed_model_static_state = _export_static_state(
|
2347
2356
|
self.tp_worker.worker.model_runner.model
|
2348
2357
|
)
|
2358
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
2349
2359
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
2350
2360
|
|
2351
2361
|
return ReleaseMemoryOccupationReqOutput()
|
2352
2362
|
|
2353
2363
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
2354
2364
|
tags = recv_req.tags
|
2365
|
+
|
2355
2366
|
if tags is None or len(tags) == 0:
|
2356
2367
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2357
2368
|
|
2358
2369
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2359
2370
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
2371
|
+
torch.distributed.barrier(self.tp_cpu_group)
|
2360
2372
|
_import_static_state(
|
2361
2373
|
self.tp_worker.worker.model_runner.model,
|
2362
2374
|
self.stashed_model_static_state,
|
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
|
|
34
34
|
hicache_ratio: float,
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
|
+
hicache_io_backend: str,
|
37
38
|
):
|
38
39
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
40
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
|
|
56
57
|
page_size,
|
57
58
|
load_cache_event=self.load_cache_event,
|
58
59
|
write_policy=hicache_write_policy,
|
60
|
+
io_backend=hicache_io_backend,
|
59
61
|
)
|
60
62
|
|
61
63
|
# record the nodes with ongoing write through
|