compressed-tensors 0.10.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,15 +13,31 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Any, List
16
+ from typing import List
17
17
 
18
18
  from pydantic import BaseModel, Field, field_validator
19
19
 
20
20
 
21
- __all__ = ["TransformArgs"]
21
+ __all__ = ["TransformArgs", "TransformLocation"]
22
22
 
23
23
 
24
24
  class TransformLocation(str, Enum):
25
+ """
26
+ Enum representing which parameters/activations a transform weight should be applied
27
+ to on a given module.
28
+
29
+ | -------------------------------------------------------------------------------------------------------- | # noqa: E501
30
+ | Name | Runtime | Values | Locations Where Inverse Could Be Applied | # noqa: E501
31
+ | --------------- | ----------- | ------------- | -------------------------------------------------------- | # noqa: E501
32
+ | `INPUT` | online | activations | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.WEIGHT_INPUT` | # noqa: E501
33
+ | `WEIGHT_INPUT` | offline | weight | `prev.WEIGHT_OUTPUT`, `prev.OUTPUT`, `this.INPUT` | # noqa: E501
34
+ | `WEIGHT_OUTPUT` | offline | weight | `this.OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
35
+ | `OUTPUT` | online | activations | `this.WEIGHT_OUTPUT`, `next.INPUT`, `next.WEIGHT_INPUT` | # noqa: E501
36
+ | `K_CACHE` | online | key_values | `q_proj.Q_ATTN` | # noqa: E501
37
+ | `Q_ATTN` | online | query_values | `k_proj.K_CACHE` | # noqa: E501
38
+ | -------------------------------------------------------------------------------------------------------- | # noqa: E501
39
+ """
40
+
25
41
  INPUT = "input"
26
42
  WEIGHT_INPUT = "weight_input"
