dreamer4 0.0.95__py3-none-any.whl → 0.0.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/trainers.py CHANGED
@@ -396,13 +396,20 @@ class SimTrainer(Module):
396
396
  old_values = experience.values
397
397
  rewards = experience.rewards
398
398
 
399
+ has_agent_embed = exists(experience.agent_embed)
400
+ agent_embed = experience.agent_embed
401
+
399
402
  discrete_actions, continuous_actions = experience.actions
400
403
  discrete_log_probs, continuous_log_probs = experience.log_probs
401
404
 
405
+ discrete_old_action_unembeds, continuous_old_action_unembeds = default(experience.old_action_unembeds, (None, None))
406
+
402
407
  # handle empties
403
408
 
404
409
  empty_tensor = torch.empty_like(rewards)
405
410
 
411
+ agent_embed = default(agent_embed, empty_tensor)
412
+
406
413
  has_discrete = exists(discrete_actions)
407
414
  has_continuous = exists(continuous_actions)
408
415
 
@@ -412,6 +419,9 @@ class SimTrainer(Module):
412
419
  discrete_log_probs = default(discrete_log_probs, empty_tensor)
413
420
  continuous_log_probs = default(continuous_log_probs, empty_tensor)
414
421
 
422
+ discrete_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
423
+ continuous_old_action_unembeds = default(discrete_old_action_unembeds, empty_tensor)
424
+
415
425
  # create the dataset and dataloader
416
426
 
417
427
  dataset = TensorDataset(
@@ -420,6 +430,9 @@ class SimTrainer(Module):
420
430
  continuous_actions,
421
431
  discrete_log_probs,
422
432
  continuous_log_probs,
433
+ agent_embed,
434
+ discrete_old_action_unembeds,
435
+ continuous_old_action_unembeds,
423
436
  old_values,
424
437
  rewards
425
438
  )
@@ -434,6 +447,9 @@ class SimTrainer(Module):
434
447
  continuous_actions,
435
448
  discrete_log_probs,
436
449
  continuous_log_probs,
450
+ agent_embed,
451
+ discrete_old_action_unembeds,
452
+ continuous_old_action_unembeds,
437
453
  old_values,
438
454
  rewards
439
455
  ) in dataloader:
@@ -448,10 +464,17 @@ class SimTrainer(Module):
448
464
  continuous_log_probs if has_continuous else None
449
465
  )
450
466
 
467
+ old_action_unembeds = (
468
+ discrete_old_action_unembeds if has_discrete else None,
469
+ continuous_old_action_unembeds if has_continuous else None
470
+ )
471
+
451
472
  batch_experience = Experience(
452
473
  latents = latents,
453
474
  actions = actions,
454
475
  log_probs = log_probs,
476
+ agent_embed = agent_embed if has_agent_embed else None,
477
+ old_action_unembeds = old_action_unembeds,
455
478
  values = old_values,
456
479
  rewards = rewards,
457
480
  step_size = step_size,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.95
3
+ Version: 0.0.96
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=pSwW4DpHMF78ortsHLVHDdWzrsjse6bIVxX3oolA-Ao,118572
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=JsnJwQJcbI_75KBTNddG6b7QVkO6LD1N_HQiVe-VnCM,15087
5
+ dreamer4-0.0.96.dist-info/METADATA,sha256=yIipeLorTHctLFwpd3bPExL6-nG5LxTfxPImf5BngN4,3065
6
+ dreamer4-0.0.96.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.96.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.96.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=pSwW4DpHMF78ortsHLVHDdWzrsjse6bIVxX3oolA-Ao,118572
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.95.dist-info/METADATA,sha256=VWiRy5xotYsn9HL5EFymn9N8j8-_wFKbYeTA5k6E0z4,3065
6
- dreamer4-0.0.95.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.95.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.95.dist-info/RECORD,,