hippoformer 0.0.7__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,8 @@ from torch.nn import Module
7
7
  from torch.jit import ScriptModule, script_method
8
8
  from torch.func import vmap, grad, functional_call
9
9
 
10
+ from beartype import beartype
11
+
10
12
  from einx import multiply
11
13
  from einops import repeat, rearrange, pack, unpack
12
14
  from einops.layers.torch import Rearrange
@@ -36,6 +38,114 @@ def pack_with_inverse(t, pattern):
36
38
  def l2norm(t):
37
39
  return F.normalize(t, dim = -1)
38
40
 
41
+ # Muon - Jordan et al from oss community - applied to the latest version of titans
42
+
43
+ def newtonschulz5(
44
+ t,
45
+ steps = 5,
46
+ eps = 1e-7,
47
+ coefs = (3.4445, -4.7750, 2.0315)
48
+ ):
49
+ not_weights = t.ndim <= 3
50
+
51
+ if not_weights:
52
+ return t
53
+
54
+ shape = t.shape
55
+ should_transpose = shape[-2] > shape[-1]
56
+
57
+ if should_transpose:
58
+ t = t.transpose(-1, -2)
59
+
60
+ t, inv_pack = pack_with_inverse(t, '* i j')
61
+ t = t / t.norm(dim = (-1, -2), keepdim = True).clamp(min = eps)
62
+
63
+ a, b, c = coefs
64
+
65
+ for _ in range(steps):
66
+ A = t @ t.transpose(-1, -2)
67
+ B = b * A + c * A @ A
68
+ t = a * t + B @ t
69
+
70
+ if should_transpose:
71
+ t = t.transpose(-1, -2)
72
+
73
+ return inv_pack(t)
74
+
75
+ # sensory encoder decoder for 2d
76
+
77
+ grid_sensory_enc_dec = (
78
+ create_mlp(
79
+ dim = 32 * 2,
80
+ dim_in = 9,
81
+ dim_out = 32,
82
+ depth = 3,
83
+ ),
84
+ create_mlp(
85
+ dim = 32 * 2,
86
+ dim_in = 32,
87
+ dim_out = 9,
88
+ depth = 3,
89
+ ),
90
+ )
91
+
92
+ # sensory encoder decoder for 3d maze
93
+
94
+ class EncoderPackTime(Module):
95
+ def __init__(self, fn: Module):
96
+ super().__init__()
97
+ self.fn = fn
98
+
99
+ def forward(self, x):
100
+ x = rearrange(x, 'b c t h w -> b t c h w')
101
+ x, packed_shape = pack([x], '* c h w')
102
+
103
+ x = self.fn(x)
104
+
105
+ x, = unpack(x, packed_shape, '* d')
106
+ print(x.shape)
107
+ return x
108
+
109
+ class DecoderPackTime(Module):
110
+ def __init__(self, fn: Module):
111
+ super().__init__()
112
+ self.fn = fn
113
+
114
+ def forward(self, x):
115
+ x, packed_shape = pack(x, '* d')
116
+
117
+ x = self.fn(x)
118
+
119
+ x = unpack(x, packed_shape, '* c h w')
120
+ x = rearrange(x, 'b t c h w -> b c t h w')
121
+ return x
122
+
123
+ maze_sensory_enc_dec = (
124
+ EncoderPackTime(nn.Sequential(
125
+ nn.Conv2d(3, 16, 7, 2, padding = 3),
126
+ nn.ReLU(),
127
+ nn.Conv2d(16, 32, 3, 2, 1),
128
+ nn.ReLU(),
129
+ nn.Conv2d(32, 64, 3, 2, 1),
130
+ nn.ReLU(),
131
+ nn.Conv2d(64, 128, 3, 2, 1),
132
+ nn.ReLU(),
133
+ Rearrange('b ... -> b (...)'),
134
+ nn.Linear(2048, 32)
135
+ )),
136
+ DecoderPackTime(nn.Sequential(
137
+ nn.Linear(32, 2048),
138
+ Rearrange('b (c h w) -> b c h w', c = 128, h = 4),
139
+ nn.ConvTranspose2d(128, 64, 3, 2, 1, output_padding = (1, 1)),
140
+ nn.ReLU(),
141
+ nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding = (1, 1)),
142
+ nn.ReLU(),
143
+ nn.ConvTranspose2d(32, 16, 3, 2, 1, output_padding = (1, 1)),
144
+ nn.ReLU(),
145
+ nn.ConvTranspose2d(16, 3, 3, 2, 1, output_padding = (1, 1))
146
+ ))
147
+ )
148
+
39
149
  # path integration
