dreamer4 0.0.7__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/__init__.py +8 -2
- dreamer4/dreamer4.py +2808 -814
- dreamer4/mocks.py +97 -0
- dreamer4/trainers.py +525 -3
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/METADATA +97 -11
- dreamer4-0.1.16.dist-info/RECORD +8 -0
- dreamer4-0.0.7.dist-info/RECORD +0 -7
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.7.dist-info → dreamer4-0.1.16.dist-info}/licenses/LICENSE +0 -0
dreamer4/mocks.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from random import choice
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import tensor, empty, randn, randint
|
|
6
|
+
from torch.nn import Module
|
|
7
|
+
|
|
8
|
+
from einops import repeat
|
|
9
|
+
|
|
10
|
+
# helpers
|
|
11
|
+
|
|
12
|
+
def exists(v):
|
|
13
|
+
return v is not None
|
|
14
|
+
|
|
15
|
+
# mock env
|
|
16
|
+
|
|
17
|
+
class MockEnv(Module):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
image_shape,
|
|
21
|
+
reward_range = (-100, 100),
|
|
22
|
+
num_envs = 1,
|
|
23
|
+
vectorized = False,
|
|
24
|
+
terminate_after_step = None,
|
|
25
|
+
rand_terminate_prob = 0.05,
|
|
26
|
+
can_truncate = False,
|
|
27
|
+
rand_truncate_prob = 0.05,
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.image_shape = image_shape
|
|
31
|
+
self.reward_range = reward_range
|
|
32
|
+
|
|
33
|
+
self.num_envs = num_envs
|
|
34
|
+
self.vectorized = vectorized
|
|
35
|
+
assert not (vectorized and num_envs == 1)
|
|
36
|
+
|
|
37
|
+
# mocking termination and truncation
|
|
38
|
+
|
|
39
|
+
self.can_terminate = exists(terminate_after_step)
|
|
40
|
+
self.terminate_after_step = terminate_after_step
|
|
41
|
+
self.rand_terminate_prob = rand_terminate_prob
|
|
42
|
+
|
|
43
|
+
self.can_truncate = can_truncate
|
|
44
|
+
self.rand_truncate_prob = rand_truncate_prob
|
|
45
|
+
|
|
46
|
+
self.register_buffer('_step', tensor(0))
|
|
47
|
+
|
|
48
|
+
def get_random_state(self):
|
|
49
|
+
return randn(3, *self.image_shape)
|
|
50
|
+
|
|
51
|
+
def reset(
|
|
52
|
+
self,
|
|
53
|
+
seed = None
|
|
54
|
+
):
|
|
55
|
+
self._step.zero_()
|
|
56
|
+
state = self.get_random_state()
|
|
57
|
+
|
|
58
|
+
if self.vectorized:
|
|
59
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
60
|
+
|
|
61
|
+
return state
|
|
62
|
+
|
|
63
|
+
def step(
|
|
64
|
+
self,
|
|
65
|
+
actions,
|
|
66
|
+
):
|
|
67
|
+
state = self.get_random_state()
|
|
68
|
+
|
|
69
|
+
reward = empty(()).uniform_(*self.reward_range)
|
|
70
|
+
|
|
71
|
+
if self.vectorized:
|
|
72
|
+
discrete, continuous = actions
|
|
73
|
+
assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
|
|
74
|
+
|
|
75
|
+
state = repeat(state, '... -> b ...', b = self.num_envs)
|
|
76
|
+
reward = repeat(reward, ' -> b', b = self.num_envs)
|
|
77
|
+
|
|
78
|
+
out = (state, reward)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if self.can_terminate:
|
|
82
|
+
shape = (self.num_envs,) if self.vectorized else (1,)
|
|
83
|
+
valid_step = self._step > self.terminate_after_step
|
|
84
|
+
|
|
85
|
+
terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
|
|
86
|
+
|
|
87
|
+
out = (*out, terminate)
|
|
88
|
+
|
|
89
|
+
# maybe truncation
|
|
90
|
+
|
|
91
|
+
if self.can_truncate:
|
|
92
|
+
truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
|
|
93
|
+
out = (*out, truncate)
|
|
94
|
+
|
|
95
|
+
self._step.add_(1)
|
|
96
|
+
|
|
97
|
+
return out
|
dreamer4/trainers.py
CHANGED
|
@@ -1,17 +1,539 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import torch
|
|
4
|
+
from torch import is_tensor
|
|
2
5
|
from torch.nn import Module
|
|
6
|
+
from torch.optim import AdamW
|
|
7
|
+
from torch.utils.data import Dataset, TensorDataset, DataLoader
|
|
3
8
|
|
|
4
9
|
from accelerate import Accelerator
|
|
5
10
|
|
|
11
|
+
from adam_atan2_pytorch import MuonAdamAtan2
|
|
12
|
+
|
|
6
13
|
from dreamer4.dreamer4 import (
|
|
7
14
|
VideoTokenizer,
|
|
8
|
-
|
|
15
|
+
DynamicsWorldModel,
|
|
16
|
+
Experience,
|
|
17
|
+
combine_experiences
|
|
9
18
|
)
|
|
10
19
|
|
|
20
|
+
from ema_pytorch import EMA
|
|
21
|
+
|
|
22
|
+
# helpers
|
|
23
|
+
|
|
24
|
+
def exists(v):
|
|
25
|
+
return v is not None
|
|
26
|
+
|
|
27
|
+
def default(v, d):
|
|
28
|
+
return v if exists(v) else d
|
|
29
|
+
|
|
30
|
+
def cycle(dl):
|
|
31
|
+
while True:
|
|
32
|
+
for batch in dl:
|
|
33
|
+
yield batch
|
|
34
|
+
|
|
35
|
+
# trainers
|
|
36
|
+
|
|
11
37
|
class VideoTokenizerTrainer(Module):
|
|
12
38
|
def __init__(
|
|
13
39
|
self,
|
|
14
|
-
model: VideoTokenizer
|
|
40
|
+
model: VideoTokenizer,
|
|
41
|
+
dataset: Dataset,
|
|
42
|
+
optim_klass = MuonAdamAtan2,
|
|
43
|
+
batch_size = 16,
|
|
44
|
+
learning_rate = 3e-4,
|
|
45
|
+
max_grad_norm = None,
|
|
46
|
+
num_train_steps = 10_000,
|
|
47
|
+
weight_decay = 0.,
|
|
48
|
+
accelerate_kwargs: dict = dict(),
|
|
49
|
+
optim_kwargs: dict = dict(),
|
|
50
|
+
cpu = False,
|
|
51
|
+
):
|
|
52
|
+
super().__init__()
|
|
53
|
+
batch_size = min(batch_size, len(dataset))
|
|
54
|
+
|
|
55
|
+
self.accelerator = Accelerator(
|
|
56
|
+
cpu = cpu,
|
|
57
|
+
**accelerate_kwargs
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self.model = model
|
|
61
|
+
self.dataset = dataset
|
|
62
|
+
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
|
63
|
+
|
|
64
|
+
optim_kwargs = dict(
|
|
65
|
+
lr = learning_rate,
|
|
66
|
+
weight_decay = weight_decay
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if optim_klass is MuonAdamAtan2:
|
|
70
|
+
optim = MuonAdamAtan2(
|
|
71
|
+
model.muon_parameters(),
|
|
72
|
+
model.parameters(),
|
|
73
|
+
**optim_kwargs
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
optim = optim_klass(
|
|
77
|
+
model.parameters(),
|
|
78
|
+
**optim_kwargs
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.optim = optim
|
|
82
|
+
|
|
83
|
+
self.max_grad_norm = max_grad_norm
|
|
84
|
+
|
|
85
|
+
self.num_train_steps = num_train_steps
|
|
86
|
+
self.batch_size = batch_size
|
|
87
|
+
|
|
88
|
+
(
|
|
89
|
+
self.model,
|
|
90
|
+
self.train_dataloader,
|
|
91
|
+
self.optim
|
|
92
|
+
) = self.accelerator.prepare(
|
|
93
|
+
self.model,
|
|
94
|
+
self.train_dataloader,
|
|
95
|
+
self.optim
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def device(self):
|
|
100
|
+
return self.accelerator.device
|
|
101
|
+
|
|
102
|
+
def print(self, *args, **kwargs):
|
|
103
|
+
return self.accelerator.print(*args, **kwargs)
|
|
104
|
+
|
|
105
|
+
def forward(
|
|
106
|
+
self
|
|
107
|
+
):
|
|
108
|
+
iter_train_dl = cycle(self.train_dataloader)
|
|
109
|
+
|
|
110
|
+
for _ in range(self.num_train_steps):
|
|
111
|
+
video = next(iter_train_dl)
|
|
112
|
+
|
|
113
|
+
loss = self.model(video)
|
|
114
|
+
self.accelerator.backward(loss)
|
|
115
|
+
|
|
116
|
+
if exists(self.max_grad_norm):
|
|
117
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
|
118
|
+
|
|
119
|
+
self.optim.step()
|
|
120
|
+
self.optim.zero_grad()
|
|
121
|
+
|
|
122
|
+
self.print('training complete')
|
|
123
|
+
|
|
124
|
+
# dynamics world model
|
|
125
|
+
|
|
126
|
+
class BehaviorCloneTrainer(Module):
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
model: DynamicsWorldModel,
|
|
130
|
+
dataset: Dataset,
|
|
131
|
+
optim_klass = MuonAdamAtan2,
|
|
132
|
+
batch_size = 16,
|
|
133
|
+
learning_rate = 3e-4,
|
|
134
|
+
max_grad_norm = None,
|
|
135
|
+
num_train_steps = 10_000,
|
|
136
|
+
weight_decay = 0.,
|
|
137
|
+
accelerate_kwargs: dict = dict(),
|
|
138
|
+
optim_kwargs: dict = dict(),
|
|
139
|
+
cpu = False,
|
|
15
140
|
):
|
|
16
141
|
super().__init__()
|
|
17
|
-
|
|
142
|
+
batch_size = min(batch_size, len(dataset))
|
|
143
|
+
|
|
144
|
+
self.accelerator = Accelerator(
|
|
145
|
+
cpu = cpu,
|
|
146
|
+
**accelerate_kwargs
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.model = model
|
|
150
|
+
self.dataset = dataset
|
|
151
|
+
self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
|
|
152
|
+
|
|
153
|
+
optim_kwargs = dict(
|
|
154
|
+
lr = learning_rate,
|
|
155
|
+
weight_decay = weight_decay
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if optim_klass is MuonAdamAtan2:
|
|
159
|
+
optim = MuonAdamAtan2(
|
|
160
|
+
model.muon_parameters(),
|
|
161
|
+
model.parameters(),
|
|
162
|
+
**optim_kwargs
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
optim = optim_klass(
|
|
166
|
+
model.parameters(),
|
|
167
|
+
**optim_kwargs
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self.optim = optim
|
|
171
|
+
|
|
172
|
+
self.max_grad_norm = max_grad_norm
|
|
173
|
+
|
|
174
|
+
self.num_train_steps = num_train_steps
|
|
175
|
+
self.batch_size = batch_size
|
|
176
|
+
|
|
177
|
+
(
|
|
178
|
+
self.model,
|
|
179
|
+
self.train_dataloader,
|
|
180
|
+
self.optim
|
|
181
|
+
) = self.accelerator.prepare(
|
|
182
|
+
self.model,
|
|
183
|
+
self.train_dataloader,
|
|
184
|
+
self.optim
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def device(self):
|
|
189
|
+
return self.accelerator.device
|
|
190
|
+
|
|
191
|
+
def print(self, *args, **kwargs):
|
|
192
|
+
return self.accelerator.print(*args, **kwargs)
|
|
193
|
+
|
|
194
|
+
def forward(
|
|
195
|
+
self
|
|
196
|
+
):
|
|
197
|
+
iter_train_dl = cycle(self.train_dataloader)
|
|
198
|
+
|
|
199
|
+
for _ in range(self.num_train_steps):
|
|
200
|
+
batch_data = next(iter_train_dl)
|
|
201
|
+
|
|
202
|
+
# just assume raw video dynamics training if batch_data is a tensor
|
|
203
|
+
# else kwargs for video, actions, rewards
|
|
204
|
+
|
|
205
|
+
if is_tensor(batch_data):
|
|
206
|
+
loss = self.model(batch_data)
|
|
207
|
+
else:
|
|
208
|
+
loss = self.model(**batch_data)
|
|
209
|
+
|
|
210
|
+
self.accelerator.backward(loss)
|
|
211
|
+
|
|
212
|
+
if exists(self.max_grad_norm):
|
|
213
|
+
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
|
214
|
+
|
|
215
|
+
self.optim.step()
|
|
216
|
+
self.optim.zero_grad()
|
|
217
|
+
|
|
218
|
+
self.print('training complete')
|
|
219
|
+
|
|
220
|
+
# training from dreams
|
|
221
|
+
|
|
222
|
+
class DreamTrainer(Module):
|
|
223
|
+
def __init__(
|
|
224
|
+
self,
|
|
225
|
+
model: DynamicsWorldModel,
|
|
226
|
+
optim_klass = AdamW,
|
|
227
|
+
batch_size = 16,
|
|
228
|
+
generate_timesteps = 16,
|
|
229
|
+
learning_rate = 3e-4,
|
|
230
|
+
max_grad_norm = None,
|
|
231
|
+
num_train_steps = 10_000,
|
|
232
|
+
weight_decay = 0.,
|
|
233
|
+
accelerate_kwargs: dict = dict(),
|
|
234
|
+
optim_kwargs: dict = dict(),
|
|
235
|
+
cpu = False,
|
|
236
|
+
):
|
|
237
|
+
super().__init__()
|
|
238
|
+
|
|
239
|
+
self.accelerator = Accelerator(
|
|
240
|
+
cpu = cpu,
|
|
241
|
+
**accelerate_kwargs
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self.model = model
|
|
245
|
+
|
|
246
|
+
optim_kwargs = dict(
|
|
247
|
+
lr = learning_rate,
|
|
248
|
+
weight_decay = weight_decay
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
|
252
|
+
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
|
253
|
+
|
|
254
|
+
self.max_grad_norm = max_grad_norm
|
|
255
|
+
|
|
256
|
+
self.num_train_steps = num_train_steps
|
|
257
|
+
self.batch_size = batch_size
|
|
258
|
+
self.generate_timesteps = generate_timesteps
|
|
259
|
+
|
|
260
|
+
self.unwrapped_model = self.model
|
|
261
|
+
|
|
262
|
+
(
|
|
263
|
+
self.model,
|
|
264
|
+
self.policy_head_optim,
|
|
265
|
+
self.value_head_optim,
|
|
266
|
+
) = self.accelerator.prepare(
|
|
267
|
+
self.model,
|
|
268
|
+
self.policy_head_optim,
|
|
269
|
+
self.value_head_optim
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def device(self):
|
|
274
|
+
return self.accelerator.device
|
|
275
|
+
|
|
276
|
+
@property
|
|
277
|
+
def unwrapped_model(self):
|
|
278
|
+
return self.accelerator.unwrap_model(self.model)
|
|
279
|
+
|
|
280
|
+
def print(self, *args, **kwargs):
|
|
281
|
+
return self.accelerator.print(*args, **kwargs)
|
|
282
|
+
|
|
283
|
+
def forward(
|
|
284
|
+
self
|
|
285
|
+
):
|
|
286
|
+
|
|
287
|
+
for _ in range(self.num_train_steps):
|
|
288
|
+
|
|
289
|
+
dreams = self.unwrapped_model.generate(
|
|
290
|
+
self.generate_timesteps + 1, # plus one for bootstrap value
|
|
291
|
+
batch_size = self.batch_size,
|
|
292
|
+
return_rewards_per_frame = True,
|
|
293
|
+
return_agent_actions = True,
|
|
294
|
+
return_log_probs_and_values = True
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
|
|
298
|
+
|
|
299
|
+
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
|
300
|
+
|
|
301
|
+
# update policy head
|
|
302
|
+
|
|
303
|
+
self.accelerator.backward(policy_head_loss)
|
|
304
|
+
|
|
305
|
+
if exists(self.max_grad_norm):
|
|
306
|
+
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
|
307
|
+
|
|
308
|
+
self.policy_head_optim.step()
|
|
309
|
+
self.policy_head_optim.zero_grad()
|
|
310
|
+
|
|
311
|
+
# update value head
|
|
312
|
+
|
|
313
|
+
self.accelerator.backward(value_head_loss)
|
|
314
|
+
|
|
315
|
+
if exists(self.max_grad_norm):
|
|
316
|
+
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
|
317
|
+
|
|
318
|
+
self.value_head_optim.step()
|
|
319
|
+
self.value_head_optim.zero_grad()
|
|
320
|
+
|
|
321
|
+
self.print('training complete')
|
|
322
|
+
|
|
323
|
+
# training from sim
|
|
324
|
+
|
|
325
|
+
class SimTrainer(Module):
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
model: DynamicsWorldModel,
|
|
329
|
+
optim_klass = AdamW,
|
|
330
|
+
batch_size = 16,
|
|
331
|
+
generate_timesteps = 16,
|
|
332
|
+
learning_rate = 3e-4,
|
|
333
|
+
max_grad_norm = None,
|
|
334
|
+
epochs = 2,
|
|
335
|
+
weight_decay = 0.,
|
|
336
|
+
accelerate_kwargs: dict = dict(),
|
|
337
|
+
optim_kwargs: dict = dict(),
|
|
338
|
+
cpu = False,
|
|
339
|
+
):
|
|
340
|
+
super().__init__()
|
|
341
|
+
|
|
342
|
+
self.accelerator = Accelerator(
|
|
343
|
+
cpu = cpu,
|
|
344
|
+
**accelerate_kwargs
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
self.model = model
|
|
348
|
+
|
|
349
|
+
optim_kwargs = dict(
|
|
350
|
+
lr = learning_rate,
|
|
351
|
+
weight_decay = weight_decay
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
|
|
355
|
+
self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
|
|
356
|
+
|
|
357
|
+
self.max_grad_norm = max_grad_norm
|
|
358
|
+
|
|
359
|
+
self.epochs = epochs
|
|
360
|
+
self.batch_size = batch_size
|
|
361
|
+
|
|
362
|
+
self.generate_timesteps = generate_timesteps
|
|
363
|
+
|
|
364
|
+
self.unwrapped_model = self.model
|
|
365
|
+
|
|
366
|
+
(
|
|
367
|
+
self.model,
|
|
368
|
+
self.policy_head_optim,
|
|
369
|
+
self.value_head_optim,
|
|
370
|
+
) = self.accelerator.prepare(
|
|
371
|
+
self.model,
|
|
372
|
+
self.policy_head_optim,
|
|
373
|
+
self.value_head_optim
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
@property
|
|
377
|
+
def device(self):
|
|
378
|
+
return self.accelerator.device
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def unwrapped_model(self):
|
|
382
|
+
return self.accelerator.unwrap_model(self.model)
|
|
383
|
+
|
|
384
|
+
def print(self, *args, **kwargs):
|
|
385
|
+
return self.accelerator.print(*args, **kwargs)
|
|
386
|
+
|
|
387
|
+
def learn(
|
|
388
|
+
self,
|
|
389
|
+
experience: Experience
|
|
390
|
+
):
|
|
391
|
+
|
|
392
|
+
step_size = experience.step_size
|
|
393
|
+
agent_index = experience.agent_index
|
|
394
|
+
|
|
395
|
+
latents = experience.latents
|
|
396
|
+
old_values = experience.values
|
|
397
|
+
rewards = experience.rewards
|
|
398
|
+
|
|
399
|
+
has_agent_embed = exists(experience.agent_embed)
|
|
400
|
+
agent_embed = experience.agent_embed
|
|
401
|
+
|
|
402
|
+
discrete_actions, continuous_actions = experience.actions
|
|
403
|
+
discrete_log_probs, continuous_log_probs = experience.log_probs
|
|
404
|
+
|
|
405
|
+
discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
|
|
406
|
+
|
|
407
|
+
# handle empties
|
|
408
|
+
|
|
409
|
+
empty_tensor = torch.empty_like(rewards)
|
|
410
|
+
|
|
411
|
+
agent_embed = default(agent_embed, empty_tensor)
|
|
412
|
+
|
|
413
|
+
has_discrete = exists(discrete_actions)
|
|
414
|
+
has_continuous = exists(continuous_actions)
|
|
415
|
+
|
|
416
|
+
discrete_actions = default(discrete_actions, empty_tensor)
|
|
417
|
+
continuous_actions = default(continuous_actions, empty_tensor)
|
|
418
|
+
|
|
419
|
+
discrete_log_probs = default(discrete_log_probs, empty_tensor)
|
|
420
|
+
continuous_log_probs = default(continuous_log_probs, empty_tensor)
|
|
421
|
+
|
|
422
|
+
discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
|
423
|
+
continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
|
|
424
|
+
|
|
425
|
+
# create the dataset and dataloader
|
|
426
|
+
|
|
427
|
+
dataset = TensorDataset(
|
|
428
|
+
latents,
|
|
429
|
+
discrete_actions,
|
|
430
|
+
continuous_actions,
|
|
431
|
+
discrete_log_probs,
|
|
432
|
+
continuous_log_probs,
|
|
433
|
+
agent_embed,
|
|
434
|
+
discrete_old_action_unembeds,
|
|
435
|
+
continuous_old_action_unembeds,
|
|
436
|
+
old_values,
|
|
437
|
+
rewards
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
|
|
441
|
+
|
|
442
|
+
for epoch in range(self.epochs):
|
|
443
|
+
|
|
444
|
+
for (
|
|
445
|
+
latents,
|
|
446
|
+
discrete_actions,
|
|
447
|
+
continuous_actions,
|
|
448
|
+
discrete_log_probs,
|
|
449
|
+
continuous_log_probs,
|
|
450
|
+
agent_embed,
|
|
451
|
+
discrete_old_action_unembeds,
|
|
452
|
+
continuous_old_action_unembeds,
|
|
453
|
+
old_values,
|
|
454
|
+
rewards
|
|
455
|
+
) in dataloader:
|
|
456
|
+
|
|
457
|
+
actions = (
|
|
458
|
+
discrete_actions if has_discrete else None,
|
|
459
|
+
continuous_actions if has_continuous else None
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
log_probs = (
|
|
463
|
+
discrete_log_probs if has_discrete else None,
|
|
464
|
+
continuous_log_probs if has_continuous else None
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
old_action_unembeds = (
|
|
468
|
+
discrete_old_action_unembeds if has_discrete else None,
|
|
469
|
+
continuous_old_action_unembeds if has_continuous else None
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
batch_experience = Experience(
|
|
473
|
+
latents = latents,
|
|
474
|
+
actions = actions,
|
|
475
|
+
log_probs = log_probs,
|
|
476
|
+
agent_embed = agent_embed if has_agent_embed else None,
|
|
477
|
+
old_action_unembeds = old_action_unembeds,
|
|
478
|
+
values = old_values,
|
|
479
|
+
rewards = rewards,
|
|
480
|
+
step_size = step_size,
|
|
481
|
+
agent_index = agent_index
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
|
|
485
|
+
|
|
486
|
+
self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
|
|
487
|
+
|
|
488
|
+
# update policy head
|
|
489
|
+
|
|
490
|
+
self.accelerator.backward(policy_head_loss)
|
|
491
|
+
|
|
492
|
+
if exists(self.max_grad_norm):
|
|
493
|
+
self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
|
|
494
|
+
|
|
495
|
+
self.policy_head_optim.step()
|
|
496
|
+
self.policy_head_optim.zero_grad()
|
|
497
|
+
|
|
498
|
+
# update value head
|
|
499
|
+
|
|
500
|
+
self.accelerator.backward(value_head_loss)
|
|
501
|
+
|
|
502
|
+
if exists(self.max_grad_norm):
|
|
503
|
+
self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
|
|
504
|
+
|
|
505
|
+
self.value_head_optim.step()
|
|
506
|
+
self.value_head_optim.zero_grad()
|
|
507
|
+
|
|
508
|
+
self.print('training complete')
|
|
509
|
+
|
|
510
|
+
def forward(
|
|
511
|
+
self,
|
|
512
|
+
env,
|
|
513
|
+
num_episodes = 50000,
|
|
514
|
+
max_experiences_before_learn = 8,
|
|
515
|
+
env_is_vectorized = False
|
|
516
|
+
):
|
|
517
|
+
|
|
518
|
+
for _ in range(num_episodes):
|
|
519
|
+
|
|
520
|
+
total_experience = 0
|
|
521
|
+
experiences = []
|
|
522
|
+
|
|
523
|
+
while total_experience < max_experiences_before_learn:
|
|
524
|
+
|
|
525
|
+
experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
|
|
526
|
+
|
|
527
|
+
num_experience = experience.video.shape[0]
|
|
528
|
+
|
|
529
|
+
total_experience += num_experience
|
|
530
|
+
|
|
531
|
+
experiences.append(experience.cpu())
|
|
532
|
+
|
|
533
|
+
combined_experiences = combine_experiences(experiences)
|
|
534
|
+
|
|
535
|
+
self.learn(combined_experiences)
|
|
536
|
+
|
|
537
|
+
experiences.clear()
|
|
538
|
+
|
|
539
|
+
self.print('training complete')
|