metacontroller-pytorch 0.0.2__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/metacontroller/metacontroller.py +7 -2
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/README.md +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/tests/test_metacontroller.py +0 -0
{metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/metacontroller/metacontroller.py
RENAMED
|
@@ -215,6 +215,8 @@ class Transformer(Module):
|
|
|
215
215
|
|
|
216
216
|
self.meta_controller = meta_controller
|
|
217
217
|
|
|
218
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
219
|
+
|
|
218
220
|
def evolve(
|
|
219
221
|
self,
|
|
220
222
|
environment,
|
|
@@ -238,7 +240,7 @@ class Transformer(Module):
|
|
|
238
240
|
discovery_phase = False,
|
|
239
241
|
return_latents = False
|
|
240
242
|
):
|
|
241
|
-
meta_controller = default(meta_controller, self.meta_controller
|
|
243
|
+
meta_controller = default(meta_controller, self.meta_controller)
|
|
242
244
|
|
|
243
245
|
embed = self.embed(ids)
|
|
244
246
|
|
|
@@ -246,7 +248,10 @@ class Transformer(Module):
|
|
|
246
248
|
|
|
247
249
|
# meta controller acts on residual stream here
|
|
248
250
|
|
|
249
|
-
|
|
251
|
+
if exists(meta_controller):
|
|
252
|
+
modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
|
|
253
|
+
else:
|
|
254
|
+
modified_residual_stream, vae_aux_loss = residual_stream, self.zero
|
|
250
255
|
|
|
251
256
|
# modified residual stream sent back
|
|
252
257
|
|
{metacontroller_pytorch-0.0.2 → metacontroller_pytorch-0.0.3}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|