hippoformer 0.0.16__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.16
3
+ Version: 0.0.17
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -300,8 +300,8 @@ class Attention(Module):
300
300
 
301
301
  i, j = sim.shape[-2:]
302
302
 
303
- j_seq = arange(i, device = device)[:, None]
304
- i_seq = arange(j, device = device)[None, :] + (j - i)
303
+ i_seq = arange(i, device = device)[:, None] + (j - i)
304
+ j_seq = arange(j, device = device)[None, :]
305
305
 
306
306
  windowed_causal_mask_without_diagonal = (i_seq > j_seq) & ((i_seq - j_seq) <= self.window_size)
307
307
 
@@ -407,18 +407,22 @@ class TEMTransformer(Module):
407
407
  self,
408
408
  sensory,
409
409
  actions,
410
- prev_hiddens = None, # for the GRU based path integrator
411
- prev_kv_cache = None # for the specialized transformer blocks for inducing the grid-cells
410
+ prev_hiddens = None, # for the GRU based path integrator
411
+ prev_kv_cache = None, # for the specialized transformer blocks for inducing the grid-cells
412
+ return_memories = False
412
413
  ):
413
414
 
414
415
  structure, next_hiddens = self.path_integrator(actions, prev_hiddens)
415
416
 
416
417
  encoded_sensory = self.sensory_encoder(sensory)
417
418
 
419
+ prev_kv_cache = default(prev_kv_cache, (None,) * len(self.layers))
420
+ iter_prev_kv_cache = iter(prev_kv_cache)
421
+
418
422
  next_kv_cache = []
419
423
 
420
424
  for layer in self.layers:
421
- structure, layer_next_cache = layer(structure, encoded_sensory)
425
+ structure, layer_next_cache = layer(structure, encoded_sensory, kv_cache = next(iter_prev_kv_cache, None))
422
426
  next_kv_cache.append(layer_next_cache)
423
427
 
424
428
  decoded_sensory = self.sensory_decoder(structure)
@@ -427,7 +431,10 @@ class TEMTransformer(Module):
427
431
 
428
432
  pred_loss = F.mse_loss(encoded_sensory, decoded_sensory)
429
433
 
430
- return pred_loss
434
+ if not return_memories:
435
+ return pred_loss
436
+
437
+ return pred_loss, next_memories
431
438
 
432
439
  # proposed mmTEM
433
440
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "hippoformer"
3
- version = "0.0.16"
3
+ version = "0.0.17"
4
4
  description = "hippoformer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes