locoformer 0.0.6__tar.gz → 0.0.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ recordings/
2
+ replay/
1
3
 
2
4
  # Byte-compiled / optimized / DLL files
3
5
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.6
3
+ Version: 0.0.11
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: rotary-embedding-torch
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
53
54
 
54
55
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
56
 
56
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
58
 
58
59
  ## Sponsors
59
60
 
@@ -4,7 +4,7 @@
4
4
 
5
5
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
6
6
 
7
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
7
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
8
8
 
9
9
  ## Sponsors
10
10
 
@@ -1,11 +1,22 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
+ from pathlib import Path
5
+ from contextlib import contextmanager
6
+
7
+ import numpy as np
8
+ from numpy import ndarray
9
+ from numpy.lib.format import open_memmap
10
+
11
+ from beartype import beartype
12
+ from beartype.door import is_bearable
13
+
4
14
  import torch
5
- from torch import nn, cat, stack, arange, is_tensor
15
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
6
16
  import torch.nn.functional as F
7
17
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
18
  from torch.utils._pytree import tree_map
19
+ from torch.utils.data import Dataset, DataLoader
9
20
 
10
21
  import einx
11
22
  from einops import rearrange, einsum
@@ -37,14 +48,18 @@ def tree_map_tensor(x, fn):
37
48
  def detach_all(x):
38
49
  return tree_map_tensor(x, lambda t: t.detach())
39
50
 
40
- def combine_kv_cache(cache1, cache2):
41
- combined_cache = []
42
-
43
- for layer_cache1, layer_cache2 in zip(cache1, cache2):
44
- next_cache = cat((layer_cache1, layer_cache2), dim = -2)
45
- combined_cache.append(next_cache)
51
+ def pad_at_dim(
52
+ t,
53
+ pad: tuple[int, int],
54
+ dim = -1,
55
+ value = 0.
56
+ ):
57
+ if pad == (0, 0):
58
+ return t
46
59
 
47
- return combined_cache
60
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
61
+ zeros = ((0, 0) * dims_from_right)
62
+ return F.pad(t, (*zeros, *pad), value = value)
48
63
 
49
64
  # generalized advantage estimate
50
65
 
@@ -146,6 +161,208 @@ def create_sliding_mask(
146
161
  create_kwargs = dict(device = device) if exists(device) else dict()
147
162
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
148
163
 
164
+ # data
165
+
166
+ def collate_var_time(data):
167
+
168
+ datum = first(data)
169
+ keys = datum.keys()
170
+
171
+ all_tensors = zip(*[datum.values() for datum in data])
172
+
173
+ collated_values = []
174
+
175
+ for key, tensors in zip(keys, all_tensors):
176
+
177
+ # the episode lens have zero dimension - think of a cleaner way to handle this later
178
+
179
+ if key != '_lens':
180
+
181
+ times = [t.shape[0] for t in tensors]
182
+ max_time = max(times)
183
+ tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
184
+
185
+ collated_values.append(stack(tensors))
186
+
187
+ return dict(zip(keys, collated_values))
188
+
189
+ class ReplayDataset(Dataset):
190
+ def __init__(
191
+ self,
192
+ folder: str | Path,
193
+ fields: tuple[str, ...] | None = None
194
+ ):
195
+ if isinstance(folder, str):
196
+ folder = Path(folder)
197
+
198
+ episode_lens = folder / 'episode_lens.npy'
199
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
200
+
201
+ # get indices of non-zero lengthed episodes
202
+
203
+ nonzero_episodes = self.episode_lens > 0
204
+ self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
205
+
206
+ # get all data files
207
+
208
+ filepaths = [*folder.glob('*.data.npy')]
209
+ assert len(filepaths) > 0
210
+
211
+ fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
212
+
213
+ fieldnames_from_files = set(fieldname_to_filepath.keys())
214
+
215
+ fields = default(fields, fieldnames_from_files)
216
+
217
+ self.memmaps = dict()
218
+
219
+ for field in fields:
220
+ assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
221
+
222
+ path = fieldname_to_filepath[field]
223
+
224
+ self.memmaps[field] = open_memmap(str(path), mode = 'r')
225
+
226
+ def __len__(self):
227
+ return len(self.indices)
228
+
229
+ def __getitem__(self, idx):
230
+ episode_index = self.indices[idx]
231
+
232
+ episode_len = self.episode_lens[episode_index]
233
+
234
+ data = {field: torch.from_numpy(memmap[episode_index, :episode_len]) for field, memmap in self.memmaps.items()}
235
+
236
+ data['_lens'] = tensor(episode_len)
237
+
238
+ return data
239
+
240
+ class ReplayBuffer:
241
+
242
+ @beartype
243
+ def __init__(
244
+ self,
245
+ folder: str | Path,
246
+ max_episodes: int,
247
+ max_timesteps: int,
248
+ fields: dict[
249
+ str,
250
+ str | tuple[str, int | tuple[int, ...]]
251
+ ]
252
+ ):
253
+
254
+ # folder for data
255
+
256
+ if not isinstance(folder, Path):
257
+ folder = Path(folder)
258
+ folder.mkdir(exist_ok = True)
259
+
260
+ self.folder = folder
261
+ assert folder.is_dir()
262
+
263
+ # keeping track of episode length
264
+
265
+ episode_lens = folder / 'episode_lens.npy'
266
+
267
+ self.episode_index = 0
268
+ self.timestep_index = 0
269
+
270
+ self.max_episodes = max_episodes
271
+ self.max_timesteps= max_timesteps
272
+
273
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
274
+
275
+ # create the memmap for individual data tracks
276
+
277
+ self.shapes = dict()
278
+ self.dtypes = dict()
279
+ self.memmaps = dict()
280
+ self.fieldnames = set(fields.keys())
281
+
282
+ for field_name, field_info in fields.items():
283
+
284
+ # some flexibility
285
+
286
+ field_info = (field_info, ()) if isinstance(field_info, str) else field_info
287
+
288
+ dtype_str, shape = field_info
289
+ assert dtype_str in {'int', 'float', 'bool'}
290
+
291
+ dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
292
+
293
+ # memmap file
294
+
295
+ filepath = folder / f'{field_name}.data.npy'
296
+ memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
297
+
298
+ self.memmaps[field_name] = memmap
299
+ self.shapes[field_name] = shape
300
+ self.dtypes[field_name] = dtype
301
+
302
+ def advance_episode(self):
303
+ self.episode_index = (self.episode_index + 1) % self.max_episodes
304
+ self.timestep_index = 0
305
+
306
+ def flush(self):
307
+ self.episode_lens[self.episode_index] = self.timestep_index
308
+
309
+ for memmap in self.memmaps.values():
310
+ memmap.flush()
311
+
312
+ self.episode_lens.flush()
313
+
314
+ @contextmanager
315
+ def one_episode(self):
316
+
317
+ yield
318
+
319
+ self.flush()
320
+ self.advance_episode()
321
+
322
+ @beartype
323
+ def store_datapoint(
324
+ self,
325
+ episode_index: int,
326
+ timestep_index: int,
327
+ name: str,
328
+ datapoint: Tensor | ndarray
329
+ ):
330
+ assert 0 <= episode_index < self.max_episodes
331
+ assert 0 <= timestep_index < self.max_timesteps
332
+
333
+ if is_tensor(datapoint):
334
+ datapoint = datapoint.detach().cpu().numpy()
335
+
336
+ assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
337
+
338
+ assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
339
+
340
+ self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
341
+
342
+ def store(
343
+ self,
344
+ **data
345
+ ):
346
+ assert is_bearable(data, dict[str, Tensor | ndarray])
347
+
348
+ assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
349
+
350
+ for name, datapoint in data.items():
351
+
352
+ self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
353
+
354
+ self.timestep_index += 1
355
+
356
+ def dataset(self) -> Dataset:
357
+ self.flush()
358
+
359
+ return ReplayDataset(self.folder)
360
+
361
+ def dataloader(self, **kwargs) -> DataLoader:
362
+ self.flush()
363
+
364
+ return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
365
+
149
366
  # transformer-xl with ppo
150
367
 
151
368
  class Attention(Module):
@@ -386,11 +603,46 @@ class Locoformer(Module):
386
603
  def device(self):
387
604
  return next(self.parameters()).device
388
605
 
606
+ def actor_parameters(self):
607
+ return self.unembedder.parameters()
608
+
609
+ def critic_parameters(self):
610
+ if not exists(self.value_network):
611
+ return []
612
+
613
+ return self.value_network.parameters()
614
+
615
+ def wrap_env_functions(self, env):
616
+
617
+ def wrapped_reset(*args, **kwargs):
618
+ state, _ = env.reset(*args, **kwargs)
619
+
620
+ if isinstance(state, ndarray):
621
+ state = from_numpy(state)
622
+
623
+ return state, _
624
+
625
+ def wrapped_step(action, *args, **kwargs):
626
+ out = env.step(action.item(), *args, **kwargs)
627
+
628
+ def transform_output(el):
629
+ if isinstance(el, ndarray):
630
+ return from_numpy(el)
631
+ elif isinstance(el, (int, bool, float)):
632
+ return tensor(el)
633
+ else:
634
+ return el
635
+
636
+ return tree_map(transform_output, out)
637
+
638
+ return wrapped_reset, wrapped_step
639
+
389
640
  def get_stateful_forward(
390
641
  self,
391
642
  initial_states: Tensor | None = None,
392
643
  inference_mode = False,
393
644
  has_batch_dim = False,
645
+ has_time_dim = False,
394
646
  **kwargs
395
647
  ):
396
648
  window_size = self.window_size
@@ -400,11 +652,14 @@ class Locoformer(Module):
400
652
  def stateful_forward(state: Tensor, **override_kwargs):
401
653
  nonlocal cache
402
654
 
403
- # handle no batch, for easier time rolling out against envs
655
+ # handle no batch or time, for easier time rolling out against envs
404
656
 
405
657
  if not has_batch_dim:
406
658
  state = rearrange(state, '... -> 1 ...')
407
659
 
660
+ if not has_time_dim:
661
+ state = rearrange(state, '... d -> ... 1 d')
662
+
408
663
  # forwards
409
664
 
410
665
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
@@ -416,7 +671,10 @@ class Locoformer(Module):
416
671
  if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
417
672
  cache = cache[..., -window_size:, :]
418
673
 
419
- # maybe remove batch
674
+ # maybe remove batch or time
675
+
676
+ if not has_time_dim:
677
+ out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
420
678
 
421
679
  if not has_batch_dim:
422
680
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -450,6 +708,8 @@ class Locoformer(Module):
450
708
  return_values = False
451
709
  ):
