locoformer 0.0.6__tar.gz → 0.0.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ recordings/
2
+ replay/
1
3
 
2
4
  # Byte-compiled / optimized / DLL files
3
5
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.6
3
+ Version: 0.0.17
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: rotary-embedding-torch
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
53
54
 
54
55
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
56
 
56
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
58
 
58
59
  ## Sponsors
59
60
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
6
 
7
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
8
 
9
9
  ## Sponsors
10
10
 
@@ -1,11 +1,24 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
+ from pathlib import Path
5
+ from contextlib import contextmanager
6
+ from collections import namedtuple
7
+
8
+ import numpy as np
9
+ from numpy import ndarray
10
+ from numpy.lib.format import open_memmap
11
+
12
+ from beartype import beartype
13
+ from beartype.door import is_bearable
14
+
4
15
  import torch
5
- from torch import nn, cat, stack, arange, is_tensor
16
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
6
17
  import torch.nn.functional as F
7
18
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
19
  from torch.utils._pytree import tree_map
20
+ from torch.utils.data import Dataset, DataLoader
21
+ from torch.optim import Optimizer
9
22
 
10
23
  import einx
11
24
  from einops import rearrange, einsum
@@ -15,6 +28,8 @@ from rotary_embedding_torch import RotaryEmbedding
15
28
 
16
29
  from assoc_scan import AssocScan
17
30
 
31
+ # constants
32
+
18
33
  LinearNoBias = partial(Linear, bias = False)
19
34
 
20
35
  # helper functions
@@ -31,20 +46,30 @@ def first(arr):
31
46
  def divisible_by(num, den):
32
47
  return (num % den) == 0
33
48
 
49
+ # tensor helpers
50
+
51
+ def log(t, eps = 1e-20):
52
+ return t.clamp_min(eps).log()
53
+
34
54
  def tree_map_tensor(x, fn):
35
55
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
36
56
 
37
- def detach_all(x):
38
- return tree_map_tensor(x, lambda t: t.detach())
39
-
40
- def combine_kv_cache(cache1, cache2):
41
- combined_cache = []
57
+ def pad_at_dim(
58
+ t,
59
+ pad: tuple[int, int],
60
+ dim = -1,
61
+ value = 0.
62
+ ):
63
+ if pad == (0, 0):
64
+ return t
42
65
 
43
- for layer_cache1, layer_cache2 in zip(cache1, cache2):
44
- next_cache = cat((layer_cache1, layer_cache2), dim = -2)
45
- combined_cache.append(next_cache)
66
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
67
+ zeros = ((0, 0) * dims_from_right)
68
+ return F.pad(t, (*zeros, *pad), value = value)
46
69
 
47
- return combined_cache
70
+ def calc_entropy(logits):
71
+ prob = logits.softmax(dim = -1)
72
+ return -(prob * log(prob)).sum(dim = -1)
48
73
 
49
74
  # generalized advantage estimate
50
75
 
