compressed-tensors 0.12.3a20251114__py3-none-any.whl → 0.12.3a20251203__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
+ import os
16
17
  import re
18
+ from collections import defaultdict
17
19
  from collections.abc import Generator
18
20
  from typing import Iterable, List, Mapping, Optional, Tuple, Union
19
21
 
@@ -29,6 +31,7 @@ __all__ = [
29
31
  "match_named_parameters",
30
32
  "match_targets",
31
33
  "match_modules_set",
34
+ "get_lowest_common_ancestor_name",
32
35
  "is_match",
33
36
  "is_narrow_match",
34
37
  ]
@@ -157,34 +160,68 @@ def match_targets(
157
160
  return matched_targets
158
161
 
159
162
 
163
+ def get_lowest_common_ancestor_name(names: list[str | None]) -> str:
164
+ """
165
+ Given a list of names, returns the lowest-scope common name ignoring Nones.
166
+
167
+ Implementation is a small alteration of os.path.commonprefix
168
+ https://docs.python.org/3/library/os.path.html#os.path.commonprefix
169
+
170
+ ([s1, s2]->prefix->result)
171
+ # case 0: multiple modules: [abc.a., abc.b.] -> .abc. -> abc
172
+ # case 1: single module: [abc.] -> .abc. -> abc
173
+ # case 2: substring modules: [abc., ab.] -> .ab -> ""
174
+ # case 3: parent & child: [ab., ab.a.] -> .ab. -> ab
175
+ """
176
+ names = [name for name in names if name is not None]
177
+ if len(names) == 0:
178
+ return ""
179
+
180
+ # 1) find longest shared prefix
181
+ s1 = "." + min(names) + "."
182
+ s2 = "." + max(names) + "."
183
+ common_prefix = os.path.commonprefix([s1, s2])
184
+ # 2) throw away right most dot and name fragment, throw away leftmost char
185
+ # ".keep.thro" -> "keep", "." -> ""
186
+ return common_prefix[1 : common_prefix.rfind(".")]
187
+
188
+
160
189
  def match_modules_set(
161
190
  model: torch.nn.Module,
162
191
  targets: Optional[Iterable[str]],
163
192
  ignore: Optional[Iterable[str]] = None,
164
- ) -> Generator[Iterable[torch.nn.Module]]:
193
+ error_on_module_rematch: bool = True,
194
+ ) -> Generator[List[List[torch.nn.Module]]]:
165
195
  """
166
- Yields modules grouped with the same order and size as `targets`.
167
- Values are returned in order of `model.named_modules()`
196
+ Yields modules grouped by parent context.
197
+
198
+ We group by parent context so that we can return ALL matches of a
199
+ specific target that can be paired with another target. This is most
200
+ relevant in the case of MoE modules with multiple modules for each
201
+ expert i.e. post_attention_layernorm <-> mlp.expert.N.gate_proj,
202
+ mlp.expert.N.up_proj for all N. The parent context will differ from
203
+ one layer to another while being the same for one expert to another.
168
204
 
169
- E.g. the following targets would yield module belonging to the following layers:
205
+ Each returned group is a list (of lists) with the same size
206
+ and order as `targets` while all matches for each target and
207
+ the overall order of the groups are ordered in the same way
208
+ as `model.named_modules`
209
+
210
+
211
+ E.g. the following targets would yield modules belonging to the following layers:
170
212
  ```python3
171
213
  match_modules_set(model, ["q_proj", "k_proj", "v_proj"]) == (
172
- (
173
- `model.layers.0.self_attn.q_proj`,
174
- `model.layers.0.self_attn.k_proj`,
175
- `model.layers.0.self_attn.v_proj`,
176
- ),
177
- (
178
- `model.layers.1.self_attn.q_proj`,
179
- `model.layers.1.self_attn.k_proj`,
180
- `model.layers.1.self_attn.v_proj`,
181
- ),
214
+ [
215
+ [`layers.0.self_attn.q_proj`],
216
+ [`layers.0.self_attn.k_proj`],
217
+ [`layers.0.self_attn.v_proj`],
218
+ ],
219
+ [
220
+ [`layers.1.self_attn.q_proj`],
221
+ [`layers.1.self_attn.k_proj`],
222
+ [`layers.1.self_attn.v_proj`],
223
+ ],
182
224
  ...
183
- (
184
- `model.layers.32.self_attn.q_proj`,
185
- `model.layers.32.self_attn.k_proj`,
186
- `model.layers.32.self_attn.v_proj`,
187
- ),
188
225
  )
189
226
  ```
190
227
 
@@ -192,33 +229,125 @@ def match_modules_set(
192
229
  For example, matching layer norms to their subsequent linear layers
193
230
  ```python3
194
231
  for norm, q, k, v in match_modules_set(model, (norm_tgt, q_tgt, k_tgt, v_tgt)):
195
- fuse_norm_linears(norm, [q, k, v])
232
+ fuse_norm_linears(*norm, [*q, *k, *v])
233
+ ```
234
+
235
+ Alternatively for MoE you would get multiple matches
236
+ per target per group, E.g.
237
+
238
+ ```python3
239
+
240
+ targets = [
241
+ "post_attention_layernorm",
242
+ "up_proj",
243
+ "down_proj"
244
+ ]
245
+ match_modules_set(model, targets) == (
246
+ [
247
+ [layers.0.post_attention_layernorm],
248
+ [
249
+ `layers.0.mlp.experts.0.up_proj`,
250
+ `layers.0.mlp.experts.1.up_proj`,
251
+ ...
252
+ ],
253
+ [
254
+ `layers.0.mlp.experts.0.down_proj`,
255
+ `layers.0.mlp.experts.1.down_proj`,
256
+ ...
257
+
258
+ ]
259
+ ], # <- first yield
260
+ [
261
+ [layers.1.post_attention_layernorm],
262
+ [
263
+ `layers.1.mlp.experts.0.up_proj`,
264
+ `layers.1.mlp.experts.1.up_proj`,
265
+ ...
266
+ ],
267
+ [
268
+ `layers.1.mlp.experts.0.down_proj`,
269
+ `layers.1.mlp.experts.1.down_proj`,
270
+ ...
271
+ ]
272
+ ],
273
+ ...
274
+ )
275
+ ```
196
276
 
197
277
  :param model: model containing modules to match against
198
278
  :param targets: target strings, potentially containing "re:" prefixes
199
279
  :param ignore: targets to ignore, potentially containing "re:" prefixes
280
+ :param error_on_module_rematch: if True, errors when a module gets
281
+ matched to multiple targets, if False, no error. (Defaults to True)
200
282
  """