452
710
 
711
+ state = state.to(self.device)
712
+
453
713
  tokens = self.embedder(state)
454
714
 
455
715
  embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "locoformer"
3
- version = "0.0.6"
3
+ version = "0.0.11"
4
4
  description = "LocoFormer"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -27,6 +27,7 @@ classifiers=[
27
27
 
28
28
  dependencies = [
29
29
  "assoc-scan",
30
+ "beartype",
30
31
  "einx>=0.3.0",
31
32
  "einops>=0.8.0",
32
33
  "rotary-embedding-torch",
@@ -0,0 +1,86 @@
1
+ import pytest
2
+ param = pytest.mark.parametrize
3
+
4
+ import torch
5
+ from x_mlps_pytorch import MLP
6
+
7
+ from einops import rearrange
8
+
9
+ def test_locoformer():
10
+ from locoformer.locoformer import Locoformer
11
+ from torch import nn
12
+
13
+ model = Locoformer(
14
+ embedder = nn.Embedding(256, 128),
15
+ unembedder = nn.Linear(128, 256, bias = False),
16
+ value_network = MLP(128, 32, 1),
17
+ transformer = dict(
18
+ dim = 128,
19
+ depth = 1,
20
+ window_size = 512
21
+ )
22
+ )
23
+
24
+ seq = torch.randint(0, 256, (3, 512))
25
+
26
+ (logits, values), cache = model(seq, return_values = True)
27
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
28
+ (logits, values), cache = model(seq, return_values = True, cache = cache)
29
+
30
+ assert logits.shape == (3, 512, 256)
31
+
32
+ stateful_forward = model.get_stateful_forward(has_batch_dim = True, has_time_dim = True, return_values = True, inference_mode = True)
33
+
34
+ for state in seq.unbind(dim = -1):
35
+ state = rearrange(state, 'b -> b 1')
36
+
37
+ logits, values = stateful_forward(state)
38
+ assert logits.shape == (3, 1, 256)
39
+
40
+ def test_replay():
41
+ from locoformer.locoformer import ReplayBuffer
42
+
43
+ replay_buffer = ReplayBuffer(
44
+ './replay_data',
45
+ max_episodes = 10_000,
46
+ max_timesteps = 501,
47
+ fields = dict(
48
+ state = ('float', (8,)),
49
+ action = 'int',
50
+ action_log_prob = 'float',
51
+ reward = 'float',
52
+ value = 'float',
53
+ done = 'bool'
54
+ )
55
+ )
56
+
57
+ lens = [3, 5, 4]
58
+
59
+ for episode_len in lens:
60
+ with replay_buffer.one_episode():
61
+ for _ in range(episode_len):
62
+ state = torch.randn((8,))
63
+ action = torch.randint(0, 4, ())
64
+ log_prob = torch.randn(())
65
+ reward = torch.randn(())
66
+ value = torch.randn(())
67
+ done = torch.randint(0, 2, ()).bool()
68
+
69
+ replay_buffer.store(
70
+ state = state,
71
+ action = action,
72
+ action_log_prob = log_prob,
73
+ reward = reward,
74
+ value = value,
75
+ done = done
76
+ )
77
+
78
+ dataset = replay_buffer.dataset()
79
+
80
+ assert len(dataset) == 3
81
+
82
+ assert torch.is_tensor(dataset[0]['state'])
83
+
84
+ dataloader = replay_buffer.dataloader(batch_size = 3)
85
+
86
+ assert next(iter(dataloader))['state'].shape[0] == 3
@@ -0,0 +1,193 @@
1
+ # /// script
2
+ # dependencies = [
3
+ # "accelerate",
4
+ # "fire",
5
+ # "gymnasium[box2d]>=1.0.0",
6
+ # "locoformer",
7
+ # "moviepy",
8
+ # "tqdm"
9
+ # ]
10
+ # ///
11
+
12
+ from fire import Fire
13
+ from shutil import rmtree
14
+ from tqdm import tqdm
15
+ from collections import deque
16
+
17
+ from accelerate import Accelerator
18
+
19
+ import gymnasium as gym
20
+
21
+ import torch
22
+ from torch import from_numpy, randint, tensor, stack
23
+ import torch.nn.functional as F
24
+ from torch.utils.data import TensorDataset, DataLoader
25
+ from torch.optim import Adam
26
+
27
+ from einops import rearrange
28
+
29
+ from locoformer.locoformer import Locoformer, ReplayBuffer
30
+ from x_mlps_pytorch import MLP
31
+
32
+ # helper functions
33
+
34
+ def exists(v):
35
+ return v is not None
36
+
37
+ def divisible_by(num, den):
38
+ return (num % den) == 0
39
+
40
+ def log(t, eps = 1e-20):
41
+ return t.clamp(min = eps).log()
42
+
43
+ def gumbel_noise(t):
44
+ return -log(-log(torch.rand_like(t)))
45
+
46
+ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
47
+ noise = gumbel_noise(logits)
48
+ return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
49
+
50
+ # main function
51
+
52
+ def main(
53
+ env_name = 'LunarLander-v3',
54
+ num_episodes = 50_000,
55
+ max_timesteps = 500,
56
+ num_timestep_before_learn = 5000,
57
+ clear_video = True,
58
+ video_folder = 'recordings',
59
+ record_every_episode = 250,
60
+ discount_factor = 0.99,
61
+ learning_rate = 1e-4,
62
+ batch_size = 16,
63
+ epochs = 2
64
+ ):
65
+
66
+ # accelerate
67
+
68
+ accelerate = Accelerator()
69
+ device = accelerate.device
70
+
71
+ # environment
72
+
73
+ env = gym.make(env_name, render_mode = 'rgb_array')
74
+
75
+ if clear_video:
76
+ rmtree(video_folder, ignore_errors = True)
77
+
78
+ env = gym.wrappers.RecordVideo(
79
+ env = env,
80
+ video_folder = video_folder,
81
+ name_prefix = 'lunar-video',
82
+ episode_trigger = lambda eps: divisible_by(eps, record_every_episode),
83
+ disable_logger = True
84
+ )
85
+
86
+ dim_state = env.observation_space.shape[0]
87
+ num_actions = env.action_space.n
88
+
89
+ # memory
90
+
91
+ replay = ReplayBuffer(
92
+ 'replay',
93
+ num_episodes,
94
+ max_timesteps,
95
+ fields = dict(
96
+ state = ('float', (dim_state,)),
97
+ action = 'int',
98
+ action_log_prob = 'float',
99
+ reward = 'float',
100
+ value = 'float',
101
+ done = 'bool'
102
+ )
103
+ )
104
+
105
+ # networks
106
+
107
+ locoformer = Locoformer(
108
+ embedder = MLP(dim_state, 64, bias = False),
109
+ unembedder = MLP(64, num_actions, bias = False),
110
+ value_network = MLP(64, 1, bias = False),
111
+ transformer = dict(
112
+ dim = 64,
113
+ dim_head = 32,
114
+ heads = 4,
115
+ depth = 4,
116
+ window_size = 16
117
+ )
118
+ ).to(device)
119
+
120
+ optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
121
+ optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
122
+
123
+ timesteps_learn = 0
124
+
125
+ # able to wrap the env for all values to torch tensors and back
126
+ # all environments should follow usual MDP interface, domain randomization should be given at instantiation
127
+
128
+ env_reset, env_step = locoformer.wrap_env_functions(env)
129
+
130
+ # loop
131
+
132
+ for _ in tqdm(range(num_episodes)):
133
+ state, *_ = env_reset()
134
+
135
+ timestep = 0
136
+
137
+ stateful_forward = locoformer.get_stateful_forward(has_batch_dim = False, has_time_dim = False, inference_mode = True)
138
+
139
+ with replay.one_episode():
140
+ while True:
141
+
142
+ # predict next action
143
+
144
+ action_logits, value = stateful_forward(state, return_values = True)
145
+
146
+ action = gumbel_sample(action_logits)
147
+
148
+ # pass to environment
149
+
150
+ next_state, reward, truncated, terminated, *_ = env_step(action)
151
+
152
+ # append to memory
153
+
154
+ done = truncated or terminated
155
+
156
+ # get log prob of action
157
+
158
+ action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
159
+ action_log_prob = rearrange(action_log_prob, '1 ->')
160
+
161
+ replay.store(
162
+ state = state,
163
+ action = action,
164
+ action_log_prob = action_log_prob,
165
+ reward = reward,
166
+ value = value,
167
+ done = done
168
+ )
169
+
170
+ # increment counters
171
+
172
+ timestep += 1
173
+ timesteps_learn += 1
174
+
175
+ # learn if hit the number of learn timesteps
176
+
177
+ if timesteps_learn >= num_timestep_before_learn:
178
+ # todo - carry out learning
179
+
180
+ timesteps_learn = 0
181
+ memories.clear()
182
+
183
+ # break if done or exceed max timestep
184
+
185
+ if done or timestep >= max_timesteps:
186
+ break
187
+
188
+ state = next_state
189
+
190
+ # main
191
+
192
+ if __name__ == '__main__':
193
+ Fire(main)
@@ -1,38 +0,0 @@
1
- import pytest
2
- param = pytest.mark.parametrize
3
-
4
- import torch
5
- from x_mlps_pytorch import MLP
6
-
7
- from einops import rearrange
8
-
9
- def test_locoformer():
10
- from locoformer.locoformer import Locoformer
11
- from torch import nn
12
-
13
- model = Locoformer(
14
- embedder = nn.Embedding(256, 128),
15
- unembedder = nn.Linear(128, 256, bias = False),
16
- value_network = MLP(128, 32, 1),
17
- transformer = dict(
18
- dim = 128,
19
- depth = 1,
20
- window_size = 256
21
- )
22
- )
23
-
24
- seq = torch.randint(0, 256, (3, 512))
25
-
26
- (logits, values), cache = model(seq, return_values = True)
27
- (logits, values), cache = model(seq, return_values = True, cache = cache)
28
- (logits, values), cache = model(seq, return_values = True, cache = cache)
29
-
30
- assert logits.shape == (3, 512, 256)
31
-
32
- stateful_forward = model.get_stateful_forward(256, has_batch_dim = True, return_values = True, inference_mode = True)
33
-
34
- for state in seq.unbind(dim = -1):
35
- state = rearrange(state, 'b -> b 1')
36
-
37
- logits, values = stateful_forward(state)
38
- assert logits.shape == (3, 1, 256)
File without changes
File without changes
File without changes
File without changes
File without changes