40
150
 
41
151
  class RNN(ScriptModule):
@@ -114,12 +224,12 @@ class PathIntegration(Module):
114
224
  # proposed mmTEM
115
225
 
116
226
  class mmTEM(Module):
227
+ @beartype
117
228
  def __init__(
118
229
  self,
119
230
  dim,
120
231
  *,
121
- sensory_encoder: Module,
122
- sensory_decoder: Module,
232
+ sensory_encoder_decoder: tuple[Module, Module],
123
233
  dim_sensory,
124
234
  dim_action,
125
235
  dim_encoded_sensory,
@@ -133,12 +243,15 @@ class mmTEM(Module):
133
243
  loss_weight_consistency = 1.,
134
244
  loss_weight_relational = 1.,
135
245
  integration_ratio_learned = True,
246
+ muon_update = False,
136
247
  assoc_scan_kwargs: dict = dict()
137
248
  ):
138
249
  super().__init__()
139
250
 
140
251
  # sensory
141
252
 
253
+ sensory_encoder, sensory_decoder = sensory_encoder_decoder
254
+
142
255
  self.sensory_encoder = sensory_encoder
143
256
  self.sensory_decoder = sensory_decoder
144
257
 
@@ -209,6 +322,10 @@ class mmTEM(Module):
209
322
  self.loss_weight_consistency = loss_weight_consistency
210
323
  self.register_buffer('zero', tensor(0.), persistent = False)
211
324
 
325
+ # update with muon
326
+
327
+ self.muon_update = muon_update
328
+
212
329
  # there is an integration ratio for error correction, but unclear what value this is fixed to or whether it is learned
213
330
 
214
331
  self.integration_ratio = nn.Parameter(tensor(0.), requires_grad = integration_ratio_learned)
@@ -338,6 +455,13 @@ class mmTEM(Module):
338
455
 
339
456
  update = self.assoc_scan(grad, expanded_beta.sigmoid(), momentum)
340
457
 
458
+ # maybe muon
459
+
460
+ if self.muon_update:
461
+ update = newtonschulz5(update)
462
+
463
+ # with forget gating
464
+
341
465
  expanded_forget = repeat(forget, 'b t -> b t w', w = grad.shape[-1])
342
466
 
343
467
  acc_update = self.assoc_scan(update, expanded_forget.sigmoid())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hippoformer
3
- Version: 0.0.7
3
+ Version: 0.0.9
4
4
  Summary: hippoformer
5
5
  Project-URL: Homepage, https://pypi.org/project/hippoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/hippoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.1
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: torch>=2.4
@@ -50,8 +51,6 @@ Description-Content-Type: text/markdown
50
51
 
51
52
  Implementation of [Hippoformer](https://openreview.net/forum?id=hxwV5EubAw), Integrating Hippocampus-inspired Spatial Memory with Transformers
52
53
 
53
- [Temporary Discord](https://discord.gg/MkACrrkrYR)
54
-
55
54
  ## Citations
56
55
 
57
56
  ```bibtex
@@ -0,0 +1,6 @@
1
+ hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
+ hippoformer/hippoformer.py,sha256=m7luQGFdMWOkZUorjd5v34hx_vjOQqpJOAGCL0njHUE,14426
3
+ hippoformer-0.0.9.dist-info/METADATA,sha256=owgkDcdTf0_N5IbUr3e_yt7u5sIWfOMha-hA5LQWnus,2772
4
+ hippoformer-0.0.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ hippoformer-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ hippoformer-0.0.9.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- hippoformer/__init__.py,sha256=A7N8GsRAZH4yP-L5hb7IVDnNjnhfjNyolg5MZ6vnGyE,71
2
- hippoformer/hippoformer.py,sha256=yYoJ5XO0YVAyp3LcRxpunU-0HA97mpCBeQFyi-NSkF0,11549
3
- hippoformer-0.0.7.dist-info/METADATA,sha256=Xg6NZ6VAQGmuiOo8mMwIAM39Gf6TpVOpyn7o4PMq7JE,2800
4
- hippoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- hippoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- hippoformer-0.0.7.dist-info/RECORD,,