sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -15,20 +15,16 @@ limitations under the License.
15
15
 
16
16
  """Radix attention."""
17
17
 
18
- from typing import Optional
19
-
20
- import torch
21
- from flashinfer.cascade import merge_state
22
18
  from torch import nn
23
19
 
24
- from sglang.global_config import global_config
25
- from sglang.srt.layers.decode_attention import decode_attention_fwd
26
- from sglang.srt.layers.extend_attention import extend_attention_fwd
27
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
28
- from sglang.srt.model_executor.model_runner import global_server_args_dict
20
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
21
 
30
22
 
31
23
  class RadixAttention(nn.Module):
24
+ """
25
+ The attention layer implementation.
26
+ """
27
+
32
28
  def __init__(
33
29
  self,
34
30
  num_heads: int,
@@ -36,8 +32,8 @@ class RadixAttention(nn.Module):
36
32
  scaling: float,
37
33
  num_kv_heads: int,
38
34
  layer_id: int,
39
- sliding_window_size: Optional[int] = None,
40
- logit_cap: int = -1,
35
+ sliding_window_size: int = -1,
36
+ logit_cap: float = 0.0,
41
37
  v_head_dim: int = -1,
42
38
  ):
43
39
  super().__init__()
@@ -49,160 +45,14 @@ class RadixAttention(nn.Module):
49
45
  self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
50
46
  self.scaling = scaling
51
47
  self.layer_id = layer_id
52
- self.sliding_window_size = sliding_window_size if sliding_window_size else -1
53
-
54
- if (
55
- not global_server_args_dict.get("disable_flashinfer", False)
56
- and self.qk_head_dim == self.v_head_dim
57
- ):
58
- self.extend_forward = self.extend_forward_flashinfer
59
- self.decode_forward = self.decode_forward_flashinfer
60
- else:
61
- self.extend_forward = self.extend_forward_triton
62
- self.decode_forward = self.decode_forward_triton
63
-
64
- self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
65
-
66
- def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
67
- if self.qk_head_dim != self.v_head_dim:
68
- o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
69
- else:
70
- o = torch.empty_like(q)
71
-
72
- self.store_kv_cache(k, v, input_metadata)
73
- extend_attention_fwd(
74
- q.view(-1, self.tp_q_head_num, self.qk_head_dim),
75
- k.contiguous(),
76
- v.contiguous(),
77
- o.view(-1, self.tp_q_head_num, self.v_head_dim),
78
- input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
79
- input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
80
- input_metadata.req_to_token_pool.req_to_token,
81
- input_metadata.req_pool_indices,
82
- input_metadata.triton_start_loc,
83
- input_metadata.seq_lens,
84
- input_metadata.triton_prefix_lens,
85
- input_metadata.extend_start_loc,
86
- input_metadata.extend_seq_lens,
87
- input_metadata.triton_max_seq_len,
88
- input_metadata.triton_max_extend_len,
89
- sm_scale=self.scaling,
90
- logit_cap=self.logit_cap,
91
- )
92
-
93
- return o
94
-
95
- def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
96
- if self.qk_head_dim != self.v_head_dim:
97
- o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
98
- else:
99
- o = torch.empty_like(q)
100
- self.store_kv_cache(k, v, input_metadata)
101
-
102
- decode_attention_fwd(
103
- q.view(-1, self.tp_q_head_num, self.qk_head_dim),
104
- input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
105
- input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
106
- o.view(-1, self.tp_q_head_num, self.v_head_dim),
107
- input_metadata.req_to_token_pool.req_to_token,
108
- input_metadata.req_pool_indices,
109
- input_metadata.triton_start_loc,
110
- input_metadata.seq_lens,
111
- input_metadata.triton_max_seq_len,
112
- input_metadata.total_num_tokens,
113
- sm_scale=self.scaling,
114
- logit_cap=self.logit_cap,
115
- )
116
-
117
- return o
118
-
119
- def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
120
- # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
121
- prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
122
- if self.sliding_window_size != -1:
123
- prefill_wrapper_paged = prefill_wrapper_paged[0]
124
- else:
125
- if isinstance(prefill_wrapper_paged, list):
126
- prefill_wrapper_paged = prefill_wrapper_paged[1]
127
-
128
- if not input_metadata.flashinfer_use_ragged:
129
- if k is not None:
130
- assert v is not None
131
- self.store_kv_cache(k, v, input_metadata)
132
-
133
- o = prefill_wrapper_paged.forward(
134
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
135
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
136
- causal=True,
137
- sm_scale=self.scaling,
138
- window_left=self.sliding_window_size,
139
- logits_soft_cap=self.logit_cap,
140
- )
141
- else:
142
- o1, s1 = (
143
- input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
144
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
145
- k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
146
- v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
147
- causal=True,
148
- sm_scale=self.scaling,
149
- logits_soft_cap=self.logit_cap,
150
- )
151
- )
152
-
153
- if input_metadata.extend_no_prefix:
154
- o = o1
155
- else:
156
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
157
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
158
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
159
- causal=False,
160
- sm_scale=self.scaling,
161
- logits_soft_cap=self.logit_cap,
162
- )
163
-
164
- o, _ = merge_state(o1, s1, o2, s2)
165
-
166
- self.store_kv_cache(k, v, input_metadata)
167
-
168
- if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
169
- torch.cuda.synchronize()
170
-
171
- return o.view(-1, self.tp_q_head_num * self.head_dim)
172
-
173
- def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
174
- decode_wrapper = input_metadata.flashinfer_decode_wrapper
175
- if self.sliding_window_size != -1:
176
- decode_wrapper = decode_wrapper[0]
177
- else:
178
- if isinstance(decode_wrapper, list):
179
- decode_wrapper = decode_wrapper[1]
180
-
181
- if k is not None:
182
- assert v is not None
183
- self.store_kv_cache(k, v, input_metadata)
184
-
185
- o = decode_wrapper.forward(
186
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
187
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
188
- sm_scale=self.scaling,
189
- logits_soft_cap=self.logit_cap,
190
- )
191
-
192
- return o.view(-1, self.tp_q_head_num * self.head_dim)
48
+ self.logit_cap = logit_cap
49
+ self.sliding_window_size = sliding_window_size or -1
193
50
 
194
51
  def forward(self, q, k, v, input_metadata: InputMetadata):
195
52
  if k is not None:
53
+ # For cross-layer sharing, kv can be None
196
54
  assert v is not None
197
55
  k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
198
56
  v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
199
57
 
200
- if input_metadata.forward_mode == ForwardMode.EXTEND:
201
- return self.extend_forward(q, k, v, input_metadata)
202
- elif input_metadata.forward_mode == ForwardMode.DECODE:
203
- return self.decode_forward(q, k, v, input_metadata)
204
-
205
- def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
206
- input_metadata.token_to_kv_pool.set_kv_buffer(
207
- self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
208
- )
58
+ return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
@@ -1,74 +1,28 @@
1
- import dataclasses
2
1
  import logging
3
- from typing import Tuple, Union
2
+ from typing import Union
4
3
 
5
4
  import torch
6
- from flashinfer.sampling import (
7
- min_p_sampling_from_probs,
8
- top_k_renorm_prob,
9
- top_k_top_p_sampling_from_probs,
10
- top_p_renorm_prob,
11
- )
12
- from torch.library import custom_op as torch_custom_op
13
- from vllm.model_executor.custom_op import CustomOp
5
+ from torch import nn
14
6
 
15
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
-
17
- # TODO: move this dict to another place
18
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
19
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
+ from sglang.srt.utils import is_hip
11
+
12
+ # ROCm: flashinfer available later
13
+ if not is_hip():
14
+ from flashinfer.sampling import (
15
+ min_p_sampling_from_probs,
16
+ top_k_renorm_prob,
17
+ top_k_top_p_sampling_from_probs,
18
+ top_p_renorm_prob,
19
+ )
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- @dataclasses.dataclass
25
- class SampleOutput:
26
- success: torch.Tensor
27
- probs: torch.Tensor
28
- batch_next_token_ids: torch.Tensor
29
-
30
-
31
- class Sampler(CustomOp):
32
- def __init__(self):
33
- super().__init__()
34
- # FIXME: torch.multinomial has too many bugs
35
- self.forward_native = self.forward_cuda
36
- self.is_torch_compile = False
37
-
38
- def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
39
- # min-token, presence, frequency
40
- if sampling_info.linear_penalties is not None:
41
- logits += sampling_info.linear_penalties
42
-
43
- # repetition
44
- if sampling_info.scaling_penalties is not None:
45
- logits = torch.where(
46
- logits > 0,
47
- logits / sampling_info.scaling_penalties,
48
- logits * sampling_info.scaling_penalties,
49
- )
50
-
51
- return logits
52
-
53
- def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
54
- # Post process logits
55
- logits = logits.contiguous()
56
- logits.div_(sampling_info.temperatures)
57
- if self.is_torch_compile:
58
- # FIXME: Temporary workaround for unknown bugs in torch.compile
59
- logits.add_(0)
60
-
61
- if sampling_info.logit_bias is not None:
62
- logits.add_(sampling_info.logit_bias)
63
-
64
- if sampling_info.vocab_mask is not None:
65
- logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
66
-
67
- logits = self._apply_penalties(logits, sampling_info)
68
-
69
- return torch.softmax(logits, dim=-1)
70
-
71
- def forward_cuda(
24
+ class Sampler(nn.Module):
25
+ def forward(
72
26
  self,
73
27
  logits: Union[torch.Tensor, LogitsProcessorOutput],
74
28
  sampling_info: SamplingBatchInfo,
@@ -76,9 +30,17 @@ class Sampler(CustomOp):
76
30
  if isinstance(logits, LogitsProcessorOutput):
77
31
  logits = logits.next_token_logits
78
32
 
79
- probs = self._get_probs(logits, sampling_info)
33
+ # Post process logits
34
+ logits.div_(sampling_info.temperatures)
35
+ probs = logits[:] = torch.softmax(logits, dim=-1)
80
36
 
81
- if not global_server_args_dict["disable_flashinfer_sampling"]:
37
+ if torch.any(torch.isnan(probs)):
38
+ logger.warning("Detected errors during sampling! NaN in the probability.")
39
+ probs = torch.where(
40
+ torch.isnan(probs), torch.full_like(probs, 1e-10), probs
41
+ )
42
+
43
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
82
44
  max_top_k_round, batch_size = 32, probs.shape[0]
83
45
  uniform_samples = torch.rand(
84
46
  (max_top_k_round, batch_size), device=probs.device
@@ -90,57 +52,24 @@ class Sampler(CustomOp):
90
52
  probs, uniform_samples, sampling_info.min_ps
91
53
  )
92
54
  else:
93
- batch_next_token_ids, success = flashinfer_top_k_top_p(
55
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
94
56
  probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
95
57
  )
96
- else:
58
+
59
+ if not torch.all(success):
60
+ logger.warning("Detected errors during sampling!")
61
+ batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
62
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
97
63
  # Here we provide a slower fallback implementation.
98
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
64
+ batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
99
65
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
100
66
  )
67
+ else:
68
+ raise ValueError(
69
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
70
+ )
101
71
 
102
- return SampleOutput(success, probs, batch_next_token_ids)
103
-
104
- def forward_native(
105
- self,
106
- logits: Union[torch.Tensor, LogitsProcessorOutput],
107
- sampling_info: SamplingBatchInfo,
108
- ):
109
- if isinstance(logits, LogitsProcessorOutput):
110
- logits = logits.next_token_logits
111
-
112
- probs = self._get_probs(logits, sampling_info)
113
-
114
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
115
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
116
- )
117
-
118
- return SampleOutput(success, probs, batch_next_token_ids)
119
-
120
-
121
- @torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={})
122
- def flashinfer_top_k_top_p(
123
- probs: torch.Tensor,
124
- uniform_samples: torch.Tensor,
125
- top_ks: torch.Tensor,
126
- top_ps: torch.Tensor,
127
- ) -> Tuple[torch.Tensor, torch.Tensor]:
128
- # NOTE: we do not use min_p neither in CUDA nor in torch.compile
129
- return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps)
130
-
131
-
132
- @flashinfer_top_k_top_p.register_fake
133
- def _(
134
- probs: torch.Tensor,
135
- uniform_samples: torch.Tensor,
136
- top_ks: torch.Tensor,
137
- top_ps: torch.Tensor,
138
- ) -> Tuple[torch.Tensor, torch.Tensor]:
139
- bs = probs.shape[0]
140
- return (
141
- torch.ones(bs, dtype=torch.bool, device=probs.device),
142
- torch.zeros(bs, dtype=torch.int32, device=probs.device),
143
- )
72
+ return batch_next_token_ids
144
73
 
145
74
 
146
75
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -160,19 +89,6 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
160
89
  ] = 0.0
161
90
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
162
91
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
163
- try:
164
- # FIXME: torch.multiomial does not support num_samples = 1
165
- sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
166
- :, :1
167
- ]
168
- except RuntimeError as e:
169
- logger.warning(f"Sampling error: {e}")
170
- batch_next_token_ids = torch.zeros(
171
- (probs_sort.shape[0],), dtype=torch.int32, device=probs.device
172
- )
173
- success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
174
- return batch_next_token_ids, success
175
-
92
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
176
93
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
177
- success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
178
- return batch_next_token_ids, success
94
+ return batch_next_token_ids
@@ -0,0 +1,75 @@
1
+ """
2
+ Common utilities for torchao.
3
+ """
4
+
5
+ from typing import Dict, Set
6
+
7
+ import torch
8
+
9
+
10
+ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11
+ """Quantize a Tensor with torchao quantization specified by torchao_config
12
+
13
+ Args:
14
+ `param`: weight parameter of the linear module
15
+ `torchao_config`: type of quantization and their arguments we want to use to
16
+ quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
17
+ 128
18
+ """
19
+ # Lazy import to suppress some warnings
20
+ from torchao.quantization import (
21
+ int4_weight_only,
22
+ int8_dynamic_activation_int8_weight,
23
+ int8_weight_only,
24
+ quantize_,
25
+ )
26
+
27
+ dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
28
+ dummy_linear.weight = param
29
+ if "int8wo" in torchao_config:
30
+ quantize_(dummy_linear, int8_weight_only())
31
+ elif "int8dq" in torchao_config:
32
+ quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
33
+ elif "int4wo" in torchao_config:
34
+ group_size = int(torchao_config.split("-")[-1])
35
+ assert group_size in [
36
+ 32,
37
+ 64,
38
+ 128,
39
+ 256,
40
+ ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
41
+ quantize_(dummy_linear, int4_weight_only(group_size=group_size))
42
+ elif "fp8wo" in torchao_config:
43
+ from torchao.quantization import float8_weight_only
44
+
45
+ # this requires newer hardware
46
+ # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
47
+ quantize_(dummy_linear, float8_weight_only())
48
+ return dummy_linear.weight
49
+
50
+
51
+ def apply_torchao_config_(
52
+ self: torch.nn.Module,
53
+ params_dict: Dict[str, torch.Tensor],
54
+ param_suffixes: Set[str],
55
+ ) -> None:
56
+ """A util function used for quantizing the weight parameters after they are loaded if
57
+ self.torchao_config is specified
58
+
59
+ Args:
60
+ `self`: the model we want to quantize
61
+ `params_dict`: dictionary mapping from param_name to the parameter Tensor
62
+ `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
63
+
64
+ Returns:
65
+ None, the `params_dict` is modified inplace and the weights of `self` model are quantized
66
+ """
67
+ if self.torchao_config:
68
+ for param_suffix in param_suffixes:
69
+ for name in params_dict:
70
+ param = params_dict[name]
71
+ if param_suffix in name and param.ndim == 2:
72
+ params_dict[name] = torchao_quantize_param_data(
73
+ param, self.torchao_config
74
+ )
75
+ self.load_state_dict(params_dict, assign=True)