hippoformer 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -285,7 +285,7 @@ class Attention(Module):
285
285
  q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
286
286
 
287
287
  if exists(kv_cache):
288
- ck1, cv1, vk2, cv2 = kv_cache
288
+ ck1, cv1, ck2, cv2 = kv_cache
289
289
  k1 = cat((ck1, k1), dim = -2)
290
290
  v1 = cat((cv1, v1), dim = -2)
291
291
  k2 = cat((ck2, k2), dim = -2)
@@ -300,12 +300,12 @@ class Attention(Module):
300
300
 
301
301
  i, j = sim.shape[-2:]
302
302
 
303
- j_seq = arange(j, device = device)[:, None]
304
- i_seq = arange(i, device = device)[None, :] + (j - i)
303
+ i_seq = arange(i, device = device)[:, None] + (j - i)
304
+ j_seq = arange(j, device = device)[None, :]
305
305
 
306
306
  windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
307
307
 
308
- sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
308
+ sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
309
309
 
310
310
  # attention sink, for token as well as for attention sinking - from gpt-oss
311
311
 
@@ -401,24 +401,28 @@ class TEMTransformer(Module):
401
401
  **transformer_kwargs
402
402
  )
403
403
 
404
- layers.append(block)
404
+ self.layers.append(block)
405
405
 
406
406
  def forward(
407
407
  self,
408
408
  sensory,
409
409
  actions,
410
- prev_hiddens = None, # for the GRU based path integrator
411
- prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
410
+ prev_hiddens = None, # for the GRU based path integrator
411
+ prev_kv_cache = None, # for the specialized transformer blocks for inducing the grid-cells
412
+ return_memories = False
412
413
  ):
413
414
 
414
- structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
415
+ structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
415
416
 
416
417
  encoded_sensory = self.sensory_encoder(sensory)
417
418
 
419
+ prev_kv_cache = default(prev_kv_cache, (None,) * len(self.layers))
420
+ iter_prev_kv_cache = iter(prev_kv_cache)
421
+
418
422
  next_kv_cache = []
419
423
 
420
424
  for layer in self.layers:
421
- structure, layer_next_cache = layer(structure, encoded_sensory)
425
+ structure, layer_next_cache = layer(structure, encoded_sensory, kv_cache = next(iter_prev_kv_cache, None))
422
426
  next_kv_cache.append(layer_next_cache)
423
427
 
424
428
  decoded_sensory = self.sensory_decoder(structure)
@@ -427,7 +431,10 @@ class TEMTransformer(Module):
427
431
 
428
432
  pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
429
433
 
430
- return pred_loss
434
+ if not return_memories:
435
+ return pred_loss
436
+
437
+ return pred_loss, next_memories
431
438
 
432
439
  # proposed mmTEM
433
440
 
@@ -484,7 +491,7 @@ class mmTEM(Module):
484
491
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
485
492
 
486
493
  self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
487
- self.assoc_scan = AssocScan(*assoc_scan_kwargs)
494
+ self.assoc_scan = AssocScan(**assoc_scan_kwargs)
488
495
 
489
496
  self.meta_memory_mlp = create_mlp(
490
497
  dim = dim * 2,
@@ -495,7 +502,7 @@ class mmTEM(Module):
495
502
  )
496
503
 
497
504
  def forward_with_mse_loss(params, keys, values):
498
- pred = functional_call(self.meta_memory_mlp, params, keys)
505
+ pred = functional_call(self.meta_memory_mlp, params, (keys,))
499
506
  return F.mse_loss(pred, values)
500
507
 
501
508
  grad_fn = grad(forward_with_mse_loss)
@@ -596,7 +603,7 @@ class mmTEM(Module):
596
603
 
597
604
  # 2b. structure from structure
598
605
 
599
- decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
606
+ decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
600
607
 
601
608
  structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
602
609
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=fWgCak1szOJpoL-6bFuoJk8Z7RGxjlcL7-8kmGYhAfU,21121
3
+ hippoformer-0.0.17.dist-info/METADATA,sha256=XfCPGG4gyxvTSrZzV-H-L5I6WzN7MKhYXG0r7puPp1w,3093
4
+ hippoformer-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ hippoformer-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.17.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=axRDjuVWDKNTz0UL40LmzAEy1n59X65rkYjGHWzoN9w,20836
3
- hippoformer-0.0.15.dist-info/METADATA,sha256=mjFwFIPpy4dpFljqO-saQUivOuybq9YRBzQIkYe__6g,3093
4
- hippoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.15.dist-info/RECORD,,