compressed-tensors-nightly 0.9.1.20250214__py3-none-any.whl → 0.9.1.20250215__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,7 +19,7 @@ import os
19
19
  import re
20
20
  from contextlib import contextmanager
21
21
  from copy import deepcopy
22
- from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
22
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
23
23
 
24
24
  import compressed_tensors
25
25
  import torch
@@ -39,13 +39,17 @@ from compressed_tensors.quantization import (
39
39
  apply_quantization_config,
40
40
  load_pretrained_quantization,
41
41
  )
42
- from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
42
+ from compressed_tensors.quantization.lifecycle import expand_target_names
43
43
  from compressed_tensors.quantization.quant_args import QuantizationArgs
44
44
  from compressed_tensors.quantization.utils import (
45
45
  is_module_quantized,
46
46
  iter_named_leaf_modules,
47
47
  )
48
- from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
48
+ from compressed_tensors.utils import (
49
+ get_safetensors_folder,
50
+ merge_names,
51
+ update_parameter_data,
52
+ )
49
53
  from compressed_tensors.utils.helpers import (
50
54
  fix_fsdp_module_name,
51
55
  is_compressed_tensors_config,
@@ -254,6 +258,107 @@ class ModelCompressor:
254
258
  quantization_config.format, config=quantization_config
255
259
  )
256
260
 
261
+ def get_missing_module_keys(self, model: Module) -> List[str]:
262
+ """
263
+ Identifies the expected missing weight keys in the compressed state_dict.
264
+
265
+ When a model undergoes sparsity or quantization compression, certain
266
+ weight tensors may be absent from the checkpoint by virtue of compression.
267
+ This function determines which weight keys are missing based on the
268
+ applied compression techniques.
269
+
270
+
271
+ :param model: The PyTorch model to check for missing keys.
272
+ :return: A list of missing keys expected in the compressed state_dict.
273
+ """
274
+ missing_keys = set()
275
+
276
+ # Determine missing keys due to sparsity compression
277
+ if (
278
+ self.sparsity_compressor
279
+ and self.sparsity_config.format != CompressionFormat.dense.value
280
+ ):
281
+ sparse_targets = expand_target_names(
282
+ model=model,
283
+ targets=self.sparsity_config.targets,
284
+ ignore=self.sparsity_config.ignore,
285
+ )
286
+ missing_keys.update(
287
+ merge_names(target, "weight") for target in sparse_targets
288
+ )
289
+
290
+ # Determine missing keys due to pack quantization
291
+ if (
292
+ self.quantization_compressor
293
+ and self.quantization_config.format
294
+ == CompressionFormat.pack_quantized.value
295
+ ):
296
+ for scheme in self.quantization_config.config_groups.values():
297
+ quant_targets = expand_target_names(
298
+ model=model,
299
+ targets=scheme.targets,
300
+ ignore=self.quantization_config.ignore,
301
+ )
302
+ missing_keys.update(
303
+ merge_names(target, "weight") for target in quant_targets
304
+ )
305
+
306
+ return list(missing_keys)
307
+
308
+ def get_unexpected_file_keys(self, model: Module) -> List[str]:
309
+ """
310
+ Identifies extra keys introduced by the compression process in the
311
+ compressed state_dict that are not expected by the model graph.
312
+
313
+ During sparsity or quantization compression, additional metadata or
314
+ auxiliary parameters may be stored in the checkpoint, which do not
315
+ correspond to any parameter in the original model. These keys are
316
+ typically introduced to support the reconstruction of compressed weights.
317
+
318
+ For example, Sparse24Bitmask compression may introduce keys such as
319
+ 'compressed', 'bitmask', and 'shape' in the checkpoint, which are
320
+ not part of the original model parameters.
321
+
322
+ :param model: The PyTorch model to check for unexpected keys.
323
+ :return: A list of extra keys introduced by the compression process
324
+ that are not expected by the model.
325
+ """
326
+
327
+ unexpected_keys = set()
328
+
329
+ # Identify unexpected keys from sparsity compression
330
+ if (
331
+ self.sparsity_compressor
332
+ and self.sparsity_config.format != CompressionFormat.dense.value
333
+ ):
334
+ sparse_targets: Set[str] = expand_target_names(
335
+ model=model,
336
+ targets=self.sparsity_config.targets,
337
+ ignore=self.sparsity_config.ignore,
338
+ )
339
+ unexpected_keys.update(
340
+ merge_names(target, param)
341
+ for target in sparse_targets
342
+ for param in self.sparsity_compressor.compression_param_names
343
+ )
344
+
345
+ # Identify unexpected keys from quantization compression
346
+ if self.quantization_compressor:
347
+ for scheme in self.quantization_config.config_groups.values():
348
+ quant_targets: Set[str] = expand_target_names(
349
+ model=model,
350
+ targets=scheme.targets,
351
+ ignore=self.quantization_config.ignore,
352
+ )
353
+ unexpected_keys.update(
354
+ merge_names(target, param)
355
+ for target in quant_targets
356
+ for param in self.quantization_compressor.compression_param_names
357
+ if param != "weight"
358
+ )
359
+
360
+ return list(unexpected_keys)
361
+
257
362
  def compress(
258
363
  self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
259
364
  ) -> Dict[str, Tensor]:
@@ -283,7 +388,7 @@ class ModelCompressor:
283
388
  )
284
389
 
285
390
  if self.sparsity_compressor is not None:
