metacontroller-pytorch 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -64,8 +64,10 @@ MetaControllerOutput = namedtuple('MetaControllerOutput', (
64
64
  class MetaController(Module):
65
65
  def __init__(
66
66
  self,
67
- dim_latent,
67
+ dim_model,
68
68
  *,
69
+ dim_meta_controller = 256,
70
+ dim_latent = 128,
69
71
  switch_per_latent_dim = True,
70
72
  decoder_expansion_factor = 2.,
71
73
  decoder_depth = 1,
@@ -73,25 +75,30 @@ class MetaController(Module):
73
75
  assoc_scan_kwargs: dict = dict()
74
76
  ):
75
77
  super().__init__()
78
+ dim_meta = default(dim_meta_controller, dim_model)
79
+
80
+ # the linear that brings from model dimension
81
+
82
+ self.model_to_meta = Linear(dim_model, dim_meta)
76
83
 
77
84
  # there are two phases, the first (discovery ssl phase) uses acausal with some ssm i don't really believe in - let's just use a bidirectional GRU as placeholders
78
85
 
79
- self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
86
+ self.bidirectional_temporal_compressor = GRU(dim_meta, dim_meta, bidirectional = True) # revisit naming
80
87
 
81
- self.emitter = GRU(dim_latent * 2, dim_latent * 2)
82
- self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
88
+ self.emitter = GRU(dim_meta * 2, dim_meta * 2)
89
+ self.emitter_to_action_mean_log_var = Readout(dim_meta * 2, num_continuous = dim_latent)
83
90
 
84
91
  # internal rl phase substitutes the acausal + emitter with a causal ssm
85
92
 
86
- self.action_proposer = GRU(dim_latent, dim_latent)
87
- self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
93
+ self.action_proposer = GRU(dim_meta, dim_meta)
94
+ self.action_proposer_mean_log_var = Readout(dim_meta, num_continuous = dim_latent)
88
95
 
89
96
  # switching unit
90
97
 
91
98
  self.switch_per_latent_dim = switch_per_latent_dim
92
99
 
93
- self.switching_unit = GRU(dim_latent, dim_latent)
94
- self.to_switching_unit_beta = nn.Linear(dim_latent, dim_latent if switch_per_latent_dim else 1, bias = False)
100
+ self.switching_unit = GRU(dim_meta, dim_meta)
101
+ self.to_switching_unit_beta = nn.Linear(dim_meta, dim_latent if switch_per_latent_dim else 1, bias = False)
95
102
 
96
103
  self.switch_gating = AssocScan(**assoc_scan_kwargs)
97
104
 
@@ -105,7 +112,7 @@ class MetaController(Module):
105
112
  dim_in = dim_latent,
106
113
  dim = dim_decoder_hidden,
107
114
  depth = decoder_depth,
108
- dim_out = 2 * hypernetwork_low_rank * dim_latent
115
+ dim_out = 2 * hypernetwork_low_rank * dim_model
109
116
  )
110
117
 
111
118
  self.to_hyper_network_weights = Rearrange('... (two d r) -> two ... d r', two = 2, r = hypernetwork_low_rank)
@@ -114,6 +121,7 @@ class MetaController(Module):
114
121
 
115
122
  def discovery_parameters(self):
116
123
  return [
124
+ *self.model_to_meta.parameters(),
117
125
  *self.bidirectional_temporal_compressor.parameters(),
118
126
  *self.emitter.parameters(),
119
127
  *self.emitter_to_action_mean_log_var.parameters(),
@@ -144,18 +152,20 @@ class MetaController(Module):
144
152
 
145
153
  next_action_proposer_hidden = None
146
154
 
155
+ meta_embed = self.model_to_meta(residual_stream)
156
+
147
157
  if discovery_phase:
148
158
  logger.warning('meta controller cache being passed back in for discovery phase, which does not make sense given bidirectional encoder')
149
159
 
150
- temporal_compressed, _ = self.bidirectional_temporal_compressor(residual_stream)
160
+ temporal_compressed, _ = self.bidirectional_temporal_compressor(meta_embed)
151
161
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
152
162
 
153
- proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
163
+ proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, meta_embed), dim = -1))
154
164
  readout = self.emitter_to_action_mean_log_var
155
165
 
156
166
  else: # else internal rl phase
157
167
 
158
- proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(residual_stream, prev_action_proposer_hidden)
168
+ proposed_action_hidden, next_action_proposer_hidden = self.action_proposer(meta_embed, prev_action_proposer_hidden)
159
169
  readout = self.action_proposer_mean_log_var
160
170
 
161
171
  # sample from the gaussian as the action from the meta controller
@@ -168,7 +178,7 @@ class MetaController(Module):
168
178
 
169
179
  batch, _, dim = sampled_action.shape
170
180
 
171
- switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(residual_stream, prev_switching_unit_gru_hidden)
181
+ switching_unit_gru_out, next_switching_unit_gru_hidden = self.switching_unit(meta_embed, prev_switching_unit_gru_hidden)
172
182
 
173
183
  switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
174
184
 
@@ -213,9 +223,7 @@ class MetaController(Module):
213
223
 
214
224
  # generating the residual stream controlling signal
215
225
 
216
- control_signal = einsum(gated_action, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
217
-
218
- modified_residual_stream = residual_stream + control_signal
226
+ control_signal = einsum(residual_stream, hypernetwork_weight, '... d1, ... d1 d2 -> ... d1')
219
227
 
220
228
  # returning
221
229
 
@@ -225,7 +233,7 @@ class MetaController(Module):
225
233
  next_switch_gated_action
226
234
  )
227
235
 
228
- return modified_residual_stream, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
236
+ return control_signal, MetaControllerOutput(next_hiddens, action_dist, sampled_action, kl_loss, switch_loss)
229
237
 
230
238
  # main transformer, which is subsumed into the environment after behavioral cloning
231
239
 
@@ -248,7 +256,7 @@ class Transformer(Module):
248
256
  super().__init__()
249
257
 
250
258
  if isinstance(lower_body, dict):
251
- lower_body = Decoder(dim = dim, **lower_body)
259
+ lower_body = Decoder(dim = dim, pre_norm_has_final_norm = False, **lower_body)
252
260
 
253
261
  if isinstance(upper_body, dict):
254
262
  upper_body = Decoder(dim = dim, **upper_body)
@@ -334,9 +342,11 @@ class Transformer(Module):
334
342
  with meta_controller_context():
335
343
 
336
344
  if exists(meta_controller):
337
- modified_residual_stream, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
345
+ control_signal, next_meta_hiddens = meta_controller(residual_stream, cache = meta_hiddens, discovery_phase = discovery_phase, temperature = meta_controller_temperature)
338
346
  else:
339
- modified_residual_stream, next_meta_hiddens = residual_stream, None
347
+ control_signal, next_meta_hiddens = self.zero, None
348
+
349
+ modified_residual_stream = residual_stream + control_signal
340
350
 
341
351
  # modified residual stream sent back to transformer upper body
342
352
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.17
3
+ Version: 0.0.19
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=GTErzikqVd8XDY8pmDnY8t4uIjbGCUd1GZBJX13peo8,12339
3
+ metacontroller_pytorch-0.0.19.dist-info/METADATA,sha256=lX3L7J3CKoSyxvJniLdSJsCu0UMEbJTxQLEw6zzT7dY,3741
4
+ metacontroller_pytorch-0.0.19.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.19.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=blxDztbtXyP3cNbjnM3fEw_KZdLFJp_l1Sub6-7zIKg,12041
3
- metacontroller_pytorch-0.0.17.dist-info/METADATA,sha256=_8hYYTO_ME23kgZXqSfhA1XXAA8W877F-AL8amA7LKM,3741
4
- metacontroller_pytorch-0.0.17.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.17.dist-info/RECORD,,