sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
16
16
  TokenToKVPoolAllocator,
17
17
  )
18
18
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
19
- from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
20
19
 
21
20
  logger = logging.getLogger(__name__)
22
21
 
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
31
30
  page_size: int,
32
31
  hicache_ratio: float,
33
32
  ):
34
- if page_size != 1:
35
- raise ValueError(
36
- "Page size larger than 1 is not yet supported in HiRadixCache."
37
- )
38
33
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
39
34
  if isinstance(self.kv_cache, MHATokenToKVPool):
40
35
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
41
- self.kv_cache, hicache_ratio
36
+ self.kv_cache, hicache_ratio, page_size
42
37
  )
43
38
  elif isinstance(self.kv_cache, MLATokenToKVPool):
44
39
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
45
- self.kv_cache, hicache_ratio
40
+ self.kv_cache, hicache_ratio, page_size
46
41
  )
47
42
  else:
48
- raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
43
+ raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
49
44
 
50
45
  self.tp_group = tp_cache_group
51
- self.page_size = page_size
52
46
 
53
47
  self.load_cache_event = threading.Event()
54
48
  self.cache_controller = HiCacheController(
55
49
  token_to_kv_pool_allocator,
56
50
  self.token_to_kv_pool_host,
51
+ page_size,
57
52
  load_cache_event=self.load_cache_event,
58
53
  )
59
54
 
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
65
60
  self.write_through_threshold = 1
66
61
  self.load_back_threshold = 10
67
62
  super().__init__(
68
- req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
63
+ req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
69
64
  )
70
65
 
71
66
  def reset(self):
@@ -210,9 +205,9 @@ class HiRadixCache(RadixCache):
210
205
  # only evict the host value of evicted nodes
211
206
  if not x.evicted:
212
207
  continue
213
- assert x.lock_ref == 0 and x.host_value is not None
214
208
 
215
- assert self.cache_controller.evict_host(x.host_value) > 0
209
+ num_evicted += self.cache_controller.evict_host(x.host_value)
210
+
216
211
  for k, v in x.parent.children.items():
217
212
  if v == x:
218
213
  break
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
299
294
 
300
295
  return last_node, prefix_indices
301
296
 
302
- def read_to_load_cache(self):
297
+ def ready_to_load_cache(self):
303
298
  self.load_cache_event.set()
304
299
 
305
300
  def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
306
- if self.disable:
307
- return [], self.root_node
301
+ empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
302
+ if self.disable or len(key) == 0:
303
+ if include_evicted:
304
+ return empty_value, self.root_node, self.root_node
305
+ else:
306
+ return empty_value, self.root_node
307
+
308
+ if self.page_size != 1:
309
+ page_aligned_len = len(key) // self.page_size * self.page_size
310
+ key = key[:page_aligned_len]
308
311
 
309
312
  value, last_node = self._match_prefix_helper(self.root_node, key)
310
313
  if value:
311
314
  value = torch.cat(value)
312
315
  else:
313
- value = torch.tensor([], dtype=torch.int64)
316
+ value = empty_value
314
317
 
315
318
  last_node_global = last_node
316
319
  while last_node.evicted:
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
323
326
 
324
327
  def _match_prefix_helper(self, node: TreeNode, key: List):
325
328
  node.last_access_time = time.time()
329
+ child_key = self.get_child_key_fn(key)
326
330
  value = []
327
- while len(key) > 0 and key[0] in node.children.keys():
328
- child = node.children[key[0]]
331
+
332
+ while len(key) > 0 and child_key in node.children.keys():
333
+ child = node.children[child_key]
329
334
  child.last_access_time = time.time()
330
- prefix_len = _key_match(child.key, key)
335
+ prefix_len = self.key_match_fn(child.key, key)
331
336
  if prefix_len < len(child.key):
332
337
  new_node = self._split_node(child.key, child, prefix_len)
333
338
  if not new_node.evicted:
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
339
344
  value.append(child.value)
340
345
  node = child
341
346
  key = key[prefix_len:]
347
+
348
+ if len(key):
349
+ child_key = self.get_child_key_fn(key)
350
+
342
351
  return value, node
343
352
 
344
353
  def _split_node(self, key, child: TreeNode, split_len: int):
345
354
  # child node split into new_node -> child
346
355
  new_node = TreeNode()
347
- new_node.children = {key[split_len]: child}
356
+ new_node.children = {self.get_child_key_fn(key[split_len:]): child}
348
357
  new_node.parent = child.parent
349
358
  new_node.lock_ref = child.lock_ref
350
359
  new_node.key = child.key[:split_len]
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
361
370
  child.host_value = child.host_value[split_len:]
