hippoformer 0.0.9__tar.gz → 0.0.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {hippoformer-0.0.9 → hippoformer-0.0.11}/PKG-INFO +1 -1
- {hippoformer-0.0.9 → hippoformer-0.0.11}/hippoformer/hippoformer.py +122 -3
- {hippoformer-0.0.9 → hippoformer-0.0.11}/pyproject.toml +1 -1
- {hippoformer-0.0.9 → hippoformer-0.0.11}/tests/test_hippoformer.py +12 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/.github/workflows/python-publish.yml +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/.github/workflows/test.yml +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/.gitignore +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/LICENSE +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/README.md +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/hippoformer/__init__.py +0 -0
- {hippoformer-0.0.9 → hippoformer-0.0.11}/hippoformer-fig6.png +0 -0
|
@@ -221,6 +221,122 @@ class PathIntegration(Module):
|
|
|
221
221
|
|
|
222
222
|
return self.rnn(transitions, prev_structural)
|
|
223
223
|
|
|
224
|
+
# custom transformer
|
|
225
|
+
# with the connections proposed by James Whittington that bridges to hippocampal models
|
|
226
|
+
# https://arxiv.org/abs/2112.04035
|
|
227
|
+
|
|
228
|
+
def FeedForward(dim, mult = 4.):
|
|
229
|
+
dim_inner = int(dim * mult)
|
|
230
|
+
return nn.Sequential(
|
|
231
|
+
nn.Linear(dim, dim_inner),
|
|
232
|
+
nn.GELU(),
|
|
233
|
+
nn.Linear(dim_inner, dim)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
class Attention(Module):
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
dim_q,
|
|
240
|
+
dim_kv,
|
|
241
|
+
dim_head = 64,
|
|
242
|
+
heads = 8
|
|
243
|
+
):
|
|
244
|
+
super().__init__()
|
|
245
|
+
dim_inner = dim_head * heads
|
|
246
|
+
self.scale = dim_head ** -0.5
|
|
247
|
+
|
|
248
|
+
self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
|
|
249
|
+
self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
|
|
250
|
+
|
|
251
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
252
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
253
|
+
|
|
254
|
+
self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
|
|
255
|
+
self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
|
|
256
|
+
|
|
257
|
+
def forward(
|
|
258
|
+
self,
|
|
259
|
+
queries_input,
|
|
260
|
+
key_values_input,
|
|
261
|
+
kv_cache = None
|
|
262
|
+
):
|
|
263
|
+
batch, seq_len, device = *queries_input.shape[:2], queries_input.device
|
|
264
|
+
|
|
265
|
+
q = self.to_queries(queries_input)
|
|
266
|
+
|
|
267
|
+
k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
|
|
268
|
+
|
|
269
|
+
q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
|
|
270
|
+
|
|
271
|
+
if exists(kv_cache):
|
|
272
|
+
ck, cv = kv_cache
|
|
273
|
+
k = cat((ck, k), dim = -2)
|
|
274
|
+
v = cat((cv, v), dim = -2)
|
|
275
|
+
|
|
276
|
+
q = q * self.scale
|
|
277
|
+
|
|
278
|
+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
279
|
+
|
|
280
|
+
# the diagonal is masked out
|
|
281
|
+
|
|
282
|
+
i, j = sim.shape[-2:]
|
|
283
|
+
causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
|
|
284
|
+
|
|
285
|
+
sim = sim.masked_fill(causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
286
|
+
|
|
287
|
+
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
288
|
+
|
|
289
|
+
attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
|
|
290
|
+
|
|
291
|
+
sim = cat((attn_sink, sim), dim = -1)
|
|
292
|
+
|
|
293
|
+
attn = sim.softmax(dim = -1)
|
|
294
|
+
|
|
295
|
+
attn = attn[..., 1:] # remove sink
|
|
296
|
+
|
|
297
|
+
# aggregate
|
|
298
|
+
|
|
299
|
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
300
|
+
|
|
301
|
+
out = self.merge_heads(out)
|
|
302
|
+
|
|
303
|
+
return self.to_out(out), stack((k, v))
|
|
304
|
+
|
|
305
|
+
class TEMTransformerBlock(Module):
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
dim_structure,
|
|
309
|
+
dim_encoded_sensory,
|
|
310
|
+
dim_head = 64,
|
|
311
|
+
heads = 8,
|
|
312
|
+
ff_expansion_factor = 4.,
|
|
313
|
+
window_size = 64
|
|
314
|
+
):
|
|
315
|
+
super().__init__()
|
|
316
|
+
|
|
317
|
+
self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
|
|
318
|
+
self.ff = FeedForward(dim_structure, ff_expansion_factor)
|
|
319
|
+
|
|
320
|
+
self.window_size = window_size
|
|
321
|
+
|
|
322
|
+
def forward(
|
|
323
|
+
self,
|
|
324
|
+
structural_codes,
|
|
325
|
+
encoded_sensory,
|
|
326
|
+
kv_cache = None
|
|
327
|
+
):
|
|
328
|
+
structure_and_sensory = cat((structural_codes, encoded_sensory), dim = -1)
|
|
329
|
+
|
|
330
|
+
retrieved, next_kv_cache = self.attn(structural_codes, structure_and_sensory, kv_cache = kv_cache)
|
|
331
|
+
|
|
332
|
+
x = retrieved + structural_codes
|
|
333
|
+
|
|
334
|
+
x = self.ff(x) + x
|
|
335
|
+
|
|
336
|
+
next_kv_cache = next_kv_cache[:, -self.window_size:]
|
|
337
|
+
|
|
338
|
+
return x, next_kv_cache
|
|
339
|
+
|
|
224
340
|
# proposed mmTEM
|
|
225
341
|
|
|
226
342
|
class mmTEM(Module):
|
|
@@ -455,6 +571,10 @@ class mmTEM(Module):
|
|
|
455
571
|
|
|
456
572
|
update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
|
|
457
573
|
|
|
574
|
+
# store next momentum
|
|
575
|
+
|
|
576
|
+
next_momentum[key] = update[:, -1]
|
|
577
|
+
|
|
458
578
|
# maybe muon
|
|
459
579
|
|
|
460
580
|
if self.muon_update:
|
|
@@ -464,14 +584,13 @@ class mmTEM(Module):
|
|
|
464
584
|
|
|
465
585
|
expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
|
|
466
586
|
|
|
467
|
-
acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
|
|
587
|
+
acc_update = self.assoc_scan(-update, expanded_forget.sigmoid(), param)
|
|
468
588
|
|
|
469
589
|
acc_update = inverse_pack(acc_update)
|
|
470
590
|
|
|
471
591
|
# set the next params and momentum, which can be passed back in
|
|
472
592
|
|
|
473
|
-
next_params[key] =
|
|
474
|
-
next_momentum[key] = update[:, -1]
|
|
593
|
+
next_params[key] = acc_update[:, -1]
|
|
475
594
|
|
|
476
595
|
# losses
|
|
477
596
|
|
|
@@ -65,3 +65,15 @@ def test_mm_tem(
|
|
|
65
65
|
|
|
66
66
|
loss = model(sensory, actions, memory_mlp_params = next_params)
|
|
67
67
|
loss.backward()
|
|
68
|
+
|
|
69
|
+
def test_tem_t():
|
|
70
|
+
from hippoformer.hippoformer import TEMTransformerBlock
|
|
71
|
+
|
|
72
|
+
block = TEMTransformerBlock(32, 16)
|
|
73
|
+
|
|
74
|
+
structural_codes = torch.randn(1, 7, 32)
|
|
75
|
+
encoded_sensory = torch.randn(1, 7, 16)
|
|
76
|
+
|
|
77
|
+
pred, kv_cache = block(structural_codes, encoded_sensory)
|
|
78
|
+
|
|
79
|
+
assert pred.shape == (1, 7, 32)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|