locoformer 0.0.7__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
locoformer/locoformer.py CHANGED
@@ -1,11 +1,23 @@
1
1
  from __future__ import annotations
2
2
  from functools import partial
3
3
 
4
+ from pathlib import Path
5
+ from contextlib import contextmanager
6
+ from collections import namedtuple
7
+
8
+ import numpy as np
9
+ from numpy import ndarray
10
+ from numpy.lib.format import open_memmap
11
+
12
+ from beartype import beartype
13
+ from beartype.door import is_bearable
14
+
4
15
  import torch
5
- from torch import nn, cat, stack, arange, Tensor, is_tensor
16
+ from torch import nn, cat, stack, arange, Tensor, tensor, is_tensor, from_numpy
6
17
  import torch.nn.functional as F
7
18
  from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
8
19
  from torch.utils._pytree import tree_map
20
+ from torch.utils.data import Dataset, DataLoader
9
21
 
10
22
  import einx
11
23
  from einops import rearrange, einsum
@@ -31,20 +43,33 @@ def first(arr):
31
43
  def divisible_by(num, den):
32
44
  return (num % den) == 0
33
45
 
46
+ # tensor helpers
47
+
48
+ def log(t, eps = 1e-20):
49
+ return t.clamp_min(eps).log()
50
+
34
51
  def tree_map_tensor(x, fn):
35
52
  return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
36
53
 
37
54
  def detach_all(x):
38
55
  return tree_map_tensor(x, lambda t: t.detach())
39
56
 
40
- def combine_kv_cache(cache1, cache2):
41
- combined_cache = []
57
+ def pad_at_dim(
58
+ t,
59
+ pad: tuple[int, int],
60
+ dim = -1,
61
+ value = 0.
62
+ ):
63
+ if pad == (0, 0):
64
+ return t
42
65
 
43
- for layer_cache1, layer_cache2 in zip(cache1, cache2):
44
- next_cache = cat((layer_cache1, layer_cache2), dim = -2)
45
- combined_cache.append(next_cache)
66
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
67
+ zeros = ((0, 0) * dims_from_right)
68
+ return F.pad(t, (*zeros, *pad), value = value)
46
69
 
47
- return combined_cache
70
+ def calc_entropy(logits):
71
+ prob = logits.softmax(dim = -1)
72
+ return -(prob * log(prob)).sum(dim = -1)
48
73
 
49
74
  # generalized advantage estimate
50
75
 
