metacontroller-pytorch 0.0.10__tar.gz → 0.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/PKG-INFO +3 -2
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/metacontroller/metacontroller.py +9 -4
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/pyproject.toml +3 -2
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/README.md +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/metacontroller/__init__.py +0 -0
- {metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/tests/test_metacontroller.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.12
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan>=0.0.3
|
|
38
|
-
Requires-Dist: discrete-continuous-embed-readout>=0.1.
|
|
38
|
+
Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: loguru
|
|
41
42
|
Requires-Dist: torch>=2.5
|
|
42
43
|
Requires-Dist: x-evolution>=0.1.23
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
{metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/metacontroller/metacontroller.py
RENAMED
|
@@ -3,6 +3,7 @@ from contextlib import nullcontext
|
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
5
|
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
from torch import nn, cat, stack, tensor
|
|
@@ -130,7 +131,8 @@ class MetaController(Module):
|
|
|
130
131
|
residual_stream,
|
|
131
132
|
cache: MetaControllerOutput | None = None,
|
|
132
133
|
discovery_phase = False,
|
|
133
|
-
hard_switch = False
|
|
134
|
+
hard_switch = False,
|
|
135
|
+
temperature = 1.
|
|
134
136
|
):
|
|
135
137
|
|
|
136
138
|
# destruct prev cache
|
|
@@ -142,6 +144,8 @@ class MetaController(Module):
|
|
|
142
144
|
next_action_proposer_hidden = None
|
|
143
145
|
|
|
144
146
|
if discovery_phase:
|
|
147
|
+
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
148
|
+
|
|
145
149
|
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
146
150
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
147
151
|
|
|
@@ -157,7 +161,7 @@ class MetaController(Module):
|
|
|
157
161
|
|
|
158
162
|
action_dist = readout(proposed_action_hidden)
|
|
159
163
|
|
|
160
|
-
sampled_action = readout.sample(action_dist)
|
|
164
|
+
sampled_action = readout.sample(action_dist, temperature = temperature)
|
|
161
165
|
|
|
162
166
|
# switching unit timer
|
|
163
167
|
|
|
@@ -287,8 +291,9 @@ class Transformer(Module):
|
|
|
287
291
|
discovery_phase = False,
|
|
288
292
|
no_grad_transformer = None,
|
|
289
293
|
no_grad_meta_controller = None,
|
|
294
|
+
meta_controller_temperature = 1.,
|
|
290
295
|
return_latents = False,
|
|
291
|
-
return_cache = False
|
|
296
|
+
return_cache = False,
|
|
292
297
|
):
|
|
293
298
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
294
299
|
|
|
@@ -319,7 +324,7 @@ class Transformer(Module):
|
|
|
319
324
|
with meta_controller_context():
|
|
320
325
|
|
|
321
326
|
if exists(meta_controller):
|
|
322
|
-
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
|
|
327
|
+
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
323
328
|
else:
|
|
324
329
|
modified_residual_stream, next_meta_hiddens = residual_stream, None
|
|
325
330
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.12"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,9 +25,10 @@ classifiers=[
|
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"assoc-scan>=0.0.3",
|
|
28
|
+
"discrete-continuous-embed-readout>=0.1.12",
|
|
28
29
|
"einx>=0.3.0",
|
|
29
30
|
"einops>=0.8.1",
|
|
30
|
-
"
|
|
31
|
+
"loguru",
|
|
31
32
|
"torch>=2.5",
|
|
32
33
|
"x-evolution>=0.1.23",
|
|
33
34
|
"x-mlps-pytorch",
|
{metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{metacontroller_pytorch-0.0.10 → metacontroller_pytorch-0.0.12}/tests/test_metacontroller.py
RENAMED
|
File without changes
|