broccoli-ml 9.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/rope.py ADDED
@@ -0,0 +1,407 @@
1
+ """
2
+ This code is from https://github.com/lucidrains/rotary-embedding-torch
3
+
4
+ It is provided along with the following copyright notice and license:
5
+
6
+ MIT License
7
+
8
+ Copyright (c) 2021 Phil Wang
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+ """
28
+
29
+ from __future__ import annotations
30
+ from math import pi
31
+
32
+ import torch
33
+
34
+ from torch.nn import Module
35
+ from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor
36
+
37
+ # Gracefully find the best way to import autocast
38
+ try:
39
+ from torch.amp import autocast as autocast_factory
40
+ except ImportError:
41
+ # Fallback: For PyTorch 1.6 to 1.9
42
+ from torch.cuda.amp import autocast
43
+
44
+ def autocast_factory(_, enabled=True):
45
+ """
46
+ A wrapper that mimics the modern autocast signature but calls the older
47
+ torch.cuda.amp.autocast, ignoring the device_type argument.
48
+ """
49
+ return autocast(enabled=enabled)
50
+
51
+
52
+ from einops import rearrange, repeat
53
+
54
+ from typing import Literal
55
+
56
+ # helper functions
57
+
58
+
59
+ def exists(val):
60
+ return val is not None
61
+
62
+
63
+ def default(val, d):
64
+ return val if exists(val) else d
65
+
66
+
67
+ # broadcat, as tortoise-tts was using it
68
+
69
+
70
+ def broadcat(tensors, dim=-1):
71
+ broadcasted_tensors = broadcast_tensors(*tensors)
72
+ return torch.cat(broadcasted_tensors, dim=dim)
73
+
74
+
75
+ def slice_at_dim(t, dim_slice: slice, *, dim):
76
+ dim += t.ndim if dim < 0 else 0
77
+ colons = [slice(None)] * t.ndim
78
+ colons[dim] = dim_slice
79
+ return t[tuple(colons)]
80
+
81
+
82
+ # rotary embedding helper functions
83
+
84
+
85
+ def rotate_half(x):
86
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
87
+ x1, x2 = x.unbind(dim=-1)
88
+ x = torch.stack((-x2, x1), dim=-1)
89
+ return rearrange(x, "... d r -> ... (d r)")
90
+
91
+
92
+ @autocast_factory("cuda", enabled=False)
93
+ def apply_rotary_emb(
94
+ freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None
95
+ ):
96
+ dtype = t.dtype
97
+
98
+ if not exists(freqs_seq_dim):
99
+ if freqs.ndim == 2 or t.ndim == 3:
100
+ freqs_seq_dim = 0
101
+
102
+ if t.ndim == 3 or exists(freqs_seq_dim):
103
+ seq_len = t.shape[seq_dim]
104
+ freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim)
105
+
106
+ rot_dim = freqs.shape[-1]
107
+ end_index = start_index + rot_dim
108
+
109
+ assert rot_dim <= t.shape[-1], (
110
+ f"feature dimension {t.shape[-1]} is not of sufficient size "
111
+ "to rotate in all the positions {rot_dim}"
112
+ )
113
+
114
+ # Split t into three parts: left, middle (to be transformed), and right
115
+ t_left = t[..., :start_index]
116
+ t_middle = t[..., start_index:end_index]
117
+ t_right = t[..., end_index:]
118
+
119
+ # Apply rotary embeddings without modifying t in place
120
+ t_transformed = (t_middle * freqs.cos() * scale) + (
121
+ rotate_half(t_middle) * freqs.sin() * scale
122
+ )
123
+
124
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
125
+
126
+ return out.type(dtype)
127
+
128
+
129
+ # learned rotation helpers
130
+
131
+
132
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
133
+ if exists(freq_ranges):
134
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
135
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
136
+
137
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
138
+ return apply_rotary_emb(rotations, t, start_index=start_index)
139
+
140
+
141
+ # classes
142
+
143
+
144
+ class RotaryEmbedding(Module):
145
+ def __init__(
146
+ self,
147
+ dim,
148
+ custom_freqs: Tensor | None = None,
149
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
150
+ theta=10000,
151
+ max_freq=10,
152
+ num_freqs=1,
153
+ learned_freq=False,
154
+ use_xpos=False,
155
+ xpos_scale_base=512,
156
+ interpolate_factor=1.0,
157
+ theta_rescale_factor=1.0,
158
+ seq_before_head_dim=False,
159
+ cache_if_possible=True,
160
+ cache_max_seq_len=8192,
161
+ ):
162
+ super().__init__()
163
+ # proposed by reddit user bloc97,
164
+ # to rescale rotary embeddings to longer sequence length without fine-tuning
165
+ # has some connection to NTK literature
166
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
167
+
168
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
169
+
170
+ self.freqs_for = freqs_for
171
+
172
+ if exists(custom_freqs):
173
+ freqs = custom_freqs
174
+ elif freqs_for == "lang":
175
+ freqs = 1.0 / (
176
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
177
+ )
178
+ elif freqs_for == "pixel":
179
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
180
+ elif freqs_for == "constant":
181
+ freqs = torch.ones(num_freqs).float()
182
+
183
+ self.cache_if_possible = cache_if_possible
184
+ self.cache_max_seq_len = cache_max_seq_len
185
+
186
+ self.register_buffer(
187
+ "cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False
188
+ )
189
+ self.cached_freqs_seq_len = 0
190
+
191
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
192
+
193
+ self.learned_freq = learned_freq
194
+
195
+ # dummy for device
196
+
197
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
198
+
199
+ # default sequence dimension
200
+
201
+ self.seq_before_head_dim = seq_before_head_dim
202
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
203
+
204
+ # interpolation factors
205
+
206
+ assert interpolate_factor >= 1.0
207
+ self.interpolate_factor = interpolate_factor
208
+
209
+ # xpos
210
+
211
+ self.use_xpos = use_xpos
212
+
213
+ if not use_xpos:
214
+ return
215
+
216
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
217
+ self.scale_base = xpos_scale_base
218
+
219
+ self.register_buffer("scale", scale, persistent=False)
220
+ self.register_buffer(
221
+ "cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False
222
+ )
223
+ self.cached_scales_seq_len = 0
224
+
225
+ # add apply_rotary_emb as static method
226
+
227
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
228
+
229
+ @property
230
+ def device(self):
231
+ return self.dummy.device
232
+
233
+ def get_seq_pos(self, seq_len, device=None, dtype=None, offset=0):
234
+ device = default(device, self.device)
235
+ dtype = default(dtype, self.cached_freqs.dtype)
236
+
237
+ return (
238
+ torch.arange(seq_len, device=device, dtype=dtype) + offset
239
+ ) / self.interpolate_factor
240
+
241
+ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, scale=None):
242
+ seq_dim = default(seq_dim, self.default_seq_dim)
243
+
244
+ assert not self.use_xpos or exists(scale), (
245
+ "you must use `.rotate_queries_and_keys` method instead and pass "
246
+ "in both queries and keys, for length extrapolatable rotary embeddings"
247
+ )
248
+
249
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
250
+
251
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
252
+
253
+ freqs = self.forward(seq, seq_len=seq_len, offset=offset)
254
+
255
+ if seq_dim == -3:
256
+ freqs = rearrange(freqs, "n d -> n 1 d")
257
+
258
+ return apply_rotary_emb(freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
259
+
260
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
261
+ dtype, device, seq_dim = (
262
+ q.dtype,
263
+ q.device,
264
+ default(seq_dim, self.default_seq_dim),
265
+ )
266
+
267
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
268
+ assert q_len <= k_len
269
+
270
+ q_scale = k_scale = 1.0
271
+
272
+ if self.use_xpos:
273
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
274
+
275
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
276
+ k_scale = self.get_scale(seq).type(dtype)
277
+
278
+ rotated_q = self.rotate_queries_or_keys(
279
+ q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset
280
+ )
281
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
282
+
283
+ rotated_q = rotated_q.type(q.dtype)
284
+ rotated_k = rotated_k.type(k.dtype)
285
+
286
+ return rotated_q, rotated_k
287
+
288
+ def rotate_queries_and_keys(self, q, k, seq_dim=None):
289
+ seq_dim = default(seq_dim, self.default_seq_dim)
290
+
291
+ assert self.use_xpos
292
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
293
+
294
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
295
+
296
+ freqs = self.forward(seq, seq_len=seq_len)
297
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
298
+
299
+ if seq_dim == -3:
300
+ freqs = rearrange(freqs, "n d -> n 1 d")
301
+ scale = rearrange(scale, "n d -> n 1 d")
302
+
303
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
304
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
305
+
306
+ rotated_q = rotated_q.type(q.dtype)
307
+ rotated_k = rotated_k.type(k.dtype)
308
+
309
+ return rotated_q, rotated_k
310
+
311
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
312
+ assert self.use_xpos
313
+
314
+ should_cache = (
315
+ self.cache_if_possible
316
+ and exists(seq_len)
317
+ and (offset + seq_len) <= self.cache_max_seq_len
318
+ )
319
+
320
+ if (
321
+ should_cache
322
+ and exists(self.cached_scales)
323
+ and (seq_len + offset) <= self.cached_scales_seq_len
324
+ ):
325
+ return self.cached_scales[offset : (offset + seq_len)]
326
+
327
+ scale = 1.0
328
+ if self.use_xpos:
329
+ power = (t - len(t) // 2) / self.scale_base
330
+ scale = self.scale ** rearrange(power, "n -> n 1")
331
+ scale = repeat(scale, "n d -> n (d r)", r=2)
332
+
333
+ if should_cache and offset == 0:
334
+ self.cached_scales[:seq_len] = scale.detach()
335
+ self.cached_scales_seq_len = seq_len
336
+
337
+ return scale
338
+
339
+ def get_axial_freqs(
340
+ self, *dims, offsets: tuple[int | float, ...] | Tensor | None = None
341
+ ):
342
+ Colon = slice(None)
343
+ all_freqs = []
344
+
345
+ # handle offset
346
+
347
+ if exists(offsets):
348
+ if not is_tensor(offsets):
349
+ offsets = tensor(offsets)
350
+
351
+ assert len(offsets) == len(dims)
352
+
353
+ # get frequencies for each axis
354
+
355
+ for ind, dim in enumerate(dims):
356
+
357
+ offset = 0
358
+ if exists(offsets):
359
+ offset = offsets[ind]
360
+
361
+ if self.freqs_for == "pixel":
362
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
363
+ else:
364
+ pos = torch.arange(dim, device=self.device)
365
+
366
+ pos = pos + offset
367
+
368
+ freqs = self.forward(pos, seq_len=dim)
369
+
370
+ all_axis = [None] * len(dims)
371
+ all_axis[ind] = Colon
372
+
373
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
374
+ all_freqs.append(freqs[new_axis_slice])
375
+
376
+ # concat all freqs
377
+
378
+ all_freqs = broadcast_tensors(*all_freqs)
379
+ return torch.cat(all_freqs, dim=-1)
380
+
381
+ @autocast_factory("cuda", enabled=False)
382
+ def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
383
+ should_cache = (
384
+ self.cache_if_possible
385
+ and not self.learned_freq
386
+ and exists(seq_len)
387
+ and self.freqs_for != "pixel"
388
+ and (offset + seq_len) <= self.cache_max_seq_len
389
+ )
390
+
391
+ if (
392
+ should_cache
393
+ and exists(self.cached_freqs)
394
+ and (offset + seq_len) <= self.cached_freqs_seq_len
395
+ ):
396
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
397
+
398
+ freqs = self.freqs
399
+
400
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
401
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
402
+
403
+ if should_cache and offset == 0:
404
+ self.cached_freqs[:seq_len] = freqs.detach()
405
+ self.cached_freqs_seq_len = seq_len
406
+
407
+ return freqs
broccoli/tensor.py ADDED
@@ -0,0 +1,128 @@
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class SigmaReparamTensor(nn.Module):
7
+ """
8
+ Inspired by Apple's Spectral Normed Linear Layers
9
+ (https://github.com/apple/ml-sigma-reparam)
10
+ """
11
+
12
+ def __init__(self, init_tensor: torch.Tensor):
13
+ assert init_tensor.ndim == 2
14
+
15
+ super().__init__()
16
+
17
+ self.sigma_reparam_tensor = nn.Parameter(init_tensor, requires_grad=True)
18
+
19
+ with torch.no_grad():
20
+ _, sigma, v_transpose = torch.linalg.svd(
21
+ self.sigma_reparam_tensor, full_matrices=False
22
+ )
23
+
24
+ self.register_buffer("approx_spectral_norm", sigma[:1])
25
+ self.register_buffer("right_singular", v_transpose[0])
26
+ self.sigma_reparam_scale = nn.Parameter(
27
+ self.approx_spectral_norm.clone().detach(), requires_grad=True
28
+ )
29
+
30
+ def power_iteration(self):
31
+ with torch.no_grad():
32
+ approx_right_singular_transpose = self.sigma_reparam_tensor.mv(
33
+ self.right_singular
34
+ )
35
+ approx_right_singular_transpose = F.normalize(
36
+ approx_right_singular_transpose, dim=0
37
+ )
38
+ updated_right_singular = self.sigma_reparam_tensor.T.mv(
39
+ approx_right_singular_transpose
40
+ )
41
+ updated_right_singular = F.normalize(updated_right_singular, dim=0)
42
+ self.right_singular.data.copy_(updated_right_singular)
43
+ rayleigh_quotient = torch.einsum(
44
+ "m,mn,n->",
45
+ approx_right_singular_transpose,
46
+ self.sigma_reparam_tensor,
47
+ updated_right_singular,
48
+ )
49
+ self.approx_spectral_norm.data.copy_(rayleigh_quotient)
50
+
51
+ def forward(self):
52
+ if self.training:
53
+ self.power_iteration()
54
+ return self.sigma_reparam_scale * (
55
+ self.sigma_reparam_tensor / self.approx_spectral_norm
56
+ )
57
+
58
+
59
+ class AnchoredReparamTensor(nn.Module):
60
+ """
61
+ Reparameterises a tensor by decoupling its magnitude and direction.
62
+
63
+ The direction is represented by a learnable weight tensor, normalised by the
64
+ Rayleigh quotient with respect to its initial dominant right-singular vector.
65
+ The magnitude is a separate learnable scalar.
66
+
67
+ The reparameterization is:
68
+
69
+ W_reparam = scale * (W / norm)
70
+
71
+ where the norm is the Rayleigh quotient uᵀWv₀, with v₀ being the dominant
72
+ right-singular vector of the initial tensor and u = normalize(Wv₀).
73
+ """
74
+
75
+ def __init__(self, init_tensor: torch.Tensor):
76
+ assert init_tensor.ndim == 2
77
+
78
+ super().__init__()
79
+
80
+ self.weight = nn.Parameter(init_tensor, requires_grad=True)
81
+
82
+ with torch.no_grad():
83
+ _, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
84
+
85
+ self.register_buffer("rayleigh_norm", sigma[:1])
86
+ self.register_buffer("initial_right_singular", v_transpose[0])
87
+ self.nondecay_scale = nn.Parameter(
88
+ sigma[:1].clone().detach(), requires_grad=True
89
+ )
90
+
91
+ def _update_rayleigh_norm(self):
92
+ with torch.no_grad():
93
+ product = self.weight.mv(self.initial_right_singular)
94
+ normed_product = F.normalize(product, dim=0)
95
+ rayleigh_norm = torch.einsum(
96
+ "m,mn,n->",
97
+ normed_product,
98
+ self.weight,
99
+ self.initial_right_singular,
100
+ )
101
+ self.rayleigh_norm.data.copy_(rayleigh_norm)
102
+
103
+ def forward(self):
104
+ if self.training:
105
+ self._update_rayleigh_norm()
106
+ return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
107
+
108
+
109
+ class NormReparamTensor(nn.Module):
110
+ """
111
+ Reparameterise a tensor as a normalised tensor of weights multiplied by a
112
+ learnable scaling factor.
113
+ """
114
+
115
+ def __init__(self, init_tensor: torch.Tensor):
116
+ assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
117
+ super().__init__()
118
+
119
+ # Use the gradboard convention of calling something nondecay_* if we should
120
+ # exclude it from weight decay
121
+ self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
+ self.nondecay_scale = nn.Parameter(
123
+ torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
124
+ )
125
+
126
+ def forward(self) -> torch.Tensor:
127
+ norm = torch.linalg.norm(self.weight)
128
+ return self.nondecay_scale * (self.weight / (norm + 1e-6))