sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -2,28 +2,23 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Tuple
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  import torch
8
8
 
9
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from sglang.srt.managers.schedule_batch import Req
14
15
 
15
16
 
16
- class ChunkCacheEntry:
17
- def __init__(self, rid: str, value: torch.Tensor):
18
- self.rid = rid
19
- self.value = value
20
-
21
-
22
17
  class ChunkCache(BasePrefixCache):
23
18
  def __init__(
24
19
  self,
25
20
  req_to_token_pool: ReqToTokenPool,
26
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
21
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
27
22
  page_size: int,
28
23
  ):
29
24
  self.req_to_token_pool = req_to_token_pool
@@ -33,8 +28,12 @@ class ChunkCache(BasePrefixCache):
33
28
  def reset(self):
34
29
  pass
35
30
 
36
- def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
37
- return [], None
31
+ def match_prefix(self, **unused_kwargs) -> MatchResult:
32
+ return MatchResult(
33
+ device_indices=torch.empty((0,), dtype=torch.int64),
34
+ last_device_node=None,
35
+ last_host_node=None,
36
+ )
38
37
 
39
38
  def cache_finished_req(self, req: Req):
40
39
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -53,9 +52,6 @@ class ChunkCache(BasePrefixCache):
53
52
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
54
53
  req.prefix_indices = kv_indices
55
54
 
56
- def insert(self):
57
- raise NotImplementedError()
58
-
59
55
  def evict(self, num_tokens: int):
60
56
  pass
61
57
 
@@ -7,13 +7,16 @@ from typing import List, Optional
7
7
  import torch
8
8
 
9
9
  from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
10
12
  from sglang.srt.mem_cache.memory_pool import (
11
13
  MHATokenToKVPool,
12
- MHATokenToKVPoolHost,
13
14
  MLATokenToKVPool,
14
- MLATokenToKVPoolHost,
15
15
  ReqToTokenPool,
16
- TokenToKVPoolAllocator,
16
+ )
17
+ from sglang.srt.mem_cache.memory_pool_host import (
18
+ MHATokenToKVPoolHost,
19
+ MLATokenToKVPoolHost,
17
20
  )
18
21
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
19
22
 
@@ -25,7 +28,7 @@ class HiRadixCache(RadixCache):
25
28
  def __init__(
26
29
  self,
27
30
  req_to_token_pool: ReqToTokenPool,
28
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
31
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
29
32
  tp_cache_group: torch.distributed.ProcessGroup,
30
33
  page_size: int,
31
34
  hicache_ratio: float,
@@ -281,39 +284,44 @@ class HiRadixCache(RadixCache):
281
284
  def init_load_back(
282
285
  self,
283
286
  last_node: TreeNode,
284
- prefix_indices: torch.Tensor,
287
+ host_hit_length: int,
285
288
  mem_quota: Optional[int] = None,
286
289
  ):
287
- assert (
288
- len(prefix_indices) == 0 or prefix_indices.is_cuda
289
- ), "indices of device kV caches should be on GPU"
290
+ _ = host_hit_length # unused, but kept for compatibility
290
291
  if last_node.evicted:
291
292
  loading_values = self.load_back(last_node, mem_quota)
292
293
  if loading_values is not None:
293
- prefix_indices = (
294
- loading_values
295
- if len(prefix_indices) == 0
296
- else torch.cat([prefix_indices, loading_values])
297
- )
298
294
  logger.debug(
299
295
  f"loading back {len(loading_values)} tokens for node {last_node.id}"
300
296
  )
297
+ return loading_values, last_node
301
298
 
302
299
  while last_node.evicted:
303
300
  last_node = last_node.parent
304
301
 
305
- return last_node, prefix_indices
302
+ return (
303
+ torch.empty((0,), dtype=torch.int64, device=self.device),
304
+ last_node,
305
+ )
306
306
 
307
- def ready_to_load_cache(self):
307
+ def ready_to_load_host_cache(self):
308
+ producer_index = self.cache_controller.layer_done_counter.next_producer()
308
309
  self.load_cache_event.set()
310
+ return producer_index
311
+
312
+ def check_hicache_events(self):
313
+ self.writing_check()
314
+ self.loading_check()
309
315
 
310
- def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
316
+ def match_prefix(self, key: List[int], **kwargs):
311
317
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
312
318
  if self.disable or len(key) == 0:
313
- if include_evicted:
314
- return empty_value, self.root_node, self.root_node
315
- else:
316
- return empty_value, self.root_node
319
+ return MatchResult(
320
+ device_indices=empty_value,
321
+ last_device_node=self.root_node,
322
+ last_host_node=self.root_node,
323
+ host_hit_length=0,
324
+ )
317
325
 
318
326
  if self.page_size != 1:
319
327
  page_aligned_len = len(key) // self.page_size * self.page_size
@@ -325,14 +333,18 @@ class HiRadixCache(RadixCache):
325
333
  else:
326
334
  value = empty_value
327
335
 
328
- last_node_global = last_node
336
+ host_hit_length = 0
337
+ last_host_node = last_node
329
338
  while last_node.evicted:
339
+ host_hit_length += len(last_node.host_value)
330
340
  last_node = last_node.parent
331
341
 
332
- if include_evicted:
333
- return value, last_node, last_node_global
334
- else:
335
- return value, last_node
342
+ return MatchResult(
343
+ device_indices=value,
344
+ last_device_node=last_node,
345
+ last_host_node=last_host_node,
346
+ host_hit_length=host_hit_length,
347
+ )
336
348
 
337
349
  def _match_prefix_helper(self, node: TreeNode, key: List):
338
350
  node.last_access_time = time.monotonic()
@@ -370,6 +382,7 @@ class HiRadixCache(RadixCache):
370
382
  new_node.lock_ref = child.lock_ref
371
383
  new_node.key = child.key[:split_len]
372
384
  new_node.loading = child.loading
385
+ new_node.hit_count = child.hit_count
373
386
 
374
387
  # split value and host value if exists
375
388
  if child.evicted: