metacontroller-pytorch 0.0.10__py3-none-any.whl → 0.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +14 -13
- {metacontroller_pytorch-0.0.10.dist-info → metacontroller_pytorch-0.0.14.dist-info}/METADATA +3 -2
- metacontroller_pytorch-0.0.14.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.10.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.10.dist-info → metacontroller_pytorch-0.0.14.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.10.dist-info → metacontroller_pytorch-0.0.14.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -3,6 +3,7 @@ from contextlib import nullcontext
|
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
5
|
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
from torch import nn, cat, stack, tensor
|
|
@@ -130,7 +131,8 @@ class MetaController(Module):
|
|
|
130
131
|
residual_stream,
|
|
131
132
|
cache: MetaControllerOutput | None = None,
|
|
132
133
|
discovery_phase = False,
|
|
133
|
-
hard_switch = False
|
|
134
|
+
hard_switch = False,
|
|
135
|
+
temperature = 1.
|
|
134
136
|
):
|
|
135
137
|
|
|
136
138
|
# destruct prev cache
|
|
@@ -142,6 +144,8 @@ class MetaController(Module):
|
|
|
142
144
|
next_action_proposer_hidden = None
|
|
143
145
|
|
|
144
146
|
if discovery_phase:
|
|
147
|
+
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
148
|
+
|
|
145
149
|
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
146
150
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
147
151
|
|
|
@@ -157,7 +161,7 @@ class MetaController(Module):
|
|
|
157
161
|
|
|
158
162
|
action_dist = readout(proposed_action_hidden)
|
|
159
163
|
|
|
160
|
-
sampled_action = readout.sample(action_dist)
|
|
164
|
+
sampled_action = readout.sample(action_dist, temperature = temperature)
|
|
161
165
|
|
|
162
166
|
# switching unit timer
|
|
163
167
|
|
|
@@ -285,10 +289,9 @@ class Transformer(Module):
|
|
|
285
289
|
meta_controller: Module | None = None,
|
|
286
290
|
cache: TransformerOutput | None = None,
|
|
287
291
|
discovery_phase = False,
|
|
288
|
-
|
|
289
|
-
no_grad_meta_controller = None,
|
|
292
|
+
meta_controller_temperature = 1.,
|
|
290
293
|
return_latents = False,
|
|
291
|
-
return_cache = False
|
|
294
|
+
return_cache = False,
|
|
292
295
|
):
|
|
293
296
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
294
297
|
|
|
@@ -296,11 +299,9 @@ class Transformer(Module):
|
|
|
296
299
|
|
|
297
300
|
# by default, if meta controller is passed in, transformer is no grad
|
|
298
301
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
303
|
-
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
302
|
+
lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
|
|
303
|
+
meta_controller_context = nullcontext if meta_controlling else torch.no_grad
|
|
304
|
+
upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
|
|
304
305
|
|
|
305
306
|
# handle cache
|
|
306
307
|
|
|
@@ -308,7 +309,7 @@ class Transformer(Module):
|
|
|
308
309
|
|
|
309
310
|
# transformer lower body
|
|
310
311
|
|
|
311
|
-
with
|
|
312
|
+
with lower_transformer_context():
|
|
312
313
|
|
|
313
314
|
embed = self.embed(ids)
|
|
314
315
|
|
|
@@ -319,13 +320,13 @@ class Transformer(Module):
|
|
|
319
320
|
with meta_controller_context():
|
|
320
321
|
|
|
321
322
|
if exists(meta_controller):
|
|
322
|
-
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
|
|
323
|
+
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
323
324
|
else:
|
|
324
325
|
modified_residual_stream, next_meta_hiddens = residual_stream, None
|
|
325
326
|
|
|
326
327
|
# modified residual stream sent back to transformer upper body
|
|
327
328
|
|
|
328
|
-
with
|
|
329
|
+
with upper_transformer_context():
|
|
329
330
|
|
|
330
331
|
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
331
332
|
|
{metacontroller_pytorch-0.0.10.dist-info → metacontroller_pytorch-0.0.14.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.14
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan>=0.0.3
|
|
38
|
-
Requires-Dist: discrete-continuous-embed-readout>=0.1.
|
|
38
|
+
Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: loguru
|
|
41
42
|
Requires-Dist: torch>=2.5
|
|
42
43
|
Requires-Dist: x-evolution>=0.1.23
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=rQJLyXMHHNCZm0iohWqAkcMpYSi8b1Z0dgB-8AJVqqo,10751
|
|
3
|
+
metacontroller_pytorch-0.0.14.dist-info/METADATA,sha256=-CP3Ak1NPaqTpyF4tTgwn-T47Pv2OiPzPFxecwGe3Ng,3736
|
|
4
|
+
metacontroller_pytorch-0.0.14.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.14.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=H-bZi70445-4JlhUFL8x_fgePY7bTxkDO4CCdItKao4,10642
|
|
3
|
-
metacontroller_pytorch-0.0.10.dist-info/METADATA,sha256=AFk9SUK6TGSG1APtt51yiASCEWIOTIvzAhtJJnS-Dsc,3714
|
|
4
|
-
metacontroller_pytorch-0.0.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.10.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.10.dist-info → metacontroller_pytorch-0.0.14.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|