sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,11 @@ from contextlib import contextmanager
20
20
  from enum import Enum, auto
21
21
  from typing import Dict, List, Optional
22
22
 
23
+ import torch
24
+
23
25
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
24
26
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
25
- from sglang.srt.mem_cache.radix_cache import TreeNode
27
+ from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
26
28
 
27
29
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
28
30
  # This can prevent the server from being too conservative.
@@ -32,6 +34,21 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
32
34
  os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
33
35
  )
34
36
 
37
+ # Threshold for in-batch prefix cache.
38
+ # If a request has a matched prefix length (against existing cache) less than this value,
39
+ # the scheduler runs the in-batch prefix caching check for this request.
40
+ # If we set it to -1, it means we disable in-batch prefix caching.
41
+ IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
42
+ os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
43
+ )
44
+
45
+ # Threshold for in-batch prefix cache.
46
+ # If a request has a matched prefix length (within the waiting queue) larger than this value,
47
+ # the scheduler deprioritizes this request
48
+ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
49
+ os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
50
+ )
51
+
35
52
 
36
53
  class SchedulePolicy:
37
54
  def __init__(self, policy: str, tree_cache: BasePrefixCache):
@@ -42,6 +59,11 @@ class SchedulePolicy:
42
59
  self.policy = policy
43
60
  self.tree_cache = tree_cache
44
61
 
62
+ # It is used to find the matching prefix for in-batch prefix caching.
63
+ self.waiting_queue_radix_tree = RadixCache(
64
+ req_to_token_pool=None, token_to_kv_pool=None, disable=False
65
+ )
66
+
45
67
  def calc_priority(self, waiting_queue: List[Req]):
46
68
  if len(waiting_queue) > 128 and self.policy == "lpm":
47
69
  # Turn off the expensive prefix matching and sorting when the #queue is large.
@@ -52,17 +74,53 @@ class SchedulePolicy:
52
74
  # Compute matched prefix length
53
75
  prefix_computed = False
54
76
  if policy == "lpm" or policy == "dfs-weight":
77
+ # rid to deprioritize in the current run for in-batch prefix caching.
78
+ temporary_deprioritized = set()
79
+ self.waiting_queue_radix_tree.reset()
80
+
55
81
  for r in waiting_queue:
82
+ prefix_ids = r.adjust_max_prefix_ids()
83
+
56
84
  # NOTE: the prefix_indices must always be aligned with last_node
57
85
  r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
58
- rid=r.rid, key=r.adjust_max_prefix_ids()
86
+ rid=r.rid, key=prefix_ids
59
87
  )
60
88
 
89
+ # NOTE(sang): This logic is for in-batch prefix caching;
90
+ # If there are more than 1 request that have small matching prefix from
91
+ # existing cache, but all those requests share the same prefix, we prefer
92
+ # to schedule only one of them so that we can increase the cache hit rate.
93
+ # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
94
+ # threshold means we cannot use in-batch prefix caching for short prefixes.
95
+ # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
96
+ if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
97
+ in_batch_matching_prefixes, _ = (
98
+ self.waiting_queue_radix_tree.match_prefix(
99
+ rid=r.rid, key=prefix_ids
100
+ )
101
+ )
102
+ if (
103
+ len(in_batch_matching_prefixes)
104
+ >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
105
+ ):
106
+ temporary_deprioritized.add(r.rid)
107
+ else:
108
+ # Insert with a dummy key
109
+ self.waiting_queue_radix_tree.insert(
110
+ prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
111
+ )
112
+
61
113
  prefix_computed = True
62
114
 
63
115
  if policy == "lpm":
64
116
  # Longest Prefix Match
65
- waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
117
+ waiting_queue.sort(
118
+ key=lambda r: (
119
+ -len(r.prefix_indices)
120
+ if r.rid not in temporary_deprioritized
121
+ else float("inf")
122
+ )
123
+ )
66
124
  elif policy == "fcfs":
67
125
  # first come first serve
68
126
  pass
@@ -72,6 +130,7 @@ class SchedulePolicy:
72
130
  elif policy == "random":
73
131
  random.shuffle(waiting_queue)
74
132
  elif policy == "dfs-weight":
133
+ # Experimental policy based on custom weights
75
134
  last_node_to_reqs = defaultdict(list)
76
135
  for req in waiting_queue:
77
136
  last_node_to_reqs[req.last_node].append(req)
@@ -101,8 +160,8 @@ class SchedulePolicy:
101
160
  def get_dfs_priority(
102
161
  self,
103
162
  cur_node: TreeNode,
104
- node_to_priority: Dict,
105
- last_node_to_reqs: Dict,
163
+ node_to_priority: Dict[TreeNode, int],
164
+ last_node_to_reqs: Dict[TreeNode, List[Req]],
106
165
  q: List,
107
166
  ):
108
167
  childs = [child for child in cur_node.children.values()]