sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,8 @@ import dataclasses
17
17
  from typing import List, Optional, Union
18
18
 
19
19
  import torch
20
+ import triton
21
+ import triton.language as tl
20
22
  from torch import nn
21
23
  from vllm.distributed import (
22
24
  get_tensor_model_parallel_world_size,
@@ -33,76 +35,72 @@ from sglang.srt.model_executor.forward_batch_info import (
33
35
 
34
36
  @dataclasses.dataclass
35
37
  class LogitsProcessorOutput:
38
+ ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
36
39
  # The logits of the next tokens. shape: [#seq, vocab_size]
37
40
  next_token_logits: torch.Tensor
38
- # The logprobs of the next tokens. shape: [#seq, vocab_size]
39
- next_token_logprobs: torch.Tensor = None
41
+ # Used by speculative decoding (EAGLE)
42
+ # The last hidden layers
43
+ hidden_states: Optional[torch.Tensor] = None
40
44
 
45
+ ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
46
+ # The logprobs of the next tokens. shape: [#seq]
47
+ next_token_logprobs: Optional[torch.Tensor] = None
48
+ # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
49
+ next_token_top_logprobs_val: Optional[List] = None
50
+ next_token_top_logprobs_idx: Optional[List] = None
51
+
52
+ ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
41
53
  # The normlaized logprobs of prompts. shape: [#seq]
42
54
  normalized_prompt_logprobs: torch.Tensor = None
43
- # The logprobs of input tokens. shape: [#token, vocab_size]
55
+ # The logprobs of input tokens. shape: [#token]
44
56
  input_token_logprobs: torch.Tensor = None
45
-
46
- # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
57
+ # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
47
58
  input_top_logprobs_val: List = None
48
59
  input_top_logprobs_idx: List = None
49
- # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
50
- output_top_logprobs_val: List = None
51
- output_top_logprobs_idx: List = None
52
-
53
- # Used by speculative decoding (EAGLE)
54
- # The output of transformer layers
55
- hidden_states: Optional[torch.Tensor] = None
56
60
 
57
61
 
58
62
  @dataclasses.dataclass
59
63
  class LogitsMetadata:
60
64
  forward_mode: ForwardMode
61
- top_logprobs_nums: Optional[List[int]]
62
-
63
- return_logprob: bool = False
64
- return_top_logprob: bool = False
65
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
65
66
 
67
+ extend_return_logprob: bool = False
68
+ extend_return_top_logprob: bool = False
66
69
  extend_seq_lens: Optional[torch.Tensor] = None
67
70
  extend_seq_lens_cpu: Optional[List[int]] = None
68
-
69
71
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
70
72
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
71
-
72
- capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
73
+ top_logprobs_nums: Optional[List[int]] = None
73
74
 
74
75
  @classmethod
75
76
  def from_forward_batch(cls, forward_batch: ForwardBatch):
76
- extend_logprob_pruned_lens_cpu = None
77
-
78
- if forward_batch.return_logprob:
79
- return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
80
- if forward_batch.forward_mode.is_extend():
81
- extend_logprob_pruned_lens_cpu = [
82
- extend_len - start_len
83
- for extend_len, start_len in zip(
84
- forward_batch.extend_seq_lens_cpu,
85
- forward_batch.extend_logprob_start_lens_cpu,
86
- )
87
- ]
88
- else:
89
- return_top_logprob = False
90
-
91
- if forward_batch.spec_info:
92
- capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
77
+ if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
78
+ extend_return_logprob = True
79
+ extend_return_top_logprob = any(
80
+ x > 0 for x in forward_batch.top_logprobs_nums
81
+ )
82
+ extend_logprob_pruned_lens_cpu = [
83
+ extend_len - start_len
84
+ for extend_len, start_len in zip(
85
+ forward_batch.extend_seq_lens_cpu,
86
+ forward_batch.extend_logprob_start_lens_cpu,
87
+ )
88
+ ]
93
89
  else:
94
- capture_hidden_mode = CaptureHiddenMode.NULL
90
+ extend_return_logprob = extend_return_top_logprob = (
91
+ extend_logprob_pruned_lens_cpu
92
+ ) = False
95
93
 
96
94
  return cls(
97
95
  forward_mode=forward_batch.forward_mode,
98
- top_logprobs_nums=forward_batch.top_logprobs_nums,
99
- return_logprob=forward_batch.return_logprob,
100
- return_top_logprob=return_top_logprob,
96
+ capture_hidden_mode=forward_batch.capture_hidden_mode,
97
+ extend_return_logprob=extend_return_logprob,
98
+ extend_return_top_logprob=extend_return_top_logprob,
101
99
  extend_seq_lens=forward_batch.extend_seq_lens,
102
100
  extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
103
101
  extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
104
102
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
105
- capture_hidden_mode=capture_hidden_mode,
103
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
106
104
  )
107
105
 
108
106
 
@@ -119,6 +117,11 @@ class LogitsProcessor(nn.Module):
119
117
  self.final_logit_softcapping = getattr(
120
118
  self.config, "final_logit_softcapping", None
121
119
  )
120
+ if (
121
+ self.final_logit_softcapping is not None
122
+ and self.final_logit_softcapping < 0
123
+ ):
124
+ self.final_logit_softcapping = None
122
125
 
123
126
  def forward(
124
127
  self,
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
129
132
  ):
130
133
  if isinstance(logits_metadata, ForwardBatch):
131
134
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
132
- assert isinstance(logits_metadata, LogitsMetadata)
133
135
 
134
136
  # Get the last hidden states and last logits for the next token prediction
135
137
  if (
@@ -142,18 +144,13 @@ class LogitsProcessor(nn.Module):
142
144
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
143
145
  last_hidden = hidden_states[last_index]
144
146
 
147
+ # Compute logits
145
148
  last_logits = self._get_logits(last_hidden, lm_head)
146
- if self.do_tensor_parallel_all_gather:
147
- last_logits = tensor_model_parallel_all_gather(last_logits)
148
- last_logits = last_logits[:, : self.config.vocab_size].float()
149
-
150
- if self.final_logit_softcapping:
151
- last_logits.div_(self.final_logit_softcapping)
152
- torch.tanh(last_logits, out=last_logits)
153
- last_logits.mul_(self.final_logit_softcapping)
154
-
155
- # Return only last_logits if logprob is not requested
156
- if not logits_metadata.return_logprob:
149
+ if (
150
+ not logits_metadata.extend_return_logprob
151
+ or logits_metadata.capture_hidden_mode.need_capture()
152
+ ):
153
+ # Decode mode or extend mode without return_logprob.
157
154
  return LogitsProcessorOutput(
158
155
  next_token_logits=last_logits,
159
156
  hidden_states=(
@@ -167,95 +164,60 @@ class LogitsProcessor(nn.Module):
167
164
  ),
168
165
  )
169
166
  else:
170
- last_logprobs = self.compute_temp_top_p_normalized_logprobs(
171
- last_logits, logits_metadata
167
+ # Slice the requested tokens to compute logprob
168
+ pt, pruned_states, pruned_input_ids = 0, [], []
169
+ for start_len, extend_len in zip(
170
+ logits_metadata.extend_logprob_start_lens_cpu,
171
+ logits_metadata.extend_seq_lens_cpu,
172
+ ):
173
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
174
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
175
+ pt += extend_len
176
+
177
+ # Compute the logits of all required tokens
178
+ pruned_states = torch.cat(pruned_states)
179
+ del hidden_states
180
+ input_token_logits = self._get_logits(pruned_states, lm_head)
181
+ del pruned_states
182
+
183
+ # Normalize the logprob w/o temperature, top-p
184
+ input_logprobs = input_token_logits
185
+ input_logprobs = self.compute_temp_top_p_normalized_logprobs(
186
+ input_logprobs, logits_metadata
172
187
  )
173
188
 
174
- if logits_metadata.forward_mode.is_decode():
175
- if logits_metadata.return_top_logprob:
176
- output_top_logprobs_val, output_top_logprobs_idx = (
177
- self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
178
- )
179
- else:
180
- output_top_logprobs_val = output_top_logprobs_idx = None
181
- return LogitsProcessorOutput(
182
- next_token_logits=last_logits,
183
- next_token_logprobs=last_logprobs,
184
- output_top_logprobs_val=output_top_logprobs_val,
185
- output_top_logprobs_idx=output_top_logprobs_idx,
186
- )
189
+ # Get the logprob of top-k tokens
190
+ if logits_metadata.extend_return_top_logprob:
191
+ (
192
+ input_top_logprobs_val,
193
+ input_top_logprobs_idx,
194
+ ) = self.get_top_logprobs(input_logprobs, logits_metadata)
187
195
  else:
188
- # Slice the requested tokens to compute logprob
189
- pt, states, pruned_input_ids = 0, [], []
190
- for start_len, extend_len in zip(
191
- logits_metadata.extend_logprob_start_lens_cpu,
192
- logits_metadata.extend_seq_lens_cpu,
193
- ):
194
- states.append(hidden_states[pt + start_len : pt + extend_len])
195
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
196
- pt += extend_len
197
-
198
- # Compute the logits and logprobs for all required tokens
199
- states = torch.cat(states, dim=0)
200
- all_logits = self._get_logits(states, lm_head)
201
- if self.do_tensor_parallel_all_gather:
202
- all_logits = tensor_model_parallel_all_gather(all_logits)
203
-
204
- # The LM head's weights may be zero-padded for parallelism. Remove any
205
- # extra logits that this padding may have produced.
206
- all_logits = all_logits[:, : self.config.vocab_size].float()
207
-
208
- if self.final_logit_softcapping:
209
- all_logits.div_(self.final_logit_softcapping)
210
- torch.tanh(all_logits, out=all_logits)
211
- all_logits.mul_(self.final_logit_softcapping)
212
-
213
- all_logprobs = all_logits
214
- del all_logits, hidden_states
215
-
216
- all_logprobs = self.compute_temp_top_p_normalized_logprobs(
217
- all_logprobs, logits_metadata
218
- )
219
-
220
- # Get the logprob of top-k tokens
221
- if logits_metadata.return_top_logprob:
222
- (
223
- input_top_logprobs_val,
224
- input_top_logprobs_idx,
225
- output_top_logprobs_val,
226
- output_top_logprobs_idx,
227
- ) = self.get_top_logprobs(all_logprobs, logits_metadata)
228
- else:
229
- input_top_logprobs_val = input_top_logprobs_idx = (
230
- output_top_logprobs_val
231
- ) = output_top_logprobs_idx = None
232
-
233
- # Compute the normalized logprobs for the requested tokens.
234
- # Note that we pad a zero at the end for easy batching.
235
- input_token_logprobs = all_logprobs[
236
- torch.arange(all_logprobs.shape[0], device="cuda"),
237
- torch.cat(
238
- [
239
- torch.cat(pruned_input_ids)[1:],
240
- torch.tensor([0], device="cuda"),
241
- ]
242
- ),
243
- ]
244
- normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
245
- input_token_logprobs,
246
- logits_metadata,
247
- )
196
+ input_top_logprobs_val = input_top_logprobs_idx = None
197
+
198
+ # Compute the normalized logprobs for the requested tokens.
199
+ # Note that we pad a zero at the end for easy batching.
200
+ input_token_logprobs = input_logprobs[
201
+ torch.arange(input_logprobs.shape[0], device="cuda"),
202
+ torch.cat(
203
+ [
204
+ torch.cat(pruned_input_ids)[1:],
205
+ torch.tensor([0], device="cuda"),
206
+ ]
207
+ ),
208
+ ]
209
+ normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
210
+ input_token_logprobs,
211
+ logits_metadata,
212
+ )
248
213
 
249
- return LogitsProcessorOutput(
250
- next_token_logits=last_logits,
251
- next_token_logprobs=last_logprobs,
252
- normalized_prompt_logprobs=normalized_prompt_logprobs,
253
- input_token_logprobs=input_token_logprobs,
254
- input_top_logprobs_val=input_top_logprobs_val,
255
- input_top_logprobs_idx=input_top_logprobs_idx,
256
- output_top_logprobs_val=output_top_logprobs_val,
257
- output_top_logprobs_idx=output_top_logprobs_idx,
258
- )
214
+ return LogitsProcessorOutput(
215
+ next_token_logits=last_logits,
216
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
217
+ input_token_logprobs=input_token_logprobs,
218
+ input_top_logprobs_val=input_top_logprobs_val,
219
+ input_top_logprobs_idx=input_top_logprobs_idx,
220
+ )
259
221
 
260
222
  def _get_logits(
261
223
  self,
@@ -269,9 +231,19 @@ class LogitsProcessor(nn.Module):
269
231
  # GGUF models
270
232
  logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
271
233
 
272
- # Optional scaling factor
273
234
  if self.logit_scale is not None:
274
- logits.mul_(self.logit_scale) # In-place multiply
235
+ logits.mul_(self.logit_scale)
236
+
237
+ if self.do_tensor_parallel_all_gather:
238
+ logits = tensor_model_parallel_all_gather(logits)
239
+
240
+ # Compute the normalized logprobs for the requested tokens.
241
+ # Note that we pad a zero at the end for easy batching.
242
+ logits = logits[:, : self.config.vocab_size].float()
243
+
244
+ if self.final_logit_softcapping:
245
+ fused_softcap(logits, self.final_logit_softcapping)
246
+
275
247
  return logits
276
248
 
277
249
  @staticmethod
@@ -302,90 +274,73 @@ class LogitsProcessor(nn.Module):
302
274
  values = ret.values.tolist()
303
275
  indices = ret.indices.tolist()
304
276
 
305
- if logits_metadata.forward_mode.is_decode():
306
- output_top_logprobs_val = []
307
- output_top_logprobs_idx = []
308
- for i, k in enumerate(logits_metadata.top_logprobs_nums):
309
- output_top_logprobs_val.append(values[i][:k])
310
- output_top_logprobs_idx.append(indices[i][:k])
311
- return None, None, output_top_logprobs_val, output_top_logprobs_idx
312
- else:
313
- input_top_logprobs_val, input_top_logprobs_idx = [], []
314
- output_top_logprobs_val, output_top_logprobs_idx = [], []
277
+ input_top_logprobs_val, input_top_logprobs_idx = [], []
315
278
 
316
- pt = 0
317
- for k, pruned_len in zip(
318
- logits_metadata.top_logprobs_nums,
319
- logits_metadata.extend_logprob_pruned_lens_cpu,
320
- ):
321
- if pruned_len <= 0:
322
- input_top_logprobs_val.append([])
323
- input_top_logprobs_idx.append([])
324
- output_top_logprobs_val.append([])
325
- output_top_logprobs_idx.append([])
326
- continue
327
-
328
- input_top_logprobs_val.append(
329
- [values[pt + j][:k] for j in range(pruned_len - 1)]
330
- )
331
- input_top_logprobs_idx.append(
332
- [indices[pt + j][:k] for j in range(pruned_len - 1)]
333
- )
334
- output_top_logprobs_val.append(
335
- list(
336
- values[pt + pruned_len - 1][:k],
337
- )
338
- )
339
- output_top_logprobs_idx.append(
340
- list(
341
- indices[pt + pruned_len - 1][:k],
342
- )
343
- )
344
- pt += pruned_len
279
+ pt = 0
280
+ for k, pruned_len in zip(
281
+ logits_metadata.top_logprobs_nums,
282
+ logits_metadata.extend_logprob_pruned_lens_cpu,
283
+ ):
284
+ if pruned_len <= 0:
285
+ input_top_logprobs_val.append([])
286
+ input_top_logprobs_idx.append([])
287
+ continue
345
288
 
346
- return (
347
- input_top_logprobs_val,
348
- input_top_logprobs_idx,
349
- output_top_logprobs_val,
350
- output_top_logprobs_idx,
289
+ input_top_logprobs_val.append(
290
+ [values[pt + j][:k] for j in range(pruned_len - 1)]
351
291
  )
292
+ input_top_logprobs_idx.append(
293
+ [indices[pt + j][:k] for j in range(pruned_len - 1)]
294
+ )
295
+ pt += pruned_len
296
+
297
+ return input_top_logprobs_val, input_top_logprobs_idx
352
298
 
353
299
  @staticmethod
354
300
  def compute_temp_top_p_normalized_logprobs(
355
301
  last_logits: torch.Tensor, logits_metadata: LogitsMetadata
356
302
  ) -> torch.Tensor:
303
+ # TODO: Implement the temp and top-p normalization
357
304
  return torch.nn.functional.log_softmax(last_logits, dim=-1)
358
305
 
359
306
 
360
- def test():
361
- all_logprobs = torch.tensor(
362
- # s s s
363
- [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
364
- dtype=torch.float32,
365
- device="cuda",
307
+ @triton.jit
308
+ def fused_softcap_kernel(
309
+ full_logits_ptr,
310
+ softcapping_value,
311
+ n_elements,
312
+ BLOCK_SIZE: tl.constexpr,
313
+ ):
314
+ pid = tl.program_id(0)
315
+ block_start = pid * BLOCK_SIZE
316
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
317
+ mask = offsets < n_elements
318
+
319
+ # Load values
320
+ x = tl.load(full_logits_ptr + offsets, mask=mask)
321
+
322
+ # Perform operations in-place
323
+ x = x / softcapping_value
324
+
325
+ # Manual tanh implementation using exp
326
+ exp2x = tl.exp(2 * x)
327
+ x = (exp2x - 1) / (exp2x + 1)
328
+
329
+ x = x * softcapping_value
330
+
331
+ # Store result
332
+ tl.store(full_logits_ptr + offsets, x, mask=mask)
333
+
334
+
335
+ def fused_softcap(full_logits, final_logit_softcapping):
336
+ n_elements = full_logits.numel()
337
+ BLOCK_SIZE = 1024
338
+ grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
339
+
340
+ fused_softcap_kernel[grid](
341
+ full_logits_ptr=full_logits,
342
+ softcapping_value=final_logit_softcapping,
343
+ n_elements=n_elements,
344
+ BLOCK_SIZE=BLOCK_SIZE,
366
345
  )
367
- seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
368
- input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
369
-
370
- token_logprobs = all_logprobs[
371
- torch.arange(all_logprobs.shape[0], device="cuda"),
372
- torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
373
- ]
374
- logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
375
-
376
- len_cumsum = torch.cumsum(seq_lens, dim=0)
377
- start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
378
- end = start + seq_lens - 2
379
- start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
380
- end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
381
- sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
382
-
383
- # assert logprobs == [2, _, 2, 4, _]
384
- print("token logprobs", token_logprobs)
385
- print("start", start)
386
- print("end", end)
387
- print("sum_logp", sum_logp)
388
-
389
-
390
- if __name__ == "__main__":
391
- test()
346
+ return full_logits
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 4
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 4
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }