locoformer 0.0.11__tar.gz → 0.0.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {locoformer-0.0.11 → locoformer-0.0.17}/PKG-INFO +1 -1
- {locoformer-0.0.11 → locoformer-0.0.17}/locoformer/locoformer.py +151 -13
- {locoformer-0.0.11 → locoformer-0.0.17}/pyproject.toml +1 -1
- {locoformer-0.0.11 → locoformer-0.0.17}/train.py +1 -1
- {locoformer-0.0.11 → locoformer-0.0.17}/train_gym.py +90 -21
- {locoformer-0.0.11 → locoformer-0.0.17}/.github/workflows/python-publish.yml +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/.github/workflows/test.yml +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/.gitignore +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/LICENSE +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/README.md +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/data/README.md +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/data/enwik8.gz +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/fig3.png +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/locoformer/__init__.py +0 -0
- {locoformer-0.0.11 → locoformer-0.0.17}/tests/test_locoformer.py +0 -0
|
@@ -3,6 +3,7 @@ from functools import partial
|
|
|
3
3
|
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from contextlib import contextmanager
|
|
6
|
+
from collections import namedtuple
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
from numpy import ndarray
|
|
@@ -17,6 +18,7 @@ import torch.nn.functional as F
|
|
|
17
18
|
from torch.nn import Module, ModuleList, Linear, RMSNorm, Identity, Sequential
|
|
18
19
|
from torch.utils._pytree import tree_map
|
|
19
20
|
from torch.utils.data import Dataset, DataLoader
|
|
21
|
+
from torch.optim import Optimizer
|
|
20
22
|
|
|
21
23
|
import einx
|
|
22
24
|
from einops import rearrange, einsum
|
|
@@ -26,6 +28,8 @@ from rotary_embedding_torch import RotaryEmbedding
|
|
|
26
28
|
|
|
27
29
|
from assoc_scan import AssocScan
|
|
28
30
|
|
|
31
|
+
# constants
|
|
32
|
+
|
|
29
33
|
LinearNoBias = partial(Linear, bias = False)
|
|
30
34
|
|
|
31
35
|
# helper functions
|
|
@@ -42,12 +46,14 @@ def first(arr):
|
|
|
42
46
|
def divisible_by(num, den):
|
|
43
47
|
return (num % den) == 0
|
|
44
48
|
|
|
49
|
+
# tensor helpers
|
|
50
|
+
|
|
51
|
+
def log(t, eps = 1e-20):
|
|
52
|
+
return t.clamp_min(eps).log()
|
|
53
|
+
|
|
45
54
|
def tree_map_tensor(x, fn):
|
|
46
55
|
return tree_map(lambda t: t if not is_tensor(t) else fn(t), x)
|
|
47
56
|
|
|
48
|
-
def detach_all(x):
|
|
49
|
-
return tree_map_tensor(x, lambda t: t.detach())
|
|
50
|
-
|
|
51
57
|
def pad_at_dim(
|
|
52
58
|
t,
|
|
53
59
|
pad: tuple[int, int],
|
|
@@ -61,13 +67,17 @@ def pad_at_dim(
|
|
|
61
67
|
zeros = ((0, 0) * dims_from_right)
|
|
62
68
|
return F.pad(t, (*zeros, *pad), value = value)
|
|
63
69
|
|
|
70
|
+
def calc_entropy(logits):
|
|
71
|
+
prob = logits.softmax(dim = -1)
|
|
72
|
+
return -(prob * log(prob)).sum(dim = -1)
|
|
73
|
+
|
|
64
74
|
# generalized advantage estimate
|
|
65
75
|
|
|
66
76
|
@torch.no_grad()
|
|
67
77
|
def calc_gae(
|
|
68
78
|
rewards,
|
|
69
79
|
values,
|
|
70
|
-
masks,
|
|
80
|
+
masks = None,
|
|
71
81
|
gamma = 0.99,
|
|
72
82
|
lam = 0.95,
|
|
73
83
|
use_accelerated = None
|
|
@@ -78,6 +88,9 @@ def calc_gae(
|
|
|
78
88
|
values = F.pad(values, (0, 1), value = 0.)
|
|
79
89
|
values, values_next = values[..., :-1], values[..., 1:]
|
|
80
90
|
|
|
91
|
+
if not exists(masks):
|
|
92
|
+
masks = torch.ones_like(values)
|
|
93
|
+
|
|
81
94
|
delta = rewards + gamma * values_next * masks - values
|
|
82
95
|
gates = gamma * lam * masks
|
|
83
96
|
|
|
@@ -87,7 +100,7 @@ def calc_gae(
|
|
|
87
100
|
|
|
88
101
|
returns = gae + values
|
|
89
102
|
|
|
90
|
-
return returns
|
|
103
|
+
return gae, returns
|
|
91
104
|
|
|
92
105
|
# transformer-xl mask w/ flex attn
|
|
93
106
|
|
|
@@ -129,8 +142,8 @@ def create_xl_mask(
|
|
|
129
142
|
# handle intra-episodic attention if needed
|
|
130
143
|
|
|
131
144
|
if exists(episode_ids):
|
|
132
|
-
q_episode =
|
|
133
|
-
k_episode =
|
|
145
|
+
q_episode = episode_ids[b, q + offset]
|
|
146
|
+
k_episode = episode_ids[b, k]
|
|
134
147
|
|
|
135
148
|
intra_episode_mask = q_episode == k_episode
|
|
136
149
|
mask = mask & intra_episode_mask
|
|
@@ -231,7 +244,7 @@ class ReplayDataset(Dataset):
|
|
|
231
244
|
|
|
232
245
|
episode_len = self.episode_lens[episode_index]
|
|
233
246
|
|
|
234
|
-
data = {field:
|
|
247
|
+
data = {field: from_numpy(memmap[episode_index, :episode_len].copy()) for field, memmap in self.memmaps.items()}
|
|
235
248
|
|
|
236
249
|
data['_lens'] = tensor(episode_len)
|
|
237
250
|
|
|
@@ -299,6 +312,13 @@ class ReplayBuffer:
|
|
|
299
312
|
self.shapes[field_name] = shape
|
|
300
313
|
self.dtypes[field_name] = dtype
|
|
301
314
|
|
|
315
|
+
self.memory_namedtuple = namedtuple('Memory', list(fields.keys()))
|
|
316
|
+
|
|
317
|
+
def reset_(self):
|
|
318
|
+
self.episode_lens[:] = 0
|
|
319
|
+
self.episode_index = 0
|
|
320
|
+
self.timestep_index = 0
|
|
321
|
+
|
|
302
322
|
def advance_episode(self):
|
|
303
323
|
self.episode_index = (self.episode_index + 1) % self.max_episodes
|
|
304
324
|
self.timestep_index = 0
|
|
@@ -353,15 +373,17 @@ class ReplayBuffer:
|
|
|
353
373
|
|
|
354
374
|
self.timestep_index += 1
|
|
355
375
|
|
|
376
|
+
return self.memory_namedtuple(**data)
|
|
377
|
+
|
|
356
378
|
def dataset(self) -> Dataset:
|
|
357
379
|
self.flush()
|
|
358
380
|
|
|
359
381
|
return ReplayDataset(self.folder)
|
|
360
382
|
|
|
361
|
-
def dataloader(self, **kwargs) -> DataLoader:
|
|
383
|
+
def dataloader(self, batch_size, **kwargs) -> DataLoader:
|
|
362
384
|
self.flush()
|
|
363
385
|
|
|
364
|
-
return DataLoader(self.dataset(), collate_fn = collate_var_time, **kwargs)
|
|
386
|
+
return DataLoader(self.dataset(), batch_size = batch_size, collate_fn = collate_var_time, **kwargs)
|
|
365
387
|
|
|
366
388
|
# transformer-xl with ppo
|
|
367
389
|
|
|
@@ -421,7 +443,6 @@ class Attention(Module):
|
|
|
421
443
|
return_kv_cache = False,
|
|
422
444
|
):
|
|
423
445
|
seq_len = tokens.shape[-2]
|
|
424
|
-
assert seq_len <= self.window_size
|
|
425
446
|
|
|
426
447
|
device = tokens.device
|
|
427
448
|
|
|
@@ -582,7 +603,14 @@ class Locoformer(Module):
|
|
|
582
603
|
embedder: Module,
|
|
583
604
|
unembedder: Module,
|
|
584
605
|
transformer: dict | TransformerXL,
|
|
585
|
-
value_network: Module | None = None
|
|
606
|
+
value_network: Module | None = None,
|
|
607
|
+
discount_factor = 0.999,
|
|
608
|
+
gae_lam = 0.95,
|
|
609
|
+
ppo_eps_clip = 0.2,
|
|
610
|
+
ppo_entropy_weight = 0.01,
|
|
611
|
+
ppo_value_clip = 0.4,
|
|
612
|
+
value_loss_weight = 0.5,
|
|
613
|
+
calc_gae_kwargs: dict = dict()
|
|
586
614
|
):
|
|
587
615
|
super().__init__()
|
|
588
616
|
|
|
@@ -599,6 +627,21 @@ class Locoformer(Module):
|
|
|
599
627
|
self.fixed_window_size = transformer.fixed_window_size
|
|
600
628
|
self.window_size = transformer.window_size
|
|
601
629
|
|
|
630
|
+
# ppo related
|
|
631
|
+
|
|
632
|
+
self.discount_factor = discount_factor
|
|
633
|
+
self.gae_lam = gae_lam
|
|
634
|
+
self.ppo_eps_clip = ppo_eps_clip
|
|
635
|
+
self.ppo_entropy_weight = ppo_entropy_weight
|
|
636
|
+
self.ppo_value_clip = ppo_value_clip
|
|
637
|
+
self.value_loss_weight = value_loss_weight
|
|
638
|
+
|
|
639
|
+
self.calc_gae_kwargs = calc_gae_kwargs
|
|
640
|
+
|
|
641
|
+
# loss related
|
|
642
|
+
|
|
643
|
+
self.register_buffer('zero', tensor(0.), persistent = False)
|
|
644
|
+
|
|
602
645
|
@property
|
|
603
646
|
def device(self):
|
|
604
647
|
return next(self.parameters()).device
|
|
@@ -612,6 +655,101 @@ class Locoformer(Module):
|
|
|
612
655
|
|
|
613
656
|
return self.value_network.parameters()
|
|
614
657
|
|
|
658
|
+
def ppo(
|
|
659
|
+
self,
|
|
660
|
+
state,
|
|
661
|
+
action,
|
|
662
|
+
old_action_log_prob,
|
|
663
|
+
reward,
|
|
664
|
+
old_value,
|
|
665
|
+
mask,
|
|
666
|
+
actor_optim: Optimizer | None = None,
|
|
667
|
+
critic_optim: Optimizer | None = None
|
|
668
|
+
):
|
|
669
|
+
window_size = self.window_size
|
|
670
|
+
total_learnable_tokens = mask.sum().item()
|
|
671
|
+
|
|
672
|
+
windowed_tensors = [
|
|
673
|
+
t.split(window_size, dim = 1) for t in
|
|
674
|
+
(
|
|
675
|
+
state,
|
|
676
|
+
action,
|
|
677
|
+
old_action_log_prob,
|
|
678
|
+
reward,
|
|
679
|
+
old_value,
|
|
680
|
+
mask
|
|
681
|
+
)
|
|
682
|
+
]
|
|
683
|
+
|
|
684
|
+
mean_actor_loss = self.zero.clone()
|
|
685
|
+
mean_critic_loss = self.zero.clone()
|
|
686
|
+
|
|
687
|
+
# learn across windows
|
|
688
|
+
|
|
689
|
+
cache = None
|
|
690
|
+
|
|
691
|
+
for (
|
|
692
|
+
state,
|
|
693
|
+
action,
|
|
694
|
+
old_action_log_prob,
|
|
695
|
+
reward,
|
|
696
|
+
old_value,
|
|
697
|
+
mask
|
|
698
|
+
) in zip(*windowed_tensors):
|
|
699
|
+
|
|
700
|
+
(action_logits, value), cache = self.forward(state, cache = cache, detach_cache = True, return_values = True)
|
|
701
|
+
entropy = calc_entropy(action_logits)
|
|
702
|
+
|
|
703
|
+
action = rearrange(action, 'b t -> b t 1')
|
|
704
|
+
log_prob = action_logits.gather(-1, action)
|
|
705
|
+
log_prob = rearrange(log_prob, 'b t 1 -> b t')
|
|
706
|
+
|
|
707
|
+
# update actor, classic clipped surrogate loss
|
|
708
|
+
|
|
709
|
+
eps_clip = self.ppo_eps_clip
|
|
710
|
+
ratio = (log_prob - old_action_log_prob).exp()
|
|
711
|
+
|
|
712
|
+
advantage, returns = calc_gae(reward, old_value, lam = self.gae_lam, gamma = self.discount_factor, **self.calc_gae_kwargs)
|
|
713
|
+
|
|
714
|
+
actor_loss = -torch.min(ratio * advantage, ratio.clamp(1. - eps_clip, 1. + eps_clip) * advantage)
|
|
715
|
+
|
|
716
|
+
actor_loss = actor_loss - self.ppo_entropy_weight * entropy
|
|
717
|
+
|
|
718
|
+
windowed_actor_loss = actor_loss[mask].sum() / total_learnable_tokens
|
|
719
|
+
windowed_actor_loss.backward(retain_graph = True)
|
|
720
|
+
|
|
721
|
+
# update critic
|
|
722
|
+
|
|
723
|
+
value_loss = F.mse_loss(returns, value, reduction = 'none')
|
|
724
|
+
|
|
725
|
+
value_clip = self.ppo_value_clip
|
|
726
|
+
clipped_value = old_value + (value - old_value).clamp(-value_clip, value_clip)
|
|
727
|
+
clipped_value_loss = F.mse_loss(returns, clipped_value, reduction = 'none')
|
|
728
|
+
|
|
729
|
+
critic_loss = torch.maximum(value_loss, clipped_value_loss) * self.value_loss_weight
|
|
730
|
+
|
|
731
|
+
windowed_critic_loss = critic_loss[mask].sum() / total_learnable_tokens
|
|
732
|
+
windowed_critic_loss.backward(retain_graph = True)
|
|
733
|
+
|
|
734
|
+
# accumulate
|
|
735
|
+
|
|
736
|
+
mean_actor_loss.add_(windowed_actor_loss)
|
|
737
|
+
mean_critic_loss.add_(windowed_critic_loss)
|
|
738
|
+
|
|
739
|
+
# optimizer update
|
|
740
|
+
|
|
741
|
+
if exists(actor_optim):
|
|
742
|
+
actor_optim.step()
|
|
743
|
+
actor_optim.zero_grad()
|
|
744
|
+
|
|
745
|
+
if exists(critic_optim):
|
|
746
|
+
critic_optim.step()
|
|
747
|
+
critic_optim.zero_grad()
|
|
748
|
+
|
|
749
|
+
# return losses for logging
|
|
750
|
+
|
|
751
|
+
return mean_actor_loss.detach(), mean_critic_loss.detach()
|
|
752
|
+
|
|
615
753
|
def wrap_env_functions(self, env):
|
|
616
754
|
|
|
617
755
|
def wrapped_reset(*args, **kwargs):
|
|
@@ -723,7 +861,7 @@ class Locoformer(Module):
|
|
|
723
861
|
# maybe detach cache
|
|
724
862
|
|
|
725
863
|
if detach_cache:
|
|
726
|
-
kv_cache =
|
|
864
|
+
kv_cache = kv_cache.detach()
|
|
727
865
|
|
|
728
866
|
# handle returning of values
|
|
729
867
|
|
|
@@ -169,7 +169,7 @@ for i in range(NUM_BATCHES):
|
|
|
169
169
|
prime = prime.to(model.device)
|
|
170
170
|
out = prime
|
|
171
171
|
|
|
172
|
-
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, initial_states = prime, inference_mode = True)
|
|
172
|
+
stateful_forward, logits = model.get_stateful_forward(has_batch_dim = False, has_time_dim = True, initial_states = prime, inference_mode = True)
|
|
173
173
|
|
|
174
174
|
# sample
|
|
175
175
|
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
# "accelerate",
|
|
4
4
|
# "fire",
|
|
5
5
|
# "gymnasium[box2d]>=1.0.0",
|
|
6
|
-
# "locoformer",
|
|
6
|
+
# "locoformer>=0.0.12",
|
|
7
7
|
# "moviepy",
|
|
8
8
|
# "tqdm"
|
|
9
9
|
# ]
|
|
@@ -13,17 +13,19 @@ from fire import Fire
|
|
|
13
13
|
from shutil import rmtree
|
|
14
14
|
from tqdm import tqdm
|
|
15
15
|
from collections import deque
|
|
16
|
+
from types import SimpleNamespace
|
|
16
17
|
|
|
17
18
|
from accelerate import Accelerator
|
|
18
19
|
|
|
19
20
|
import gymnasium as gym
|
|
20
21
|
|
|
21
22
|
import torch
|
|
22
|
-
from torch import from_numpy, randint, tensor, stack
|
|
23
|
+
from torch import from_numpy, randint, tensor, stack, arange
|
|
23
24
|
import torch.nn.functional as F
|
|
24
25
|
from torch.utils.data import TensorDataset, DataLoader
|
|
25
26
|
from torch.optim import Adam
|
|
26
27
|
|
|
28
|
+
import einx
|
|
27
29
|
from einops import rearrange
|
|
28
30
|
|
|
29
31
|
from locoformer.locoformer import Locoformer, ReplayBuffer
|
|
@@ -47,26 +49,69 @@ def gumbel_sample(logits, temperature = 1., eps = 1e-6):
|
|
|
47
49
|
noise = gumbel_noise(logits)
|
|
48
50
|
return ((logits / max(temperature, eps)) + noise).argmax(dim = -1)
|
|
49
51
|
|
|
52
|
+
# learn
|
|
53
|
+
|
|
54
|
+
def learn(
|
|
55
|
+
model,
|
|
56
|
+
actor_optim,
|
|
57
|
+
critic_optim,
|
|
58
|
+
accelerator,
|
|
59
|
+
replay,
|
|
60
|
+
batch_size = 16,
|
|
61
|
+
epochs = 2,
|
|
62
|
+
):
|
|
63
|
+
device = accelerator.device
|
|
64
|
+
|
|
65
|
+
dl = replay.dataloader(batch_size = batch_size, shuffle = True)
|
|
66
|
+
model, dl, actor_optim, critic_optim = accelerator.prepare(model, dl, actor_optim, critic_optim)
|
|
67
|
+
|
|
68
|
+
for _ in range(epochs):
|
|
69
|
+
for data in dl:
|
|
70
|
+
|
|
71
|
+
data = SimpleNamespace(**data)
|
|
72
|
+
|
|
73
|
+
seq_len = data.state.shape[1]
|
|
74
|
+
|
|
75
|
+
value_mask = einx.less('j, i -> i j', arange(seq_len, device = device), data._lens)
|
|
76
|
+
value = torch.where(value_mask, data.value, 0.)
|
|
77
|
+
|
|
78
|
+
actor_loss, critic_loss = model.ppo(
|
|
79
|
+
state = data.state,
|
|
80
|
+
action = data.action,
|
|
81
|
+
old_action_log_prob = data.action_log_prob,
|
|
82
|
+
reward = data.reward,
|
|
83
|
+
old_value = value,
|
|
84
|
+
mask = data.learnable,
|
|
85
|
+
actor_optim = actor_optim,
|
|
86
|
+
critic_optim = critic_optim
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
accelerator.print(f'actor: {actor_loss.item():.3f} | critic: {critic_loss.item():.3f}')
|
|
90
|
+
|
|
50
91
|
# main function
|
|
51
92
|
|
|
52
93
|
def main(
|
|
53
94
|
env_name = 'LunarLander-v3',
|
|
54
95
|
num_episodes = 50_000,
|
|
55
96
|
max_timesteps = 500,
|
|
56
|
-
|
|
97
|
+
num_episodes_before_learn = 32,
|
|
57
98
|
clear_video = True,
|
|
58
99
|
video_folder = 'recordings',
|
|
59
100
|
record_every_episode = 250,
|
|
101
|
+
learning_rate = 8e-4,
|
|
60
102
|
discount_factor = 0.99,
|
|
61
|
-
|
|
103
|
+
betas = (0.9, 0.99),
|
|
104
|
+
gae_lam = 0.95,
|
|
105
|
+
ppo_eps_clip = 0.2,
|
|
106
|
+
ppo_entropy_weight = .01,
|
|
62
107
|
batch_size = 16,
|
|
63
108
|
epochs = 2
|
|
64
109
|
):
|
|
65
110
|
|
|
66
111
|
# accelerate
|
|
67
112
|
|
|
68
|
-
|
|
69
|
-
device =
|
|
113
|
+
accelerator = Accelerator()
|
|
114
|
+
device = accelerator.device
|
|
70
115
|
|
|
71
116
|
# environment
|
|
72
117
|
|
|
@@ -91,14 +136,15 @@ def main(
|
|
|
91
136
|
replay = ReplayBuffer(
|
|
92
137
|
'replay',
|
|
93
138
|
num_episodes,
|
|
94
|
-
max_timesteps,
|
|
139
|
+
max_timesteps + 1, # one extra node for bootstrap node - not relevant for locoformer, but for completeness
|
|
95
140
|
fields = dict(
|
|
96
141
|
state = ('float', (dim_state,)),
|
|
97
142
|
action = 'int',
|
|
98
143
|
action_log_prob = 'float',
|
|
99
144
|
reward = 'float',
|
|
100
145
|
value = 'float',
|
|
101
|
-
done = 'bool'
|
|
146
|
+
done = 'bool',
|
|
147
|
+
learnable = 'bool'
|
|
102
148
|
)
|
|
103
149
|
)
|
|
104
150
|
|
|
@@ -114,11 +160,18 @@ def main(
|
|
|
114
160
|
heads = 4,
|
|
115
161
|
depth = 4,
|
|
116
162
|
window_size = 16
|
|
163
|
+
),
|
|
164
|
+
discount_factor = discount_factor,
|
|
165
|
+
gae_lam = gae_lam,
|
|
166
|
+
ppo_eps_clip = ppo_eps_clip,
|
|
167
|
+
ppo_entropy_weight = ppo_entropy_weight,
|
|
168
|
+
calc_gae_kwargs = dict(
|
|
169
|
+
use_accelerated = False
|
|
117
170
|
)
|
|
118
171
|
).to(device)
|
|
119
172
|
|
|
120
|
-
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate)
|
|
121
|
-
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate)
|
|
173
|
+
optim_actor = Adam([*locoformer.transformer.parameters(), *locoformer.actor_parameters()], lr = learning_rate, betas = betas)
|
|
174
|
+
optim_critic = Adam([*locoformer.transformer.parameters(), *locoformer.critic_parameters()], lr = learning_rate, betas = betas)
|
|
122
175
|
|
|
123
176
|
timesteps_learn = 0
|
|
124
177
|
|
|
@@ -129,7 +182,8 @@ def main(
|
|
|
129
182
|
|
|
130
183
|
# loop
|
|
131
184
|
|
|
132
|
-
for
|
|
185
|
+
for episodes_index in tqdm(range(num_episodes)):
|
|
186
|
+
|
|
133
187
|
state, *_ = env_reset()
|
|
134
188
|
|
|
135
189
|
timestep = 0
|
|
@@ -158,27 +212,29 @@ def main(
|
|
|
158
212
|
action_log_prob = action_logits.gather(-1, rearrange(action, '-> 1'))
|
|
159
213
|
action_log_prob = rearrange(action_log_prob, '1 ->')
|
|
160
214
|
|
|
161
|
-
replay.store(
|
|
215
|
+
memory = replay.store(
|
|
162
216
|
state = state,
|
|
163
217
|
action = action,
|
|
164
218
|
action_log_prob = action_log_prob,
|
|
165
219
|
reward = reward,
|
|
166
220
|
value = value,
|
|
167
|
-
done = done
|
|
221
|
+
done = done,
|
|
222
|
+
learnable = tensor(True)
|
|
168
223
|
)
|
|
169
224
|
|
|
170
|
-
#
|
|
225
|
+
# handle bootstrap value, which is a non-learnable timestep added with the next value for GAE
|
|
226
|
+
# only if terminated signal not detected
|
|
171
227
|
|
|
172
|
-
|
|
173
|
-
|
|
228
|
+
if not terminated:
|
|
229
|
+
_, next_value = stateful_forward(next_state, return_values = True)
|
|
230
|
+
|
|
231
|
+
memory._replace(value = next_value, learnable = False)
|
|
174
232
|
|
|
175
|
-
|
|
233
|
+
replay.store(**memory._asdict())
|
|
176
234
|
|
|
177
|
-
|
|
178
|
-
# todo - carry out learning
|
|
235
|
+
# increment counters
|
|
179
236
|
|
|
180
|
-
|
|
181
|
-
memories.clear()
|
|
237
|
+
timestep += 1
|
|
182
238
|
|
|
183
239
|
# break if done or exceed max timestep
|
|
184
240
|
|
|
@@ -187,6 +243,19 @@ def main(
|
|
|
187
243
|
|
|
188
244
|
state = next_state
|
|
189
245
|
|
|
246
|
+
# learn if hit the number of learn timesteps
|
|
247
|
+
|
|
248
|
+
if divisible_by(episodes_index + 1, num_episodes_before_learn):
|
|
249
|
+
|
|
250
|
+
learn(
|
|
251
|
+
locoformer,
|
|
252
|
+
optim_actor,
|
|
253
|
+
optim_critic,
|
|
254
|
+
accelerator,
|
|
255
|
+
replay,
|
|
256
|
+
batch_size,
|
|
257
|
+
epochs,
|
|
258
|
+
)
|
|
190
259
|
# main
|
|
191
260
|
|
|
192
261
|
if __name__ == '__main__':
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|