dreamer4 0.0.7__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/mocks.py ADDED
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+ from random import choice
3
+
4
+ import torch
5
+ from torch import tensor, empty, randn, randint
6
+ from torch.nn import Module
7
+
8
+ from einops import repeat
9
+
10
+ # helpers
11
+
12
+ def exists(v):
13
+ return v is not None
14
+
15
+ # mock env
16
+
17
+ class MockEnv(Module):
18
+ def __init__(
19
+ self,
20
+ image_shape,
21
+ reward_range = (-100, 100),
22
+ num_envs = 1,
23
+ vectorized = False,
24
+ terminate_after_step = None,
25
+ rand_terminate_prob = 0.05,
26
+ can_truncate = False,
27
+ rand_truncate_prob = 0.05,
28
+ ):
29
+ super().__init__()
30
+ self.image_shape = image_shape
31
+ self.reward_range = reward_range
32
+
33
+ self.num_envs = num_envs
34
+ self.vectorized = vectorized
35
+ assert not (vectorized and num_envs == 1)
36
+
37
+ # mocking termination and truncation
38
+
39
+ self.can_terminate = exists(terminate_after_step)
40
+ self.terminate_after_step = terminate_after_step
41
+ self.rand_terminate_prob = rand_terminate_prob
42
+
43
+ self.can_truncate = can_truncate
44
+ self.rand_truncate_prob = rand_truncate_prob
45
+
46
+ self.register_buffer('_step', tensor(0))
47
+
48
+ def get_random_state(self):
49
+ return randn(3, *self.image_shape)
50
+
51
+ def reset(
52
+ self,
53
+ seed = None
54
+ ):
55
+ self._step.zero_()
56
+ state = self.get_random_state()
57
+
58
+ if self.vectorized:
59
+ state = repeat(state, '... -> b ...', b = self.num_envs)
60
+
61
+ return state
62
+
63
+ def step(
64
+ self,
65
+ actions,
66
+ ):
67
+ state = self.get_random_state()
68
+
69
+ reward = empty(()).uniform_(*self.reward_range)
70
+
71
+ if self.vectorized:
72
+ discrete, continuous = actions
73
+ assert discrete.shape[0] == self.num_envs, f'expected batch of actions for {self.num_envs} environments'
74
+
75
+ state = repeat(state, '... -> b ...', b = self.num_envs)
76
+ reward = repeat(reward, ' -> b', b = self.num_envs)
77
+
78
+ out = (state, reward)
79
+
80
+
81
+ if self.can_terminate:
82
+ shape = (self.num_envs,) if self.vectorized else (1,)
83
+ valid_step = self._step > self.terminate_after_step
84
+
85
+ terminate = (torch.rand(shape) < self.rand_terminate_prob) & valid_step
86
+
87
+ out = (*out, terminate)
88
+
89
+ # maybe truncation
90
+
91
+ if self.can_truncate:
92
+ truncate = (torch.rand(shape) < self.rand_truncate_prob) & valid_step & ~terminate
93
+ out = (*out, truncate)
94
+
95
+ self._step.add_(1)
96
+
97
+ return out
dreamer4/trainers.py CHANGED
@@ -1,17 +1,539 @@
1
+ from __future__ import annotations
2
+
1
3
  import torch
4
+ from torch import is_tensor
2
5
  from torch.nn import Module
6
+ from torch.optim import AdamW
7
+ from torch.utils.data import Dataset, TensorDataset, DataLoader
3
8
 
4
9
  from accelerate import Accelerator
5
10
 
11
+ from adam_atan2_pytorch import MuonAdamAtan2
12
+
6
13
  from dreamer4.dreamer4 import (
7
14
  VideoTokenizer,
8
- DynamicsModel
15
+ DynamicsWorldModel,
16
+ Experience,
17
+ combine_experiences
9
18
  )
10
19
 
