hippoformer 0.0.16__py3-none-any.whl → 0.0.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hippoformer/hippoformer.py +13 -6
- {hippoformer-0.0.16.dist-info → hippoformer-0.0.17.dist-info}/METADATA +1 -1
- hippoformer-0.0.17.dist-info/RECORD +6 -0
- hippoformer-0.0.16.dist-info/RECORD +0 -6
- {hippoformer-0.0.16.dist-info → hippoformer-0.0.17.dist-info}/WHEEL +0 -0
- {hippoformer-0.0.16.dist-info → hippoformer-0.0.17.dist-info}/licenses/LICENSE +0 -0
hippoformer/hippoformer.py
CHANGED
|
@@ -300,8 +300,8 @@ class Attention(Module):
|
|
|
300
300
|
|
|
301
301
|
i, j = sim.shape[-2:]
|
|
302
302
|
|
|
303
|
-
|
|
304
|
-
|
|
303
|
+
i_seq = arange(i, device = device)[:, None] + (j - i)
|
|
304
|
+
j_seq = arange(j, device = device)[None, :]
|
|
305
305
|
|
|
306
306
|
windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
|
|
307
307
|
|
|
@@ -407,18 +407,22 @@ class TEMTransformer(Module):
|
|
|
407
407
|
self,
|
|
408
408
|
sensory,
|
|
409
409
|
actions,
|
|
410
|
-
prev_hiddens = None,
|
|
411
|
-
prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
|
|
410
|
+
prev_hiddens = None, # for the GRU based path integrator
|
|
411
|
+
prev_kv_cache = None, # for the specialized transformer blocks for inducing the grid-cells
|
|
412
|
+
return_memories = False
|
|
412
413
|
):
|
|
413
414
|
|
|
414
415
|
structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
|
|
415
416
|
|
|
416
417
|
encoded_sensory = self.sensory_encoder(sensory)
|
|
417
418
|
|
|
419
|
+
prev_kv_cache = default(prev_kv_cache, (None,) * len(self.layers))
|
|
420
|
+
iter_prev_kv_cache = iter(prev_kv_cache)
|
|
421
|
+
|
|
418
422
|
next_kv_cache = []
|
|
419
423
|
|
|
420
424
|
for layer in self.layers:
|
|
421
|
-
structure, layer_next_cache = layer(structure, encoded_sensory)
|
|
425
|
+
structure, layer_next_cache = layer(structure, encoded_sensory, kv_cache = next(iter_prev_kv_cache, None))
|
|
422
426
|
next_kv_cache.append(layer_next_cache)
|
|
423
427
|
|
|
424
428
|
decoded_sensory = self.sensory_decoder(structure)
|
|
@@ -427,7 +431,10 @@ class TEMTransformer(Module):
|
|
|
427
431
|
|
|
428
432
|
pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
|
|
429
433
|
|
|
430
|
-
|
|
434
|
+
if not return_memories:
|
|
435
|
+
return pred_loss
|
|
436
|
+
|
|
437
|
+
return pred_loss, next_memories
|
|
431
438
|
|
|
432
439
|
# proposed mmTEM
|
|
433
440
|
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
+
hippoformer/hippoformer.py,sha256=fWgCak1szOJpoL-6bFuoJk8Z7RGxjlcL7-8kmGYhAfU,21121
|
|
3
|
+
hippoformer-0.0.17.dist-info/METADATA,sha256=XfCPGG4gyxvTSrZzV-H-L5I6WzN7MKhYXG0r7puPp1w,3093
|
|
4
|
+
hippoformer-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
hippoformer-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
hippoformer-0.0.17.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
|
|
2
|
-
hippoformer/hippoformer.py,sha256=CchBXGVmVioT2eBLGp44wvC8flv9e-TQLLovcz9G0ts,20842
|
|
3
|
-
hippoformer-0.0.16.dist-info/METADATA,sha256=L77ovZLV2uxu_AIRXNTEJf6h7K5xy21WGHDanJbVC6M,3093
|
|
4
|
-
hippoformer-0.0.16.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
hippoformer-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
hippoformer-0.0.16.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|