hippoformer 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +20 -13
- {hippoformer-0.0.15.dist-info → hippoformer-0.0.17.dist-info}/METADATA +1 -1
- hippoformer-0.0.17.dist-info/RECORD +6 -0
- {hippoformer-0.0.15.dist-info → hippoformer-0.0.17.dist-info}/WHEEL +1 -1
- hippoformer-0.0.15.dist-info/RECORD +0 -6
- {hippoformer-0.0.15.dist-info → hippoformer-0.0.17.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -285,7 +285,7 @@ class Attention(Module):
|
|
|
285
285
|
q, k1, v1, k2, v2 = tuple(self.split_heads(t) for t in (q, k1, v1, k2, v2))
|
|
286
286
|
|
|
287
287
|
if exists(kv_cache):
|
|
288
|
-
ck1, cv1,
|
|
288
|
+
ck1, cv1, ck2, cv2 = kv_cache
|
|
289
289
|
k1 = cat((ck1, k1), dim = -2)
|
|
290
290
|
v1 = cat((cv1, v1), dim = -2)
|
|
291
291
|
k2 = cat((ck2, k2), dim = -2)
|
|
@@ -300,12 +300,12 @@ class Attention(Module):
|
|
|
300
300
|
|
|
301
301
|
i, j = sim.shape[-2:]
|
|
302
302
|
|
|
303
|
-
|
|
304
|
-
|
|
303
|
+
i_seq = arange(i, device = device)[:, None] + (j - i)
|
|
304
|
+
j_seq = arange(j, device = device)[None, :]
|
|
305
305
|
|
|
306
306
|
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
307
307
|
|
|
308
|
-
sim = sim.masked_fill(windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
308
|
+
sim = sim.masked_fill(~windowed_causal_mask_without_diagonal, -torch.finfo(sim.dtype).max)
|
|
309
309
|
|
|
310
310
|
# attention sink, for token as well as for attention sinking - from gpt-oss
|
|
311
311
|
|
|
@@ -401,24 +401,28 @@ class TEMTransformer(Module):
|
|
|
401
401
|
**transformer_kwargs
|
|
402
402
|
)
|
|
403
403
|
|
|
404
|
-
layers.append(block)
|
|
404
|
+
self.layers.append(block)
|
|
405
405
|
|
|
406
406
|
def forward(
|
|
407
407
|
self,
|
|
408
408
|
sensory,
|
|
409
409
|
actions,
|
|
410
|
-
prev_hiddens = None,
|
|
411
|
-
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
410
|
+
prev_hiddens = None, # for the GRU based path integrator
|
|
411
|
+
prev_kv_cache = None, # for the specialized transformer blocks for inducing the grid-cells
|
|
412
|
+
return_memories = False
|
|
412
413
|
):
|
|
413
414
|
|
|
414
|
-
structure, next_hiddens = self.
|
|
415
|
+
structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
|
|
415
416
|
|
|
416
417
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
418
|
|
|
419
|
+
prev_kv_cache = default(prev_kv_cache, (None,) * len(self.layers))
|
|
420
|
+
iter_prev_kv_cache = iter(prev_kv_cache)
|
|
421
|
+
|
|
418
422
|
next_kv_cache = []
|
|
419
423
|
|
|
420
424
|
for layer in self.layers:
|
|
421
|
-
structure, layer_next_cache = layer(structure, encoded_sensory)
|
|
425
|
+
structure, layer_next_cache = layer(structure, encoded_sensory, kv_cache = next(iter_prev_kv_cache, None))
|
|
422
426
|
next_kv_cache.append(layer_next_cache)
|
|
423
427
|
|
|
424
428
|
decoded_sensory = self.sensory_decoder(structure)
|
|
@@ -427,7 +431,10 @@ class TEMTransformer(Module):
|
|
|
427
431
|
|
|
428
432
|
pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
|
|
429
433
|
|
|
430
|
-
|
|
434
|
+
if not return_memories:
|
|
435
|
+
return pred_loss
|
|
436
|
+
|
|
437
|
+
return pred_loss, next_memories
|
|
431
438
|
|
|
432
439
|
# proposed mmTEM
|
|
433
440
|
|
|
@@ -484,7 +491,7 @@ class mmTEM(Module):
|
|
|
484
491
|
self.to_values = nn.Linear(dim_joint_rep, dim, bias = False)
|
|
485
492
|
|
|
486
493
|
self.to_learned_optim_hparams = nn.Linear(dim_joint_rep, 3, bias = False) # for learning rate, forget gate, and momentum
|
|
487
|
-
self.assoc_scan = AssocScan(
|
|
494
|
+
self.assoc_scan = AssocScan(**assoc_scan_kwargs)
|
|
488
495
|
|
|
489
496
|
self.meta_memory_mlp = create_mlp(
|
|
490
497
|
dim = dim * 2,
|
|
@@ -495,7 +502,7 @@ class mmTEM(Module):
|
|
|
495
502
|
)
|
|
496
503
|
|
|
497
504
|
def forward_with_mse_loss(params, keys, values):
|
|
498
|
-
pred = functional_call(self.meta_memory_mlp, params, keys)
|
|
505
|
+
pred = functional_call(self.meta_memory_mlp, params, (keys,))
|
|
499
506
|
return F.mse_loss(pred, values)
|
|
500
507
|
|
|
501
508
|
grad_fn = grad(forward_with_mse_loss)
|
|
@@ -596,7 +603,7 @@ class mmTEM(Module):
|
|
|
596
603
|
|
|
597
604
|
# 2b. structure from structure
|
|
598
605
|
|
|
599
|
-
decoded_structure, decoded_encoded_sensory = self.retrieve(
|
|
606
|
+
decoded_structure, decoded_encoded_sensory = self.retrieve(structural_codes, zeros_like(encoded_sensory))
|
|
600
607
|
|
|
601
608
|
structure_from_structure_loss = F.mse_loss(decoded_structure, structural_codes)
|
|
602
609
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=fWgCak1szOJpoL-6bFuoJk8Z7RGxjlcL7-8kmGYhAfU,21121
|
|
3
|
+
hippoformer-0.0.17.dist-info/METADATA,sha256=XfCPGG4gyxvTSrZzV-H-L5I6WzN7MKhYXG0r7puPp1w,3093
|
|
4
|
+
hippoformer-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
hippoformer-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.17.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=axRDjuVWDKNTz0UL40LmzAEy1n59X65rkYjGHWzoN9w,20836
|
|
3
|
-
hippoformer-0.0.15.dist-info/METADATA,sha256=mjFwFIPpy4dpFljqO-saQUivOuybq9YRBzQIkYe__6g,3093
|
|
4
|
-
hippoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
hippoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.15.dist-info/RECORD,,
|
|
File without changes
|