20
+ from ema_pytorch import EMA
21
+
22
+ # helpers
23
+
24
+ def exists(v):
25
+ return v is not None
26
+
27
+ def default(v, d):
28
+ return v if exists(v) else d
29
+
30
+ def cycle(dl):
31
+ while True:
32
+ for batch in dl:
33
+ yield batch
34
+
35
+ # trainers
36
+
11
37
  class VideoTokenizerTrainer(Module):
12
38
  def __init__(
13
39
  self,
14
- model: VideoTokenizer
40
+ model: VideoTokenizer,
41
+ dataset: Dataset,
42
+ optim_klass = MuonAdamAtan2,
43
+ batch_size = 16,
44
+ learning_rate = 3e-4,
45
+ max_grad_norm = None,
46
+ num_train_steps = 10_000,
47
+ weight_decay = 0.,
48
+ accelerate_kwargs: dict = dict(),
49
+ optim_kwargs: dict = dict(),
50
+ cpu = False,
51
+ ):
52
+ super().__init__()
53
+ batch_size = min(batch_size, len(dataset))
54
+
55
+ self.accelerator = Accelerator(
56
+ cpu = cpu,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ self.model = model
61
+ self.dataset = dataset
62
+ self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
63
+
64
+ optim_kwargs = dict(
65
+ lr = learning_rate,
66
+ weight_decay = weight_decay
67
+ )
68
+
69
+ if optim_klass is MuonAdamAtan2:
70
+ optim = MuonAdamAtan2(
71
+ model.muon_parameters(),
72
+ model.parameters(),
73
+ **optim_kwargs
74
+ )
75
+ else:
76
+ optim = optim_klass(
77
+ model.parameters(),
78
+ **optim_kwargs
79
+ )
80
+
81
+ self.optim = optim
82
+
83
+ self.max_grad_norm = max_grad_norm
84
+
85
+ self.num_train_steps = num_train_steps
86
+ self.batch_size = batch_size
87
+
88
+ (
89
+ self.model,
90
+ self.train_dataloader,
91
+ self.optim
92
+ ) = self.accelerator.prepare(
93
+ self.model,
94
+ self.train_dataloader,
95
+ self.optim
96
+ )
97
+
98
+ @property
99
+ def device(self):
100
+ return self.accelerator.device
101
+
102
+ def print(self, *args, **kwargs):
103
+ return self.accelerator.print(*args, **kwargs)
104
+
105
+ def forward(
106
+ self
107
+ ):
108
+ iter_train_dl = cycle(self.train_dataloader)
109
+
110
+ for _ in range(self.num_train_steps):
111
+ video = next(iter_train_dl)
112
+
113
+ loss = self.model(video)
114
+ self.accelerator.backward(loss)
115
+
116
+ if exists(self.max_grad_norm):
117
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
118
+
119
+ self.optim.step()
120
+ self.optim.zero_grad()
121
+
122
+ self.print('training complete')
123
+
124
+ # dynamics world model
125
+
126
+ class BehaviorCloneTrainer(Module):
127
+ def __init__(
128
+ self,
129
+ model: DynamicsWorldModel,
130
+ dataset: Dataset,
131
+ optim_klass = MuonAdamAtan2,
132
+ batch_size = 16,
133
+ learning_rate = 3e-4,
134
+ max_grad_norm = None,
135
+ num_train_steps = 10_000,
136
+ weight_decay = 0.,
137
+ accelerate_kwargs: dict = dict(),
138
+ optim_kwargs: dict = dict(),
139
+ cpu = False,
15
140
  ):
16
141
  super().__init__()
17
- raise NotImplementedError
142
+ batch_size = min(batch_size, len(dataset))
143
+
144
+ self.accelerator = Accelerator(
145
+ cpu = cpu,
146
+ **accelerate_kwargs
147
+ )
148
+
149
+ self.model = model
150
+ self.dataset = dataset
151
+ self.train_dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True, shuffle = True)
152
+
153
+ optim_kwargs = dict(
154
+ lr = learning_rate,
155
+ weight_decay = weight_decay
156
+ )
157
+
158
+ if optim_klass is MuonAdamAtan2:
159
+ optim = MuonAdamAtan2(
160
+ model.muon_parameters(),
161
+ model.parameters(),
162
+ **optim_kwargs
163
+ )
164
+ else:
165
+ optim = optim_klass(
166
+ model.parameters(),
167
+ **optim_kwargs
168
+ )
169
+
170
+ self.optim = optim
171
+
172
+ self.max_grad_norm = max_grad_norm
173
+
174
+ self.num_train_steps = num_train_steps
175
+ self.batch_size = batch_size
176
+
177
+ (
178
+ self.model,
179
+ self.train_dataloader,
180
+ self.optim
181
+ ) = self.accelerator.prepare(
182
+ self.model,
183
+ self.train_dataloader,
184
+ self.optim
185
+ )
186
+
187
+ @property
188
+ def device(self):
189
+ return self.accelerator.device
190
+
191
+ def print(self, *args, **kwargs):
192
+ return self.accelerator.print(*args, **kwargs)
193
+
194
+ def forward(
195
+ self
196
+ ):
197
+ iter_train_dl = cycle(self.train_dataloader)
198
+
199
+ for _ in range(self.num_train_steps):
200
+ batch_data = next(iter_train_dl)
201
+
202
+ # just assume raw video dynamics training if batch_data is a tensor
203
+ # else kwargs for video, actions, rewards
204
+
205
+ if is_tensor(batch_data):
206
+ loss = self.model(batch_data)
207
+ else:
208
+ loss = self.model(**batch_data)
209
+
210
+ self.accelerator.backward(loss)
211
+
212
+ if exists(self.max_grad_norm):
213
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
214
+
215
+ self.optim.step()
216
+ self.optim.zero_grad()
217
+
218
+ self.print('training complete')
219
+
220
+ # training from dreams
221
+
222
+ class DreamTrainer(Module):
223
+ def __init__(
224
+ self,
225
+ model: DynamicsWorldModel,
226
+ optim_klass = AdamW,
227
+ batch_size = 16,
228
+ generate_timesteps = 16,
229
+ learning_rate = 3e-4,
230
+ max_grad_norm = None,
231
+ num_train_steps = 10_000,
232
+ weight_decay = 0.,
233
+ accelerate_kwargs: dict = dict(),
234
+ optim_kwargs: dict = dict(),
235
+ cpu = False,
236
+ ):
237
+ super().__init__()
238
+
239
+ self.accelerator = Accelerator(
240
+ cpu = cpu,
241
+ **accelerate_kwargs
242
+ )
243
+
244
+ self.model = model
245
+
246
+ optim_kwargs = dict(
247
+ lr = learning_rate,
248
+ weight_decay = weight_decay
249
+ )
250
+
251
+ self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
252
+ self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
253
+
254
+ self.max_grad_norm = max_grad_norm
255
+
256
+ self.num_train_steps = num_train_steps
257
+ self.batch_size = batch_size
258
+ self.generate_timesteps = generate_timesteps
259
+
260
+ self.unwrapped_model = self.model
261
+
262
+ (
263
+ self.model,
264
+ self.policy_head_optim,
265
+ self.value_head_optim,
266
+ ) = self.accelerator.prepare(
267
+ self.model,
268
+ self.policy_head_optim,
269
+ self.value_head_optim
270
+ )
271
+
272
+ @property
273
+ def device(self):
274
+ return self.accelerator.device
275
+
276
+ @property
277
+ def unwrapped_model(self):
278
+ return self.accelerator.unwrap_model(self.model)
279
+
280
+ def print(self, *args, **kwargs):
281
+ return self.accelerator.print(*args, **kwargs)
282
+
283
+ def forward(
284
+ self
285
+ ):
286
+
287
+ for _ in range(self.num_train_steps):
288
+
289
+ dreams = self.unwrapped_model.generate(
290
+ self.generate_timesteps + 1, # plus one for bootstrap value
291
+ batch_size = self.batch_size,
292
+ return_rewards_per_frame = True,
293
+ return_agent_actions = True,
294
+ return_log_probs_and_values = True
295
+ )
296
+
297
+ policy_head_loss, value_head_loss = self.model.learn_from_experience(dreams)
298
+
299
+ self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
300
+
301
+ # update policy head
302
+
303
+ self.accelerator.backward(policy_head_loss)
304
+
305
+ if exists(self.max_grad_norm):
306
+ self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
307
+
308
+ self.policy_head_optim.step()
309
+ self.policy_head_optim.zero_grad()
310
+
311
+ # update value head
312
+
313
+ self.accelerator.backward(value_head_loss)
314
+
315
+ if exists(self.max_grad_norm):
316
+ self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
317
+
318
+ self.value_head_optim.step()
319
+ self.value_head_optim.zero_grad()
320
+
321
+ self.print('training complete')
322
+
323
+ # training from sim
324
+
325
+ class SimTrainer(Module):
326
+ def __init__(
327
+ self,
328
+ model: DynamicsWorldModel,
329
+ optim_klass = AdamW,
330
+ batch_size = 16,
331
+ generate_timesteps = 16,
332
+ learning_rate = 3e-4,
333
+ max_grad_norm = None,
334
+ epochs = 2,
335
+ weight_decay = 0.,
336
+ accelerate_kwargs: dict = dict(),
337
+ optim_kwargs: dict = dict(),
338
+ cpu = False,
339
+ ):
340
+ super().__init__()
341
+
342
+ self.accelerator = Accelerator(
343
+ cpu = cpu,
344
+ **accelerate_kwargs
345
+ )
346
+
347
+ self.model = model
348
+
349
+ optim_kwargs = dict(
350
+ lr = learning_rate,
351
+ weight_decay = weight_decay
352
+ )
353
+
354
+ self.policy_head_optim = AdamW(model.policy_head_parameters(), **optim_kwargs)
355
+ self.value_head_optim = AdamW(model.value_head_parameters(), **optim_kwargs)
356
+
357
+ self.max_grad_norm = max_grad_norm
358
+
359
+ self.epochs = epochs
360
+ self.batch_size = batch_size
361
+
362
+ self.generate_timesteps = generate_timesteps
363
+
364
+ self.unwrapped_model = self.model
365
+
366
+ (
367
+ self.model,
368
+ self.policy_head_optim,
369
+ self.value_head_optim,
370
+ ) = self.accelerator.prepare(
371
+ self.model,
372
+ self.policy_head_optim,
373
+ self.value_head_optim
374
+ )
375
+
376
+ @property
377
+ def device(self):
378
+ return self.accelerator.device
379
+
380
+ @property
381
+ def unwrapped_model(self):
382
+ return self.accelerator.unwrap_model(self.model)
383
+
384
+ def print(self, *args, **kwargs):
385
+ return self.accelerator.print(*args, **kwargs)
386
+
387
+ def learn(
388
+ self,
389
+ experience: Experience
390
+ ):
391
+
392
+ step_size = experience.step_size
393
+ agent_index = experience.agent_index
394
+
395
+ latents = experience.latents
396
+ old_values = experience.values
397
+ rewards = experience.rewards
398
+
399
+ has_agent_embed = exists(experience.agent_embed)
400
+ agent_embed = experience.agent_embed
401
+
402
+ discrete_actions, continuous_actions = experience.actions
403
+ discrete_log_probs, continuous_log_probs = experience.log_probs
404
+
405
+ discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
406
+
407
+ # handle empties
408
+
409
+ empty_tensor = torch.empty_like(rewards)
410
+
411
+ agent_embed = default(agent_embed, empty_tensor)
412
+
413
+ has_discrete = exists(discrete_actions)
414
+ has_continuous = exists(continuous_actions)
415
+
416
+ discrete_actions = default(discrete_actions, empty_tensor)
417
+ continuous_actions = default(continuous_actions, empty_tensor)
418
+
419
+ discrete_log_probs = default(discrete_log_probs, empty_tensor)
420
+ continuous_log_probs = default(continuous_log_probs, empty_tensor)
421
+
422
+ discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
423
+ continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
424
+
425
+ # create the dataset and dataloader
426
+
427
+ dataset = TensorDataset(
428
+ latents,
429
+ discrete_actions,
430
+ continuous_actions,
431
+ discrete_log_probs,
432
+ continuous_log_probs,
433
+ agent_embed,
434
+ discrete_old_action_unembeds,
435
+ continuous_old_action_unembeds,
436
+ old_values,
437
+ rewards
438
+ )
439
+
440
+ dataloader = DataLoader(dataset, batch_size = self.batch_size, shuffle = True)
441
+
442
+ for epoch in range(self.epochs):
443
+
444
+ for (
445
+ latents,
446
+ discrete_actions,
447
+ continuous_actions,
448
+ discrete_log_probs,
449
+ continuous_log_probs,
450
+ agent_embed,
451
+ discrete_old_action_unembeds,
452
+ continuous_old_action_unembeds,
453
+ old_values,
454
+ rewards
455
+ ) in dataloader:
456
+
457
+ actions = (
458
+ discrete_actions if has_discrete else None,
459
+ continuous_actions if has_continuous else None
460
+ )
461
+
462
+ log_probs = (
463
+ discrete_log_probs if has_discrete else None,
464
+ continuous_log_probs if has_continuous else None
465
+ )
466
+
467
+ old_action_unembeds = (
468
+ discrete_old_action_unembeds if has_discrete else None,
469
+ continuous_old_action_unembeds if has_continuous else None
470
+ )
471
+
472
+ batch_experience = Experience(
473
+ latents = latents,
474
+ actions = actions,
475
+ log_probs = log_probs,
476
+ agent_embed = agent_embed if has_agent_embed else None,
477
+ old_action_unembeds = old_action_unembeds,
478
+ values = old_values,
479
+ rewards = rewards,
480
+ step_size = step_size,
481
+ agent_index = agent_index
482
+ )
483
+
484
+ policy_head_loss, value_head_loss = self.model.learn_from_experience(batch_experience)
485
+
486
+ self.print(f'policy head loss: {policy_head_loss.item():.3f} | value head loss: {value_head_loss.item():.3f}')
487
+
488
+ # update policy head
489
+
490
+ self.accelerator.backward(policy_head_loss)
491
+
492
+ if exists(self.max_grad_norm):
493
+ self.accelerator.clip_grad_norm_(self.model.policy_head_parameters()(), self.max_grad_norm)
494
+
495
+ self.policy_head_optim.step()
496
+ self.policy_head_optim.zero_grad()
497
+
498
+ # update value head
499
+
500
+ self.accelerator.backward(value_head_loss)
501
+
502
+ if exists(self.max_grad_norm):
503
+ self.accelerator.clip_grad_norm_(self.model.value_head_parameters(), self.max_grad_norm)
504
+
505
+ self.value_head_optim.step()
506
+ self.value_head_optim.zero_grad()
507
+
508
+ self.print('training complete')
509
+
510
+ def forward(
511
+ self,
512
+ env,
513
+ num_episodes = 50000,
514
+ max_experiences_before_learn = 8,
515
+ env_is_vectorized = False
516
+ ):
517
+
518
+ for _ in range(num_episodes):
519
+
520
+ total_experience = 0
521
+ experiences = []
522
+
523
+ while total_experience < max_experiences_before_learn:
524
+
525
+ experience = self.unwrapped_model.interact_with_env(env, env_is_vectorized = env_is_vectorized)
526
+
527
+ num_experience = experience.video.shape[0]
528
+
529
+ total_experience += num_experience
530
+
531
+ experiences.append(experience.cpu())
532
+
533
+ combined_experiences = combine_experiences(experiences)
534
+
535
+ self.learn(combined_experiences)
536
+
537
+ experiences.clear()
538
+
539
+ self.print('training complete')