hippoformer 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import torch
4
- from torch import nn, Tensor, cat, stack, zeros_like, einsum, tensor
4
+ from torch import nn, Tensor, cat, stack, arange, zeros_like, einsum, tensor
5
5
  import torch.nn.functional as F
6
6
  from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
@@ -221,6 +221,132 @@ class PathIntegration(Module):
221
221
 
222
222
  return self.rnn(transitions, prev_structural)
223
223
 
224
+ # custom transformer proposed by James Whittington that bridges to hippocampal models with a few twists
225
+
226
+ # the mmTEM can be seen as a linear attention / TTT variant of what he proposed
227
+ # needed for the baseline as well as the parallel block to bolster local time prediction
228
+
229
+ # https://arxiv.org/abs/2112.04035
230
+
231
+ def FeedForward(dim, mult = 4.):
232
+ dim_inner = int(dim * mult)
233
+ return nn.Sequential(
234
+ nn.Linear(dim, dim_inner),
235
+ nn.GELU(),
236
+ nn.Linear(dim_inner, dim)
237
+ )
238
+
239
+ class Attention(Module):
240
+ def __init__(
241
+ self,
242
+ dim_q,
243
+ dim_kv,
244
+ window_size,
245
+ dim_head = 64,
246
+ heads = 8,
247
+ ):
248
+ super().__init__()
249
+ dim_inner = dim_head * heads
250
+ self.scale = dim_head ** -0.5
251
+
252
+ self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
253
+ self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
254
+
255
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
256
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
257
+
258
+ self.window_size = window_size
259
+
260
+ self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
261
+ self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
262
+
263
+ def forward(
264
+ self,
265
+ queries_input,
266
+ key_values_input,
267
+ kv_cache = None
268
+ ):
269
+ batch, seq_len, device = *queries_input.shape[:2], queries_input.device
270
+
271
+ q = self.to_queries(queries_input)
272
+
273
+ k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
274
+
275
+ q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
276
+
277
+ if exists(kv_cache):
278
+ ck, cv = kv_cache
279
+ k = cat((ck, k), dim = -2)
280
+ v = cat((cv, v), dim = -2)
281
+
282
+ q = q * self.scale
283
+
284
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
285
+
286
+ # the diagonal is masked out
287
+
288
+ i, j = sim.shape[-2:]
289
+
290
+ j_seq = arange(j, device = device)[:, None]
291
+ i_seq = arange(i, device = device)[None, :] + (j - i)
292
+
293
+ windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
294
+
295
+ sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
296
+
297
+ # attention sink, for token as well as for attention sinking - from gpt-oss
298
+
299
+ attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
300
+
301
+ sim = cat((attn_sink, sim), dim = -1)
302
+
303
+ attn = sim.softmax(dim = -1)
304
+
305
+ attn = attn[..., 1:] # remove sink
306
+
307
+ # aggregate
308
+
309
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
310
+
311
+ out = self.merge_heads(out)
312
+
313
+ return self.to_out(out), stack((k, v))
314
+
315
+ class TEMTransformerBlock(Module):
316
+ def __init__(
317
+ self,
318
+ dim_structure,
319
+ dim_encoded_sensory,
320
+ dim_head = 64,
321
+ heads = 8,
322
+ ff_expansion_factor = 4.,
323
+ window_size = 64
324
+ ):
325
+ super().__init__()
326
+
327
+ self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, window_size, dim_head = dim_head, heads = heads)
328
+ self.ff = FeedForward(dim_structure, ff_expansion_factor)
329
+
330
+ self.window_size = window_size
331
+
332
+ def forward(
333
+ self,
334
+ structural_codes,
335
+ encoded_sensory,
336
+ kv_cache = None
337
+ ):
338
+ structure_and_sensory = cat((structural_codes, encoded_sensory), dim = -1)
339
+
340
+ retrieved, next_kv_cache = self.attn(structural_codes, structure_and_sensory, kv_cache = kv_cache)
341
+
342
+ x = retrieved + structural_codes
343
+
344
+ x = self.ff(x) + x
345
+
346
+ next_kv_cache = next_kv_cache[:, -self.window_size:]
347
+
348
+ return x, next_kv_cache
349
+
224
350
  # proposed mmTEM
225
351
 
226
352
  class mmTEM(Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=DjKAUfMpcoaAOqyuWauKp8n8e2YTzGVSOctNXxagkiA,18166
3
+ hippoformer-0.0.12.dist-info/METADATA,sha256=bH3GaJniFX2zCgxFFg8v0amGEPdWlBbG401Ml3_hDCs,2773
4
+ hippoformer-0.0.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.12.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=GWFJy2idp0FWBoVFw8T_6inTXYtY4i47hfhKj88_I0A,14463
3
- hippoformer-0.0.10.dist-info/METADATA,sha256=IB7iybYMwOkee3Q5ji-B_dnOB62LyK_6t1FPM_UT-FM,2773
4
- hippoformer-0.0.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.10.dist-info/RECORD,,