metacontroller-pytorch 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
+ from collections import namedtuple
3
4
 
4
5
  import torch
5
6
  from torch import nn, cat, stack, tensor
@@ -42,6 +43,11 @@ def default(*args):
42
43
  return arg
43
44
  return None
44
45
 
46
+ # tensor helpers
47
+
48
+ def straight_through(src, tgt):
49
+ return tgt + src - src.detach()
50
+
45
51
  # meta controller
46
52
 
47
53
  class MetaController(Module):
@@ -62,12 +68,12 @@ class MetaController(Module):
62
68
  self.bidirectional_temporal_compressor = GRU(dim_latent, dim_latent, bidirectional = True) # revisit naming
63
69
 
64
70
  self.emitter = GRU(dim_latent * 2, dim_latent * 2)
65
- self.emitter_to_action_mean_log_var = LinearNoBias(dim_latent * 2, dim_latent * 2)
71
+ self.emitter_to_action_mean_log_var = Readout(dim_latent * 2, num_continuous = dim_latent)
66
72
 
67
73
  # internal rl phase substitutes the acausal + emitter with a causal ssm
68
74
 
69
75
  self.action_proposer = GRU(dim_latent, dim_latent)
70
- self.action_proposer_mean_log_var = LinearNoBias(dim_latent, dim_latent * 2)
76
+ self.action_proposer_mean_log_var = Readout(dim_latent, num_continuous = dim_latent)
71
77
 
72
78
  # switching unit
73
79
 
@@ -115,7 +121,8 @@ class MetaController(Module):
115
121
  def forward(
116
122
  self,
117
123
  residual_stream,
118
- discovery_phase = False
124
+ discovery_phase = False,
125
+ hard_switch = False
119
126
  ):
120
127
 
121
128
  if discovery_phase:
@@ -123,41 +130,51 @@ class MetaController(Module):
123
130
  temporal_compressed = reduce(temporal_compressed, '... (two d) -> ... d', 'mean', two = 2)
124
131
 
125
132
  proposed_action_hidden, _ = self.emitter(cat((temporal_compressed, residual_stream), dim = -1))
126
- proposed_action = self.emitter_to_action_mean_log_var(proposed_action_hidden)
133
+ readout = self.emitter_to_action_mean_log_var
127
134
 
128
135
  else: # else internal rl phase
129
136
  proposed_action_hidden, _ = self.action_proposer(residual_stream)
130
- proposed_action = self.action_proposer_mean_log_var(proposed_action_hidden)
137
+ readout = self.action_proposer_mean_log_var
131
138
 
132
139
  # sample from the gaussian as the action from the meta controller
133
140
 
134
- mean, log_var = proposed_action.chunk(2, dim = -1)
141
+ action_dist = readout(proposed_action_hidden)
142
+
143
+ sampled_action = readout.sample(action_dist)
144
+
145
+ # switching unit timer
146
+
147
+ batch, _, dim = sampled_action.shape
135
148
 
136
- std = (0.5 * log_var).exp()
137
- sampled_action_intents = mean + torch.randn_like(mean) * std
149
+ switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
150
+
151
+ switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
152
+
153
+ action_intent_for_gating = rearrange(sampled_action, 'b n d -> (b d) n')
154
+ switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
138
155
 
139
156
  # need to encourage normal distribution
140
157
 
141
158
  vae_kl_loss = self.zero
142
159
 
143
160
  if discovery_phase:
161
+ mean, log_var = action_dist.unbind(dim = -1)
162
+
144
163
  vae_kl_loss = (0.5 * (
145
164
  log_var.exp()
146
165
  + mean.square()
147
166
  - log_var
148
167
  - 1.
149
- )).sum(dim = -1).mean()
168
+ )).sum(dim = -1)
150
169
 
151
- # switching unit timer
170
+ vae_kl_loss = vae_kl_loss * switch_beta
171
+ vae_kl_loss = vae_kl_loss.mean()
152
172
 
153
- batch, _, dim = sampled_action_intents.shape
173
+ # maybe hard switch, then use associative scan
154
174
 
155
- switching_unit_gru_out, switching_unit_gru_hidden = self.switching_unit(residual_stream)
156
-
157
- switch_beta = self.to_switching_unit_beta(switching_unit_gru_out).sigmoid()
158
-
159
- action_intent_for_gating = rearrange(sampled_action_intents, 'b n d -> (b d) n')
160
- switch_beta = repeat(switch_beta, 'b n d -> (b r d) n', r = dim if not self.switch_per_latent_dim else 1)
175
+ if hard_switch:
176
+ hard_switch = (switch_beta > 0.5).float()
177
+ switch_beta = straight_through(switch_beta, hard_switch)
161
178
 
162
179
  forget = 1. - switch_beta
163
180
  gated_action_intent = self.switch_gating(action_intent_for_gating * forget, switch_beta)
@@ -177,7 +194,7 @@ class MetaController(Module):
177
194
 
178
195
  modified_residual_stream = residual_stream + control_signal
179
196
 
180
- return modified_residual_stream, vae_kl_loss
197
+ return modified_residual_stream, action_dist, sampled_action, vae_kl_loss
181
198
 
182
199
  # main transformer, which is subsumed into the environment after behavioral cloning
183
200
 
@@ -249,9 +266,9 @@ class Transformer(Module):
249
266
  # meta controller acts on residual stream here
250
267
 
251
268
  if exists(meta_controller):
252
- modified_residual_stream, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
269
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = meta_controller(residual_stream, discovery_phase = discovery_phase)
253
270
  else:
254
- modified_residual_stream, vae_aux_loss = residual_stream, self.zero
271
+ modified_residual_stream, action_dist, sampled_action, vae_aux_loss = residual_stream, None, None, self.zero
255
272
 
256
273
  # modified residual stream sent back
257
274
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: metacontroller-pytorch
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Transformer Metacontroller
5
5
  Project-URL: Homepage, https://pypi.org/project/metacontroller/
6
6
  Project-URL: Repository, https://github.com/lucidrains/metacontroller
@@ -0,0 +1,6 @@
1
+ metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
+ metacontroller/metacontroller.py,sha256=-glgWcv6QQZ6wVAy6tK2Ye8QNXuWBiGmB1rs6DVWA1I,8573
3
+ metacontroller_pytorch-0.0.5.dist-info/METADATA,sha256=CfXW_uO8B9gz31XkUO-2aVl4TN64iYicdZPtW7DzzHc,3706
4
+ metacontroller_pytorch-0.0.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ metacontroller_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ metacontroller_pytorch-0.0.5.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- metacontroller/__init__.py,sha256=lj7IOGpN--qMxJWbB-4SGqoPXG7Hd4mgtToTRSyTZ58,57
2
- metacontroller/metacontroller.py,sha256=JZHAkyYbonP5aj_Ob-Xe41ziLEoJLL0KjJy3Dm_bXGY,8091
3
- metacontroller_pytorch-0.0.3.dist-info/METADATA,sha256=XO3zKqbfSpGmYe34P0uXetIygqkfD1QLuKSSClEeZhk,3706
4
- metacontroller_pytorch-0.0.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
- metacontroller_pytorch-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- metacontroller_pytorch-0.0.3.dist-info/RECORD,,