hippoformer 0.0.10__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.10
3
+ Version: 0.0.11
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -221,6 +221,122 @@ class PathIntegration(Module):
221
221
 
222
222
  return self.rnn(transitions, prev_structural)
223
223
 
224
+ # custom transformer
225
+ # with the connections proposed by James Whittington that bridges to hippocampal models
226
+ # https://arxiv.org/abs/2112.04035
227
+
228
+ def FeedForward(dim, mult = 4.):
229
+ dim_inner = int(dim * mult)
230
+ return nn.Sequential(
231
+ nn.Linear(dim, dim_inner),
232
+ nn.GELU(),
233
+ nn.Linear(dim_inner, dim)
234
+ )
235
+
236
+ class Attention(Module):
237
+ def __init__(
238
+ self,
239
+ dim_q,
240
+ dim_kv,
241
+ dim_head = 64,
242
+ heads = 8
243
+ ):
244
+ super().__init__()
245
+ dim_inner = dim_head * heads
246
+ self.scale = dim_head ** -0.5
247
+
248
+ self.to_queries = nn.Linear(dim_q, dim_inner, bias = False)
249
+ self.to_key_values = nn.Linear(dim_kv, dim_inner * 2, bias = False)
250
+
251
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
252
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
253
+
254
+ self.to_out = nn.Linear(dim_inner, dim_q, bias = False)
255
+ self.attn_head_sink = nn.Parameter(torch.randn(heads) * 1e-2) # needed as the diagonal is masked out, and for attention sink
256
+
257
+ def forward(
258
+ self,
259
+ queries_input,
260
+ key_values_input,
261
+ kv_cache = None
262
+ ):
263
+ batch, seq_len, device = *queries_input.shape[:2], queries_input.device
264
+
265
+ q = self.to_queries(queries_input)
266
+
267
+ k, v = self.to_key_values(key_values_input).chunk(2, dim = -1)
268
+
269
+ q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
270
+
271
+ if exists(kv_cache):
272
+ ck, cv = kv_cache
273
+ k = cat((ck, k), dim = -2)
274
+ v = cat((cv, v), dim = -2)
275
+
276
+ q = q * self.scale
277
+
278
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
279
+
280
+ # the diagonal is masked out
281
+
282
+ i, j = sim.shape[-2:]
283
+ causal_mask_without_diagonal = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i)
284
+
285
+ sim = sim.masked_fill(causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
286
+
287
+ # attention sink, for token as well as for attention sinking - from gpt-oss
288
+
289
+ attn_sink = repeat(self.attn_head_sink, 'h -> b h i 1', b = batch, i = seq_len)
290
+
291
+ sim = cat((attn_sink, sim), dim = -1)
292
+
293
+ attn = sim.softmax(dim = -1)
294
+
295
+ attn = attn[..., 1:] # remove sink
296
+
297
+ # aggregate
298
+
299
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
300
+
301
+ out = self.merge_heads(out)
302
+
303
+ return self.to_out(out), stack((k, v))
304
+
305
+ class TEMTransformerBlock(Module):
306
+ def __init__(
307
+ self,
308
+ dim_structure,
309
+ dim_encoded_sensory,
310
+ dim_head = 64,
311
+ heads = 8,
312
+ ff_expansion_factor = 4.,
313
+ window_size = 64
314
+ ):
315
+ super().__init__()
316
+
317
+ self.attn = Attention(dim_structure, dim_structure + dim_encoded_sensory, dim_head = dim_head, heads = heads)
318
+ self.ff = FeedForward(dim_structure, ff_expansion_factor)
319
+
320
+ self.window_size = window_size
321
+
322
+ def forward(
323
+ self,
324
+ structural_codes,
325
+ encoded_sensory,
326
+ kv_cache = None
327
+ ):
328
+ structure_and_sensory = cat((structural_codes, encoded_sensory), dim = -1)
329
+
330
+ retrieved, next_kv_cache = self.attn(structural_codes, structure_and_sensory, kv_cache = kv_cache)
331
+
332
+ x = retrieved + structural_codes
333
+
334
+ x = self.ff(x) + x
335
+
336
+ next_kv_cache = next_kv_cache[:, -self.window_size:]
337
+
338
+ return x, next_kv_cache
339
+
224
340
  # proposed mmTEM
225
341
 
226
342
  class mmTEM(Module):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.10"
3
+ version = "0.0.11"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -65,3 +65,15 @@ def test_mm_tem(
65
65
 
66
66
  loss = model(sensory, actions, memory_mlp_params = next_params)
67
67
  loss.backward()
68
+
69
+ def test_tem_t():
70
+ from hippoformer.hippoformer import TEMTransformerBlock
71
+
72
+ block = TEMTransformerBlock(32, 16)
73
+
74
+ structural_codes = torch.randn(1, 7, 32)
75
+ encoded_sensory = torch.randn(1, 7, 16)
76
+
77
+ pred, kv_cache = block(structural_codes, encoded_sensory)
78
+
79
+ assert pred.shape == (1, 7, 32)
File without changes
File without changes
File without changes