compressed-tensors 0.10.3a20250806__py3-none-any.whl → 0.10.3a20250812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,9 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- SPARSITY_CONFIG_NAME = "sparsity_config"
15
+ # configs
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
- COMPRESSION_CONFIG_NAME = "compression_config"
18
- KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17
+ SPARSITY_CONFIG_NAME = "sparsity_config"
18
+ TRANSFORM_CONFIG_NAME = "transform_config"
19
+
20
+ # required fields
19
21
  COMPRESSION_VERSION_NAME = "version"
20
22
  QUANTIZATION_METHOD_NAME = "quant_method"
23
+
24
+ # auxillary configs
25
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -29,6 +29,7 @@ from compressed_tensors.base import (
29
29
  QUANTIZATION_CONFIG_NAME,
30
30
  QUANTIZATION_METHOD_NAME,
31
31
  SPARSITY_CONFIG_NAME,
32
+ TRANSFORM_CONFIG_NAME,
32
33
  )
33
34
  from compressed_tensors.compressors.base import BaseCompressor
34
35
  from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@ from compressed_tensors.quantization import (
43
44
  )
44
45
  from compressed_tensors.quantization.lifecycle import expand_target_names
45
46
  from compressed_tensors.quantization.utils import is_module_quantized
47
+ from compressed_tensors.transform import TransformConfig
46
48
  from compressed_tensors.utils import (
47
49
  align_module_device,
48
50
  delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105
107
 
106
108
  sparsity_config: Optional[SparsityCompressionConfig] = None
107
109
  quantization_config: Optional[QuantizationConfig] = None
110
+ transform_config: Optional[TransformConfig] = None
108
111
 
109
112
  @classmethod
110
113
  def from_pretrained(
@@ -144,6 +147,8 @@ class ModelCompressor:
144
147
 
145
148
  sparsity_config = cls.parse_sparsity_config(compression_config)
146
149
  quantization_config = cls.parse_quantization_config(compression_config)
150
+ # TODO: transform config is not support by CompressedTensorsConfig yet
151
+
147
152
  if sparsity_config is None and quantization_config is None:
148
153
  return None
149
154
 
@@ -177,25 +182,32 @@ class ModelCompressor:
177
182
  algorithm
178
183
  :return: compressor for the configs, or None if model is not compressed
179
184
  """
185
+ # reconstruct config from schemes attached to modules
180
186
  quantization_config = QuantizationConfig.from_pretrained(
181
187
  model, format=quantization_format
182
188
  )
183
189
 
190
+ # use config passed as argument
184
191
  if isinstance(sparsity_config, str): # we passed in a sparsity format
185
192
  sparsity_config = SparsityCompressionConfig.load_from_registry(
186
193
  sparsity_config
187
194
  )
188
195
 
189
- if sparsity_config is None and quantization_config is None:
196
+ # use config attached to model
197
+ transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198
+
199
+ if not any((quantization_config, sparsity_config, transform_config)):
190
200
  return None
191
201
 
192
202
  return cls(
193
- sparsity_config=sparsity_config, quantization_config=quantization_config
203
+ sparsity_config=sparsity_config,
204
+ quantization_config=quantization_config,
205
+ transform_config=transform_config,
194
206
  )
195
207
 
196
208
  @staticmethod
197
209
  def parse_sparsity_config(
198
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
210
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
199
211
  ) -> Union[Dict[str, Any], None]:
200
212
  """
201
213
  Parse sparsity config from quantization/compression config. Sparsity
@@ -215,7 +227,7 @@ class ModelCompressor:
215
227
 
216
228
  @staticmethod
217
229
  def parse_quantization_config(
218
- compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
230
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
219
231
  ) -> Union[Dict[str, Any], None]:
220
232
  """
221
233
  Parse quantization config from quantization/compression config. The
@@ -234,6 +246,7 @@ class ModelCompressor:
234
246
 
235
247
  quantization_config = deepcopy(compression_config)
236
248
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
249
+ quantization_config.pop(TRANSFORM_CONFIG_NAME, None)
237
250
 
238
251
  # some fields are required, even if a qconfig is not present
239
252
  # pop them off and if nothing remains, then there is no qconfig
@@ -254,13 +267,17 @@ class ModelCompressor:
254
267
  self,
255
268
  sparsity_config: Optional[SparsityCompressionConfig] = None,
256
269
  quantization_config: Optional[QuantizationConfig] = None,
270
+ transform_config: Optional[TransformConfig] = None,
257
271
  ):
258
272
  self.sparsity_config = sparsity_config
259
273
  self.quantization_config = quantization_config
274
+ self.transform_config = transform_config
275
+
260
276
  self.sparsity_compressor = None
261
277
  self.quantization_compressor: Optional[
262
278
  Union[BaseQuantizationCompressor, DenseCompressor]
263
279
  ] = None
280
+ # no transform compressor is required
264
281
 
265
282
  if sparsity_config is not None:
266
283
  self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +657,49 @@ class ModelCompressor:
640
657
 
641
658
  :param save_directory: path to a folder containing a HF model config
642
659
  """
643
- if self.quantization_config is None and self.sparsity_config is None:
660
+ # this check is also done in `from_pretrained_model`,
661
+ # but not in `from_pretrained`` or `from_compression_config``
662
+ if not any(
663
+ (self.quantization_config, self.sparsity_config, self.transform_config)
664
+ ):
644
665
  return
645
666
 
667
+ # write to config.json file, regardless of whether it exists already
668
+ # overwrite previous config and version if already existing
646
669
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
647
- if not os.path.exists(config_file_path):
648
- _LOGGER.warning(
649
- f"Could not find a valid model config file in "
650
- f"{save_directory}. Compression config will not be saved."
651
- )
652
- return
670
+ if os.path.exists(config_file_path):
671
+ with open(config_file_path, "r") as file:
672
+ config_data = json.load(file)
673
+ else:
674
+ config_data = {}
653
675
 
654
- with open(config_file_path, "r") as config_file:
655
- config_data = json.load(config_file)
676
+ # serialize configs into json
677
+ qconfig_data = (
678
+ self.quantization_config.model_dump(exclude=["quant_method"])
679
+ if self.quantization_config is not None
680
+ else {}
681
+ )
682
+ sconfig_data = (
683
+ self.sparsity_config.model_dump()
684
+ if self.sparsity_config is not None
685
+ else {}
686
+ )
687
+ tconfig_data = (
688
+ self.transform_config.model_dump()
689
+ if self.transform_config is not None
690
+ else {}
691
+ )
656
692
 
657
- # required metadata whenever a quantization or sparsity config is present
658
- # overwrite previous config and version if already existing
659
- config_data[QUANTIZATION_CONFIG_NAME] = {}
660
- config_data[QUANTIZATION_CONFIG_NAME][
661
- COMPRESSION_VERSION_NAME
662
- ] = compressed_tensors.__version__
663
- if self.quantization_config is not None:
664
- self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665
- else:
666
- config_data[QUANTIZATION_CONFIG_NAME][
667
- QUANTIZATION_METHOD_NAME
668
- ] = DEFAULT_QUANTIZATION_METHOD
669
-
670
- # quantization and sparsity configs
671
- if self.quantization_config is not None:
672
- quant_config_data = self.quantization_config.model_dump()
673
- config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674
- if self.sparsity_config is not None:
675
- sparsity_config_data = self.sparsity_config.model_dump()
676
- config_data[QUANTIZATION_CONFIG_NAME][
677
- SPARSITY_CONFIG_NAME
678
- ] = sparsity_config_data
693
+ # construct compression (quantization) config
694
+ config_data[QUANTIZATION_CONFIG_NAME] = {
695
+ COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
696
+ QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
697
+ SPARSITY_CONFIG_NAME: sconfig_data,
698
+ TRANSFORM_CONFIG_NAME: tconfig_data,
699
+ **qconfig_data,
700
+ }
679
701
 
702
+ # write results to config.json file
680
703
  with open(config_file_path, "w") as config_file:
681
704
  json.dump(config_data, config_file, indent=2, sort_keys=True)
682
705
 
@@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union
19
19
  import torch
20
20
  from compressed_tensors.utils import Aliasable
21
21
  from compressed_tensors.utils.helpers import deprecated
22
- from pydantic import BaseModel, Field, field_validator, model_validator
22
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
23
23
 
24
24
 
25
25
  __all__ = [
@@ -358,6 +358,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
358
358
  def get_observer(self) -> str:
359
359
  return self.observer
360
360
 
361
+ model_config = ConfigDict(extra="forbid")
362
+
361
363
 
362
364
  def round_to_quantized_type(
363
365
  tensor: torch.Tensor, args: QuantizationArgs
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Dict, List, Optional, Union
16
+ from typing import Annotated, Any, Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
19
  from compressed_tensors.quantization.quant_args import DynamicType, QuantizationArgs
@@ -26,7 +26,7 @@ from compressed_tensors.quantization.utils import (
26
26
  module_type,
27
27
  parse_out_kv_cache_args,
28
28
  )
29
- from pydantic import BaseModel, Field
29
+ from pydantic import BaseModel, ConfigDict, Field
30
30
  from torch.nn import Module
31
31
 
32
32
 
@@ -142,6 +142,9 @@ class QuantizationConfig(BaseModel):
142
142
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
143
143
  global_compression_ratio: Optional[float] = None
144
144
  ignore: Optional[List[str]] = Field(default_factory=list)
145
+ # `run_compressed` is a dummy, unused arg for backwards compatibility
146
+ # see: https://github.com/huggingface/transformers/pull/39324
147
+ run_compressed: Annotated[Any, Field(exclude=True)] = None
145
148
 
146
149
  def model_post_init(self, __context):
147
150
  """
@@ -254,3 +257,6 @@ class QuantizationConfig(BaseModel):
254
257
  return True
255
258
 
256
259
  return False
260
+
261
+ # TODO set `extra="forbid"` when upstream transformers is compatible
262
+ model_config = ConfigDict(extra="ignore")
@@ -14,7 +14,7 @@
14
14
 
15
15
  import warnings
16
16
  from copy import deepcopy
17
- from typing import Any, Dict, List, Optional
17
+ from typing import List, Optional
18
18
 
19
19
  from compressed_tensors.quantization.quant_args import (
20
20
  DynamicType,
@@ -22,7 +22,7 @@ from compressed_tensors.quantization.quant_args import (
22
22
  QuantizationStrategy,
23
23
  QuantizationType,
24
24
  )
25
- from pydantic import BaseModel, model_validator
25
+ from pydantic import BaseModel, ConfigDict, model_validator
26
26
 
27
27
 
28
28
  __all__ = [
@@ -81,6 +81,8 @@ class QuantizationScheme(BaseModel):
81
81
 
82
82
  return model
83
83
 
84
+ model_config = ConfigDict(extra="forbid")
85
+
84
86
 
85
87
  """
86
88
  Pre-Set Quantization Scheme Args
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import torch
16
+ from compressed_tensors import TRANSFORM_CONFIG_NAME
16
17
  from compressed_tensors.transform import TransformConfig, TransformFactory
17
18
 
18
19
 
@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
30
31
  for name, scheme in config.config_groups.items():
31
32
  factory = TransformFactory.from_scheme(scheme, name=name)
32
33
  factory.apply_to_model(model)
34
+
35
+ # attach config to model for compression/serialization
36
+ setattr(model, TRANSFORM_CONFIG_NAME, config)
@@ -14,11 +14,10 @@
14
14
 
15
15
  from abc import ABC, abstractmethod
16
16
  from collections import defaultdict
17
- from typing import List, Optional, Tuple, Set
17
+ from typing import List, Optional, Set, Tuple
18
18
 
19
19
  import torch
20
20
  import torch.nn.utils.parametrize as P
21
- from compressed_tensors import InternalModule
22
21
  from compressed_tensors.registry.registry import RegistryMixin, T
23
22
  from compressed_tensors.transform import (
24
23
  TransformArgs,
@@ -34,6 +33,7 @@ from compressed_tensors.utils import (
34
33
  register_offload_module,
35
34
  update_offload_parameter,
36
35
  )
36
+ from compressed_tensors.utils.internal import InternalModule
37
37
  from torch import Tensor
38
38
  from torch.nn import Module, Parameter
39
39
 
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import math
16
- from typing import Optional, Union
15
+ from typing import Optional
17
16
 
18
17
  import torch
19
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -26,7 +25,7 @@ from compressed_tensors.transform.utils.matrix import (
26
25
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
27
26
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
28
27
  from torch import Tensor, device, dtype
29
- from torch.nn import Linear, Module, Parameter
28
+ from torch.nn import Module, Parameter
30
29
 
31
30
 
32
31
  @TransformFactory.register("hadamard")
@@ -54,14 +53,14 @@ class HadamardFactory(TransformFactory):
54
53
  """
55
54
  assert hasattr(module, "weight")
56
55
  size = get_transform_size(module, args.location, self.scheme.head_dim)
57
- dtype = module.weight.dtype
56
+ dtype = self.scheme.precision
58
57
  device = get_offloaded_device(module)
59
58
  exec_device = get_execution_device(module)
60
59
 
61
60
  factory_kwargs = {"construct_device": exec_device}
62
61
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
63
62
  perm = self.perms[weight] if self.scheme.randomize else None
64
- return HadamardTransform(weight, perm, args, type(module))
63
+ return HadamardTransform(weight, perm, self.scheme, args, type(module))
65
64
 
66
65
  def _create_weight(
67
66
  self,
@@ -85,15 +84,18 @@ class HadamardTransform(TransformBase):
85
84
  self,
86
85
  weight: Parameter,
87
86
  perm: Optional[Parameter],
87
+ scheme: TransformScheme,
88
88
  args: TransformArgs,
89
89
  module_type: type[torch.nn.Module],
90
90
  ):
91
91
  super().__init__()
92
92
  self.weight = weight
93
93
  self.perm = perm
94
+ self.scheme = scheme
94
95
  self.args = args
95
96
  self.module_type = module_type
96
- self._scale = math.sqrt(weight.size(0))
97
+ self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98
+ self._precision = scheme.precision if args.is_online() else torch.float64
97
99
 
98
100
  def forward(self, value: Tensor) -> Tensor:
99
101
  weight = self.weight
@@ -105,6 +107,11 @@ class HadamardTransform(TransformBase):
105
107
  weight = weight.T
106
108
 
107
109
  return (
108
- apply_transform_weight(weight, value, self.args.location, self.module_type)
110
+ apply_transform_weight(
111
+ weight.to(self._precision),
112
+ value.to(self._precision),
113
+ self.args.location,
114
+ self.module_type,
115
+ )
109
116
  / self._scale
110
- )
117
+ ).to(value.dtype)
@@ -24,7 +24,7 @@ from compressed_tensors.transform.utils.matrix import (
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26
26
  from torch import Tensor, device, dtype
27
- from torch.nn import Linear, Module, Parameter
27
+ from torch.nn import Module, Parameter
28
28
 
29
29
 
30
30
  @TransformFactory.register("random-matrix")
@@ -52,14 +52,14 @@ class RandomMatrixFactory(TransformFactory):
52
52
  """
53
53
  assert hasattr(module, "weight")
54
54
  size = get_transform_size(module, args.location, self.scheme.head_dim)
55
- dtype = module.weight.dtype
55
+ dtype = self.scheme.precision
56
56
  device = get_offloaded_device(module)
57
57
 
58
58
  weight = self.weights[size, dtype, device]
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args, type(module))
62
+ return RandomMatrixTransform(weight, self.scheme, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
78
78
  def __init__(
79
79
  self,
80
80
  weight: Tensor,
81
+ scheme: TransformScheme,
81
82
  args: TransformArgs,
82
83
  module_type: type[torch.nn.Module],
83
84
  ):
84
85
  super().__init__()
85
86
  self.weight = weight # is an inverse if args.inverse
87
+ self.scheme = scheme
86
88
  self.args = args
87
89
  self.module_type = module_type
90
+ self._precision = scheme.precision if args.is_online() else torch.float64
88
91
 
89
92
  def forward(self, value: Tensor) -> Parameter:
90
93
  return apply_transform_weight(
91
- self.weight, value, self.args.location, self.module_type
92
- )
94
+ self.weight.to(self._precision),
95
+ value.to(self._precision),
96
+ self.args.location,
97
+ self.module_type,
98
+ ).to(value.dtype)
93
99
 
94
100
  def right_inverse(self, value: Tensor) -> Tensor:
95
101
  inverse = high_precision_invert(self.weight)
96
102
  return apply_transform_weight(
97
- inverse, value, self.args.location, self.module_type
98
- )
103
+ inverse.to(self._precision),
104
+ value.to(self._precision),
105
+ self.args.location,
106
+ self.module_type,
107
+ ).to(value.dtype)
99
108
 
100
109
 
101
110
  def high_precision_invert(weight: Tensor) -> Tensor:
102
- return torch.linalg.inv(weight.to(torch.float32)).to(weight.dtype)
111
+ return torch.linalg.inv(weight.to(torch.float64)).to(weight.dtype)
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import List
17
17
 
18
- from pydantic import BaseModel, Field, field_validator
18
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
19
19
 
20
20
 
21
21
  __all__ = ["TransformArgs", "TransformLocation"]
@@ -68,3 +68,11 @@ class TransformArgs(BaseModel, use_enum_values=True):
68
68
  if isinstance(value, str):
69
69
  return [value]
70
70
  return value
71
+
72
+ def is_online(self) -> bool:
73
+ return self.location not in (
74
+ TransformLocation.WEIGHT_INPUT,
75
+ TransformLocation.WEIGHT_OUTPUT,
76
+ )
77
+
78
+ model_config = ConfigDict(extra="forbid")
@@ -15,7 +15,7 @@
15
15
  from typing import Dict
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs, TransformScheme
18
- from pydantic import BaseModel
18
+ from pydantic import BaseModel, ConfigDict
19
19
 
20
20
 
21
21
  __all__ = ["TransformConfig"]
@@ -32,42 +32,4 @@ class TransformConfig(BaseModel):
32
32
 
33
33
  config_groups: Dict[str, TransformScheme]
34
34
 
35
-
36
- # quip / quip sharp
37
- QUIP = TransformConfig(
38
- config_groups={
39
- "v": TransformScheme(
40
- type="hadamard",
41
- apply=[
42
- TransformArgs(
43
- targets=["Linear"],
44
- location="input", # non-mergable
45
- ),
46
- TransformArgs(
47
- targets=["Linear"],
48
- location="weight_input",
49
- inverse=True,
50
- ),
51
- ],
52
- randomize=True,
53
- ),
54
- "u": TransformScheme(
55
- type="hadamard",
56
- apply=[
57
- TransformArgs(
58
- targets=["Linear"],
59
- location="weight_output",
60
- ),
61
- TransformArgs(
62
- targets=["Linear"], location="output", inverse=True # non-mergable
63
- ),
64
- ],
65
- randomize=True,
66
- ),
67
- }
68
- )
69
-
70
-
71
- PRESET_CONFIGS = {
72
- "QUIP": QUIP,
73
- }
35
+ model_config = ConfigDict(extra="forbid")
@@ -14,8 +14,10 @@
14
14
 
15
15
  from typing import List, Optional
16
16
 
17
+ import torch
17
18
  from compressed_tensors.transform import TransformArgs
18
- from pydantic import BaseModel, Field
19
+ from compressed_tensors.utils import TorchDtype
20
+ from pydantic import BaseModel, ConfigDict, Field
19
21
 
20
22
 
21
23
  __all__ = ["TransformScheme"]
@@ -34,6 +36,8 @@ class TransformScheme(BaseModel):
34
36
  :param randomize: True if uniquely randomized transform weights should be used,
35
37
  otherwise use identical transform weights where applicable
36
38
  :param requires_grad: True if weights include gradients for training
39
+ :param precision: Precision at which this transform should be applied during online
40
+ rotations. Fused (offline) rotations are always performed in float64
37
41
  """
38
42
 
39
43
  type: str
@@ -41,3 +45,6 @@ class TransformScheme(BaseModel):
41
45
  randomize: bool = Field(default=False)
42
46
  requires_grad: bool = Field(default=False)
43
47
  head_dim: Optional[int] = Field(default=None)
48
+ precision: TorchDtype = Field(default=torch.float32)
49
+
50
+ model_config = ConfigDict(extra="forbid")
@@ -21,3 +21,4 @@ from .permutations_24 import *
21
21
  from .permute import *
22
22
  from .safetensors_load import *
23
23
  from .semi_structured_conversions import *
24
+ from .type import *
@@ -86,6 +86,7 @@ __all__ = [
86
86
  "offloaded_dispatch",
87
87
  "disable_offloading",
88
88
  "remove_dispatch",
89
+ "cast_to_device",
89
90
  ]
90
91
 
91
92
 
@@ -169,6 +170,19 @@ def update_parameter_data(
169
170
  """ Candidates for Upstreaming """
170
171
 
171
172
 
173
+ def cast_to_device(device_spec: Union[int, torch.device]) -> torch.device:
174
+ """
175
+ Convert an integer device index or torch.device into a torch.device object.
176
+
177
+ :param device_spec: Device index (int) or torch.device object.
178
+ Negative integers map to CPU.
179
+ :return: torch.device corresponding to the given device specification.
180
+ """
181
+ if isinstance(device_spec, int):
182
+ return torch.device(f"cuda:{device_spec}" if device_spec >= 0 else "cpu")
183
+ return device_spec
184
+
185
+
172
186
  def get_execution_device(module: torch.nn.Module) -> torch.device:
173
187
  """
174
188
  Get the device which inputs should be moved to before module execution.
@@ -179,7 +193,7 @@ def get_execution_device(module: torch.nn.Module) -> torch.device:
179
193
  """
180
194
  for submodule in module.modules():
181
195
  if has_offloaded_params(submodule):
182
- return submodule._hf_hook.execution_device
196
+ return cast_to_device(submodule._hf_hook.execution_device)
183
197
 
184
198
  param = next(submodule.parameters(recurse=False), None)
185
199
  if param is not None:
@@ -0,0 +1,74 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Annotated, Any
16
+
17
+ import torch
18
+ from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
19
+ from pydantic.json_schema import JsonSchemaValue
20
+ from pydantic_core import core_schema
21
+
22
+
23
+ __all__ = ["TorchDtype"]
24
+
25
+
26
+ class _TorchDtypeAnnotation:
27
+ @classmethod
28
+ def __get_pydantic_core_schema__(
29
+ cls,
30
+ _source_type: Any,
31
+ _handler: GetCoreSchemaHandler,
32
+ ) -> core_schema.CoreSchema:
33
+ # support strings of the form `torch.xxx` or `xxx`
34
+ def validate_from_str(name: str) -> torch.dtype:
35
+ name = name.removeprefix("torch.")
36
+ try:
37
+ value = getattr(torch, name)
38
+ assert isinstance(value, torch.dtype)
39
+ except Exception:
40
+ raise ValueError(f"No such torch dtype `torch.{name}`")
41
+
42
+ return value
43
+
44
+ # package validation into a schema (which also validates str type)
45
+ from_str_schema = core_schema.chain_schema(
46
+ [
47
+ core_schema.str_schema(),
48
+ core_schema.no_info_plain_validator_function(validate_from_str),
49
+ ]
50
+ )
51
+
52
+ return core_schema.json_or_python_schema(
53
+ json_schema=from_str_schema,
54
+ python_schema=core_schema.union_schema(
55
+ [
56
+ # support both torch.dtype or strings
57
+ core_schema.is_instance_schema(torch.dtype),
58
+ from_str_schema,
59
+ ]
60
+ ),
61
+ # serialize as `torch.xxx`
62
+ serialization=core_schema.plain_serializer_function_ser_schema(
63
+ lambda instance: str(instance)
64
+ ),
65
+ )
66
+
67
+ @classmethod
68
+ def __get_pydantic_json_schema__(
69
+ cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
70
+ ) -> JsonSchemaValue:
71
+ return handler(core_schema.str_schema())
72
+
73
+
74
+ TorchDtype = Annotated[torch.dtype, _TorchDtypeAnnotation]
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250806'
20
+ __version__ = version = '0.10.3.a20250812'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250806
3
+ Version: 0.10.3a20250812
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,11 +1,11 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
- compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=AuoKIjSgjjAcZIPZe3HN5zhNJ7enhDAjwQrqUHPg76o,523
2
+ compressed_tensors/base.py,sha256=-gxWvDF4LCkyeDP8YlGzvBBKxo4Dk9h4NINPD61drFU,921
3
+ compressed_tensors/version.py,sha256=xvYY_n0Nd8St_3mnUNQJLLacYomaXp2WG5TBIJ3wSDo,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=e-2nLkLLzbPlm3xubZPv44BIXC16xaHDZmCp4eXCjZU,33316
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=NEnRjTRhmOHP_TVgrhQb58t-XjftDNoZBNyTXQQBRao,34042
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=YGUMzbxekj_36ChgQnVZN6T8uDjXtGG1zfMIBGBLWco,10354
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
@@ -26,9 +26,9 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
26
26
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
27
27
  compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
28
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
29
- compressed_tensors/quantization/quant_args.py,sha256=yKTj_4lAy_pnXeTCyUADpyz2qAzJXYJU2P03NF_TP68,12835
30
- compressed_tensors/quantization/quant_config.py,sha256=w6sEEZGVGIF0Ub2r_cqRfZwbkBT8WzfY3ug52olmjGY,10049
31
- compressed_tensors/quantization/quant_scheme.py,sha256=xk2LPn18tjS1PEOyf0WKvavBq3rzAVHFLB3H2mQQWnc,8473
29
+ compressed_tensors/quantization/quant_args.py,sha256=PMoaa6hpyJLGGSeCWefGmzGVxbOtxAdDunHJi_L5gNs,12894
30
+ compressed_tensors/quantization/quant_config.py,sha256=StEpCvc70JasE1srLaHqI-TJlasLWGtHU2o0E_gDJhQ,10400
31
+ compressed_tensors/quantization/quant_scheme.py,sha256=3EUGCw5_e7nnmvYPK_UlQKaaskOLIAo30dHYn0z7HmQ,8521
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=wM8mVcbKvZjBo18pSXMp28i30YWwUXJPSS7_HCakH9U,17892
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
@@ -40,30 +40,31 @@ compressed_tensors/quantization/utils/helpers.py,sha256=7a89X0kg6xDGplw6trOrkRQz
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
42
  compressed_tensors/transform/__init__.py,sha256=v2wfl4CMfA6KbD7Hxx_MbRev63y_6QLDlccZq-WTtdw,907
43
- compressed_tensors/transform/apply.py,sha256=Cnc7Q8d8FzpLGtXixvdPzqApfjAXpfShxvVl_7nNJ4E,1259
44
- compressed_tensors/transform/transform_args.py,sha256=jJY-Qt996w45LWQ10AHd7tUtNrnV9mjD9M5D4SZ5B3E,3199
45
- compressed_tensors/transform/transform_config.py,sha256=A3RuLNDqBNEByQNeu40Kg7sItwE6kWgnX18Umg1uONI,2128
46
- compressed_tensors/transform/transform_scheme.py,sha256=uGLC4avdbhrVqNC3-Eo0p7WzNRQK92Fpg0N9hWiuCRQ,1752
43
+ compressed_tensors/transform/apply.py,sha256=nCJvhHleIyWPNYPr-SZvXhmTKpqHVpJrG8VfIW-K6d8,1422
44
+ compressed_tensors/transform/transform_args.py,sha256=rVgReFp7wMXcYugkfd325e2tTFh8pGV3FnYTGCEv5jY,3429
45
+ compressed_tensors/transform/transform_config.py,sha256=h2EYyMrUwAzyak84JY1lsAgZ7Eupotw_cYLq8Ov5SH4,1219
46
+ compressed_tensors/transform/transform_scheme.py,sha256=S7vYLnuv7xZ_bwphkpCiGqZLjnnTnb4lj1T8a6WwnE0,2094
47
47
  compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
48
- compressed_tensors/transform/factory/base.py,sha256=NJ3lI95tJk6gHOeZEVheQ_Ae7NHhhUG_9FHXu613x30,7740
49
- compressed_tensors/transform/factory/hadamard.py,sha256=B0BVjbF3y707MO6L2XfEoZJTQU965vU9dUPLOiUSXII,4193
50
- compressed_tensors/transform/factory/matrix_multiply.py,sha256=kCB7cfM_PCgJDyyhg2d1rKTEiyuscwzhprXY7VfIx6E,3989
48
+ compressed_tensors/transform/factory/base.py,sha256=Txkr1nWKtlMU1MmBcQ85-JqJzD356Z9nYbaF24tJ5rw,7755
49
+ compressed_tensors/transform/factory/hadamard.py,sha256=CEy98vOIip_Pomh1XB62BqcjU8GQ9fUZSpnZH4GrBnE,4499
50
+ compressed_tensors/transform/factory/matrix_multiply.py,sha256=boZLMkaNrgXQ9cU-tFzJ-1N1tLgbKMJzAxiYZAr4Pu8,4326
51
51
  compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Umme71pnjMPgwYoGlwjKlU27UHZ4,1634
52
52
  compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
53
53
  compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
54
54
  compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
55
55
  compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiPAAjCM9Ejz77wb-w,6181
56
- compressed_tensors/utils/__init__.py,sha256=KZctuotCmX4byXhwDvSeXgp-Ny_awpziAX-WUkZfodI,853
56
+ compressed_tensors/utils/__init__.py,sha256=spzbjUO4-hZ2jXGST27r3MIt2yzIXsjdbEaYyaMcizo,873
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
59
59
  compressed_tensors/utils/match.py,sha256=9x-yZIlq7ndSLf2aQwNT7IpBQDe-8H6utiJkji8wPrQ,9397
60
- compressed_tensors/utils/offload.py,sha256=3XiBuWbUkBAt8v1t5i57qDcbB3VJQs_FDeayi-JzIWg,23896
60
+ compressed_tensors/utils/offload.py,sha256=gFoEDaissHsLM5-JDbgPxh5hiE9VFN4HFxvszYvReos,24446
61
61
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
62
62
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
63
63
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
64
64
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
65
- compressed_tensors-0.10.3a20250806.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
66
- compressed_tensors-0.10.3a20250806.dist-info/METADATA,sha256=e8DIx-6UDn2Wj7fGLEBgVru2k9Tme9dOPgxS_ciZDcw,7031
67
- compressed_tensors-0.10.3a20250806.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
- compressed_tensors-0.10.3a20250806.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
69
- compressed_tensors-0.10.3a20250806.dist-info/RECORD,,
65
+ compressed_tensors/utils/type.py,sha256=bNwoo_FWlvLuDpYAGGzZJITRg0JA_Ngk9LGPo-kvjeU,2554
66
+ compressed_tensors-0.10.3a20250812.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
67
+ compressed_tensors-0.10.3a20250812.dist-info/METADATA,sha256=FKQ6iJRJ_Q7GMoCLI73bfQXEKxW5E2hgmDl9lCTbtSY,7031
68
+ compressed_tensors-0.10.3a20250812.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
69
+ compressed_tensors-0.10.3a20250812.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
70
+ compressed_tensors-0.10.3a20250812.dist-info/RECORD,,