sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
|
|
29
29
|
tp_cache_group: torch.distributed.ProcessGroup,
|
30
30
|
page_size: int,
|
31
31
|
hicache_ratio: float,
|
32
|
+
hicache_size: int,
|
33
|
+
hicache_write_policy: str,
|
32
34
|
):
|
33
35
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
34
36
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
35
37
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
36
|
-
self.kv_cache, hicache_ratio, page_size
|
38
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
37
39
|
)
|
38
40
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
39
41
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
40
|
-
self.kv_cache, hicache_ratio, page_size
|
42
|
+
self.kv_cache, hicache_ratio, hicache_size, page_size
|
41
43
|
)
|
42
44
|
else:
|
43
45
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
|
|
50
52
|
self.token_to_kv_pool_host,
|
51
53
|
page_size,
|
52
54
|
load_cache_event=self.load_cache_event,
|
55
|
+
write_policy=hicache_write_policy,
|
53
56
|
)
|
54
57
|
|
55
58
|
# record the nodes with ongoing write through
|
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
|
|
57
60
|
# record the node segments with ongoing load back
|
58
61
|
self.ongoing_load_back = {}
|
59
62
|
# todo: dynamically adjust the threshold
|
60
|
-
self.write_through_threshold =
|
63
|
+
self.write_through_threshold = (
|
64
|
+
1 if hicache_write_policy == "write_through" else 3
|
65
|
+
)
|
61
66
|
self.load_back_threshold = 10
|
62
67
|
super().__init__(
|
63
68
|
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
|
|
76
81
|
height += 1
|
77
82
|
return height
|
78
83
|
|
79
|
-
def write_backup(self, node: TreeNode):
|
84
|
+
def write_backup(self, node: TreeNode, write_back=False):
|
80
85
|
host_indices = self.cache_controller.write(
|
81
86
|
device_indices=node.value,
|
82
87
|
node_id=node.id,
|
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
|
|
90
95
|
if host_indices is not None:
|
91
96
|
node.host_value = host_indices
|
92
97
|
self.ongoing_write_through[node.id] = node
|
93
|
-
|
98
|
+
if not write_back:
|
99
|
+
# no need to lock nodes if write back
|
100
|
+
self.inc_lock_ref(node)
|
94
101
|
else:
|
95
102
|
return 0
|
96
103
|
|
97
104
|
return len(host_indices)
|
98
105
|
|
99
106
|
def inc_hit_count(self, node: TreeNode):
|
100
|
-
if self.cache_controller.write_policy
|
107
|
+
if node.backuped or self.cache_controller.write_policy == "write_back":
|
101
108
|
return
|
102
109
|
node.hit_count += 1
|
103
|
-
if node.
|
110
|
+
if node.hit_count >= self.write_through_threshold:
|
104
111
|
self.write_backup(node)
|
105
112
|
node.hit_count = 0
|
106
113
|
|
107
|
-
def writing_check(self):
|
114
|
+
def writing_check(self, write_back=False):
|
115
|
+
if write_back:
|
116
|
+
# blocking till all write back complete
|
117
|
+
while len(self.ongoing_write_through) > 0:
|
118
|
+
ack_id = self.cache_controller.ack_write_queue.get()
|
119
|
+
del self.ongoing_write_through[ack_id]
|
120
|
+
return
|
108
121
|
queue_size = torch.tensor(
|
109
122
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
110
123
|
)
|
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
|
|
143
156
|
heapq.heapify(leaves)
|
144
157
|
|
145
158
|
num_evicted = 0
|
146
|
-
|
159
|
+
write_back_nodes = []
|
147
160
|
while num_evicted < num_tokens and len(leaves):
|
148
161
|
x = heapq.heappop(leaves)
|
149
162
|
|
150
163
|
if x.lock_ref > 0:
|
151
164
|
continue
|
152
165
|
|
153
|
-
if x.
|
166
|
+
if not x.backuped:
|
154
167
|
if self.cache_controller.write_policy == "write_back":
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
num_evicted += self._evict_write_through_selective(x)
|
168
|
+
# write to host if the node is not backuped
|
169
|
+
num_evicted += self.write_backup(x, write_back=True)
|
170
|
+
write_back_nodes.append(x)
|
159
171
|
else:
|
160
|
-
|
161
|
-
self.cache_controller.write_policy != "write_through"
|
162
|
-
), "write_through should be inclusive"
|
163
|
-
raise NotImplementedError
|
172
|
+
num_evicted += self._evict_regular(x)
|
164
173
|
else:
|
165
|
-
num_evicted += self.
|
174
|
+
num_evicted += self._evict_backuped(x)
|
166
175
|
|
167
176
|
for child in x.parent.children.values():
|
168
|
-
if child in
|
177
|
+
if child in write_back_nodes:
|
169
178
|
continue
|
170
179
|
if not child.evicted:
|
171
180
|
break
|
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
|
|
174
183
|
heapq.heappush(leaves, x.parent)
|
175
184
|
|
176
185
|
if self.cache_controller.write_policy == "write_back":
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
for node in pending_nodes:
|
182
|
-
assert node.host_value is not None
|
183
|
-
self._evict_write_through(node)
|
186
|
+
self.writing_check(write_back=True)
|
187
|
+
for node in write_back_nodes:
|
188
|
+
assert node.backuped
|
189
|
+
self._evict_backuped(node)
|
184
190
|
|
185
|
-
def
|
191
|
+
def _evict_backuped(self, node: TreeNode):
|
186
192
|
# evict a node already written to host
|
187
193
|
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
|
188
194
|
assert num_evicted > 0
|
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
|
|
190
196
|
node.value = None
|
191
197
|
return num_evicted
|
192
198
|
|
193
|
-
def
|
199
|
+
def _evict_regular(self, node: TreeNode):
|
194
200
|
# evict a node not initiated write to host
|
195
201
|
self.cache_controller.mem_pool_device_allocator.free(node.value)
|
196
202
|
num_evicted = len(node.value)
|
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
|
|
339
345
|
prefix_len = self.key_match_fn(child.key, key)
|
340
346
|
if prefix_len < len(child.key):
|
341
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
348
|
+
self.inc_hit_count(new_node)
|
342
349
|
if not new_node.evicted:
|
343
350
|
value.append(new_node.value)
|
344
351
|
node = new_node
|
345
352
|
break
|
346
353
|
else:
|
354
|
+
self.inc_hit_count(child)
|
347
355
|
if not child.evicted:
|
348
356
|
value.append(child.value)
|
349
357
|
node = child
|
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
|
|
369
377
|
else:
|
370
378
|
new_node.value = child.value[:split_len]
|
371
379
|
child.value = child.value[split_len:]
|
372
|
-
if child.
|
380
|
+
if child.backuped:
|
373
381
|
new_node.host_value = child.host_value[:split_len]
|
374
382
|
child.host_value = child.host_value[split_len:]
|
375
383
|
child.parent = new_node
|
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
|
|
426
434
|
node.children[child_key] = new_node
|
427
435
|
self.evictable_size_ += len(value)
|
428
436
|
|
429
|
-
if self.cache_controller.write_policy
|
430
|
-
self.
|
437
|
+
if self.cache_controller.write_policy != "write_back":
|
438
|
+
self.inc_hit_count(new_node)
|
431
439
|
return total_prefix_length
|
432
440
|
|
433
441
|
def _collect_leaves_device(self):
|
@@ -34,6 +34,8 @@ from typing import List, Optional, Tuple, Union
|
|
34
34
|
import numpy as np
|
35
35
|
import psutil
|
36
36
|
import torch
|
37
|
+
import triton
|
38
|
+
import triton.language as tl
|
37
39
|
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
41
|
from sglang.srt.utils import debug_timing, get_compiler_backend
|
@@ -405,6 +407,72 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
405
407
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
406
408
|
|
407
409
|
|
410
|
+
@triton.jit
|
411
|
+
def set_mla_kv_buffer_kernel(
|
412
|
+
kv_buffer_ptr,
|
413
|
+
cache_k_nope_ptr,
|
414
|
+
cache_k_rope_ptr,
|
415
|
+
loc_ptr,
|
416
|
+
buffer_stride: tl.constexpr,
|
417
|
+
nope_stride: tl.constexpr,
|
418
|
+
rope_stride: tl.constexpr,
|
419
|
+
nope_dim: tl.constexpr,
|
420
|
+
rope_dim: tl.constexpr,
|
421
|
+
BLOCK: tl.constexpr,
|
422
|
+
):
|
423
|
+
pid_loc = tl.program_id(0)
|
424
|
+
pid_blk = tl.program_id(1)
|
425
|
+
|
426
|
+
base = pid_blk * BLOCK
|
427
|
+
offs = base + tl.arange(0, BLOCK)
|
428
|
+
total_dim = nope_dim + rope_dim
|
429
|
+
mask = offs < total_dim
|
430
|
+
|
431
|
+
loc = tl.load(loc_ptr + pid_loc)
|
432
|
+
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
433
|
+
|
434
|
+
if base + BLOCK <= nope_dim:
|
435
|
+
src = tl.load(
|
436
|
+
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
437
|
+
mask=mask,
|
438
|
+
)
|
439
|
+
else:
|
440
|
+
offs_rope = offs - nope_dim
|
441
|
+
src = tl.load(
|
442
|
+
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
443
|
+
mask=mask,
|
444
|
+
)
|
445
|
+
|
446
|
+
tl.store(dst_ptr, src, mask=mask)
|
447
|
+
|
448
|
+
|
449
|
+
def set_mla_kv_buffer_triton(
|
450
|
+
kv_buffer: torch.Tensor,
|
451
|
+
loc: torch.Tensor,
|
452
|
+
cache_k_nope: torch.Tensor,
|
453
|
+
cache_k_rope: torch.Tensor,
|
454
|
+
):
|
455
|
+
nope_dim = cache_k_nope.shape[-1]
|
456
|
+
rope_dim = cache_k_rope.shape[-1]
|
457
|
+
total_dim = nope_dim + rope_dim
|
458
|
+
BLOCK = 128
|
459
|
+
n_loc = loc.numel()
|
460
|
+
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
461
|
+
|
462
|
+
set_mla_kv_buffer_kernel[grid](
|
463
|
+
kv_buffer,
|
464
|
+
cache_k_nope,
|
465
|
+
cache_k_rope,
|
466
|
+
loc,
|
467
|
+
kv_buffer.stride(0),
|
468
|
+
cache_k_nope.stride(0),
|
469
|
+
cache_k_rope.stride(0),
|
470
|
+
nope_dim,
|
471
|
+
rope_dim,
|
472
|
+
BLOCK=BLOCK,
|
473
|
+
)
|
474
|
+
|
475
|
+
|
408
476
|
class MLATokenToKVPool(KVCache):
|
409
477
|
def __init__(
|
410
478
|
self,
|
@@ -446,13 +514,28 @@ class MLATokenToKVPool(KVCache):
|
|
446
514
|
]
|
447
515
|
|
448
516
|
self.layer_transfer_counter = None
|
517
|
+
self.page_size = page_size
|
518
|
+
|
519
|
+
kv_size = self.get_kv_size_bytes()
|
520
|
+
logger.info(
|
521
|
+
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
522
|
+
)
|
523
|
+
|
524
|
+
def get_kv_size_bytes(self):
|
525
|
+
assert hasattr(self, "kv_buffer")
|
526
|
+
kv_size_bytes = 0
|
527
|
+
for kv_cache in self.kv_buffer:
|
528
|
+
kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
|
529
|
+
return kv_size_bytes
|
449
530
|
|
450
531
|
# for disagg
|
451
532
|
def get_contiguous_buf_infos(self):
|
452
533
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
453
534
|
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
454
535
|
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
455
|
-
kv_item_lens = [
|
536
|
+
kv_item_lens = [
|
537
|
+
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
|
538
|
+
]
|
456
539
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
457
540
|
|
458
541
|
def get_key_buffer(self, layer_id: int):
|
@@ -489,6 +572,25 @@ class MLATokenToKVPool(KVCache):
|
|
489
572
|
else:
|
490
573
|
self.kv_buffer[layer_id][loc] = cache_k
|
491
574
|
|
575
|
+
def set_mla_kv_buffer(
|
576
|
+
self,
|
577
|
+
layer: RadixAttention,
|
578
|
+
loc: torch.Tensor,
|
579
|
+
cache_k_nope: torch.Tensor,
|
580
|
+
cache_k_rope: torch.Tensor,
|
581
|
+
):
|
582
|
+
layer_id = layer.layer_id
|
583
|
+
if cache_k_nope.dtype != self.dtype:
|
584
|
+
cache_k_nope = cache_k_nope.to(self.dtype)
|
585
|
+
cache_k_rope = cache_k_rope.to(self.dtype)
|
586
|
+
if self.store_dtype != self.dtype:
|
587
|
+
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
588
|
+
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
589
|
+
|
590
|
+
set_mla_kv_buffer_triton(
|
591
|
+
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
592
|
+
)
|
593
|
+
|
492
594
|
def get_flat_data(self, indices):
|
493
595
|
# prepare a large chunk of contiguous data for efficient transfer
|
494
596
|
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
@@ -621,26 +723,27 @@ class HostKVCache(abc.ABC):
|
|
621
723
|
self,
|
622
724
|
device_pool: MHATokenToKVPool,
|
623
725
|
host_to_device_ratio: float,
|
726
|
+
host_size: int,
|
624
727
|
pin_memory: bool,
|
625
728
|
device: str,
|
626
729
|
page_size: int,
|
627
730
|
):
|
628
|
-
assert (
|
629
|
-
host_to_device_ratio >= 1
|
630
|
-
), "The host memory should be larger than the device memory with the current protocol"
|
631
|
-
# todo, other ways of configuring the size
|
632
|
-
|
633
731
|
self.device_pool = device_pool
|
634
|
-
self.
|
732
|
+
self.dtype = device_pool.store_dtype
|
635
733
|
self.pin_memory = pin_memory
|
636
734
|
self.device = device
|
637
735
|
self.page_size = page_size
|
638
|
-
|
639
|
-
|
736
|
+
self.size_per_token = self.get_size_per_token()
|
737
|
+
if host_size > 0:
|
738
|
+
self.size = int(host_size * 1e9 // self.size_per_token)
|
739
|
+
else:
|
740
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
640
741
|
# Align the host memory pool size to the page size
|
641
742
|
self.size = self.size - (self.size % self.page_size)
|
642
|
-
|
643
|
-
|
743
|
+
|
744
|
+
assert (
|
745
|
+
self.size > device_pool.size
|
746
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
644
747
|
|
645
748
|
# Verify there is enough available host memory.
|
646
749
|
host_mem = psutil.virtual_memory()
|
@@ -792,12 +895,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
792
895
|
self,
|
793
896
|
device_pool: MHATokenToKVPool,
|
794
897
|
host_to_device_ratio: float,
|
898
|
+
host_size: int,
|
795
899
|
page_size: int,
|
796
900
|
pin_memory: bool = True,
|
797
901
|
device: str = "cpu",
|
798
902
|
):
|
799
903
|
super().__init__(
|
800
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
904
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
801
905
|
)
|
802
906
|
|
803
907
|
def get_size_per_token(self):
|
@@ -866,12 +970,13 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
866
970
|
self,
|
867
971
|
device_pool: MLATokenToKVPool,
|
868
972
|
host_to_device_ratio: float,
|
973
|
+
host_size: int,
|
869
974
|
page_size: int,
|
870
975
|
pin_memory: bool = True,
|
871
976
|
device: str = "cpu",
|
872
977
|
):
|
873
978
|
super().__init__(
|
874
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
979
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
875
980
|
)
|
876
981
|
|
877
982
|
def get_size_per_token(self):
|
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
35
35
|
ForwardMode,
|
36
36
|
)
|
37
37
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
|
-
from sglang.srt.utils import
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
get_available_gpu_memory,
|
40
|
+
get_device_memory_capacity,
|
41
|
+
is_hip,
|
42
|
+
)
|
39
43
|
|
40
44
|
if TYPE_CHECKING:
|
41
45
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -129,7 +133,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
129
133
|
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
130
134
|
)
|
131
135
|
|
132
|
-
|
136
|
+
gpu_mem = get_device_memory_capacity()
|
137
|
+
# Batch size of each rank will not become so large when DP is on
|
138
|
+
if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
|
133
139
|
capture_bs += list(range(160, 257, 8))
|
134
140
|
|
135
141
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -140,12 +146,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
140
146
|
]
|
141
147
|
|
142
148
|
capture_bs = list(sorted(set(capture_bs)))
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
]
|
149
|
+
|
150
|
+
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
151
|
+
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
152
|
+
if server_args.cuda_graph_max_bs:
|
153
|
+
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
149
154
|
compile_bs = (
|
150
155
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
151
156
|
if server_args.enable_torch_compile
|
@@ -186,6 +191,7 @@ class CudaGraphRunner:
|
|
186
191
|
|
187
192
|
# Batch sizes to capture
|
188
193
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
194
|
+
|
189
195
|
self.capture_forward_mode = ForwardMode.DECODE
|
190
196
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
191
197
|
self.num_tokens_per_bs = 1
|
@@ -273,9 +279,9 @@ class CudaGraphRunner:
|
|
273
279
|
f"Capture cuda graph failed: {e}\n"
|
274
280
|
"Possible solutions:\n"
|
275
281
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
276
|
-
"2. set --cuda-graph-max-bs to a smaller value (e.g.,
|
282
|
+
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
277
283
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
278
|
-
"4. disable cuda graph by --disable-cuda-graph\n"
|
284
|
+
"4. disable cuda graph by --disable-cuda-graph. (Not recommonded. Huge perf loss)\n"
|
279
285
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
280
286
|
)
|
281
287
|
|
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import get_compiler_backend
|
41
|
+
from sglang.srt.utils import flatten_nested_list, get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -364,23 +364,23 @@ class ForwardBatch:
|
|
364
364
|
|
365
365
|
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
366
366
|
"""
|
367
|
-
Merge all
|
367
|
+
Merge all multimodal inputs in the batch into a single MultiModalInputs object.
|
368
368
|
|
369
369
|
Returns:
|
370
|
-
if none, current batch contains no
|
370
|
+
if none, current batch contains no multimodal input
|
371
371
|
|
372
372
|
"""
|
373
373
|
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
374
374
|
return None
|
375
|
-
|
376
375
|
# Filter out None values
|
377
376
|
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
378
377
|
|
379
|
-
#
|
380
|
-
|
378
|
+
# TODO: is it expensive?
|
379
|
+
# a workaround to avoid importing `MultimodalInputs`
|
380
|
+
merged = valid_inputs[0].__class__(mm_items=[])
|
381
381
|
|
382
382
|
# Merge remaining inputs
|
383
|
-
for mm_input in valid_inputs
|
383
|
+
for mm_input in valid_inputs:
|
384
384
|
merged.merge(mm_input)
|
385
385
|
|
386
386
|
return merged
|
@@ -407,104 +407,60 @@ class ForwardBatch:
|
|
407
407
|
def _compute_mrope_positions(
|
408
408
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
409
409
|
):
|
410
|
-
|
411
|
-
|
412
|
-
mrope_positions_list = [
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
421
|
-
mrope_position_delta,
|
422
|
-
int(self.seq_lens[i]) - 1,
|
423
|
-
int(self.seq_lens[i]),
|
410
|
+
# batch_size * [3 * seq_len]
|
411
|
+
batch_size = self.seq_lens.shape[0]
|
412
|
+
mrope_positions_list = [[]] * batch_size
|
413
|
+
for batch_idx in range(batch_size):
|
414
|
+
mm_input = batch.multimodal_inputs[batch_idx]
|
415
|
+
if self.forward_mode.is_decode():
|
416
|
+
mrope_position_deltas = (
|
417
|
+
[0]
|
418
|
+
if mm_input is None
|
419
|
+
else flatten_nested_list(mm_input.mrope_position_delta.tolist())
|
424
420
|
)
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
421
|
+
next_input_positions = []
|
422
|
+
for mrope_position_delta in mrope_position_deltas:
|
423
|
+
# batched deltas needs to be processed separately
|
424
|
+
# Convert list of lists to tensor with shape [3, seq_len]
|
425
|
+
next_input_positions += [
|
426
|
+
MRotaryEmbedding.get_next_input_positions(
|
427
|
+
mrope_position_delta,
|
428
|
+
int(self.seq_lens[batch_idx]) - 1,
|
429
|
+
int(self.seq_lens[batch_idx]),
|
430
|
+
)
|
431
|
+
]
|
432
|
+
# 3 * N
|
433
|
+
mrope_positions_list[batch_idx] = torch.cat(next_input_positions, dim=1)
|
434
|
+
elif self.forward_mode.is_extend():
|
435
|
+
extend_seq_len, extend_prefix_len = (
|
436
|
+
batch.extend_seq_lens[batch_idx],
|
437
|
+
batch.extend_prefix_lens[batch_idx],
|
432
438
|
)
|
433
439
|
if mm_input is None:
|
434
440
|
# text only
|
435
|
-
mrope_positions =
|
441
|
+
mrope_positions = torch.tensor(
|
436
442
|
[
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
443
|
+
[
|
444
|
+
pos
|
445
|
+
for pos in range(
|
446
|
+
extend_prefix_len,
|
447
|
+
extend_prefix_len + extend_seq_len,
|
448
|
+
)
|
449
|
+
]
|
441
450
|
]
|
442
|
-
|
443
|
-
else:
|
444
|
-
image_grid_thws_list = [
|
445
|
-
item.image_grid_thws
|
446
|
-
for item in mm_input.mm_items
|
447
|
-
if item.image_grid_thws is not None
|
448
|
-
]
|
449
|
-
image_grid_thw = (
|
450
|
-
None
|
451
|
-
if len(image_grid_thws_list) == 0
|
452
|
-
else torch.cat(image_grid_thws_list, dim=0)
|
453
|
-
)
|
454
|
-
|
455
|
-
video_grid_thws_list = [
|
456
|
-
item.video_grid_thws
|
457
|
-
for item in mm_input.mm_items
|
458
|
-
if item.video_grid_thws is not None
|
459
|
-
]
|
460
|
-
video_grid_thw = (
|
461
|
-
None
|
462
|
-
if len(video_grid_thws_list) == 0
|
463
|
-
else torch.cat(video_grid_thws_list, dim=0)
|
451
|
+
* 3
|
464
452
|
)
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
if item.second_per_grid_ts is not None
|
453
|
+
else:
|
454
|
+
mrope_positions = mm_input.mrope_positions[
|
455
|
+
:,
|
456
|
+
extend_prefix_len : extend_prefix_len + extend_seq_len,
|
470
457
|
]
|
471
|
-
|
472
|
-
None
|
473
|
-
if len(second_per_grid_ts_list) == 0
|
474
|
-
else torch.cat(second_per_grid_ts_list, dim=0)
|
475
|
-
)
|
476
|
-
|
477
|
-
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
|
478
|
-
mrope_positions, mrope_position_delta = (
|
479
|
-
MRotaryEmbedding.get_input_positions(
|
480
|
-
input_tokens=self.input_ids[
|
481
|
-
extend_start_loc : extend_start_loc + extend_seq_len
|
482
|
-
].tolist(),
|
483
|
-
image_grid_thw=image_grid_thw,
|
484
|
-
video_grid_thw=video_grid_thw,
|
485
|
-
image_token_id=hf_config.image_token_id,
|
486
|
-
video_token_id=hf_config.video_token_id,
|
487
|
-
vision_start_token_id=hf_config.vision_start_token_id,
|
488
|
-
vision_end_token_id=hf_config.vision_end_token_id,
|
489
|
-
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
490
|
-
context_len=0,
|
491
|
-
seq_len=len(self.input_ids),
|
492
|
-
second_per_grid_ts=second_per_grid_ts,
|
493
|
-
tokens_per_second=getattr(
|
494
|
-
hf_config.vision_config, "tokens_per_second", None
|
495
|
-
),
|
496
|
-
)
|
497
|
-
)
|
498
|
-
batch.multimodal_inputs[i].mrope_position_delta = (
|
499
|
-
mrope_position_delta
|
500
|
-
)
|
501
|
-
mrope_positions_list[i] = mrope_positions
|
458
|
+
mrope_positions_list[batch_idx] = mrope_positions
|
502
459
|
|
503
460
|
self.mrope_positions = torch.cat(
|
504
|
-
[
|
505
|
-
|
506
|
-
)
|
507
|
-
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
461
|
+
[pos.to(device=model_runner.device) for pos in mrope_positions_list],
|
462
|
+
dim=1,
|
463
|
+
).to(dtype=torch.int64, device=model_runner.device)
|
508
464
|
|
509
465
|
def get_max_chunk_capacity(self):
|
510
466
|
# Maximum number of tokens in each chunk
|