sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/bench_one_batch.py +3 -0
  3. sglang/srt/configs/__init__.py +8 -0
  4. sglang/srt/configs/model_config.py +4 -0
  5. sglang/srt/configs/step3_vl.py +172 -0
  6. sglang/srt/conversation.py +23 -0
  7. sglang/srt/disaggregation/decode.py +2 -8
  8. sglang/srt/disaggregation/launch_lb.py +5 -20
  9. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  10. sglang/srt/disaggregation/prefill.py +2 -6
  11. sglang/srt/distributed/parallel_state.py +86 -1
  12. sglang/srt/entrypoints/engine.py +14 -18
  13. sglang/srt/entrypoints/http_server.py +10 -2
  14. sglang/srt/entrypoints/openai/serving_chat.py +2 -21
  15. sglang/srt/eplb/expert_distribution.py +5 -0
  16. sglang/srt/eplb/expert_location.py +17 -6
  17. sglang/srt/eplb/expert_location_dispatch.py +1 -0
  18. sglang/srt/eplb/expert_location_updater.py +2 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/step3_detector.py +436 -0
  21. sglang/srt/hf_transformers_utils.py +2 -0
  22. sglang/srt/jinja_template_utils.py +4 -1
  23. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  24. sglang/srt/layers/attention/utils.py +6 -1
  25. sglang/srt/layers/moe/cutlass_moe.py +2 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +39 -674
  27. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
  29. sglang/srt/layers/quantization/fp8.py +52 -18
  30. sglang/srt/layers/quantization/unquant.py +0 -8
  31. sglang/srt/layers/quantization/w4afp8.py +1 -0
  32. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  33. sglang/srt/managers/cache_controller.py +165 -67
  34. sglang/srt/managers/data_parallel_controller.py +2 -0
  35. sglang/srt/managers/io_struct.py +0 -2
  36. sglang/srt/managers/scheduler.py +90 -671
  37. sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
  38. sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
  39. sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
  40. sglang/srt/managers/template_manager.py +62 -19
  41. sglang/srt/managers/tokenizer_manager.py +123 -74
  42. sglang/srt/managers/tp_worker.py +4 -0
  43. sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
  44. sglang/srt/mem_cache/hicache_storage.py +60 -17
  45. sglang/srt/mem_cache/hiradix_cache.py +36 -8
  46. sglang/srt/mem_cache/memory_pool.py +15 -118
  47. sglang/srt/mem_cache/memory_pool_host.py +418 -29
  48. sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
  49. sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
  50. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  51. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  52. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  53. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
  54. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
  55. sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  57. sglang/srt/model_executor/model_runner.py +13 -1
  58. sglang/srt/model_loader/weight_utils.py +2 -0
  59. sglang/srt/models/arcee.py +532 -0
  60. sglang/srt/models/deepseek_v2.py +7 -6
  61. sglang/srt/models/glm4_moe.py +6 -4
  62. sglang/srt/models/granitemoe.py +3 -0
  63. sglang/srt/models/grok.py +3 -0
  64. sglang/srt/models/hunyuan.py +1 -0
  65. sglang/srt/models/llama4.py +3 -0
  66. sglang/srt/models/mixtral.py +3 -0
  67. sglang/srt/models/olmoe.py +3 -0
  68. sglang/srt/models/phimoe.py +1 -0
  69. sglang/srt/models/step3_vl.py +991 -0
  70. sglang/srt/multimodal/processors/base_processor.py +15 -16
  71. sglang/srt/multimodal/processors/step3_vl.py +515 -0
  72. sglang/srt/reasoning_parser.py +2 -1
  73. sglang/srt/server_args.py +49 -18
  74. sglang/srt/speculative/eagle_worker.py +2 -0
  75. sglang/srt/utils.py +1 -0
  76. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  77. sglang/utils.py +0 -11
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
  80. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
  81. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  82. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
37
  hicache_io_backend: str,
38
+ hicache_mem_layout: str,
38
39
  hicache_storage_backend: Optional[str] = None,
39
40
  ):
41
+
42
+ if hicache_io_backend == "direct":
43
+ if hicache_mem_layout == "page_first":
44
+ hicache_mem_layout = "layer_first"
45
+ logger.warning(
46
+ "Page first layout is not supported with direct IO backend, switching to layer first layout"
47
+ )
48
+
40
49
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
41
50
  if isinstance(self.kv_cache, MHATokenToKVPool):
42
51
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
43
- self.kv_cache, hicache_ratio, hicache_size, page_size
52
+ self.kv_cache,
53
+ hicache_ratio,
54
+ hicache_size,
55
+ page_size,
56
+ hicache_mem_layout,
44
57
  )
45
58
  elif isinstance(self.kv_cache, MLATokenToKVPool):
46
59
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
47
- self.kv_cache, hicache_ratio, hicache_size, page_size
60
+ self.kv_cache,
61
+ hicache_ratio,
62
+ hicache_size,
63
+ page_size,
64
+ hicache_mem_layout,
48
65
  )
49
66
  else:
50
67
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -79,7 +96,9 @@ class HiRadixCache(RadixCache):
79
96
  self.write_through_threshold = (
80
97
  1 if hicache_write_policy == "write_through" else 3
81
98
  )
82
- self.write_through_threshold_storage = 3
99
+ self.write_through_threshold_storage = (
100
+ 1 if hicache_write_policy == "write_through" else 3
101
+ )
83
102
  self.load_back_threshold = 10
84
103
  super().__init__(
85
104
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -111,6 +130,7 @@ class HiRadixCache(RadixCache):
111
130
  )
112
131
  if host_indices is not None:
113
132
  node.host_value = host_indices
133
+ assert len(node.host_value) > 0
114
134
  self.ongoing_write_through[node.id] = node
115
135
  if not write_back:
116
136
  # no need to lock nodes if write back
@@ -388,10 +408,14 @@ class HiRadixCache(RadixCache):
388
408
  self.cache_controller.ack_backup_queue.get()
389
409
  )
390
410
  host_node = self.ongoing_backup[ack_id]
391
- if completed_tokens < len(host_node.key):
411
+ if completed_tokens == 0:
412
+ host_node.hash_value = None
413
+ elif completed_tokens < len(host_node.key):
392
414
  # backup is only partially successful, split the node
393
415
  new_node = self._split_node(host_node.key, host_node, completed_tokens)
394
416
  new_node.hash_value = hash_value
417
+ else:
418
+ host_node.hash_value = hash_value
395
419
  host_node.release_host()
396
420
  del self.ongoing_backup[ack_id]
397
421
 
@@ -429,8 +453,10 @@ class HiRadixCache(RadixCache):
429
453
  last_host_node,
430
454
  fetched_token_ids,
431
455
  written_indices,
432
- hash_value[:min_completed_tokens],
456
+ hash_value[: min_completed_tokens // self.page_size],
433
457
  )
458
+ if len(written_indices):
459
+ self.cache_controller.mem_pool_host.update_prefetch(written_indices)
434
460
 