@@ -52,7 +77,7 @@ def combine_kv_cache(cache1, cache2):
52
77
  def calc_gae(
53
78
  rewards,
54
79
  values,
55
- masks,
80
+ masks = None,
56
81
  gamma = 0.99,
57
82
  lam = 0.95,
58
83
  use_accelerated = None
@@ -63,6 +88,9 @@ def calc_gae(
63
88
  values = F.pad(values, (0, 1), value = 0.)
64
89
  values, values_next = values[..., :-1], values[..., 1:]
65
90
 
91
+ if not exists(masks):
92
+ masks = torch.ones_like(values)
93
+
66
94
  delta = rewards + gamma * values_next * masks - values
67
95
  gates = gamma * lam * masks
68
96
 
@@ -72,7 +100,7 @@ def calc_gae(
72
100
 
73
101
  returns = gae + values
74
102
 
75
- return returns
103
+ return gae, returns
76
104
 
77
105
  # transformer-xl mask w/ flex attn
78
106
 
@@ -114,8 +142,8 @@ def create_xl_mask(
114
142
  # handle intra-episodic attention if needed
115
143
 
116
144
  if exists(episode_ids):
117
- q_episode = episodes[b, q + offset]
118
- k_episode = episodes[b, k]
145
+ q_episode = episode_ids[b, q + offset]
146
+ k_episode = episode_ids[b, k]
119
147
 
120
148
  intra_episode_mask = q_episode == k_episode
121
149
  mask = mask & intra_episode_mask
@@ -146,6 +174,217 @@ def create_sliding_mask(
146
174
  create_kwargs = dict(device = device) if exists(device) else dict()
147
175
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
148
176
 
177
+ # data
178
+
179
+ def collate_var_time(data):
180
+
181
+ datum = first(data)
182
+ keys = datum.keys()
183
+
184
+ all_tensors = zip(*[datum.values() for datum in data])
185
+
186
+ collated_values = []
187
+
188
+ for key, tensors in zip(keys, all_tensors):
189
+
190
+ # the episode lens have zero dimension - think of a cleaner way to handle this later
191
+
192
+ if key != '_lens':
193
+
194
+ times = [t.shape[0] for t in tensors]
195
+ max_time = max(times)
196
+ tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
197
+
198
+ collated_values.append(stack(tensors))
199
+
200
+ return dict(zip(keys, collated_values))
201
+
202
+ class ReplayDataset(Dataset):
203
+ def __init__(
204
+ self,
205
+ folder: str | Path,
206
+ fields: tuple[str, ...] | None = None
207
+ ):
208
+ if isinstance(folder, str):
209
+ folder = Path(folder)
210
+
211
+ episode_lens = folder / 'episode_lens.npy'
212
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
213
+
214
+ # get indices of non-zero lengthed episodes
215
+
216
+ nonzero_episodes = self.episode_lens > 0
217
+ self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
218
+
219
+ # get all data files
220
+
221
+ filepaths = [*folder.glob('*.data.npy')]
222
+ assert len(filepaths) > 0
223
+
224
+ fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
225
+
226
+ fieldnames_from_files = set(fieldname_to_filepath.keys())
227
+
228
+ fields = default(fields, fieldnames_from_files)
229
+
230
+ self.memmaps = dict()
231
+
232
+ for field in fields:
233
+ assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
234
+
235
+ path = fieldname_to_filepath[field]
236
+
237
+ self.memmaps[field] = open_memmap(str(path), mode = 'r')
238
+
239
+ def __len__(self):
240
+ return len(self.indices)
241
+
242
+ def __getitem__(self, idx):
243
+ episode_index = self.indices[idx]
244
+
245
+ episode_len = self.episode_lens[episode_index]
246
+
247
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
248
+
249
+ data['_lens'] = tensor(episode_len)
250
+
251
+ return data
252
+
253
+ class ReplayBuffer:
254
+
255
+ @beartype
256
+ def __init__(
257
+ self,
258
+ folder: str | Path,
259
+ max_episodes: int,
260
+ max_timesteps: int,
261
+ fields: dict[
262
+ str,
263
+ str | tuple[str, int | tuple[int, ...]]
264
+ ]
265
+ ):
266
+
267
+ # folder for data
268
+
269
+ if not isinstance(folder, Path):
270
+ folder = Path(folder)
271
+ folder.mkdir(exist_ok = True)
272
+
273
+ self.folder = folder
274
+ assert folder.is_dir()
275
+
276
+ # keeping track of episode length
277
+
278
+ episode_lens = folder / 'episode_lens.npy'
279
+
280
+ self.episode_index = 0
281
+ self.timestep_index = 0
282
+
283
+ self.max_episodes = max_episodes
284
+ self.max_timesteps= max_timesteps
285
+
286
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
287
+
288
+ # create the memmap for individual data tracks
289
+
290
+ self.shapes = dict()
291
+ self.dtypes = dict()
292
+ self.memmaps = dict()
293
+ self.fieldnames = set(fields.keys())
294
+
295
+ for field_name, field_info in fields.items():
296
+
297
+ # some flexibility
298
+
299
+ field_info = (field_info, ()) if isinstance(field_info, str) else field_info
300
+
301
+ dtype_str, shape = field_info
302
+ assert dtype_str in {'int', 'float', 'bool'}
303
+
304
+ dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
305
+
306
+ # memmap file
307
+
308
+ filepath = folder / f'{field_name}.data.npy'
309
+ memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
310
+
311
+ self.memmaps[field_name] = memmap
312
+ self.shapes[field_name] = shape
313
+ self.dtypes[field_name] = dtype
314
+
315
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
+
317
+ def reset_(self):
318
+ self.episode_lens[:] = 0
319
+ self.episode_index = 0
320
+ self.timestep_index = 0
321
+
322
+ def advance_episode(self):
323
+ self.episode_index = (self.episode_index + 1) % self.max_episodes
324
+ self.timestep_index = 0
325
+
326
+ def flush(self):
327
+ self.episode_lens[self.episode_index] = self.timestep_index
328
+
329
+ for memmap in self.memmaps.values():
330
+ memmap.flush()
331
+
332
+ self.episode_lens.flush()
333
+
334
+ @contextmanager
335
+ def one_episode(self):
336
+
337
+ yield
338
+
339
+ self.flush()
340
+ self.advance_episode()
341
+
342
+ @beartype
343
+ def store_datapoint(
344
+ self,
345
+ episode_index: int,
346
+ timestep_index: int,
347
+ name: str,
348
+ datapoint: Tensor | ndarray
349
+ ):
350
+ assert 0 <= episode_index < self.max_episodes
351
+ assert 0 <= timestep_index < self.max_timesteps
352
+
353
+ if is_tensor(datapoint):
354
+ datapoint = datapoint.detach().cpu().numpy()
355
+
356
+ assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
357
+
358
+ assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
359
+
360
+ self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
361
+
362
+ def store(
363
+ self,
364
+ **data
365
+ ):
366
+ assert is_bearable(data, dict[str, Tensor | ndarray])
367
+
368
+ assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
369
+
370
+ for name, datapoint in data.items():
371
+
372
+ self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
373
+
374
+ self.timestep_index += 1
375
+
376
+ return self.memory_namedtuple(**data)
377
+
378
+ def dataset(self) -> Dataset:
379
+ self.flush()
380
+
381
+ return ReplayDataset(self.folder)
382
+
383
+ def dataloader(self, batch_size, **kwargs) -> DataLoader:
384
+ self.flush()
385
+
386
+ return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
387
+
149
388
  # transformer-xl with ppo
150
389
 
151
390
  class Attention(Module):
@@ -204,7 +443,6 @@ class Attention(Module):
204
443
  return_kv_cache = False,
205
444
  ):
206
445
  seq_len = tokens.shape[-2]
207
- assert seq_len <= self.window_size
208
446
 
209
447
  device = tokens.device
210
448
 
@@ -365,7 +603,14 @@ class Locoformer(Module):
365
603
  embedder: Module,
366
604
  unembedder: Module,
367
605
  transformer: dict | TransformerXL,
368
- value_network: Module | None = None
606
+ value_network: Module | None = None,
607
+ discount_factor = 0.999,
608
+ gae_lam = 0.95,
609
+ ppo_eps_clip = 0.2,
610
+ ppo_entropy_weight = 0.01,
611
+ ppo_value_clip = 0.4,
612
+ value_loss_weight = 0.5,
613
+ calc_gae_kwargs: dict = dict()
369
614
  ):
370
615
  super().__init__()
371
616
 
@@ -382,15 +627,160 @@ class Locoformer(Module):
382
627
  self.fixed_window_size = transformer.fixed_window_size
383
628
  self.window_size = transformer.window_size
384
629
 
630
+ # ppo related
631
+
632
+ self.discount_factor = discount_factor
633
+ self.gae_lam = gae_lam
634
+ self.ppo_eps_clip = ppo_eps_clip
635
+ self.ppo_entropy_weight = ppo_entropy_weight
636
+ self.ppo_value_clip = ppo_value_clip
637
+ self.value_loss_weight = value_loss_weight
638
+
639
+ self.calc_gae_kwargs = calc_gae_kwargs
640
+
641
+ # loss related
642
+
643
+ self.register_buffer('zero', tensor(0.), persistent = False)
644
+
385
645
  @property
386
646
  def device(self):
387
647
  return next(self.parameters()).device
388
648
 
649
+ def actor_parameters(self):
650
+ return self.unembedder.parameters()
651
+
652
+ def critic_parameters(self):
653
+ if not exists(self.value_network):
654
+ return []
655
+
656
+ return self.value_network.parameters()
657
+
658
+ def ppo(
659
+ self,
660
+ state,
661
+ action,
662
+ old_action_log_prob,
663
+ reward,
664
+ old_value,
665
+ mask,
666
+ actor_optim: Optimizer | None = None,
667
+ critic_optim: Optimizer | None = None
668
+ ):
669
+ window_size = self.window_size
670
+ total_learnable_tokens = mask.sum().item()
671
+
672
+ windowed_tensors = [
673
+ t.split(window_size, dim = 1) for t in
674
+ (
675
+ state,
676
+ action,
677
+ old_action_log_prob,
678
+ reward,
679
+ old_value,
680
+ mask
681
+ )
682
+ ]
683
+
684
+ mean_actor_loss = self.zero.clone()
685
+ mean_critic_loss = self.zero.clone()
686
+
687
+ # learn across windows
688
+
689
+ cache = None
690
+
691
+ for (
692
+ state,
693
+ action,
694
+ old_action_log_prob,
695
+ reward,
696
+ old_value,
697
+ mask
698
+ ) in zip(*windowed_tensors):
699
+
700
+ (action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
701
+ entropy = calc_entropy(action_logits)
702
+
703
+ action = rearrange(action, 'b t -> b t 1')
704
+ log_prob = action_logits.gather(-1, action)
705
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
706
+
707
+ # update actor, classic clipped surrogate loss
708
+
709
+ eps_clip = self.ppo_eps_clip
710
+ ratio = (log_prob - old_action_log_prob).exp()
711
+
712
+ advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
713
+
714
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
715
+
716
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
717
+
718
+ windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
719
+ windowed_actor_loss.backward(retain_graph = True)
720
+
721
+ # update critic
722
+
723
+ value_loss = F.mse_loss(returns, value, reduction = 'none')
724
+
725
+ value_clip = self.ppo_value_clip
726
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
727
+ clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
728
+
729
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
730
+
731
+ windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
732
+ windowed_critic_loss.backward(retain_graph = True)
733
+
734
+ # accumulate
735
+
736
+ mean_actor_loss.add_(windowed_actor_loss)
737
+ mean_critic_loss.add_(windowed_critic_loss)
738
+
739
+ # optimizer update
740
+
741
+ if exists(actor_optim):
742
+ actor_optim.step()
743
+ actor_optim.zero_grad()
744
+
745
+ if exists(critic_optim):
746
+ critic_optim.step()
747
+ critic_optim.zero_grad()
748
+
749
+ # return losses for logging
750
+
751
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
752
+
753
+ def wrap_env_functions(self, env):
754
+
755
+ def wrapped_reset(*args, **kwargs):
756
+ state, _ = env.reset(*args, **kwargs)
757
+
758
+ if isinstance(state, ndarray):
759
+ state = from_numpy(state)
760
+
761
+ return state, _
762
+
763
+ def wrapped_step(action, *args, **kwargs):
764
+ out = env.step(action.item(), *args, **kwargs)
765
+
766
+ def transform_output(el):
767
+ if isinstance(el, ndarray):
768
+ return from_numpy(el)
769
+ elif isinstance(el, (int, bool, float)):
770
+ return tensor(el)
771
+ else:
772
+ return el
773
+
774
+ return tree_map(transform_output, out)
775
+
776
+ return wrapped_reset, wrapped_step
777
+
389
778
  def get_stateful_forward(
390
779
  self,
391
780
  initial_states: Tensor | None = None,
392
781
  inference_mode = False,
393
782
  has_batch_dim = False,
783
+ has_time_dim = False,
394
784
  **kwargs
395
785
  ):
396
786
  window_size = self.window_size
@@ -400,11 +790,14 @@ class Locoformer(Module):
400
790
  def stateful_forward(state: Tensor, **override_kwargs):
401
791
  nonlocal cache
402
792
 
403
- # handle no batch, for easier time rolling out against envs
793
+ # handle no batch or time, for easier time rolling out against envs
404
794
 
405
795
  if not has_batch_dim:
406
796
  state = rearrange(state, '... -> 1 ...')
407
797
 
798
+ if not has_time_dim:
799
+ state = rearrange(state, '... d -> ... 1 d')
800
+
408
801
  # forwards
409
802
 
410
803
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
@@ -416,7 +809,10 @@ class Locoformer(Module):
416
809
  if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
417
810
  cache = cache[..., -window_size:, :]
418
811
 
419
- # maybe remove batch
812
+ # maybe remove batch or time
813
+
814
+ if not has_time_dim:
815
+ out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
420
816
 
421
817
  if not has_batch_dim:
422
818
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -450,6 +846,8 @@ class Locoformer(Module):
450
846
  return_values = False
451
847
  ):
452
848
 
849
+ state = state.to(self.device)
850
+
453
851
  tokens = self.embedder(state)
454
852
 
455
853
  embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
@@ -463,7 +861,7 @@ class Locoformer(Module):
463
861
  # maybe detach cache
464
862
 
465
863
  if detach_cache:
466
- kv_cache = detach_all(kv_cache)
864
+ kv_cache = kv_cache.detach()
467
865
 
468
866
  # handle returning of values
469
867
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.6"
3
+ version = "0.0.17"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -27,6 +27,7 @@ classifiers=[
27
27
 
28
28
  dependencies = [
29
29
  "assoc-scan",
30
+ "beartype",
30
31
  "einx>=0.3.0",
31
32
  "einops>=0.8.0",
32
33
  "rotary-embedding-torch",
@@ -0,0 +1,86 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from x_mlps_pytorch import MLP
6
+
7
+ from einops import rearrange
8
+
9
+ def test_locoformer():
10
+ from locoformer.locoformer import Locoformer
11
+ from torch import nn
12
+
13
+ model = Locoformer(
14
+ embedder = nn.Embedding(256, 128),
15
+ unembedder = nn.Linear(128, 256, bias = False),
16
+ value_network = MLP(128, 32, 1),
17
+ transformer = dict(
18
+ dim = 128,
19
+ depth = 1,
20
+ window_size = 512
21
+ )
22
+ )
23
+
24
+ seq = torch.randint(0, 256, (3, 512))
25
+
26
+ (logits, values), cache = model(seq, return_values = True)
27
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
28
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
29
+
30
+ assert logits.shape == (3, 512, 256)
31
+
32
+ stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
33
+
34
+ for state in seq.unbind(dim = -1):
35
+ state = rearrange(state, 'b -> b 1')
36
+
37
+ logits, values = stateful_forward(state)
38
+ assert logits.shape == (3, 1, 256)
39
+
40
+ def test_replay():
41
+ from locoformer.locoformer import ReplayBuffer
42
+
43
+ replay_buffer = ReplayBuffer(
44
+ './replay_data',
45
+ max_episodes = 10_000,
46
+ max_timesteps = 501,
47
+ fields = dict(
48
+ state = ('float', (8,)),
49
+ action = 'int',
50
+ action_log_prob = 'float',
51
+ reward = 'float',
52
+ value = 'float',
53
+ done = 'bool'
54
+ )
55
+ )
56
+
57
+ lens = [3, 5, 4]
58
+
59
+ for episode_len in lens:
60
+ with replay_buffer.one_episode():
61
+ for _ in range(episode_len):
62
+ state = torch.randn((8,))
63
+ action = torch.randint(0, 4, ())
64
+ log_prob = torch.randn(())
65
+ reward = torch.randn(())
66
+ value = torch.randn(())
67
+ done = torch.randint(0, 2, ()).bool()
68
+
69
+ replay_buffer.store(
70
+ state = state,
71
+ action = action,
72
+ action_log_prob = log_prob,
73
+ reward = reward,
74
+ value = value,
75
+ done = done
76
+ )
77
+
78
+ dataset = replay_buffer.dataset()
79
+
80
+ assert len(dataset) == 3
81
+
82
+ assert torch.is_tensor(dataset[0]['state'])
83
+
84
+ dataloader = replay_buffer.dataloader(batch_size = 3)
85
+
86
+ assert next(iter(dataloader))['state'].shape[0] == 3
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
169
169
  prime = prime.to(model.device)
170
170
  out = prime
171
171
 
172
- stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
172
+ stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
173
173
 
174
174
  # sample
175
175
 
@@ -0,0 +1,262 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate",
4
+ # "fire",
5
+ # "gymnasium[box2d]>=1.0.0",
6
+ # "locoformer>=0.0.12",
7
+ # "moviepy",
8
+ # "tqdm"
9
+ # ]
10
+ # ///
11
+
12
+ from fire import Fire
13
+ from shutil import rmtree
14
+ from tqdm import tqdm
15
+ from collections import deque
16
+ from types import SimpleNamespace
17
+
18
+ from accelerate import Accelerator
19
+
20
+ import gymnasium as gym
21
+
22
+ import torch
23
+ from torch import from_numpy, randint, tensor, stack, arange
24
+ import torch.nn.functional as F
25
+ from torch.utils.data import TensorDataset, DataLoader
26
+ from torch.optim import Adam
27
+
28
+ import einx
29
+ from einops import rearrange
30
+
31
+ from locoformer.locoformer import Locoformer, ReplayBuffer
32
+ from x_mlps_pytorch import MLP
33
+
34
+ # helper functions
35
+
36
+ def exists(v):
37
+ return v is not None
38
+
39
+ def divisible_by(num, den):
40
+ return (num % den) == 0
41
+
42
+ def log(t, eps = 1e-20):
43
+ return t.clamp(min = eps).log()
44
+
45
+ def gumbel_noise(t):
46
+ return -log(-log(torch.rand_like(t)))
47
+
48
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
49
+ noise = gumbel_noise(logits)
50
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
51
+
52
+ # learn
53
+
54
+ def learn(
55
+ model,
56
+ actor_optim,
57
+ critic_optim,
58
+ accelerator,
59
+ replay,
60
+ batch_size = 16,
61
+ epochs = 2,
62
+ ):
63
+ device = accelerator.device
64
+
65
+ dl = replay.dataloader(batch_size = batch_size, shuffle = True)
66
+ model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
67
+
68
+ for _ in range(epochs):
69
+ for data in dl:
70
+
71
+ data = SimpleNamespace(**data)
72
+
73
+ seq_len = data.state.shape[1]
74
+
75
+ value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
76
+ value = torch.where(value_mask, data.value, 0.)
77
+
78
+ actor_loss, critic_loss = model.ppo(
79
+ state = data.state,
80
+ action = data.action,
81
+ old_action_log_prob = data.action_log_prob,
82
+ reward = data.reward,
83
+ old_value = value,
84
+ mask = data.learnable,
85
+ actor_optim = actor_optim,
86
+ critic_optim = critic_optim
87
+ )
88
+
89
+ accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
90
+
91
+ # main function
92
+
93
+ def main(
94
+ env_name = 'LunarLander-v3',
95
+ num_episodes = 50_000,
96
+ max_timesteps = 500,
97
+ num_episodes_before_learn = 32,
98
+ clear_video = True,
99
+ video_folder = 'recordings',
100
+ record_every_episode = 250,
101
+ learning_rate = 8e-4,
102
+ discount_factor = 0.99,
103
+ betas = (0.9, 0.99),
104
+ gae_lam = 0.95,
105
+ ppo_eps_clip = 0.2,
106
+ ppo_entropy_weight = .01,
107
+ batch_size = 16,
108
+ epochs = 2
109
+ ):
110
+
111
+ # accelerate
112
+
113
+ accelerator = Accelerator()
114
+ device = accelerator.device
115
+
116
+ # environment
117
+
118
+ env = gym.make(env_name, render_mode = 'rgb_array')
119
+
120
+ if clear_video:
121
+ rmtree(video_folder, ignore_errors = True)
122
+
123
+ env = gym.wrappers.RecordVideo(
124
+ env = env,
125
+ video_folder = video_folder,
126
+ name_prefix = 'lunar-video',
127
+ episode_trigger = lambda eps: divisible_by(eps, record_every_episode),
128
+ disable_logger = True
129
+ )
130
+
131
+ dim_state = env.observation_space.shape[0]
132
+ num_actions = env.action_space.n
133
+
134
+ # memory
135
+
136
+ replay = ReplayBuffer(
137
+ 'replay',
138
+ num_episodes,
139
+ max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
140
+ fields = dict(
141
+ state = ('float', (dim_state,)),
142
+ action = 'int',
143
+ action_log_prob = 'float',
144
+ reward = 'float',
145
+ value = 'float',
146
+ done = 'bool',
147
+ learnable = 'bool'
148
+ )
149
+ )
150
+
151
+ # networks
152
+
153
+ locoformer = Locoformer(
154
+ embedder = MLP(dim_state, 64, bias = False),
155
+ unembedder = MLP(64, num_actions, bias = False),
156
+ value_network = MLP(64, 1, bias = False),
157
+ transformer = dict(
158
+ dim = 64,
159
+ dim_head = 32,
160
+ heads = 4,
161
+ depth = 4,
162
+ window_size = 16
163
+ ),
164
+ discount_factor = discount_factor,
165
+ gae_lam = gae_lam,
166
+ ppo_eps_clip = ppo_eps_clip,
167
+ ppo_entropy_weight = ppo_entropy_weight,
168
+ calc_gae_kwargs = dict(
169
+ use_accelerated = False
170
+ )
171
+ ).to(device)
172
+
173
+ optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
174
+ optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
175
+
176
+ timesteps_learn = 0
177
+
178
+ # able to wrap the env for all values to torch tensors and back
179
+ # all environments should follow usual MDP interface, domain randomization should be given at instantiation
180
+
181
+ env_reset, env_step = locoformer.wrap_env_functions(env)
182
+
183
+ # loop
184
+
185
+ for episodes_index in tqdm(range(num_episodes)):
186
+
187
+ state, *_ = env_reset()
188
+
189
+ timestep = 0
190
+
191
+ stateful_forward = locoformer.get_stateful_forward(has_batch_dim = False, has_time_dim = False, inference_mode = True)
192
+
193
+ with replay.one_episode():
194
+ while True:
195
+
196
+ # predict next action
197
+
198
+ action_logits, value = stateful_forward(state, return_values = True)
199
+
200
+ action = gumbel_sample(action_logits)
201
+
202
+ # pass to environment
203
+
204
+ next_state, reward, truncated, terminated, *_ = env_step(action)
205
+
206
+ # append to memory
207
+
208
+ done = truncated or terminated
209
+
210
+ # get log prob of action
211
+
212
+ action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
213
+ action_log_prob = rearrange(action_log_prob, '1 ->')
214
+
215
+ memory = replay.store(
216
+ state = state,
217
+ action = action,
218
+ action_log_prob = action_log_prob,
219
+ reward = reward,
220
+ value = value,
221
+ done = done,
222
+ learnable = tensor(True)
223
+ )
224
+
225
+ # handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
226
+ # only if terminated signal not detected
227
+
228
+ if not terminated:
229
+ _, next_value = stateful_forward(next_state, return_values = True)
230
+
231
+ memory._replace(value = next_value, learnable = False)
232
+
233
+ replay.store(**memory._asdict())
234
+
235
+ # increment counters
236
+
237
+ timestep += 1
238
+
239
+ # break if done or exceed max timestep
240
+
241
+ if done or timestep >= max_timesteps:
242
+ break
243
+
244
+ state = next_state
245
+
246
+ # learn if hit the number of learn timesteps
247
+
248
+ if divisible_by(episodes_index + 1, num_episodes_before_learn):
249
+
250
+ learn(
251
+ locoformer,
252
+ optim_actor,
253
+ optim_critic,
254
+ accelerator,
255
+ replay,
256
+ batch_size,
257
+ epochs,
258
+ )
259
+ # main
260
+
261
+ if __name__ == '__main__':
262
+ Fire(main)
@@ -1,38 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from x_mlps_pytorch import MLP
6
-
7
- from einops import rearrange
8
-
9
- def test_locoformer():
10
- from locoformer.locoformer import Locoformer
11
- from torch import nn
12
-
13
- model = Locoformer(
14
- embedder = nn.Embedding(256, 128),
15
- unembedder = nn.Linear(128, 256, bias = False),
16
- value_network = MLP(128, 32, 1),
17
- transformer = dict(
18
- dim = 128,
19
- depth = 1,
20
- window_size = 256
21
- )
22
- )
23
-
24
- seq = torch.randint(0, 256, (3, 512))
25
-
26
- (logits, values), cache = model(seq, return_values = True)
27
- (logits, values), cache = model(seq, return_values = True, cache = cache)
28
- (logits, values), cache = model(seq, return_values = True, cache = cache)
29
-
30
- assert logits.shape == (3, 512, 256)
31
-
32
- stateful_forward = model.get_stateful_forward(256, has_batch_dim = True, return_values = True, inference_mode = True)
33
-
34
- for state in seq.unbind(dim = -1):
35
- state = rearrange(state, 'b -> b 1')
36
-
37
- logits, values = stateful_forward(state)
38
- assert logits.shape == (3, 1, 256)
File without changes
File without changes
File without changes
File without changes