27
43
  WEIGHT_OUTPUT = "weight_output"
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,161 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+ import numpy
19
+ import torch
20
+
21
+
22
+ __all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
23
+
24
+ # adapted from:
25
+ # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
26
+ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
27
+ """
28
+ Construct an n-by-n Hadamard matrix, using Sylvester's construction.
29
+ `n` must be a power of 2.
30
+
31
+ :param size: order of the matrix, must be a power of 2
32
+ :return: hadamard matrix of size `size`
33
+ """
34
+ if size <= 0:
35
+ raise ValueError("Cannot construct deterministic hadamard of size <= 0")
36
+
37
+ log2 = int(math.log(size, 2))
38
+ if size != 2**log2:
39
+ raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
40
+
41
+ H = numpy.array([[1]], dtype=int)
42
+
43
+ # Sylvester's construction
44
+ for i in range(0, log2):
45
+ H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
46
+
47
+ return torch.from_numpy(H / math.sqrt(size))
48
+
49
+
50
+ # adapted from:
51
+ # https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52
+
53
+ # TODO: the following library exists for online rotations and should be considered
54
+ # in the future:
55
+ # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
56
+
57
+
58
+ def random_hadamard_matrix(
59
+ size: int, gen: Optional[torch.Generator] = None
60
+ ) -> torch.Tensor:
61
+ """
62
+ Produces a randomly generated Hadamard matrix.
63
+ See https://cornell-relaxml.github.io/quip-sharp/ ,
64
+ Section "Randomized Hadamard Transformation"
65
+
66
+ :param size: The dimension of the hamadard matrix
67
+ :param gen: Optional generator random values
68
+ :return: randomly generated hadamard matrix
69
+ """
70
+ # Benefits: support other shapes / non powers of 2, support randomization
71
+ Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
72
+ Q = Q * 2 - 1
73
+ Q = torch.diag(Q)
74
+ return _matmul_hadU(Q) / math.sqrt(size)
75
+
76
+
77
+ def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78
+ # NOTE: we can easily extend the list of supported shapes/sizes
79
+ # by adding to these methods
80
+ hadK, K = None, None
81
+ if n % 20 == 0:
82
+ assert _is_pow2(n // 20)
83
+ K = 20
84
+ hadK = _get_had20().T if transpose else _get_had20()
85
+ elif n % 12 == 0:
86
+ assert _is_pow2(n // 12)
87
+ K = 12
88
+ hadK = _get_had12().T if transpose else _get_had12()
89
+ else:
90
+ assert _is_pow2(n)
91
+ K = 1
92
+
93
+ return hadK, K
94
+
95
+
96
+ def _matmul_hadU(X, transpose=False) -> torch.Tensor:
97
+ n = X.shape[-1]
98
+ # Check if we have the determined hadamard matrix
99
+ hadK, K = _get_hadK(n, transpose)
100
+ # Reshape diag matrix with randomized -1/+1
101
+ input = X.clone().view(-1, n, 1)
102
+ output = input.clone()
103
+
104
+ # for cases when hadK is not predetermined, determine hadamard matrix
105
+ while input.shape[1] > K:
106
+ input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
107
+ output = output.view(input.shape)
108
+ output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
109
+ output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
110
+ output = output.view(input.shape[0], input.shape[1], -1)
111
+ (input, output) = (output, input)
112
+ del output
113
+
114
+ # K == 1 when hadK is None; this happens when the size dim (n)
115
+ # is not comaptible with any of the maintained hadamard matrices
116
+
117
+ if K > 1:
118
+ # Do not explicitly repeat - OOM
119
+ # input = torch.bmm(
120
+ # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
121
+ # Use bcast instead
122
+
123
+ # for cases when hadK is pre-determined
124
+ input = hadK.view(1, K, K).to(input) @ input
125
+
126
+ # normalize
127
+ return input.view(X.shape)
128
+
129
+
130
+ def _is_pow2(n: int) -> bool:
131
+ return (n & (n - 1) == 0) and (n > 0)
132
+
133
+
134
+ def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135
+ had_unpacked = numpy.unpackbits(packed_bits)
136
+ had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137
+ had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138
+ return had_unpacked
139
+
140
+
141
+ # http://www.neilsloane.com/hadamard/index.html
142
+ def _get_had12() -> torch.Tensor:
143
+ # fmt: off
144
+ had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145
+ 62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146
+ # fmt: on
147
+ # TODO: just unpack during apply
148
+ had_12_unpacked = _reshape_bits(had_12, original_size=12)
149
+ return torch.tensor(had_12_unpacked)
150
+
151
+
152
+ def _get_had20() -> torch.Tensor:
153
+ # fmt: off
154
+ had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155
+ 216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156
+ 97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157
+ 205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158
+ # fmt: on
159
+ # TODO: just unpack during apply
160
+ had_20_unpacked = _reshape_bits(had_20, original_size=20)
161
+ return torch.tensor(had_20_unpacked)
@@ -0,0 +1,91 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.transform import TransformLocation
17
+
18
+
19
+ __all__ = ["get_matrix_size", "apply_transform_weight"]
20
+
21
+
22
+ def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
23
+ """
24
+ Determine the size of a matrix given its location on the module
25
+
26
+ :param module: module that matrix will be applied to
27
+ :param location: location on module
28
+ :return: size of matrix
29
+ """
30
+ assert isinstance(module, torch.nn.Linear)
31
+ if location in ("input", TransformLocation.WEIGHT_INPUT):
32
+ return module.in_features
33
+ else:
34
+ return module.out_features
35
+
36
+
37
+ def apply_transform_weight(
38
+ weight: torch.Tensor,
39
+ value: torch.Tensor,
40
+ location: TransformLocation,
41
+ ) -> torch.Tensor:
42
+ """
43
+ Using the transform location, determine how to apply the transform weight to the
44
+ given value. For more info on input and output transforms, see `TransformLocation`
45
+
46
+ The following explains how weights should be applied to values according to location
47
+
48
+ let x be input activation
49
+ W be weight,
50
+ yh, xh, Wh be transformed output, input, weight
51
+
52
+ note that
53
+ y = (x W.T) // torch.nn.Linear
54
+
55
+ Choose values for yh, xh, and Wh which incorporate matrix transforms
56
+
57
+ let V, Vi be transform matrices on input side
58
+ U, Ui be transform matrices on output side
59
+
60
+ pick xh = (x V)
61
+ Wh = (U.T W Vi.T)
62
+ yh = (y U)
63
+
64
+ The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
65
+
66
+ (xh) (Wh).T = (x V) (U.T W Vi.T).T
67
+ = (x V) (Vi W.T U) // transpose matrix product identity
68
+ = (x W.T) U
69
+ = y U
70
+ = yh
71
+
72
+ :param weight: transform weight to apply
73
+ :param value: value to apply weight to
74
+ :param location: determines how weight should be applied
75
+ :return: value after transform weight has been applied
76
+ """
77
+
78
+ if location == TransformLocation.INPUT:
79
+ return value @ weight
80
+
81
+ elif location == TransformLocation.WEIGHT_INPUT:
82
+ return value @ weight.T
83
+
84
+ elif location == TransformLocation.WEIGHT_OUTPUT:
85
+ return weight.T @ value
86
+
87
+ elif location == TransformLocation.OUTPUT:
88
+ return value @ weight
89
+
90
+ else:
91
+ raise NotImplementedError(f"{location} has not been implemented yet")
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import contextlib
15
16
  import warnings
16
17
  from functools import wraps
17
18
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
@@ -38,6 +39,8 @@ __all__ = [
38
39
  "shard_tensor",
39
40
  "pack_bitmasks",
40
41
  "unpack_bitmasks",
42
+ "patch_attr",
43
+ "ParameterizedDefaultDict",
41
44
  ]
42
45
 
43
46
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -328,3 +331,53 @@ def unpack_bitmasks(
328
331
  )
329
332
 
330
333
  return unpacked_bitmasks_torch
334
+
335
+
336
+ @contextlib.contextmanager
337
+ def patch_attr(base: object, attr: str, value: Any):
338
+ """
339
+ Patch the value of an object attribute. Original value is restored upon exit
340
+
341
+ :param base: object which has the attribute to patch
342
+ :param attr: name of the the attribute to patch
343
+ :param value: used to replace original value
344
+
345
+ Usage:
346
+ >>> from types import SimpleNamespace
347
+ >>> obj = SimpleNamespace()
348
+ >>> with patch_attr(obj, "attribute", "value"):
349
+ ... assert obj.attribute == "value"
350
+ >>> assert not hasattr(obj, "attribute")
351
+ """
352
+ _sentinel = object()
353
+ original_value = getattr(base, attr, _sentinel)
354
+
355
+ setattr(base, attr, value)
356
+ try:
357
+ yield
358
+ finally:
359
+ if original_value is not _sentinel:
360
+ setattr(base, attr, original_value)
361
+ else:
362
+ delattr(base, attr)
363
+
364
+
365
+ class ParameterizedDefaultDict(dict):
366
+ """
367
+ Similar to `collections.DefaultDict`, but upon fetching a key which is missing,
368
+ the key is passed as arguments to the `default_factory`
369
+
370
+ :param default_factory: function which takes a key as input and returns the
371
+ corresponding default value
372
+ """
373
+
374
+ def __init__(self, default_factory: Callable[[Any], Any]):
375
+ self.default_factory = default_factory
376
+
377
+ def __missing__(self, key):
378
+ if isinstance(key, tuple):
379
+ value = self.default_factory(*key)
380
+ else:
381
+ value = self.default_factory(key)
382
+ self[key] = value
383
+ return value
@@ -87,13 +87,15 @@ def check_accelerate(fallback: Any):
87
87
  if not _has_accelerate:
88
88
 
89
89
  if fallback == "error":
90
- raise ValueError(
91
- "Please install `accelerate` in order to use this function"
92
- )
93
-
94
- @wraps(func)
95
- def fallback_fn(*args, **kwargs):
96
- return fallback
90
+ @wraps(func)
91
+ def fallback_fn(*args, **kwargs):
92
+ raise ValueError(
93
+ "Please install `accelerate` in order to use this function"
94
+ )
95
+ else:
96
+ @wraps(func)
97
+ def fallback_fn(*args, **kwargs):
98
+ return fallback
97
99
 
98
100
  return fallback_fn
99
101
 
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.0'
21
- __version_tuple__ = version_tuple = (0, 10, 0)
20
+ __version__ = version = '0.10.1'
21
+ __version_tuple__ = version_tuple = (0, 10, 1)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=P22A-D7Hg0yC0IOerZDcj2-6YOrcxCN9Sq5s06MywPA,513
3
+ compressed_tensors/version.py,sha256=StiR6uxiq6hqMzT3MUIl_ZooIq2cetH9oWrHUI_qWFU,513
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -40,18 +40,21 @@ compressed_tensors/quantization/utils/helpers.py,sha256=bqxNL2NU1XVsSxNzmDVZE3zd
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
41
  compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
42
  compressed_tensors/transform/__init__.py,sha256=oa5VdrE-GtDYYceXNSwj5X_ropoXLLukm6Aufcc9WhY,747
43
- compressed_tensors/transform/transform_args.py,sha256=Sazu_4kXL7IvIEgTaimgo8dV-qacXf_t1NLEfDvPJEU,1759
43
+ compressed_tensors/transform/transform_args.py,sha256=8-Ab5_dFfdObfwVCgrWrEWcoVRzXmMBSDSUxjftI-Ss,3177
44
44
  compressed_tensors/transform/transform_config.py,sha256=6JA8VFcoz4EGHOev6thj51OuB7K2gKUUazWjrVPYDLc,2144
45
45
  compressed_tensors/transform/transform_scheme.py,sha256=c7NAuLDL0itFgUfBMNShegMI9bzKL7s4LR3QJTHsXLs,1733
46
+ compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
47
+ compressed_tensors/transform/utils/hadamard.py,sha256=SmPZmnHtc5N36gJA5EbM1T65uf4w1_flgl7SWBeg_W8,5642
48
+ compressed_tensors/transform/utils/utils.py,sha256=PRPTYwPs2nnNaQMq2GEbC4QYKHFKlZwaRyPgdDhl66g,2992
46
49
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
47
- compressed_tensors/utils/helpers.py,sha256=RrNvzD08naEjEiXdU-FdZjQVda1nQywu1hA_GCDj0vg,10415
48
- compressed_tensors/utils/offload.py,sha256=hAGjp9aS0HpFVhjYMGf-WTm76WMY6cS-YXhVEn80qPE,20196
50
+ compressed_tensors/utils/helpers.py,sha256=cPg-ikdeA92aIGwBONg8GmPNvcGlFhozyJVwsRiXBTA,11981
51
+ compressed_tensors/utils/offload.py,sha256=fT7WiUQmRmJ2Reb3I5kNcsHy4YdmZJHSOTNdS0tbKQo,20316
49
52
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
50
53
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
51
54
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
52
55
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
53
- compressed_tensors-0.10.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
- compressed_tensors-0.10.0.dist-info/METADATA,sha256=LCNJhwDW8s0vzcb1XkGUzuKz2NTFKN1sbc5-xTx9pP4,6996
55
- compressed_tensors-0.10.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- compressed_tensors-0.10.0.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
57
- compressed_tensors-0.10.0.dist-info/RECORD,,
56
+ compressed_tensors-0.10.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
57
+ compressed_tensors-0.10.1.dist-info/METADATA,sha256=2y4RJsufdvf5Bap5PKk73UA3STedxdzbD0yRuZF21uc,6996
58
+ compressed_tensors-0.10.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
59
+ compressed_tensors-0.10.1.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
60
+ compressed_tensors-0.10.1.dist-info/RECORD,,