435
461
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
436
462
  self.cache_controller.mem_pool_host.free(
@@ -520,7 +546,7 @@ class HiRadixCache(RadixCache):
520
546
  prefix_len = self.key_match_fn(node.key, key)
521
547
  key = key[prefix_len:]
522
548
  host_value = host_value[prefix_len:]
523
- hash_value = hash_value[prefix_len:]
549
+ hash_value = hash_value[prefix_len // self.page_size :]
524
550
  matched_length += prefix_len
525
551
 
526
552
  if prefix_len < len(node.key):
@@ -551,13 +577,11 @@ class HiRadixCache(RadixCache):
551
577
  prefix_len = self.key_match_fn(child.key, key)
552
578
  if prefix_len < len(child.key):
553
579
  new_node = self._split_node(child.key, child, prefix_len)
554
- self.inc_hit_count(new_node)
555
580
  if not new_node.evicted:
556
581
  value.append(new_node.value)
557
582
  node = new_node
558
583
  break
559
584
  else:
560
- self.inc_hit_count(child)
561
585
  if not child.evicted:
562
586
  value.append(child.value)
563
587
  node = child
@@ -587,6 +611,10 @@ class HiRadixCache(RadixCache):
587
611
  if child.backuped:
588
612
  new_node.host_value = child.host_value[:split_len]
589
613
  child.host_value = child.host_value[split_len:]
614
+
615
+ if child.hash_value:
616
+ new_node.hash_value = child.hash_value[: split_len // self.page_size]
617
+ child.hash_value = child.hash_value[split_len // self.page_size :]
590
618
  child.parent = new_node
591
619
  child.key = child.key[split_len:]
592
620
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import torch
34
- import torch.distributed as dist
35
34
  import triton
36
35
  import triton.language as tl
37
36
 
38
37
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
39
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
44
43
  GB = 1024 * 1024 * 1024
45
44
  _is_cuda = is_cuda()
46
- _is_npu = is_npu()
47
- if not _is_npu:
48
- from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
49
45
 
50
46
 
51
47
  class ReqToTokenPool:
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
153
149
  ) -> None:
154
150
  raise NotImplementedError()
155
151
 
156
- @abc.abstractmethod
157
- def load_from_host_per_layer(
158
- self, host_pool, host_indices, device_indices, layer_id, io_backend
159
- ):
160
- raise NotImplementedError()
161
-
162
- @abc.abstractmethod
163
- def backup_to_host_all_layer(
164
- self, host_pool, host_indices, device_indices, io_backend
165
- ):
166
- raise NotImplementedError()
167
-
168
152
  def register_layer_transfer_counter(self, layer_transfer_counter):
169
153
  self.layer_transfer_counter = layer_transfer_counter
170
154
 
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
253
237
  )
254
238
  for _ in range(self.layer_num)
255
239
  ]
256
- self.token_stride = self.head_num * self.head_dim
257
- self.data_ptrs = torch.tensor(
258
- [x.data_ptr() for x in self.k_buffer + self.v_buffer],
240
+
241
+ self.k_data_ptrs = torch.tensor(
242
+ [x.data_ptr() for x in self.k_buffer],
243
+ dtype=torch.uint64,
244
+ device=self.device,
245
+ )
246
+ self.v_data_ptrs = torch.tensor(
247
+ [x.data_ptr() for x in self.v_buffer],
259
248
  dtype=torch.uint64,
260
249
  device=self.device,
261
250
  )
251
+ self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
262
252
  self.data_strides = torch.tensor(
263
253
  [
264
254
  np.prod(x.shape[1:]) * x.dtype.itemsize
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
347
337
  self.v_buffer[layer_id][chunk_indices] = v_chunk
348
338
  torch.cuda.synchronize()
349
339
 
350
- def load_from_host_per_layer(
351
- self,
352
- host_pool,
353
- host_indices,
354
- device_indices,
355
- layer_id,
356
- io_backend,
357
- ):
358
- transfer_kv_per_layer(
359
- src_k=host_pool.k_buffer[layer_id],
360
- dst_k=self.k_buffer[layer_id],
361
- src_v=host_pool.v_buffer[layer_id],
362
- dst_v=self.v_buffer[layer_id],
363
- src_indices=host_indices,
364
- dst_indices=device_indices,
365
- io_backend=io_backend,
366
- page_size=self.page_size,
367
- item_size=self.token_stride,
368
- )
369
-
370
- def backup_to_host_all_layer(
371
- self, host_pool, host_indices, device_indices, io_backend
372
- ):
373
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
374
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
375
- if layer_id - self.start_layer >= len(host_pool.k_buffer):
376
- raise ValueError(
377
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
378
- )
379
- transfer_kv_per_layer(
380
- src_k=self.k_buffer[layer_id],
381
- dst_k=host_pool.k_buffer[layer_id],
382
- src_v=self.v_buffer[layer_id],
383
- dst_v=host_pool.v_buffer[layer_id],
384
- src_indices=device_indices,
385
- dst_indices=host_indices,
386
- io_backend=io_backend,
387
- page_size=self.page_size,
388
- item_size=self.token_stride,
389
- )
390
-
391
340
  def _get_key_buffer(self, layer_id: int):
392
341
  # for internal use of referencing
393
342
  if self.store_dtype != self.dtype:
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
602
551
  layer_id_override=layer_id_pool,
603
552
  )
604
553
 
605
- def load_from_host_per_layer(
606
- self, host_pool, host_indices, device_indices, layer_id, io_backend
607
- ):
608
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
609
-
610
- def backup_to_host_all_layer(
611
- self, host_pool, host_indices, device_indices, io_backend
612
- ):
613
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
614
-
615
554
 
616
555
  class AscendTokenToKVPool(MHATokenToKVPool):
617
556
 
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
823
762
  for _ in range(layer_num)
824
763
  ]
825
764
 
826
- self.token_stride = kv_lora_rank + qk_rope_head_dim
765
+ self.data_ptrs = torch.tensor(
766
+ [x.data_ptr() for x in self.kv_buffer],
767
+ dtype=torch.uint64,
768
+ device=self.device,
769
+ )
827
770
  self.layer_transfer_counter = None
828
771
 
829
772
  kv_size = self.get_kv_size_bytes()
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
909
852
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
910
853
  )
911
854
 
912
- def load_from_host_per_layer(
913
- self, host_pool, host_indices, device_indices, layer_id, io_backend
914
- ):
915
- transfer_kv_per_layer_mla(
916
- src=host_pool.kv_buffer[layer_id],
917
- dst=self.kv_buffer[layer_id],
918
- src_indices=host_indices,
919
- dst_indices=device_indices,
920
- io_backend=io_backend,
921
- page_size=self.page_size,
922
- item_size=self.token_stride,
923
- )
924
-
925
- def backup_to_host_all_layer(
926
- self, host_pool, host_indices, device_indices, io_backend
927
- ):
928
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
929
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
930
- if layer_id - self.start_layer >= len(host_pool.kv_buffer):
931
- raise ValueError(
932
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
933
- )
934
- transfer_kv_per_layer_mla(
935
- src=self.kv_buffer[layer_id],
936
- dst=host_pool.kv_buffer[layer_id],
937
- src_indices=device_indices,
938
- dst_indices=host_indices,
939
- io_backend=io_backend,
940
- page_size=self.page_size,
941
- item_size=self.token_stride,
942
- )
943
-
944
855
  def get_cpu_copy(self, indices):
945
856
  torch.cuda.synchronize()
946
857
  kv_cache_cpu = []
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
1131
1042
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
1132
1043
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
1133
1044
 
1134
- def load_from_host_per_layer(
1135
- self, host_pool, host_indices, device_indices, layer_id, io_backend
1136
- ):
1137
- raise NotImplementedError(
1138
- "HiCache not supported for DoubleSparseTokenToKVPool."
1139
- )
1140
-
1141
- def backup_to_host_all_layer(
1142
- self, host_pool, host_indices, device_indices, io_backend
1143
- ):
1144
- raise NotImplementedError(
1145
- "HiCache not supported for DoubleSparseTokenToKVPool."
1146
- )
1147
-
1148
1045
 
1149
1046
  @triton.jit
1150
1047
  def copy_all_layer_kv_cache(