titans-pytorch 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/titans.py +15 -6
- {titans_pytorch-0.0.17.dist-info → titans_pytorch-0.0.19.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.19.dist-info/RECORD +8 -0
- titans_pytorch-0.0.17.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.17.dist-info → titans_pytorch-0.0.19.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.17.dist-info → titans_pytorch-0.0.19.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
17
17
|
)
|
18
18
|
|
19
19
|
import einx
|
20
|
-
from einops import rearrange, pack, unpack
|
20
|
+
from einops import rearrange, repeat, pack, unpack
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
22
22
|
|
23
23
|
"""
|
@@ -152,6 +152,11 @@ class NeuralMemory(Module):
|
|
152
152
|
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
153
|
self.store_memory_loss_fn = store_memory_loss_fn
|
154
154
|
|
155
|
+
# empty memory embed
|
156
|
+
|
157
|
+
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
158
|
+
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
159
|
+
|
155
160
|
# learned adaptive learning rate and momentum
|
156
161
|
# todo - explore mlp layerwise learned lr / momentum
|
157
162
|
|
@@ -187,6 +192,9 @@ class NeuralMemory(Module):
|
|
187
192
|
|
188
193
|
return init_weights, init_momentum
|
189
194
|
|
195
|
+
def init_empty_memory_embed(self, batch, seq_len):
|
196
|
+
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
197
|
+
|
190
198
|
def store_memories(
|
191
199
|
self,
|
192
200
|
seq,
|
@@ -269,7 +277,7 @@ class NeuralMemory(Module):
|
|
269
277
|
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
278
|
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
279
|
|
272
|
-
outputs = scan(gates, inputs)
|
280
|
+
outputs = scan(gates.contiguous(), inputs.contiguous())
|
273
281
|
|
274
282
|
outputs = outputs[..., :seq_len]
|
275
283
|
outputs = rearrange(outputs, 'b d n -> b n d')
|
@@ -372,11 +380,12 @@ class NeuralMemory(Module):
|
|
372
380
|
|
373
381
|
values = self.post_rmsnorm(values)
|
374
382
|
|
375
|
-
# restore
|
383
|
+
# restore, pad with empty memory embed
|
376
384
|
|
377
|
-
|
378
|
-
values = values
|
385
|
+
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
386
|
+
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
379
387
|
|
388
|
+
values = values[:, :-padding]
|
380
389
|
return values
|
381
390
|
|
382
391
|
def forward(
|
@@ -389,7 +398,7 @@ class NeuralMemory(Module):
|
|
389
398
|
batch, seq_len = seq.shape[:2]
|
390
399
|
|
391
400
|
if seq_len < self.chunk_size:
|
392
|
-
return
|
401
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
393
402
|
|
394
403
|
if exists(past_state):
|
395
404
|
past_state = tuple(TensorDict(d) for d in past_state)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
|
6
|
+
titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.19.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.17.dist-info/METADATA,sha256=s8PEQdaW8WSjZkjlho1Gv1gcsk7GD2lu9oka7bP5Rf8,3811
|
6
|
-
titans_pytorch-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|