metacontroller-pytorch 0.0.16__tar.gz → 0.0.18__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.16
3
+ Version: 0.0.18
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -57,14 +57,17 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
57
57
  'prev_hiddens',
58
58
  'action_dist',
59
59
  'actions',
60
- 'kl_loss'
60
+ 'kl_loss',
61
+ 'switch_loss'
61
62
  ))
62
63
 
63
64
  class MetaController(Module):
64
65
  def __init__(
65
66
  self,
66
- dim_latent,
67
+ dim_model,
67
68
  *,
69
+ dim_meta_controller = 256,
70
+ dim_latent = 128,
68
71
  switch_per_latent_dim = True,
69
72
  decoder_expansion_factor = 2.,
70
73
  decoder_depth = 1,
@@ -72,25 +75,30 @@ class MetaController(Module):
72
75
  assoc_scan_kwargs: dict = dict()
73
76
  ):
74
77
  super().__init__()
78
+ dim_meta = default(dim_meta_controller, dim_model)
79
+
80
+ # the linear that brings from model dimension
81
+
82
+ self.model_to_meta = Linear(dim_model, dim_meta)
75
83
 
76
84
  # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
77
85
 
78
- self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
86
+ self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
79
87
 
80
- self.emitter = GRU(dim_latent * 2, dim_latent * 2)
81
- self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
88
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
89
+ self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
82
90
 
83
91
  # internal rl phase substitutes the acausal + emitter with a causal ssm
84
92
 
85
- self.action_proposer = GRU(dim_latent, dim_latent)
86
- self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
93
+ self.action_proposer = GRU(dim_meta, dim_meta)
94
+ self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
87
95
 
88
96
  # switching unit
89
97
 
90
98
  self.switch_per_latent_dim = switch_per_latent_dim
91
99
 
92
- self.switching_unit = GRU(dim_latent, dim_latent)
93
- self.to_switching_unit_beta = nn.Linear(dim_latent, dim_latent if switch_per_latent_dim else 1, bias = False)
100
+ self.switching_unit = GRU(dim_meta, dim_meta)
101
+ self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
94
102
 
95
103
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
96
104
 
@@ -104,7 +112,7 @@ class MetaController(Module):
104
112
  dim_in = dim_latent,
105
113
  dim = dim_decoder_hidden,
106
114
  depth = decoder_depth,
107
- dim_out = 2 * hypernetwork_low_rank * dim_latent
115
+ dim_out = 2 * hypernetwork_low_rank * dim_model
108
116
  )
109
117
 
110
118
  self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
@@ -113,6 +121,7 @@ class MetaController(Module):
113
121
 
114
122
  def discovery_parameters(self):
115
123
  return [
124
+ *self.model_to_meta.parameters(),
116
125
  *self.bidirectional_temporal_compressor.parameters(),
117
126
  *self.emitter.parameters(),
118
127
  *self.emitter_to_action_mean_log_var.parameters(),
@@ -143,18 +152,20 @@ class MetaController(Module):
143
152
 
144
153
  next_action_proposer_hidden = None
145
154
 
155
+ meta_embed = self.model_to_meta(residual_stream)
156
+
146
157
  if discovery_phase:
147
158
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
148
159
 
149
- temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
160
+ temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
150
161
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
151
162
 
152
- proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
163
+ proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
153
164
  readout = self.emitter_to_action_mean_log_var
154
165
 
155
166
  else: # else internal rl phase
156
167
 
157
- proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
168
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
158
169
  readout = self.action_proposer_mean_log_var
159
170
 
160
171
  # sample from the gaussian as the action from the meta controller
@@ -167,13 +178,13 @@ class MetaController(Module):
167
178
 
168
179
  batch, _, dim = sampled_action.shape
169
180
 
170
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
181
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
171
182
 
172
183
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
173
184
 
174
185
  # need to encourage normal distribution
175
186
 
176
- kl_loss = self.zero
187
+ kl_loss = switch_loss = self.zero
177
188
 
178
189
  if discovery_phase:
179
190
  mean, log_var = action_dist.unbind(dim = -1)
@@ -188,6 +199,10 @@ class MetaController(Module):
188
199
  kl_loss = kl_loss * switch_beta
189
200
  kl_loss = kl_loss.sum(dim = -1).mean()
190
201
 
202
+ # encourage less switching
203
+
204
+ switch_loss = switch_beta.mean()
205
+
191
206
  # maybe hard switch, then use associative scan
192
207
 
193
208
  if hard_switch:
@@ -208,9 +223,7 @@ class MetaController(Module):
208
223
 
209
224
  # generating the residual stream controlling signal
210
225
 
211
- control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
212
-
213
- modified_residual_stream = residual_stream + control_signal
226
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
214
227
 
215
228
  # returning
216
229
 
@@ -220,7 +233,7 @@ class MetaController(Module):
220
233
  next_switch_gated_action
221
234
  )
222
235
 
223
- return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss)
236
+ return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
224
237
 
225
238
  # main transformer, which is subsumed into the environment after behavioral cloning
226
239
 
@@ -329,9 +342,11 @@ class Transformer(Module):
329
342
  with meta_controller_context():
330
343
 
331
344
  if exists(meta_controller):
332
- modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
345
+ control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
333
346
  else:
334
- modified_residual_stream, next_meta_hiddens = residual_stream, None
347
+ control_signal, next_meta_hiddens = self.zero, None
348
+
349
+ modified_residual_stream = residual_stream + control_signal
335
350
 
336
351
  # modified residual stream sent back to transformer upper body
337
352
 
@@ -357,7 +372,7 @@ class Transformer(Module):
357
372
 
358
373
  action_recon_loss = self.action_readout.calculate_loss(dist_params, target_action_ids)
359
374
 
360
- return action_recon_loss, next_meta_hiddens.kl_loss
375
+ return action_recon_loss, next_meta_hiddens.kl_loss, next_meta_hiddens.switch_loss
361
376
 
362
377
  # returning
363
378
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "metacontroller-pytorch"
3
- version = "0.0.16"
3
+ version = "0.0.18"
4
4
  description = "Transformer Metacontroller"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -38,14 +38,16 @@ def test_metacontroller(
38
38
  # discovery and internal rl phase with meta controller
39
39
 
40
40
  meta_controller = MetaController(
41
- dim_latent = 512,
41
+ dim_model = 512,
42
+ dim_meta_controller = 256,
43
+ dim_latent = 128,
42
44
  switch_per_latent_dim = switch_per_latent_dim
43
45
  )
44
46
 
45
47
  # discovery phase
46
48
 
47
- (action_recon_loss, kl_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
48
- (action_recon_loss + kl_loss * 0.1).backward()
49
+ (action_recon_loss, kl_loss, switch_loss) = model(state, actions, meta_controller = meta_controller, discovery_phase = True)
50
+ (action_recon_loss + kl_loss * 0.1 + switch_loss * 0.2).backward()
49
51
 
50
52
  # internal rl
51
53