metacontroller-pytorch 0.0.9__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- metacontroller/metacontroller.py +60 -20
- {metacontroller_pytorch-0.0.9.dist-info → metacontroller_pytorch-0.0.10.dist-info}/METADATA +1 -1
- metacontroller_pytorch-0.0.10.dist-info/RECORD +6 -0
- metacontroller_pytorch-0.0.9.dist-info/RECORD +0 -6
- {metacontroller_pytorch-0.0.9.dist-info → metacontroller_pytorch-0.0.10.dist-info}/WHEEL +0 -0
- {metacontroller_pytorch-0.0.9.dist-info → metacontroller_pytorch-0.0.10.dist-info}/licenses/LICENSE +0 -0
metacontroller/metacontroller.py
CHANGED
|
@@ -52,6 +52,13 @@ def straight_through(src, tgt):
|
|
|
52
52
|
|
|
53
53
|
# meta controller
|
|
54
54
|
|
|
55
|
+
MetaControllerOutput = namedtuple('MetaControllerOutput', (
|
|
56
|
+
'prev_hiddens',
|
|
57
|
+
'action_dist',
|
|
58
|
+
'actions',
|
|
59
|
+
'kl_loss'
|
|
60
|
+
))
|
|
61
|
+
|
|
55
62
|
class MetaController(Module):
|
|
56
63
|
def __init__(
|
|
57
64
|
self,
|
|
@@ -107,9 +114,9 @@ class MetaController(Module):
|
|
|
107
114
|
return [
|
|
108
115
|
*self.bidirectional_temporal_compressor.parameters(),
|
|
109
116
|
*self.emitter.parameters(),
|
|
110
|
-
*self.emitter_to_action_mean_log_var.parameters()
|
|
117
|
+
*self.emitter_to_action_mean_log_var.parameters(),
|
|
111
118
|
*self.decoder.parameters(),
|
|
112
|
-
*self.switch_gating
|
|
119
|
+
*self.switch_gating.parameters()
|
|
113
120
|
]
|
|
114
121
|
|
|
115
122
|
def internal_rl_parameters(self):
|
|
@@ -121,10 +128,19 @@ class MetaController(Module):
|
|
|
121
128
|
def forward(
|
|
122
129
|
self,
|
|
123
130
|
residual_stream,
|
|
131
|
+
cache: MetaControllerOutput | None = None,
|
|
124
132
|
discovery_phase = False,
|
|
125
133
|
hard_switch = False
|
|
126
134
|
):
|
|
127
135
|
|
|
136
|
+
# destruct prev cache
|
|
137
|
+
|
|
138
|
+
prev_action_proposer_hidden, prev_switching_unit_gru_hidden, prev_switch_gated_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
139
|
+
|
|
140
|
+
# getting proposed action for the two phases
|
|
141
|
+
|
|
142
|
+
next_action_proposer_hidden = None
|
|
143
|
+
|
|
128
144
|
if discovery_phase:
|
|
129
145
|
temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
|
|
130
146
|
temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
|
|
@@ -133,7 +149,8 @@ class MetaController(Module):
|
|
|
133
149
|
readout = self.emitter_to_action_mean_log_var
|
|
134
150
|
|
|
135
151
|
else: # else internal rl phase
|
|
136
|
-
|
|
152
|
+
|
|
153
|
+
proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
|
|
137
154
|
readout = self.action_proposer_mean_log_var
|
|
138
155
|
|
|
139
156
|
# sample from the gaussian as the action from the meta controller
|
|
@@ -146,35 +163,37 @@ class MetaController(Module):
|
|
|
146
163
|
|
|
147
164
|
batch, _, dim = sampled_action.shape
|
|
148
165
|
|
|
149
|
-
switching_unit_gru_out,
|
|
166
|
+
switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
|
|
150
167
|
|
|
151
168
|
switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
|
|
152
169
|
|
|
153
170
|
# need to encourage normal distribution
|
|
154
171
|
|
|
155
|
-
|
|
172
|
+
kl_loss = self.zero
|
|
156
173
|
|
|
157
174
|
if discovery_phase:
|
|
158
175
|
mean, log_var = action_dist.unbind(dim = -1)
|
|
159
176
|
|
|
160
|
-
|
|
177
|
+
kl_loss = (0.5 * (
|
|
161
178
|
log_var.exp()
|
|
162
179
|
+ mean.square()
|
|
163
180
|
- log_var
|
|
164
181
|
- 1.
|
|
165
182
|
))
|
|
166
183
|
|
|
167
|
-
|
|
168
|
-
|
|
184
|
+
kl_loss = kl_loss * switch_beta
|
|
185
|
+
kl_loss = kl_loss.sum(dim = -1).mean()
|
|
169
186
|
|
|
170
187
|
# maybe hard switch, then use associative scan
|
|
171
188
|
|
|
172
189
|
if hard_switch:
|
|
173
|
-
|
|
174
|
-
switch_beta = straight_through(switch_beta,
|
|
190
|
+
hard_switch_beta = (switch_beta > 0.5).float()
|
|
191
|
+
switch_beta = straight_through(switch_beta, hard_switch_beta)
|
|
175
192
|
|
|
176
193
|
forget = 1. - switch_beta
|
|
177
|
-
gated_action = self.switch_gating(switch_beta, sampled_action * forget)
|
|
194
|
+
gated_action = self.switch_gating(switch_beta, sampled_action * forget, prev = prev_switch_gated_hiddens)
|
|
195
|
+
|
|
196
|
+
next_switch_gated_action = gated_action[:, -1]
|
|
178
197
|
|
|
179
198
|
# decoder
|
|
180
199
|
|
|
@@ -189,10 +208,23 @@ class MetaController(Module):
|
|
|
189
208
|
|
|
190
209
|
modified_residual_stream = residual_stream + control_signal
|
|
191
210
|
|
|
192
|
-
|
|
211
|
+
# returning
|
|
212
|
+
|
|
213
|
+
next_hiddens = (
|
|
214
|
+
next_action_proposer_hidden,
|
|
215
|
+
next_switching_unit_gru_hidden,
|
|
216
|
+
next_switch_gated_action
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
|
|
193
220
|
|
|
194
221
|
# main transformer, which is subsumed into the environment after behavioral cloning
|
|
195
222
|
|
|
223
|
+
TransformerOutput = namedtuple('TransformerOutput', (
|
|
224
|
+
'residual_stream_latent',
|
|
225
|
+
'prev_hiddens'
|
|
226
|
+
))
|
|
227
|
+
|
|
196
228
|
class Transformer(Module):
|
|
197
229
|
def __init__(
|
|
198
230
|
self,
|
|
@@ -251,10 +283,12 @@ class Transformer(Module):
|
|
|
251
283
|
self,
|
|
252
284
|
ids,
|
|
253
285
|
meta_controller: Module | None = None,
|
|
286
|
+
cache: TransformerOutput | None = None,
|
|
254
287
|
discovery_phase = False,
|
|
255
|
-
return_latents = False,
|
|
256
288
|
no_grad_transformer = None,
|
|
257
|
-
no_grad_meta_controller = None
|
|
289
|
+
no_grad_meta_controller = None,
|
|
290
|
+
return_latents = False,
|
|
291
|
+
return_cache = False
|
|
258
292
|
):
|
|
259
293
|
meta_controller = default(meta_controller, self.meta_controller)
|
|
260
294
|
|
|
@@ -268,28 +302,32 @@ class Transformer(Module):
|
|
|
268
302
|
transformer_context = torch.no_grad if no_grad_transformer else nullcontext
|
|
269
303
|
meta_controller_context = torch.no_grad if no_grad_meta_controller else nullcontext
|
|
270
304
|
|
|
305
|
+
# handle cache
|
|
306
|
+
|
|
307
|
+
lower_transformer_hiddens, meta_hiddens, upper_transformer_hiddens = cache.prev_hiddens if exists(cache) else ((None,) * 3)
|
|
308
|
+
|
|
271
309
|
# transformer lower body
|
|
272
310
|
|
|
273
311
|
with transformer_context():
|
|
274
312
|
|
|
275
313
|
embed = self.embed(ids)
|
|
276
314
|
|
|
277
|
-
residual_stream = self.lower_body(embed)
|
|
315
|
+
residual_stream, next_lower_hiddens = self.lower_body(embed, cache = lower_transformer_hiddens, return_hiddens = True)
|
|
278
316
|
|
|
279
317
|
# meta controller acts on residual stream here
|
|
280
318
|
|
|
281
319
|
with meta_controller_context():
|
|
282
320
|
|
|
283
321
|
if exists(meta_controller):
|
|
284
|
-
modified_residual_stream,
|
|
322
|
+
modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase)
|
|
285
323
|
else:
|
|
286
|
-
modified_residual_stream,
|
|
324
|
+
modified_residual_stream, next_meta_hiddens = residual_stream, None
|
|
287
325
|
|
|
288
326
|
# modified residual stream sent back to transformer upper body
|
|
289
327
|
|
|
290
328
|
with transformer_context():
|
|
291
329
|
|
|
292
|
-
attended = self.upper_body(modified_residual_stream)
|
|
330
|
+
attended, next_upper_hiddens = self.upper_body(modified_residual_stream, cache = upper_transformer_hiddens, return_hiddens = True)
|
|
293
331
|
|
|
294
332
|
# head readout
|
|
295
333
|
|
|
@@ -297,7 +335,9 @@ class Transformer(Module):
|
|
|
297
335
|
|
|
298
336
|
# returning
|
|
299
337
|
|
|
300
|
-
|
|
338
|
+
return_one = not (return_latents or return_cache)
|
|
339
|
+
|
|
340
|
+
if return_one:
|
|
301
341
|
return dist_params
|
|
302
342
|
|
|
303
|
-
return dist_params,
|
|
343
|
+
return dist_params, TransformerOutput(residual_stream, (next_lower_hiddens, next_meta_hiddens, next_upper_hiddens))
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
+
metacontroller/metacontroller.py,sha256=H-bZi70445-4JlhUFL8x_fgePY7bTxkDO4CCdItKao4,10642
|
|
3
|
+
metacontroller_pytorch-0.0.10.dist-info/METADATA,sha256=AFk9SUK6TGSG1APtt51yiASCEWIOTIvzAhtJJnS-Dsc,3714
|
|
4
|
+
metacontroller_pytorch-0.0.10.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
metacontroller_pytorch-0.0.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
+
metacontroller_pytorch-0.0.10.dist-info/RECORD,,
|
|
@@ -1,6 +0,0 @@
|
|
|
1
|
-
metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
|
|
2
|
-
metacontroller/metacontroller.py,sha256=V2Nb7ByGj310CalTzho-grwNsoHMp55oN5spkedJihc,9189
|
|
3
|
-
metacontroller_pytorch-0.0.9.dist-info/METADATA,sha256=BA4AHlFW8DsD_NPXNv8N8rmRPISZNTkcjvGautB7xJA,3713
|
|
4
|
-
metacontroller_pytorch-0.0.9.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
-
metacontroller_pytorch-0.0.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
6
|
-
metacontroller_pytorch-0.0.9.dist-info/RECORD,,
|
|
File without changes
|
{metacontroller_pytorch-0.0.9.dist-info → metacontroller_pytorch-0.0.10.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|