compressed-tensors 0.11.1a20250820__py3-none-any.whl → 0.11.1a20250828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +178 -156
  2. compressed_tensors/compressors/quantized_compressors/base.py +2 -2
  3. compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +9 -9
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -3
  5. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +1 -1
  6. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  7. compressed_tensors/quantization/lifecycle/apply.py +48 -142
  8. compressed_tensors/quantization/lifecycle/forward.py +5 -4
  9. compressed_tensors/quantization/lifecycle/initialize.py +7 -6
  10. compressed_tensors/quantization/quant_args.py +7 -5
  11. compressed_tensors/quantization/quant_scheme.py +4 -3
  12. compressed_tensors/quantization/utils/helpers.py +0 -1
  13. compressed_tensors/registry/registry.py +1 -1
  14. compressed_tensors/transform/transform_config.py +1 -1
  15. compressed_tensors/transform/utils/matrix.py +1 -1
  16. compressed_tensors/utils/match.py +57 -8
  17. compressed_tensors/utils/offload.py +0 -1
  18. compressed_tensors/utils/safetensors_load.py +0 -1
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/METADATA +1 -1
  21. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/RECORD +24 -24
  22. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/WHEEL +0 -0
  23. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/licenses/LICENSE +0 -0
  24. {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- import re
17
- from collections import OrderedDict, defaultdict
16
+ from collections import OrderedDict
18
17
  from copy import deepcopy
19
18
  from typing import Dict, Iterable, List, Optional
20
19
  from typing import OrderedDict as OrderedDictType
21
- from typing import Set, Union
20
+ from typing import Union
22
21
 
23
22
  import torch
24
23
  from compressed_tensors.config import CompressionFormat
@@ -39,7 +38,8 @@ from compressed_tensors.quantization.utils import (
39
38
  infer_quantization_status,
40
39
  is_kv_cache_quant_scheme,
41
40
  )
42
- from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
41
+ from compressed_tensors.utils.helpers import deprecated, replace_module
42
+ from compressed_tensors.utils.match import match_named_modules, match_targets
43
43
  from compressed_tensors.utils.offload import update_parameter_data
44
44
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
45
45
  from safetensors import safe_open
@@ -51,8 +51,6 @@ __all__ = [
51
51
  "apply_quantization_config",
52
52
  "apply_quantization_status",
53
53
  "find_name_or_class_matches",
54
- "expand_target_names",
55
- "is_target",
56
54
  ]
57
55
 
58
56
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -73,14 +71,14 @@ def load_pretrained_quantization_parameters(
73
71
  Loads the quantization parameters (scale and zero point) from model_name_or_path to
74
72
  a model that has already been initialized with a quantization config.
75
73
 
76
- NOTE: Will always load inputs/output parameters.
77
- Will conditioanlly load weight parameters, if load_weight_quantization is set to True.
74
+ NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
75
+ parameters, if load_weight_quantization is set to True.
78
76
 
79
77
  :param model: model to load pretrained quantization parameters to
80
78
  :param model_name_or_path: Hugging Face stub or local folder containing a quantized
81
79
  model, which is used to load quantization parameters
82
- :param load_weight_quantization: whether or not the weight quantization parameters shoud
83
- be laoded
80
+ :param load_weight_quantization: whether or not the weight quantization parameters
81
+ should be loaded
84
82
  """
85
83
  model_path = get_safetensors_folder(model_name_or_path)
86
84
  mapping = get_quantization_parameter_to_path_mapping(model_path)
@@ -117,7 +115,7 @@ def load_pretrained_quantization_parameters(
117
115
 
118
116
  def apply_quantization_config(
119
117
  model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
120
- ) -> Dict[str, QuantizationScheme]:
118
+ ):
121
119
  """
122
120
  Initializes the model for quantization in-place based on the given config.
123
121
  Optionally coverts quantizable modules to compressed_linear modules
@@ -127,71 +125,49 @@ def apply_quantization_config(
127
125
  :param run_compressed: Whether the model will be run in compressed mode or
128
126
  decompressed fully on load
129
127
  """
130
- # Workaround for when HF Quantizer passes None, see PR #180
131
- if config is None:
132
- return dict()
128
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
133
129
 
134
- # remove reference to the original `config`
135
- # argument. This function can mutate it, and we'd
136
- # like to keep the original `config` as it is.
137
130
  config = deepcopy(config)
131
+ if config is None: # see PR #180
132
+ return dict()
133
+
134
+ # preprocess to support kv cache scheme
135
+ config = process_quantization_config(config)
136
+
138
137
  # build mapping of targets to schemes for easier matching
139
138
  # use ordered dict to preserve target ordering in config
140
139
  target_to_scheme = OrderedDict()
141
- config = process_quantization_config(config)
142
- names_to_scheme = dict()
143
140
  for scheme in config.config_groups.values():
144
141
  for target in scheme.targets:
145
142
  target_to_scheme[target] = scheme
146
143
 
147
- if run_compressed:
148
- from compressed_tensors.linear.compressed_linear import CompressedLinear
149
-
150
- # list of submodules to ignore
151
- ignored_submodules = defaultdict(list)
152
144
  # mark appropriate layers for quantization by setting their quantization schemes
153
- for name, submodule in model.named_modules():
154
- # potentially fix module name to remove FSDP wrapper prefix
155
- name = fix_fsdp_module_name(name)
156
- if matches := find_name_or_class_matches(name, submodule, config.ignore):
157
- for match in matches:
158
- ignored_submodules[match].append(name)
159
- continue # layer matches ignore list, continue
160
-
161
- targets = find_name_or_class_matches(name, submodule, target_to_scheme)
162
-
163
- if targets:
164
- # mark modules to be quantized by adding
165
- # quant scheme to the matching layers
166
- scheme = _scheme_from_targets(target_to_scheme, targets, name)
167
- if run_compressed:
168
- format = config.format
169
- if format != CompressionFormat.dense.value:
170
- if isinstance(submodule, torch.nn.Linear):
171
- # TODO: expand to more module types
172
- compressed_linear = CompressedLinear.from_linear(
173
- submodule,
174
- quantization_scheme=scheme,
175
- quantization_format=format,
176
- )
177
- replace_module(model, name, compressed_linear)
178
-
179
- # target matched - add layer and scheme to target list
180
- submodule.quantization_scheme = scheme
181
-
182
- names_to_scheme[name] = submodule.quantization_scheme
183
-
184
- if config.ignore is not None and ignored_submodules is not None:
185
- if set(config.ignore) - set(ignored_submodules):
186
- _LOGGER.warning(
187
- "Some layers that were to be ignored were "
188
- "not found in the model: "
189
- f"{set(config.ignore) - set(ignored_submodules)}"
190
- )
145
+ for name, submodule in match_named_modules(
146
+ model, target_to_scheme, config.ignore, warn_on_fail=True
147
+ ):
148
+ # mark modules to be quantized by adding
149
+ # quant scheme to the matching layers
150
+ matched_targets = match_targets(name, submodule, target_to_scheme)
151
+ scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
152
+ # target matched - add layer and scheme to target list
153
+ submodule.quantization_scheme = scheme
154
+
155
+ # replace with run compressed if applicable
156
+ # FUTURE: move this to model compressor
157
+ if isinstance(submodule, torch.nn.Linear) and run_compressed:
158
+ format = config.format
159
+ if format != CompressionFormat.dense.value:
160
+ if isinstance(submodule, torch.nn.Linear):
161
+ # TODO: expand to more module types
162
+ compressed_linear = CompressedLinear.from_linear(
163
+ submodule,
164
+ quantization_scheme=scheme,
165
+ quantization_format=format,
166
+ )
167
+ replace_module(model, name, compressed_linear)
191
168
 
192
169
  # apply current quantization status across all targeted layers
193
170
  apply_quantization_status(model, config.quantization_status)
194
- return names_to_scheme
195
171
 
196
172
 
197
173
  def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
@@ -262,54 +238,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
262
238
  model.apply(compress_quantized_weights)
263
239
 
264
240
 
265
- def expand_target_names(
266
- model: Module,
267
- targets: Optional[Iterable[str]] = None,
268
- ignore: Optional[Iterable[str]] = None,
269
- ) -> Set[str]:
270
- """
271
- Finds all unique module names in the model that match the given
272
- targets and ignore lists.
273
-
274
- Note: Targets must be regexes, layer types, or full layer names.
275
-
276
- :param model: model to search for targets in
277
- :param targets: Iterable of targets to search for
278
- :param ignore: Iterable of targets to ignore
279
- :return: set of all targets that match the given targets and should
280
- not be ignored
281
- """
282
- return {
283
- name
284
- for name, module in model.named_modules()
285
- if is_target(name, module, targets, ignore)
286
- }
287
-
288
-
289
- def is_target(
290
- name: str,
291
- module: Module,
292
- targets: Optional[Iterable[str]] = None,
293
- ignore: Optional[Iterable[str]] = None,
294
- ) -> bool:
295
- """
296
- Determines if a module should be included in the targets based on the
297
- targets and ignore lists.
298
-
299
- Note: Targets must be regexes, layer types, or full layer names.
300
-
301
- :param name: name of the module
302
- :param module: the module itself
303
- :param targets: Iterable of targets to search for
304
- :param ignore: Iterable of targets to ignore
305
- :return: True if the module is a target and not ignored, False otherwise
306
- """
307
- return bool(
308
- find_name_or_class_matches(name, module, targets or [])
309
- and not find_name_or_class_matches(name, module, ignore or [])
310
- )
311
-
312
-
241
+ @deprecated(
242
+ message="This function is deprecated and will be removed in a future release."
243
+ "Please use `match_targets` from `compressed_tensors.utils.match` instead."
244
+ )
313
245
  def find_name_or_class_matches(
314
246
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
315
247
  ) -> List[str]:
@@ -322,38 +254,13 @@ def find_name_or_class_matches(
322
254
  2. matches on regex patterns
323
255
  3. matches on module names
324
256
  """
325
- from compressed_tensors import InternalModule
326
-
327
- if isinstance(module, InternalModule):
328
- return []
329
-
330
- targets = sorted(targets, key=lambda x: ("re:" in x, x))
331
- if isinstance(targets, Iterable):
332
- matches = _find_matches(name, targets) + _find_matches(
333
- module.__class__.__name__, targets, check_contains
257
+ if check_contains:
258
+ raise NotImplementedError(
259
+ "This function is deprecated, and the check_contains=True option has been"
260
+ " removed."
334
261
  )
335
- matches = [match for match in matches if match is not None]
336
- return matches
337
262
 
338
-
339
- def _find_matches(
340
- value: str, targets: Iterable[str], check_contains: bool = False
341
- ) -> List[str]:
342
- # returns all the targets that match value either
343
- # exactly or as a regex after 're:'. if check_contains is set to True,
344
- # additionally checks if the target string is contained with value.
345
- matches = []
346
- for target in targets:
347
- if target.startswith("re:"):
348
- pattern = target[3:]
349
- if re.match(pattern, value):
350
- matches.append(target)
351
- elif check_contains:
352
- if target.lower() in value.lower():
353
- matches.append(target)
354
- elif target == value:
355
- matches.append(target)
356
- return matches
263
+ return match_targets(name, module, targets)
357
264
 
358
265
 
359
266
  def _infer_status(model: Module) -> Optional[QuantizationStatus]:
@@ -429,7 +336,6 @@ def _scheme_from_targets(
429
336
  def _merge_schemes(
430
337
  schemes_to_merge: List[QuantizationScheme], name: str
431
338
  ) -> QuantizationScheme:
432
-
433
339
  kv_cache_quantization_scheme = [
434
340
  scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
435
341
  ]
@@ -205,7 +205,8 @@ def _process_quantization(
205
205
  q_min, q_max = calculate_range(args, x.device)
206
206
  group_size = args.group_size
207
207
 
208
- # blockwise FP8: quantize per 2D block, supports block_structure for static block quant
208
+ # blockwise FP8: quantize per 2D block, supports block_structure for static block
209
+ # quantization
209
210
  if args.strategy == QuantizationStrategy.BLOCK:
210
211
  original_shape = x.shape
211
212
  rows, cols = x.shape[-2], x.shape[-1]
@@ -214,8 +215,8 @@ def _process_quantization(
214
215
  # Ensure exact division (tensor dimensions must be divisible by block size)
215
216
  if rows % block_height != 0:
216
217
  raise ValueError(
217
- f"Tensor height {rows} is not divisible by block_height {block_height}. "
218
- f"Block quantization requires exact division."
218
+ f"Tensor height {rows} is not divisible by block_height {block_height}."
219
+ f" Block quantization requires exact division."
219
220
  )
220
221
  if cols % block_width != 0:
221
222
  raise ValueError(
@@ -295,7 +296,7 @@ def _process_quantization(
295
296
  perm = torch.argsort(g_idx)
296
297
  x = safe_permute(x, perm, dim=1)
297
298
 
298
- # Maintain all dimensions apart from the last dim, which is divided by the group_size
299
+ # Maintain all dimensions except the last dim, which is divided by group_size
299
300
  reshaped_dims = (
300
301
  ceil(x.shape[-1] / group_size),
301
302
  group_size,
@@ -17,7 +17,7 @@ import logging
17
17
  import math
18
18
  import warnings
19
19
  from enum import Enum
20
- from typing import List, Optional
20
+ from typing import Optional
21
21
 
22
22
  import torch
23
23
  from compressed_tensors.quantization.lifecycle.forward import (
@@ -87,7 +87,6 @@ def initialize_module_for_quantization(
87
87
  _initialize_attn_scales(module)
88
88
 
89
89
  else:
90
-
91
90
  if scheme.input_activations is not None:
92
91
  _initialize_scale_zero_point(
93
92
  module,
@@ -183,7 +182,8 @@ def _initialize_scale_zero_point(
183
182
  num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
184
183
  expected_shape = (weight_shape[0], max(num_groups, 1))
185
184
  elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186
- # For block quantization, scale shape should match number of blocks - only for weights
185
+ # For block quantization, scale shape should match number of blocks - only
186
+ # for weights
187
187
  if quantization_args.block_structure is None:
188
188
  raise ValueError(
189
189
  "Block quantization requires block_structure to be specified"
@@ -196,9 +196,10 @@ def _initialize_scale_zero_point(
196
196
  # Warn if dimensions don't divide evenly
197
197
  if rows % block_height != 0 or cols % block_width != 0:
198
198
  warnings.warn(
199
- f"Block quantization: tensor shape {weight_shape} does not divide evenly "
200
- f"by block structure {quantization_args.block_structure}. "
201
- f"Some blocks will be incomplete which may affect quantization quality.",
199
+ f"Block quantization: tensor shape {weight_shape} does not divide"
200
+ f"evenly by block structure {quantization_args.block_structure}. "
201
+ f"Some blocks will be incomplete which may affect quantization"
202
+ "quality.",
202
203
  UserWarning,
203
204
  )
204
205
 
@@ -217,16 +217,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
217
217
  return [int(x) for x in value.split("x")]
218
218
  except Exception:
219
219
  raise ValueError(
220
- f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
220
+ f"Invalid block_structure '{value}'. Must be a list of ints "
221
+ "[rows, cols]."
221
222
  )
222
223
  if isinstance(value, (list, tuple)):
223
224
  if len(value) != 2 or not all(isinstance(v, int) for v in value):
224
225
  raise ValueError(
225
- f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
226
+ f"Invalid block_structure '{value}'. Must be a list of ints "
227
+ "[rows, cols]."
226
228
  )
227
229
  return list(value)
228
230
  raise ValueError(
229
- f"Invalid block_structure '{value}'. Must be a list of two ints [rows, cols]."
231
+ f"Invalid block_structure '{value}'. Must be a list of ints [rows, cols]."
230
232
  )
231
233
 
232
234
  @field_validator("strategy", mode="before")
@@ -307,7 +309,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
307
309
  )
308
310
  if strategy not in supported_strategies:
309
311
  raise ValueError(
310
- f"One of {supported_strategies} must be used for dynamic quantization"
312
+ f"One of {supported_strategies} must be used for dynamic quant."
311
313
  )
312
314
 
313
315
  if (
@@ -322,7 +324,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
322
324
  observer != "memoryless"
323
325
  ): # avoid annoying users with old configs
324
326
  warnings.warn(
325
- "No observer is used for dynamic quantization, setting to None"
327
+ "No observer is used for dynamic quant., setting to None"
326
328
  )
327
329
  observer = None
328
330
  else:
@@ -81,9 +81,10 @@ class QuantizationScheme(BaseModel):
81
81
  ):
82
82
  warnings.warn(
83
83
  "Using GROUP strategy for both weights and input_activations "
84
- f"with different group sizes ({weights.group_size} vs {inputs.group_size}) "
85
- "may complicate fused kernel implementations. Consider using "
86
- "TENSOR_GROUP strategy for both or matching group sizes.",
84
+ f"with different group sizes ({weights.group_size} vs "
85
+ f"{inputs.group_size}) may complicate fused kernel implementations. "
86
+ "Consider using TENSOR_GROUP strategy for both or matching group"
87
+ " sizes.",
87
88
  UserWarning,
88
89
  stacklevel=2,
89
90
  )
@@ -29,7 +29,6 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
29
29
  from compressed_tensors.utils import deprecated
30
30
  from torch import FloatTensor, IntTensor, Tensor
31
31
  from torch.nn import Module
32
- from tqdm import tqdm
33
32
 
34
33
 
35
34
  __all__ = [
@@ -55,7 +55,7 @@ def standardize_lookup_name(name: str) -> str:
55
55
 
56
56
 
57
57
  def standardize_alias_name(
58
- name: Union[None, str, List[str]]
58
+ name: Union[None, str, List[str]],
59
59
  ) -> Union[None, str, List[str]]:
60
60
  if name is None:
61
61
  return None
@@ -14,7 +14,7 @@
14
14
 
15
15
  from typing import Dict
16
16
 
17
- from compressed_tensors.transform import TransformArgs, TransformScheme
17
+ from compressed_tensors.transform import TransformScheme
18
18
  from pydantic import BaseModel, ConfigDict
19
19
 
20
20
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Callable, Optional, Tuple
15
+ from typing import Optional
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.transform import TransformLocation
@@ -27,6 +27,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
27
27
  __all__ = [
28
28
  "match_named_modules",
29
29
  "match_named_parameters",
30
+ "match_targets",
30
31
  "match_modules_set",
31
32
  "is_match",
32
33
  ]
@@ -37,8 +38,8 @@ FusedMappping = Mapping[str, Iterable[str]]
37
38
 
38
39
  def match_named_modules(
39
40
  model: torch.nn.Module,
40
- targets: Iterable[str],
41
- ignore: Iterable[str] = tuple(),
41
+ targets: Optional[Iterable[str]],
42
+ ignore: Optional[Iterable[str]] = None,
42
43
  fused: Optional[FusedMappping] = None,
43
44
  warn_on_fail: bool = False,
44
45
  ) -> Generator[Tuple[str, torch.nn.Module]]:
@@ -54,14 +55,18 @@ def match_named_modules(
54
55
  :param warn_on_fail: if True, warns if any targets do not match any modules in model
55
56
  :return: generator of module names and modules
56
57
  """
58
+ targets = targets or []
59
+ ignore = ignore or []
60
+
57
61
  unmatched_targets = set(targets)
62
+
58
63
  for name, module in model.named_modules():
59
64
  for target in targets:
60
65
  if is_match(name, module, target, fused=fused):
61
66
  unmatched_targets -= {target}
62
-
63
67
  if not is_match(name, module, ignore, fused=fused):
64
68
  yield name, module
69
+ break
65
70
 
66
71
  if warn_on_fail:
67
72
  for target in unmatched_targets:
@@ -72,8 +77,8 @@ def match_named_modules(
72
77
 
73
78
  def match_named_parameters(
74
79
  model: torch.nn.Module,
75
- targets: Iterable[str],
76
- ignore: Iterable[str] = tuple(),
80
+ targets: Optional[Iterable[str]],
81
+ ignore: Optional[Iterable[str]] = None,
77
82
  fused: Optional[FusedMappping] = None,
78
83
  warn_on_fail: bool = False,
79
84
  ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
@@ -89,6 +94,9 @@ def match_named_parameters(
89
94
  :param warn_on_fail: if True, warns if any targets do not match any params in model
90
95
  :return: generator of fully-qualified param names, parent modules, and params
91
96
  """
97
+ targets = targets or []
98
+ ignore = ignore or []
99
+
92
100
  unmatched_targets = set(targets)
93
101
  for module_name, module in model.named_modules():
94
102
  if isinstance(module, InternalModule):
@@ -110,16 +118,54 @@ def match_named_parameters(
110
118
  )
111
119
 
112
120
 
121
+ def match_targets(
122
+ name: str, module: torch.nn.Module, targets: Optional[Iterable[str]]
123
+ ) -> List[str]:
124
+ """
125
+ Returns the targets that match the given name and module.
126
+
127
+ :param name: the name of the module
128
+ :param module: the module to match
129
+ :param targets: the target strings, potentially containing "re:" prefixes
130
+ :return: the targets that match the given name and module
131
+
132
+ Outputs are ordered by type: exact name match, regex name match, class name match
133
+ """
134
+ targets = targets or []
135
+
136
+ if isinstance(module, InternalModule):
137
+ return []
138
+
139
+ # The order of the output `matches` list matters, the are arranged from most
140
+ # specific to least specific, and this order will be used when merging configs.
141
+ # The entries are sorted in the following order:
142
+ # 1. matches on exact strings
143
+ # 2. matches on regex patterns
144
+ # 3. matches on module names
145
+
146
+ targets = sorted(targets, key=lambda x: ("re:" in x, x))
147
+ matched_targets = []
148
+ for target in targets:
149
+ if _match_name(name, target):
150
+ matched_targets.append(target)
151
+
152
+ for target in targets:
153
+ if _match_class(module, target) and target not in matched_targets:
154
+ matched_targets.append(target)
155
+
156
+ return matched_targets
157
+
158
+
113
159
  def match_modules_set(
114
160
  model: torch.nn.Module,
115
- targets: Iterable[str],
116
- ignore: Iterable[str] = tuple(),
161
+ targets: Optional[Iterable[str]],
162
+ ignore: Optional[Iterable[str]] = None,
117
163
  ) -> Generator[Iterable[torch.nn.Module]]:
118
164
  """
119
165
  Yields modules grouped with the same order and size as `targets`.
120
166
  Values are returned in order of `model.named_modules()`
121
167
 
122
- For example, the following targets would yield module belonging to the following layers:
168
+ E.g. the following targets would yield module belonging to the following layers:
123
169
  ```python3
124
170
  match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
125
171
  (
@@ -151,6 +197,9 @@ def match_modules_set(
151
197
  :param targets: target strings, potentially containing "re:" prefixes
152
198
  :param ignore: targets to ignore, potentially containing "re:" prefixes
153
199
  """
200
+ targets = targets or []
201
+ ignore = ignore or []
202
+
154
203
  matches = dict.fromkeys(targets, None)
155
204
  for name, module in model.named_modules():
156
205
  # match until we get a full set
@@ -296,7 +296,6 @@ def disable_hf_hook(module: torch.nn.Module):
296
296
  hooks = {}
297
297
 
298
298
  def collect_hooks(module):
299
- nonlocal hooks
300
299
  if hasattr(module, "_hf_hook"):
301
300
  hooks[module] = module._hf_hook
302
301
  remove_hook_from_module(module)
@@ -18,7 +18,6 @@ import re
18
18
  import struct
19
19
  from typing import Dict, Iterable, Optional, Tuple, Union
20
20
 
21
- from safetensors import safe_open
22
21
  from torch import Tensor
23
22
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file
24
23
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.11.1.a20250820'
20
+ __version__ = version = '0.11.1.a20250828'
21
21
  __version_tuple__ = version_tuple = (0, 11, 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.11.1a20250820
3
+ Version: 0.11.1a20250828
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.