broccoli-ml 9.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/__init__.py +6 -0
- broccoli/activation.py +118 -0
- broccoli/cnn.py +157 -0
- broccoli/linear.py +352 -0
- broccoli/rope.py +407 -0
- broccoli/tensor.py +128 -0
- broccoli/transformer.py +779 -0
- broccoli/utils.py +15 -0
- broccoli/vit.py +600 -0
- broccoli_ml-9.5.1.dist-info/LICENSE +21 -0
- broccoli_ml-9.5.1.dist-info/METADATA +43 -0
- broccoli_ml-9.5.1.dist-info/RECORD +13 -0
- broccoli_ml-9.5.1.dist-info/WHEEL +4 -0
broccoli/rope.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This code is from https://github.com/lucidrains/rotary-embedding-torch
|
|
3
|
+
|
|
4
|
+
It is provided along with the following copyright notice and license:
|
|
5
|
+
|
|
6
|
+
MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2021 Phil Wang
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
from math import pi
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
|
|
34
|
+
from torch.nn import Module
|
|
35
|
+
from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor
|
|
36
|
+
|
|
37
|
+
# Gracefully find the best way to import autocast
|
|
38
|
+
try:
|
|
39
|
+
from torch.amp import autocast as autocast_factory
|
|
40
|
+
except ImportError:
|
|
41
|
+
# Fallback: For PyTorch 1.6 to 1.9
|
|
42
|
+
from torch.cuda.amp import autocast
|
|
43
|
+
|
|
44
|
+
def autocast_factory(_, enabled=True):
|
|
45
|
+
"""
|
|
46
|
+
A wrapper that mimics the modern autocast signature but calls the older
|
|
47
|
+
torch.cuda.amp.autocast, ignoring the device_type argument.
|
|
48
|
+
"""
|
|
49
|
+
return autocast(enabled=enabled)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
from einops import rearrange, repeat
|
|
53
|
+
|
|
54
|
+
from typing import Literal
|
|
55
|
+
|
|
56
|
+
# helper functions
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def exists(val):
|
|
60
|
+
return val is not None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def default(val, d):
|
|
64
|
+
return val if exists(val) else d
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# broadcat, as tortoise-tts was using it
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def broadcat(tensors, dim=-1):
|
|
71
|
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
|
72
|
+
return torch.cat(broadcasted_tensors, dim=dim)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def slice_at_dim(t, dim_slice: slice, *, dim):
|
|
76
|
+
dim += t.ndim if dim < 0 else 0
|
|
77
|
+
colons = [slice(None)] * t.ndim
|
|
78
|
+
colons[dim] = dim_slice
|
|
79
|
+
return t[tuple(colons)]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
# rotary embedding helper functions
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def rotate_half(x):
|
|
86
|
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
|
87
|
+
x1, x2 = x.unbind(dim=-1)
|
|
88
|
+
x = torch.stack((-x2, x1), dim=-1)
|
|
89
|
+
return rearrange(x, "... d r -> ... (d r)")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@autocast_factory("cuda", enabled=False)
|
|
93
|
+
def apply_rotary_emb(
|
|
94
|
+
freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None
|
|
95
|
+
):
|
|
96
|
+
dtype = t.dtype
|
|
97
|
+
|
|
98
|
+
if not exists(freqs_seq_dim):
|
|
99
|
+
if freqs.ndim == 2 or t.ndim == 3:
|
|
100
|
+
freqs_seq_dim = 0
|
|
101
|
+
|
|
102
|
+
if t.ndim == 3 or exists(freqs_seq_dim):
|
|
103
|
+
seq_len = t.shape[seq_dim]
|
|
104
|
+
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
|
|
105
|
+
|
|
106
|
+
rot_dim = freqs.shape[-1]
|
|
107
|
+
end_index = start_index + rot_dim
|
|
108
|
+
|
|
109
|
+
assert rot_dim <= t.shape[-1], (
|
|
110
|
+
f"feature dimension {t.shape[-1]} is not of sufficient size "
|
|
111
|
+
"to rotate in all the positions {rot_dim}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Split t into three parts: left, middle (to be transformed), and right
|
|
115
|
+
t_left = t[..., :start_index]
|
|
116
|
+
t_middle = t[..., start_index:end_index]
|
|
117
|
+
t_right = t[..., end_index:]
|
|
118
|
+
|
|
119
|
+
# Apply rotary embeddings without modifying t in place
|
|
120
|
+
t_transformed = (t_middle * freqs.cos() * scale) + (
|
|
121
|
+
rotate_half(t_middle) * freqs.sin() * scale
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
|
125
|
+
|
|
126
|
+
return out.type(dtype)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# learned rotation helpers
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
|
133
|
+
if exists(freq_ranges):
|
|
134
|
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
|
135
|
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
|
136
|
+
|
|
137
|
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
|
138
|
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# classes
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class RotaryEmbedding(Module):
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
dim,
|
|
148
|
+
custom_freqs: Tensor | None = None,
|
|
149
|
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
|
150
|
+
theta=10000,
|
|
151
|
+
max_freq=10,
|
|
152
|
+
num_freqs=1,
|
|
153
|
+
learned_freq=False,
|
|
154
|
+
use_xpos=False,
|
|
155
|
+
xpos_scale_base=512,
|
|
156
|
+
interpolate_factor=1.0,
|
|
157
|
+
theta_rescale_factor=1.0,
|
|
158
|
+
seq_before_head_dim=False,
|
|
159
|
+
cache_if_possible=True,
|
|
160
|
+
cache_max_seq_len=8192,
|
|
161
|
+
):
|
|
162
|
+
super().__init__()
|
|
163
|
+
# proposed by reddit user bloc97,
|
|
164
|
+
# to rescale rotary embeddings to longer sequence length without fine-tuning
|
|
165
|
+
# has some connection to NTK literature
|
|
166
|
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
|
167
|
+
|
|
168
|
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
|
169
|
+
|
|
170
|
+
self.freqs_for = freqs_for
|
|
171
|
+
|
|
172
|
+
if exists(custom_freqs):
|
|
173
|
+
freqs = custom_freqs
|
|
174
|
+
elif freqs_for == "lang":
|
|
175
|
+
freqs = 1.0 / (
|
|
176
|
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
|
177
|
+
)
|
|
178
|
+
elif freqs_for == "pixel":
|
|
179
|
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
|
180
|
+
elif freqs_for == "constant":
|
|
181
|
+
freqs = torch.ones(num_freqs).float()
|
|
182
|
+
|
|
183
|
+
self.cache_if_possible = cache_if_possible
|
|
184
|
+
self.cache_max_seq_len = cache_max_seq_len
|
|
185
|
+
|
|
186
|
+
self.register_buffer(
|
|
187
|
+
"cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False
|
|
188
|
+
)
|
|
189
|
+
self.cached_freqs_seq_len = 0
|
|
190
|
+
|
|
191
|
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
|
192
|
+
|
|
193
|
+
self.learned_freq = learned_freq
|
|
194
|
+
|
|
195
|
+
# dummy for device
|
|
196
|
+
|
|
197
|
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
|
198
|
+
|
|
199
|
+
# default sequence dimension
|
|
200
|
+
|
|
201
|
+
self.seq_before_head_dim = seq_before_head_dim
|
|
202
|
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
|
203
|
+
|
|
204
|
+
# interpolation factors
|
|
205
|
+
|
|
206
|
+
assert interpolate_factor >= 1.0
|
|
207
|
+
self.interpolate_factor = interpolate_factor
|
|
208
|
+
|
|
209
|
+
# xpos
|
|
210
|
+
|
|
211
|
+
self.use_xpos = use_xpos
|
|
212
|
+
|
|
213
|
+
if not use_xpos:
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
217
|
+
self.scale_base = xpos_scale_base
|
|
218
|
+
|
|
219
|
+
self.register_buffer("scale", scale, persistent=False)
|
|
220
|
+
self.register_buffer(
|
|
221
|
+
"cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False
|
|
222
|
+
)
|
|
223
|
+
self.cached_scales_seq_len = 0
|
|
224
|
+
|
|
225
|
+
# add apply_rotary_emb as static method
|
|
226
|
+
|
|
227
|
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def device(self):
|
|
231
|
+
return self.dummy.device
|
|
232
|
+
|
|
233
|
+
def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
|
|
234
|
+
device = default(device, self.device)
|
|
235
|
+
dtype = default(dtype, self.cached_freqs.dtype)
|
|
236
|
+
|
|
237
|
+
return (
|
|
238
|
+
torch.arange(seq_len, device=device, dtype=dtype) + offset
|
|
239
|
+
) / self.interpolate_factor
|
|
240
|
+
|
|
241
|
+
def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
|
|
242
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
243
|
+
|
|
244
|
+
assert not self.use_xpos or exists(scale), (
|
|
245
|
+
"you must use `.rotate_queries_and_keys` method instead and pass "
|
|
246
|
+
"in both queries and keys, for length extrapolatable rotary embeddings"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
|
250
|
+
|
|
251
|
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
|
252
|
+
|
|
253
|
+
freqs = self.forward(seq, seq_len=seq_len, offset=offset)
|
|
254
|
+
|
|
255
|
+
if seq_dim == -3:
|
|
256
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
257
|
+
|
|
258
|
+
return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
|
259
|
+
|
|
260
|
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
|
261
|
+
dtype, device, seq_dim = (
|
|
262
|
+
q.dtype,
|
|
263
|
+
q.device,
|
|
264
|
+
default(seq_dim, self.default_seq_dim),
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
|
268
|
+
assert q_len <= k_len
|
|
269
|
+
|
|
270
|
+
q_scale = k_scale = 1.0
|
|
271
|
+
|
|
272
|
+
if self.use_xpos:
|
|
273
|
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
|
274
|
+
|
|
275
|
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
|
276
|
+
k_scale = self.get_scale(seq).type(dtype)
|
|
277
|
+
|
|
278
|
+
rotated_q = self.rotate_queries_or_keys(
|
|
279
|
+
q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset
|
|
280
|
+
)
|
|
281
|
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
|
282
|
+
|
|
283
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
284
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
285
|
+
|
|
286
|
+
return rotated_q, rotated_k
|
|
287
|
+
|
|
288
|
+
def rotate_queries_and_keys(self, q, k, seq_dim=None):
|
|
289
|
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
|
290
|
+
|
|
291
|
+
assert self.use_xpos
|
|
292
|
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
|
293
|
+
|
|
294
|
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
|
295
|
+
|
|
296
|
+
freqs = self.forward(seq, seq_len=seq_len)
|
|
297
|
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
|
298
|
+
|
|
299
|
+
if seq_dim == -3:
|
|
300
|
+
freqs = rearrange(freqs, "n d -> n 1 d")
|
|
301
|
+
scale = rearrange(scale, "n d -> n 1 d")
|
|
302
|
+
|
|
303
|
+
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
|
|
304
|
+
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
|
305
|
+
|
|
306
|
+
rotated_q = rotated_q.type(q.dtype)
|
|
307
|
+
rotated_k = rotated_k.type(k.dtype)
|
|
308
|
+
|
|
309
|
+
return rotated_q, rotated_k
|
|
310
|
+
|
|
311
|
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
312
|
+
assert self.use_xpos
|
|
313
|
+
|
|
314
|
+
should_cache = (
|
|
315
|
+
self.cache_if_possible
|
|
316
|
+
and exists(seq_len)
|
|
317
|
+
and (offset + seq_len) <= self.cache_max_seq_len
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if (
|
|
321
|
+
should_cache
|
|
322
|
+
and exists(self.cached_scales)
|
|
323
|
+
and (seq_len + offset) <= self.cached_scales_seq_len
|
|
324
|
+
):
|
|
325
|
+
return self.cached_scales[offset : (offset + seq_len)]
|
|
326
|
+
|
|
327
|
+
scale = 1.0
|
|
328
|
+
if self.use_xpos:
|
|
329
|
+
power = (t - len(t) // 2) / self.scale_base
|
|
330
|
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
331
|
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
|
332
|
+
|
|
333
|
+
if should_cache and offset == 0:
|
|
334
|
+
self.cached_scales[:seq_len] = scale.detach()
|
|
335
|
+
self.cached_scales_seq_len = seq_len
|
|
336
|
+
|
|
337
|
+
return scale
|
|
338
|
+
|
|
339
|
+
def get_axial_freqs(
|
|
340
|
+
self, *dims, offsets: tuple[int | float, ...] | Tensor | None = None
|
|
341
|
+
):
|
|
342
|
+
Colon = slice(None)
|
|
343
|
+
all_freqs = []
|
|
344
|
+
|
|
345
|
+
# handle offset
|
|
346
|
+
|
|
347
|
+
if exists(offsets):
|
|
348
|
+
if not is_tensor(offsets):
|
|
349
|
+
offsets = tensor(offsets)
|
|
350
|
+
|
|
351
|
+
assert len(offsets) == len(dims)
|
|
352
|
+
|
|
353
|
+
# get frequencies for each axis
|
|
354
|
+
|
|
355
|
+
for ind, dim in enumerate(dims):
|
|
356
|
+
|
|
357
|
+
offset = 0
|
|
358
|
+
if exists(offsets):
|
|
359
|
+
offset = offsets[ind]
|
|
360
|
+
|
|
361
|
+
if self.freqs_for == "pixel":
|
|
362
|
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
|
363
|
+
else:
|
|
364
|
+
pos = torch.arange(dim, device=self.device)
|
|
365
|
+
|
|
366
|
+
pos = pos + offset
|
|
367
|
+
|
|
368
|
+
freqs = self.forward(pos, seq_len=dim)
|
|
369
|
+
|
|
370
|
+
all_axis = [None] * len(dims)
|
|
371
|
+
all_axis[ind] = Colon
|
|
372
|
+
|
|
373
|
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
|
374
|
+
all_freqs.append(freqs[new_axis_slice])
|
|
375
|
+
|
|
376
|
+
# concat all freqs
|
|
377
|
+
|
|
378
|
+
all_freqs = broadcast_tensors(*all_freqs)
|
|
379
|
+
return torch.cat(all_freqs, dim=-1)
|
|
380
|
+
|
|
381
|
+
@autocast_factory("cuda", enabled=False)
|
|
382
|
+
def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
383
|
+
should_cache = (
|
|
384
|
+
self.cache_if_possible
|
|
385
|
+
and not self.learned_freq
|
|
386
|
+
and exists(seq_len)
|
|
387
|
+
and self.freqs_for != "pixel"
|
|
388
|
+
and (offset + seq_len) <= self.cache_max_seq_len
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if (
|
|
392
|
+
should_cache
|
|
393
|
+
and exists(self.cached_freqs)
|
|
394
|
+
and (offset + seq_len) <= self.cached_freqs_seq_len
|
|
395
|
+
):
|
|
396
|
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
|
397
|
+
|
|
398
|
+
freqs = self.freqs
|
|
399
|
+
|
|
400
|
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
|
401
|
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
|
402
|
+
|
|
403
|
+
if should_cache and offset == 0:
|
|
404
|
+
self.cached_freqs[:seq_len] = freqs.detach()
|
|
405
|
+
self.cached_freqs_seq_len = seq_len
|
|
406
|
+
|
|
407
|
+
return freqs
|
broccoli/tensor.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SigmaReparamTensor(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Inspired by Apple's Spectral Normed Linear Layers
|
|
9
|
+
(https://github.com/apple/ml-sigma-reparam)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, init_tensor: torch.Tensor):
|
|
13
|
+
assert init_tensor.ndim == 2
|
|
14
|
+
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
self.sigma_reparam_tensor = nn.Parameter(init_tensor, requires_grad=True)
|
|
18
|
+
|
|
19
|
+
with torch.no_grad():
|
|
20
|
+
_, sigma, v_transpose = torch.linalg.svd(
|
|
21
|
+
self.sigma_reparam_tensor, full_matrices=False
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
self.register_buffer("approx_spectral_norm", sigma[:1])
|
|
25
|
+
self.register_buffer("right_singular", v_transpose[0])
|
|
26
|
+
self.sigma_reparam_scale = nn.Parameter(
|
|
27
|
+
self.approx_spectral_norm.clone().detach(), requires_grad=True
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def power_iteration(self):
|
|
31
|
+
with torch.no_grad():
|
|
32
|
+
approx_right_singular_transpose = self.sigma_reparam_tensor.mv(
|
|
33
|
+
self.right_singular
|
|
34
|
+
)
|
|
35
|
+
approx_right_singular_transpose = F.normalize(
|
|
36
|
+
approx_right_singular_transpose, dim=0
|
|
37
|
+
)
|
|
38
|
+
updated_right_singular = self.sigma_reparam_tensor.T.mv(
|
|
39
|
+
approx_right_singular_transpose
|
|
40
|
+
)
|
|
41
|
+
updated_right_singular = F.normalize(updated_right_singular, dim=0)
|
|
42
|
+
self.right_singular.data.copy_(updated_right_singular)
|
|
43
|
+
rayleigh_quotient = torch.einsum(
|
|
44
|
+
"m,mn,n->",
|
|
45
|
+
approx_right_singular_transpose,
|
|
46
|
+
self.sigma_reparam_tensor,
|
|
47
|
+
updated_right_singular,
|
|
48
|
+
)
|
|
49
|
+
self.approx_spectral_norm.data.copy_(rayleigh_quotient)
|
|
50
|
+
|
|
51
|
+
def forward(self):
|
|
52
|
+
if self.training:
|
|
53
|
+
self.power_iteration()
|
|
54
|
+
return self.sigma_reparam_scale * (
|
|
55
|
+
self.sigma_reparam_tensor / self.approx_spectral_norm
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AnchoredReparamTensor(nn.Module):
|
|
60
|
+
"""
|
|
61
|
+
Reparameterises a tensor by decoupling its magnitude and direction.
|
|
62
|
+
|
|
63
|
+
The direction is represented by a learnable weight tensor, normalised by the
|
|
64
|
+
Rayleigh quotient with respect to its initial dominant right-singular vector.
|
|
65
|
+
The magnitude is a separate learnable scalar.
|
|
66
|
+
|
|
67
|
+
The reparameterization is:
|
|
68
|
+
|
|
69
|
+
W_reparam = scale * (W / norm)
|
|
70
|
+
|
|
71
|
+
where the norm is the Rayleigh quotient uᵀWv₀, with v₀ being the dominant
|
|
72
|
+
right-singular vector of the initial tensor and u = normalize(Wv₀).
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(self, init_tensor: torch.Tensor):
|
|
76
|
+
assert init_tensor.ndim == 2
|
|
77
|
+
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
self.weight = nn.Parameter(init_tensor, requires_grad=True)
|
|
81
|
+
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
_, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
|
|
84
|
+
|
|
85
|
+
self.register_buffer("rayleigh_norm", sigma[:1])
|
|
86
|
+
self.register_buffer("initial_right_singular", v_transpose[0])
|
|
87
|
+
self.nondecay_scale = nn.Parameter(
|
|
88
|
+
sigma[:1].clone().detach(), requires_grad=True
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def _update_rayleigh_norm(self):
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
product = self.weight.mv(self.initial_right_singular)
|
|
94
|
+
normed_product = F.normalize(product, dim=0)
|
|
95
|
+
rayleigh_norm = torch.einsum(
|
|
96
|
+
"m,mn,n->",
|
|
97
|
+
normed_product,
|
|
98
|
+
self.weight,
|
|
99
|
+
self.initial_right_singular,
|
|
100
|
+
)
|
|
101
|
+
self.rayleigh_norm.data.copy_(rayleigh_norm)
|
|
102
|
+
|
|
103
|
+
def forward(self):
|
|
104
|
+
if self.training:
|
|
105
|
+
self._update_rayleigh_norm()
|
|
106
|
+
return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class NormReparamTensor(nn.Module):
|
|
110
|
+
"""
|
|
111
|
+
Reparameterise a tensor as a normalised tensor of weights multiplied by a
|
|
112
|
+
learnable scaling factor.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, init_tensor: torch.Tensor):
|
|
116
|
+
assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
|
|
117
|
+
super().__init__()
|
|
118
|
+
|
|
119
|
+
# Use the gradboard convention of calling something nondecay_* if we should
|
|
120
|
+
# exclude it from weight decay
|
|
121
|
+
self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
|
122
|
+
self.nondecay_scale = nn.Parameter(
|
|
123
|
+
torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def forward(self) -> torch.Tensor:
|
|
127
|
+
norm = torch.linalg.norm(self.weight)
|
|
128
|
+
return self.nondecay_scale * (self.weight / (norm + 1e-6))
|