362
371
  child.parent = new_node
363
372
  child.key = child.key[split_len:]
364
- new_node.parent.children[key[0]] = new_node
373
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
365
374
  return new_node
366
375
 
367
376
  def _insert_helper(self, node: TreeNode, key: List, value):
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
369
378
  if len(key) == 0:
370
379
  return 0
371
380
 
372
- if key[0] in node.children.keys():
373
- child = node.children[key[0]]
374
- prefix_len = _key_match(child.key, key)
381
+ child_key = self.get_child_key_fn(key)
382
+ total_prefix_length = 0
375
383
 
376
- if prefix_len == len(child.key):
377
- if child.evicted:
384
+ while len(key) > 0 and child_key in node.children.keys():
385
+ node = node.children[child_key]
386
+ node.last_access_time = time.time()
387
+ prefix_len = self.key_match_fn(node.key, key)
388
+
389
+ if prefix_len == len(node.key):
390
+ if node.evicted:
378
391
  # change the reference if the node is evicted
379
392
  # this often happens in the case of KV cache recomputation
380
- child.value = value[:prefix_len]
381
- self.token_to_kv_pool_host.update_synced(child.host_value)
382
- self.evictable_size_ += len(value[:prefix_len])
383
- return self._insert_helper(
384
- child, key[prefix_len:], value[prefix_len:]
385
- )
393
+ node.value = value[:prefix_len]
394
+ self.token_to_kv_pool_host.update_synced(node.host_value)
395
+ self.evictable_size_ += len(node.value)
386
396
  else:
387
- self.inc_hit_count(child)
388
- return prefix_len + self._insert_helper(
389
- child, key[prefix_len:], value[prefix_len:]
390
- )
391
-
392
- # partial match, split the node
393
- new_node = self._split_node(child.key, child, prefix_len)
394
- if new_node.evicted:
395
- new_node.value = value[:prefix_len]
396
- self.token_to_kv_pool_host.update_synced(new_node.host_value)
397
- self.evictable_size_ += len(new_node.value)
398
- return self._insert_helper(
399
- new_node, key[prefix_len:], value[prefix_len:]
400
- )
397
+ self.inc_hit_count(node)
398
+ total_prefix_length += prefix_len
401
399
  else:
402
- self.inc_hit_count(new_node)
403
- return prefix_len + self._insert_helper(
404
- new_node, key[prefix_len:], value[prefix_len:]
405
- )
400
+ # partial match, split the node
401
+ new_node = self._split_node(node.key, node, prefix_len)
402
+ if new_node.evicted:
403
+ new_node.value = value[:prefix_len]
404
+ self.token_to_kv_pool_host.update_synced(new_node.host_value)
405
+ self.evictable_size_ += len(new_node.value)
406
+ else:
407
+ self.inc_hit_count(new_node)
408
+ total_prefix_length += prefix_len
409
+ node = new_node
410
+
411
+ key = key[prefix_len:]
412
+ value = value[prefix_len:]
413
+
414
+ if len(key):
415
+ child_key = self.get_child_key_fn(key)
406
416
 
407
417
  if len(key):
408
418
  new_node = TreeNode()
409
419
  new_node.parent = node
410
420
  new_node.key = key
411
421
  new_node.value = value
412
- node.children[key[0]] = new_node
422
+ node.children[child_key] = new_node
413
423
  self.evictable_size_ += len(value)
414
424
 
415
425
  if self.cache_controller.write_policy == "write_through":
416
426
  self.write_backup(new_node)
417
- return 0
427
+ return total_prefix_length
418
428
 
419
429
  def _collect_leaves_device(self):
420
430
  def is_leaf(node):
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
185
185
  if self.free_group:
186
186
  self.free(torch.cat(self.free_group))
187
187
 
188
+ def backup_state(self):
189
+ return self.free_slots
190
+
191
+ def restore_state(self, free_slots):
192
+ self.free_slots = free_slots
193
+
188
194
  def clear(self):
189
195
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
190
196
  self.free_slots = torch.arange(
@@ -602,8 +608,9 @@ class HostKVCache(abc.ABC):
602
608
  self,
603
609
  device_pool: MHATokenToKVPool,
604
610
  host_to_device_ratio: float,
605
- pin_memory: bool = False, # no need to use pin memory with the double buffering
606
- device: str = "cpu",
611
+ pin_memory: bool,
612
+ device: str,
613
+ page_size: int,
607
614
  ):
