sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
16
16
|
TokenToKVPoolAllocator,
|
17
17
|
)
|
18
18
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
19
|
-
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
20
19
|
|
21
20
|
logger = logging.getLogger(__name__)
|
22
21
|
|
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
|
|
31
30
|
page_size: int,
|
32
31
|
hicache_ratio: float,
|
33
32
|
):
|
34
|
-
if page_size != 1:
|
35
|
-
raise ValueError(
|
36
|
-
"Page size larger than 1 is not yet supported in HiRadixCache."
|
37
|
-
)
|
38
33
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
34
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
40
35
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
41
|
-
self.kv_cache, hicache_ratio
|
36
|
+
self.kv_cache, hicache_ratio, page_size
|
42
37
|
)
|
43
38
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
44
39
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
45
|
-
self.kv_cache, hicache_ratio
|
40
|
+
self.kv_cache, hicache_ratio, page_size
|
46
41
|
)
|
47
42
|
else:
|
48
|
-
raise ValueError(f"
|
43
|
+
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
49
44
|
|
50
45
|
self.tp_group = tp_cache_group
|
51
|
-
self.page_size = page_size
|
52
46
|
|
53
47
|
self.load_cache_event = threading.Event()
|
54
48
|
self.cache_controller = HiCacheController(
|
55
49
|
token_to_kv_pool_allocator,
|
56
50
|
self.token_to_kv_pool_host,
|
51
|
+
page_size,
|
57
52
|
load_cache_event=self.load_cache_event,
|
58
53
|
)
|
59
54
|
|
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
|
|
65
60
|
self.write_through_threshold = 1
|
66
61
|
self.load_back_threshold = 10
|
67
62
|
super().__init__(
|
68
|
-
req_to_token_pool, token_to_kv_pool_allocator,
|
63
|
+
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
69
64
|
)
|
70
65
|
|
71
66
|
def reset(self):
|
@@ -210,9 +205,9 @@ class HiRadixCache(RadixCache):
|
|
210
205
|
# only evict the host value of evicted nodes
|
211
206
|
if not x.evicted:
|
212
207
|
continue
|
213
|
-
assert x.lock_ref == 0 and x.host_value is not None
|
214
208
|
|
215
|
-
|
209
|
+
num_evicted += self.cache_controller.evict_host(x.host_value)
|
210
|
+
|
216
211
|
for k, v in x.parent.children.items():
|
217
212
|
if v == x:
|
218
213
|
break
|
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
|
|
299
294
|
|
300
295
|
return last_node, prefix_indices
|
301
296
|
|
302
|
-
def
|
297
|
+
def ready_to_load_cache(self):
|
303
298
|
self.load_cache_event.set()
|
304
299
|
|
305
300
|
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
306
|
-
|
307
|
-
|
301
|
+
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
302
|
+
if self.disable or len(key) == 0:
|
303
|
+
if include_evicted:
|
304
|
+
return empty_value, self.root_node, self.root_node
|
305
|
+
else:
|
306
|
+
return empty_value, self.root_node
|
307
|
+
|
308
|
+
if self.page_size != 1:
|
309
|
+
page_aligned_len = len(key) // self.page_size * self.page_size
|
310
|
+
key = key[:page_aligned_len]
|
308
311
|
|
309
312
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
310
313
|
if value:
|
311
314
|
value = torch.cat(value)
|
312
315
|
else:
|
313
|
-
value =
|
316
|
+
value = empty_value
|
314
317
|
|
315
318
|
last_node_global = last_node
|
316
319
|
while last_node.evicted:
|
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
|
|
323
326
|
|
324
327
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
325
328
|
node.last_access_time = time.time()
|
329
|
+
child_key = self.get_child_key_fn(key)
|
326
330
|
value = []
|
327
|
-
|
328
|
-
|
331
|
+
|
332
|
+
while len(key) > 0 and child_key in node.children.keys():
|
333
|
+
child = node.children[child_key]
|
329
334
|
child.last_access_time = time.time()
|
330
|
-
prefix_len =
|
335
|
+
prefix_len = self.key_match_fn(child.key, key)
|
331
336
|
if prefix_len < len(child.key):
|
332
337
|
new_node = self._split_node(child.key, child, prefix_len)
|
333
338
|
if not new_node.evicted:
|
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
|
|
339
344
|
value.append(child.value)
|
340
345
|
node = child
|
341
346
|
key = key[prefix_len:]
|
347
|
+
|
348
|
+
if len(key):
|
349
|
+
child_key = self.get_child_key_fn(key)
|
350
|
+
|
342
351
|
return value, node
|
343
352
|
|
344
353
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
345
354
|
# child node split into new_node -> child
|
346
355
|
new_node = TreeNode()
|
347
|
-
new_node.children = {key[split_len]: child}
|
356
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
348
357
|
new_node.parent = child.parent
|
349
358
|
new_node.lock_ref = child.lock_ref
|
350
359
|
new_node.key = child.key[:split_len]
|
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
|
|
361
370
|
child.host_value = child.host_value[split_len:]
|
362
371
|
child.parent = new_node
|
363
372
|
child.key = child.key[split_len:]
|
364
|
-
new_node.parent.children[key
|
373
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
365
374
|
return new_node
|
366
375
|
|
367
376
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
|
|
369
378
|
if len(key) == 0:
|
370
379
|
return 0
|
371
380
|
|
372
|
-
|
373
|
-
|
374
|
-
prefix_len = _key_match(child.key, key)
|
381
|
+
child_key = self.get_child_key_fn(key)
|
382
|
+
total_prefix_length = 0
|
375
383
|
|
376
|
-
|
377
|
-
|
384
|
+
while len(key) > 0 and child_key in node.children.keys():
|
385
|
+
node = node.children[child_key]
|
386
|
+
node.last_access_time = time.time()
|
387
|
+
prefix_len = self.key_match_fn(node.key, key)
|
388
|
+
|
389
|
+
if prefix_len == len(node.key):
|
390
|
+
if node.evicted:
|
378
391
|
# change the reference if the node is evicted
|
379
392
|
# this often happens in the case of KV cache recomputation
|
380
|
-
|
381
|
-
self.token_to_kv_pool_host.update_synced(
|
382
|
-
self.evictable_size_ += len(value
|
383
|
-
return self._insert_helper(
|
384
|
-
child, key[prefix_len:], value[prefix_len:]
|
385
|
-
)
|
393
|
+
node.value = value[:prefix_len]
|
394
|
+
self.token_to_kv_pool_host.update_synced(node.host_value)
|
395
|
+
self.evictable_size_ += len(node.value)
|
386
396
|
else:
|
387
|
-
self.inc_hit_count(
|
388
|
-
|
389
|
-
child, key[prefix_len:], value[prefix_len:]
|
390
|
-
)
|
391
|
-
|
392
|
-
# partial match, split the node
|
393
|
-
new_node = self._split_node(child.key, child, prefix_len)
|
394
|
-
if new_node.evicted:
|
395
|
-
new_node.value = value[:prefix_len]
|
396
|
-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
397
|
-
self.evictable_size_ += len(new_node.value)
|
398
|
-
return self._insert_helper(
|
399
|
-
new_node, key[prefix_len:], value[prefix_len:]
|
400
|
-
)
|
397
|
+
self.inc_hit_count(node)
|
398
|
+
total_prefix_length += prefix_len
|
401
399
|
else:
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
400
|
+
# partial match, split the node
|
401
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
402
|
+
if new_node.evicted:
|
403
|
+
new_node.value = value[:prefix_len]
|
404
|
+
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
405
|
+
self.evictable_size_ += len(new_node.value)
|
406
|
+
else:
|
407
|
+
self.inc_hit_count(new_node)
|
408
|
+
total_prefix_length += prefix_len
|
409
|
+
node = new_node
|
410
|
+
|
411
|
+
key = key[prefix_len:]
|
412
|
+
value = value[prefix_len:]
|
413
|
+
|
414
|
+
if len(key):
|
415
|
+
child_key = self.get_child_key_fn(key)
|
406
416
|
|
407
417
|
if len(key):
|
408
418
|
new_node = TreeNode()
|
409
419
|
new_node.parent = node
|
410
420
|
new_node.key = key
|
411
421
|
new_node.value = value
|
412
|
-
node.children[
|
422
|
+
node.children[child_key] = new_node
|
413
423
|
self.evictable_size_ += len(value)
|
414
424
|
|
415
425
|
if self.cache_controller.write_policy == "write_through":
|
416
426
|
self.write_backup(new_node)
|
417
|
-
return
|
427
|
+
return total_prefix_length
|
418
428
|
|
419
429
|
def _collect_leaves_device(self):
|
420
430
|
def is_leaf(node):
|
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
|
|
185
185
|
if self.free_group:
|
186
186
|
self.free(torch.cat(self.free_group))
|
187
187
|
|
188
|
+
def backup_state(self):
|
189
|
+
return self.free_slots
|
190
|
+
|
191
|
+
def restore_state(self, free_slots):
|
192
|
+
self.free_slots = free_slots
|
193
|
+
|
188
194
|
def clear(self):
|
189
195
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
190
196
|
self.free_slots = torch.arange(
|
@@ -602,8 +608,9 @@ class HostKVCache(abc.ABC):
|
|
602
608
|
self,
|
603
609
|
device_pool: MHATokenToKVPool,
|
604
610
|
host_to_device_ratio: float,
|
605
|
-
pin_memory: bool
|
606
|
-
device: str
|
611
|
+
pin_memory: bool,
|
612
|
+
device: str,
|
613
|
+
page_size: int,
|
607
614
|
):
|
608
615
|
assert (
|
609
616
|
host_to_device_ratio >= 1
|
@@ -614,8 +621,11 @@ class HostKVCache(abc.ABC):
|
|
614
621
|
self.host_to_device_ratio = host_to_device_ratio
|
615
622
|
self.pin_memory = pin_memory
|
616
623
|
self.device = device
|
624
|
+
self.page_size = page_size
|
617
625
|
|
618
626
|
self.size = int(device_pool.size * host_to_device_ratio)
|
627
|
+
# Align the host memory pool size to the page size
|
628
|
+
self.size = self.size - (self.size % self.page_size)
|
619
629
|
self.dtype = device_pool.store_dtype
|
620
630
|
self.size_per_token = self.get_size_per_token()
|
621
631
|
|
@@ -769,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
769
779
|
self,
|
770
780
|
device_pool: MHATokenToKVPool,
|
771
781
|
host_to_device_ratio: float,
|
772
|
-
|
782
|
+
page_size: int,
|
783
|
+
pin_memory: bool = True,
|
773
784
|
device: str = "cpu",
|
774
785
|
):
|
775
|
-
super().__init__(
|
786
|
+
super().__init__(
|
787
|
+
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
788
|
+
)
|
776
789
|
|
777
790
|
def get_size_per_token(self):
|
778
791
|
self.head_num = self.device_pool.head_num
|
@@ -805,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
805
818
|
def assign_flat_data(self, indices, flat_data):
|
806
819
|
self.kv_buffer[:, :, indices] = flat_data
|
807
820
|
|
821
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
822
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
823
|
+
for i in range(len(device_indices_cpu)):
|
824
|
+
h_index = host_indices[i * self.page_size]
|
825
|
+
d_index = device_indices_cpu[i]
|
826
|
+
for j in range(self.layer_num):
|
827
|
+
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
828
|
+
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
829
|
+
non_blocking=True,
|
830
|
+
)
|
831
|
+
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
832
|
+
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
833
|
+
non_blocking=True,
|
834
|
+
)
|
835
|
+
|
836
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
837
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
838
|
+
for i in range(len(device_indices_cpu)):
|
839
|
+
h_index = host_indices[i * self.page_size]
|
840
|
+
d_index = device_indices_cpu[i]
|
841
|
+
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
842
|
+
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
|
843
|
+
non_blocking=True,
|
844
|
+
)
|
845
|
+
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
846
|
+
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
|
847
|
+
non_blocking=True,
|
848
|
+
)
|
849
|
+
|
808
850
|
|
809
851
|
class MLATokenToKVPoolHost(HostKVCache):
|
810
852
|
def __init__(
|
811
853
|
self,
|
812
854
|
device_pool: MLATokenToKVPool,
|
813
855
|
host_to_device_ratio: float,
|
814
|
-
|
856
|
+
page_size: int,
|
857
|
+
pin_memory: bool = True,
|
815
858
|
device: str = "cpu",
|
816
859
|
):
|
817
|
-
super().__init__(
|
860
|
+
super().__init__(
|
861
|
+
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
862
|
+
)
|
818
863
|
|
819
864
|
def get_size_per_token(self):
|
820
865
|
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
@@ -851,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
851
896
|
|
852
897
|
def assign_flat_data(self, indices, flat_data):
|
853
898
|
self.kv_buffer[:, indices] = flat_data
|
899
|
+
|
900
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
901
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
902
|
+
for i in range(len(device_indices_cpu)):
|
903
|
+
h_index = host_indices[i * self.page_size]
|
904
|
+
d_index = device_indices_cpu[i]
|
905
|
+
for j in range(self.layer_num):
|
906
|
+
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
907
|
+
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
908
|
+
non_blocking=True,
|
909
|
+
)
|
910
|
+
|
911
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
912
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
913
|
+
for i in range(len(device_indices_cpu)):
|
914
|
+
h_index = host_indices[i * self.page_size]
|
915
|
+
d_index = device_indices_cpu[i]
|
916
|
+
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
917
|
+
self.kv_buffer[layer_id, h_index : h_index + self.page_size],
|
918
|
+
non_blocking=True,
|
919
|
+
)
|
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
|
|
190
190
|
def available_size(self):
|
191
191
|
return len(self.free_pages) * self.page_size
|
192
192
|
|
193
|
+
def get_kvcache(self):
|
194
|
+
return self._kvcache
|
195
|
+
|
196
|
+
def alloc(self, need_size: int):
|
197
|
+
# page-aligned allocation, returning contiguous indices of pages
|
198
|
+
if self.debug_mode:
|
199
|
+
assert (
|
200
|
+
need_size % self.page_size == 0
|
201
|
+
), "The allocation size should be page-aligned"
|
202
|
+
|
203
|
+
num_pages = need_size // self.page_size
|
204
|
+
if num_pages > len(self.free_pages):
|
205
|
+
return None
|
206
|
+
|
207
|
+
out_pages = self.free_pages[:num_pages]
|
208
|
+
self.free_pages = self.free_pages[num_pages:]
|
209
|
+
|
210
|
+
out_indices = (
|
211
|
+
out_pages[:, None] * self.page_size
|
212
|
+
+ torch.arange(self.page_size, device=self.device)
|
213
|
+
).reshape(-1)
|
214
|
+
|
215
|
+
return out_indices
|
216
|
+
|
193
217
|
def alloc_extend(
|
194
218
|
self,
|
195
219
|
prefix_lens: torch.Tensor,
|
@@ -218,6 +242,9 @@ class PagedTokenToKVPoolAllocator:
|
|
218
242
|
next_power_of_2(extend_num_tokens),
|
219
243
|
)
|
220
244
|
|
245
|
+
if self.debug_mode:
|
246
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
247
|
+
|
221
248
|
merged_value = self.ret_values.item()
|
222
249
|
num_new_pages = merged_value >> 32
|
223
250
|
if num_new_pages > len(self.free_pages):
|
@@ -248,6 +275,9 @@ class PagedTokenToKVPoolAllocator:
|
|
248
275
|
self.page_size,
|
249
276
|
)
|
250
277
|
|
278
|
+
if self.debug_mode:
|
279
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
280
|
+
|
251
281
|
num_new_pages = self.ret_values.item()
|
252
282
|
if num_new_pages > len(self.free_pages):
|
253
283
|
return None
|
@@ -265,6 +295,9 @@ class PagedTokenToKVPoolAllocator:
|
|
265
295
|
else:
|
266
296
|
self.free_group.append(free_index)
|
267
297
|
|
298
|
+
if self.debug_mode:
|
299
|
+
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
300
|
+
|
268
301
|
def free_group_begin(self):
|
269
302
|
self.is_not_in_free_group = False
|
270
303
|
self.free_group = []
|
@@ -274,6 +307,12 @@ class PagedTokenToKVPoolAllocator:
|
|
274
307
|
if self.free_group:
|
275
308
|
self.free(torch.cat(self.free_group))
|
276
309
|
|
310
|
+
def backup_state(self):
|
311
|
+
return self.free_pages
|
312
|
+
|
313
|
+
def restore_state(self, free_pages):
|
314
|
+
self.free_pages = free_pages
|
315
|
+
|
277
316
|
def clear(self):
|
278
317
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
318
|
self.free_pages = torch.arange(
|
sglang/srt/metrics/collector.py
CHANGED
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
|
|
33
33
|
|
34
34
|
def __init__(self, labels: Dict[str, str]) -> None:
|
35
35
|
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
|
36
|
-
from prometheus_client import Gauge
|
36
|
+
from prometheus_client import Gauge, Histogram
|
37
37
|
|
38
38
|
self.labels = labels
|
39
39
|
self.last_log_time = time.time()
|
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
|
|
139
139
|
labelnames=labels.keys(),
|
140
140
|
buckets=[
|
141
141
|
0.1,
|
142
|
-
0.
|
143
|
-
0.
|
144
|
-
0.
|
145
|
-
0.
|
142
|
+
0.2,
|
143
|
+
0.4,
|
144
|
+
0.6,
|
145
|
+
0.8,
|
146
146
|
1,
|
147
147
|
2,
|
148
148
|
4,
|
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
|
|
153
153
|
40,
|
154
154
|
60,
|
155
155
|
80,
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
)
|
160
|
-
|
161
|
-
self.histogram_time_per_output_token = Histogram(
|
162
|
-
name="sglang:time_per_output_token_seconds",
|
163
|
-
documentation="Histogram of time per output token in seconds.",
|
164
|
-
labelnames=labels.keys(),
|
165
|
-
buckets=[
|
166
|
-
0.002,
|
167
|
-
0.005,
|
168
|
-
0.010,
|
169
|
-
0.020,
|
170
|
-
0.030,
|
171
|
-
0.040,
|
172
|
-
0.050,
|
173
|
-
0.060,
|
174
|
-
0.070,
|
175
|
-
0.080,
|
176
|
-
0.090,
|
177
|
-
0.100,
|
178
|
-
0.150,
|
179
|
-
0.200,
|
180
|
-
0.300,
|
181
|
-
0.400,
|
182
|
-
0.600,
|
183
|
-
0.800,
|
184
|
-
1.000,
|
185
|
-
2.000,
|
156
|
+
100,
|
157
|
+
200,
|
158
|
+
400,
|
186
159
|
],
|
187
160
|
)
|
188
161
|
|
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
|
|
202
175
|
0.030,
|
203
176
|
0.035,
|
204
177
|
0.040,
|
205
|
-
0.
|
206
|
-
0.
|
178
|
+
0.060,
|
179
|
+
0.080,
|
207
180
|
0.100,
|
208
|
-
0.150,
|
209
181
|
0.200,
|
210
|
-
0.300,
|
211
182
|
0.400,
|
212
|
-
0.
|
213
|
-
0.
|
183
|
+
0.600,
|
184
|
+
0.800,
|
214
185
|
1.000,
|
215
186
|
2.000,
|
187
|
+
4.000,
|
188
|
+
6.000,
|
189
|
+
8.000,
|
216
190
|
],
|
217
191
|
)
|
218
192
|
|
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
|
|
224
198
|
0.1,
|
225
199
|
0.2,
|
226
200
|
0.4,
|
201
|
+
0.6,
|
227
202
|
0.8,
|
228
203
|
1,
|
229
204
|
2,
|
230
|
-
|
205
|
+
4,
|
206
|
+
6,
|
207
|
+
8,
|
231
208
|
10,
|
232
209
|
20,
|
233
210
|
40,
|
234
211
|
60,
|
235
212
|
80,
|
236
213
|
100,
|
237
|
-
150,
|
238
214
|
200,
|
239
|
-
|
240
|
-
|
241
|
-
350,
|
242
|
-
500,
|
243
|
-
1000,
|
215
|
+
400,
|
216
|
+
800,
|
244
217
|
],
|
245
218
|
)
|
246
219
|
|
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
|
|
256
229
|
):
|
257
230
|
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
|
258
231
|
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
|
259
|
-
|
232
|
+
if cached_tokens > 0:
|
233
|
+
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
|
260
234
|
self.num_requests_total.labels(**self.labels).inc(1)
|
261
235
|
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
|
262
|
-
if generation_tokens >= 1:
|
263
|
-
self.histogram_time_per_output_token.labels(**self.labels).observe(
|
264
|
-
e2e_latency / generation_tokens
|
265
|
-
)
|
266
236
|
|
267
237
|
def observe_time_to_first_token(self, value: float):
|
268
238
|
self.histogram_time_to_first_token.labels(**self.labels).observe(value)
|
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
116
116
|
if capture_bs is None:
|
117
117
|
if server_args.speculative_algorithm is None:
|
118
118
|
if server_args.disable_cuda_graph_padding:
|
119
|
-
capture_bs = list(range(1, 33)) +
|
119
|
+
capture_bs = list(range(1, 33)) + range(40, 161, 16)
|
120
120
|
else:
|
121
|
-
capture_bs = [1, 2, 4] +
|
121
|
+
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
122
122
|
else:
|
123
123
|
# Since speculative decoding requires more cuda graph memory, we
|
124
124
|
# capture less.
|
125
|
-
capture_bs =
|
125
|
+
capture_bs = (
|
126
|
+
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
127
|
+
)
|
126
128
|
|
127
129
|
if _is_hip:
|
128
|
-
capture_bs +=
|
130
|
+
capture_bs += list(range(160, 257, 8))
|
129
131
|
|
130
132
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
131
133
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
@@ -489,10 +491,10 @@ class CudaGraphRunner:
|
|
489
491
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
490
492
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
491
493
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
492
|
-
if forward_batch.
|
494
|
+
if forward_batch.seq_lens_cpu is not None:
|
493
495
|
if bs != raw_bs:
|
494
496
|
self.seq_lens_cpu.fill_(1)
|
495
|
-
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.
|
497
|
+
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
496
498
|
|
497
499
|
if self.is_encoder_decoder:
|
498
500
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
|
|
104
104
|
or self == ForwardMode.IDLE
|
105
105
|
)
|
106
106
|
|
107
|
+
def is_extend_or_draft_extend(self):
|
108
|
+
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
|
109
|
+
|
107
110
|
def is_dummy_first(self):
|
108
111
|
return self == ForwardMode.DUMMY_FIRST
|
109
112
|
|
@@ -148,6 +151,9 @@ class ForwardBatch:
|
|
148
151
|
# The sum of all sequence lengths
|
149
152
|
seq_lens_sum: int
|
150
153
|
|
154
|
+
# Optional seq_lens on cpu
|
155
|
+
seq_lens_cpu: Optional[torch.Tensor] = None
|
156
|
+
|
151
157
|
# For logprob
|
152
158
|
return_logprob: bool = False
|
153
159
|
top_logprobs_nums: Optional[List[int]] = None
|
@@ -162,9 +168,6 @@ class ForwardBatch:
|
|
162
168
|
# Position information
|
163
169
|
positions: torch.Tensor = None
|
164
170
|
|
165
|
-
# For decode
|
166
|
-
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
167
|
-
|
168
171
|
# For extend
|
169
172
|
extend_num_tokens: Optional[int] = None
|
170
173
|
extend_seq_lens: Optional[torch.Tensor] = None
|
@@ -293,12 +296,14 @@ class ForwardBatch:
|
|
293
296
|
):
|
294
297
|
ret.positions = ret.spec_info.positions
|
295
298
|
|
299
|
+
# Get seq_lens_cpu if needed
|
300
|
+
if ret.seq_lens_cpu is None:
|
301
|
+
ret.seq_lens_cpu = batch.seq_lens_cpu
|
302
|
+
|
296
303
|
# Init position information
|
297
304
|
if ret.forward_mode.is_decode():
|
298
305
|
if ret.positions is None:
|
299
306
|
ret.positions = clamp_position(batch.seq_lens)
|
300
|
-
if ret.decode_seq_lens_cpu is None:
|
301
|
-
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
302
307
|
else:
|
303
308
|
ret.extend_seq_lens = torch.tensor(
|
304
309
|
batch.extend_seq_lens, dtype=torch.int32
|
@@ -353,11 +358,6 @@ class ForwardBatch:
|
|
353
358
|
for mm_input in valid_inputs[1:]:
|
354
359
|
merged.merge(mm_input)
|
355
360
|
|
356
|
-
if isinstance(merged.pixel_values, np.ndarray):
|
357
|
-
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
358
|
-
if isinstance(merged.audio_features, np.ndarray):
|
359
|
-
merged.audio_features = torch.from_numpy(merged.audio_features)
|
360
|
-
|
361
361
|
return merged
|
362
362
|
|
363
363
|
def contains_image_inputs(self) -> bool:
|