compressed-tensors 0.10.3a20250715__py3-none-any.whl → 0.10.3a20250716__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,10 +117,8 @@ class TransformFactory(RegistryMixin, ABC):
117
117
  TransformLocation.WEIGHT_INPUT,
118
118
  TransformLocation.WEIGHT_OUTPUT,
119
119
  ):
120
- assert isinstance(module, torch.nn.Linear)
121
- assert module.bias is None
122
-
123
120
  # fuse transform into weight
121
+ assert hasattr(module, "weight")
124
122
  with torch.no_grad(), align_module_device(module):
125
123
  update_offload_parameter(module, "weight", transform(module.weight))
126
124
 
@@ -14,13 +14,14 @@
14
14
 
15
15
  from typing import Optional, Union
16
16
 
17
+ import math
17
18
  import torch
18
19
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
20
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
21
  from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
21
- from compressed_tensors.transform.utils.utils import (
22
+ from compressed_tensors.transform.utils.matrix import (
22
23
  apply_transform_weight,
23
- get_matrix_size,
24
+ get_transform_size,
24
25
  )
25
26
  from compressed_tensors.utils import get_execution_device, get_offloaded_device
26
27
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -51,8 +52,8 @@ class HadamardFactory(TransformFactory):
51
52
  :param module: parent module that transform will be applied to
52
53
  :param args: defines how the transform will be applied to the module
53
54
  """
54
- assert isinstance(module, Linear)
55
- size = get_matrix_size(module, args.location)
55
+ assert hasattr(module, "weight")
56
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
56
57
  dtype = module.weight.dtype
57
58
  device = get_offloaded_device(module)
58
59
  exec_device = get_execution_device(module)
@@ -60,7 +61,7 @@ class HadamardFactory(TransformFactory):
60
61
  factory_kwargs = {"construct_device": exec_device}
61
62
  weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
62
63
  perm = self.perms[weight] if self.scheme.randomize else None
63
- return HadamardTransform(weight, perm, args)
64
+ return HadamardTransform(weight, perm, args, type(module))
64
65
 
65
66
  def _create_weight(
66
67
  self,
@@ -81,12 +82,18 @@ class HadamardFactory(TransformFactory):
81
82
 
82
83
  class HadamardTransform(TransformBase):
83
84
  def __init__(
84
- self, weight: Parameter, perm: Union[Parameter, None], args: TransformArgs
85
+ self,
86
+ weight: Parameter,
87
+ perm: Optional[Parameter],
88
+ args: TransformArgs,
89
+ module_type: type[torch.nn.Module],
85
90
  ):
86
91
  super().__init__()
87
92
  self.weight = weight
88
93
  self.perm = perm
89
94
  self.args = args
95
+ self.module_type = module_type
96
+ self._scale = math.sqrt(weight.size(0))
90
97
 
91
98
  def forward(self, value: Tensor) -> Tensor:
92
99
  weight = self.weight
@@ -96,5 +103,7 @@ class HadamardTransform(TransformBase):
96
103
 
97
104
  if self.args.inverse:
98
105
  weight = weight.T
99
-
100
- return apply_transform_weight(weight, value, self.args.location)
106
+
107
+ return apply_transform_weight(
108
+ weight, value, self.args.location, self.module_type
109
+ ) / self._scale
@@ -17,9 +17,9 @@ from typing import Optional
17
17
  import torch
18
18
  from compressed_tensors.transform import TransformArgs, TransformScheme
19
19
  from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
20
- from compressed_tensors.transform.utils.utils import (
20
+ from compressed_tensors.transform.utils.matrix import (
21
21
  apply_transform_weight,
22
- get_matrix_size,
22
+ get_transform_size,
23
23
  )
24
24
  from compressed_tensors.utils import get_offloaded_device
25
25
  from compressed_tensors.utils.helpers import ParameterizedDefaultDict
@@ -50,8 +50,8 @@ class RandomMatrixFactory(TransformFactory):
50
50
  :param module: parent module that transform will be applied to
51
51
  :param args: defines how the transform will be applied to the module
52
52
  """
53
- assert isinstance(module, Linear)
54
- size = get_matrix_size(module, args.location)
53
+ assert hasattr(module, "weight")
54
+ size = get_transform_size(module, args.location, self.scheme.head_dim)
55
55
  dtype = module.weight.dtype
56
56
  device = get_offloaded_device(module)
57
57
 
@@ -59,7 +59,7 @@ class RandomMatrixFactory(TransformFactory):
59
59
  if args.inverse:
60
60
  weight = self.inverses[weight]
61
61
 
62
- return RandomMatrixTransform(weight, args)
62
+ return RandomMatrixTransform(weight, args, type(module))
63
63
 
64
64
  def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65
65
  # TODO: verify that weight is invertible (has non-zero determinant)
@@ -74,17 +74,27 @@ class RandomMatrixFactory(TransformFactory):
74
74
 
75
75
 
76
76
  class RandomMatrixTransform(TransformBase):
77
- def __init__(self, weight: Tensor, args: TransformArgs):
77
+ def __init__(
78
+ self,
79
+ weight: Tensor,
80
+ args: TransformArgs,
81
+ module_type: type[torch.nn.Module],
82
+ ):
78
83
  super().__init__()
79
84
  self.weight = weight # is an inverse if args.inverse
80
85
  self.args = args
86
+ self.module_type = module_type
81
87
 
82
88
  def forward(self, value: Tensor) -> Parameter:
83
- return apply_transform_weight(self.weight, value, self.args.location)
89
+ return apply_transform_weight(
90
+ self.weight, value, self.args.location, self.module_type
91
+ )
84
92
 
85
93
  def right_inverse(self, value: Tensor) -> Tensor:
86
94
  inverse = high_precision_invert(self.weight)
87
- return apply_transform_weight(inverse, value, self.args.location)
95
+ return apply_transform_weight(
96
+ inverse, value, self.args.location, self.module_type
97
+ )
88
98
 
89
99
 
90
100
  def high_precision_invert(weight: Tensor) -> Tensor:
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import List
15
+ from typing import List, Optional
16
16
 
17
17
  from compressed_tensors.transform import TransformArgs
18
18
  from pydantic import BaseModel, Field
@@ -40,3 +40,4 @@ class TransformScheme(BaseModel):
40
40
  apply: List[TransformArgs] = Field(default_factory=list)
41
41
  randomize: bool = Field(default=False)
42
42
  requires_grad: bool = Field(default=False)
43
+ head_dim: Optional[int] = Field(default=None)
@@ -59,7 +59,7 @@ def deterministic_hadamard_matrix(
59
59
  for _ in range(log2):
60
60
  H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
61
61
 
62
- return H / math.sqrt(size)
62
+ return H
63
63
 
64
64
 
65
65
  def random_hadamard_matrix(
@@ -86,7 +86,7 @@ def random_hadamard_matrix(
86
86
  Q = Q.to(device=device)
87
87
  Q = Q * 2 - 1
88
88
  Q = torch.diag(Q)
89
- return _matmul_hadU(Q) / math.sqrt(size)
89
+ return _matmul_hadU(Q)
90
90
 
91
91
 
92
92
  def is_pow2(n: int) -> bool:
@@ -0,0 +1,179 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Callable, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.transform import TransformLocation
19
+
20
+
21
+ __all__ = ["get_transform_size", "apply_transform_weight"]
22
+
23
+
24
+ def get_transform_size(
25
+ module: torch.nn.Module,
26
+ location: TransformLocation,
27
+ head_dim: Optional[int] = None,
28
+ ) -> int:
29
+ """
30
+ Determine the size of a transform matrix given its location on the module
31
+
32
+ :param module: module that matrix will be applied to
33
+ :param location: location on module
34
+ :param head_dim: size of head when transform is applied to mha
35
+ :return: size of matrix
36
+ """
37
+ if isinstance(module, torch.nn.Linear):
38
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
39
+ size = module.in_features
40
+ else:
41
+ size = module.out_features
42
+ elif isinstance(module, torch.nn.Embedding):
43
+ if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
44
+ size = module.num_embeddings
45
+ else:
46
+ size = module.embedding_dim
47
+ else:
48
+ raise NotImplementedError(f"Transforms on {type(module)} are not supported")
49
+
50
+ if head_dim is not None:
51
+ if size % head_dim != 0:
52
+ raise ValueError(
53
+ f"{head_dim} must divide {size} for {type(module)} at {location}"
54
+ )
55
+
56
+ size = head_dim
57
+
58
+ return size
59
+
60
+
61
+ def apply_transform_weight(
62
+ transform_weight: torch.Tensor,
63
+ value: torch.Tensor,
64
+ location: TransformLocation,
65
+ module_type: type[torch.nn.Module],
66
+ ) -> torch.Tensor:
67
+ """
68
+ Using the transform location, apply the transform_weight to the
69
+ given value wrt linear weights. For more info on input and output transforms,
70
+ see `TransformLocation`
71
+
72
+ The following explains how weights should be applied to values according to location
73
+
74
+ let x be input activation
75
+ W be weight,
76
+ yh, xh, Wh be transformed output, input, weight
77
+
78
+ note that
79
+ y = (x W.T) // torch.nn.Linear
80
+
81
+ Choose values for yh, xh, and Wh which incorporate matrix transforms
82
+
83
+ let V, Vi be transform matrices on input side
84
+ U, Ui be transform matrices on output side
85
+
86
+ pick xh = (x V)
87
+ Wh = (U.T W Vi.T)
88
+ yh = (y U)
89
+
90
+ The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
91
+
92
+ (xh) (Wh).T = (x V) (U.T W Vi.T).T
93
+ = (x V) (Vi W.T U) // transpose matrix product identity
94
+ = (x W.T) U
95
+ = y U
96
+ = yh
97
+
98
+ :param transform_weight: transform weight to apply
99
+ :param value: value to apply transform_weight to
100
+ :param location: determines how weight should be applied
101
+ :param model_type: result of type(module), passed in to determine application of
102
+ weight transform
103
+ :return: value after transform_weight has been applied
104
+ """
105
+
106
+ assert transform_weight.shape[0] == transform_weight.shape[1]
107
+
108
+ if module_type == torch.nn.Linear:
109
+ if location == TransformLocation.INPUT:
110
+ return _multihead_matmul(value, transform_weight)
111
+
112
+ elif location == TransformLocation.WEIGHT_INPUT:
113
+ # equivalent to (transform_weight @ value.T).T
114
+ return _multihead_matmul(value, transform_weight.T)
115
+
116
+ elif location == TransformLocation.WEIGHT_OUTPUT:
117
+ # equivalent to (value.T @ transform_weight).T
118
+ return _multihead_matmul(transform_weight.T, value)
119
+
120
+ elif location == TransformLocation.OUTPUT:
121
+ return _multihead_matmul(value, transform_weight)
122
+
123
+ # similar derivation to torch.nn.Linear, but `y = (x W)`
124
+ elif module_type == torch.nn.Embedding:
125
+ if location == TransformLocation.INPUT:
126
+ return _multihead_matmul(value, transform_weight)
127
+
128
+ elif location == TransformLocation.WEIGHT_INPUT:
129
+ return _multihead_matmul(
130
+ transform_weight,
131
+ value,
132
+ )
133
+
134
+ elif location == TransformLocation.WEIGHT_OUTPUT:
135
+ return _multihead_matmul(value, transform_weight)
136
+
137
+ elif location == TransformLocation.OUTPUT:
138
+ return _multihead_matmul(value, transform_weight)
139
+
140
+ raise NotImplementedError(
141
+ f"Applying transforms to {module_type} {location} is not supported"
142
+ )
143
+
144
+
145
+ def _multihead_matmul(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ Performs A @ B for last two dims of two matrices A and B that possibly
148
+ have different shapes, as is the case in multi-headed dimension. If
149
+ shapes are different, this is equivalent to converting the last two dims
150
+ of the smaller matrix into a block-diagonal matrix with the same shape as
151
+ the last two dims of the larger matrix.
152
+
153
+ E.g. if A is half the size of B, this function will perform
154
+ [[A ] @ B
155
+ [ A]]
156
+
157
+ If B is a third of the size of A, this function will perform
158
+ A @ [[B ]
159
+ [ B ]
160
+ [ B]]
161
+
162
+ This function will error out if the shapes are not evenly divisble
163
+
164
+ :param A: left-hand tensor
165
+ :param B: right-hand tensor
166
+ :return: result
167
+ """
168
+ if A.shape[-1] > B.shape[-2]:
169
+ head_dim = B.shape[-2]
170
+ num_heads = A.shape[-1] // head_dim
171
+ A = A.unflatten(-1, (num_heads, head_dim))
172
+ return (A @ B).flatten(-2, -1)
173
+ elif A.shape[-1] < B.shape[-2]:
174
+ head_dim = A.shape[-1]
175
+ num_heads = B.shape[-2] // head_dim
176
+ B = B.unflatten(-2, (num_heads, head_dim))
177
+ return (A @ B).flatten(-3, -2)
178
+ else:
179
+ return A @ B
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.10.3.a20250715'
20
+ __version__ = version = '0.10.3.a20250716'
21
21
  __version_tuple__ = version_tuple = (0, 10, 3)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.10.3a20250715
3
+ Version: 0.10.3a20250716
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=4fhlyf_Dxsw-6e50R1K8Mn8AWFNP0TiRJ8S21Gge6UA,523
3
+ compressed_tensors/version.py,sha256=DSQFEQZQHt-pmBpCtXPX8Vc2dUeG5ueZMmYLIwONR1c,523
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -43,16 +43,16 @@ compressed_tensors/transform/__init__.py,sha256=v2wfl4CMfA6KbD7Hxx_MbRev63y_6QLD
43
43
  compressed_tensors/transform/apply.py,sha256=Cnc7Q8d8FzpLGtXixvdPzqApfjAXpfShxvVl_7nNJ4E,1259
44
44
  compressed_tensors/transform/transform_args.py,sha256=jJY-Qt996w45LWQ10AHd7tUtNrnV9mjD9M5D4SZ5B3E,3199
45
45
  compressed_tensors/transform/transform_config.py,sha256=A3RuLNDqBNEByQNeu40Kg7sItwE6kWgnX18Umg1uONI,2128
46
- compressed_tensors/transform/transform_scheme.py,sha256=JAFQoCiNLg04diXG5KsynRGcLIB0Y0tC5s8U7HoDM7c,1692
46
+ compressed_tensors/transform/transform_scheme.py,sha256=uGLC4avdbhrVqNC3-Eo0p7WzNRQK92Fpg0N9hWiuCRQ,1752
47
47
  compressed_tensors/transform/factory/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
48
- compressed_tensors/transform/factory/base.py,sha256=D5s6dOn01Yy_Z2v8YmYIYfzJie2BYV6uxBJKP_9jTCQ,6044
49
- compressed_tensors/transform/factory/hadamard.py,sha256=oLdDUu1p82lgD7li-sHMSvXZxz1SDjLeYf-EfXqNzvk,3918
50
- compressed_tensors/transform/factory/matrix_multiply.py,sha256=KYiQRGFSU33TpPWkGTKwNADTmYoU0E3hjQypOMclHbg,3689
48
+ compressed_tensors/transform/factory/base.py,sha256=w9ic5eSxfSNn2Xju-xZvG4_iXAIsJCU56qik8w---aI,5994
49
+ compressed_tensors/transform/factory/hadamard.py,sha256=iJ2OyKitR2Duw0z5Jqj69GTih2C1WtHRXQCTtATaTtw,4180
50
+ compressed_tensors/transform/factory/matrix_multiply.py,sha256=LdoV2E12HTucmUWcw7UKOpRNnL8QhOOIUnNVlpOpGiI,3925
51
51
  compressed_tensors/transform/factory/random_hadamard.py,sha256=nUhTlFa4ikSpcl4Umme71pnjMPgwYoGlwjKlU27UHZ4,1634
52
52
  compressed_tensors/transform/utils/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
53
- compressed_tensors/transform/utils/hadamard.py,sha256=U27Kvo-eDebKcVt8oXTSIAaQ5DvPQj9tDv2hdXHCPPQ,5584
53
+ compressed_tensors/transform/utils/hadamard.py,sha256=hDJZC0Gw2fKdxqa3f8TmFc5J0eJqxHtFRxswLU_yVJc,5548
54
54
  compressed_tensors/transform/utils/hadamards.safetensors,sha256=mFd1GzNodGG-ifA1IoH-0nHYzfraCOvrq_dX2zFI1B4,1436901
55
- compressed_tensors/transform/utils/utils.py,sha256=PRPTYwPs2nnNaQMq2GEbC4QYKHFKlZwaRyPgdDhl66g,2992
55
+ compressed_tensors/transform/utils/matrix.py,sha256=FIHCUlpWVIIhdr3c6EbQec41JeiPAAjCM9Ejz77wb-w,6181
56
56
  compressed_tensors/utils/__init__.py,sha256=QFQzF6MpV3yStajPzYktZkmvZsxvfpKUZq2oGbd1Cvw,832
57
57
  compressed_tensors/utils/helpers.py,sha256=Q3iRAa2XSdmmn4vSpUplnvKOmWwn4Clao9ZkPBHXtpI,12604
58
58
  compressed_tensors/utils/internal.py,sha256=7SSWgDoNFRnlfadwkoFhLW-T2jOc7Po_WzWv5h32Sa8,982
@@ -61,8 +61,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
61
61
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
62
62
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
63
63
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
64
- compressed_tensors-0.10.3a20250715.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
- compressed_tensors-0.10.3a20250715.dist-info/METADATA,sha256=Pb3ZHNAmiFtJHKdWbAA7m2KvmSttY8EV2huSI71ZBGM,7031
66
- compressed_tensors-0.10.3a20250715.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
- compressed_tensors-0.10.3a20250715.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
68
- compressed_tensors-0.10.3a20250715.dist-info/RECORD,,
64
+ compressed_tensors-0.10.3a20250716.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
65
+ compressed_tensors-0.10.3a20250716.dist-info/METADATA,sha256=zWpNTNT2bbSxFEQ58pvjEqdAE8MzBBZhvvtVVh6AQ14,7031
66
+ compressed_tensors-0.10.3a20250716.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
+ compressed_tensors-0.10.3a20250716.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
68
+ compressed_tensors-0.10.3a20250716.dist-info/RECORD,,
@@ -1,91 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from compressed_tensors.transform import TransformLocation
17
-
18
-
19
- __all__ = ["get_matrix_size", "apply_transform_weight"]
20
-
21
-
22
- def get_matrix_size(module: torch.nn.Module, location: TransformLocation) -> int:
23
- """
24
- Determine the size of a matrix given its location on the module
25
-
26
- :param module: module that matrix will be applied to
27
- :param location: location on module
28
- :return: size of matrix
29
- """
30
- assert isinstance(module, torch.nn.Linear)
31
- if location in ("input", TransformLocation.WEIGHT_INPUT):
32
- return module.in_features
33
- else:
34
- return module.out_features
35
-
36
-
37
- def apply_transform_weight(
38
- weight: torch.Tensor,
39
- value: torch.Tensor,
40
- location: TransformLocation,
41
- ) -> torch.Tensor:
42
- """
43
- Using the transform location, determine how to apply the transform weight to the
44
- given value. For more info on input and output transforms, see `TransformLocation`
45
-
46
- The following explains how weights should be applied to values according to location
47
-
48
- let x be input activation
49
- W be weight,
50
- yh, xh, Wh be transformed output, input, weight
51
-
52
- note that
53
- y = (x W.T) // torch.nn.Linear
54
-
55
- Choose values for yh, xh, and Wh which incorporate matrix transforms
56
-
57
- let V, Vi be transform matrices on input side
58
- U, Ui be transform matrices on output side
59
-
60
- pick xh = (x V)
61
- Wh = (U.T W Vi.T)
62
- yh = (y U)
63
-
64
- The following shows that `yh = (xh) (Wh).T` for the chosen values of yh, xh, and Wh
65
-
66
- (xh) (Wh).T = (x V) (U.T W Vi.T).T
67
- = (x V) (Vi W.T U) // transpose matrix product identity
68
- = (x W.T) U
69
- = y U
70
- = yh
71
-
72
- :param weight: transform weight to apply
73
- :param value: value to apply weight to
74
- :param location: determines how weight should be applied
75
- :return: value after transform weight has been applied
76
- """
77
-
78
- if location == TransformLocation.INPUT:
79
- return value @ weight
80
-
81
- elif location == TransformLocation.WEIGHT_INPUT:
82
- return value @ weight.T
83
-
84
- elif location == TransformLocation.WEIGHT_OUTPUT:
85
- return weight.T @ value
86
-
87
- elif location == TransformLocation.OUTPUT:
88
- return value @ weight
89
-
90
- else:
91
- raise NotImplementedError(f"{location} has not been implemented yet")