201
283
  targets = targets or []
202
284
  ignore = ignore or []
203
285
 
204
- matches = dict.fromkeys(targets, None)
286
+ # as we iterate through modules and try to match them with targets,
287
+ # the algorithm can be in 2 possible states:
288
+ # 0) unmatched_targets > 0, i.e. some of the targets haven't been matched.
289
+ # Keep matching until all targets have at least one match
290
+ # 1) unmatched_targets == 0 i.e. we have at least one match for each target.
291
+ # At this point we are unsure if we have a full set or if we need to add
292
+ # more matches.
293
+ # There are 3 things that can happen once were in state 1:
294
+ # A) found a new match with same parent_context,
295
+ # (add it to matches and keep going)
296
+ # B) found a new match with different parent_context, i.e. we found a match
297
+ # that requires a deeper parent context, this indicates that this match
298
+ # should be part of a new set.
299
+ # (yield current set [not including newest match] and go back to state 0)
300
+ # C) ran out of modules, we will always yield the final remaining set when
301
+ # we we've iterated through all the modules in the model.
302
+ # (yield final set then exit.)
303
+ # Note: its possible to iterate through all the modules in the model while
304
+ # not having a full matched set if the user specified a bad matching, in
305
+ # that case something has gone wrong and we error
306
+ matches = defaultdict(list)
307
+ parent_context = None
308
+ unmatched_targets = set(targets)
309
+
205
310
  for name, module in model.named_modules():
206
- # match until we get a full set
311
+ matched_targets_for_cur_module = set()
207
312
  for target in targets:
208
313
  if is_match(name, module, target, ignore):
