metacontroller-pytorch 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/PKG-INFO +2 -2
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/metacontroller/metacontroller.py +27 -15
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/pyproject.toml +2 -2
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/README.md +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/tests/test_metacontroller.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
-
Requires-Dist: assoc-scan
|
|
37
|
+
Requires-Dist: assoc-scan>=0.0.3
|
|
38
38
|
Requires-Dist: discrete-continuous-embed-readout>=0.1.11
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
{metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/metacontroller/metacontroller.py
RENAMED
|
@@ -43,6 +43,11 @@ def default(*args):
|
|
|
43
43
|
return arg
|
|
44
44
|
return None
|
|
45
45
|
|
|
46
|
+
# tensor helpers
|
|
47
|
+
|
|
48
|
+
def straight_through(src, tgt):
|
|
49
|
+
return tgt + src - src.detach()
|
|
50
|
+
|
|
46
51
|
# meta controller
|
|
47
52
|
|
|
48
53
|
class MetaController(Module):
|
|
@@ -116,7 +121,8 @@ class MetaController(Module):
|
|
|
116
121
|
def forward(
|
|
117
122
|
self,
|
|
118
123
|
residual_stream,
|
|
119
|
-
discovery_phase = False
|
|
124
|
+
discovery_phase = False,
|
|
125
|
+
hard_switch = False
|
|
120
126
|
):
|
|
121
127
|
|
|
122
128
|
if discovery_phase:
|
|
@@ -136,6 +142,16 @@ class MetaController(Module):
|
|
|
136
142
|
|
|
137
143
|
sampled_action = readout.sample(action_dist)
|
|
138
144
|
|
|
145
|
+
# switching unit timer
|
|
146
|
+
|
|
147
|
+
batch, _, dim = sampled_action.shape
|
|
148
|
+
|
|
149
|
+
switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
|
|
150
|
+
|
|
151
|
+
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
|
+
|
|
153
|
+
# switch_beta = switch_beta.expand_as(sampled_action)
|
|
154
|
+
|
|
139
155
|
# need to encourage normal distribution
|
|
140
156
|
|
|
141
157
|
vae_kl_loss = self.zero
|
|
@@ -148,34 +164,30 @@ class MetaController(Module):
|
|
|
148
164
|
+ mean.square()
|
|
149
165
|
- log_var
|
|
150
166
|
- 1.
|
|
151
|
-
))
|
|
167
|
+
))
|
|
152
168
|
|
|
153
|
-
|
|
169
|
+
vae_kl_loss = vae_kl_loss * switch_beta
|
|
170
|
+
vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
|
|
154
171
|
|
|
155
|
-
|
|
172
|
+
# maybe hard switch, then use associative scan
|
|
156
173
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
|
|
162
|
-
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
174
|
+
if hard_switch:
|
|
175
|
+
hard_switch = (switch_beta > 0.5).float()
|
|
176
|
+
switch_beta = straight_through(switch_beta, hard_switch)
|
|
163
177
|
|
|
164
178
|
forget = 1. - switch_beta
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
|
|
179
|
+
gated_action = self.switch_gating(switch_beta, sampled_action * forget)
|
|
168
180
|
|
|
169
181
|
# decoder
|
|
170
182
|
|
|
171
|
-
decoder_out = self.decoder(
|
|
183
|
+
decoder_out = self.decoder(gated_action)
|
|
172
184
|
|
|
173
185
|
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
174
186
|
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
175
187
|
|
|
176
188
|
# generating the residual stream controlling signal
|
|
177
189
|
|
|
178
|
-
control_signal = einsum(
|
|
190
|
+
control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
179
191
|
|
|
180
192
|
modified_residual_stream = residual_stream + control_signal
|
|
181
193
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.6"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -24,7 +24,7 @@ classifiers=[
|
|
|
24
24
|
]
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
|
-
"assoc-scan",
|
|
27
|
+
"assoc-scan>=0.0.3",
|
|
28
28
|
"einx>=0.3.0",
|
|
29
29
|
"einops>=0.8.1",
|
|
30
30
|
"discrete-continuous-embed-readout>=0.1.11",
|
{metacontroller_pytorch-0.0.4 → metacontroller_pytorch-0.0.6}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|