compressed-tensors 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  2. compressed_tensors/config/base.py +60 -2
  3. compressed_tensors/quantization/__init__.py +0 -1
  4. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  5. compressed_tensors/quantization/lifecycle/apply.py +1 -16
  6. compressed_tensors/quantization/lifecycle/forward.py +25 -86
  7. compressed_tensors/quantization/lifecycle/initialize.py +23 -25
  8. compressed_tensors/quantization/quant_args.py +28 -15
  9. compressed_tensors/quantization/quant_scheme.py +3 -0
  10. compressed_tensors/quantization/utils/helpers.py +125 -8
  11. compressed_tensors/registry/registry.py +1 -1
  12. compressed_tensors/version.py +1 -1
  13. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/METADATA +1 -1
  14. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/RECORD +17 -26
  15. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/WHEEL +1 -1
  16. compressed_tensors/quantization/cache.py +0 -201
  17. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  18. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  19. compressed_tensors/quantization/observers/__init__.py +0 -22
  20. compressed_tensors/quantization/observers/base.py +0 -213
  21. compressed_tensors/quantization/observers/helpers.py +0 -111
  22. compressed_tensors/quantization/observers/memoryless.py +0 -56
  23. compressed_tensors/quantization/observers/min_max.py +0 -104
  24. compressed_tensors/quantization/observers/mse.py +0 -162
  25. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/LICENSE +0 -0
  26. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/top_level.txt +0 -0
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
238
238
  _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
239
239
 
240
240
  if (
241
- quantization_args.strategy is QuantizationStrategy.GROUP
241
+ quantization_args.strategy == QuantizationStrategy.GROUP
242
242
  and quantization_args.group_size < size_k
243
243
  ):
244
244
  scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
@@ -12,16 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from enum import Enum
15
+ from enum import Enum, unique
16
16
  from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
20
20
 
21
21
 
22
- __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
23
23
 
24
24
 
25
+ @unique
25
26
  class CompressionFormat(Enum):
26
27
  dense = "dense"
27
28
  sparse_bitmask = "sparse-bitmask"
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
32
33
  marlin_24 = "marlin-24"
33
34
 
34
35
 
36
+ @unique
37
+ class SparsityStructure(Enum):
38
+ """
39
+ An enumeration to represent different sparsity structures.
40
+
41
+ Attributes
42
+ ----------
43
+ TWO_FOUR : str
44
+ Represents a 2:4 sparsity structure.
45
+ ZERO_ZERO : str
46
+ Represents a 0:0 sparsity structure.
47
+ UNSTRUCTURED : str
48
+ Represents an unstructured sparsity structure.
49
+
50
+ Examples
51
+ --------
52
+ >>> SparsityStructure('2:4')
53
+ <SparsityStructure.TWO_FOUR: '2:4'>
54
+
55
+ >>> SparsityStructure('unstructured')
56
+ <SparsityStructure.UNSTRUCTURED: 'unstructured'>
57
+
58
+ >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
59
+ True
60
+
61
+ >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
62
+ True
63
+
64
+ >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
65
+ True
66
+
67
+ >>> SparsityStructure('invalid')
68
+ Traceback (most recent call last):
69
+ ...
70
+ ValueError: invalid is not a valid SparsityStructure
71
+ """
72
+
73
+ TWO_FOUR = "2:4"
74
+ UNSTRUCTURED = "unstructured"
75
+ ZERO_ZERO = "0:0"
76
+
77
+ def __new__(cls, value):
78
+ obj = object.__new__(cls)
79
+ obj._value_ = value.lower() if value is not None else value
80
+ return obj
81
+
82
+ @classmethod
83
+ def _missing_(cls, value):
84
+ # Handle None and case-insensitive values
85
+ if value is None:
86
+ return cls.UNSTRUCTURED
87
+ for member in cls:
88
+ if member.value == value.lower():
89
+ return member
90
+ raise ValueError(f"{value} is not a valid {cls.__name__}")
91
+
92
+
35
93
  class SparsityCompressionConfig(RegistryMixin, BaseModel):
