metacontroller-pytorch 0.0.6__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +40 -16
- {metacontroller_pytorch-0.0.6.dist-info → metacontroller_pytorch-0.0.9.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.9.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.6.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.6.dist-info → metacontroller_pytorch-0.0.9.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.6.dist-info → metacontroller_pytorch-0.0.9.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
2
4
|
from functools import partial
|
|
3
5
|
from collections import namedtuple
|
|
4
6
|
|
|
@@ -113,9 +115,7 @@ class MetaController(Module):
|
|
|
113
115
|
def internal_rl_parameters(self):
|
|
114
116
|
return [
|
|
115
117
|
*self.action_proposer.parameters(),
|
|
116
|
-
*self.action_proposer_mean_log_var.parameters()
|
|
117
|
-
*self.decoder.parameters(),
|
|
118
|
-
*self.switch_gating
|
|
118
|
+
*self.action_proposer_mean_log_var.parameters()
|
|
119
119
|
]
|
|
120
120
|
|
|
121
121
|
def forward(
|
|
@@ -150,8 +150,6 @@ class MetaController(Module):
|
|
|
150
150
|
|
|
151
151
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
152
|
|
|
153
|
-
# switch_beta = switch_beta.expand_as(sampled_action)
|
|
154
|
-
|
|
155
153
|
# need to encourage normal distribution
|
|
156
154
|
|
|
157
155
|
vae_kl_loss = self.zero
|
|
@@ -233,13 +231,15 @@ class Transformer(Module):
|
|
|
233
231
|
|
|
234
232
|
def evolve(
|
|
235
233
|
self,
|
|
234
|
+
num_generations,
|
|
236
235
|
environment,
|
|
237
236
|
**kwargs
|
|
238
237
|
):
|
|
239
|
-
assert exists(self.meta_controller), '`meta_controller` must be defined on init for evolutionary strategies to be straightforwardly applied'
|
|
238
|
+
assert exists(self.meta_controller), '`meta_controller` must be passed in or defined on init for evolutionary strategies to be straightforwardly applied'
|
|
240
239
|
|
|
241
240
|
evo_strat = EvoStrategy(
|
|
242
241
|
self,
|
|
242
|
+
num_generations = num_generations,
|
|
243
243
|
environment = environment,
|
|
244
244
|
params_to_optimize = self.meta_controller.internal_rl_parameters(),
|
|
245
245
|
**kwargs
|
|
@@ -252,26 +252,50 @@ class Transformer(Module):
|
|
|
252
252
|
ids,
|
|
253
253
|
meta_controller: Module | None = None,
|
|
254
254
|
discovery_phase = False,
|
|
255
|
-
return_latents = False
|
|
255
|
+
return_latents = False,
|
|
256
|
+
no_grad_transformer = None,
|
|
257
|
+
no_grad_meta_controller = None
|
|
256
258
|
):
|
|
257
259
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
258
260
|
|
|
259
|
-
|
|
261
|
+
meta_controlling = exists(meta_controller)
|
|
262
|
+
|
|
263
|
+
# by default, if meta controller is passed in, transformer is no grad
|
|
264
|
+
|
|
265
|
+
no_grad_transformer = default(no_grad_transformer, meta_controlling)
|
|
266
|
+
no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
|
|
267
|
+
|
|
268
|
+
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
269
|
+
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
260
270
|
|
|
261
|
-
|
|
271
|
+
# transformer lower body
|
|
272
|
+
|
|
273
|
+
with transformer_context():
|
|
274
|
+
|
|
275
|
+
embed = self.embed(ids)
|
|
276
|
+
|
|
277
|
+
residual_stream = self.lower_body(embed)
|
|
262
278
|
|
|
263
279
|
# meta controller acts on residual stream here
|
|
264
280
|
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
281
|
+
with meta_controller_context():
|
|
282
|
+
|
|
283
|
+
if exists(meta_controller):
|
|
284
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
285
|
+
else:
|
|
286
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
|
|
287
|
+
|
|
288
|
+
# modified residual stream sent back to transformer upper body
|
|
289
|
+
|
|
290
|
+
with transformer_context():
|
|
291
|
+
|
|
292
|
+
attended = self.upper_body(modified_residual_stream)
|
|
269
293
|
|
|
270
|
-
|
|
294
|
+
# head readout
|
|
271
295
|
|
|
272
|
-
|
|
296
|
+
dist_params = self.readout(attended)
|
|
273
297
|
|
|
274
|
-
|
|
298
|
+
# returning
|
|
275
299
|
|
|
276
300
|
if not return_latents:
|
|
277
301
|
return dist_params
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=V2Nb7ByGj310CalTzho-grwNsoHMp55oN5spkedJihc,9189
|
|
3
|
+
metacontroller_pytorch-0.0.9.dist-info/METADATA,sha256=BA4AHlFW8DsD_NPXNv8N8rmRPISZNTkcjvGautB7xJA,3713
|
|
4
|
+
metacontroller_pytorch-0.0.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.9.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=8AWkPlZWh2A1sRT6nMV0CGHuYhQ5pHEpd5bgyRmZelg,8316
|
|
3
|
-
metacontroller_pytorch-0.0.6.dist-info/METADATA,sha256=vXT_-n3bHgpddnS5axyyc-cADGNk4l2enJv4g4cTJ7A,3713
|
|
4
|
-
metacontroller_pytorch-0.0.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.6.dist-info → metacontroller_pytorch-0.0.9.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|