metacontroller-pytorch 0.0.9__tar.gz → 0.0.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/PKG-INFO +3 -2
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/metacontroller/metacontroller.py +67 -22
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/pyproject.toml +3 -2
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/tests/test_metacontroller.py +6 -1
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/README.md +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/metacontroller/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: metacontroller-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.12
|
|
4
4
|
Summary: Transformer Metacontroller
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/metacontroller/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/metacontroller
|
|
@@ -35,9 +35,10 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
36
36
|
Requires-Python: >=3.9
|
|
37
37
|
Requires-Dist: assoc-scan>=0.0.3
|
|
38
|
-
Requires-Dist: discrete-continuous-embed-readout>=0.1.
|
|
38
|
+
Requires-Dist: discrete-continuous-embed-readout>=0.1.12
|
|
39
39
|
Requires-Dist: einops>=0.8.1
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
+
Requires-Dist: loguru
|
|
41
42
|
Requires-Dist: torch>=2.5
|
|
42
43
|
Requires-Dist: x-evolution>=0.1.23
|
|
43
44
|
Requires-Dist: x-mlps-pytorch
|
{metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/metacontroller/metacontroller.py
RENAMED
|
@@ -3,6 +3,7 @@ from contextlib import nullcontext
|
|
|
3
3
|
|
|
4
4
|
from functools import partial
|
|
5
5
|
from collections import namedtuple
|
|
6
|
+
from loguru import logger
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
9
|
from torch import nn, cat, stack, tensor
|
|
@@ -52,6 +53,13 @@ def straight_through(src, tgt):
|
|
|
52
53
|
|
|
53
54
|
# meta controller
|
|
54
55
|
|
|
56
|
+
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
57
|
+
'prev_hiddens',
|
|
58
|
+
'action_dist',
|
|
59
|
+
'actions',
|
|
60
|
+
'kl_loss'
|
|
61
|
+
))
|
|
62
|
+
|
|
55
63
|
class MetaController(Module):
|
|
56
64
|
def __init__(
|
|
57
65
|
self,
|
|
@@ -107,9 +115,9 @@ class MetaController(Module):
|
|
|
107
115
|
return [
|
|
108
116
|
*self.bidirectional_temporal_compressor.parameters(),
|
|
109
117
|
*self.emitter.parameters(),
|
|
110
|
-
*self.emitter_to_action_mean_log_var.parameters()
|
|
118
|
+
*self.emitter_to_action_mean_log_var.parameters(),
|
|
111
119
|
*self.decoder.parameters(),
|
|
112
|
-
*self.switch_gating
|
|
120
|
+
*self.switch_gating.parameters()
|
|
113
121
|
]
|
|
114
122
|
|
|
115
123
|
def internal_rl_parameters(self):
|
|
@@ -121,11 +129,23 @@ class MetaController(Module):
|
|
|
121
129
|
def forward(
|
|
122
130
|
self,
|
|
123
131
|
residual_stream,
|
|
132
|
+
cache: MetaControllerOutput | None = None,
|
|
124
133
|
discovery_phase = False,
|
|
125
|
-
hard_switch = False
|
|
134
|
+
hard_switch = False,
|
|
135
|
+
temperature = 1.
|
|
126
136
|
):
|
|
127
137
|
|
|
138
|
+
# destruct prev cache
|
|
139
|
+
|
|
140
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
141
|
+
|
|
142
|
+
# getting proposed action for the two phases
|
|
143
|
+
|
|
144
|
+
next_action_proposer_hidden = None
|
|
145
|
+
|
|
128
146
|
if discovery_phase:
|
|
147
|
+
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
148
|
+
|
|
129
149
|
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
130
150
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
131
151
|
|
|
@@ -133,48 +153,51 @@ class MetaController(Module):
|
|
|
133
153
|
readout = self.emitter_to_action_mean_log_var
|
|
134
154
|
|
|
135
155
|
else: # else internal rl phase
|
|
136
|
-
|
|
156
|
+
|
|
157
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
|
|
137
158
|
readout = self.action_proposer_mean_log_var
|
|
138
159
|
|
|
139
160
|
# sample from the gaussian as the action from the meta controller
|
|
140
161
|
|
|
141
162
|
action_dist = readout(proposed_action_hidden)
|
|
142
163
|
|
|
143
|
-
sampled_action = readout.sample(action_dist)
|
|
164
|
+
sampled_action = readout.sample(action_dist, temperature = temperature)
|
|
144
165
|
|
|
145
166
|
# switching unit timer
|
|
146
167
|
|
|
147
168
|
batch, _, dim = sampled_action.shape
|
|
148
169
|
|
|
149
|
-
switching_unit_gru_out,
|
|
170
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
|
|
150
171
|
|
|
151
172
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
173
|
|
|
153
174
|
# need to encourage normal distribution
|
|
154
175
|
|
|
155
|
-
|
|
176
|
+
kl_loss = self.zero
|
|
156
177
|
|
|
157
178
|
if discovery_phase:
|
|
158
179
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
159
180
|
|
|
160
|
-
|
|
181
|
+
kl_loss = (0.5 * (
|
|
161
182
|
log_var.exp()
|
|
162
183
|
+ mean.square()
|
|
163
184
|
- log_var
|
|
164
185
|
- 1.
|
|
165
186
|
))
|
|
166
187
|
|
|
167
|
-
|
|
168
|
-
|
|
188
|
+
kl_loss = kl_loss * switch_beta
|
|
189
|
+
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
169
190
|
|
|
170
191
|
# maybe hard switch, then use associative scan
|
|
171
192
|
|
|
172
193
|
if hard_switch:
|
|
173
|
-
|
|
174
|
-
switch_beta = straight_through(switch_beta,
|
|
194
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
195
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
175
196
|
|
|
176
197
|
forget = 1. - switch_beta
|
|
177
|
-
gated_action = self.switch_gating(switch_beta, sampled_action * forget)
|
|
198
|
+
gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
|
|
199
|
+
|
|
200
|
+
next_switch_gated_action = gated_action[:, -1]
|
|
178
201
|
|
|
179
202
|
# decoder
|
|
180
203
|
|
|
@@ -189,10 +212,23 @@ class MetaController(Module):
|
|
|
189
212
|
|
|
190
213
|
modified_residual_stream = residual_stream + control_signal
|
|
191
214
|
|
|
192
|
-
|
|
215
|
+
# returning
|
|
216
|
+
|
|
217
|
+
next_hiddens = (
|
|
218
|
+
next_action_proposer_hidden,
|
|
219
|
+
next_switching_unit_gru_hidden,
|
|
220
|
+
next_switch_gated_action
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
|
|
193
224
|
|
|
194
225
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
195
226
|
|
|
227
|
+
TransformerOutput = namedtuple('TransformerOutput', (
|
|
228
|
+
'residual_stream_latent',
|
|
229
|
+
'prev_hiddens'
|
|
230
|
+
))
|
|
231
|
+
|
|
196
232
|
class Transformer(Module):
|
|
197
233
|
def __init__(
|
|
198
234
|
self,
|
|
@@ -251,10 +287,13 @@ class Transformer(Module):
|
|
|
251
287
|
self,
|
|
252
288
|
ids,
|
|
253
289
|
meta_controller: Module | None = None,
|
|
290
|
+
cache: TransformerOutput | None = None,
|
|
254
291
|
discovery_phase = False,
|
|
255
|
-
return_latents = False,
|
|
256
292
|
no_grad_transformer = None,
|
|
257
|
-
no_grad_meta_controller = None
|
|
293
|
+
no_grad_meta_controller = None,
|
|
294
|
+
meta_controller_temperature = 1.,
|
|
295
|
+
return_latents = False,
|
|
296
|
+
return_cache = False,
|
|
258
297
|
):
|
|
259
298
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
260
299
|
|
|
@@ -268,28 +307,32 @@ class Transformer(Module):
|
|
|
268
307
|
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
269
308
|
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
270
309
|
|
|
310
|
+
# handle cache
|
|
311
|
+
|
|
312
|
+
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
313
|
+
|
|
271
314
|
# transformer lower body
|
|
272
315
|
|
|
273
316
|
with transformer_context():
|
|
274
317
|
|
|
275
318
|
embed = self.embed(ids)
|
|
276
319
|
|
|
277
|
-
residual_stream = self.lower_body(embed)
|
|
320
|
+
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
278
321
|
|
|
279
322
|
# meta controller acts on residual stream here
|
|
280
323
|
|
|
281
324
|
with meta_controller_context():
|
|
282
325
|
|
|
283
326
|
if exists(meta_controller):
|
|
284
|
-
modified_residual_stream,
|
|
327
|
+
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
285
328
|
else:
|
|
286
|
-
modified_residual_stream,
|
|
329
|
+
modified_residual_stream, next_meta_hiddens = residual_stream, None
|
|
287
330
|
|
|
288
331
|
# modified residual stream sent back to transformer upper body
|
|
289
332
|
|
|
290
333
|
with transformer_context():
|
|
291
334
|
|
|
292
|
-
attended = self.upper_body(modified_residual_stream)
|
|
335
|
+
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
293
336
|
|
|
294
337
|
# head readout
|
|
295
338
|
|
|
@@ -297,7 +340,9 @@ class Transformer(Module):
|
|
|
297
340
|
|
|
298
341
|
# returning
|
|
299
342
|
|
|
300
|
-
|
|
343
|
+
return_one = not (return_latents or return_cache)
|
|
344
|
+
|
|
345
|
+
if return_one:
|
|
301
346
|
return dist_params
|
|
302
347
|
|
|
303
|
-
return dist_params,
|
|
348
|
+
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "metacontroller-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.12"
|
|
4
4
|
description = "Transformer Metacontroller"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -25,9 +25,10 @@ classifiers=[
|
|
|
25
25
|
|
|
26
26
|
dependencies = [
|
|
27
27
|
"assoc-scan>=0.0.3",
|
|
28
|
+
"discrete-continuous-embed-readout>=0.1.12",
|
|
28
29
|
"einx>=0.3.0",
|
|
29
30
|
"einops>=0.8.1",
|
|
30
|
-
"
|
|
31
|
+
"loguru",
|
|
31
32
|
"torch>=2.5",
|
|
32
33
|
"x-evolution>=0.1.23",
|
|
33
34
|
"x-mlps-pytorch",
|
|
@@ -26,9 +26,14 @@ def test_metacontroller(
|
|
|
26
26
|
switch_per_latent_dim = switch_per_latent_dim
|
|
27
27
|
)
|
|
28
28
|
|
|
29
|
-
logits = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase)
|
|
29
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True)
|
|
30
30
|
|
|
31
31
|
assert logits.shape == (1, 1024, 256)
|
|
32
32
|
|
|
33
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
34
|
+
logits, cache = model(ids, meta_controller = meta_controller, discovery_phase = discovery_phase, return_cache = True, cache = cache)
|
|
35
|
+
|
|
36
|
+
assert logits.shape == (1, 1, 256)
|
|
37
|
+
|
|
33
38
|
model.meta_controller = meta_controller
|
|
34
39
|
model.evolve(1, lambda _: 1., noise_population_size = 2)
|
{metacontroller_pytorch-0.0.9 → metacontroller_pytorch-0.0.12}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|