compressed-tensors 0.11.1a20250820__py3-none-any.whl → 0.11.1a20250828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +178 -156
- compressed_tensors/compressors/quantized_compressors/base.py +2 -2
- compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py +9 -9
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -3
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +1 -1
- compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
- compressed_tensors/quantization/lifecycle/apply.py +48 -142
- compressed_tensors/quantization/lifecycle/forward.py +5 -4
- compressed_tensors/quantization/lifecycle/initialize.py +7 -6
- compressed_tensors/quantization/quant_args.py +7 -5
- compressed_tensors/quantization/quant_scheme.py +4 -3
- compressed_tensors/quantization/utils/helpers.py +0 -1
- compressed_tensors/registry/registry.py +1 -1
- compressed_tensors/transform/transform_config.py +1 -1
- compressed_tensors/transform/utils/matrix.py +1 -1
- compressed_tensors/utils/match.py +57 -8
- compressed_tensors/utils/offload.py +0 -1
- compressed_tensors/utils/safetensors_load.py +0 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/METADATA +1 -1
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/RECORD +24 -24
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.11.1a20250820.dist-info → compressed_tensors-0.11.1a20250828.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,11 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
import
|
17
|
-
from collections import OrderedDict, defaultdict
|
16
|
+
from collections import OrderedDict
|
18
17
|
from copy import deepcopy
|
19
18
|
from typing import Dict, Iterable, List, Optional
|
20
19
|
from typing import OrderedDict as OrderedDictType
|
21
|
-
from typing import
|
20
|
+
from typing import Union
|
22
21
|
|
23
22
|
import torch
|
24
23
|
from compressed_tensors.config import CompressionFormat
|
@@ -39,7 +38,8 @@ from compressed_tensors.quantization.utils import (
|
|
39
38
|
infer_quantization_status,
|
40
39
|
is_kv_cache_quant_scheme,
|
41
40
|
)
|
42
|
-
from compressed_tensors.utils.helpers import
|
41
|
+
from compressed_tensors.utils.helpers import deprecated, replace_module
|
42
|
+
from compressed_tensors.utils.match import match_named_modules, match_targets
|
43
43
|
from compressed_tensors.utils.offload import update_parameter_data
|
44
44
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
45
45
|
from safetensors import safe_open
|
@@ -51,8 +51,6 @@ __all__ = [
|
|
51
51
|
"apply_quantization_config",
|
52
52
|
"apply_quantization_status",
|
53
53
|
"find_name_or_class_matches",
|
54
|
-
"expand_target_names",
|
55
|
-
"is_target",
|
56
54
|
]
|
57
55
|
|
58
56
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -73,14 +71,14 @@ def load_pretrained_quantization_parameters(
|
|
73
71
|
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
74
72
|
a model that has already been initialized with a quantization config.
|
75
73
|
|
76
|
-
NOTE: Will always load inputs/output parameters.
|
77
|
-
|
74
|
+
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
|
75
|
+
parameters, if load_weight_quantization is set to True.
|
78
76
|
|
79
77
|
:param model: model to load pretrained quantization parameters to
|
80
78
|
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
|
81
79
|
model, which is used to load quantization parameters
|
82
|
-
:param load_weight_quantization: whether or not the weight quantization parameters
|
83
|
-
be
|
80
|
+
:param load_weight_quantization: whether or not the weight quantization parameters
|
81
|
+
should be loaded
|
84
82
|
"""
|
85
83
|
model_path = get_safetensors_folder(model_name_or_path)
|
86
84
|
mapping = get_quantization_parameter_to_path_mapping(model_path)
|
@@ -117,7 +115,7 @@ def load_pretrained_quantization_parameters(
|
|
117
115
|
|
118
116
|
def apply_quantization_config(
|
119
117
|
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
120
|
-
)
|
118
|
+
):
|
121
119
|
"""
|
122
120
|
Initializes the model for quantization in-place based on the given config.
|
123
121
|
Optionally coverts quantizable modules to compressed_linear modules
|
@@ -127,71 +125,49 @@ def apply_quantization_config(
|
|
127
125
|
:param run_compressed: Whether the model will be run in compressed mode or
|
128
126
|
decompressed fully on load
|
129
127
|
"""
|
130
|
-
|
131
|
-
if config is None:
|
132
|
-
return dict()
|
128
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
133
129
|
|
134
|
-
# remove reference to the original `config`
|
135
|
-
# argument. This function can mutate it, and we'd
|
136
|
-
# like to keep the original `config` as it is.
|
137
130
|
config = deepcopy(config)
|
131
|
+
if config is None: # see PR #180
|
132
|
+
return dict()
|
133
|
+
|
134
|
+
# preprocess to support kv cache scheme
|
135
|
+
config = process_quantization_config(config)
|
136
|
+
|
138
137
|
# build mapping of targets to schemes for easier matching
|
139
138
|
# use ordered dict to preserve target ordering in config
|
140
139
|
target_to_scheme = OrderedDict()
|
141
|
-
config = process_quantization_config(config)
|
142
|
-
names_to_scheme = dict()
|
143
140
|
for scheme in config.config_groups.values():
|
144
141
|
for target in scheme.targets:
|
145
142
|
target_to_scheme[target] = scheme
|
146
143
|
|
147
|
-
if run_compressed:
|
148
|
-
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
149
|
-
|
150
|
-
# list of submodules to ignore
|
151
|
-
ignored_submodules = defaultdict(list)
|
152
144
|
# mark appropriate layers for quantization by setting their quantization schemes
|
153
|
-
for name, submodule in
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
if
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
)
|
177
|
-
replace_module(model, name, compressed_linear)
|
178
|
-
|
179
|
-
# target matched - add layer and scheme to target list
|
180
|
-
submodule.quantization_scheme = scheme
|
181
|
-
|
182
|
-
names_to_scheme[name] = submodule.quantization_scheme
|
183
|
-
|
184
|
-
if config.ignore is not None and ignored_submodules is not None:
|
185
|
-
if set(config.ignore) - set(ignored_submodules):
|
186
|
-
_LOGGER.warning(
|
187
|
-
"Some layers that were to be ignored were "
|
188
|
-
"not found in the model: "
|
189
|
-
f"{set(config.ignore) - set(ignored_submodules)}"
|
190
|
-
)
|
145
|
+
for name, submodule in match_named_modules(
|
146
|
+
model, target_to_scheme, config.ignore, warn_on_fail=True
|
147
|
+
):
|
148
|
+
# mark modules to be quantized by adding
|
149
|
+
# quant scheme to the matching layers
|
150
|
+
matched_targets = match_targets(name, submodule, target_to_scheme)
|
151
|
+
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
|
152
|
+
# target matched - add layer and scheme to target list
|
153
|
+
submodule.quantization_scheme = scheme
|
154
|
+
|
155
|
+
# replace with run compressed if applicable
|
156
|
+
# FUTURE: move this to model compressor
|
157
|
+
if isinstance(submodule, torch.nn.Linear) and run_compressed:
|
158
|
+
format = config.format
|
159
|
+
if format != CompressionFormat.dense.value:
|
160
|
+
if isinstance(submodule, torch.nn.Linear):
|
161
|
+
# TODO: expand to more module types
|
162
|
+
compressed_linear = CompressedLinear.from_linear(
|
163
|
+
submodule,
|
164
|
+
quantization_scheme=scheme,
|
165
|
+
quantization_format=format,
|
166
|
+
)
|
167
|
+
replace_module(model, name, compressed_linear)
|
191
168
|
|
192
169
|
# apply current quantization status across all targeted layers
|
193
170
|
apply_quantization_status(model, config.quantization_status)
|
194
|
-
return names_to_scheme
|
195
171
|
|
196
172
|
|
197
173
|
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
|
@@ -262,54 +238,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
262
238
|
model.apply(compress_quantized_weights)
|
263
239
|
|
264
240
|
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
) -> Set[str]:
|
270
|
-
"""
|
271
|
-
Finds all unique module names in the model that match the given
|
272
|
-
targets and ignore lists.
|
273
|
-
|
274
|
-
Note: Targets must be regexes, layer types, or full layer names.
|
275
|
-
|
276
|
-
:param model: model to search for targets in
|
277
|
-
:param targets: Iterable of targets to search for
|
278
|
-
:param ignore: Iterable of targets to ignore
|
279
|
-
:return: set of all targets that match the given targets and should
|
280
|
-
not be ignored
|
281
|
-
"""
|
282
|
-
return {
|
283
|
-
name
|
284
|
-
for name, module in model.named_modules()
|
285
|
-
if is_target(name, module, targets, ignore)
|
286
|
-
}
|
287
|
-
|
288
|
-
|
289
|
-
def is_target(
|
290
|
-
name: str,
|
291
|
-
module: Module,
|
292
|
-
targets: Optional[Iterable[str]] = None,
|
293
|
-
ignore: Optional[Iterable[str]] = None,
|
294
|
-
) -> bool:
|
295
|
-
"""
|
296
|
-
Determines if a module should be included in the targets based on the
|
297
|
-
targets and ignore lists.
|
298
|
-
|
299
|
-
Note: Targets must be regexes, layer types, or full layer names.
|
300
|
-
|
301
|
-
:param name: name of the module
|
302
|
-
:param module: the module itself
|
303
|
-
:param targets: Iterable of targets to search for
|
304
|
-
:param ignore: Iterable of targets to ignore
|
305
|
-
:return: True if the module is a target and not ignored, False otherwise
|
306
|
-
"""
|
307
|
-
return bool(
|
308
|
-
find_name_or_class_matches(name, module, targets or [])
|
309
|
-
and not find_name_or_class_matches(name, module, ignore or [])
|
310
|
-
)
|
311
|
-
|
312
|
-
|
241
|
+
@deprecated(
|
242
|
+
message="This function is deprecated and will be removed in a future release."
|
243
|
+
"Please use `match_targets` from `compressed_tensors.utils.match` instead."
|
244
|
+
)
|
313
245
|
def find_name_or_class_matches(
|
314
246
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
315
247
|
) -> List[str]:
|
@@ -322,38 +254,13 @@ def find_name_or_class_matches(
|
|
322
254
|
2. matches on regex patterns
|
323
255
|
3. matches on module names
|
324
256
|
"""
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
331
|
-
if isinstance(targets, Iterable):
|
332
|
-
matches = _find_matches(name, targets) + _find_matches(
|
333
|
-
module.__class__.__name__, targets, check_contains
|
257
|
+
if check_contains:
|
258
|
+
raise NotImplementedError(
|
259
|
+
"This function is deprecated, and the check_contains=True option has been"
|
260
|
+
" removed."
|
334
261
|
)
|
335
|
-
matches = [match for match in matches if match is not None]
|
336
|
-
return matches
|
337
262
|
|
338
|
-
|
339
|
-
def _find_matches(
|
340
|
-
value: str, targets: Iterable[str], check_contains: bool = False
|
341
|
-
) -> List[str]:
|
342
|
-
# returns all the targets that match value either
|
343
|
-
# exactly or as a regex after 're:'. if check_contains is set to True,
|
344
|
-
# additionally checks if the target string is contained with value.
|
345
|
-
matches = []
|
346
|
-
for target in targets:
|
347
|
-
if target.startswith("re:"):
|
348
|
-
pattern = target[3:]
|
349
|
-
if re.match(pattern, value):
|
350
|
-
matches.append(target)
|
351
|
-
elif check_contains:
|
352
|
-
if target.lower() in value.lower():
|
353
|
-
matches.append(target)
|
354
|
-
elif target == value:
|
355
|
-
matches.append(target)
|
356
|
-
return matches
|
263
|
+
return match_targets(name, module, targets)
|
357
264
|
|
358
265
|
|
359
266
|
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
@@ -429,7 +336,6 @@ def _scheme_from_targets(
|
|
429
336
|
def _merge_schemes(
|
430
337
|
schemes_to_merge: List[QuantizationScheme], name: str
|
431
338
|
) -> QuantizationScheme:
|
432
|
-
|
433
339
|
kv_cache_quantization_scheme = [
|
434
340
|
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
435
341
|
]
|
@@ -205,7 +205,8 @@ def _process_quantization(
|
|
205
205
|
q_min, q_max = calculate_range(args, x.device)
|
206
206
|
group_size = args.group_size
|
207
207
|
|
208
|
-
# blockwise FP8: quantize per 2D block, supports block_structure for static block
|
208
|
+
# blockwise FP8: quantize per 2D block, supports block_structure for static block
|
209
|
+
# quantization
|
209
210
|
if args.strategy == QuantizationStrategy.BLOCK:
|
210
211
|
original_shape = x.shape
|
211
212
|
rows, cols = x.shape[-2], x.shape[-1]
|
@@ -214,8 +215,8 @@ def _process_quantization(
|
|
214
215
|
# Ensure exact division (tensor dimensions must be divisible by block size)
|
215
216
|
if rows % block_height != 0:
|
216
217
|
raise ValueError(
|
217
|
-
f"Tensor height {rows} is not divisible by block_height {block_height}.
|
218
|
-
f"Block quantization requires exact division."
|
218
|
+
f"Tensor height {rows} is not divisible by block_height {block_height}."
|
219
|
+
f" Block quantization requires exact division."
|
219
220
|
)
|
220
221
|
if cols % block_width != 0:
|
221
222
|
raise ValueError(
|
@@ -295,7 +296,7 @@ def _process_quantization(
|
|
295
296
|
perm = torch.argsort(g_idx)
|
296
297
|
x = safe_permute(x, perm, dim=1)
|
297
298
|
|
298
|
-
# Maintain all dimensions
|
299
|
+
# Maintain all dimensions except the last dim, which is divided by group_size
|
299
300
|
reshaped_dims = (
|
300
301
|
ceil(x.shape[-1] / group_size),
|
301
302
|
group_size,
|
@@ -17,7 +17,7 @@ import logging
|
|
17
17
|
import math
|
18
18
|
import warnings
|
19
19
|
from enum import Enum
|
20
|
-
from typing import
|
20
|
+
from typing import Optional
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from compressed_tensors.quantization.lifecycle.forward import (
|
@@ -87,7 +87,6 @@ def initialize_module_for_quantization(
|
|
87
87
|
_initialize_attn_scales(module)
|
88
88
|
|
89
89
|
else:
|
90
|
-
|
91
90
|
if scheme.input_activations is not None:
|
92
91
|
_initialize_scale_zero_point(
|
93
92
|
module,
|
@@ -183,7 +182,8 @@ def _initialize_scale_zero_point(
|
|
183
182
|
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
|
184
183
|
expected_shape = (weight_shape[0], max(num_groups, 1))
|
185
184
|
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
|
186
|
-
# For block quantization, scale shape should match number of blocks - only
|
185
|
+
# For block quantization, scale shape should match number of blocks - only
|
186
|
+
# for weights
|
187
187
|
if quantization_args.block_structure is None:
|
188
188
|
raise ValueError(
|
189
189
|
"Block quantization requires block_structure to be specified"
|
@@ -196,9 +196,10 @@ def _initialize_scale_zero_point(
|
|
196
196
|
# Warn if dimensions don't divide evenly
|
197
197
|
if rows % block_height != 0 or cols % block_width != 0:
|
198
198
|
warnings.warn(
|
199
|
-
f"Block quantization: tensor shape {weight_shape} does not divide
|
200
|
-
f"by block structure {quantization_args.block_structure}. "
|
201
|
-
f"Some blocks will be incomplete which may affect quantization
|
199
|
+
f"Block quantization: tensor shape {weight_shape} does not divide"
|
200
|
+
f"evenly by block structure {quantization_args.block_structure}. "
|
201
|
+
f"Some blocks will be incomplete which may affect quantization"
|
202
|
+
"quality.",
|
202
203
|
UserWarning,
|
203
204
|
)
|
204
205
|
|
@@ -217,16 +217,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
217
217
|
return [int(x) for x in value.split("x")]
|
218
218
|
except Exception:
|
219
219
|
raise ValueError(
|
220
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
220
|
+
f"Invalid block_structure '{value}'. Must be a list of ints "
|
221
|
+
"[rows, cols]."
|
221
222
|
)
|
222
223
|
if isinstance(value, (list, tuple)):
|
223
224
|
if len(value) != 2 or not all(isinstance(v, int) for v in value):
|
224
225
|
raise ValueError(
|
225
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
226
|
+
f"Invalid block_structure '{value}'. Must be a list of ints "
|
227
|
+
"[rows, cols]."
|
226
228
|
)
|
227
229
|
return list(value)
|
228
230
|
raise ValueError(
|
229
|
-
f"Invalid block_structure '{value}'. Must be a list of
|
231
|
+
f"Invalid block_structure '{value}'. Must be a list of ints [rows, cols]."
|
230
232
|
)
|
231
233
|
|
232
234
|
@field_validator("strategy", mode="before")
|
@@ -307,7 +309,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
307
309
|
)
|
308
310
|
if strategy not in supported_strategies:
|
309
311
|
raise ValueError(
|
310
|
-
f"One of {supported_strategies} must be used for dynamic
|
312
|
+
f"One of {supported_strategies} must be used for dynamic quant."
|
311
313
|
)
|
312
314
|
|
313
315
|
if (
|
@@ -322,7 +324,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
322
324
|
observer != "memoryless"
|
323
325
|
): # avoid annoying users with old configs
|
324
326
|
warnings.warn(
|
325
|
-
"No observer is used for dynamic
|
327
|
+
"No observer is used for dynamic quant., setting to None"
|
326
328
|
)
|
327
329
|
observer = None
|
328
330
|
else:
|
@@ -81,9 +81,10 @@ class QuantizationScheme(BaseModel):
|
|
81
81
|
):
|
82
82
|
warnings.warn(
|
83
83
|
"Using GROUP strategy for both weights and input_activations "
|
84
|
-
f"with different group sizes ({weights.group_size} vs
|
85
|
-
"may complicate fused kernel implementations.
|
86
|
-
"TENSOR_GROUP strategy for both or matching group
|
84
|
+
f"with different group sizes ({weights.group_size} vs "
|
85
|
+
f"{inputs.group_size}) may complicate fused kernel implementations. "
|
86
|
+
"Consider using TENSOR_GROUP strategy for both or matching group"
|
87
|
+
" sizes.",
|
87
88
|
UserWarning,
|
88
89
|
stacklevel=2,
|
89
90
|
)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import
|
15
|
+
from typing import Optional
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from compressed_tensors.transform import TransformLocation
|
@@ -27,6 +27,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
27
27
|
__all__ = [
|
28
28
|
"match_named_modules",
|
29
29
|
"match_named_parameters",
|
30
|
+
"match_targets",
|
30
31
|
"match_modules_set",
|
31
32
|
"is_match",
|
32
33
|
]
|
@@ -37,8 +38,8 @@ FusedMappping = Mapping[str, Iterable[str]]
|
|
37
38
|
|
38
39
|
def match_named_modules(
|
39
40
|
model: torch.nn.Module,
|
40
|
-
targets: Iterable[str],
|
41
|
-
ignore: Iterable[str] =
|
41
|
+
targets: Optional[Iterable[str]],
|
42
|
+
ignore: Optional[Iterable[str]] = None,
|
42
43
|
fused: Optional[FusedMappping] = None,
|
43
44
|
warn_on_fail: bool = False,
|
44
45
|
) -> Generator[Tuple[str, torch.nn.Module]]:
|
@@ -54,14 +55,18 @@ def match_named_modules(
|
|
54
55
|
:param warn_on_fail: if True, warns if any targets do not match any modules in model
|
55
56
|
:return: generator of module names and modules
|
56
57
|
"""
|
58
|
+
targets = targets or []
|
59
|
+
ignore = ignore or []
|
60
|
+
|
57
61
|
unmatched_targets = set(targets)
|
62
|
+
|
58
63
|
for name, module in model.named_modules():
|
59
64
|
for target in targets:
|
60
65
|
if is_match(name, module, target, fused=fused):
|
61
66
|
unmatched_targets -= {target}
|
62
|
-
|
63
67
|
if not is_match(name, module, ignore, fused=fused):
|
64
68
|
yield name, module
|
69
|
+
break
|
65
70
|
|
66
71
|
if warn_on_fail:
|
67
72
|
for target in unmatched_targets:
|
@@ -72,8 +77,8 @@ def match_named_modules(
|
|
72
77
|
|
73
78
|
def match_named_parameters(
|
74
79
|
model: torch.nn.Module,
|
75
|
-
targets: Iterable[str],
|
76
|
-
ignore: Iterable[str] =
|
80
|
+
targets: Optional[Iterable[str]],
|
81
|
+
ignore: Optional[Iterable[str]] = None,
|
77
82
|
fused: Optional[FusedMappping] = None,
|
78
83
|
warn_on_fail: bool = False,
|
79
84
|
) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]:
|
@@ -89,6 +94,9 @@ def match_named_parameters(
|
|
89
94
|
:param warn_on_fail: if True, warns if any targets do not match any params in model
|
90
95
|
:return: generator of fully-qualified param names, parent modules, and params
|
91
96
|
"""
|
97
|
+
targets = targets or []
|
98
|
+
ignore = ignore or []
|
99
|
+
|
92
100
|
unmatched_targets = set(targets)
|
93
101
|
for module_name, module in model.named_modules():
|
94
102
|
if isinstance(module, InternalModule):
|
@@ -110,16 +118,54 @@ def match_named_parameters(
|
|
110
118
|
)
|
111
119
|
|
112
120
|
|
121
|
+
def match_targets(
|
122
|
+
name: str, module: torch.nn.Module, targets: Optional[Iterable[str]]
|
123
|
+
) -> List[str]:
|
124
|
+
"""
|
125
|
+
Returns the targets that match the given name and module.
|
126
|
+
|
127
|
+
:param name: the name of the module
|
128
|
+
:param module: the module to match
|
129
|
+
:param targets: the target strings, potentially containing "re:" prefixes
|
130
|
+
:return: the targets that match the given name and module
|
131
|
+
|
132
|
+
Outputs are ordered by type: exact name match, regex name match, class name match
|
133
|
+
"""
|
134
|
+
targets = targets or []
|
135
|
+
|
136
|
+
if isinstance(module, InternalModule):
|
137
|
+
return []
|
138
|
+
|
139
|
+
# The order of the output `matches` list matters, the are arranged from most
|
140
|
+
# specific to least specific, and this order will be used when merging configs.
|
141
|
+
# The entries are sorted in the following order:
|
142
|
+
# 1. matches on exact strings
|
143
|
+
# 2. matches on regex patterns
|
144
|
+
# 3. matches on module names
|
145
|
+
|
146
|
+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
147
|
+
matched_targets = []
|
148
|
+
for target in targets:
|
149
|
+
if _match_name(name, target):
|
150
|
+
matched_targets.append(target)
|
151
|
+
|
152
|
+
for target in targets:
|
153
|
+
if _match_class(module, target) and target not in matched_targets:
|
154
|
+
matched_targets.append(target)
|
155
|
+
|
156
|
+
return matched_targets
|
157
|
+
|
158
|
+
|
113
159
|
def match_modules_set(
|
114
160
|
model: torch.nn.Module,
|
115
|
-
targets: Iterable[str],
|
116
|
-
ignore: Iterable[str] =
|
161
|
+
targets: Optional[Iterable[str]],
|
162
|
+
ignore: Optional[Iterable[str]] = None,
|
117
163
|
) -> Generator[Iterable[torch.nn.Module]]:
|
118
164
|
"""
|
119
165
|
Yields modules grouped with the same order and size as `targets`.
|
120
166
|
Values are returned in order of `model.named_modules()`
|
121
167
|
|
122
|
-
|
168
|
+
E.g. the following targets would yield module belonging to the following layers:
|
123
169
|
```python3
|
124
170
|
match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
|
125
171
|
(
|
@@ -151,6 +197,9 @@ def match_modules_set(
|
|
151
197
|
:param targets: target strings, potentially containing "re:" prefixes
|
152
198
|
:param ignore: targets to ignore, potentially containing "re:" prefixes
|
153
199
|
"""
|
200
|
+
targets = targets or []
|
201
|
+
ignore = ignore or []
|
202
|
+
|
154
203
|
matches = dict.fromkeys(targets, None)
|
155
204
|
for name, module in model.named_modules():
|
156
205
|
# match until we get a full set
|
compressed_tensors/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: compressed-tensors
|
3
|
-
Version: 0.11.
|
3
|
+
Version: 0.11.1a20250828
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|