compressed-tensors 0.9.5a20250521__py3-none-any.whl → 0.9.5a20250528__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,6 +18,7 @@ from typing import Optional
18
18
 
19
19
  import torch
20
20
  from compressed_tensors.quantization.quant_args import (
21
+ DynamicType,
21
22
  QuantizationArgs,
22
23
  QuantizationStrategy,
23
24
  QuantizationType,
@@ -189,7 +190,11 @@ def _process_quantization(
189
190
  q_min, q_max = calculate_range(args, x.device)
190
191
  group_size = args.group_size
191
192
 
192
- if args.strategy == QuantizationStrategy.GROUP:
193
+ if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
194
+ n_dims = x.shape
195
+ if len(n_dims) > 2:
196
+ x = x.squeeze(0)
197
+
193
198
  output_dtype = dtype if dtype is not None else x.dtype
194
199
  output = torch.zeros_like(x).to(output_dtype)
195
200
  columns = output.shape[1]
@@ -251,6 +256,9 @@ def _process_quantization(
251
256
  if not is_column_order:
252
257
  output = safe_permute(output, torch.argsort(perm), dim=1)
253
258
 
259
+ if len(n_dims) > 2:
260
+ output = output.unsqueeze(0)
261
+
254
262
  else: # covers channel, token and tensor strategies
255
263
  if do_quantize:
256
264
  output = _quantize(
@@ -352,9 +360,11 @@ def forward_quantize(
352
360
  g_idx = getattr(module, "weight_g_idx", None)
353
361
  global_scale = getattr(module, f"{base_name}_global_scale", None)
354
362
 
355
- if args.dynamic:
363
+ if args.dynamic in (True, DynamicType.LOCAL):
356
364
  # dynamic quantization - determine the scale/zp on the fly
357
- scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
365
+ scale, zero_point = compute_dynamic_scales_and_zp(
366
+ value=value, args=args, module=module, global_scale=global_scale
367
+ )
358
368
  else:
359
369
  # static quantization - get scale and zero point from layer
360
370
  scale = getattr(module, f"{base_name}_scale")
@@ -388,6 +398,7 @@ def _quantize(
388
398
  scale = scale.to(global_scale.dtype) / global_scale
389
399
 
390
400
  scaled = x / scale
401
+
391
402
  if zero_point is not None:
392
403
  scaled += zero_point.to(x.dtype)
393
404
 
@@ -398,6 +409,7 @@ def _quantize(
398
409
  q_max,
399
410
  )
400
411
  quantized_value = round_to_quantized_type(clamped_value, args)
412
+
401
413
  if dtype is not None:
402
414
  quantized_value = quantized_value.to(dtype)
403
415
 
@@ -422,6 +434,7 @@ def _dequantize(
422
434
 
423
435
  if zero_point is not None:
424
436
  dequant_value = dequant_value - zero_point.to(scale.dtype)
437
+
425
438
  dequant_value = dequant_value * scale
426
439
 
427
440
  if dtype is not None:
@@ -156,13 +156,33 @@ def _initialize_scale_zero_point(
156
156
  force_zero_point: bool = True,
157
157
  scale_dtype: Optional[torch.dtype] = None,
158
158
  ):
159
- if quantization_args.dynamic:
159
+ if quantization_args.dynamic is True:
160
160
  return
161
161
 
162
162
  # initialize on execution device to avoid performing quantized ops on cpu
163
163
  device = get_execution_device(module)
164
164
 
165
- # infer expected scale/zero point shape
165
+ # 1. Create global_scales for tensor_group
166
+ if quantization_args.strategy == QuantizationStrategy.TENSOR_GROUP:
167
+ # TODO: should move to llmcompressor
168
+ if base_name == "weight":
169
+ # When applying weight-only FP4 quantization, generate a global_scale
170
+ # This scale is applied during runtime to ensure that the generated
171
+ # local scale falls properly within the FP8 range (i.e max value is FP8_max)
172
+ # which is the expected dtype of NVFP4A16 scales
173
+ value = generate_global_scale(input_tensor=module.weight)
174
+ value = value.to(device)
175
+ init_global_scale = Parameter(value, requires_grad=False)
176
+ else:
177
+ init_global_scale = Parameter(
178
+ torch.empty(1, dtype=torch.float32, device=device),
179
+ requires_grad=False,
180
+ )
181
+ register_offload_parameter(
182
+ module, f"{base_name}_global_scale", init_global_scale
183
+ )
184
+
185
+ # 2. Infer expected scale/zero point shape
166
186
  if quantization_args.strategy == QuantizationStrategy.TOKEN:
167
187
  expected_shape = (1, 1)
168
188
  else:
@@ -172,47 +192,35 @@ def _initialize_scale_zero_point(
172
192
  if quantization_args.strategy == QuantizationStrategy.CHANNEL:
173
193
  # (output_channels, 1)
174
194
  expected_shape = (weight_shape[0], 1)
175
- elif quantization_args.strategy == QuantizationStrategy.GROUP:
195
+ elif quantization_args.strategy in (
196
+ QuantizationStrategy.TENSOR_GROUP,
197
+ QuantizationStrategy.GROUP,
198
+ ):
176
199
  num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
177
200
  expected_shape = (weight_shape[0], max(num_groups, 1))
178
201
 
202
+ # 3. Identify quantization scale and zp dtype
179
203
  scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
180
- # TODO: consider erroring out in the future as if the dtype if not one fo these,
181
- # there is likely bug
182
-
183
- if is_fp4(quantization_args=quantization_args) and base_name == "weight":
184
- scale_dtype = FP8_E4M3_DATA.dtype
185
- # When applying weight-only FP4 quantization, generate a global_scale
186
- # This scale is applied during runtime to ensure that the generated
187
- # local scale falls properly within the FP8 range (i.e max value is FP8_max)
188
- # which is the expected dtype of NVFP4A16 scales
189
- value = generate_global_scale(input_tensor=module.weight)
190
- value = value.to(device)
191
- init_global_scale = Parameter(value, requires_grad=False)
192
- register_offload_parameter(
193
- module, f"{base_name}_global_scale", init_global_scale
194
- )
195
-
196
- if scale_dtype not in [
197
- torch.float16,
198
- torch.bfloat16,
199
- torch.float32,
200
- ] and not is_fp4(quantization_args=quantization_args):
201
- scale_dtype = torch.float16
202
204
 
203
- # initializes empty scale, zero point, and g_idx parameters for the module
204
- init_scale = Parameter(
205
- torch.empty(expected_shape, dtype=scale_dtype, device=device),
206
- requires_grad=False,
207
- )
208
- register_offload_parameter(module, f"{base_name}_scale", init_scale)
205
+ if is_fp4(quantization_args=quantization_args):
206
+ scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
207
+ else:
208
+ # TODO: consider erroring out in the future as if the dtype if not one of these,
209
+ # there is likely bug
210
+ if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
211
+ scale_dtype = torch.float16
212
+ zp_dtype = quantization_args.pytorch_dtype()
213
+
214
+ # 4. Initializes empty scale, zero point, and g_idx parameters for the module
215
+ # do not init scales for quantzation_args.dynamic == DynamicType.local
216
+ if not quantization_args.dynamic:
217
+ init_scale = Parameter(
218
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
219
+ requires_grad=False,
220
+ )
221
+ register_offload_parameter(module, f"{base_name}_scale", init_scale)
209
222
 
210
223
  if force_zero_point or not quantization_args.symmetric:
211
- if is_fp4(quantization_args=quantization_args):
212
- zp_dtype = FP8_E4M3_DATA.dtype
213
- else:
214
- zp_dtype = quantization_args.pytorch_dtype()
215
-
216
224
  init_zero_point = Parameter(
217
225
  torch.zeros(expected_shape, device=device, dtype=zp_dtype),
218
226
  requires_grad=False,
@@ -32,6 +32,7 @@ __all__ = [
32
32
  "QuantizationArgs",
33
33
  "round_to_quantized_type",
34
34
  "ActivationOrdering",
35
+ "DynamicType",
35
36
  ]
36
37
 
37
38
 
@@ -98,6 +99,22 @@ class QuantizationStrategy(str, Enum):
98
99
  GROUP = "group"
99
100
  BLOCK = "block"
100
101
  TOKEN = "token"
102
+ TENSOR_GROUP = "tensor_group"
103
+
104
+
105
+ class DynamicType(str, Enum):
106
+ """
107
+ Enum storing potential dynamic types.
108
+
109
+ 1. If dynamic is True, all quantization parameters are generated on the fly.
110
+ 2. If dynamic is False, all quantization parameters generated are static.
111
+ 3. If "local" is provided, only local quantization parameters are dynamic.
112
+
113
+ Note: "local" is only currently supported for NVFP4.
114
+
115
+ """
116
+
117
+ LOCAL = "local"
101
118
 
102
119
 
103
120
  class ActivationOrdering(Aliasable, str, Enum):
@@ -152,7 +169,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
152
169
  group_size: Optional[int] = None
153
170
  strategy: Optional[QuantizationStrategy] = None
154
171
  block_structure: Optional[str] = None
155
- dynamic: bool = False
172
+ dynamic: Union[DynamicType, bool] = False
156
173
  actorder: Union[ActivationOrdering, bool, None] = None
157
174
  observer: Optional[str] = Field(
158
175
  default=None,
@@ -206,6 +223,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
206
223
 
207
224
  return value
208
225
 
226
+ @field_validator("dynamic", mode="before")
227
+ def validate_dynamic(cls, value) -> Union[DynamicType, bool]:
228
+ if isinstance(value, str):
229
+ return DynamicType(value.lower())
230
+ return value
231
+
209
232
  @model_validator(mode="after")
210
233
  def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
211
234
  # extract user-passed values from dictionary
@@ -239,7 +262,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
239
262
  if (
240
263
  group_size is not None
241
264
  and group_size > 0
242
- and strategy != QuantizationStrategy.GROUP
265
+ and strategy
266
+ not in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP)
243
267
  ):
244
268
  raise ValueError("group_size requires strategy to be set to 'group'")
245
269
 
@@ -255,22 +279,35 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
255
279
  if strategy not in (
256
280
  QuantizationStrategy.TOKEN,
257
281
  QuantizationStrategy.TENSOR,
282
+ QuantizationStrategy.TENSOR_GROUP,
258
283
  ):
259
284
  raise ValueError(
260
- f"One of {QuantizationStrategy.TOKEN} or "
261
- f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
262
- "quantization",
285
+ f"One of {(QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP)} "
286
+ "must be used for dynamic quantization",
263
287
  )
288
+
289
+ if (
290
+ dynamic == DynamicType.LOCAL
291
+ and strategy != QuantizationStrategy.TENSOR_GROUP
292
+ ):
293
+ raise ValueError("local is only supported for strategy tensor_group")
294
+
264
295
  if observer is not None:
265
- if observer != "memoryless": # avoid annoying users with old configs
266
- warnings.warn(
267
- "No observer is used for dynamic quantization, setting to None"
268
- )
269
- observer = None
296
+ if dynamic is True: # checking if dynamic is True, not "local"
297
+ if (
298
+ observer != "memoryless"
299
+ ): # avoid annoying users with old configs
300
+ warnings.warn(
301
+ "No observer is used for dynamic quantization, setting to None"
302
+ )
303
+ observer = None
304
+ else:
305
+ if dynamic == DynamicType.LOCAL:
306
+ observer = "minmax"
270
307
 
271
308
  elif observer is None:
272
- # default to mse for non-dynamic cases
273
- observer = "mse"
309
+ # default to minmax for non-dynamic cases
310
+ observer = "minmax"
274
311
 
275
312
  # write back modified values
276
313
  model.strategy = strategy
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from typing import Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
20
20
  from compressed_tensors.quantization.quant_scheme import (
21
21
  QuantizationScheme,
22
22
  preset_name_to_scheme,
@@ -251,7 +251,7 @@ class QuantizationConfig(BaseModel):
251
251
 
252
252
  for _, scheme in self.config_groups.items():
253
253
  if scheme.input_activations is not None:
254
- if not scheme.input_activations.dynamic:
254
+ if scheme.input_activations.dynamic in (False, DynamicType.LOCAL):
255
255
  return True
256
256
  if scheme.output_activations is not None:
257
257
  if not scheme.output_activations.dynamic:
@@ -16,6 +16,7 @@ from copy import deepcopy
16
16
  from typing import Any, Dict, List, Optional
17
17
 
18
18
  from compressed_tensors.quantization.quant_args import (
19
+ DynamicType,
19
20
  QuantizationArgs,
20
21
  QuantizationStrategy,
21
22
  QuantizationType,
@@ -104,13 +105,33 @@ NVFP4A16 = dict(
104
105
  weights=QuantizationArgs(
105
106
  num_bits=4,
106
107
  type=QuantizationType.FLOAT,
107
- strategy=QuantizationStrategy.GROUP,
108
+ strategy=QuantizationStrategy.TENSOR_GROUP,
108
109
  symmetric=True,
109
110
  dynamic=False,
110
111
  group_size=16,
111
112
  )
112
113
  )
113
114
 
115
+
116
+ NVFP4 = dict(
117
+ weights=QuantizationArgs(
118
+ num_bits=4,
119
+ type=QuantizationType.FLOAT,
120
+ strategy=QuantizationStrategy.TENSOR_GROUP,
121
+ symmetric=True,
122
+ dynamic=False,
123
+ group_size=16,
124
+ ),
125
+ input_activations=QuantizationArgs(
126
+ num_bits=4,
127
+ type=QuantizationType.FLOAT,
128
+ strategy=QuantizationStrategy.TENSOR_GROUP,
129
+ symmetric=True,
130
+ dynamic=DynamicType.LOCAL,
131
+ group_size=16,
132
+ ),
133
+ )
134
+
114
135
  # 8 bit integer weights and 8 bit activations quantization
115
136
  INT8_W8A8 = dict(
116
137
  weights=QuantizationArgs(
@@ -237,4 +258,5 @@ PRESET_SCHEMES = {
237
258
  "FP8": FP8,
238
259
  "FP8_DYNAMIC": FP8_DYNAMIC,
239
260
  "NVFP4A16": NVFP4A16,
261
+ "NVFP4": NVFP4,
240
262
  }
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
+ import math
16
17
  from typing import Generator, List, Optional, Tuple
17
18
 
18
19
  import torch
@@ -103,7 +104,9 @@ def calculate_qparams(
103
104
  if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104
105
  # Conditionally scale the generated local scale by a global_scale
105
106
  scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
107
+ scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
106
108
  scales = scales.to(FP8_E4M3_DATA.dtype)
109
+
107
110
  else:
108
111
  scales = max_val_pos / (float(bit_range) / 2)
109
112
 
@@ -143,7 +146,12 @@ def calculate_qparams(
143
146
  return scales, zero_points
144
147
 
145
148
 
146
- def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
149
+ def compute_dynamic_scales_and_zp(
150
+ value: Tensor,
151
+ args: QuantizationArgs,
152
+ module: torch.nn.Module,
153
+ global_scale: Optional[Tensor] = None,
154
+ ):
147
155
  """
148
156
  Returns the computed scales and zero points for dynamic activation
149
157
  quantization.
@@ -155,24 +163,41 @@ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
155
163
  reduced dimensions
156
164
  :return: tuple of scale and zero point derived from the observed tensor
157
165
  """
166
+
167
+ keep_dims = True
158
168
  if args.strategy == QuantizationStrategy.TOKEN:
159
169
  dim = {1, 2}
160
170
  reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
161
171
  elif args.strategy == QuantizationStrategy.TENSOR:
162
172
  reduce_dims = None
173
+ elif args.strategy == QuantizationStrategy.TENSOR_GROUP:
174
+ if len(value.shape) > 2:
175
+ value = value.squeeze(0)
176
+
177
+ dim = {0, 1}
178
+ reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179
+ keep_dims = False
180
+ value = torch.reshape(
181
+ value,
182
+ (
183
+ value.shape[0],
184
+ math.ceil(value.shape[1] / args.group_size),
185
+ args.group_size,
186
+ ),
187
+ )
163
188
  else:
164
189
  raise ValueError(
165
- f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
166
- "must be used for dynamic quantization",
190
+ "Dynamic quantization is only supported for ",
191
+ f"{QuantizationStrategy.TOKEN, QuantizationStrategy.TENSOR, QuantizationStrategy.TENSOR_GROUP}",
167
192
  )
168
193
 
169
194
  if not reduce_dims:
170
195
  min_val, max_val = torch.aminmax(value)
171
196
  else:
172
- min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
173
- max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
197
+ min_val = torch.amin(value, dim=reduce_dims, keepdims=keep_dims)
198
+ max_val = torch.amax(value, dim=reduce_dims, keepdims=keep_dims)
174
199
 
175
- return calculate_qparams(min_val, max_val, args)
200
+ return calculate_qparams(min_val, max_val, args, global_scale=global_scale)
176
201
 
177
202
 
178
203
  def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: skip_file
17
+
18
+ from .transform_args import *
19
+ from .transform_scheme import *
20
+ from .transform_config import *
@@ -0,0 +1,54 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
16
+ from typing import Any, List
17
+
18
+ from pydantic import BaseModel, Field, field_validator
19
+
20
+
21
+ __all__ = ["TransformArgs"]
22
+
23
+
24
+ class TransformLocation(str, Enum):
25
+ INPUT = "input"
26
+ WEIGHT_INPUT = "weight_input"
27
+ WEIGHT_OUTPUT = "weight_output"
28
+ OUTPUT = "output"
29
+ K_CACHE = "k_cache"
30
+ Q_ATTN = "q_attn"
31
+
32
+
33
+ class TransformArgs(BaseModel):
34
+ """
35
+ Arguments which define *how* and where a transform should be applied to a model
36
+
37
+ :param targets: list of modules to apply transforms to
38
+ :param location: where to apply transform on module, one of (`input`, `weight`,
39
+ `output`, `k_cache`, `q_attn`)
40
+ :param inverse: whether or not to apply the inverse of a transform
41
+ :param ignore: any modules which should be ignored from the targets list
42
+ """
43
+
44
+ targets: List[str]
45
+ location: TransformLocation
46
+ inverse: bool = Field(default=False)
47
+ ignore: List[str] = Field(default_factory=list)
48
+
49
+ @field_validator("targets", "ignore", mode="before")
50
+ @classmethod
51
+ def wrap_singleton(cls, value):
52
+ if isinstance(value, str):
53
+ return [value]
54
+ return value
@@ -0,0 +1,73 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ from compressed_tensors.transform import TransformArgs, TransformScheme
18
+ from pydantic import BaseModel
19
+
20
+
21
+ __all__ = ["TransformConfig"]
22
+
23
+
24
+ class TransformConfig(BaseModel):
25
+ """
26
+ Configuration of transforms to be applied to a model. This config is to be
27
+ serialized within a model's `config.json` file
28
+
29
+ :param config_groups: A dictionary of `TransformSchemes` that should be applied
30
+ to a particular model. The keys can be any arbitrary string
31
+ """
32
+
33
+ config_groups: Dict[str, TransformScheme]
34
+
35
+
36
+ # quip / quip sharp
37
+ QUIP = TransformConfig(
38
+ config_groups={
39
+ "v": TransformScheme(
40
+ type="hadamard",
41
+ apply=[
42
+ TransformArgs(
43
+ targets=["Linear"],
44
+ location="input", # non-mergable
45
+ ),
46
+ TransformArgs(
47
+ targets=["Linear"],
48
+ location="weight_input",
49
+ inverse=True,
50
+ ),
51
+ ],
52
+ randomize_modules=True,
53
+ ),
54
+ "u": TransformScheme(
55
+ type="hadamard",
56
+ apply=[
57
+ TransformArgs(
58
+ targets=["Linear"],
59
+ location="weight_output",
60
+ ),
61
+ TransformArgs(
62
+ targets=["Linear"], location="output", inverse=True # non-mergable
63
+ ),
64
+ ],
65
+ randomize_modules=True,
66
+ ),
67
+ }
68
+ )
69
+
70
+
71
+ PRESET_CONFIGS = {
72
+ "QUIP": QUIP,
73
+ }
@@ -0,0 +1,43 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List
16
+
17
+ from compressed_tensors.transform import TransformArgs
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ __all__ = ["TransformScheme"]
22
+
23
+
24
+ class TransformScheme(BaseModel):
25
+ """
26
+ Scheme used to parameterize a particular transform type and specify how and where it
27
+ should be applied to the model
28
+
29
+ :param type: string indicating the particular transform type that should be created
30
+ and applied. This should be one of the registered transform types
31
+ (see `Transforms.registered_names()`)
32
+ :param apply: list of TransformationArgs containing the information about the
33
+ modules that should be targeted by the specified transform
34
+ :param randomize_modules: True if unique transforms should be applied to each
35
+ unique module targeted by `apply`, otherwise reuse transform weights where
36
+ applicable
37
+ :param requires_grad: True if weights include gradients for training
38
+ """
39
+
40
+ type: str
41
+ apply: List[TransformArgs] = Field(default_factory=list)
42
+ randomize_modules: bool = Field(default=False)
43
+ requires_grad: bool = Field(default=False)
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250521'
20
+ __version__ = version = '0.9.5.a20250528'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250521
3
+ Version: 0.9.5a20250528
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=FJ5OPohL511E88TFF_Jipl_3ikvZ6NgmdrYxPbi2vo8,521
3
+ compressed_tensors/version.py,sha256=iHdCbvf5_sP-ylnF-60aPKldM3BLsLc1pARRzA74l60,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -26,19 +26,23 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
26
26
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
27
27
  compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
28
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
29
- compressed_tensors/quantization/quant_args.py,sha256=5-mq43RmbI81z9Xl9pYNv4bqIP5AIT65FgT--4ERsE8,10502
30
- compressed_tensors/quantization/quant_config.py,sha256=MxSUcb5dOqMN6LFyD5K2h8X0TvEtcWIAoiUJqD2dHGE,10159
31
- compressed_tensors/quantization/quant_scheme.py,sha256=Fx7Ma4bDlFB6OWkHKhOB6_0AOVIOPRgNE_qTwmDLSbc,6586
29
+ compressed_tensors/quantization/quant_args.py,sha256=huROC8fbY899EYa2MnEmujvcBeHYLpn-e8ZEViEFASo,11804
30
+ compressed_tensors/quantization/quant_config.py,sha256=aFi6PKqmEX9iP9O8GVn3mEUjRDEwk_hOCbmmiq-j9oU,10198
31
+ compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF5hLz89J2PVYWBBnXRc,7102
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=-OKZ-FFFfIIoeGTrho8lXx6HVWZQp3Xkn3Q-G0hU-CM,18294
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=WY-HY5kXY2Zs9HMpaq44bpolQUAQ1ELrNZC7GM5C4jw,14494
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=65USJEtsp_n8X36L5y4g4ftMnhrQyRWbwKJ8RZMMiBo,14797
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=dWXxjYLemjmtrSnb8vyuvNoNTSm8ywmUswze3soKY4o,12041
37
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=976sZ45ywGVzH1n4pyVhG7hnUBP1wKEWoo9cHrmKHxU,12522
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
39
- compressed_tensors/quantization/utils/helpers.py,sha256=w3Ucpdog88b0MnZdJ37VzgtYi1fqrwJafYdfWPc0hTk,16852
39
+ compressed_tensors/quantization/utils/helpers.py,sha256=I-bJcMdBFXjIUQEpnxMMN_FfQyXjojpe5w7ZIKSZ5UU,17588
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
+ compressed_tensors/transform/__init__.py,sha256=oa5VdrE-GtDYYceXNSwj5X_ropoXLLukm6Aufcc9WhY,747
43
+ compressed_tensors/transform/transform_args.py,sha256=Sazu_4kXL7IvIEgTaimgo8dV-qacXf_t1NLEfDvPJEU,1759
44
+ compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
45
+ compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
42
46
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
43
47
  compressed_tensors/utils/helpers.py,sha256=RrNvzD08naEjEiXdU-FdZjQVda1nQywu1hA_GCDj0vg,10415
44
48
  compressed_tensors/utils/offload.py,sha256=JNQ66_6vhSsizhlUaMgyEdBuFolYxbgUuT1mAZrCfKY,15436
@@ -46,8 +50,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
46
50
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
47
51
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
48
52
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
49
- compressed_tensors-0.9.5a20250521.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
- compressed_tensors-0.9.5a20250521.dist-info/METADATA,sha256=Xl6EbYwMlKhFyy6VXtxD2x0TsiTDG36YszGdub5wLqM,7004
51
- compressed_tensors-0.9.5a20250521.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
52
- compressed_tensors-0.9.5a20250521.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
- compressed_tensors-0.9.5a20250521.dist-info/RECORD,,
53
+ compressed_tensors-0.9.5a20250528.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
+ compressed_tensors-0.9.5a20250528.dist-info/METADATA,sha256=diiEZExV1kI7i_tWkgGp3B1UPGtoNtTABx2BJZcHk8I,7004
55
+ compressed_tensors-0.9.5a20250528.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ compressed_tensors-0.9.5a20250528.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
57
+ compressed_tensors-0.9.5a20250528.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5