286
- sparse_compression_targets: Set[str] = expand_sparse_target_names(
391
+ sparse_compression_targets: Set[str] = expand_target_names(
287
392
  model=model,
288
393
  targets=self.sparsity_config.targets,
289
394
  ignore=self.sparsity_config.ignore,
@@ -52,8 +52,8 @@ __all__ = [
52
52
  "apply_quantization_config",
53
53
  "apply_quantization_status",
54
54
  "find_name_or_class_matches",
55
- "expand_sparse_target_names",
56
- "is_sparse_target",
55
+ "expand_target_names",
56
+ "is_target",
57
57
  ]
58
58
 
59
59
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -247,8 +247,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
247
247
  model.apply(compress_quantized_weights)
248
248
 
249
249
 
250
- def expand_sparse_target_names(
251
- model: Module, targets: Iterable[str], ignore: Iterable[str]
250
+ def expand_target_names(
251
+ model: Module,
252
+ targets: Optional[Iterable[str]] = None,
253
+ ignore: Optional[Iterable[str]] = None,
252
254
  ) -> Set[str]:
253
255
  """
254
256
  Finds all unique module names in the model that match the given
@@ -257,20 +259,23 @@ def expand_sparse_target_names(
257
259
  Note: Targets must be regexes, layer types, or full layer names.
258
260
 
259
261
  :param model: model to search for targets in
260
- :param targets: list of targets to search for
261
- :param ignore: list of targets to ignore
262
+ :param targets: Iterable of targets to search for
263
+ :param ignore: Iterable of targets to ignore
262
264
  :return: set of all targets that match the given targets and should
263
265
  not be ignored
264
266
  """
265
267
  return {
266
268
  name
267
269
  for name, module in iter_named_leaf_modules(model)
268
- if is_sparse_target(name, module, targets, ignore)
270
+ if is_target(name, module, targets, ignore)
269
271
  }
270
272
 
271
273
 
272
- def is_sparse_target(
273
- name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
274
+ def is_target(
275
+ name: str,
276
+ module: Module,
277
+ targets: Optional[Iterable[str]] = None,
278
+ ignore: Optional[Iterable[str]] = None,
274
279
  ) -> bool:
275
280
  """
276
281
  Determines if a module should be included in the targets based on the
@@ -280,12 +285,12 @@ def is_sparse_target(
280
285
 
281
286
  :param name: name of the module
282
287
  :param module: the module itself
283
- :param targets: list of targets to search for
284
- :param ignore: list of targets to ignore
288
+ :param targets: Iterable of targets to search for
289
+ :param ignore: Iterable of targets to ignore
285
290
  :return: True if the module is a target and not ignored, False otherwise
286
291
  """
287
292
  return bool(
288
- find_name_or_class_matches(name, module, targets)
293
+ find_name_or_class_matches(name, module, targets or [])
289
294
  and not find_name_or_class_matches(name, module, ignore or [])
290
295
  )
291
296
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.9.1.20250214
3
+ Version: 0.9.1.20250215
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -5,7 +5,7 @@ compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1
5
5
  compressed_tensors/compressors/base.py,sha256=x8dQrWVEurynXw03yHJZTaAmrRTOsdZJoHjmvs0IKwk,7002
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=3WyzAW2Rm_uLprxwO2QH6FR76W6Mk4r2yedayaSZHhw,18396
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=AmIE1SoNRH1fNgQALfNkdQo8y5tePVpdWUgLIOtf5rg,22569
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=cp8S1Kr3HhlMHIz7k4vGo-qxxdknEC3qP1QLIhNnwRA,7217
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fd0KlkSx6bvZ3xwIkK3jEUdPSUPs56Eua4dEDOtzKW0,5150
@@ -29,7 +29,7 @@ compressed_tensors/quantization/quant_args.py,sha256=sKpb8DcNObidjXjNol1Tn_Iih3Z
29
29
  compressed_tensors/quantization/quant_config.py,sha256=vx06wBo91p4LCb3Vzd-2eCTUeIf_Sz2ZXRP263eQyjQ,10385
30
30
  compressed_tensors/quantization/quant_scheme.py,sha256=eQ0JrRZ80GX69fpwW87VzPzzhajhk4mUaJScjk82OY4,6010
31
31
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
32
- compressed_tensors/quantization/lifecycle/apply.py,sha256=XS4M6N1opKBybhkuQsS338QVb_CKMhUM5TUKrqoNQ0k,16517
32
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=lZmCCSm1_o79iUAy460w6Bv9FaOvntVisMdS-dN9fnk,16594
33
33
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
34
34
  compressed_tensors/quantization/lifecycle/forward.py,sha256=DOWouUqfaLA4Qhg-ojVVBdhhSAlgZqFC26vZARxE0ko,12961
35
35
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
@@ -45,8 +45,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
45
45
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
46
46
  compressed_tensors/utils/safetensors_load.py,sha256=5SeM2hzLh77Ne8Vk7qR6-km7cf8bhov41ExpWITqX3A,11470
47
47
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
48
- compressed_tensors_nightly-0.9.1.20250214.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
49
- compressed_tensors_nightly-0.9.1.20250214.dist-info/METADATA,sha256=IrLuKBJxl7kmRAn9ScOSOQ7k0L6ZS8L4Cf8mxyqPJXI,6992
50
- compressed_tensors_nightly-0.9.1.20250214.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
51
- compressed_tensors_nightly-0.9.1.20250214.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
52
- compressed_tensors_nightly-0.9.1.20250214.dist-info/RECORD,,
48
+ compressed_tensors_nightly-0.9.1.20250215.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
49
+ compressed_tensors_nightly-0.9.1.20250215.dist-info/METADATA,sha256=2Kc88b_CWYKco_ntUMkkVuHxp5Km3YAQg0mi3OwWc4Q,6992
50
+ compressed_tensors_nightly-0.9.1.20250215.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
51
+ compressed_tensors_nightly-0.9.1.20250215.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
52
+ compressed_tensors_nightly-0.9.1.20250215.dist-info/RECORD,,