sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,8 @@ import dataclasses
|
|
17
17
|
from typing import List, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
|
+
import triton
|
21
|
+
import triton.language as tl
|
20
22
|
from torch import nn
|
21
23
|
from vllm.distributed import (
|
22
24
|
get_tensor_model_parallel_world_size,
|
@@ -33,76 +35,72 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
35
|
|
34
36
|
@dataclasses.dataclass
|
35
37
|
class LogitsProcessorOutput:
|
38
|
+
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
36
39
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
37
40
|
next_token_logits: torch.Tensor
|
38
|
-
#
|
39
|
-
|
41
|
+
# Used by speculative decoding (EAGLE)
|
42
|
+
# The last hidden layers
|
43
|
+
hidden_states: Optional[torch.Tensor] = None
|
40
44
|
|
45
|
+
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
46
|
+
# The logprobs of the next tokens. shape: [#seq]
|
47
|
+
next_token_logprobs: Optional[torch.Tensor] = None
|
48
|
+
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
49
|
+
next_token_top_logprobs_val: Optional[List] = None
|
50
|
+
next_token_top_logprobs_idx: Optional[List] = None
|
51
|
+
|
52
|
+
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
41
53
|
# The normlaized logprobs of prompts. shape: [#seq]
|
42
54
|
normalized_prompt_logprobs: torch.Tensor = None
|
43
|
-
# The logprobs of input tokens. shape: [#token
|
55
|
+
# The logprobs of input tokens. shape: [#token]
|
44
56
|
input_token_logprobs: torch.Tensor = None
|
45
|
-
|
46
|
-
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
57
|
+
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
47
58
|
input_top_logprobs_val: List = None
|
48
59
|
input_top_logprobs_idx: List = None
|
49
|
-
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
50
|
-
output_top_logprobs_val: List = None
|
51
|
-
output_top_logprobs_idx: List = None
|
52
|
-
|
53
|
-
# Used by speculative decoding (EAGLE)
|
54
|
-
# The output of transformer layers
|
55
|
-
hidden_states: Optional[torch.Tensor] = None
|
56
60
|
|
57
61
|
|
58
62
|
@dataclasses.dataclass
|
59
63
|
class LogitsMetadata:
|
60
64
|
forward_mode: ForwardMode
|
61
|
-
|
62
|
-
|
63
|
-
return_logprob: bool = False
|
64
|
-
return_top_logprob: bool = False
|
65
|
+
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
65
66
|
|
67
|
+
extend_return_logprob: bool = False
|
68
|
+
extend_return_top_logprob: bool = False
|
66
69
|
extend_seq_lens: Optional[torch.Tensor] = None
|
67
70
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
68
|
-
|
69
71
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
70
72
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
71
|
-
|
72
|
-
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
73
|
+
top_logprobs_nums: Optional[List[int]] = None
|
73
74
|
|
74
75
|
@classmethod
|
75
76
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
else:
|
89
|
-
return_top_logprob = False
|
90
|
-
|
91
|
-
if forward_batch.spec_info:
|
92
|
-
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
77
|
+
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
78
|
+
extend_return_logprob = True
|
79
|
+
extend_return_top_logprob = any(
|
80
|
+
x > 0 for x in forward_batch.top_logprobs_nums
|
81
|
+
)
|
82
|
+
extend_logprob_pruned_lens_cpu = [
|
83
|
+
extend_len - start_len
|
84
|
+
for extend_len, start_len in zip(
|
85
|
+
forward_batch.extend_seq_lens_cpu,
|
86
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
87
|
+
)
|
88
|
+
]
|
93
89
|
else:
|
94
|
-
|
90
|
+
extend_return_logprob = extend_return_top_logprob = (
|
91
|
+
extend_logprob_pruned_lens_cpu
|
92
|
+
) = False
|
95
93
|
|
96
94
|
return cls(
|
97
95
|
forward_mode=forward_batch.forward_mode,
|
98
|
-
|
99
|
-
|
100
|
-
|
96
|
+
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
97
|
+
extend_return_logprob=extend_return_logprob,
|
98
|
+
extend_return_top_logprob=extend_return_top_logprob,
|
101
99
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
102
100
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
103
101
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
104
102
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
105
|
-
|
103
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
106
104
|
)
|
107
105
|
|
108
106
|
|
@@ -119,6 +117,11 @@ class LogitsProcessor(nn.Module):
|
|
119
117
|
self.final_logit_softcapping = getattr(
|
120
118
|
self.config, "final_logit_softcapping", None
|
121
119
|
)
|
120
|
+
if (
|
121
|
+
self.final_logit_softcapping is not None
|
122
|
+
and self.final_logit_softcapping < 0
|
123
|
+
):
|
124
|
+
self.final_logit_softcapping = None
|
122
125
|
|
123
126
|
def forward(
|
124
127
|
self,
|
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
|
|
129
132
|
):
|
130
133
|
if isinstance(logits_metadata, ForwardBatch):
|
131
134
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
132
|
-
assert isinstance(logits_metadata, LogitsMetadata)
|
133
135
|
|
134
136
|
# Get the last hidden states and last logits for the next token prediction
|
135
137
|
if (
|
@@ -142,18 +144,13 @@ class LogitsProcessor(nn.Module):
|
|
142
144
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
143
145
|
last_hidden = hidden_states[last_index]
|
144
146
|
|
147
|
+
# Compute logits
|
145
148
|
last_logits = self._get_logits(last_hidden, lm_head)
|
146
|
-
if
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
last_logits.div_(self.final_logit_softcapping)
|
152
|
-
torch.tanh(last_logits, out=last_logits)
|
153
|
-
last_logits.mul_(self.final_logit_softcapping)
|
154
|
-
|
155
|
-
# Return only last_logits if logprob is not requested
|
156
|
-
if not logits_metadata.return_logprob:
|
149
|
+
if (
|
150
|
+
not logits_metadata.extend_return_logprob
|
151
|
+
or logits_metadata.capture_hidden_mode.need_capture()
|
152
|
+
):
|
153
|
+
# Decode mode or extend mode without return_logprob.
|
157
154
|
return LogitsProcessorOutput(
|
158
155
|
next_token_logits=last_logits,
|
159
156
|
hidden_states=(
|
@@ -167,95 +164,60 @@ class LogitsProcessor(nn.Module):
|
|
167
164
|
),
|
168
165
|
)
|
169
166
|
else:
|
170
|
-
|
171
|
-
|
167
|
+
# Slice the requested tokens to compute logprob
|
168
|
+
pt, pruned_states, pruned_input_ids = 0, [], []
|
169
|
+
for start_len, extend_len in zip(
|
170
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
171
|
+
logits_metadata.extend_seq_lens_cpu,
|
172
|
+
):
|
173
|
+
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
174
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
175
|
+
pt += extend_len
|
176
|
+
|
177
|
+
# Compute the logits of all required tokens
|
178
|
+
pruned_states = torch.cat(pruned_states)
|
179
|
+
del hidden_states
|
180
|
+
input_token_logits = self._get_logits(pruned_states, lm_head)
|
181
|
+
del pruned_states
|
182
|
+
|
183
|
+
# Normalize the logprob w/o temperature, top-p
|
184
|
+
input_logprobs = input_token_logits
|
185
|
+
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
186
|
+
input_logprobs, logits_metadata
|
172
187
|
)
|
173
188
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
output_top_logprobs_val = output_top_logprobs_idx = None
|
181
|
-
return LogitsProcessorOutput(
|
182
|
-
next_token_logits=last_logits,
|
183
|
-
next_token_logprobs=last_logprobs,
|
184
|
-
output_top_logprobs_val=output_top_logprobs_val,
|
185
|
-
output_top_logprobs_idx=output_top_logprobs_idx,
|
186
|
-
)
|
189
|
+
# Get the logprob of top-k tokens
|
190
|
+
if logits_metadata.extend_return_top_logprob:
|
191
|
+
(
|
192
|
+
input_top_logprobs_val,
|
193
|
+
input_top_logprobs_idx,
|
194
|
+
) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
187
195
|
else:
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
)
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
# extra logits that this padding may have produced.
|
206
|
-
all_logits = all_logits[:, : self.config.vocab_size].float()
|
207
|
-
|
208
|
-
if self.final_logit_softcapping:
|
209
|
-
all_logits.div_(self.final_logit_softcapping)
|
210
|
-
torch.tanh(all_logits, out=all_logits)
|
211
|
-
all_logits.mul_(self.final_logit_softcapping)
|
212
|
-
|
213
|
-
all_logprobs = all_logits
|
214
|
-
del all_logits, hidden_states
|
215
|
-
|
216
|
-
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
217
|
-
all_logprobs, logits_metadata
|
218
|
-
)
|
219
|
-
|
220
|
-
# Get the logprob of top-k tokens
|
221
|
-
if logits_metadata.return_top_logprob:
|
222
|
-
(
|
223
|
-
input_top_logprobs_val,
|
224
|
-
input_top_logprobs_idx,
|
225
|
-
output_top_logprobs_val,
|
226
|
-
output_top_logprobs_idx,
|
227
|
-
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
228
|
-
else:
|
229
|
-
input_top_logprobs_val = input_top_logprobs_idx = (
|
230
|
-
output_top_logprobs_val
|
231
|
-
) = output_top_logprobs_idx = None
|
232
|
-
|
233
|
-
# Compute the normalized logprobs for the requested tokens.
|
234
|
-
# Note that we pad a zero at the end for easy batching.
|
235
|
-
input_token_logprobs = all_logprobs[
|
236
|
-
torch.arange(all_logprobs.shape[0], device="cuda"),
|
237
|
-
torch.cat(
|
238
|
-
[
|
239
|
-
torch.cat(pruned_input_ids)[1:],
|
240
|
-
torch.tensor([0], device="cuda"),
|
241
|
-
]
|
242
|
-
),
|
243
|
-
]
|
244
|
-
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
245
|
-
input_token_logprobs,
|
246
|
-
logits_metadata,
|
247
|
-
)
|
196
|
+
input_top_logprobs_val = input_top_logprobs_idx = None
|
197
|
+
|
198
|
+
# Compute the normalized logprobs for the requested tokens.
|
199
|
+
# Note that we pad a zero at the end for easy batching.
|
200
|
+
input_token_logprobs = input_logprobs[
|
201
|
+
torch.arange(input_logprobs.shape[0], device="cuda"),
|
202
|
+
torch.cat(
|
203
|
+
[
|
204
|
+
torch.cat(pruned_input_ids)[1:],
|
205
|
+
torch.tensor([0], device="cuda"),
|
206
|
+
]
|
207
|
+
),
|
208
|
+
]
|
209
|
+
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
210
|
+
input_token_logprobs,
|
211
|
+
logits_metadata,
|
212
|
+
)
|
248
213
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
output_top_logprobs_val=output_top_logprobs_val,
|
257
|
-
output_top_logprobs_idx=output_top_logprobs_idx,
|
258
|
-
)
|
214
|
+
return LogitsProcessorOutput(
|
215
|
+
next_token_logits=last_logits,
|
216
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
217
|
+
input_token_logprobs=input_token_logprobs,
|
218
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
219
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
220
|
+
)
|
259
221
|
|
260
222
|
def _get_logits(
|
261
223
|
self,
|
@@ -269,9 +231,19 @@ class LogitsProcessor(nn.Module):
|
|
269
231
|
# GGUF models
|
270
232
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
271
233
|
|
272
|
-
# Optional scaling factor
|
273
234
|
if self.logit_scale is not None:
|
274
|
-
logits.mul_(self.logit_scale)
|
235
|
+
logits.mul_(self.logit_scale)
|
236
|
+
|
237
|
+
if self.do_tensor_parallel_all_gather:
|
238
|
+
logits = tensor_model_parallel_all_gather(logits)
|
239
|
+
|
240
|
+
# Compute the normalized logprobs for the requested tokens.
|
241
|
+
# Note that we pad a zero at the end for easy batching.
|
242
|
+
logits = logits[:, : self.config.vocab_size].float()
|
243
|
+
|
244
|
+
if self.final_logit_softcapping:
|
245
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
246
|
+
|
275
247
|
return logits
|
276
248
|
|
277
249
|
@staticmethod
|
@@ -302,90 +274,73 @@ class LogitsProcessor(nn.Module):
|
|
302
274
|
values = ret.values.tolist()
|
303
275
|
indices = ret.indices.tolist()
|
304
276
|
|
305
|
-
|
306
|
-
output_top_logprobs_val = []
|
307
|
-
output_top_logprobs_idx = []
|
308
|
-
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
309
|
-
output_top_logprobs_val.append(values[i][:k])
|
310
|
-
output_top_logprobs_idx.append(indices[i][:k])
|
311
|
-
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
312
|
-
else:
|
313
|
-
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
314
|
-
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
277
|
+
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
315
278
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
output_top_logprobs_idx.append([])
|
326
|
-
continue
|
327
|
-
|
328
|
-
input_top_logprobs_val.append(
|
329
|
-
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
330
|
-
)
|
331
|
-
input_top_logprobs_idx.append(
|
332
|
-
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
333
|
-
)
|
334
|
-
output_top_logprobs_val.append(
|
335
|
-
list(
|
336
|
-
values[pt + pruned_len - 1][:k],
|
337
|
-
)
|
338
|
-
)
|
339
|
-
output_top_logprobs_idx.append(
|
340
|
-
list(
|
341
|
-
indices[pt + pruned_len - 1][:k],
|
342
|
-
)
|
343
|
-
)
|
344
|
-
pt += pruned_len
|
279
|
+
pt = 0
|
280
|
+
for k, pruned_len in zip(
|
281
|
+
logits_metadata.top_logprobs_nums,
|
282
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
283
|
+
):
|
284
|
+
if pruned_len <= 0:
|
285
|
+
input_top_logprobs_val.append([])
|
286
|
+
input_top_logprobs_idx.append([])
|
287
|
+
continue
|
345
288
|
|
346
|
-
|
347
|
-
|
348
|
-
input_top_logprobs_idx,
|
349
|
-
output_top_logprobs_val,
|
350
|
-
output_top_logprobs_idx,
|
289
|
+
input_top_logprobs_val.append(
|
290
|
+
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
351
291
|
)
|
292
|
+
input_top_logprobs_idx.append(
|
293
|
+
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
294
|
+
)
|
295
|
+
pt += pruned_len
|
296
|
+
|
297
|
+
return input_top_logprobs_val, input_top_logprobs_idx
|
352
298
|
|
353
299
|
@staticmethod
|
354
300
|
def compute_temp_top_p_normalized_logprobs(
|
355
301
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
356
302
|
) -> torch.Tensor:
|
303
|
+
# TODO: Implement the temp and top-p normalization
|
357
304
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
358
305
|
|
359
306
|
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
307
|
+
@triton.jit
|
308
|
+
def fused_softcap_kernel(
|
309
|
+
full_logits_ptr,
|
310
|
+
softcapping_value,
|
311
|
+
n_elements,
|
312
|
+
BLOCK_SIZE: tl.constexpr,
|
313
|
+
):
|
314
|
+
pid = tl.program_id(0)
|
315
|
+
block_start = pid * BLOCK_SIZE
|
316
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
317
|
+
mask = offsets < n_elements
|
318
|
+
|
319
|
+
# Load values
|
320
|
+
x = tl.load(full_logits_ptr + offsets, mask=mask)
|
321
|
+
|
322
|
+
# Perform operations in-place
|
323
|
+
x = x / softcapping_value
|
324
|
+
|
325
|
+
# Manual tanh implementation using exp
|
326
|
+
exp2x = tl.exp(2 * x)
|
327
|
+
x = (exp2x - 1) / (exp2x + 1)
|
328
|
+
|
329
|
+
x = x * softcapping_value
|
330
|
+
|
331
|
+
# Store result
|
332
|
+
tl.store(full_logits_ptr + offsets, x, mask=mask)
|
333
|
+
|
334
|
+
|
335
|
+
def fused_softcap(full_logits, final_logit_softcapping):
|
336
|
+
n_elements = full_logits.numel()
|
337
|
+
BLOCK_SIZE = 1024
|
338
|
+
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
|
339
|
+
|
340
|
+
fused_softcap_kernel[grid](
|
341
|
+
full_logits_ptr=full_logits,
|
342
|
+
softcapping_value=final_logit_softcapping,
|
343
|
+
n_elements=n_elements,
|
344
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
366
345
|
)
|
367
|
-
|
368
|
-
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
369
|
-
|
370
|
-
token_logprobs = all_logprobs[
|
371
|
-
torch.arange(all_logprobs.shape[0], device="cuda"),
|
372
|
-
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
373
|
-
]
|
374
|
-
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
375
|
-
|
376
|
-
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
377
|
-
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
378
|
-
end = start + seq_lens - 2
|
379
|
-
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
380
|
-
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
381
|
-
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
382
|
-
|
383
|
-
# assert logprobs == [2, _, 2, 4, _]
|
384
|
-
print("token logprobs", token_logprobs)
|
385
|
-
print("start", start)
|
386
|
-
print("end", end)
|
387
|
-
print("sum_logp", sum_logp)
|
388
|
-
|
389
|
-
|
390
|
-
if __name__ == "__main__":
|
391
|
-
test()
|
346
|
+
return full_logits
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 128,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 256,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 128,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 128,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 8,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 8,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|