sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -15,20 +15,16 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Radix attention."""
|
17
17
|
|
18
|
-
from typing import Optional
|
19
|
-
|
20
|
-
import torch
|
21
|
-
from flashinfer.cascade import merge_state
|
22
18
|
from torch import nn
|
23
19
|
|
24
|
-
from sglang.
|
25
|
-
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
26
|
-
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
27
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
28
|
-
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
21
|
|
30
22
|
|
31
23
|
class RadixAttention(nn.Module):
|
24
|
+
"""
|
25
|
+
The attention layer implementation.
|
26
|
+
"""
|
27
|
+
|
32
28
|
def __init__(
|
33
29
|
self,
|
34
30
|
num_heads: int,
|
@@ -36,8 +32,8 @@ class RadixAttention(nn.Module):
|
|
36
32
|
scaling: float,
|
37
33
|
num_kv_heads: int,
|
38
34
|
layer_id: int,
|
39
|
-
sliding_window_size:
|
40
|
-
logit_cap:
|
35
|
+
sliding_window_size: int = -1,
|
36
|
+
logit_cap: float = 0.0,
|
41
37
|
v_head_dim: int = -1,
|
42
38
|
):
|
43
39
|
super().__init__()
|
@@ -49,160 +45,14 @@ class RadixAttention(nn.Module):
|
|
49
45
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
50
46
|
self.scaling = scaling
|
51
47
|
self.layer_id = layer_id
|
52
|
-
self.
|
53
|
-
|
54
|
-
if (
|
55
|
-
not global_server_args_dict.get("disable_flashinfer", False)
|
56
|
-
and self.qk_head_dim == self.v_head_dim
|
57
|
-
):
|
58
|
-
self.extend_forward = self.extend_forward_flashinfer
|
59
|
-
self.decode_forward = self.decode_forward_flashinfer
|
60
|
-
else:
|
61
|
-
self.extend_forward = self.extend_forward_triton
|
62
|
-
self.decode_forward = self.decode_forward_triton
|
63
|
-
|
64
|
-
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
65
|
-
|
66
|
-
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
67
|
-
if self.qk_head_dim != self.v_head_dim:
|
68
|
-
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
69
|
-
else:
|
70
|
-
o = torch.empty_like(q)
|
71
|
-
|
72
|
-
self.store_kv_cache(k, v, input_metadata)
|
73
|
-
extend_attention_fwd(
|
74
|
-
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
75
|
-
k.contiguous(),
|
76
|
-
v.contiguous(),
|
77
|
-
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
78
|
-
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
79
|
-
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
80
|
-
input_metadata.req_to_token_pool.req_to_token,
|
81
|
-
input_metadata.req_pool_indices,
|
82
|
-
input_metadata.triton_start_loc,
|
83
|
-
input_metadata.seq_lens,
|
84
|
-
input_metadata.triton_prefix_lens,
|
85
|
-
input_metadata.extend_start_loc,
|
86
|
-
input_metadata.extend_seq_lens,
|
87
|
-
input_metadata.triton_max_seq_len,
|
88
|
-
input_metadata.triton_max_extend_len,
|
89
|
-
sm_scale=self.scaling,
|
90
|
-
logit_cap=self.logit_cap,
|
91
|
-
)
|
92
|
-
|
93
|
-
return o
|
94
|
-
|
95
|
-
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
96
|
-
if self.qk_head_dim != self.v_head_dim:
|
97
|
-
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
98
|
-
else:
|
99
|
-
o = torch.empty_like(q)
|
100
|
-
self.store_kv_cache(k, v, input_metadata)
|
101
|
-
|
102
|
-
decode_attention_fwd(
|
103
|
-
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
104
|
-
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
105
|
-
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
106
|
-
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
107
|
-
input_metadata.req_to_token_pool.req_to_token,
|
108
|
-
input_metadata.req_pool_indices,
|
109
|
-
input_metadata.triton_start_loc,
|
110
|
-
input_metadata.seq_lens,
|
111
|
-
input_metadata.triton_max_seq_len,
|
112
|
-
input_metadata.total_num_tokens,
|
113
|
-
sm_scale=self.scaling,
|
114
|
-
logit_cap=self.logit_cap,
|
115
|
-
)
|
116
|
-
|
117
|
-
return o
|
118
|
-
|
119
|
-
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
120
|
-
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
121
|
-
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
122
|
-
if self.sliding_window_size != -1:
|
123
|
-
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
124
|
-
else:
|
125
|
-
if isinstance(prefill_wrapper_paged, list):
|
126
|
-
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
127
|
-
|
128
|
-
if not input_metadata.flashinfer_use_ragged:
|
129
|
-
if k is not None:
|
130
|
-
assert v is not None
|
131
|
-
self.store_kv_cache(k, v, input_metadata)
|
132
|
-
|
133
|
-
o = prefill_wrapper_paged.forward(
|
134
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
135
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
136
|
-
causal=True,
|
137
|
-
sm_scale=self.scaling,
|
138
|
-
window_left=self.sliding_window_size,
|
139
|
-
logits_soft_cap=self.logit_cap,
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
o1, s1 = (
|
143
|
-
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
144
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
145
|
-
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
146
|
-
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
147
|
-
causal=True,
|
148
|
-
sm_scale=self.scaling,
|
149
|
-
logits_soft_cap=self.logit_cap,
|
150
|
-
)
|
151
|
-
)
|
152
|
-
|
153
|
-
if input_metadata.extend_no_prefix:
|
154
|
-
o = o1
|
155
|
-
else:
|
156
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
157
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
158
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
159
|
-
causal=False,
|
160
|
-
sm_scale=self.scaling,
|
161
|
-
logits_soft_cap=self.logit_cap,
|
162
|
-
)
|
163
|
-
|
164
|
-
o, _ = merge_state(o1, s1, o2, s2)
|
165
|
-
|
166
|
-
self.store_kv_cache(k, v, input_metadata)
|
167
|
-
|
168
|
-
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
169
|
-
torch.cuda.synchronize()
|
170
|
-
|
171
|
-
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
172
|
-
|
173
|
-
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
174
|
-
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
175
|
-
if self.sliding_window_size != -1:
|
176
|
-
decode_wrapper = decode_wrapper[0]
|
177
|
-
else:
|
178
|
-
if isinstance(decode_wrapper, list):
|
179
|
-
decode_wrapper = decode_wrapper[1]
|
180
|
-
|
181
|
-
if k is not None:
|
182
|
-
assert v is not None
|
183
|
-
self.store_kv_cache(k, v, input_metadata)
|
184
|
-
|
185
|
-
o = decode_wrapper.forward(
|
186
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
187
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
188
|
-
sm_scale=self.scaling,
|
189
|
-
logits_soft_cap=self.logit_cap,
|
190
|
-
)
|
191
|
-
|
192
|
-
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
48
|
+
self.logit_cap = logit_cap
|
49
|
+
self.sliding_window_size = sliding_window_size or -1
|
193
50
|
|
194
51
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
195
52
|
if k is not None:
|
53
|
+
# For cross-layer sharing, kv can be None
|
196
54
|
assert v is not None
|
197
55
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
198
56
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
199
57
|
|
200
|
-
|
201
|
-
return self.extend_forward(q, k, v, input_metadata)
|
202
|
-
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
203
|
-
return self.decode_forward(q, k, v, input_metadata)
|
204
|
-
|
205
|
-
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
206
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
207
|
-
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
208
|
-
)
|
58
|
+
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,74 +1,28 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import logging
|
3
|
-
from typing import
|
2
|
+
from typing import Union
|
4
3
|
|
5
4
|
import torch
|
6
|
-
from
|
7
|
-
min_p_sampling_from_probs,
|
8
|
-
top_k_renorm_prob,
|
9
|
-
top_k_top_p_sampling_from_probs,
|
10
|
-
top_p_renorm_prob,
|
11
|
-
)
|
12
|
-
from torch.library import custom_op as torch_custom_op
|
13
|
-
from vllm.model_executor.custom_op import CustomOp
|
5
|
+
from torch import nn
|
14
6
|
|
15
7
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
16
|
-
|
17
|
-
# TODO: move this dict to another place
|
18
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
19
9
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
# ROCm: flashinfer available later
|
13
|
+
if not is_hip():
|
14
|
+
from flashinfer.sampling import (
|
15
|
+
min_p_sampling_from_probs,
|
16
|
+
top_k_renorm_prob,
|
17
|
+
top_k_top_p_sampling_from_probs,
|
18
|
+
top_p_renorm_prob,
|
19
|
+
)
|
20
20
|
|
21
21
|
logger = logging.getLogger(__name__)
|
22
22
|
|
23
23
|
|
24
|
-
|
25
|
-
|
26
|
-
success: torch.Tensor
|
27
|
-
probs: torch.Tensor
|
28
|
-
batch_next_token_ids: torch.Tensor
|
29
|
-
|
30
|
-
|
31
|
-
class Sampler(CustomOp):
|
32
|
-
def __init__(self):
|
33
|
-
super().__init__()
|
34
|
-
# FIXME: torch.multinomial has too many bugs
|
35
|
-
self.forward_native = self.forward_cuda
|
36
|
-
self.is_torch_compile = False
|
37
|
-
|
38
|
-
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
39
|
-
# min-token, presence, frequency
|
40
|
-
if sampling_info.linear_penalties is not None:
|
41
|
-
logits += sampling_info.linear_penalties
|
42
|
-
|
43
|
-
# repetition
|
44
|
-
if sampling_info.scaling_penalties is not None:
|
45
|
-
logits = torch.where(
|
46
|
-
logits > 0,
|
47
|
-
logits / sampling_info.scaling_penalties,
|
48
|
-
logits * sampling_info.scaling_penalties,
|
49
|
-
)
|
50
|
-
|
51
|
-
return logits
|
52
|
-
|
53
|
-
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
54
|
-
# Post process logits
|
55
|
-
logits = logits.contiguous()
|
56
|
-
logits.div_(sampling_info.temperatures)
|
57
|
-
if self.is_torch_compile:
|
58
|
-
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
59
|
-
logits.add_(0)
|
60
|
-
|
61
|
-
if sampling_info.logit_bias is not None:
|
62
|
-
logits.add_(sampling_info.logit_bias)
|
63
|
-
|
64
|
-
if sampling_info.vocab_mask is not None:
|
65
|
-
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
66
|
-
|
67
|
-
logits = self._apply_penalties(logits, sampling_info)
|
68
|
-
|
69
|
-
return torch.softmax(logits, dim=-1)
|
70
|
-
|
71
|
-
def forward_cuda(
|
24
|
+
class Sampler(nn.Module):
|
25
|
+
def forward(
|
72
26
|
self,
|
73
27
|
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
74
28
|
sampling_info: SamplingBatchInfo,
|
@@ -76,9 +30,17 @@ class Sampler(CustomOp):
|
|
76
30
|
if isinstance(logits, LogitsProcessorOutput):
|
77
31
|
logits = logits.next_token_logits
|
78
32
|
|
79
|
-
|
33
|
+
# Post process logits
|
34
|
+
logits.div_(sampling_info.temperatures)
|
35
|
+
probs = logits[:] = torch.softmax(logits, dim=-1)
|
80
36
|
|
81
|
-
if
|
37
|
+
if torch.any(torch.isnan(probs)):
|
38
|
+
logger.warning("Detected errors during sampling! NaN in the probability.")
|
39
|
+
probs = torch.where(
|
40
|
+
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
|
41
|
+
)
|
42
|
+
|
43
|
+
if global_server_args_dict["sampling_backend"] == "flashinfer":
|
82
44
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
83
45
|
uniform_samples = torch.rand(
|
84
46
|
(max_top_k_round, batch_size), device=probs.device
|
@@ -90,57 +52,24 @@ class Sampler(CustomOp):
|
|
90
52
|
probs, uniform_samples, sampling_info.min_ps
|
91
53
|
)
|
92
54
|
else:
|
93
|
-
batch_next_token_ids, success =
|
55
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
94
56
|
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
|
95
57
|
)
|
96
|
-
|
58
|
+
|
59
|
+
if not torch.all(success):
|
60
|
+
logger.warning("Detected errors during sampling!")
|
61
|
+
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
62
|
+
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
97
63
|
# Here we provide a slower fallback implementation.
|
98
|
-
batch_next_token_ids
|
64
|
+
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
99
65
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
100
66
|
)
|
67
|
+
else:
|
68
|
+
raise ValueError(
|
69
|
+
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
70
|
+
)
|
101
71
|
|
102
|
-
return
|
103
|
-
|
104
|
-
def forward_native(
|
105
|
-
self,
|
106
|
-
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
107
|
-
sampling_info: SamplingBatchInfo,
|
108
|
-
):
|
109
|
-
if isinstance(logits, LogitsProcessorOutput):
|
110
|
-
logits = logits.next_token_logits
|
111
|
-
|
112
|
-
probs = self._get_probs(logits, sampling_info)
|
113
|
-
|
114
|
-
batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
|
115
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
116
|
-
)
|
117
|
-
|
118
|
-
return SampleOutput(success, probs, batch_next_token_ids)
|
119
|
-
|
120
|
-
|
121
|
-
@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
|
122
|
-
def flashinfer_top_k_top_p(
|
123
|
-
probs: torch.Tensor,
|
124
|
-
uniform_samples: torch.Tensor,
|
125
|
-
top_ks: torch.Tensor,
|
126
|
-
top_ps: torch.Tensor,
|
127
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
128
|
-
# NOTE: we do not use min_p neither in CUDA nor in torch.compile
|
129
|
-
return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
|
130
|
-
|
131
|
-
|
132
|
-
@flashinfer_top_k_top_p.register_fake
|
133
|
-
def _(
|
134
|
-
probs: torch.Tensor,
|
135
|
-
uniform_samples: torch.Tensor,
|
136
|
-
top_ks: torch.Tensor,
|
137
|
-
top_ps: torch.Tensor,
|
138
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
139
|
-
bs = probs.shape[0]
|
140
|
-
return (
|
141
|
-
torch.ones(bs, dtype=torch.bool, device=probs.device),
|
142
|
-
torch.zeros(bs, dtype=torch.int32, device=probs.device),
|
143
|
-
)
|
72
|
+
return batch_next_token_ids
|
144
73
|
|
145
74
|
|
146
75
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -160,19 +89,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
160
89
|
] = 0.0
|
161
90
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
162
91
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
163
|
-
|
164
|
-
# FIXME: torch.multiomial does not support num_samples = 1
|
165
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
166
|
-
:, :1
|
167
|
-
]
|
168
|
-
except RuntimeError as e:
|
169
|
-
logger.warning(f"Sampling error: {e}")
|
170
|
-
batch_next_token_ids = torch.zeros(
|
171
|
-
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
|
172
|
-
)
|
173
|
-
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
174
|
-
return batch_next_token_ids, success
|
175
|
-
|
92
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
176
93
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
177
|
-
|
178
|
-
return batch_next_token_ids, success
|
94
|
+
return batch_next_token_ids
|
@@ -0,0 +1,75 @@
|
|
1
|
+
"""
|
2
|
+
Common utilities for torchao.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Dict, Set
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
11
|
+
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
12
|
+
|
13
|
+
Args:
|
14
|
+
`param`: weight parameter of the linear module
|
15
|
+
`torchao_config`: type of quantization and their arguments we want to use to
|
16
|
+
quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
|
17
|
+
128
|
18
|
+
"""
|
19
|
+
# Lazy import to suppress some warnings
|
20
|
+
from torchao.quantization import (
|
21
|
+
int4_weight_only,
|
22
|
+
int8_dynamic_activation_int8_weight,
|
23
|
+
int8_weight_only,
|
24
|
+
quantize_,
|
25
|
+
)
|
26
|
+
|
27
|
+
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
28
|
+
dummy_linear.weight = param
|
29
|
+
if "int8wo" in torchao_config:
|
30
|
+
quantize_(dummy_linear, int8_weight_only())
|
31
|
+
elif "int8dq" in torchao_config:
|
32
|
+
quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
|
33
|
+
elif "int4wo" in torchao_config:
|
34
|
+
group_size = int(torchao_config.split("-")[-1])
|
35
|
+
assert group_size in [
|
36
|
+
32,
|
37
|
+
64,
|
38
|
+
128,
|
39
|
+
256,
|
40
|
+
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
41
|
+
quantize_(dummy_linear, int4_weight_only(group_size=group_size))
|
42
|
+
elif "fp8wo" in torchao_config:
|
43
|
+
from torchao.quantization import float8_weight_only
|
44
|
+
|
45
|
+
# this requires newer hardware
|
46
|
+
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
47
|
+
quantize_(dummy_linear, float8_weight_only())
|
48
|
+
return dummy_linear.weight
|
49
|
+
|
50
|
+
|
51
|
+
def apply_torchao_config_(
|
52
|
+
self: torch.nn.Module,
|
53
|
+
params_dict: Dict[str, torch.Tensor],
|
54
|
+
param_suffixes: Set[str],
|
55
|
+
) -> None:
|
56
|
+
"""A util function used for quantizing the weight parameters after they are loaded if
|
57
|
+
self.torchao_config is specified
|
58
|
+
|
59
|
+
Args:
|
60
|
+
`self`: the model we want to quantize
|
61
|
+
`params_dict`: dictionary mapping from param_name to the parameter Tensor
|
62
|
+
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
|
66
|
+
"""
|
67
|
+
if self.torchao_config:
|
68
|
+
for param_suffix in param_suffixes:
|
69
|
+
for name in params_dict:
|
70
|
+
param = params_dict[name]
|
71
|
+
if param_suffix in name and param.ndim == 2:
|
72
|
+
params_dict[name] = torchao_quantize_param_data(
|
73
|
+
param, self.torchao_config
|
74
|
+
)
|
75
|
+
self.load_state_dict(params_dict, assign=True)
|