locoformer 0.0.11__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.11
3
+ Version: 0.0.17
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -3,6 +3,7 @@ from functools import partial
3
3
 
4
4
  from pathlib import Path
5
5
  from contextlib import contextmanager
6
+ from collections import namedtuple
6
7
 
7
8
  import numpy as np
8
9
  from numpy import ndarray
@@ -17,6 +18,7 @@ import torch.nn.functional as F
17
18
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
18
19
  from torch.utils._pytree import tree_map
19
20
  from torch.utils.data import Dataset, DataLoader
21
+ from torch.optim import Optimizer
20
22
 
21
23
  import einx
22
24
  from einops import rearrange, einsum
@@ -26,6 +28,8 @@ from rotary_embedding_torch import RotaryEmbedding
26
28
 
27
29
  from assoc_scan import AssocScan
28
30
 
31
+ # constants
32
+
29
33
  LinearNoBias = partial(Linear, bias = False)
30
34
 
31
35
  # helper functions
@@ -42,12 +46,14 @@ def first(arr):
42
46
  def divisible_by(num, den):
43
47
  return (num % den) == 0
44
48
 
49
+ # tensor helpers
50
+
51
+ def log(t, eps = 1e-20):
52
+ return t.clamp_min(eps).log()
53
+
45
54
  def tree_map_tensor(x, fn):
46
55
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
47
56
 