36
94
  """
37
95
  Base data class for storing sparsity compression parameters
@@ -19,4 +19,3 @@ from .quant_args import *
19
19
  from .quant_config import *
20
20
  from .quant_scheme import *
21
21
  from .lifecycle import *
22
- from .cache import QuantizedKVParameterCache
@@ -15,9 +15,7 @@
15
15
  # flake8: noqa
16
16
  # isort: skip_file
17
17
 
18
- from .calibration import *
19
18
  from .forward import *
20
- from .frozen import *
21
19
  from .initialize import *
22
20
  from .compressed import *
23
21
  from .apply import *
@@ -22,13 +22,9 @@ from typing import Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
25
- from compressed_tensors.quantization.lifecycle.calibration import (
26
- set_module_for_calibration,
27
- )
28
25
  from compressed_tensors.quantization.lifecycle.compressed import (
29
26
  compress_quantized_weights,
30
27
  )
31
- from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
32
28
  from compressed_tensors.quantization.lifecycle.initialize import (
33
29
  initialize_module_for_quantization,
34
30
  )
@@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
233
229
  :param model: model to apply quantization to
234
230
  :param status: status to update the module to
235
231
  """
232
+
236
233
  current_status = infer_quantization_status(model)
237
234
 
238
235
  if status >= QuantizationStatus.INITIALIZED > current_status:
@@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
243
240
  )
244
241
  )
245
242
 
246
- if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
247
- # only quantize weights up front when our end goal state is calibration,
248
- # weight quantization parameters are already loaded for frozen/compressed
249
- quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
250
- model.apply(
251
- lambda module: set_module_for_calibration(
252
- module, quantize_weights_upfront=quantize_weights_upfront
253
- )
254
- )
255
- if current_status < status >= QuantizationStatus.FROZEN > current_status:
256
- model.apply(freeze_module_quantization)
257
-
258
243
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
259
244
  model.apply(compress_quantized_weights)
260
245
 
@@ -14,11 +14,9 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Callable, Optional
17
+ from typing import Optional
18
18
 
19
19
  import torch
20
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
21
- from compressed_tensors.quantization.observers.helpers import calculate_range
22
20
  from compressed_tensors.quantization.quant_args import (
23
21
  QuantizationArgs,
24
22
  QuantizationStrategy,
@@ -26,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
26
24
  )
27
25
  from compressed_tensors.quantization.quant_config import QuantizationStatus
28
26
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29
- from compressed_tensors.utils import safe_permute, update_parameter_data
27
+ from compressed_tensors.quantization.utils import (
28
+ calculate_range,
29
+ compute_dynamic_scales_and_zp,
30
+ )
31
+ from compressed_tensors.utils import safe_permute
30
32
  from torch.nn import Module
31
33
 
32
34
 
@@ -35,7 +37,7 @@ __all__ = [
35
37
  "dequantize",
36
38
  "fake_quantize",
37
39
  "wrap_module_forward_quantized",
38
- "maybe_calibrate_or_quantize",
40
+ "forward_quantize",
39
41
  ]
40
42
 
41
43
 
@@ -272,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
272
274
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
273
275
 
274
276
  if scheme.input_activations is not None:
275
- # calibrate and (fake) quantize input activations when applicable
276
- input_ = maybe_calibrate_or_quantize(
277
- module, input_, "input", scheme.input_activations
278
- )
277
+ # prehook should calibrate activations before forward call
278
+ input_ = forward_quantize(module, input_, "input", scheme.input_activations)
279
279
 
280
280
  if scheme.weights is not None and not compressed:
281
281
  # calibrate and (fake) quantize weights when applicable
282
282
  unquantized_weight = self.weight.data.clone()