209
- if matches[target] is not None:
210
- raise ValueError(f"Matched a {target} twice before completing set")
211
- matches[target] = module
212
-
213
- # once we have a full set, yield and reset
214
- if targets and all((matches[target] is not None for target in targets)):
215
- yield [matches[target] for target in targets] # ensure correct ordering
216
- matches = dict.fromkeys(targets, None)
217
-
218
- # check that none are left over
219
- unmatched_keys = [match for match, value in matches.items() if value is not None]
220
- if len(unmatched_keys):
221
- raise ValueError(f"Unable to match targets into set: {unmatched_keys}")
314
+ new_parent_context = get_lowest_common_ancestor_name(
315
+ [name, parent_context]
316
+ )
317
+
318
+ # code for (B)
319
+ if not unmatched_targets and new_parent_context != parent_context:
320
+ yield [matches[target] for target in targets]
321
+ matches = defaultdict(list)
322
+ new_parent_context = name
323
+ unmatched_targets = set(targets)
324
+
325
+ matches[target].append(module)
326
+ parent_context = new_parent_context
327
+ unmatched_targets -= {target}
328
+ matched_targets_for_cur_module |= {target}
329
+
330
+ if len(matched_targets_for_cur_module) > 1 and error_on_module_rematch:
331
+ raise ValueError(
332
+ f"module: {name} was matched with multiple targets: "
333
+ f"{matched_targets_for_cur_module} which is unexpected "
334
+ "disable this check by setting `error_on_module_rematch = False`"
335
+ )
336
+
337
+ # never found anything
338
+ if unmatched_targets == set(targets):
339
+ return
340
+
341
+ # code for (C)
342
+ if not unmatched_targets: # have a full matching
343
+ yield [matches[target] for target in targets]
344
+ return
345
+
346
+ raise ValueError(
347
+ f"Found a final incomplete set with matches found for keys: "
348
+ f"{set(targets) - unmatched_targets} "
349
+ f"but no matches found for keys: {unmatched_targets}"
350
+ )
222
351
 
223
352
 
224
353
  def is_match(
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.12.3.a20251114'
20
+ __version__ = version = '0.12.3.a20251203'
21
21
  __version_tuple__ = version_tuple = (0, 12, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.12.3a20251114
3
+ Version: 0.12.3a20251203
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/vllm-project/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,7 +1,7 @@
1
1
  compressed_tensors/__init__.py,sha256=SRqNYFVvxAaLa4SImhoiIBKfoOSj7EUdx0CxXjGC2PA,884
2
2
  compressed_tensors/base.py,sha256=dKAVgQAp9GPH6YspvF_cbGXCrbiqAeLEIPydYAO40WE,859
3
3
  compressed_tensors/logger.py,sha256=sTm1Od1cV0aDxBm3YN-PPvsOATxY_2tBV62TQE4HiPw,4032
4
- compressed_tensors/version.py,sha256=_76JjfEalYtLnwlx_1vHVRHYO4_7nPpez11U9pkUbyk,523
4
+ compressed_tensors/version.py,sha256=muVSrnf9zuwEcHxyznnbC_TivRxUkdpFuvlk05CdEcA,523
5
5
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
6
6
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
7
7
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -63,14 +63,14 @@ compressed_tensors/transform/utils/matrix.py,sha256=BapkVu1763cN1VPP0ukvSzmG0dHo
63
63
  compressed_tensors/utils/__init__.py,sha256=eXvtlJEUiV4XPcfsxVrOwL7DyY8r-L0XG_Rr5qmZrmU,850
64
64
  compressed_tensors/utils/helpers.py,sha256=WHYh8yxMsmG2HxcfNVzcMLC-dtgWroRZaLcreADmUYE,15562
65
65
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
66
- compressed_tensors/utils/match.py,sha256=g1K6x56LoCGVR_sA25MOpja_U_V6_MZ-6cSY_Q_IauY,12320
66
+ compressed_tensors/utils/match.py,sha256=wlNvl_x8QYpd8pyiAsRrJITHAKihRpi_pC_1iNi5UgI,17120
67
67
  compressed_tensors/utils/offload.py,sha256=eXqLzl8kUkVDlNtcO5sn_4QoDcbAaxbCAS3tyZ-aGr8,23538
68
68
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
69
69
  compressed_tensors/utils/safetensors_load.py,sha256=Vql34aCTDHwmTZXJHzCyBISJo7iA7EQ78LdTlMjdpZo,12023
70
70
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
71
71
  compressed_tensors/utils/type.py,sha256=bNwoo_FWlvLuDpYAGGzZJITRg0JA_Ngk9LGPo-kvjeU,2554
72
- compressed_tensors-0.12.3a20251114.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
73
- compressed_tensors-0.12.3a20251114.dist-info/METADATA,sha256=_GmxiVbPvm29hVkbhLwcHbFVKEnTq8RJPLVRIuFKyQQ,7027
74
- compressed_tensors-0.12.3a20251114.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
75
- compressed_tensors-0.12.3a20251114.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
76
- compressed_tensors-0.12.3a20251114.dist-info/RECORD,,
72
+ compressed_tensors-0.12.3a20251203.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
73
+ compressed_tensors-0.12.3a20251203.dist-info/METADATA,sha256=SH2vzxy3ZkciavF49vLmX-6dMl07CIKvkljmba5ArcQ,7027
74
+ compressed_tensors-0.12.3a20251203.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
75
+ compressed_tensors-0.12.3a20251203.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
76
+ compressed_tensors-0.12.3a20251203.dist-info/RECORD,,