locoformer 0.0.5__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,12 +1,25 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
+ from pathlib import Path
5
+ from contextlib import contextmanager
6
+ from collections import namedtuple
7
+
8
+ import numpy as np
9
+ from numpy import ndarray
10
+ from numpy.lib.format import open_memmap
11
+
12
+ from beartype import beartype
13
+ from beartype.door import is_bearable
14
+
4
15
  import torch
5
- from torch import cat, stack, is_tensor
16
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
6
17
  import torch.nn.functional as F
7
- from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity
18
+ from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
19
  from torch.utils._pytree import tree_map
20
+ from torch.utils.data import Dataset, DataLoader
9
21
 
22
+ import einx
10
23
  from einops import rearrange, einsum
11
24
  from einops.layers.torch import Rearrange
12
25
 
@@ -24,23 +37,39 @@ def exists(v):
24
37
  def default(v, d):
25
38
  return v if exists(v) else d
26
39
 
40
+ def first(arr):
41
+ return arr[0]
42
+
27
43
  def divisible_by(num, den):
28
44
  return (num % den) == 0
29
45
 
46
+ # tensor helpers
47
+
48
+ def log(t, eps = 1e-20):
49
+ return t.clamp_min(eps).log()
50
+
30
51
  def tree_map_tensor(x, fn):
31
52
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
32
53
 
33
54
  def detach_all(x):
34
55
  return tree_map_tensor(x, lambda t: t.detach())
35
56
 
36
- def combine_kv_cache(cache1, cache2):
37
- combined_cache = []
57
+ def pad_at_dim(
58
+ t,
59
+ pad: tuple[int, int],
60
+ dim = -1,
61
+ value = 0.
62
+ ):
63
+ if pad == (0, 0):
64
+ return t
38
65
 
39
- for layer_cache1, layer_cache2 in zip(cache1, cache2):
40
- next_cache = cat((layer_cache1, layer_cache2), dim = -2)
41
- combined_cache.append(next_cache)
66
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
67
+ zeros = ((0, 0) * dims_from_right)
68
+ return F.pad(t, (*zeros, *pad), value = value)
42
69
 
43
- return combined_cache
70
+ def calc_entropy(logits):
71
+ prob = logits.softmax(dim = -1)
72
+ return -(prob * log(prob)).sum(dim = -1)
44
73
 
45
74
  # generalized advantage estimate
46
75
 