48
- def detach_all(x):
49
- return tree_map_tensor(x, lambda t: t.detach())
50
-
51
57
  def pad_at_dim(
52
58
  t,
53
59
  pad: tuple[int, int],
@@ -61,13 +67,17 @@ def pad_at_dim(
61
67
  zeros = ((0, 0) * dims_from_right)
62
68
  return F.pad(t, (*zeros, *pad), value = value)
63
69
 
70
+ def calc_entropy(logits):
71
+ prob = logits.softmax(dim = -1)
72
+ return -(prob * log(prob)).sum(dim = -1)
73
+
64
74
  # generalized advantage estimate
65
75
 
66
76
  @torch.no_grad()
67
77
  def calc_gae(
68
78
  rewards,
69
79
  values,
70
- masks,
80
+ masks = None,
71
81
  gamma = 0.99,
72
82
  lam = 0.95,
73
83
  use_accelerated = None
@@ -78,6 +88,9 @@ def calc_gae(
78
88
  values = F.pad(values, (0, 1), value = 0.)
79
89
  values, values_next = values[..., :-1], values[..., 1:]
80
90
 
91
+ if not exists(masks):
92
+ masks = torch.ones_like(values)
93
+
81
94
  delta = rewards + gamma * values_next * masks - values
82
95
  gates = gamma * lam * masks
83
96
 
@@ -87,7 +100,7 @@ def calc_gae(
87
100
 
88
101
  returns = gae + values
89
102
 
90
- return returns
103
+ return gae, returns
91
104
 
92
105
  # transformer-xl mask w/ flex attn
93
106
 
@@ -129,8 +142,8 @@ def create_xl_mask(
129
142
  # handle intra-episodic attention if needed
130
143
 
131
144
  if exists(episode_ids):
132
- q_episode = episodes[b, q + offset]
133
- k_episode = episodes[b, k]
145
+ q_episode = episode_ids[b, q + offset]
146
+ k_episode = episode_ids[b, k]
134
147
 
135
148
  intra_episode_mask = q_episode == k_episode
136
149
  mask = mask & intra_episode_mask
@@ -231,7 +244,7 @@ class ReplayDataset(Dataset):
231
244
 
232
245
  episode_len = self.episode_lens[episode_index]
233
246
 
234
- data = {field: torch.from_numpy(memmap[episode_index, :episode_len]) for field, memmap in self.memmaps.items()}
247
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
235
248
 
236
249
  data['_lens'] = tensor(episode_len)
237
250
 
@@ -299,6 +312,13 @@ class ReplayBuffer:
299
312
  self.shapes[field_name] = shape
300
313
  self.dtypes[field_name] = dtype
301
314
 
315
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
+
317
+ def reset_(self):
318
+ self.episode_lens[:] = 0
319
+ self.episode_index = 0
320
+ self.timestep_index = 0
321
+
302
322
  def advance_episode(self):
303
323
  self.episode_index = (self.episode_index + 1) % self.max_episodes
304
324
  self.timestep_index = 0
@@ -353,15 +373,17 @@ class ReplayBuffer:
353
373
 
354
374
  self.timestep_index += 1
355
375
 
376
+ return self.memory_namedtuple(**data)
377
+
356
378
  def dataset(self) -> Dataset:
357
379
  self.flush()
358
380
 
359
381
  return ReplayDataset(self.folder)
360
382
 
361
- def dataloader(self, **kwargs) -> DataLoader:
383
+ def dataloader(self, batch_size, **kwargs) -> DataLoader:
362
384
  self.flush()
363
385
 
364
- return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
386
+ return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
365
387
 
366
388
  # transformer-xl with ppo
367
389
 
@@ -421,7 +443,6 @@ class Attention(Module):
421
443
  return_kv_cache = False,
422
444
  ):
423
445
  seq_len = tokens.shape[-2]
424
- assert seq_len <= self.window_size
425
446
 
426
447
  device = tokens.device
427
448
 
@@ -582,7 +603,14 @@ class Locoformer(Module):
582
603
  embedder: Module,
583
604
  unembedder: Module,
584
605
  transformer: dict | TransformerXL,
585
- value_network: Module | None = None
606
+ value_network: Module | None = None,
607
+ discount_factor = 0.999,
608
+ gae_lam = 0.95,
609
+ ppo_eps_clip = 0.2,
610
+ ppo_entropy_weight = 0.01,
611
+ ppo_value_clip = 0.4,
612
+ value_loss_weight = 0.5,
613
+ calc_gae_kwargs: dict = dict()
586
614
  ):
587
615
  super().__init__()
588
616
 
@@ -599,6 +627,21 @@ class Locoformer(Module):
599
627
  self.fixed_window_size = transformer.fixed_window_size
600
628
  self.window_size = transformer.window_size
601
629
 
630
+ # ppo related
631
+
632
+ self.discount_factor = discount_factor
633
+ self.gae_lam = gae_lam
634
+ self.ppo_eps_clip = ppo_eps_clip
635
+ self.ppo_entropy_weight = ppo_entropy_weight
636
+ self.ppo_value_clip = ppo_value_clip
637
+ self.value_loss_weight = value_loss_weight
638
+
639
+ self.calc_gae_kwargs = calc_gae_kwargs
640
+
641
+ # loss related
642
+
643
+ self.register_buffer('zero', tensor(0.), persistent = False)
644
+
602
645
  @property
603
646
  def device(self):
604
647
  return next(self.parameters()).device
@@ -612,6 +655,101 @@ class Locoformer(Module):
612
655
 
613
656
  return self.value_network.parameters()
614
657
 
658
+ def ppo(
659
+ self,
660
+ state,
661
+ action,
662
+ old_action_log_prob,
663
+ reward,
664
+ old_value,
665
+ mask,
666
+ actor_optim: Optimizer | None = None,
667
+ critic_optim: Optimizer | None = None
668
+ ):
669
+ window_size = self.window_size
670
+ total_learnable_tokens = mask.sum().item()
671
+
672
+ windowed_tensors = [
673
+ t.split(window_size, dim = 1) for t in
674
+ (
675
+ state,
676
+ action,
677
+ old_action_log_prob,
678
+ reward,
679
+ old_value,
680
+ mask
681
+ )
682
+ ]
683
+
684
+ mean_actor_loss = self.zero.clone()
685
+ mean_critic_loss = self.zero.clone()
686
+
687
+ # learn across windows
688
+
689
+ cache = None
690
+
691
+ for (
692
+ state,
693
+ action,
694
+ old_action_log_prob,
695
+ reward,
696
+ old_value,
697
+ mask
698
+ ) in zip(*windowed_tensors):
699
+
700
+ (action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
701
+ entropy = calc_entropy(action_logits)
702
+
703
+ action = rearrange(action, 'b t -> b t 1')
704
+ log_prob = action_logits.gather(-1, action)
705
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
706
+
707
+ # update actor, classic clipped surrogate loss
708
+
709
+ eps_clip = self.ppo_eps_clip
710
+ ratio = (log_prob - old_action_log_prob).exp()
711
+
712
+ advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
713
+
714
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
715
+
716
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
717
+
718
+ windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
719
+ windowed_actor_loss.backward(retain_graph = True)
720
+
721
+ # update critic
722
+
723
+ value_loss = F.mse_loss(returns, value, reduction = 'none')
724
+
725
+ value_clip = self.ppo_value_clip
726
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
727
+ clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
728
+
729
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
730
+
731
+ windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
732
+ windowed_critic_loss.backward(retain_graph = True)
733
+
734
+ # accumulate
735
+
736
+ mean_actor_loss.add_(windowed_actor_loss)
737
+ mean_critic_loss.add_(windowed_critic_loss)
738
+
739
+ # optimizer update
740
+
741
+ if exists(actor_optim):
742
+ actor_optim.step()
743
+ actor_optim.zero_grad()
744
+
745
+ if exists(critic_optim):
746
+ critic_optim.step()
747
+ critic_optim.zero_grad()
748
+
749
+ # return losses for logging
750
+
751
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
752
+
615
753
  def wrap_env_functions(self, env):
616
754
 
617
755
  def wrapped_reset(*args, **kwargs):
@@ -723,7 +861,7 @@ class Locoformer(Module):
723
861
  # maybe detach cache
724
862
 
725
863
  if detach_cache:
726
- kv_cache = detach_all(kv_cache)
864
+ kv_cache = kv_cache.detach()
727
865
 
728
866
  # handle returning of values
729
867
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.11"
3
+ version = "0.0.17"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
169
169
  prime = prime.to(model.device)
170
170
  out = prime
171
171
 
172
- stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
172
+ stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
173
173
 
174
174
  # sample
175
175
 
@@ -3,7 +3,7 @@
3
3
  # "accelerate",
4
4
  # "fire",
5
5
  # "gymnasium[box2d]>=1.0.0",
6
- # "locoformer",
6
+ # "locoformer>=0.0.12",
7
7
  # "moviepy",
8
8
  # "tqdm"
9
9
  # ]
@@ -13,17 +13,19 @@ from fire import Fire
13
13
  from shutil import rmtree
14
14
  from tqdm import tqdm
15
15
  from collections import deque
16
+ from types import SimpleNamespace
16
17
 
17
18
  from accelerate import Accelerator
18
19
 
19
20
  import gymnasium as gym
20
21
 
21
22
  import torch
22
- from torch import from_numpy, randint, tensor, stack
23
+ from torch import from_numpy, randint, tensor, stack, arange
23
24
  import torch.nn.functional as F
24
25
  from torch.utils.data import TensorDataset, DataLoader
25
26
  from torch.optim import Adam
26
27
 
28
+ import einx
27
29
  from einops import rearrange
28
30
 
29
31
  from locoformer.locoformer import Locoformer, ReplayBuffer
@@ -47,26 +49,69 @@ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
49
  noise = gumbel_noise(logits)
48
50
  return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
51
 
52
+ # learn
53
+
54
+ def learn(
55
+ model,
56
+ actor_optim,
57
+ critic_optim,
58
+ accelerator,
59
+ replay,
60
+ batch_size = 16,
61
+ epochs = 2,
62
+ ):
63
+ device = accelerator.device
64
+
65
+ dl = replay.dataloader(batch_size = batch_size, shuffle = True)
66
+ model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
67
+
68
+ for _ in range(epochs):
69
+ for data in dl:
70
+
71
+ data = SimpleNamespace(**data)
72
+
73
+ seq_len = data.state.shape[1]
74
+
75
+ value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
76
+ value = torch.where(value_mask, data.value, 0.)
77
+
78
+ actor_loss, critic_loss = model.ppo(
79
+ state = data.state,
80
+ action = data.action,
81
+ old_action_log_prob = data.action_log_prob,
82
+ reward = data.reward,
83
+ old_value = value,
84
+ mask = data.learnable,
85
+ actor_optim = actor_optim,
86
+ critic_optim = critic_optim
87
+ )
88
+
89
+ accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
90
+
50
91
  # main function
51
92
 
52
93
  def main(
53
94
  env_name = 'LunarLander-v3',
54
95
  num_episodes = 50_000,
55
96
  max_timesteps = 500,
56
- num_timestep_before_learn = 5000,
97
+ num_episodes_before_learn = 32,
57
98
  clear_video = True,
58
99
  video_folder = 'recordings',
59
100
  record_every_episode = 250,
101
+ learning_rate = 8e-4,
60
102
  discount_factor = 0.99,
61
- learning_rate = 1e-4,
103
+ betas = (0.9, 0.99),
104
+ gae_lam = 0.95,
105
+ ppo_eps_clip = 0.2,
106
+ ppo_entropy_weight = .01,
62
107
  batch_size = 16,
63
108
  epochs = 2
64
109
  ):
65
110
 
66
111
  # accelerate
67
112
 
68
- accelerate = Accelerator()
69
- device = accelerate.device
113
+ accelerator = Accelerator()
114
+ device = accelerator.device
70
115
 
71
116
  # environment
72
117
 
@@ -91,14 +136,15 @@ def main(
91
136
  replay = ReplayBuffer(
92
137
  'replay',
93
138
  num_episodes,
94
- max_timesteps,
139
+ max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
95
140
  fields = dict(
96
141
  state = ('float', (dim_state,)),
97
142
  action = 'int',
98
143
  action_log_prob = 'float',
99
144
  reward = 'float',
100
145
  value = 'float',
101
- done = 'bool'
146
+ done = 'bool',
147
+ learnable = 'bool'
102
148
  )
103
149
  )
104
150
 
@@ -114,11 +160,18 @@ def main(
114
160
  heads = 4,
115
161
  depth = 4,
116
162
  window_size = 16
163
+ ),
164
+ discount_factor = discount_factor,
165
+ gae_lam = gae_lam,
166
+ ppo_eps_clip = ppo_eps_clip,
167
+ ppo_entropy_weight = ppo_entropy_weight,
168
+ calc_gae_kwargs = dict(
169
+ use_accelerated = False
117
170
  )
118
171
  ).to(device)
119
172
 
120
- optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
121
- optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
173
+ optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
174
+ optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
122
175
 
123
176
  timesteps_learn = 0
124
177
 
@@ -129,7 +182,8 @@ def main(
129
182
 
130
183
  # loop
131
184
 
132
- for _ in tqdm(range(num_episodes)):
185
+ for episodes_index in tqdm(range(num_episodes)):
186
+
133
187
  state, *_ = env_reset()
134
188
 
135
189
  timestep = 0
@@ -158,27 +212,29 @@ def main(
158
212
  action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
159
213
  action_log_prob = rearrange(action_log_prob, '1 ->')
160
214
 
161
- replay.store(
215
+ memory = replay.store(
162
216
  state = state,
163
217
  action = action,
164
218
  action_log_prob = action_log_prob,
165
219
  reward = reward,
166
220
  value = value,
167
- done = done
221
+ done = done,
222
+ learnable = tensor(True)
168
223
  )
169
224
 
170
- # increment counters
225
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
226
+ # only if terminated signal not detected
171
227
 
172
- timestep += 1
173
- timesteps_learn += 1
228
+ if not terminated:
229
+ _, next_value = stateful_forward(next_state, return_values = True)
230
+
231
+ memory._replace(value = next_value, learnable = False)
174
232
 
175
- # learn if hit the number of learn timesteps
233
+ replay.store(**memory._asdict())
176
234
 
177
- if timesteps_learn >= num_timestep_before_learn:
178
- # todo - carry out learning
235
+ # increment counters
179
236
 
180
- timesteps_learn = 0
181
- memories.clear()
237
+ timestep += 1
182
238
 
183
239
  # break if done or exceed max timestep
184
240
 
@@ -187,6 +243,19 @@ def main(
187
243
 
188
244
  state = next_state
189
245
 
246
+ # learn if hit the number of learn timesteps
247
+
248
+ if divisible_by(episodes_index + 1, num_episodes_before_learn):
249
+
250
+ learn(
251
+ locoformer,
252
+ optim_actor,
253
+ optim_critic,
254
+ accelerator,
255
+ replay,
256
+ batch_size,
257
+ epochs,
258
+ )
190
259
  # main
191
260
 
192
261
  if __name__ == '__main__':
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes