metacontroller-pytorch 0.0.16__tar.gz → 0.0.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/PKG-INFO +1 -1
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/metacontroller/metacontroller.py +37 -22
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/pyproject.toml +1 -1
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/tests/test_metacontroller.py +5 -3
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/.github/workflows/python-publish.yml +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/.github/workflows/test.yml +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/.gitignore +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/LICENSE +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/README.md +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/fig1.png +0 -0
- {metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/metacontroller/__init__.py +0 -0
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/metacontroller/metacontroller.py
RENAMED
|
@@ -57,14 +57,17 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
|
57
57
|
'prev_hiddens',
|
|
58
58
|
'action_dist',
|
|
59
59
|
'actions',
|
|
60
|
-
'kl_loss'
|
|
60
|
+
'kl_loss',
|
|
61
|
+
'switch_loss'
|
|
61
62
|
))
|
|
62
63
|
|
|
63
64
|
class MetaController(Module):
|
|
64
65
|
def __init__(
|
|
65
66
|
self,
|
|
66
|
-
|
|
67
|
+
dim_model,
|
|
67
68
|
*,
|
|
69
|
+
dim_meta_controller = 256,
|
|
70
|
+
dim_latent = 128,
|
|
68
71
|
switch_per_latent_dim = True,
|
|
69
72
|
decoder_expansion_factor = 2.,
|
|
70
73
|
decoder_depth = 1,
|
|
@@ -72,25 +75,30 @@ class MetaController(Module):
|
|
|
72
75
|
assoc_scan_kwargs: dict = dict()
|
|
73
76
|
):
|
|
74
77
|
super().__init__()
|
|
78
|
+
dim_meta = default(dim_meta_controller, dim_model)
|
|
79
|
+
|
|
80
|
+
# the linear that brings from model dimension
|
|
81
|
+
|
|
82
|
+
self.model_to_meta = Linear(dim_model, dim_meta)
|
|
75
83
|
|
|
76
84
|
# there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
|
|
77
85
|
|
|
78
|
-
self.bidirectional_temporal_compressor = GRU(
|
|
86
|
+
self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
|
|
79
87
|
|
|
80
|
-
self.emitter = GRU(
|
|
81
|
-
self.emitter_to_action_mean_log_var = Readout(
|
|
88
|
+
self.emitter = GRU(dim_meta * 2, dim_meta * 2)
|
|
89
|
+
self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
|
|
82
90
|
|
|
83
91
|
# internal rl phase substitutes the acausal + emitter with a causal ssm
|
|
84
92
|
|
|
85
|
-
self.action_proposer = GRU(
|
|
86
|
-
self.action_proposer_mean_log_var = Readout(
|
|
93
|
+
self.action_proposer = GRU(dim_meta, dim_meta)
|
|
94
|
+
self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
|
|
87
95
|
|
|
88
96
|
# switching unit
|
|
89
97
|
|
|
90
98
|
self.switch_per_latent_dim = switch_per_latent_dim
|
|
91
99
|
|
|
92
|
-
self.switching_unit = GRU(
|
|
93
|
-
self.to_switching_unit_beta = nn.Linear(
|
|
100
|
+
self.switching_unit = GRU(dim_meta, dim_meta)
|
|
101
|
+
self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
|
|
94
102
|
|
|
95
103
|
self.switch_gating = AssocScan(**assoc_scan_kwargs)
|
|
96
104
|
|
|
@@ -104,7 +112,7 @@ class MetaController(Module):
|
|
|
104
112
|
dim_in = dim_latent,
|
|
105
113
|
dim = dim_decoder_hidden,
|
|
106
114
|
depth = decoder_depth,
|
|
107
|
-
dim_out = 2 * hypernetwork_low_rank *
|
|
115
|
+
dim_out = 2 * hypernetwork_low_rank * dim_model
|
|
108
116
|
)
|
|
109
117
|
|
|
110
118
|
self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
|
|
@@ -113,6 +121,7 @@ class MetaController(Module):
|
|
|
113
121
|
|
|
114
122
|
def discovery_parameters(self):
|
|
115
123
|
return [
|
|
124
|
+
*self.model_to_meta.parameters(),
|
|
116
125
|
*self.bidirectional_temporal_compressor.parameters(),
|
|
117
126
|
*self.emitter.parameters(),
|
|
118
127
|
*self.emitter_to_action_mean_log_var.parameters(),
|
|
@@ -143,18 +152,20 @@ class MetaController(Module):
|
|
|
143
152
|
|
|
144
153
|
next_action_proposer_hidden = None
|
|
145
154
|
|
|
155
|
+
meta_embed = self.model_to_meta(residual_stream)
|
|
156
|
+
|
|
146
157
|
if discovery_phase:
|
|
147
158
|
logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
|
|
148
159
|
|
|
149
|
-
temporal_compressed, _ = self.bidirectional_temporal_compressor(
|
|
160
|
+
temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
|
|
150
161
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
151
162
|
|
|
152
|
-
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed,
|
|
163
|
+
proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
|
|
153
164
|
readout = self.emitter_to_action_mean_log_var
|
|
154
165
|
|
|
155
166
|
else: # else internal rl phase
|
|
156
167
|
|
|
157
|
-
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(
|
|
168
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
|
|
158
169
|
readout = self.action_proposer_mean_log_var
|
|
159
170
|
|
|
160
171
|
# sample from the gaussian as the action from the meta controller
|
|
@@ -167,13 +178,13 @@ class MetaController(Module):
|
|
|
167
178
|
|
|
168
179
|
batch, _, dim = sampled_action.shape
|
|
169
180
|
|
|
170
|
-
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(
|
|
181
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
|
|
171
182
|
|
|
172
183
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
173
184
|
|
|
174
185
|
# need to encourage normal distribution
|
|
175
186
|
|
|
176
|
-
kl_loss = self.zero
|
|
187
|
+
kl_loss = switch_loss = self.zero
|
|
177
188
|
|
|
178
189
|
if discovery_phase:
|
|
179
190
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
@@ -188,6 +199,10 @@ class MetaController(Module):
|
|
|
188
199
|
kl_loss = kl_loss * switch_beta
|
|
189
200
|
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
190
201
|
|
|
202
|
+
# encourage less switching
|
|
203
|
+
|
|
204
|
+
switch_loss = switch_beta.mean()
|
|
205
|
+
|
|
191
206
|
# maybe hard switch, then use associative scan
|
|
192
207
|
|
|
193
208
|
if hard_switch:
|
|
@@ -208,9 +223,7 @@ class MetaController(Module):
|
|
|
208
223
|
|
|
209
224
|
# generating the residual stream controlling signal
|
|
210
225
|
|
|
211
|
-
control_signal = einsum(
|
|
212
|
-
|
|
213
|
-
modified_residual_stream = residual_stream + control_signal
|
|
226
|
+
control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
|
|
214
227
|
|
|
215
228
|
# returning
|
|
216
229
|
|
|
@@ -220,7 +233,7 @@ class MetaController(Module):
|
|
|
220
233
|
next_switch_gated_action
|
|
221
234
|
)
|
|
222
235
|
|
|
223
|
-
return
|
|
236
|
+
return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
|
|
224
237
|
|
|
225
238
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
226
239
|
|
|
@@ -329,9 +342,11 @@ class Transformer(Module):
|
|
|
329
342
|
with meta_controller_context():
|
|
330
343
|
|
|
331
344
|
if exists(meta_controller):
|
|
332
|
-
|
|
345
|
+
control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
|
|
333
346
|
else:
|
|
334
|
-
|
|
347
|
+
control_signal, next_meta_hiddens = self.zero, None
|
|
348
|
+
|
|
349
|
+
modified_residual_stream = residual_stream + control_signal
|
|
335
350
|
|
|
336
351
|
# modified residual stream sent back to transformer upper body
|
|
337
352
|
|
|
@@ -357,7 +372,7 @@ class Transformer(Module):
|
|
|
357
372
|
|
|
358
373
|
action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
|
|
359
374
|
|
|
360
|
-
return action_recon_loss, next_meta_hiddens.kl_loss
|
|
375
|
+
return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
|
|
361
376
|
|
|
362
377
|
# returning
|
|
363
378
|
|
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/tests/test_metacontroller.py
RENAMED
|
@@ -38,14 +38,16 @@ def test_metacontroller(
|
|
|
38
38
|
# discovery and internal rl phase with meta controller
|
|
39
39
|
|
|
40
40
|
meta_controller = MetaController(
|
|
41
|
-
|
|
41
|
+
dim_model = 512,
|
|
42
|
+
dim_meta_controller = 256,
|
|
43
|
+
dim_latent = 128,
|
|
42
44
|
switch_per_latent_dim = switch_per_latent_dim
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
# discovery phase
|
|
46
48
|
|
|
47
|
-
(action_recon_loss, kl_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
48
|
-
(action_recon_loss + kl_loss * 0.1).backward()
|
|
49
|
+
(action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
|
|
50
|
+
(action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
|
|
49
51
|
|
|
50
52
|
# internal rl
|
|
51
53
|
|
{metacontroller_pytorch-0.0.16 → metacontroller_pytorch-0.0.18}/.github/workflows/python-publish.yml
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|