metacontroller-pytorch 0.0.8__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/metacontroller/metacontroller.py +36 -10
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/README.md +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/tests/test_metacontroller.py +0 -0
{metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/metacontroller/metacontroller.py
RENAMED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
|
|
2
4
|
from functools import partial
|
|
3
5
|
from collections import namedtuple
|
|
4
6
|
|
|
@@ -250,26 +252,50 @@ class Transformer(Module):
|
|
|
250
252
|
ids,
|
|
251
253
|
meta_controller: Module | None = None,
|
|
252
254
|
discovery_phase = False,
|
|
253
|
-
return_latents = False
|
|
255
|
+
return_latents = False,
|
|
256
|
+
no_grad_transformer = None,
|
|
257
|
+
no_grad_meta_controller = None
|
|
254
258
|
):
|
|
255
259
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
256
260
|
|
|
257
|
-
|
|
261
|
+
meta_controlling = exists(meta_controller)
|
|
262
|
+
|
|
263
|
+
# by default, if meta controller is passed in, transformer is no grad
|
|
264
|
+
|
|
265
|
+
no_grad_transformer = default(no_grad_transformer, meta_controlling)
|
|
266
|
+
no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
|
|
267
|
+
|
|
268
|
+
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
269
|
+
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
270
|
+
|
|
271
|
+
# transformer lower body
|
|
272
|
+
|
|
273
|
+
with transformer_context():
|
|
258
274
|
|
|
259
|
-
|
|
275
|
+
embed = self.embed(ids)
|
|
276
|
+
|
|
277
|
+
residual_stream = self.lower_body(embed)
|
|
260
278
|
|
|
261
279
|
# meta controller acts on residual stream here
|
|
262
280
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
281
|
+
with meta_controller_context():
|
|
282
|
+
|
|
283
|
+
if exists(meta_controller):
|
|
284
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
285
|
+
else:
|
|
286
|
+
modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
|
|
287
|
+
|
|
288
|
+
# modified residual stream sent back to transformer upper body
|
|
289
|
+
|
|
290
|
+
with transformer_context():
|
|
291
|
+
|
|
292
|
+
attended = self.upper_body(modified_residual_stream)
|
|
267
293
|
|
|
268
|
-
|
|
294
|
+
# head readout
|
|
269
295
|
|
|
270
|
-
|
|
296
|
+
dist_params = self.readout(attended)
|
|
271
297
|
|
|
272
|
-
|
|
298
|
+
# returning
|
|
273
299
|
|
|
274
300
|
if not return_latents:
|
|
275
301
|
return dist_params
|
{metacontroller_pytorch-0.0.8 → metacontroller_pytorch-0.0.9}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|