608
615
  assert (
609
616
  host_to_device_ratio >= 1
@@ -614,8 +621,11 @@ class HostKVCache(abc.ABC):
614
621
  self.host_to_device_ratio = host_to_device_ratio
615
622
  self.pin_memory = pin_memory
616
623
  self.device = device
624
+ self.page_size = page_size
617
625
 
618
626
  self.size = int(device_pool.size * host_to_device_ratio)
627
+ # Align the host memory pool size to the page size
628
+ self.size = self.size - (self.size % self.page_size)
619
629
  self.dtype = device_pool.store_dtype
620
630
  self.size_per_token = self.get_size_per_token()
621
631
 
@@ -769,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
769
779
  self,
770
780
  device_pool: MHATokenToKVPool,
771
781
  host_to_device_ratio: float,
772
- pin_memory: bool = False, # no need to use pin memory with the double buffering
782
+ page_size: int,
783
+ pin_memory: bool = True,
773
784
  device: str = "cpu",
774
785
  ):
775
- super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
786
+ super().__init__(
787
+ device_pool, host_to_device_ratio, pin_memory, device, page_size
788
+ )
776
789
 
777
790
  def get_size_per_token(self):
778
791
  self.head_num = self.device_pool.head_num
@@ -805,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
805
818
  def assign_flat_data(self, indices, flat_data):
806
819
  self.kv_buffer[:, :, indices] = flat_data
807
820
 
821
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
822
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
823
+ for i in range(len(device_indices_cpu)):
824
+ h_index = host_indices[i * self.page_size]
825
+ d_index = device_indices_cpu[i]
826
+ for j in range(self.layer_num):
827
+ self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
828
+ device_pool.k_buffer[j][d_index : d_index + self.page_size],
829
+ non_blocking=True,
830
+ )
831
+ self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
832
+ device_pool.v_buffer[j][d_index : d_index + self.page_size],
833
+ non_blocking=True,
834
+ )
835
+
836
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
837
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
838
+ for i in range(len(device_indices_cpu)):
839
+ h_index = host_indices[i * self.page_size]
840
+ d_index = device_indices_cpu[i]
841
+ device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
842
+ self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
843
+ non_blocking=True,
844
+ )
845
+ device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
846
+ self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
847
+ non_blocking=True,
848
+ )
849
+
808
850
 
809
851
  class MLATokenToKVPoolHost(HostKVCache):
810
852
  def __init__(
811
853
  self,
812
854
  device_pool: MLATokenToKVPool,
813
855
  host_to_device_ratio: float,
814
- pin_memory: bool = False, # no need to use pin memory with the double buffering
856
+ page_size: int,
857
+ pin_memory: bool = True,
815
858
  device: str = "cpu",
816
859
  ):
817
- super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
860
+ super().__init__(
861
+ device_pool, host_to_device_ratio, pin_memory, device, page_size
862
+ )
818
863
 
819
864
  def get_size_per_token(self):
820
865
  self.kv_lora_rank = self.device_pool.kv_lora_rank
@@ -851,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
851
896
 
852
897
  def assign_flat_data(self, indices, flat_data):
853
898
  self.kv_buffer[:, indices] = flat_data
899
+
900
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
901
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
902
+ for i in range(len(device_indices_cpu)):
903
+ h_index = host_indices[i * self.page_size]
904
+ d_index = device_indices_cpu[i]
905
+ for j in range(self.layer_num):
906
+ self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
907
+ device_pool.kv_buffer[j][d_index : d_index + self.page_size],
908
+ non_blocking=True,
909
+ )
910
+
911
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
912
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
913
+ for i in range(len(device_indices_cpu)):
914
+ h_index = host_indices[i * self.page_size]
915
+ d_index = device_indices_cpu[i]
916
+ device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
917
+ self.kv_buffer[layer_id, h_index : h_index + self.page_size],
918
+ non_blocking=True,
919
+ )
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
190
190
  def available_size(self):
191
191
  return len(self.free_pages) * self.page_size
192
192
 
193
+ def get_kvcache(self):
194
+ return self._kvcache
195
+
196
+ def alloc(self, need_size: int):
197
+ # page-aligned allocation, returning contiguous indices of pages
198
+ if self.debug_mode:
199
+ assert (
200
+ need_size % self.page_size == 0
201
+ ), "The allocation size should be page-aligned"
202
+
203
+ num_pages = need_size // self.page_size
204
+ if num_pages > len(self.free_pages):
205
+ return None
206
+
207
+ out_pages = self.free_pages[:num_pages]
208
+ self.free_pages = self.free_pages[num_pages:]
209
+
210
+ out_indices = (
211
+ out_pages[:, None] * self.page_size
212
+ + torch.arange(self.page_size, device=self.device)
213
+ ).reshape(-1)
214
+
215
+ return out_indices
216
+
193
217
  def alloc_extend(
194
218
  self,
195
219
  prefix_lens: torch.Tensor,
@@ -218,6 +242,9 @@ class PagedTokenToKVPoolAllocator:
218
242
  next_power_of_2(extend_num_tokens),
219
243
  )
220
244
 
245
+ if self.debug_mode:
246
+ assert len(torch.unique(out_indices)) == len(out_indices)
247
+
221
248
  merged_value = self.ret_values.item()
222
249
  num_new_pages = merged_value >> 32
223
250
  if num_new_pages > len(self.free_pages):
@@ -248,6 +275,9 @@ class PagedTokenToKVPoolAllocator:
248
275
  self.page_size,
249
276
  )
250
277
 
278
+ if self.debug_mode:
279
+ assert len(torch.unique(out_indices)) == len(out_indices)
280
+
251
281
  num_new_pages = self.ret_values.item()
252
282
  if num_new_pages > len(self.free_pages):
253
283
  return None
@@ -265,6 +295,9 @@ class PagedTokenToKVPoolAllocator:
265
295
  else:
266
296
  self.free_group.append(free_index)
267
297
 
298
+ if self.debug_mode:
299
+ assert len(torch.unique(self.free_pages)) == len(self.free_pages)
300
+
268
301
  def free_group_begin(self):
269
302
  self.is_not_in_free_group = False
270
303
  self.free_group = []
@@ -274,6 +307,12 @@ class PagedTokenToKVPoolAllocator:
274
307
  if self.free_group:
275
308
  self.free(torch.cat(self.free_group))
276
309
 
310
+ def backup_state(self):
311
+ return self.free_pages
312
+
313
+ def restore_state(self, free_pages):
314
+ self.free_pages = free_pages
315
+
277
316
  def clear(self):
278
317
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
279
318
  self.free_pages = torch.arange(
@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
33
33
 
34
34
  def __init__(self, labels: Dict[str, str]) -> None:
35
35
  # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
36
- from prometheus_client import Gauge
36
+ from prometheus_client import Gauge, Histogram
37
37
 
38
38
  self.labels = labels
39
39
  self.last_log_time = time.time()
@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
139
139
  labelnames=labels.keys(),
140
140
  buckets=[
141
141
  0.1,
142
- 0.3,
143
- 0.5,
144
- 0.7,
145
- 0.9,
142
+ 0.2,
143
+ 0.4,
144
+ 0.6,
145
+ 0.8,
146
146
  1,
147
147
  2,
148
148
  4,
@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
153
153
  40,
154
154
  60,
155
155
  80,
156
- 120,
157
- 160,
158
- ],
159
- )
160
-
161
- self.histogram_time_per_output_token = Histogram(
162
- name="sglang:time_per_output_token_seconds",
163
- documentation="Histogram of time per output token in seconds.",
164
- labelnames=labels.keys(),
165
- buckets=[
166
- 0.002,
167
- 0.005,
168
- 0.010,
169
- 0.020,
170
- 0.030,
171
- 0.040,
172
- 0.050,
173
- 0.060,
174
- 0.070,
175
- 0.080,
176
- 0.090,
177
- 0.100,
178
- 0.150,
179
- 0.200,
180
- 0.300,
181
- 0.400,
182
- 0.600,
183
- 0.800,
184
- 1.000,
185
- 2.000,
156
+ 100,
157
+ 200,
158
+ 400,
186
159
  ],
187
160
  )
188
161
 
@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
202
175
  0.030,
203
176
  0.035,
204
177
  0.040,
205
- 0.050,
206
- 0.075,
178
+ 0.060,
179
+ 0.080,
207
180
  0.100,
208
- 0.150,
209
181
  0.200,
210
- 0.300,
211
182
  0.400,
212
- 0.500,
213
- 0.750,
183
+ 0.600,
184
+ 0.800,
214
185
  1.000,
215
186
  2.000,
187
+ 4.000,
188
+ 6.000,
189
+ 8.000,
216
190
  ],
217
191
  )
218
192
 
@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
224
198
  0.1,
225
199
  0.2,
226
200
  0.4,
201
+ 0.6,
227
202
  0.8,
228
203
  1,
229
204
  2,
230
- 5,
205
+ 4,
206
+ 6,
207
+ 8,
231
208
  10,
232
209
  20,
233
210
  40,
234
211
  60,
235
212
  80,
236
213
  100,
237
- 150,
238
214
  200,
239
- 250,
240
- 300,
241
- 350,
242
- 500,
243
- 1000,
215
+ 400,
216
+ 800,
244
217
  ],