283
- self.weight.data = maybe_calibrate_or_quantize(
283
+ self.weight.data = forward_quantize(
284
284
  module, self.weight, "weight", scheme.weights
285
285
  )
286
286
 
@@ -288,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
288
288
  output = forward_func_orig.__get__(module, module.__class__)(
289
289
  input_, *args[1:], **kwargs
290
290
  )
291
- if scheme.output_activations is not None:
292
-
293
- # calibrate and (fake) quantize output activations when applicable
294
- # kv_cache scales updated on model self_attn forward call in
295
- # wrap_module_forward_quantized_attn
296
- output = maybe_calibrate_or_quantize(
297
- module, output, "output", scheme.output_activations
298
- )
299
291
 
300
292
  # restore back to unquantized_value
301
293
  if scheme.weights is not None and not compressed:
302
294
  self.weight.data = unquantized_weight
303
295
 
304
- return output
305
-
306
- # bind wrapped forward to module class so reference to `self` is correct
307
- bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
308
- # set forward to wrapped forward
309
- setattr(module, "forward", bound_wrapped_forward)
310
-
311
-
312
- def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
313
- # expects a module already initialized and injected with the parameters in
314
- # initialize_module_for_quantization
315
- if hasattr(module.forward, "__func__"):
316
- forward_func_orig = module.forward.__func__
317
- else:
318
- forward_func_orig = module.forward.func
319
-
320
- @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
321
- def wrapped_forward(self, *args, **kwargs):
322
-
323
- # kv cache stored under weights
324
- if module.quantization_status == QuantizationStatus.CALIBRATION:
325
- quantization_args: QuantizationArgs = scheme.output_activations
326
- past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
327
- kwargs["past_key_value"] = past_key_value
328
-
329
- # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
330
- # does not store quantized_key_states and quantized_value_state
331
- kwargs["use_cache"] = False
332
-
333
- attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
334
-
335
- past_key_value.reset_states()
336
-
337
- rtn = attn_forward(*args, **kwargs)
338
-
339
- update_parameter_data(
340
- module, past_key_value.k_scales[module.layer_idx], "k_scale"
341
- )
342
- update_parameter_data(
343
- module, past_key_value.v_scales[module.layer_idx], "v_scale"
296
+ if scheme.output_activations is not None:
297
+ # forward-hook should calibrate/forward_quantize
298
+ if (
299
+ module.quantization_status == QuantizationStatus.CALIBRATION
300
+ and not scheme.output_activations.dynamic
301
+ ):
302
+ return output
303
+
304
+ output = forward_quantize(
305
+ module, output, "output", scheme.output_activations
344
306
  )
345
-
346
- return rtn
347
-
348
- return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
307
+ return output
349
308
 
350
309
  # bind wrapped forward to module class so reference to `self` is correct
351
310
  bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
@@ -353,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
353
312
  setattr(module, "forward", bound_wrapped_forward)
354
313
 
355
314
 
356
- def maybe_calibrate_or_quantize(
315
+ def forward_quantize(
357
316
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
358
317
  ) -> torch.Tensor:
359
- # don't run quantization if we haven't entered calibration mode
360
- if module.quantization_status == QuantizationStatus.INITIALIZED:
361
- return value
362
318
 
363
319
  # in compressed mode, the weight is already compressed and quantized so we don't
364
320
  # need to run fake quantization
@@ -376,30 +332,13 @@ def maybe_calibrate_or_quantize(
376
332
  g_idx = getattr(module, "weight_g_idx", None)
377
333
 
378
334
  if args.dynamic:
379
- # dynamic quantization - get scale and zero point directly from observer
380
- observer = getattr(module, f"{base_name}_observer")
381
- scale, zero_point = observer(value, g_idx=g_idx)
335
+ # dynamic quantization - determine the scale/zp on the fly
336
+ scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
382
337
  else:
383
- # static quantization - get previous scale and zero point from layer
338
+ # static quantization - get scale and zero point from layer
384
339
  scale = getattr(module, f"{base_name}_scale")
385
340
  zero_point = getattr(module, f"{base_name}_zero_point", None)
386
341
 
387
- if (
388
- module.quantization_status == QuantizationStatus.CALIBRATION
389
- and base_name != "weight"
390
- ):
391
- # calibration mode - get new quant params from observer
392
- observer = getattr(module, f"{base_name}_observer")
393
-
394
- updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
395
-
396
- # update scale and zero point
397
- update_parameter_data(module, updated_scale, f"{base_name}_scale")
398
- update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
399
-
400
- scale = updated_scale
401
- zero_point = updated_zero_point
402
-
403
342
  return fake_quantize(
404
343
  x=value,
405
344
  scale=scale,
@@ -14,13 +14,12 @@
14
14
 
15
15
 
16
16
  import logging
17
+ from enum import Enum
17
18
  from typing import Optional
18
19
 
19
20
  import torch
20
- from compressed_tensors.quantization.cache import KVCacheScaleType
21
21
  from compressed_tensors.quantization.lifecycle.forward import (
22
22
  wrap_module_forward_quantized,
23
- wrap_module_forward_quantized_attn,
24
23
  )
25
24
  from compressed_tensors.quantization.quant_args import (
26
25
  ActivationOrdering,
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
36
35
 
37
36
  __all__ = [
38
37
  "initialize_module_for_quantization",
38
+ "is_attention_module",
39
+ "KVCacheScaleType",
39
40
  ]
40
41
 
41
42
 
42
43
  _LOGGER = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class KVCacheScaleType(Enum):
47
+ KEY = "k_scale"
48
+ VALUE = "v_scale"
49
+
50
+
45
51
  def initialize_module_for_quantization(
46
52
  module: Module,
47
53
  scheme: Optional[QuantizationScheme] = None,
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
66
72
  return
67
73
 
68
74
  if is_attention_module(module):
69
- # wrap forward call of module to perform
70
75
  # quantized actions based on calltime status
71
- wrap_module_forward_quantized_attn(module, scheme)
72
76
  _initialize_attn_scales(module)
73
77
 
74
78
  else:
75
79
 
76
80
  if scheme.input_activations is not None:
77
- _initialize_scale_zero_point_observer(
81
+ _initialize_scale_zero_point(
78
82
  module,
79
83
  "input",
80
84
  scheme.input_activations,
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
85
89
  weight_shape = None
86
90
  if isinstance(module, torch.nn.Linear):
87
91
  weight_shape = module.weight.shape
88
- _initialize_scale_zero_point_observer(
92
+ _initialize_scale_zero_point(
89
93
  module,
90
94
  "weight",
91
95
  scheme.weights,
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
101
105
 
102
106
  if scheme.output_activations is not None:
103
107
  if not is_kv_cache_quant_scheme(scheme):
104
- _initialize_scale_zero_point_observer(
108
+ _initialize_scale_zero_point(
105
109
  module, "output", scheme.output_activations
106
110
  )
107
111
 
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
109
113
  module.quantization_status = QuantizationStatus.INITIALIZED
110
114
 
111
115
  offloaded = False
116
+ # What is this doing/why isn't this in the attn case?
112
117
  if is_module_offloaded(module):
113
118
  try:
114
119
  from accelerate.hooks import add_hook_to_module, remove_hook_from_module
@@ -146,19 +151,23 @@ def initialize_module_for_quantization(
146
151
  module._hf_hook.weights_map = new_prefix_dict
147
152
 
148
153
 
149
- def _initialize_scale_zero_point_observer(
154
+ def is_attention_module(module: Module):
155
+ return "attention" in module.__class__.__name__.lower() and (
156
+ hasattr(module, "k_proj")
157
+ or hasattr(module, "v_proj")
158
+ or hasattr(module, "qkv_proj")
159
+ )
160
+
161
+
162
+ def _initialize_scale_zero_point(
150
163
  module: Module,
151
164
  base_name: str,
152
165
  quantization_args: QuantizationArgs,
153
166
  weight_shape: Optional[torch.Size] = None,
154
167
  force_zero_point: bool = True,
155
168
  ):
156
- # initialize observer module and attach as submodule
157
- observer = quantization_args.get_observer()
158
- module.register_module(f"{base_name}_observer", observer)
159
-
160
169
  if quantization_args.dynamic:
161
- return # no need to register a scale and zero point for a dynamic observer
170
+ return
162
171
 
163
172
  device = next(module.parameters()).device
164
173
  if is_module_offloaded(module):
@@ -173,10 +182,7 @@ def _initialize_scale_zero_point_observer(
173
182
  expected_shape = (weight_shape[0], 1)
174
183
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
175
184
  num_groups = weight_shape[1] // quantization_args.group_size
176
- expected_shape = (
177
- weight_shape[0],
178
- max(num_groups, 1)
179
- )
185
+ expected_shape = (weight_shape[0], max(num_groups, 1))
180
186
 
181
187
  scale_dtype = module.weight.dtype
182
188
  if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
@@ -208,14 +214,6 @@ def _initialize_scale_zero_point_observer(
208
214
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
209
215
 
210
216
 
211
- def is_attention_module(module: Module):
212
- return "attention" in module.__class__.__name__.lower() and (
213
- hasattr(module, "k_proj")
214
- or hasattr(module, "v_proj")
215
- or hasattr(module, "qkv_proj")
216
- )
217
-
218
-
219
217
  def _initialize_attn_scales(module: Module) -> None:
220
218
  """Initlaize k_scale, v_scale for self_attn"""
221
219
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import warnings
15
16
  from enum import Enum
16
17
  from typing import Any, Dict, Optional, Union
17
18
 
@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
94
95
  block_structure: Optional[str] = None
95
96
  dynamic: bool = False
96
97
  actorder: Union[ActivationOrdering, bool, None] = None
97
- observer: str = Field(
98
+ observer: Optional[str] = Field(
98
99
  default="minmax",
99
100
  description=(
100
101
  "The class to use to compute the quantization param - "
@@ -113,20 +114,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
113
114
  """
114
115
  :return: torch quantization FakeQuantize built based on these QuantizationArgs
115
116
  """
116
- from compressed_tensors.quantization.observers.base import Observer
117
-
118
- if self.dynamic:
119
- # override defualt observer for dynamic, you never want minmax which
120
- # keeps state across samples for dynamic
121
- self.observer = "memoryless"
122
-
123
- return Observer.load_from_registry(self.observer, quantization_args=self)
124
-
125
- def get_kv_cache(self):
126
- """Get the singleton KV Cache"""
127
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
128
-
129
- return QuantizedKVParameterCache(self)
117
+ return self.observer
130
118
 
131
119
  @field_validator("type", mode="before")
132
120
  def validate_type(cls, value) -> QuantizationType:
@@ -171,6 +159,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
171
159
  strategy = model.strategy
172
160
  group_size = model.group_size
173
161
  actorder = model.actorder
162
+ dynamic = model.dynamic
163
+ observer = model.observer
174
164
 
175
165
  # infer strategy
176
166
  if strategy is None:
@@ -207,8 +197,31 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
207
197
  "activation ordering"
208
198
  )
209
199
 
200
+ # infer observer w.r.t. dynamic
201
+ if dynamic:
202
+ if strategy not in (
203
+ QuantizationStrategy.TOKEN,
204
+ QuantizationStrategy.TENSOR,
205
+ ):
206
+ raise ValueError(
207
+ f"One of {QuantizationStrategy.TOKEN} or "
208
+ f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
209
+ "quantization",
210
+ )
211
+ if observer is not None:
212
+ if observer != "memoryless": # avoid annoying users with old configs
213
+ warnings.warn(
214
+ "No observer is used for dynamic quantization, setting to None"
215
+ )
216
+ observer = None
217
+
218
+ elif observer is None:
219
+ # default to minmax for non-dynamic cases
220
+ observer = "minmax"
221
+
210
222
  # write back modified values
211
223
  model.strategy = strategy
224
+ model.observer = observer
212
225
  return model
213
226
 
214
227
  def pytorch_dtype(self) -> torch.dtype:
@@ -122,6 +122,7 @@ INT8_W8A8 = dict(
122
122
  strategy=QuantizationStrategy.TOKEN,
123
123
  symmetric=True,
124
124
  dynamic=True,
125
+ observer=None,
125
126
  ),
126
127
  )
127
128
 
@@ -164,6 +165,7 @@ INT8_W4A8 = dict(
164
165
  strategy=QuantizationStrategy.TOKEN,
165
166
  symmetric=True,
166
167
  dynamic=True,
168
+ observer=None,
167
169
  ),
168
170
  )
169
171
 
@@ -200,6 +202,7 @@ FP8_DYNAMIC = dict(
200
202
  strategy=QuantizationStrategy.TOKEN,
201
203
  symmetric=True,
202
204
  dynamic=True,
205
+ observer=None,
203
206
  ),
204
207
  )
205
208