hippoformer 0.0.11__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
4
+ from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
- from torch.nn import Module
6
+ from torch.nn import Module, ModuleList
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
@@ -221,8 +221,11 @@ class PathIntegration(Module):
221
221
 
222
222
  return self.rnn(transitions, prev_structural)
223
223
 
224
- # custom transformer
225
- # with the connections proposed by James Whittington that bridges to hippocampal models
224
+ # custom transformer proposed by James Whittington that bridges to hippocampal models with a few twists
225
+
226
+ # the mmTEM can be seen as a linear attention / TTT variant of what he proposed
227
+ # needed for the baseline as well as the parallel block to bolster local time prediction
228
+
226
229
  # https://arxiv.org/abs/2112.04035
227
230
 
228
231
  def FeedForward(dim, mult = 4.):
@@ -238,19 +241,32 @@ class Attention(Module):
238
241
  self,
239
242
  dim_q,
240
243
  dim_kv,
244
+ window_size,
241
245
  dim_head = 64,
242
- heads = 8
246
+ heads = 8,
247
+ implicit_mlp_expansion = 2 # for fair comparison, the attention should have an implicit mlp of 2 layers with a non-linearity, just like the meta-memory mlp in titans (linear attention)
243
248
  ):
244
249
  super().__init__()
245
250
  dim_inner = dim_head * heads
251
+ dim_mlp_inner = dim_head * heads * implicit_mlp_expansion
252
+
246
253
  self.scale = dim_head ** -0.5
247
254
 
248
255
  self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
249
- self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
256
+
257
+ self.to_w1_keys = nn.Linear(dim_kv, dim_inner, bias = False)
258
+ self.to_w1_values = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
259
+
260
+ self.implicit_mlp_activation = nn.SiLU()
261
+
262
+ self.to_w2_keys = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
263
+ self.to_w2_values = nn.Linear(dim_kv, dim_inner, bias = False)
250
264
 
251
265
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
252
266
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
253
267
 
268
+ self.window_size = window_size
269
+
254
270
  self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
255
271
  self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
256
272
 
@@ -264,43 +280,59 @@ class Attention(Module):
264
280
 
265
281
  q = self.to_queries(queries_input)
266
282
 
267
- k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
283
+ k1, v1, k2, v2 = [fn(key_values_input) for fn in (self.to_w1_keys, self.to_w1_values, self.to_w2_keys, self.to_w2_values)]
268
284
 
269
- q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
285
+ q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
270
286
 
271
287
  if exists(kv_cache):
272
- ck, cv = kv_cache
273
- k = cat((ck, k), dim = -2)
274
- v = cat((cv, v), dim = -2)
288
+ ck1, cv1, vk2, cv2 = kv_cache
289
+ k1 = cat((ck1, k1), dim = -2)
290
+ v1 = cat((cv1, v1), dim = -2)
291
+ k2 = cat((ck2, k2), dim = -2)
292
+ v2 = cat((cv2, v2), dim = -2)
293
+
294
+ def attend(q, k, v):
295
+ q = q * self.scale
296
+
297
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
275
298
 
276
- q = q * self.scale
299
+ # the diagonal is masked out
277
300
 
278
- sim = einsum('b h i d, b h j d -> b h i j', q, k)
301
+ i, j = sim.shape[-2:]
279
302
 
280
- # the diagonal is masked out
303
+ j_seq = arange(j, device = device)[:, None]
304
+ i_seq = arange(i, device = device)[None, :] + (j - i)
281
305
 
282
- i, j = sim.shape[-2:]
283
- causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
306
+ windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
284
307
 
285
- sim = sim.masked_fill(causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
308
+ sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
286
309
 
287
- # attention sink, for token as well as for attention sinking - from gpt-oss
310
+ # attention sink, for token as well as for attention sinking - from gpt-oss
288
311
 
289
- attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
312
+ attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
290
313
 
291
- sim = cat((attn_sink, sim), dim = -1)
314
+ sim = cat((attn_sink, sim), dim = -1)
292
315
 
293
- attn = sim.softmax(dim = -1)
316
+ attn = sim.softmax(dim = -1)
294
317
 
295
- attn = attn[..., 1:] # remove sink
318
+ attn = attn[..., 1:] # remove sink
296
319
 
297
- # aggregate
320
+ # aggregate
298
321
 
299
- out = einsum('b h i j, b h j d -> b h i d', attn, v)
322
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
323
+ return out
324
+
325
+ # implicit memory mlp w1
326
+
327
+ hiddens = attend(q, k1, v1)
328
+ hiddens = self.implicit_mlp_activation(hiddens)
329
+ out = attend(hiddens, k2, v2)
330
+
331
+ # merge heads
300
332
 
301
333
  out = self.merge_heads(out)
302
334
 
303
- return self.to_out(out), stack((k, v))
335
+ return self.to_out(out), (k1, v1, k2, v2)
304
336
 
305
337
  class TEMTransformerBlock(Module):
306
338
  def __init__(
@@ -314,7 +346,7 @@ class TEMTransformerBlock(Module):
314
346
  ):
315
347
  super().__init__()
316
348
 
317
- self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
349
+ self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, window_size, dim_head = dim_head, heads = heads)
318
350
  self.ff = FeedForward(dim_structure, ff_expansion_factor)
319
351
 
320
352
  self.window_size = window_size
@@ -337,6 +369,66 @@ class TEMTransformerBlock(Module):
337
369
 
338
370
  return x, next_kv_cache
339
371
 
372
+ class TEMTransformer(Module):
373
+ def __init__(
374
+ self,
375
+ sensory_encoder_decoder: tuple[Module, Module],
376
+ dim_sensory,
377
+ dim_action,
378
+ dim_encoded_sensory,
379
+ dim_structure,
380
+ depth = 4,
381
+ transformer_kwargs: dict = dict(
382
+ dim_head = 64,
383
+ heads = 8,
384
+ ff_expansion_factor = 4,
385
+ window_size = 32
386
+ ),
387
+ ):
388
+ super().__init__()
389
+
390
+ self.sensory_encoder, self.sensory_decoder = sensory_encoder_decoder
391
+
392
+ self.path_integrator = nn.GRU(dim_action, dim_structure)
393
+
394
+ self.layers = ModuleList([])
395
+
396
+ for _ in range(depth):
397
+
398
+ block = TEMTransformerBlock(
399
+ dim_structure,
400
+ dim_encoded_sensory,
401
+ **transformer_kwargs
402
+ )
403
+
404
+ layers.append(block)
405
+
406
+ def forward(
407
+ self,
408
+ sensory,
409
+ actions,
410
+ prev_hiddens = None, # for the GRU based path integrator
411
+ prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
412
+ ):
413
+
414
+ structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
415
+
416
+ encoded_sensory = self.sensory_encoder(sensory)
417
+
418
+ next_kv_cache = []
419
+
420
+ for layer in self.layers:
421
+ structure, layer_next_cache = layer(structure, encoded_sensory)
422
+ next_kv_cache.append(layer_next_cache)
423
+
424
+ decoded_sensory = self.sensory_decoder(structure)
425
+
426
+ next_memories = (next_hiddens, stack(next_kv_cache))
427
+
428
+ pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
429
+
430
+ return pred_loss
431
+
340
432
  # proposed mmTEM
341
433
 
342
434
  class mmTEM(Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.11
3
+ Version: 0.0.14
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -63,3 +63,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
63
63
  note = {under review}
64
64
  }
65
65
  ```
66
+
67
+ ```bibtex
68
+ @article{Li2020GridCA,
69
+ title = {Grid Cells Are Ubiquitous in Neural Networks},
70
+ author = {Songlin Li and Yangdong Deng and Zhihua Wang},
71
+ journal = {ArXiv},
72
+ year = {2020},
73
+ volume = {abs/2003.03482},
74
+ url = {https://api.semanticscholar.org/CorpusID:212634300}
75
+ }
76
+ ```
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=RqHobJTA36mJ-YFutteMx5_6QrrSqI370C1sw5AYrUE,20818
3
+ hippoformer-0.0.14.dist-info/METADATA,sha256=W8k541yoURENT9RFpxZtu7O08r_R8joyzjj3ikwFbH0,3093
4
+ hippoformer-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.14.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=PLMfdype8AMwlVWrtItDBkE3gU_BCUaL42NMjB4vhAY,17795
3
- hippoformer-0.0.11.dist-info/METADATA,sha256=6NlqhZSEApQkUKsncBxmDIE03x_xZktHH-JCeYlYfcg,2773
4
- hippoformer-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.11.dist-info/RECORD,,