245
218
  )
246
219
 
@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
256
229
  ):
257
230
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
258
231
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
259
- self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
232
+ if cached_tokens > 0:
233
+ self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
260
234
  self.num_requests_total.labels(**self.labels).inc(1)
261
235
  self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
262
- if generation_tokens >= 1:
263
- self.histogram_time_per_output_token.labels(**self.labels).observe(
264
- e2e_latency / generation_tokens
265
- )
266
236
 
267
237
  def observe_time_to_first_token(self, value: float):
268
238
  self.histogram_time_to_first_token.labels(**self.labels).observe(value)
@@ -116,16 +116,18 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
116
116
  if capture_bs is None:
117
117
  if server_args.speculative_algorithm is None:
118
118
  if server_args.disable_cuda_graph_padding:
119
- capture_bs = list(range(1, 33)) + [64, 96, 128, 160]
119
+ capture_bs = list(range(1, 33)) + range(40, 161, 16)
120
120
  else:
121
- capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
121
+ capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
122
122
  else:
123
123
  # Since speculative decoding requires more cuda graph memory, we
124
124
  # capture less.
125
- capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
125
+ capture_bs = (
126
+ list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
127
+ )
126
128
 
127
129
  if _is_hip:
128
- capture_bs += [i * 8 for i in range(21, 33)]
130
+ capture_bs += list(range(160, 257, 8))
129
131
 
130
132
  if max(capture_bs) > model_runner.req_to_token_pool.size:
131
133
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -489,10 +491,10 @@ class CudaGraphRunner:
489
491
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
490
492
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
491
493
  self.positions[:raw_num_token].copy_(forward_batch.positions)
492
- if forward_batch.decode_seq_lens_cpu is not None:
494
+ if forward_batch.seq_lens_cpu is not None:
493
495
  if bs != raw_bs:
494
496
  self.seq_lens_cpu.fill_(1)
495
- self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
497
+ self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
496
498
 
497
499
  if self.is_encoder_decoder:
498
500
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
@@ -104,6 +104,9 @@ class ForwardMode(IntEnum):
104
104
  or self == ForwardMode.IDLE
105
105
  )
106
106
 
107
+ def is_extend_or_draft_extend(self):
108
+ return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
109
+
107
110
  def is_dummy_first(self):
108
111
  return self == ForwardMode.DUMMY_FIRST
109
112
 
@@ -148,6 +151,9 @@ class ForwardBatch:
148
151
  # The sum of all sequence lengths
149
152
  seq_lens_sum: int
150
153
 
154
+ # Optional seq_lens on cpu
155
+ seq_lens_cpu: Optional[torch.Tensor] = None
156
+
151
157
  # For logprob
152
158
  return_logprob: bool = False
153
159
  top_logprobs_nums: Optional[List[int]] = None
@@ -162,9 +168,6 @@ class ForwardBatch:
162
168
  # Position information
163
169
  positions: torch.Tensor = None
164
170
 
165
- # For decode
166
- decode_seq_lens_cpu: Optional[torch.Tensor] = None
167
-
168
171
  # For extend
169
172
  extend_num_tokens: Optional[int] = None
170
173
  extend_seq_lens: Optional[torch.Tensor] = None
@@ -293,12 +296,14 @@ class ForwardBatch:
293
296
  ):
294
297
  ret.positions = ret.spec_info.positions
295
298
 
299
+ # Get seq_lens_cpu if needed
300
+ if ret.seq_lens_cpu is None:
301
+ ret.seq_lens_cpu = batch.seq_lens_cpu
302
+
296
303
  # Init position information
297
304
  if ret.forward_mode.is_decode():
298
305
  if ret.positions is None:
299
306
  ret.positions = clamp_position(batch.seq_lens)
300
- if ret.decode_seq_lens_cpu is None:
301
- ret.decode_seq_lens_cpu = batch.decode_seq_lens
302
307
  else:
303
308
  ret.extend_seq_lens = torch.tensor(
304
309
  batch.extend_seq_lens, dtype=torch.int32
@@ -353,11 +358,6 @@ class ForwardBatch:
353
358
  for mm_input in valid_inputs[1:]:
354
359
  merged.merge(mm_input)
355
360
 
356
- if isinstance(merged.pixel_values, np.ndarray):
357
- merged.pixel_values = torch.from_numpy(merged.pixel_values)
358
- if isinstance(merged.audio_features, np.ndarray):
359
- merged.audio_features = torch.from_numpy(merged.audio_features)
360
-
361
361
  return merged
362
362
 
363
363
  def contains_image_inputs(self) -> bool: