hippoformer 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import torch
4
4
  from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
- from torch.nn import Module
6
+ from torch.nn import Module, ModuleList
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
@@ -244,13 +244,23 @@ class Attention(Module):
244
244
  window_size,
245
245
  dim_head = 64,
246
246
  heads = 8,
247
+ implicit_mlp_expansion = 2 # for fair comparison, the attention should have an implicit mlp of 2 layers with a non-linearity, just like the meta-memory mlp in titans (linear attention)
247
248
  ):
248
249
  super().__init__()
249
250
  dim_inner = dim_head * heads
251
+ dim_mlp_inner = dim_head * heads * implicit_mlp_expansion
252
+
250
253
  self.scale = dim_head ** -0.5
251
254
 
252
255
  self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
253
- self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
256
+
257
+ self.to_w1_keys = nn.Linear(dim_kv, dim_inner, bias = False)
258
+ self.to_w1_values = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
259
+
260
+ self.implicit_mlp_activation = nn.SiLU()
261
+
262
+ self.to_w2_keys = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
263
+ self.to_w2_values = nn.Linear(dim_kv, dim_inner, bias = False)
254
264
 
255
265
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
256
266
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -270,47 +280,59 @@ class Attention(Module):
270
280
 
271
281
  q = self.to_queries(queries_input)
272
282
 
273
- k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
283
+ k1, v1, k2, v2 = [fn(key_values_input) for fn in (self.to_w1_keys, self.to_w1_values, self.to_w2_keys, self.to_w2_values)]
274
284
 
275
- q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
285
+ q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
276
286
 
277
287
  if exists(kv_cache):
278
- ck, cv = kv_cache
279
- k = cat((ck, k), dim = -2)
280
- v = cat((cv, v), dim = -2)
288
+ ck1, cv1, vk2, cv2 = kv_cache
289
+ k1 = cat((ck1, k1), dim = -2)
290
+ v1 = cat((cv1, v1), dim = -2)
291
+ k2 = cat((ck2, k2), dim = -2)
292
+ v2 = cat((cv2, v2), dim = -2)
293
+
294
+ def attend(q, k, v):
295
+ q = q * self.scale
281
296
 
282
- q = q * self.scale
297
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
283
298
 
284
- sim = einsum('b h i d, b h j d -> b h i j', q, k)
299
+ # the diagonal is masked out
285
300
 
286
- # the diagonal is masked out
301
+ i, j = sim.shape[-2:]
287
302
 
288
- i, j = sim.shape[-2:]
303
+ j_seq = arange(j, device = device)[:, None]
304
+ i_seq = arange(i, device = device)[None, :] + (j - i)
289
305
 
290
- j_seq = arange(j, device = device)[:, None]
291
- i_seq = arange(i, device = device)[None, :] + (j - i)
306
+ windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
292
307
 
293
- windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
308
+ sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
294
309
 
295
- sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
310
+ # attention sink, for token as well as for attention sinking - from gpt-oss
296
311
 
297
- # attention sink, for token as well as for attention sinking - from gpt-oss
312
+ attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
298
313
 
299
- attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
314
+ sim = cat((attn_sink, sim), dim = -1)
300
315
 
301
- sim = cat((attn_sink, sim), dim = -1)
316
+ attn = sim.softmax(dim = -1)
302
317
 
303
- attn = sim.softmax(dim = -1)
318
+ attn = attn[..., 1:] # remove sink
304
319
 
305
- attn = attn[..., 1:] # remove sink
320
+ # aggregate
306
321
 
307
- # aggregate
322
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
323
+ return out
308
324
 
309
- out = einsum('b h i j, b h j d -> b h i d', attn, v)
325
+ # implicit memory mlp w1
326
+
327
+ hiddens = attend(q, k1, v1)
328
+ hiddens = self.implicit_mlp_activation(hiddens)
329
+ out = attend(hiddens, k2, v2)
330
+
331
+ # merge heads
310
332
 
311
333
  out = self.merge_heads(out)
312
334
 
313
- return self.to_out(out), stack((k, v))
335
+ return self.to_out(out), (k1, v1, k2, v2)
314
336
 
315
337
  class TEMTransformerBlock(Module):
316
338
  def __init__(
@@ -347,6 +369,66 @@ class TEMTransformerBlock(Module):
347
369
 
348
370
  return x, next_kv_cache
349
371
 
372
+ class TEMTransformer(Module):
373
+ def __init__(
374
+ self,
375
+ sensory_encoder_decoder: tuple[Module, Module],
376
+ dim_sensory,
377
+ dim_action,
378
+ dim_encoded_sensory,
379
+ dim_structure,
380
+ depth = 4,
381
+ transformer_kwargs: dict = dict(
382
+ dim_head = 64,
383
+ heads = 8,
384
+ ff_expansion_factor = 4,
385
+ window_size = 32
386
+ ),
387
+ ):
388
+ super().__init__()
389
+
390
+ self.sensory_encoder, self.sensory_decoder = sensory_encoder_decoder
391
+
392
+ self.path_integrator = nn.GRU(dim_action, dim_structure)
393
+
394
+ self.layers = ModuleList([])
395
+
396
+ for _ in range(depth):
397
+
398
+ block = TEMTransformerBlock(
399
+ dim_structure,
400
+ dim_encoded_sensory,
401
+ **transformer_kwargs
402
+ )
403
+
404
+ layers.append(block)
405
+
406
+ def forward(
407
+ self,
408
+ sensory,
409
+ actions,
410
+ prev_hiddens = None, # for the GRU based path integrator
411
+ prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
412
+ ):
413
+
414
+ structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
415
+
416
+ encoded_sensory = self.sensory_encoder(sensory)
417
+
418
+ next_kv_cache = []
419
+
420
+ for layer in self.layers:
421
+ structure, layer_next_cache = layer(structure, encoded_sensory)
422
+ next_kv_cache.append(layer_next_cache)
423
+
424
+ decoded_sensory = self.sensory_decoder(structure)
425
+
426
+ next_memories = (next_hiddens, stack(next_kv_cache))
427
+
428
+ pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
429
+
430
+ return pred_loss
431
+
350
432
  # proposed mmTEM
351
433
 
352
434
  class mmTEM(Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.12
3
+ Version: 0.0.14
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -63,3 +63,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
63
63
  note = {under review}
64
64
  }
65
65
  ```
66
+
67
+ ```bibtex
68
+ @article{Li2020GridCA,
69
+ title = {Grid Cells Are Ubiquitous in Neural Networks},
70
+ author = {Songlin Li and Yangdong Deng and Zhihua Wang},
71
+ journal = {ArXiv},
72
+ year = {2020},
73
+ volume = {abs/2003.03482},
74
+ url = {https://api.semanticscholar.org/CorpusID:212634300}
75
+ }
76
+ ```
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=RqHobJTA36mJ-YFutteMx5_6QrrSqI370C1sw5AYrUE,20818
3
+ hippoformer-0.0.14.dist-info/METADATA,sha256=W8k541yoURENT9RFpxZtu7O08r_R8joyzjj3ikwFbH0,3093
4
+ hippoformer-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.14.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=DjKAUfMpcoaAOqyuWauKp8n8e2YTzGVSOctNXxagkiA,18166
3
- hippoformer-0.0.12.dist-info/METADATA,sha256=bH3GaJniFX2zCgxFFg8v0amGEPdWlBbG401Ml3_hDCs,2773
4
- hippoformer-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.12.dist-info/RECORD,,