metacontroller-pytorch 0.0.10__py3-none-any.whl → 0.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,6 +3,7 @@ from contextlib import nullcontext
3
3
 
4
4
  from functools import partial
5
5
  from collections import namedtuple
6
+ from loguru import logger
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, stack, tensor
@@ -130,7 +131,8 @@ class MetaController(Module):
130
131
  residual_stream,
131
132
  cache: MetaControllerOutput | None = None,
132
133
  discovery_phase = False,
133
- hard_switch = False
134
+ hard_switch = False,
135
+ temperature = 1.
134
136
  ):
135
137
 
136
138
  # destruct prev cache
@@ -142,6 +144,8 @@ class MetaController(Module):
142
144
  next_action_proposer_hidden = None
143
145
 
144
146
  if discovery_phase:
147
+ logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
+
145
149
  temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
146
150
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
147
151
 
@@ -157,7 +161,7 @@ class MetaController(Module):
157
161
 
158
162
  action_dist = readout(proposed_action_hidden)
159
163
 
160
- sampled_action = readout.sample(action_dist)
164
+ sampled_action = readout.sample(action_dist, temperature = temperature)
161
165
 
162
166
  # switching unit timer
163
167
 
@@ -287,8 +291,9 @@ class Transformer(Module):
287
291
  discovery_phase = False,
288
292
  no_grad_transformer = None,
289
293
  no_grad_meta_controller = None,
294
+ meta_controller_temperature = 1.,
290
295
  return_latents = False,
291
- return_cache = False
296
+ return_cache = False,
292
297
  ):
293
298
  meta_controller = default(meta_controller, self.meta_controller)
294
299
 
@@ -319,7 +324,7 @@ class Transformer(Module):
319
324
  with meta_controller_context():
320
325
 
321
326
  if exists(meta_controller):
322
- modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
327
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
323
328
  else:
324
329
  modified_residual_stream, next_meta_hiddens = residual_stream, None
325
330
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.10
3
+ Version: 0.0.12
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan>=0.0.3
38
- Requires-Dist: discrete-continuous-embed-readout>=0.1.11
38
+ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: loguru
41
42
  Requires-Dist: torch>=2.5
42
43
  Requires-Dist: x-evolution>=0.1.23
43
44
  Requires-Dist: x-mlps-pytorch
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=11Du7hg4Fj3epfOY42xa9IZO9b0bhquBddenWliwIMY,10956
3
+ metacontroller_pytorch-0.0.12.dist-info/METADATA,sha256=fKGk-rrrTXv1ueGzehcMV2k2G5n_Tmyk-VpXNo1gcFE,3736
4
+ metacontroller_pytorch-0.0.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.12.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=H-bZi70445-4JlhUFL8x_fgePY7bTxkDO4CCdItKao4,10642
3
- metacontroller_pytorch-0.0.10.dist-info/METADATA,sha256=AFk9SUK6TGSG1APtt51yiASCEWIOTIvzAhtJJnS-Dsc,3714
4
- metacontroller_pytorch-0.0.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.10.dist-info/RECORD,,