compressed-tensors 0.9.5a20250603__py3-none-any.whl → 0.9.5a20250604__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -50,6 +50,7 @@ from compressed_tensors.utils import (
50
50
  align_module_device,
51
51
  delete_offload_parameter,
52
52
  get_execution_device,
53
+ get_offloaded_device,
53
54
  get_safetensors_folder,
54
55
  has_offloaded_params,
55
56
  merge_names,
@@ -408,16 +409,17 @@ class ModelCompressor:
408
409
  )
409
410
 
410
411
  # remove any existing parameters
411
- device = get_execution_device(module)
412
+ exec_device = get_execution_device(module)
413
+ offload_device = get_offloaded_device(module)
412
414
  for name, _ in list(module.named_parameters()):
413
- delattr(module, name)
415
+ delete_offload_parameter(module, name)
414
416
 
415
417
  # replace with compressed parameters
416
418
  for name, value in state_dict.items():
417
419
  name = name.removeprefix(f"{prefix}.")
418
- value = value.to(device)
420
+ value = value.to(exec_device)
419
421
  param = torch.nn.Parameter(value, requires_grad=False)
420
- register_offload_parameter(module, name, param)
422
+ register_offload_parameter(module, name, param, offload_device)
421
423
 
422
424
  module.quantization_status = QuantizationStatus.COMPRESSED
423
425
 
@@ -460,30 +462,26 @@ class ModelCompressor:
460
462
 
461
463
  # quantization second
462
464
  if prefix in module_to_scheme:
463
- generator = self.quantization_compressor.decompress_from_state_dict(
464
- state_dict,
465
- names_to_scheme=module_to_scheme,
465
+ state_dict = (
466
+ self.quantization_compressor.decompress_module_from_state_dict(
467
+ prefix,
468
+ state_dict,
469
+ scheme=module_to_scheme[prefix],
470
+ )
466
471
  )
467
- # generates (mod_path, {param_name, param_val})
468
- # of compressed params and used params, but not unused params
469
- # some used params are removed by get_unexpected_file_keys
470
- state_dict = {
471
- merge_names(module_path, param_name): param_value
472
- for module_path, compressed_data in generator
473
- for param_name, param_value in compressed_data.items()
474
- }
475
472
 
476
473
  # remove any existing parameters
477
- device = get_execution_device(module)
474
+ exec_device = get_execution_device(module)
475
+ offload_device = get_offloaded_device(module)
478
476
  for name, _ in list(module.named_parameters()):
479
477
  delete_offload_parameter(module, name)
480
478
 
481
479
  # replace with decompressed parameters
482
480
  for name, value in state_dict.items():
483
481
  name = name.removeprefix(f"{prefix}.")
484
- value = value.to(device)
482
+ value = value.to(exec_device)
485
483
  param = torch.nn.Parameter(value, requires_grad=False)
486
- register_offload_parameter(module, name, param)
484
+ register_offload_parameter(module, name, param, offload_device)
487
485
 
488
486
  module.quantization_status = QuantizationStatus.FROZEN
489
487
 
@@ -24,6 +24,7 @@ from compressed_tensors.utils import (
24
24
  get_nested_weight_mappings,
25
25
  merge_names,
26
26
  )
27
+ from compressed_tensors.utils.safetensors_load import match_param_name
27
28
  from safetensors import safe_open
28
29
  from torch import Tensor
29
30
  from tqdm import tqdm
@@ -223,9 +224,7 @@ class BaseQuantizationCompressor(BaseCompressor):
223
224
  state_dict, self.compression_param_names
224
225
  )
225
226
  for module_path in weight_mappings.keys():
226
- weight_data = {}
227
- for param_name, param_value in weight_mappings[module_path].items():
228
- weight_data[param_name] = param_value
227
+ weight_data = weight_mappings[module_path].copy()
229
228
 
230
229
  if "weight_scale" in weight_data:
231
230
  quant_args = names_to_scheme[module_path].weights
@@ -234,3 +233,31 @@ class BaseQuantizationCompressor(BaseCompressor):
234
233
  )
235
234
  weight_data["weight"] = decompressed
236
235
  yield module_path, weight_data
236
+
237
+ def decompress_module_from_state_dict(
238
+ self,
239
+ prefix: str,
240
+ state_dict: Dict[str, torch.Tensor],
241
+ scheme: QuantizationScheme,
242
+ ) -> Dict[str, torch.Tensor]:
243
+ """
244
+ Only used by in-memory decompression pathways to decompress the parameters of
245
+ one module
246
+
247
+ :param prefix: prefix of state_dict, typically the path to the module
248
+ :param state_dict: state dict containing module parameter values
249
+ :param scheme: quantization scheme of module to decompress
250
+ :return: state dict with weight decompressed if applicable
251
+ """
252
+ state_dict = {
253
+ key.removeprefix(f"{prefix}."): value for key, value in state_dict.items()
254
+ }
255
+
256
+ if "weight_scale" in state_dict:
257
+ state_dict["weight"] = self.decompress_weight(
258
+ compressed_data=state_dict, quantization_args=scheme.weights
259
+ )
260
+
261
+ state_dict = {f"{prefix}.{key}": value for key, value in state_dict.items()}
262
+
263
+ return state_dict
@@ -21,7 +21,6 @@ from compressed_tensors.quantization.quant_args import (
21
21
  DynamicType,
22
22
  QuantizationArgs,
23
23
  QuantizationStrategy,
24
- QuantizationType,
25
24
  round_to_quantized_type,
26
25
  )