@@ -48,7 +77,7 @@ def combine_kv_cache(cache1, cache2):
48
77
  def calc_gae(
49
78
  rewards,
50
79
  values,
51
- masks,
80
+ masks = None,
52
81
  gamma = 0.99,
53
82
  lam = 0.95,
54
83
  use_accelerated = None
@@ -59,6 +88,9 @@ def calc_gae(
59
88
  values = F.pad(values, (0, 1), value = 0.)
60
89
  values, values_next = values[..., :-1], values[..., 1:]
61
90
 
91
+ if not exists(masks):
92
+ masks = torch.ones_like(values)
93
+
62
94
  delta = rewards + gamma * values_next * masks - values
63
95
  gates = gamma * lam * masks
64
96
 
@@ -110,8 +142,8 @@ def create_xl_mask(
110
142
  # handle intra-episodic attention if needed
111
143
 
112
144
  if exists(episode_ids):
113
- q_episode = episodes[b, q + offset]
114
- k_episode = episodes[b, k]
145
+ q_episode = episode_ids[b, q + offset]
146
+ k_episode = episode_ids[b, k]
115
147
 
116
148
  intra_episode_mask = q_episode == k_episode
117
149
  mask = mask & intra_episode_mask
@@ -142,15 +174,229 @@ def create_sliding_mask(
142
174
  create_kwargs = dict(device = device) if exists(device) else dict()
143
175
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
144
176
 
177
+ # data
178
+
179
+ def collate_var_time(data):
180
+
181
+ datum = first(data)
182
+ keys = datum.keys()
183
+
184
+ all_tensors = zip(*[datum.values() for datum in data])
185
+
186
+ collated_values = []
187
+
188
+ for key, tensors in zip(keys, all_tensors):
189
+
190
+ # the episode lens have zero dimension - think of a cleaner way to handle this later
191
+
192
+ if key != '_lens':
193
+
194
+ times = [t.shape[0] for t in tensors]
195
+ max_time = max(times)
196
+ tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
197
+
198
+ collated_values.append(stack(tensors))
199
+
200
+ return dict(zip(keys, collated_values))
201
+
202
+ class ReplayDataset(Dataset):
203
+ def __init__(
204
+ self,
205
+ folder: str | Path,
206
+ fields: tuple[str, ...] | None = None
207
+ ):
208
+ if isinstance(folder, str):
209
+ folder = Path(folder)
210
+
211
+ episode_lens = folder / 'episode_lens.npy'
212
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
213
+
214
+ # get indices of non-zero lengthed episodes
215
+
216
+ nonzero_episodes = self.episode_lens > 0
217
+ self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
218
+
219
+ # get all data files
220
+
221
+ filepaths = [*folder.glob('*.data.npy')]
222
+ assert len(filepaths) > 0
223
+
224
+ fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
225
+
226
+ fieldnames_from_files = set(fieldname_to_filepath.keys())
227
+
228
+ fields = default(fields, fieldnames_from_files)
229
+
230
+ self.memmaps = dict()
231
+
232
+ for field in fields:
233
+ assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
234
+
235
+ path = fieldname_to_filepath[field]
236
+
237
+ self.memmaps[field] = open_memmap(str(path), mode = 'r')
238
+
239
+ def __len__(self):
240
+ return len(self.indices)
241
+
242
+ def __getitem__(self, idx):
243
+ episode_index = self.indices[idx]
244
+
245
+ episode_len = self.episode_lens[episode_index]
246
+
247
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
248
+
249
+ data['_lens'] = tensor(episode_len)
250
+
251
+ return data
252
+
253
+ class ReplayBuffer:
254
+
255
+ @beartype
256
+ def __init__(
257
+ self,
258
+ folder: str | Path,
259
+ max_episodes: int,
260
+ max_timesteps: int,
261
+ fields: dict[
262
+ str,
263
+ str | tuple[str, int | tuple[int, ...]]
264
+ ]
265
+ ):
266
+
267
+ # folder for data
268
+
269
+ if not isinstance(folder, Path):
270
+ folder = Path(folder)
271
+ folder.mkdir(exist_ok = True)
272
+
273
+ self.folder = folder
274
+ assert folder.is_dir()
275
+
276
+ # keeping track of episode length
277
+
278
+ episode_lens = folder / 'episode_lens.npy'
279
+
280
+ self.episode_index = 0
281
+ self.timestep_index = 0
282
+
283
+ self.max_episodes = max_episodes
284
+ self.max_timesteps= max_timesteps
285
+
286
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
287
+
288
+ # create the memmap for individual data tracks
289
+
290
+ self.shapes = dict()
291
+ self.dtypes = dict()
292
+ self.memmaps = dict()
293
+ self.fieldnames = set(fields.keys())
294
+
295
+ for field_name, field_info in fields.items():
296
+
297
+ # some flexibility
298
+
299
+ field_info = (field_info, ()) if isinstance(field_info, str) else field_info
300
+
301
+ dtype_str, shape = field_info
302
+ assert dtype_str in {'int', 'float', 'bool'}
303
+
304
+ dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
305
+
306
+ # memmap file
307
+
308
+ filepath = folder / f'{field_name}.data.npy'
309
+ memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
310
+
311
+ self.memmaps[field_name] = memmap
312
+ self.shapes[field_name] = shape
313
+ self.dtypes[field_name] = dtype
314
+
315
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
+
317
+ def reset_(self):
318
+ self.episode_lens[:] = 0
319
+ self.episode_index = 0
320
+ self.timestep_index = 0
321
+
322
+ def advance_episode(self):
323
+ self.episode_index = (self.episode_index + 1) % self.max_episodes
324
+ self.timestep_index = 0
325
+
326
+ def flush(self):
327
+ self.episode_lens[self.episode_index] = self.timestep_index
328
+
329
+ for memmap in self.memmaps.values():
330
+ memmap.flush()
331
+
332
+ self.episode_lens.flush()
333
+
334
+ @contextmanager
335
+ def one_episode(self):
336
+
337
+ yield
338
+
339
+ self.flush()
340
+ self.advance_episode()
341
+
342
+ @beartype
343
+ def store_datapoint(
344
+ self,
345
+ episode_index: int,
346
+ timestep_index: int,
347
+ name: str,
348
+ datapoint: Tensor | ndarray
349
+ ):
350
+ assert 0 <= episode_index < self.max_episodes
351
+ assert 0 <= timestep_index < self.max_timesteps
352
+
353
+ if is_tensor(datapoint):
354
+ datapoint = datapoint.detach().cpu().numpy()
355
+
356
+ assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
357
+
358
+ assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
359
+
360
+ self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
361
+
362
+ def store(
363
+ self,
364
+ **data
365
+ ):
366
+ assert is_bearable(data, dict[str, Tensor | ndarray])
367
+
368
+ assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
369
+
370
+ for name, datapoint in data.items():
371
+
372
+ self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
373
+
374
+ self.timestep_index += 1
375
+
376
+ return self.memory_namedtuple(**data)
377
+
378
+ def dataset(self) -> Dataset:
379
+ self.flush()
380
+
381
+ return ReplayDataset(self.folder)
382
+
383
+ def dataloader(self, batch_size, **kwargs) -> DataLoader:
384
+ self.flush()
385
+
386
+ return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
387
+
145
388
  # transformer-xl with ppo
146
389
 
147
390
  class Attention(Module):
148
391
  def __init__(
149
392
  self,
150
393
  dim,
394
+ window_size,
151
395
  dim_head = 64,
152
396
  heads = 8,
153
- pre_rmsnorm = True
397
+ pre_rmsnorm = True,
398
+ fixed_window_size = False,
399
+ accept_value_residual = False
154
400
  ):
155
401
  super().__init__()
156
402
  self.scale = dim_head ** -0.5
@@ -167,20 +413,54 @@ class Attention(Module):
167
413
  self.to_kv = LinearNoBias(dim, dim_inner * 2)
168
414
  self.to_out = LinearNoBias(dim_inner, dim)
169
415
 
416
+ self.to_v_gates = Sequential(
417
+ LinearNoBias(dim, heads),
418
+ Rearrange('b n h -> b h n 1'),
419
+ nn.Sigmoid()
420
+ )
421
+
422
+ # value residual
423
+
424
+ self.accept_value_residual = accept_value_residual
425
+
426
+ if accept_value_residual:
427
+ self.to_value_residual_mix = Sequential(
428
+ LinearNoBias(dim, heads),
429
+ Rearrange('b n h -> b h n 1'),
430
+ nn.Sigmoid()
431
+ )
432
+
433
+ # fixed window size
434
+
435
+ self.fixed_window_size = fixed_window_size
436
+ self.window_size = window_size
437
+
170
438
  def forward(
171
439
  self,
172
440
  tokens,
441
+ value_residual = None,
173
442
  kv_cache = None,
174
- return_kv_cache = False
443
+ return_kv_cache = False,
175
444
  ):
445
+ seq_len = tokens.shape[-2]
446
+
447
+ device = tokens.device
448
+
176
449
  tokens = self.norm(tokens)
177
450
 
178
451
  q, k, v = (self.to_q(tokens), *self.to_kv(tokens).chunk(2, dim = -1))
179
452
 
180
453
  q, k, v = map(self.split_heads, (q, k, v))
181
454
 
455
+ orig_v = v
456
+
182
457
  q = q * self.scale
183
458
 
459
+ if exists(value_residual):
460
+ assert self.accept_value_residual
461
+ mix = self.to_value_residual_mix(tokens)
462
+ v = v.lerp(value_residual, mix)
463
+
184
464
  if exists(kv_cache):
185
465
  ck, cv = kv_cache
186
466
  k = cat((ck, k), dim = -2)
@@ -195,7 +475,13 @@ class Attention(Module):
195
475
 
196
476
  i, j = sim.shape[-2:]
197
477
 
198
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
478
+ if self.fixed_window_size:
479
+ i_seq = arange(i, device = device)
480
+ j_seq = arange(j, device = device) - (j - i)
481
+ dist = einx.subtract('i, j -> i j', i_seq, j_seq)
482
+ causal_mask = (dist < 0) | (dist > self.window_size)
483
+ else:
484
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = sim.device).triu(j - i + 1)
199
485
 
200
486
  sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
201
487
 
@@ -203,6 +489,8 @@ class Attention(Module):
203
489
 
204
490
  out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
205
491
 
492
+ out = out * self.to_v_gates(tokens)
493
+
206
494
  out = self.merge_heads(out)
207
495
 
208
496
  out = self.to_out(out)
@@ -210,7 +498,7 @@ class Attention(Module):
210
498
  if not return_kv_cache:
211
499
  return out
212
500
 
213
- return out, next_kv_cache
501
+ return out, (next_kv_cache, orig_v)
214
502
 
215
503
  class FeedForward(Module):
216
504
  def __init__(
@@ -244,17 +532,21 @@ class TransformerXL(Module):
244
532
  self,
245
533
  dim,
246
534
  depth,
535
+ window_size,
247
536
  dim_head = 64,
248
537
  heads = 8,
249
538
  expansion_factor = 4.,
250
- final_norm = True
539
+ final_norm = True,
540
+ fixed_window_size = False,
251
541
  ):
252
542
  super().__init__()
253
543
 
254
544
  layers = ModuleList([])
255
545
 
256
- for _ in range(depth):
257
- attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
546
+ for i in range(depth):
547
+ is_first = i == 0
548
+
549
+ attn = Attention(dim = dim, dim_head = dim_head, heads = heads, fixed_window_size = fixed_window_size, window_size = window_size, accept_value_residual = not is_first)
258
550
 
259
551
  ff = FeedForward(dim = dim, expansion_factor = expansion_factor)
260
552
 
@@ -265,6 +557,11 @@ class TransformerXL(Module):
265
557
  self.layers = layers
266
558
  self.norm = RMSNorm(dim) if final_norm else Identity()
267
559
 
560
+ # fixed window size
561
+
562
+ self.fixed_window_size = fixed_window_size
563
+ self.window_size = window_size
564
+
268
565
  def forward(
269
566
  self,
270
567
  x,
@@ -275,22 +572,28 @@ class TransformerXL(Module):
275
572
  cache = default(cache, (None,) * len(self.layers))
276
573
 
277
574
  next_kv_caches = []
575
+ value_residual = None
278
576
 
279
577
  for (attn, ff), kv_cache in zip(self.layers, cache):
280
578
 
281
- attn_out, next_kv_cache = attn(x, kv_cache = kv_cache, return_kv_cache = True)
282
-
283
- next_kv_caches.append(next_kv_cache)
579
+ attn_out, (next_kv_cache, values) = attn(x, value_residual = value_residual, kv_cache = kv_cache, return_kv_cache = True)
284
580
 
285
581
  x = attn_out + x
286
582
  x = ff(x) + x
287
583
 
584
+ next_kv_caches.append(next_kv_cache)
585
+ value_residual = default(value_residual, values)
586
+
288
587
  embed = self.norm(x)
289
588
 
290
589
  if not return_kv_cache:
291
590
  return embed
292
591
 
293
- return embed, stack(next_kv_caches)
592
+ next_kv_cache = stack(next_kv_caches)
593
+
594
+ next_kv_cache = next_kv_cache[..., -self.window_size:, :]
595
+
596
+ return embed, next_kv_cache
294
597
 
295
598
  # class
296
599
 
@@ -300,7 +603,13 @@ class Locoformer(Module):
300
603
  embedder: Module,
301
604
  unembedder: Module,
302
605
  transformer: dict | TransformerXL,
303
- value_network: Module | None = None
606
+ value_network: Module | None = None,
607
+ discount_factor = 0.999,
608
+ gae_lam = 0.95,
609
+ ppo_eps_clip = 0.2,
610
+ ppo_entropy_weight = 0.01,
611
+ ppo_value_clip = 0.4,
612
+ value_loss_weight = 0.5
304
613
  ):
305
614
  super().__init__()
306
615
 
@@ -314,28 +623,138 @@ class Locoformer(Module):
314
623
 
315
624
  self.value_network = value_network
316
625
 
626
+ self.fixed_window_size = transformer.fixed_window_size
627
+ self.window_size = transformer.window_size
628
+
629
+ # ppo related
630
+
631
+ self.discount_factor = discount_factor
632
+ self.gae_lam = gae_lam
633
+ self.ppo_eps_clip = ppo_eps_clip
634
+ self.ppo_entropy_weight = ppo_entropy_weight
635
+ self.ppo_value_clip = ppo_value_clip
636
+ self.value_loss_weight = value_loss_weight
637
+
317
638
  @property
318
639
  def device(self):
319
640
  return next(self.parameters()).device
320
641
 
642
+ def actor_parameters(self):
643
+ return self.unembedder.parameters()
644
+
645
+ def critic_parameters(self):
646
+ if not exists(self.value_network):
647
+ return []
648
+
649
+ return self.value_network.parameters()
650
+
651
+ def ppo(
652
+ self,
653
+ state,
654
+ action,
655
+ old_action_log_prob,
656
+ reward,
657
+ old_value,
658
+ mask,
659
+ actor_optim,
660
+ critic_optim
661
+ ):
662
+
663
+ (action_logits, value), _ = self.forward(state, return_values = True)
664
+ entropy = calc_entropy(action_logits)
665
+
666
+ action = rearrange(action, 'b t -> b t 1')
667
+ log_prob = action_logits.gather(-1, action)
668
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
669
+
670
+ # update actor, classic clipped surrogate loss
671
+
672
+ eps_clip = self.ppo_eps_clip
673
+ ratio = (log_prob - old_action_log_prob).exp()
674
+
675
+ returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor)
676
+ advantage = returns - old_value
677
+
678
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
679
+
680
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
681
+
682
+ mean_actor_loss = actor_loss[mask].mean()
683
+ mean_actor_loss.backward(retain_graph = True)
684
+
685
+ # update critic
686
+
687
+ value_loss = F.mse_loss(returns, value, reduction = 'none')
688
+
689
+ value_clip = self.ppo_value_clip
690
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
691
+ clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
692
+
693
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
694
+
695
+ mean_critic_loss = critic_loss[mask].mean()
696
+ mean_critic_loss.backward()
697
+
698
+ # optimizer update
699
+
700
+ actor_optim.step()
701
+ actor_optim.zero_grad()
702
+
703
+ critic_optim.step()
704
+ critic_optim.zero_grad()
705
+
706
+ # return losses for logging
707
+
708
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
709
+
710
+ def wrap_env_functions(self, env):
711
+
712
+ def wrapped_reset(*args, **kwargs):
713
+ state, _ = env.reset(*args, **kwargs)
714
+
715
+ if isinstance(state, ndarray):
716
+ state = from_numpy(state)
717
+
718
+ return state, _
719
+
720
+ def wrapped_step(action, *args, **kwargs):
721
+ out = env.step(action.item(), *args, **kwargs)
722
+
723
+ def transform_output(el):
724
+ if isinstance(el, ndarray):
725
+ return from_numpy(el)
726
+ elif isinstance(el, (int, bool, float)):
727
+ return tensor(el)
728
+ else:
729
+ return el
730
+
731
+ return tree_map(transform_output, out)
732
+
733
+ return wrapped_reset, wrapped_step
734
+
321
735
  def get_stateful_forward(
322
736
  self,
323
- segment_size,
324
737
  initial_states: Tensor | None = None,
325
738
  inference_mode = False,
326
739
  has_batch_dim = False,
740
+ has_time_dim = False,
327
741
  **kwargs
328
742
  ):
743
+ window_size = self.window_size
744
+
329
745
  cache = None
330
746
 
331
- def stateful_forward(state: Tensor, override_kwargs: dict = dict()):
747
+ def stateful_forward(state: Tensor, **override_kwargs):
332
748
  nonlocal cache
333
749
 
334
- # handle no batch, for easier time rolling out against envs
750
+ # handle no batch or time, for easier time rolling out against envs
335
751
 
336
752
  if not has_batch_dim:
337
753
  state = rearrange(state, '... -> 1 ...')
338
754
 
755
+ if not has_time_dim:
756
+ state = rearrange(state, '... d -> ... 1 d')
757
+
339
758
  # forwards
340
759
 
341
760
  out, cache = self.forward(state, cache = cache, **{**kwargs, **override_kwargs})
@@ -344,10 +763,13 @@ class Locoformer(Module):
344
763
 
345
764
  cache_len = cache.shape[-2]
346
765
 
347
- if divisible_by(cache_len, segment_size * 2):
348
- cache = cache[..., -segment_size:, :]
766
+ if self.fixed_window_size or divisible_by(cache_len, window_size * 2):
767
+ cache = cache[..., -window_size:, :]
349
768
 
350
- # maybe remove batch
769
+ # maybe remove batch or time
770
+
771
+ if not has_time_dim:
772
+ out = tree_map_tensor(out, lambda t: rearrange(t, '... 1 d -> ... d'))
351
773
 
352
774
  if not has_batch_dim:
353
775
  out = tree_map_tensor(out, lambda t: rearrange(t, '1 ... -> ...'))
@@ -364,7 +786,7 @@ class Locoformer(Module):
364
786
 
365
787
  initial_logits = []
366
788
 
367
- for state_segments in initial_states.split(segment_size, dim = -1):
789
+ for state_segments in initial_states.split(self.window_size, dim = -1):
368
790
 
369
791
  logits = stateful_forward(state_segments, return_values = False)
370
792
  initial_logits.append(logits)
@@ -381,6 +803,8 @@ class Locoformer(Module):
381
803
  return_values = False
382
804
  ):
383
805
 
806
+ state = state.to(self.device)
807
+
384
808
  tokens = self.embedder(state)
385
809
 
386
810
  embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.5
3
+ Version: 0.0.15
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: rotary-embedding-torch
@@ -53,7 +54,7 @@ Description-Content-Type: text/markdown
53
54
 
54
55
  [LocoFormer - Generalist Locomotion via Long-Context Adaptation](https://generalist-locomotion.github.io/)
55
56
 
56
- The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment). When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
+ The gist is they trained a simple Transformer-XL in simulation on robots with many different bodies (cross-embodiment) with extreme domain randomization. When transferring to the real-world, they noticed the robot now gains the ability to adapt to insults. The XL memories span across multiple trials, which allowed the robot to learn in-context adaptation.
57
58
 
58
59
  ## Sponsors
59
60
 
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
3
+ locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
4
+ locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.15.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
- locoformer/locoformer.py,sha256=Yoh3hrj2E_91YLoYRa73wGzjdIiMdcd5ofNjkiVlogI,10570
3
- locoformer-0.0.5.dist-info/METADATA,sha256=oe6HfOwWKQvusiJl1ukmNFcrGRhdDZ6NcKZi3upv-SY,3159
4
- locoformer-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- locoformer-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- locoformer-0.0.5.dist-info/RECORD,,