sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
- from sglang.srt.managers.schedule_batch import global_server_args_dict
9
8
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
9
 
11
10
  if TYPE_CHECKING:
@@ -35,10 +34,8 @@ class TritonAttnBackend(AttentionBackend):
35
34
  model_runner.model_config.num_attention_heads // model_runner.tp_size
36
35
  )
37
36
 
38
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
39
- self.reduce_dtype = torch.float32
40
- else:
41
- self.reduce_dtype = torch.float16
37
+ self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
38
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
42
39
 
43
40
  self.forward_metadata = None
44
41
 
@@ -50,23 +47,23 @@ class TritonAttnBackend(AttentionBackend):
50
47
  """Init auxiliary variables for triton attention backend."""
51
48
 
52
49
  if forward_batch.forward_mode.is_decode():
53
- start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
54
- start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
55
-
56
- total_num_tokens = forward_batch.seq_lens_sum
57
50
  attn_logits = torch.empty(
58
- (self.num_head, total_num_tokens),
59
- dtype=self.reduce_dtype,
51
+ (
52
+ forward_batch.batch_size,
53
+ self.num_head,
54
+ self.num_kv_splits,
55
+ self.v_head_dim + 1,
56
+ ),
57
+ dtype=torch.float32,
60
58
  device=self.device,
61
59
  )
62
60
 
63
- max_seq_len = torch.max(forward_batch.seq_lens).item()
64
61
  max_extend_len = None
65
62
  else:
66
- start_loc = attn_logits = max_seq_len = None
63
+ attn_logits = None
67
64
  max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
68
65
 
69
- self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
66
+ self.forward_metadata = attn_logits, max_extend_len
70
67
 
71
68
  def init_cuda_graph_state(self, max_bs: int):
72
69
  self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
@@ -75,11 +72,8 @@ class TritonAttnBackend(AttentionBackend):
75
72
  (max_bs,), dtype=torch.int32, device=self.device
76
73
  )
77
74
  self.cuda_graph_attn_logits = torch.empty(
78
- (
79
- self.num_head,
80
- self.cuda_graph_max_total_num_tokens,
81
- ),
82
- dtype=self.reduce_dtype,
75
+ (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
76
+ dtype=torch.float32,
83
77
  device="cuda",
84
78
  )
85
79
 
@@ -92,9 +86,7 @@ class TritonAttnBackend(AttentionBackend):
92
86
  ):
93
87
  # NOTE: encoder_lens expected to be zeros or None
94
88
  self.forward_metadata = (
95
- self.cuda_graph_start_loc,
96
89
  self.cuda_graph_attn_logits,
97
- self.cuda_graph_max_seq_len,
98
90
  None,
99
91
  )
100
92
 
@@ -114,7 +106,13 @@ class TritonAttnBackend(AttentionBackend):
114
106
  return 1
115
107
 
116
108
  def forward_extend(
117
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
109
+ self,
110
+ q,
111
+ k,
112
+ v,
113
+ layer: RadixAttention,
114
+ forward_batch: ForwardBatch,
115
+ save_kv_cache=True,
118
116
  ):
119
117
  # TODO: reuse the buffer across layers
120
118
  if layer.qk_head_dim != layer.v_head_dim:
@@ -122,11 +120,12 @@ class TritonAttnBackend(AttentionBackend):
122
120
  else:
123
121
  o = torch.empty_like(q)
124
122
 
125
- forward_batch.token_to_kv_pool.set_kv_buffer(
126
- layer, forward_batch.out_cache_loc, k, v
127
- )
123
+ if save_kv_cache:
124
+ forward_batch.token_to_kv_pool.set_kv_buffer(
125
+ layer, forward_batch.out_cache_loc, k, v
126
+ )
128
127
 
129
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
128
+ _, max_extend_len = self.forward_metadata
130
129
  self.extend_attention_fwd(
131
130
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
132
131
  k.contiguous(),
@@ -146,7 +145,13 @@ class TritonAttnBackend(AttentionBackend):
146
145
  return o
147
146
 
148
147
  def forward_decode(
149
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
148
+ self,
149
+ q,
150
+ k,
151
+ v,
152
+ layer: RadixAttention,
153
+ forward_batch: ForwardBatch,
154
+ save_kv_cache=True,
150
155
  ):
151
156
  # During torch.compile, there is a bug in rotary_emb that causes the
152
157
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -158,11 +163,12 @@ class TritonAttnBackend(AttentionBackend):
158
163
  else:
159
164
  o = torch.empty_like(q)
160
165
 
161
- start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
166
+ attn_logits, _ = self.forward_metadata
162
167
 
163
- forward_batch.token_to_kv_pool.set_kv_buffer(
164
- layer, forward_batch.out_cache_loc, k, v
165
- )
168
+ if save_kv_cache:
169
+ forward_batch.token_to_kv_pool.set_kv_buffer(
170
+ layer, forward_batch.out_cache_loc, k, v
171
+ )
166
172
 
167
173
  self.decode_attention_fwd(
168
174
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
@@ -171,10 +177,9 @@ class TritonAttnBackend(AttentionBackend):
171
177
  o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
172
178
  forward_batch.req_to_token_pool.req_to_token,
173
179
  forward_batch.req_pool_indices,
174
- start_loc,
175
180
  forward_batch.seq_lens,
176
181
  attn_logits,
177
- max_seq_len,
182
+ self.num_kv_splits,
178
183
  layer.scaling,
179
184
  layer.logit_cap,
180
185
  )