hippoformer 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -221,6 +221,122 @@ class PathIntegration(Module):
221
221
 
222
222
  return self.rnn(transitions, prev_structural)
223
223
 
224
+ # custom transformer
225
+ # with the connections proposed by James Whittington that bridges to hippocampal models
226
+ # https://arxiv.org/abs/2112.04035
227
+
228
+ def FeedForward(dim, mult = 4.):
229
+ dim_inner = int(dim * mult)
230
+ return nn.Sequential(
231
+ nn.Linear(dim, dim_inner),
232
+ nn.GELU(),
233
+ nn.Linear(dim_inner, dim)
234
+ )
235
+
236
+ class Attention(Module):
237
+ def __init__(
238
+ self,
239
+ dim_q,
240
+ dim_kv,
241
+ dim_head = 64,
242
+ heads = 8
243
+ ):
244
+ super().__init__()
245
+ dim_inner = dim_head * heads
246
+ self.scale = dim_head ** -0.5
247
+
248
+ self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
249
+ self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
250
+
251
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
252
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
253
+
254
+ self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
255
+ self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
256
+
257
+ def forward(
258
+ self,
259
+ queries_input,
260
+ key_values_input,
261
+ kv_cache = None
262
+ ):
263
+ batch, seq_len, device = *queries_input.shape[:2], queries_input.device
264
+
265
+ q = self.to_queries(queries_input)
266
+
267
+ k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
268
+
269
+ q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
270
+
271
+ if exists(kv_cache):
272
+ ck, cv = kv_cache
273
+ k = cat((ck, k), dim = -2)
274
+ v = cat((cv, v), dim = -2)
275
+
276
+ q = q * self.scale
277
+
278
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
279
+
280
+ # the diagonal is masked out
281
+
282
+ i, j = sim.shape[-2:]
283
+ causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
284
+
285
+ sim = sim.masked_fill(causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
286
+
287
+ # attention sink, for token as well as for attention sinking - from gpt-oss
288
+
289
+ attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
290
+
291
+ sim = cat((attn_sink, sim), dim = -1)
292
+
293
+ attn = sim.softmax(dim = -1)
294
+
295
+ attn = attn[..., 1:] # remove sink
296
+
297
+ # aggregate
298
+
299
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
300
+
301
+ out = self.merge_heads(out)
302
+
303
+ return self.to_out(out), stack((k, v))
304
+
305
+ class TEMTransformerBlock(Module):
306
+ def __init__(
307
+ self,
308
+ dim_structure,
309
+ dim_encoded_sensory,
310
+ dim_head = 64,
311
+ heads = 8,
312
+ ff_expansion_factor = 4.,
313
+ window_size = 64
314
+ ):
315
+ super().__init__()
316
+
317
+ self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
318
+ self.ff = FeedForward(dim_structure, ff_expansion_factor)
319
+
320
+ self.window_size = window_size
321
+
322
+ def forward(
323
+ self,
324
+ structural_codes,
325
+ encoded_sensory,
326
+ kv_cache = None
327
+ ):
328
+ structure_and_sensory = cat((structural_codes, encoded_sensory), dim = -1)
329
+
330
+ retrieved, next_kv_cache = self.attn(structural_codes, structure_and_sensory, kv_cache = kv_cache)
331
+
332
+ x = retrieved + structural_codes
333
+
334
+ x = self.ff(x) + x
335
+
336
+ next_kv_cache = next_kv_cache[:, -self.window_size:]
337
+
338
+ return x, next_kv_cache
339
+
224
340
  # proposed mmTEM
225
341
 
226
342
  class mmTEM(Module):
@@ -455,6 +571,10 @@ class mmTEM(Module):
455
571
 
456
572
  update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
457
573
 
574
+ # store next momentum
575
+
576
+ next_momentum[key] = update[:, -1]
577
+
458
578
  # maybe muon
459
579
 
460
580
  if self.muon_update:
@@ -464,14 +584,13 @@ class mmTEM(Module):
464
584
 
465
585
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
466
586
 
467
- acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
587
+ acc_update = self.assoc_scan(-update, expanded_forget.sigmoid(), param)
468
588
 
469
589
  acc_update = inverse_pack(acc_update)
470
590
 
471
591
  # set the next params and momentum, which can be passed back in
472
592
 
473
- next_params[key] = param - acc_update[:, -1]
474
- next_momentum[key] = update[:, -1]
593
+ next_params[key] = acc_update[:, -1]
475
594
 
476
595
  # losses
477
596
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.9
3
+ Version: 0.0.11
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=PLMfdype8AMwlVWrtItDBkE3gU_BCUaL42NMjB4vhAY,17795
3
+ hippoformer-0.0.11.dist-info/METADATA,sha256=6NlqhZSEApQkUKsncBxmDIE03x_xZktHH-JCeYlYfcg,2773
4
+ hippoformer-0.0.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.11.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=m7luQGFdMWOkZUorjd5v34hx_vjOQqpJOAGCL0njHUE,14426
3
- hippoformer-0.0.9.dist-info/METADATA,sha256=owgkDcdTf0_N5IbUr3e_yt7u5sIWfOMha-hA5LQWnus,2772
4
- hippoformer-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.9.dist-info/RECORD,,