hippoformer 0.0.12__tar.gz → 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.12 → hippoformer-0.0.14}/PKG-INFO +12 -1
- {hippoformer-0.0.12 → hippoformer-0.0.14}/README.md +11 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/hippoformer/hippoformer.py +105 -23
- {hippoformer-0.0.12 → hippoformer-0.0.14}/pyproject.toml +1 -1
- {hippoformer-0.0.12 → hippoformer-0.0.14}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/.gitignore +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/LICENSE +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/hippoformer-fig6.png +0 -0
- {hippoformer-0.0.12 → hippoformer-0.0.14}/tests/test_hippoformer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: hippoformer
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.14
|
|
4
4
|
Summary: hippoformer
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/hippoformer/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/hippoformer
|
|
@@ -63,3 +63,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
|
|
|
63
63
|
note = {under review}
|
|
64
64
|
}
|
|
65
65
|
```
|
|
66
|
+
|
|
67
|
+
```bibtex
|
|
68
|
+
@article{Li2020GridCA,
|
|
69
|
+
title = {Grid Cells Are Ubiquitous in Neural Networks},
|
|
70
|
+
author = {Songlin Li and Yangdong Deng and Zhihua Wang},
|
|
71
|
+
journal = {ArXiv},
|
|
72
|
+
year = {2020},
|
|
73
|
+
volume = {abs/2003.03482},
|
|
74
|
+
url = {https://api.semanticscholar.org/CorpusID:212634300}
|
|
75
|
+
}
|
|
76
|
+
```
|
|
@@ -16,3 +16,14 @@ Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Int
|
|
|
16
16
|
note = {under review}
|
|
17
17
|
}
|
|
18
18
|
```
|
|
19
|
+
|
|
20
|
+
```bibtex
|
|
21
|
+
@article{Li2020GridCA,
|
|
22
|
+
title = {Grid Cells Are Ubiquitous in Neural Networks},
|
|
23
|
+
author = {Songlin Li and Yangdong Deng and Zhihua Wang},
|
|
24
|
+
journal = {ArXiv},
|
|
25
|
+
year = {2020},
|
|
26
|
+
volume = {abs/2003.03482},
|
|
27
|
+
url = {https://api.semanticscholar.org/CorpusID:212634300}
|
|
28
|
+
}
|
|
29
|
+
```
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
|
|
5
5
|
import torch.nn.functional as F
|
|
6
|
-
from torch.nn import Module
|
|
6
|
+
from torch.nn import Module, ModuleList
|
|
7
7
|
from torch.jit import ScriptModule, script_method
|
|
8
8
|
from torch.func import vmap, grad, functional_call
|
|
9
9
|
|
|
@@ -244,13 +244,23 @@ class Attention(Module):
|
|
|
244
244
|
window_size,
|
|
245
245
|
dim_head = 64,
|
|
246
246
|
heads = 8,
|
|
247
|
+
implicit_mlp_expansion = 2 # for fair comparison, the attention should have an implicit mlp of 2 layers with a non-linearity, just like the meta-memory mlp in titans (linear attention)
|
|
247
248
|
):
|
|
248
249
|
super().__init__()
|
|
249
250
|
dim_inner = dim_head * heads
|
|
251
|
+
dim_mlp_inner = dim_head * heads * implicit_mlp_expansion
|
|
252
|
+
|
|
250
253
|
self.scale = dim_head ** -0.5
|
|
251
254
|
|
|
252
255
|
self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
|
|
253
|
-
|
|
256
|
+
|
|
257
|
+
self.to_w1_keys = nn.Linear(dim_kv, dim_inner, bias = False)
|
|
258
|
+
self.to_w1_values = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
|
|
259
|
+
|
|
260
|
+
self.implicit_mlp_activation = nn.SiLU()
|
|
261
|
+
|
|
262
|
+
self.to_w2_keys = nn.Linear(dim_kv, dim_mlp_inner, bias = False)
|
|
263
|
+
self.to_w2_values = nn.Linear(dim_kv, dim_inner, bias = False)
|
|
254
264
|
|
|
255
265
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
256
266
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
@@ -270,47 +280,59 @@ class Attention(Module):
|
|
|
270
280
|
|
|
271
281
|
q = self.to_queries(queries_input)
|
|
272
282
|
|
|
273
|
-
|
|
283
|
+
k1, v1, k2, v2 = [fn(key_values_input) for fn in (self.to_w1_keys, self.to_w1_values, self.to_w2_keys, self.to_w2_values)]
|
|
274
284
|
|
|
275
|
-
q,
|
|
285
|
+
q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
|
|
276
286
|
|
|
277
287
|
if exists(kv_cache):
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
288
|
+
ck1, cv1, vk2, cv2 = kv_cache
|
|
289
|
+
k1 = cat((ck1, k1), dim = -2)
|
|
290
|
+
v1 = cat((cv1, v1), dim = -2)
|
|
291
|
+
k2 = cat((ck2, k2), dim = -2)
|
|
292
|
+
v2 = cat((cv2, v2), dim = -2)
|
|
293
|
+
|
|
294
|
+
def attend(q, k, v):
|
|
295
|
+
q = q * self.scale
|
|
281
296
|
|
|
282
|
-
|
|
297
|
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
283
298
|
|
|
284
|
-
|
|
299
|
+
# the diagonal is masked out
|
|
285
300
|
|
|
286
|
-
|
|
301
|
+
i, j = sim.shape[-2:]
|
|
287
302
|
|
|
288
|
-
|
|
303
|
+
j_seq = arange(j, device = device)[:, None]
|
|
304
|
+
i_seq = arange(i, device = device)[None, :] + (j - i)
|
|
289
305
|
|
|
290
|
-
|
|
291
|
-
i_seq = arange(i, device = device)[None, :] + (j - i)
|
|
306
|
+
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
292
307
|
|
|
293
|
-
|
|
308
|
+
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
294
309
|
|
|
295
|
-
|
|
310
|
+
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
296
311
|
|
|
297
|
-
|
|
312
|
+
attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
|
|
298
313
|
|
|
299
|
-
|
|
314
|
+
sim = cat((attn_sink, sim), dim = -1)
|
|
300
315
|
|
|
301
|
-
|
|
316
|
+
attn = sim.softmax(dim = -1)
|
|
302
317
|
|
|
303
|
-
|
|
318
|
+
attn = attn[..., 1:] # remove sink
|
|
304
319
|
|
|
305
|
-
|
|
320
|
+
# aggregate
|
|
306
321
|
|
|
307
|
-
|
|
322
|
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
323
|
+
return out
|
|
308
324
|
|
|
309
|
-
|
|
325
|
+
# implicit memory mlp w1
|
|
326
|
+
|
|
327
|
+
hiddens = attend(q, k1, v1)
|
|
328
|
+
hiddens = self.implicit_mlp_activation(hiddens)
|
|
329
|
+
out = attend(hiddens, k2, v2)
|
|
330
|
+
|
|
331
|
+
# merge heads
|
|
310
332
|
|
|
311
333
|
out = self.merge_heads(out)
|
|
312
334
|
|
|
313
|
-
return self.to_out(out),
|
|
335
|
+
return self.to_out(out), (k1, v1, k2, v2)
|
|
314
336
|
|
|
315
337
|
class TEMTransformerBlock(Module):
|
|
316
338
|
def __init__(
|
|
@@ -347,6 +369,66 @@ class TEMTransformerBlock(Module):
|
|
|
347
369
|
|
|
348
370
|
return x, next_kv_cache
|
|
349
371
|
|
|
372
|
+
class TEMTransformer(Module):
|
|
373
|
+
def __init__(
|
|
374
|
+
self,
|
|
375
|
+
sensory_encoder_decoder: tuple[Module, Module],
|
|
376
|
+
dim_sensory,
|
|
377
|
+
dim_action,
|
|
378
|
+
dim_encoded_sensory,
|
|
379
|
+
dim_structure,
|
|
380
|
+
depth = 4,
|
|
381
|
+
transformer_kwargs: dict = dict(
|
|
382
|
+
dim_head = 64,
|
|
383
|
+
heads = 8,
|
|
384
|
+
ff_expansion_factor = 4,
|
|
385
|
+
window_size = 32
|
|
386
|
+
),
|
|
387
|
+
):
|
|
388
|
+
super().__init__()
|
|
389
|
+
|
|
390
|
+
self.sensory_encoder, self.sensory_decoder = sensory_encoder_decoder
|
|
391
|
+
|
|
392
|
+
self.path_integrator = nn.GRU(dim_action, dim_structure)
|
|
393
|
+
|
|
394
|
+
self.layers = ModuleList([])
|
|
395
|
+
|
|
396
|
+
for _ in range(depth):
|
|
397
|
+
|
|
398
|
+
block = TEMTransformerBlock(
|
|
399
|
+
dim_structure,
|
|
400
|
+
dim_encoded_sensory,
|
|
401
|
+
**transformer_kwargs
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
layers.append(block)
|
|
405
|
+
|
|
406
|
+
def forward(
|
|
407
|
+
self,
|
|
408
|
+
sensory,
|
|
409
|
+
actions,
|
|
410
|
+
prev_hiddens = None, # for the GRU based path integrator
|
|
411
|
+
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
412
|
+
):
|
|
413
|
+
|
|
414
|
+
structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
|
|
415
|
+
|
|
416
|
+
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
|
+
|
|
418
|
+
next_kv_cache = []
|
|
419
|
+
|
|
420
|
+
for layer in self.layers:
|
|
421
|
+
structure, layer_next_cache = layer(structure, encoded_sensory)
|
|
422
|
+
next_kv_cache.append(layer_next_cache)
|
|
423
|
+
|
|
424
|
+
decoded_sensory = self.sensory_decoder(structure)
|
|
425
|
+
|
|
426
|
+
next_memories = (next_hiddens, stack(next_kv_cache))
|
|
427
|
+
|
|
428
|
+
pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
|
|
429
|
+
|
|
430
|
+
return pred_loss
|
|
431
|
+
|
|
350
432
|
# proposed mmTEM
|
|
351
433
|
|
|
352
434
|
class mmTEM(Module):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|