@@ -52,7 +77,7 @@ def combine_kv_cache(cache1, cache2):
52
77
  def calc_gae(
53
78
  rewards,
54
79
  values,
55
- masks,
80
+ masks = None,
56
81
  gamma = 0.99,
57
82
  lam = 0.95,
58
83
  use_accelerated = None
@@ -63,6 +88,9 @@ def calc_gae(
63
88
  values = F.pad(values, (0, 1), value = 0.)
64
89
  values, values_next = values[..., :-1], values[..., 1:]
65
90
 
91
+ if not exists(masks):
92
+ masks = torch.ones_like(values)
93
+
66
94
  delta = rewards + gamma * values_next * masks - values
67
95
  gates = gamma * lam * masks
68
96
 
@@ -114,8 +142,8 @@ def create_xl_mask(
114
142
  # handle intra-episodic attention if needed
115
143
 
116
144
  if exists(episode_ids):
117
- q_episode = episodes[b, q + offset]
118
- k_episode = episodes[b, k]
145
+ q_episode = episode_ids[b, q + offset]
146
+ k_episode = episode_ids[b, k]
119
147
 
120
148
  intra_episode_mask = q_episode == k_episode
121
149
  mask = mask & intra_episode_mask
@@ -146,6 +174,217 @@ def create_sliding_mask(
146
174
  create_kwargs = dict(device = device) if exists(device) else dict()
147
175
  return create_block_mask(sliding_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = kv_seq_len, _compile = True, **create_kwargs)
148
176
 
177
+ # data
178
+
179
+ def collate_var_time(data):
180
+
181
+ datum = first(data)
182
+ keys = datum.keys()
183
+
184
+ all_tensors = zip(*[datum.values() for datum in data])
185
+
186
+ collated_values = []
187
+
188
+ for key, tensors in zip(keys, all_tensors):
189
+
190
+ # the episode lens have zero dimension - think of a cleaner way to handle this later
191
+
192
+ if key != '_lens':
193
+
194
+ times = [t.shape[0] for t in tensors]
195
+ max_time = max(times)
196
+ tensors = [pad_at_dim(t, (0, max_time - t.shape[0]), dim = 0) for t in tensors]
197
+
198
+ collated_values.append(stack(tensors))
199
+
200
+ return dict(zip(keys, collated_values))
201
+
202
+ class ReplayDataset(Dataset):
203
+ def __init__(
204
+ self,
205
+ folder: str | Path,
206
+ fields: tuple[str, ...] | None = None
207
+ ):
208
+ if isinstance(folder, str):
209
+ folder = Path(folder)
210
+
211
+ episode_lens = folder / 'episode_lens.npy'
212
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'r')
213
+
214
+ # get indices of non-zero lengthed episodes
215
+
216
+ nonzero_episodes = self.episode_lens > 0
217
+ self.indices = np.arange(self.episode_lens.shape[-1])[nonzero_episodes]
218
+
219
+ # get all data files
220
+
221
+ filepaths = [*folder.glob('*.data.npy')]
222
+ assert len(filepaths) > 0
223
+
224
+ fieldname_to_filepath = {path.name.split('.')[0]: path for path in filepaths}
225
+
226
+ fieldnames_from_files = set(fieldname_to_filepath.keys())
227
+
228
+ fields = default(fields, fieldnames_from_files)
229
+
230
+ self.memmaps = dict()
231
+
232
+ for field in fields:
233
+ assert field in fieldnames_from_files, f'invalid field {field} - must be one of {fieldnames_from_files}'
234
+
235
+ path = fieldname_to_filepath[field]
236
+
237
+ self.memmaps[field] = open_memmap(str(path), mode = 'r')
238
+
239
+ def __len__(self):
240
+ return len(self.indices)
241
+
242
+ def __getitem__(self, idx):
243
+ episode_index = self.indices[idx]
244
+
245
+ episode_len = self.episode_lens[episode_index]
246
+
247
+ data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
248
+
249
+ data['_lens'] = tensor(episode_len)
250
+
251
+ return data
252
+
253
+ class ReplayBuffer:
254
+
255
+ @beartype
256
+ def __init__(
257
+ self,
258
+ folder: str | Path,
259
+ max_episodes: int,
260
+ max_timesteps: int,
261
+ fields: dict[
262
+ str,
263
+ str | tuple[str, int | tuple[int, ...]]
264
+ ]
265
+ ):
266
+
267
+ # folder for data
268
+
269
+ if not isinstance(folder, Path):
270
+ folder = Path(folder)
271
+ folder.mkdir(exist_ok = True)
272
+
273
+ self.folder = folder
274
+ assert folder.is_dir()
275
+
276
+ # keeping track of episode length
277
+
278
+ episode_lens = folder / 'episode_lens.npy'
279
+
280
+ self.episode_index = 0
281
+ self.timestep_index = 0
282
+
283
+ self.max_episodes = max_episodes
284
+ self.max_timesteps= max_timesteps
285
+
286
+ self.episode_lens = open_memmap(str(episode_lens), mode = 'w+', dtype = np.int32, shape = (max_episodes,))
287
+
288
+ # create the memmap for individual data tracks
289
+
290
+ self.shapes = dict()
291
+ self.dtypes = dict()
292
+ self.memmaps = dict()
293
+ self.fieldnames = set(fields.keys())
294
+
295
+ for field_name, field_info in fields.items():
296
+
297
+ # some flexibility
298
+
299
+ field_info = (field_info, ()) if isinstance(field_info, str) else field_info
300
+
301
+ dtype_str, shape = field_info
302
+ assert dtype_str in {'int', 'float', 'bool'}
303
+
304
+ dtype = dict(int = np.int32, float = np.float32, bool = np.bool_)[dtype_str]
305
+
306
+ # memmap file
307
+
308
+ filepath = folder / f'{field_name}.data.npy'
309
+ memmap = open_memmap(str(filepath), mode = 'w+', dtype = dtype, shape = (max_episodes, max_timesteps, *shape))
310
+
311
+ self.memmaps[field_name] = memmap
312
+ self.shapes[field_name] = shape
313
+ self.dtypes[field_name] = dtype
314
+
315
+ self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
316
+
317
+ def reset_(self):
318
+ self.episode_lens[:] = 0
319
+ self.episode_index = 0
320
+ self.timestep_index = 0
321
+
322
+ def advance_episode(self):
323
+ self.episode_index = (self.episode_index + 1) % self.max_episodes
324
+ self.timestep_index = 0
325
+
326
+ def flush(self):
327
+ self.episode_lens[self.episode_index] = self.timestep_index
328
+
329
+ for memmap in self.memmaps.values():
330
+ memmap.flush()
331
+
332
+ self.episode_lens.flush()
333
+
334
+ @contextmanager
335
+ def one_episode(self):
336
+
337
+ yield
338
+
339
+ self.flush()
340
+ self.advance_episode()
341
+
342
+ @beartype
343
+ def store_datapoint(
344
+ self,
345
+ episode_index: int,
346
+ timestep_index: int,
347
+ name: str,
348
+ datapoint: Tensor | ndarray
349
+ ):
350
+ assert 0 <= episode_index < self.max_episodes
351
+ assert 0 <= timestep_index < self.max_timesteps
352
+
353
+ if is_tensor(datapoint):
354
+ datapoint = datapoint.detach().cpu().numpy()
355
+
356
+ assert name in self.fieldnames, f'invalid field name {name} - must be one of {self.fieldnames}'
357
+
358
+ assert datapoint.shape == self.shapes[name], f'invalid shape {datapoint.shape} - shape must be {self.shapes[name]}'
359
+
360
+ self.memmaps[name][self.episode_index, self.timestep_index] = datapoint
361
+
362
+ def store(
363
+ self,
364
+ **data
365
+ ):
366
+ assert is_bearable(data, dict[str, Tensor | ndarray])
367
+
368
+ assert not self.timestep_index >= self.max_timesteps, 'you exceeded the `max_timesteps` set on the replay buffer'
369
+
370
+ for name, datapoint in data.items():
371
+
372
+ self.store_datapoint(self.episode_index, self.timestep_index, name, datapoint)
373
+
374
+ self.timestep_index += 1
375
+
376
+ return self.memory_namedtuple(**data)
377
+
378
+ def dataset(self) -> Dataset:
379
+ self.flush()
380
+
381
+ return ReplayDataset(self.folder)
382
+
383
+ def dataloader(self, batch_size, **kwargs) -> DataLoader:
384
+ self.flush()
385
+
386
+ return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
387
+
149
388
  # transformer-xl with ppo
150
389
 
151
390
  class Attention(Module):
@@ -204,7 +443,6 @@ class Attention(Module):
204
443
  return_kv_cache = False,
205
444
  ):
206
445
  seq_len = tokens.shape[-2]
207
- assert seq_len <= self.window_size
208
446
 
209
447
  device = tokens.device
210
448
 
@@ -365,7 +603,13 @@ class Locoformer(Module):
365
603
  embedder: Module,
366
604
  unembedder: Module,
367
605
  transformer: dict | TransformerXL,
368
- value_network: Module | None = None
606
+ value_network: Module | None = None,
607
+ discount_factor = 0.999,
608
+ gae_lam = 0.95,
609
+ ppo_eps_clip = 0.2,
610
+ ppo_entropy_weight = 0.01,
611
+ ppo_value_clip = 0.4,
612
+ value_loss_weight = 0.5
369
613
  ):
370
614
  super().__init__()
371
615
 
@@ -382,6 +626,15 @@ class Locoformer(Module):
382
626
  self.fixed_window_size = transformer.fixed_window_size
383
627
  self.window_size = transformer.window_size
384
628
 
629
+ # ppo related
630
+
631
+ self.discount_factor = discount_factor
632
+ self.gae_lam = gae_lam
633
+ self.ppo_eps_clip = ppo_eps_clip
634
+ self.ppo_entropy_weight = ppo_entropy_weight
635
+ self.ppo_value_clip = ppo_value_clip
636
+ self.value_loss_weight = value_loss_weight
637
+
385
638
  @property
386
639
  def device(self):
387
640
  return next(self.parameters()).device
@@ -390,8 +643,95 @@ class Locoformer(Module):
390
643
  return self.unembedder.parameters()
391
644
 
392
645
  def critic_parameters(self):
646
+ if not exists(self.value_network):
647
+ return []
648
+
393
649
  return self.value_network.parameters()
394
650
 
651
+ def ppo(
652
+ self,
653
+ state,
654
+ action,
655
+ old_action_log_prob,
656
+ reward,
657
+ old_value,
658
+ mask,
659
+ actor_optim,
660
+ critic_optim
661
+ ):
662
+
663
+ (action_logits, value), _ = self.forward(state, return_values = True)
664
+ entropy = calc_entropy(action_logits)
665
+
666
+ action = rearrange(action, 'b t -> b t 1')
667
+ log_prob = action_logits.gather(-1, action)
668
+ log_prob = rearrange(log_prob, 'b t 1 -> b t')
669
+
670
+ # update actor, classic clipped surrogate loss
671
+
672
+ eps_clip = self.ppo_eps_clip
673
+ ratio = (log_prob - old_action_log_prob).exp()
674
+
675
+ returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor)
676
+ advantage = returns - old_value
677
+
678
+ actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
679
+
680
+ actor_loss = actor_loss - self.ppo_entropy_weight * entropy
681
+
682
+ mean_actor_loss = actor_loss[mask].mean()
683
+ mean_actor_loss.backward(retain_graph = True)
684
+
685
+ # update critic
686
+
687
+ value_loss = F.mse_loss(returns, value, reduction = 'none')
688
+
689
+ value_clip = self.ppo_value_clip
690
+ clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
691
+ clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
692
+
693
+ critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
694
+
695
+ mean_critic_loss = critic_loss[mask].mean()
696
+ mean_critic_loss.backward()
697
+
698
+ # optimizer update
699
+
700
+ actor_optim.step()
701
+ actor_optim.zero_grad()
702
+
703
+ critic_optim.step()
704
+ critic_optim.zero_grad()
705
+
706
+ # return losses for logging
707
+
708
+ return mean_actor_loss.detach(), mean_critic_loss.detach()
709
+
710
+ def wrap_env_functions(self, env):
711
+
712
+ def wrapped_reset(*args, **kwargs):
713
+ state, _ = env.reset(*args, **kwargs)
714
+
715
+ if isinstance(state, ndarray):
716
+ state = from_numpy(state)
717
+
718
+ return state, _
719
+
720
+ def wrapped_step(action, *args, **kwargs):
721
+ out = env.step(action.item(), *args, **kwargs)
722
+
723
+ def transform_output(el):
724
+ if isinstance(el, ndarray):
725
+ return from_numpy(el)
726
+ elif isinstance(el, (int, bool, float)):
727
+ return tensor(el)
728
+ else:
729
+ return el
730
+
731
+ return tree_map(transform_output, out)
732
+
733
+ return wrapped_reset, wrapped_step
734
+
395
735
  def get_stateful_forward(
396
736
  self,
397
737
  initial_states: Tensor | None = None,
@@ -463,6 +803,8 @@ class Locoformer(Module):
463
803
  return_values = False
464
804
  ):
