metacontroller-pytorch 0.0.5__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +9 -14
- {metacontroller_pytorch-0.0.5.dist-info → metacontroller_pytorch-0.0.8.dist-info}/METADATA +2 -2
- metacontroller_pytorch-0.0.8.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.5.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.5.dist-info → metacontroller_pytorch-0.0.8.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.5.dist-info → metacontroller_pytorch-0.0.8.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -113,9 +113,7 @@ class MetaController(Module):
|
|
|
113
113
|
def internal_rl_parameters(self):
|
|
114
114
|
return [
|
|
115
115
|
*self.action_proposer.parameters(),
|
|
116
|
-
*self.action_proposer_mean_log_var.parameters()
|
|
117
|
-
*self.decoder.parameters(),
|
|
118
|
-
*self.switch_gating
|
|
116
|
+
*self.action_proposer_mean_log_var.parameters()
|
|
119
117
|
]
|
|
120
118
|
|
|
121
119
|
def forward(
|
|
@@ -150,9 +148,6 @@ class MetaController(Module):
|
|
|
150
148
|
|
|
151
149
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
150
|
|
|
153
|
-
action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
|
|
154
|
-
switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
|
|
155
|
-
|
|
156
151
|
# need to encourage normal distribution
|
|
157
152
|
|
|
158
153
|
vae_kl_loss = self.zero
|
|
@@ -165,10 +160,10 @@ class MetaController(Module):
|
|
|
165
160
|
+ mean.square()
|
|
166
161
|
- log_var
|
|
167
162
|
- 1.
|
|
168
|
-
))
|
|
163
|
+
))
|
|
169
164
|
|
|
170
165
|
vae_kl_loss = vae_kl_loss * switch_beta
|
|
171
|
-
vae_kl_loss = vae_kl_loss.mean()
|
|
166
|
+
vae_kl_loss = vae_kl_loss.sum(dim = -1).mean()
|
|
172
167
|
|
|
173
168
|
# maybe hard switch, then use associative scan
|
|
174
169
|
|
|
@@ -177,20 +172,18 @@ class MetaController(Module):
|
|
|
177
172
|
switch_beta = straight_through(switch_beta, hard_switch)
|
|
178
173
|
|
|
179
174
|
forget = 1. - switch_beta
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
gated_action_intent = rearrange(gated_action_intent, '(b d) n -> b n d', b = batch)
|
|
175
|
+
gated_action = self.switch_gating(switch_beta, sampled_action * forget)
|
|
183
176
|
|
|
184
177
|
# decoder
|
|
185
178
|
|
|
186
|
-
decoder_out = self.decoder(
|
|
179
|
+
decoder_out = self.decoder(gated_action)
|
|
187
180
|
|
|
188
181
|
w1, w2 = self.to_hyper_network_weights(decoder_out)
|
|
189
182
|
hypernetwork_weight = einsum(w1, w2, '... i r, ... j r -> ... i j')
|
|
190
183
|
|
|
191
184
|
# generating the residual stream controlling signal
|
|
192
185
|
|
|
193
|
-
control_signal = einsum(
|
|
186
|
+
control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
194
187
|
|
|
195
188
|
modified_residual_stream = residual_stream + control_signal
|
|
196
189
|
|
|
@@ -236,13 +229,15 @@ class Transformer(Module):
|
|
|
236
229
|
|
|
237
230
|
def evolve(
|
|
238
231
|
self,
|
|
232
|
+
num_generations,
|
|
239
233
|
environment,
|
|
240
234
|
**kwargs
|
|
241
235
|
):
|
|
242
|
-
assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
|
|
236
|
+
assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
|
|
243
237
|
|
|
244
238
|
evo_strat = EvoStrategy(
|
|
245
239
|
self,
|
|
240
|
+
num_generations = num_generations,
|
|
246
241
|
environment = environment,
|
|
247
242
|
params_to_optimize = self.meta_controller.internal_rl_parameters(),
|
|
248
243
|
**kwargs
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.8
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -34,7 +34,7 @@ Classifier: License :: OSI Approved :: MIT License
|
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
|
-
Requires-Dist: assoc-scan
|
|
37
|
+
Requires-Dist: assoc-scan>=0.0.3
|
|
38
38
|
Requires-Dist: discrete-continuous-embed-readout>=0.1.11
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=cbo0F861KIcGIhJ1j-Js6Qwfm_-8nm5Sm0LJiRO7hl0,8265
|
|
3
|
+
metacontroller_pytorch-0.0.8.dist-info/METADATA,sha256=a7aUiVugnv5PJ-AZqnCyEczWsmuUS30s-3DsBKuThNQ,3713
|
|
4
|
+
metacontroller_pytorch-0.0.8.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.8.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.8.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=-glgWcv6QQZ6wVAy6tK2Ye8QNXuWBiGmB1rs6DVWA1I,8573
|
|
3
|
-
metacontroller_pytorch-0.0.5.dist-info/METADATA,sha256=CfXW_uO8B9gz31XkUO-2aVl4TN64iYicdZPtW7DzzHc,3706
|
|
4
|
-
metacontroller_pytorch-0.0.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.5.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.5.dist-info → metacontroller_pytorch-0.0.8.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|