hippoformer 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -285,7 +285,7 @@ class Attention(Module):
285
285
  q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
286
286
 
287
287
  if exists(kv_cache):
288
- ck1, cv1, vk2, cv2 = kv_cache
288
+ ck1, cv1, ck2, cv2 = kv_cache
289
289
  k1 = cat((ck1, k1), dim = -2)
290
290
  v1 = cat((cv1, v1), dim = -2)
291
291
  k2 = cat((ck2, k2), dim = -2)
@@ -300,12 +300,12 @@ class Attention(Module):
300
300
 
301
301
  i, j = sim.shape[-2:]
302
302
 
303
- j_seq = arange(j, device = device)[:, None]
304
- i_seq = arange(i, device = device)[None, :] + (j - i)
303
+ j_seq = arange(i, device = device)[:, None]
304
+ i_seq = arange(j, device = device)[None, :] + (j - i)
305
305
 
306
306
  windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
307
307
 
308
- sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
308
+ sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
309
309
 
310
310
  # attention sink, for token as well as for attention sinking - from gpt-oss
311
311
 
@@ -401,7 +401,7 @@ class TEMTransformer(Module):
401
401
  **transformer_kwargs
402
402
  )
403
403
 
404
- layers.append(block)
404
+ self.layers.append(block)
405
405
 
406
406
  def forward(
407
407
  self,
@@ -411,7 +411,7 @@ class TEMTransformer(Module):
411
411
  prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
412
412
  ):
413
413
 
414
- structure, next_hiddens = self.gru_path_integrator(actions, prev_hiddens)
414
+ structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
415
415
 
416
416
  encoded_sensory = self.sensory_encoder(sensory)
417
417
 
@@ -484,7 +484,7 @@ class mmTEM(Module):
484
484
  self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
485
485
 
486
486
  self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
487
- self.assoc_scan = AssocScan(*assoc_scan_kwargs)
487
+ self.assoc_scan = AssocScan(**assoc_scan_kwargs)
488
488
 
489
489
  self.meta_memory_mlp = create_mlp(
490
490
  dim = dim * 2,
@@ -495,7 +495,7 @@ class mmTEM(Module):
495
495
  )
496
496
 
497
497
  def forward_with_mse_loss(params, keys, values):
498
- pred = functional_call(self.meta_memory_mlp, params, keys)
498
+ pred = functional_call(self.meta_memory_mlp, params, (keys,))
499
499
  return F.mse_loss(pred, values)
500
500
 
501
501
  grad_fn = grad(forward_with_mse_loss)
@@ -596,7 +596,7 @@ class mmTEM(Module):
596
596
 
597
597
  # 2b. structure from structure
598
598
 
599
- decoded_structure, decoded_encoded_sensory = self.retrieve(zeros_like(structural_codes), encoded_sensory)
599
+ decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
600
600
 
601
601
  structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
602
602
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=CchBXGVmVioT2eBLGp44wvC8flv9e-TQLLovcz9G0ts,20842
3
+ hippoformer-0.0.16.dist-info/METADATA,sha256=L77ovZLV2uxu_AIRXNTEJf6h7K5xy21WGHDanJbVC6M,3093
4
+ hippoformer-0.0.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ hippoformer-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.16.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=axRDjuVWDKNTz0UL40LmzAEy1n59X65rkYjGHWzoN9w,20836
3
- hippoformer-0.0.15.dist-info/METADATA,sha256=mjFwFIPpy4dpFljqO-saQUivOuybq9YRBzQIkYe__6g,3093
4
- hippoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.15.dist-info/RECORD,,