465
805
 
806
+ state = state.to(self.device)
807
+
466
808
  tokens = self.embedder(state)
467
809
 
468
810
  embed, kv_cache = self.transformer(tokens, cache = cache, return_kv_cache = True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: locoformer
3
- Version: 0.0.7
3
+ Version: 0.0.15
4
4
  Summary: LocoFormer
5
5
  Project-URL: Homepage, https://pypi.org/project/locoformer/
6
6
  Project-URL: Repository, https://github.com/lucidrains/locoformer
@@ -35,6 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: assoc-scan
38
+ Requires-Dist: beartype
38
39
  Requires-Dist: einops>=0.8.0
39
40
  Requires-Dist: einx>=0.3.0
40
41
  Requires-Dist: rotary-embedding-torch
@@ -0,0 +1,6 @@
1
+ locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
+ locoformer/locoformer.py,sha256=1jPK41G4HB1PEPtlusQxcrne489E-3QKXAULZ20FEZM,22740
3
+ locoformer-0.0.15.dist-info/METADATA,sha256=IHtK7NvVQewYQ0GBB7v1KG90_H2Jakxir0MakUIA-jU,3218
4
+ locoformer-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ locoformer-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
+ locoformer-0.0.15.dist-info/RECORD,,
@@ -1,6 +0,0 @@
1
- locoformer/__init__.py,sha256=XctsMGEZSR4mVl75fhds_1BtS5qGFiiItTDV7CmCt_I,45
2
- locoformer/locoformer.py,sha256=lJQs0CKr9iztF8tie1FRUVEItCt-IZbIILQqKcgK2sI,13142
3
- locoformer-0.0.7.dist-info/METADATA,sha256=PZ_phKV3t4Bha0GnUB5HPmE9w8A5fvNevsuN532Ls3s,3193
4
- locoformer-0.0.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- locoformer-0.0.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
6
- locoformer-0.0.7.dist-info/RECORD,,