hippoformer 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +10 -10
- {hippoformer-0.0.14.dist-info → hippoformer-0.0.16.dist-info}/METADATA +1 -1
- hippoformer-0.0.16.dist-info/RECORD +6 -0
- {hippoformer-0.0.14.dist-info → hippoformer-0.0.16.dist-info}/WHEEL +1 -1
- hippoformer-0.0.14.dist-info/RECORD +0 -6
- {hippoformer-0.0.14.dist-info → hippoformer-0.0.16.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -285,7 +285,7 @@ class Attention(Module):
|
|
|
285
285
|
q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
|
|
286
286
|
|
|
287
287
|
if exists(kv_cache):
|
|
288
|
-
ck1, cv1,
|
|
288
|
+
ck1, cv1, ck2, cv2 = kv_cache
|
|
289
289
|
k1 = cat((ck1, k1), dim = -2)
|
|
290
290
|
v1 = cat((cv1, v1), dim = -2)
|
|
291
291
|
k2 = cat((ck2, k2), dim = -2)
|
|
@@ -300,12 +300,12 @@ class Attention(Module):
|
|
|
300
300
|
|
|
301
301
|
i, j = sim.shape[-2:]
|
|
302
302
|
|
|
303
|
-
j_seq = arange(
|
|
304
|
-
i_seq = arange(
|
|
303
|
+
j_seq = arange(i, device = device)[:, None]
|
|
304
|
+
i_seq = arange(j, device = device)[None, :] + (j - i)
|
|
305
305
|
|
|
306
306
|
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
307
307
|
|
|
308
|
-
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
308
|
+
sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
309
309
|
|
|
310
310
|
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
311
311
|
|
|
@@ -365,7 +365,7 @@ class TEMTransformerBlock(Module):
|
|
|
365
365
|
|
|
366
366
|
x = self.ff(x) + x
|
|
367
367
|
|
|
368
|
-
next_kv_cache =
|
|
368
|
+
next_kv_cache = tuple(t[:, -self.window_size:] for t in next_kv_cache)
|
|
369
369
|
|
|
370
370
|
return x, next_kv_cache
|
|
371
371
|
|
|
@@ -401,7 +401,7 @@ class TEMTransformer(Module):
|
|
|
401
401
|
**transformer_kwargs
|
|
402
402
|
)
|
|
403
403
|
|
|
404
|
-
layers.append(block)
|
|
404
|
+
self.layers.append(block)
|
|
405
405
|
|
|
406
406
|
def forward(
|
|
407
407
|
self,
|
|
@@ -411,7 +411,7 @@ class TEMTransformer(Module):
|
|
|
411
411
|
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
412
412
|
):
|
|
413
413
|
|
|
414
|
-
structure, next_hiddens = self.
|
|
414
|
+
structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
|
|
415
415
|
|
|
416
416
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
417
|
|
|
@@ -484,7 +484,7 @@ class mmTEM(Module):
|
|
|
484
484
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
485
485
|
|
|
486
486
|
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
487
|
-
self.assoc_scan = AssocScan(
|
|
487
|
+
self.assoc_scan = AssocScan(**assoc_scan_kwargs)
|
|
488
488
|
|
|
489
489
|
self.meta_memory_mlp = create_mlp(
|
|
490
490
|
dim = dim * 2,
|
|
@@ -495,7 +495,7 @@ class mmTEM(Module):
|
|
|
495
495
|
)
|
|
496
496
|
|
|
497
497
|
def forward_with_mse_loss(params, keys, values):
|
|
498
|
-
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
498
|
+
pred = functional_call(self.meta_memory_mlp, params, (keys,))
|
|
499
499
|
return F.mse_loss(pred, values)
|
|
500
500
|
|
|
501
501
|
grad_fn = grad(forward_with_mse_loss)
|
|
@@ -596,7 +596,7 @@ class mmTEM(Module):
|
|
|
596
596
|
|
|
597
597
|
# 2b. structure from structure
|
|
598
598
|
|
|
599
|
-
decoded_structure, decoded_encoded_sensory = self.retrieve(
|
|
599
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
600
600
|
|
|
601
601
|
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
602
602
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=CchBXGVmVioT2eBLGp44wvC8flv9e-TQLLovcz9G0ts,20842
|
|
3
|
+
hippoformer-0.0.16.dist-info/METADATA,sha256=L77ovZLV2uxu_AIRXNTEJf6h7K5xy21WGHDanJbVC6M,3093
|
|
4
|
+
hippoformer-0.0.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
hippoformer-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.16.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=RqHobJTA36mJ-YFutteMx5_6QrrSqI370C1sw5AYrUE,20818
|
|
3
|
-
hippoformer-0.0.14.dist-info/METADATA,sha256=W8k541yoURENT9RFpxZtu7O08r_R8joyzjj3ikwFbH0,3093
|
|
4
|
-
hippoformer-0.0.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.14.dist-info/RECORD,,
|
|
File without changes
|