sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +62 -26
- sglang/srt/managers/tokenizer_manager.py +22 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +108 -57
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
3
|
+
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import torch
|
6
6
|
from torch.nn.functional import scaled_dot_product_attention
|
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
23
23
|
"""Init the metadata for a forward pass."""
|
24
24
|
pass
|
25
25
|
|
26
|
-
def init_cuda_graph_state(self, max_bs: int):
|
27
|
-
# TODO: Support CUDA graph
|
28
|
-
raise ValueError(
|
29
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
30
|
-
)
|
31
|
-
|
32
|
-
def init_forward_metadata_capture_cuda_graph(
|
33
|
-
self,
|
34
|
-
bs: int,
|
35
|
-
req_pool_indices: torch.Tensor,
|
36
|
-
seq_lens: torch.Tensor,
|
37
|
-
encoder_lens: Optional[torch.Tensor] = None,
|
38
|
-
):
|
39
|
-
# TODO: Support CUDA graph
|
40
|
-
raise ValueError(
|
41
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
42
|
-
)
|
43
|
-
|
44
|
-
def init_forward_metadata_replay_cuda_graph(
|
45
|
-
self,
|
46
|
-
bs: int,
|
47
|
-
req_pool_indices: torch.Tensor,
|
48
|
-
seq_lens: torch.Tensor,
|
49
|
-
seq_lens_sum: int,
|
50
|
-
encoder_lens: Optional[torch.Tensor] = None,
|
51
|
-
):
|
52
|
-
# TODO: Support CUDA graph
|
53
|
-
raise ValueError(
|
54
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
55
|
-
)
|
56
|
-
|
57
|
-
def get_cuda_graph_seq_len_fill_value(self):
|
58
|
-
# TODO: Support CUDA graph
|
59
|
-
raise ValueError(
|
60
|
-
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
61
|
-
)
|
62
|
-
|
63
26
|
def _run_sdpa_forward_extend(
|
64
27
|
self,
|
65
28
|
query: torch.Tensor,
|
@@ -1,15 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
8
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
11
|
from sglang.srt.layers.radix_attention import RadixAttention
|
12
12
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
13
|
+
from sglang.srt.speculative.spec_info import SpecInfo
|
13
14
|
|
14
15
|
|
15
16
|
class TritonAttnBackend(AttentionBackend):
|
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
|
|
80
81
|
def init_forward_metadata_capture_cuda_graph(
|
81
82
|
self,
|
82
83
|
bs: int,
|
84
|
+
num_tokens: int,
|
83
85
|
req_pool_indices: torch.Tensor,
|
84
86
|
seq_lens: torch.Tensor,
|
85
|
-
encoder_lens
|
87
|
+
encoder_lens: Optional[torch.Tensor],
|
88
|
+
forward_mode: ForwardMode,
|
89
|
+
spec_info: Optional[SpecInfo],
|
86
90
|
):
|
87
|
-
|
91
|
+
assert encoder_lens is None, "Not supported"
|
92
|
+
assert forward_mode.is_decode(), "Not supported"
|
93
|
+
assert spec_info is None, "Not supported"
|
94
|
+
|
88
95
|
self.forward_metadata = (
|
89
96
|
self.cuda_graph_attn_logits,
|
90
97
|
None,
|
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
96
103
|
req_pool_indices: torch.Tensor,
|
97
104
|
seq_lens: torch.Tensor,
|
98
105
|
seq_lens_sum: int,
|
99
|
-
encoder_lens
|
106
|
+
encoder_lens: Optional[torch.Tensor],
|
107
|
+
forward_mode: ForwardMode,
|
108
|
+
spec_info: Optional[SpecInfo],
|
100
109
|
):
|
101
110
|
# NOTE: encoder_lens expected to be zeros or None
|
102
111
|
self.cuda_graph_start_loc.zero_()
|
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
107
116
|
|
108
117
|
def forward_extend(
|
109
118
|
self,
|
110
|
-
q,
|
111
|
-
k,
|
112
|
-
v,
|
119
|
+
q: torch.Tensor,
|
120
|
+
k: torch.Tensor,
|
121
|
+
v: torch.Tensor,
|
113
122
|
layer: RadixAttention,
|
114
123
|
forward_batch: ForwardBatch,
|
115
124
|
save_kv_cache=True,
|
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
146
155
|
|
147
156
|
def forward_decode(
|
148
157
|
self,
|
149
|
-
q,
|
150
|
-
k,
|
151
|
-
v,
|
158
|
+
q: torch.Tensor,
|
159
|
+
k: torch.Tensor,
|
160
|
+
v: torch.Tensor,
|
152
161
|
layer: RadixAttention,
|
153
162
|
forward_batch: ForwardBatch,
|
154
163
|
save_kv_cache=True,
|
@@ -17,6 +17,8 @@ import dataclasses
|
|
17
17
|
from typing import List, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
|
+
import triton
|
21
|
+
import triton.language as tl
|
20
22
|
from torch import nn
|
21
23
|
from vllm.distributed import (
|
22
24
|
get_tensor_model_parallel_world_size,
|
@@ -33,76 +35,77 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
35
|
|
34
36
|
@dataclasses.dataclass
|
35
37
|
class LogitsProcessorOutput:
|
38
|
+
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
36
39
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
37
40
|
next_token_logits: torch.Tensor
|
38
|
-
#
|
39
|
-
|
41
|
+
# Used by speculative decoding (EAGLE)
|
42
|
+
# The last hidden layers
|
43
|
+
hidden_states: Optional[torch.Tensor] = None
|
44
|
+
|
45
|
+
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
46
|
+
# The logprobs of the next tokens. shape: [#seq]
|
47
|
+
next_token_logprobs: Optional[torch.Tensor] = None
|
48
|
+
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
49
|
+
next_token_top_logprobs_val: Optional[List] = None
|
50
|
+
next_token_top_logprobs_idx: Optional[List] = None
|
40
51
|
|
52
|
+
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
41
53
|
# The normlaized logprobs of prompts. shape: [#seq]
|
42
54
|
normalized_prompt_logprobs: torch.Tensor = None
|
43
|
-
# The logprobs of input tokens. shape: [#token
|
55
|
+
# The logprobs of input tokens. shape: [#token]
|
44
56
|
input_token_logprobs: torch.Tensor = None
|
45
|
-
|
46
|
-
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
57
|
+
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
47
58
|
input_top_logprobs_val: List = None
|
48
59
|
input_top_logprobs_idx: List = None
|
49
|
-
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
50
|
-
output_top_logprobs_val: List = None
|
51
|
-
output_top_logprobs_idx: List = None
|
52
|
-
|
53
|
-
# Used by speculative decoding (EAGLE)
|
54
|
-
# The output of transformer layers
|
55
|
-
hidden_states: Optional[torch.Tensor] = None
|
56
60
|
|
57
61
|
|
58
62
|
@dataclasses.dataclass
|
59
63
|
class LogitsMetadata:
|
60
64
|
forward_mode: ForwardMode
|
61
|
-
|
62
|
-
|
63
|
-
return_logprob: bool = False
|
64
|
-
return_top_logprob: bool = False
|
65
|
+
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
65
66
|
|
67
|
+
extend_return_logprob: bool = False
|
68
|
+
extend_return_top_logprob: bool = False
|
66
69
|
extend_seq_lens: Optional[torch.Tensor] = None
|
67
70
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
68
|
-
|
69
71
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
70
72
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
71
|
-
|
72
|
-
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
73
|
+
top_logprobs_nums: Optional[List[int]] = None
|
73
74
|
|
74
75
|
@classmethod
|
75
76
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
76
|
-
extend_logprob_pruned_lens_cpu = None
|
77
|
-
|
78
|
-
if forward_batch.return_logprob:
|
79
|
-
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
80
|
-
if forward_batch.forward_mode.is_extend():
|
81
|
-
extend_logprob_pruned_lens_cpu = [
|
82
|
-
extend_len - start_len
|
83
|
-
for extend_len, start_len in zip(
|
84
|
-
forward_batch.extend_seq_lens_cpu,
|
85
|
-
forward_batch.extend_logprob_start_lens_cpu,
|
86
|
-
)
|
87
|
-
]
|
88
|
-
else:
|
89
|
-
return_top_logprob = False
|
90
|
-
|
91
77
|
if forward_batch.spec_info:
|
92
78
|
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
93
79
|
else:
|
94
80
|
capture_hidden_mode = CaptureHiddenMode.NULL
|
95
81
|
|
82
|
+
if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
|
83
|
+
extend_return_logprob = True
|
84
|
+
extend_return_top_logprob = any(
|
85
|
+
x > 0 for x in forward_batch.top_logprobs_nums
|
86
|
+
)
|
87
|
+
extend_logprob_pruned_lens_cpu = [
|
88
|
+
extend_len - start_len
|
89
|
+
for extend_len, start_len in zip(
|
90
|
+
forward_batch.extend_seq_lens_cpu,
|
91
|
+
forward_batch.extend_logprob_start_lens_cpu,
|
92
|
+
)
|
93
|
+
]
|
94
|
+
else:
|
95
|
+
extend_return_logprob = extend_return_top_logprob = (
|
96
|
+
extend_logprob_pruned_lens_cpu
|
97
|
+
) = False
|
98
|
+
|
96
99
|
return cls(
|
97
100
|
forward_mode=forward_batch.forward_mode,
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
+
capture_hidden_mode=capture_hidden_mode,
|
102
|
+
extend_return_logprob=extend_return_logprob,
|
103
|
+
extend_return_top_logprob=extend_return_top_logprob,
|
101
104
|
extend_seq_lens=forward_batch.extend_seq_lens,
|
102
105
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
103
106
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
104
107
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
105
|
-
|
108
|
+
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
106
109
|
)
|
107
110
|
|
108
111
|
|
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
|
|
129
132
|
):
|
130
133
|
if isinstance(logits_metadata, ForwardBatch):
|
131
134
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
132
|
-
assert isinstance(logits_metadata, LogitsMetadata)
|
133
135
|
|
134
136
|
# Get the last hidden states and last logits for the next token prediction
|
135
137
|
if (
|
@@ -142,18 +144,13 @@ class LogitsProcessor(nn.Module):
|
|
142
144
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
143
145
|
last_hidden = hidden_states[last_index]
|
144
146
|
|
147
|
+
# Compute logits
|
145
148
|
last_logits = self._get_logits(last_hidden, lm_head)
|
146
|
-
if
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
last_logits.div_(self.final_logit_softcapping)
|
152
|
-
torch.tanh(last_logits, out=last_logits)
|
153
|
-
last_logits.mul_(self.final_logit_softcapping)
|
154
|
-
|
155
|
-
# Return only last_logits if logprob is not requested
|
156
|
-
if not logits_metadata.return_logprob:
|
149
|
+
if (
|
150
|
+
not logits_metadata.extend_return_logprob
|
151
|
+
or logits_metadata.capture_hidden_mode.need_capture()
|
152
|
+
):
|
153
|
+
# Decode mode or extend mode without return_logprob.
|
157
154
|
return LogitsProcessorOutput(
|
158
155
|
next_token_logits=last_logits,
|
159
156
|
hidden_states=(
|
@@ -167,95 +164,60 @@ class LogitsProcessor(nn.Module):
|
|
167
164
|
),
|
168
165
|
)
|
169
166
|
else:
|
170
|
-
|
171
|
-
|
167
|
+
# Slice the requested tokens to compute logprob
|
168
|
+
pt, pruned_states, pruned_input_ids = 0, [], []
|
169
|
+
for start_len, extend_len in zip(
|
170
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
171
|
+
logits_metadata.extend_seq_lens_cpu,
|
172
|
+
):
|
173
|
+
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
174
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
175
|
+
pt += extend_len
|
176
|
+
|
177
|
+
# Compute the logits of all required tokens
|
178
|
+
pruned_states = torch.cat(pruned_states)
|
179
|
+
del hidden_states
|
180
|
+
input_token_logits = self._get_logits(pruned_states, lm_head)
|
181
|
+
del pruned_states
|
182
|
+
|
183
|
+
# Normalize the logprob w/o temperature, top-p
|
184
|
+
input_logprobs = input_token_logits
|
185
|
+
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
186
|
+
input_logprobs, logits_metadata
|
172
187
|
)
|
173
188
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
output_top_logprobs_val = output_top_logprobs_idx = None
|
181
|
-
return LogitsProcessorOutput(
|
182
|
-
next_token_logits=last_logits,
|
183
|
-
next_token_logprobs=last_logprobs,
|
184
|
-
output_top_logprobs_val=output_top_logprobs_val,
|
185
|
-
output_top_logprobs_idx=output_top_logprobs_idx,
|
186
|
-
)
|
189
|
+
# Get the logprob of top-k tokens
|
190
|
+
if logits_metadata.extend_return_top_logprob:
|
191
|
+
(
|
192
|
+
input_top_logprobs_val,
|
193
|
+
input_top_logprobs_idx,
|
194
|
+
) = self.get_top_logprobs(input_logprobs, logits_metadata)
|
187
195
|
else:
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
)
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
# extra logits that this padding may have produced.
|
206
|
-
all_logits = all_logits[:, : self.config.vocab_size].float()
|
207
|
-
|
208
|
-
if self.final_logit_softcapping:
|
209
|
-
all_logits.div_(self.final_logit_softcapping)
|
210
|
-
torch.tanh(all_logits, out=all_logits)
|
211
|
-
all_logits.mul_(self.final_logit_softcapping)
|
212
|
-
|
213
|
-
all_logprobs = all_logits
|
214
|
-
del all_logits, hidden_states
|
215
|
-
|
216
|
-
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
217
|
-
all_logprobs, logits_metadata
|
218
|
-
)
|
219
|
-
|
220
|
-
# Get the logprob of top-k tokens
|
221
|
-
if logits_metadata.return_top_logprob:
|
222
|
-
(
|
223
|
-
input_top_logprobs_val,
|
224
|
-
input_top_logprobs_idx,
|
225
|
-
output_top_logprobs_val,
|
226
|
-
output_top_logprobs_idx,
|
227
|
-
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
228
|
-
else:
|
229
|
-
input_top_logprobs_val = input_top_logprobs_idx = (
|
230
|
-
output_top_logprobs_val
|
231
|
-
) = output_top_logprobs_idx = None
|
232
|
-
|
233
|
-
# Compute the normalized logprobs for the requested tokens.
|
234
|
-
# Note that we pad a zero at the end for easy batching.
|
235
|
-
input_token_logprobs = all_logprobs[
|
236
|
-
torch.arange(all_logprobs.shape[0], device="cuda"),
|
237
|
-
torch.cat(
|
238
|
-
[
|
239
|
-
torch.cat(pruned_input_ids)[1:],
|
240
|
-
torch.tensor([0], device="cuda"),
|
241
|
-
]
|
242
|
-
),
|
243
|
-
]
|
244
|
-
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
245
|
-
input_token_logprobs,
|
246
|
-
logits_metadata,
|
247
|
-
)
|
196
|
+
input_top_logprobs_val = input_top_logprobs_idx = None
|
197
|
+
|
198
|
+
# Compute the normalized logprobs for the requested tokens.
|
199
|
+
# Note that we pad a zero at the end for easy batching.
|
200
|
+
input_token_logprobs = input_logprobs[
|
201
|
+
torch.arange(input_logprobs.shape[0], device="cuda"),
|
202
|
+
torch.cat(
|
203
|
+
[
|
204
|
+
torch.cat(pruned_input_ids)[1:],
|
205
|
+
torch.tensor([0], device="cuda"),
|
206
|
+
]
|
207
|
+
),
|
208
|
+
]
|
209
|
+
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
210
|
+
input_token_logprobs,
|
211
|
+
logits_metadata,
|
212
|
+
)
|
248
213
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
output_top_logprobs_val=output_top_logprobs_val,
|
257
|
-
output_top_logprobs_idx=output_top_logprobs_idx,
|
258
|
-
)
|
214
|
+
return LogitsProcessorOutput(
|
215
|
+
next_token_logits=last_logits,
|
216
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
217
|
+
input_token_logprobs=input_token_logprobs,
|
218
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
219
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
220
|
+
)
|
259
221
|
|
260
222
|
def _get_logits(
|
261
223
|
self,
|
@@ -269,9 +231,19 @@ class LogitsProcessor(nn.Module):
|
|
269
231
|
# GGUF models
|
270
232
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
271
233
|
|
272
|
-
# Optional scaling factor
|
273
234
|
if self.logit_scale is not None:
|
274
|
-
logits.mul_(self.logit_scale)
|
235
|
+
logits.mul_(self.logit_scale)
|
236
|
+
|
237
|
+
if self.do_tensor_parallel_all_gather:
|
238
|
+
logits = tensor_model_parallel_all_gather(logits)
|
239
|
+
|
240
|
+
# Compute the normalized logprobs for the requested tokens.
|
241
|
+
# Note that we pad a zero at the end for easy batching.
|
242
|
+
logits = logits[:, : self.config.vocab_size].float()
|
243
|
+
|
244
|
+
if self.final_logit_softcapping:
|
245
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
246
|
+
|
275
247
|
return logits
|
276
248
|
|
277
249
|
@staticmethod
|
@@ -302,90 +274,73 @@ class LogitsProcessor(nn.Module):
|
|
302
274
|
values = ret.values.tolist()
|
303
275
|
indices = ret.indices.tolist()
|
304
276
|
|
305
|
-
|
306
|
-
output_top_logprobs_val = []
|
307
|
-
output_top_logprobs_idx = []
|
308
|
-
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
309
|
-
output_top_logprobs_val.append(values[i][:k])
|
310
|
-
output_top_logprobs_idx.append(indices[i][:k])
|
311
|
-
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
312
|
-
else:
|
313
|
-
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
314
|
-
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
277
|
+
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
315
278
|
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
output_top_logprobs_idx.append([])
|
326
|
-
continue
|
327
|
-
|
328
|
-
input_top_logprobs_val.append(
|
329
|
-
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
330
|
-
)
|
331
|
-
input_top_logprobs_idx.append(
|
332
|
-
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
333
|
-
)
|
334
|
-
output_top_logprobs_val.append(
|
335
|
-
list(
|
336
|
-
values[pt + pruned_len - 1][:k],
|
337
|
-
)
|
338
|
-
)
|
339
|
-
output_top_logprobs_idx.append(
|
340
|
-
list(
|
341
|
-
indices[pt + pruned_len - 1][:k],
|
342
|
-
)
|
343
|
-
)
|
344
|
-
pt += pruned_len
|
279
|
+
pt = 0
|
280
|
+
for k, pruned_len in zip(
|
281
|
+
logits_metadata.top_logprobs_nums,
|
282
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
283
|
+
):
|
284
|
+
if pruned_len <= 0:
|
285
|
+
input_top_logprobs_val.append([])
|
286
|
+
input_top_logprobs_idx.append([])
|
287
|
+
continue
|
345
288
|
|
346
|
-
|
347
|
-
|
348
|
-
input_top_logprobs_idx,
|
349
|
-
output_top_logprobs_val,
|
350
|
-
output_top_logprobs_idx,
|
289
|
+
input_top_logprobs_val.append(
|
290
|
+
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
351
291
|
)
|
292
|
+
input_top_logprobs_idx.append(
|
293
|
+
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
294
|
+
)
|
295
|
+
pt += pruned_len
|
296
|
+
|
297
|
+
return input_top_logprobs_val, input_top_logprobs_idx
|
352
298
|
|
353
299
|
@staticmethod
|
354
300
|
def compute_temp_top_p_normalized_logprobs(
|
355
301
|
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
356
302
|
) -> torch.Tensor:
|
303
|
+
# TODO: Implement the temp and top-p normalization
|
357
304
|
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
358
305
|
|
359
306
|
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
307
|
+
@triton.jit
|
308
|
+
def fused_softcap_kernel(
|
309
|
+
full_logits_ptr,
|
310
|
+
softcapping_value,
|
311
|
+
n_elements,
|
312
|
+
BLOCK_SIZE: tl.constexpr,
|
313
|
+
):
|
314
|
+
pid = tl.program_id(0)
|
315
|
+
block_start = pid * BLOCK_SIZE
|
316
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
317
|
+
mask = offsets < n_elements
|
318
|
+
|
319
|
+
# Load values
|
320
|
+
x = tl.load(full_logits_ptr + offsets, mask=mask)
|
321
|
+
|
322
|
+
# Perform operations in-place
|
323
|
+
x = x / softcapping_value
|
324
|
+
|
325
|
+
# Manual tanh implementation using exp
|
326
|
+
exp2x = tl.exp(2 * x)
|
327
|
+
x = (exp2x - 1) / (exp2x + 1)
|
328
|
+
|
329
|
+
x = x * softcapping_value
|
330
|
+
|
331
|
+
# Store result
|
332
|
+
tl.store(full_logits_ptr + offsets, x, mask=mask)
|
333
|
+
|
334
|
+
|
335
|
+
def fused_softcap(full_logits, final_logit_softcapping):
|
336
|
+
n_elements = full_logits.numel()
|
337
|
+
BLOCK_SIZE = 1024
|
338
|
+
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
|
339
|
+
|
340
|
+
fused_softcap_kernel[grid](
|
341
|
+
full_logits_ptr=full_logits,
|
342
|
+
softcapping_value=final_logit_softcapping,
|
343
|
+
n_elements=n_elements,
|
344
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
366
345
|
)
|
367
|
-
|
368
|
-
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
369
|
-
|
370
|
-
token_logprobs = all_logprobs[
|
371
|
-
torch.arange(all_logprobs.shape[0], device="cuda"),
|
372
|
-
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
373
|
-
]
|
374
|
-
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
375
|
-
|
376
|
-
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
377
|
-
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
378
|
-
end = start + seq_lens - 2
|
379
|
-
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
380
|
-
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
381
|
-
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
382
|
-
|
383
|
-
# assert logprobs == [2, _, 2, 4, _]
|
384
|
-
print("token logprobs", token_logprobs)
|
385
|
-
print("start", start)
|
386
|
-
print("end", end)
|
387
|
-
print("sum_logp", sum_logp)
|
388
|
-
|
389
|
-
|
390
|
-
if __name__ == "__main__":
|
391
|
-
test()
|
346
|
+
return full_logits
|