metacontroller-pytorch 0.0.10__tar.gz → 0.0.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.10
3
+ Version: 0.0.14
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan>=0.0.3
38
- Requires-Dist: discrete-continuous-embed-readout>=0.1.11
38
+ Requires-Dist: discrete-continuous-embed-readout>=0.1.12
39
39
  Requires-Dist: einops>=0.8.1
40
40
  Requires-Dist: einx>=0.3.0
41
+ Requires-Dist: loguru
41
42
  Requires-Dist: torch>=2.5
42
43
  Requires-Dist: x-evolution>=0.1.23
43
44
  Requires-Dist: x-mlps-pytorch
@@ -3,6 +3,7 @@ from contextlib import nullcontext
3
3
 
4
4
  from functools import partial
5
5
  from collections import namedtuple
6
+ from loguru import logger
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, stack, tensor
@@ -130,7 +131,8 @@ class MetaController(Module):
130
131
  residual_stream,
131
132
  cache: MetaControllerOutput | None = None,
132
133
  discovery_phase = False,
133
- hard_switch = False
134
+ hard_switch = False,
135
+ temperature = 1.
134
136
  ):
135
137
 
136
138
  # destruct prev cache
@@ -142,6 +144,8 @@ class MetaController(Module):
142
144
  next_action_proposer_hidden = None
143
145
 
144
146
  if discovery_phase:
147
+ logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
+
145
149
  temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
146
150
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
147
151
 
@@ -157,7 +161,7 @@ class MetaController(Module):
157
161
 
158
162
  action_dist = readout(proposed_action_hidden)
159
163
 
160
- sampled_action = readout.sample(action_dist)
164
+ sampled_action = readout.sample(action_dist, temperature = temperature)
161
165
 
162
166
  # switching unit timer
163
167
 
@@ -285,10 +289,9 @@ class Transformer(Module):
285
289
  meta_controller: Module | None = None,
286
290
  cache: TransformerOutput | None = None,
287
291
  discovery_phase = False,
288
- no_grad_transformer = None,
289
- no_grad_meta_controller = None,
292
+ meta_controller_temperature = 1.,
290
293
  return_latents = False,
291
- return_cache = False
294
+ return_cache = False,
292
295
  ):
293
296
  meta_controller = default(meta_controller, self.meta_controller)
294
297
 
@@ -296,11 +299,9 @@ class Transformer(Module):
296
299
 
297
300
  # by default, if meta controller is passed in, transformer is no grad
298
301
 
299
- no_grad_transformer = default(no_grad_transformer, meta_controlling)
300
- no_grad_meta_controller = default(no_grad_meta_controller, no_grad_transformer) # by default, if transformer is eval no grad then meta controller is being learnt
301
-
302
- transformer_context = torch.no_grad if no_grad_transformer else nullcontext
303
- meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
302
+ lower_transformer_context = nullcontext if not meta_controlling else torch.no_grad
303
+ meta_controller_context = nullcontext if meta_controlling else torch.no_grad
304
+ upper_transformer_context = nullcontext if meta_controlling and discovery_phase else torch.no_grad
304
305
 
305
306
  # handle cache
306
307
 
@@ -308,7 +309,7 @@ class Transformer(Module):
308
309
 
309
310
  # transformer lower body
310
311
 
311
- with transformer_context():
312
+ with lower_transformer_context():
312
313
 
313
314
  embed = self.embed(ids)
314
315
 
@@ -319,13 +320,13 @@ class Transformer(Module):
319
320
  with meta_controller_context():
320
321
 
321
322
  if exists(meta_controller):
322
- modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
323
+ modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
323
324
  else:
324
325
  modified_residual_stream, next_meta_hiddens = residual_stream, None
325
326
 
326
327
  # modified residual stream sent back to transformer upper body
327
328
 
328
- with transformer_context():
329
+ with upper_transformer_context():
329
330
 
330
331
  attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
331
332
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.10"
3
+ version = "0.0.14"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -25,9 +25,10 @@ classifiers=[
25
25
 
26
26
  dependencies = [
27
27
  "assoc-scan>=0.0.3",
28
+ "discrete-continuous-embed-readout>=0.1.12",
28
29
  "einx>=0.3.0",
29
30
  "einops>=0.8.1",
30
- "discrete-continuous-embed-readout>=0.1.11",
31
+ "loguru",
31
32
  "torch>=2.5",
32
33
  "x-evolution>=0.1.23",
33
34
  "x-mlps-pytorch",