27
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -405,7 +404,7 @@ def _quantize(
405
404
 
406
405
  # if a global scale is optionally provided, use it
407
406
  # to further scale the local `scale` parameter
408
- if global_scale:
407
+ if global_scale is not None:
409
408
  scale = scale.to(global_scale.dtype) / global_scale
410
409
 
411
410
  scaled = x / scale
@@ -438,7 +437,7 @@ def _dequantize(
438
437
 
439
438
  # if a global scale is optionally provided, use it
440
439
  # to further scale the local `scale` parameter
441
- if global_scale:
440
+ if global_scale is not None:
442
441
  scale = scale.to(global_scale.dtype) / global_scale
443
442
 
444
443
  dequant_value = x_q.to(scale.dtype)
@@ -110,6 +110,7 @@ def calculate_qparams(
110
110
  else:
111
111
  scales = max_val_pos / (float(bit_range) / 2)
112
112
 
113
+ # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113
114
  if scales.dtype == FP8_E4M3_DATA.dtype:
114
115
  # torch.clamp not supported for FP8
115
116
  # use the next largest fp8 value from 0
@@ -495,4 +496,4 @@ def generate_gparam(
495
496
  max_vals = torch.max(updated_max_val, torch.zeros_like(updated_max_val))
496
497
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
497
498
  global_scale = scale_data.max * quant_data.max / max_val_pos
498
- return global_scale.to(dtype)
499
+ return global_scale.to(dtype).reshape([1])
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250603'
20
+ __version__ = version = '0.9.5.a20250604'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250603
3
+ Version: 0.9.5a20250604
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,13 +1,13 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=fFlh5YaVrAZG4nN-6r62Ow43mZVj9W0f2ASMDww8e5k,521
3
+ compressed_tensors/version.py,sha256=QmqChcTnn-HquSfq_8n_1b_CkQT93OOAGt5yzbeUk0A,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=BBJd3Ei6FtqVQLBkOm80G6pSJ11IMTGuTA-FL4n6_5g,32704
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=72h2tWDIGbbqLQF8MDzOehy18eu5TvsCLd_AuzGv_O4,32517
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
- compressed_tensors/compressors/quantized_compressors/base.py,sha256=n_sVSzySHUBgXt-nkLggM1DtB0aEgQmiKhTzcnQU9Dc,9266
10
+ compressed_tensors/compressors/quantized_compressors/base.py,sha256=ByE3z61boZ5wdz0nhc-2CJH61bSixJQE78pfkS6XRDg,10269
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
12
12
  compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
13
13
  compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=_66tQ8bxslDUdas-ULORXblPw9kdNNn1UJJU9-ZOGPY,11380
@@ -32,11 +32,11 @@ compressed_tensors/quantization/quant_scheme.py,sha256=IDWa1GWUbUdWCo8j78Jz6svYF
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=DOoxH4jM8r0270GGGUFOpRrgwaisiJi7TV-Q6E8qM8E,18067
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=WFwvNebxXNUlpX5p1xG80oa8W9fz4-Xd6LCH_B_nptg,14881
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=JWOQ-03bsgh9_nnOLAjmLZ0S8bFQA-GjwDK6YUBwcrU,14883
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
37
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=9d5Ee7qt3zxaa5_PFitkvadvRDXeDqBIxYgooBqtrf8,8638
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
39
- compressed_tensors/quantization/utils/helpers.py,sha256=DLSPX-5cmrXxVQbt-keN9Qgbvn_lPOL674pXa2gR8-A,17740
39
+ compressed_tensors/quantization/utils/helpers.py,sha256=bqxNL2NU1XVsSxNzmDVZE3zd65PlLFq1Ir-RHwff8G0,17840
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
42
  compressed_tensors/transform/__init__.py,sha256=oa5VdrE-GtDYYceXNSwj5X_ropoXLLukm6Aufcc9WhY,747
@@ -50,8 +50,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
50
50
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
51
51
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
52
52
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
53
- compressed_tensors-0.9.5a20250603.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
- compressed_tensors-0.9.5a20250603.dist-info/METADATA,sha256=VEjGe1Y3JAwrVC7SfWv3yu1Xq85-mrQsYW2UxUwKuyE,7004
55
- compressed_tensors-0.9.5a20250603.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- compressed_tensors-0.9.5a20250603.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
57
- compressed_tensors-0.9.5a20250603.dist-info/RECORD,,
53
+ compressed_tensors-0.9.5a20250604.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
+ compressed_tensors-0.9.5a20250604.dist-info/METADATA,sha256=2dI2Y96LKAAG_vshTtYzxXhvM5Fby_hSISAAFbXYJXE,7004
55
+ compressed_tensors-0.9.5a20250604.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ compressed_tensors-0.9.5a20250604.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
57
+ compressed_tensors-0.9.5a20250604.dist-info/RECORD,,