titans-pytorch 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
titans_pytorch/titans.py CHANGED
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
17
17
  )
18
18
 
19
19
  import einx
20
- from einops import rearrange, pack, unpack
20
+ from einops import rearrange, repeat, pack, unpack
21
21
  from einops.layers.torch import Rearrange, Reduce
22
22
 
23
23
  """
@@ -152,6 +152,11 @@ class NeuralMemory(Module):
152
152
  self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
153
  self.store_memory_loss_fn = store_memory_loss_fn
154
154
 
155
+ # empty memory embed
156
+
157
+ self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
158
+ nn.init.normal_(self.empty_memory_embed, std = 0.02)
159
+
155
160
  # learned adaptive learning rate and momentum
156
161
  # todo - explore mlp layerwise learned lr / momentum
157
162
 
@@ -187,6 +192,9 @@ class NeuralMemory(Module):
187
192
 
188
193
  return init_weights, init_momentum
189
194
 
195
+ def init_empty_memory_embed(self, batch, seq_len):
196
+ return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
197
+
190
198
  def store_memories(
191
199
  self,
192
200
  seq,
@@ -269,7 +277,7 @@ class NeuralMemory(Module):
269
277
  gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
278
  inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
279
 
272
- outputs = scan(gates, inputs)
280
+ outputs = scan(gates.contiguous(), inputs.contiguous())
273
281
 
274
282
  outputs = outputs[..., :seq_len]
275
283
  outputs = rearrange(outputs, 'b d n -> b n d')
@@ -372,11 +380,12 @@ class NeuralMemory(Module):
372
380
 
373
381
  values = self.post_rmsnorm(values)
374
382
 
375
- # restore
383
+ # restore, pad with empty memory embed
376
384
 
377
- values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
- values = values[:, :-padding]
385
+ empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
386
+ values = torch.cat((empty_memory_embeds, values), dim = -2)
379
387
 
388
+ values = values[:, :-padding]
380
389
  return values
381
390
 
382
391
  def forward(
@@ -389,7 +398,7 @@ class NeuralMemory(Module):
389
398
  batch, seq_len = seq.shape[:2]
390
399
 
391
400
  if seq_len < self.chunk_size:
392
- return torch.zeros_like(seq)
401
+ return self.init_empty_memory_embed(batch, seq_len)
393
402
 
394
403
  if exists(past_state):
395
404
  past_state = tuple(TensorDict(d) for d in past_state)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=CxbJgNdIS9NbbCDdgotFXAnrV16xmvufUErerKe7qJA,12636
4
+ titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
+ titans_pytorch-0.0.19.dist-info/METADATA,sha256=5Wpk79HYI4z8LeNRV__UaamKppiGcJ2HdIlll1JSZr8,3811
6
+ titans_pytorch-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.19.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
4
- titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
5
- titans_pytorch-0.0.17.dist-info/METADATA,sha256=s8PEQdaW8WSjZkjlho1Gv1gcsk7GD2lu9oka7bP5Rf8,3811
6
- titans_pytorch-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.17.dist-info/RECORD,,