dreamer4 0.0.71__py3-none-any.whl → 0.0.72__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dreamer4/dreamer4.py CHANGED
@@ -155,6 +155,15 @@ def is_power_two(num):
155
155
  def is_empty(t):
156
156
  return t.numel() == 0
157
157
 
158
+ def lens_to_mask(t, max_len = None):
159
+ if not exists(max_len):
160
+ max_len = t.amax().item()
161
+
162
+ device = t.device
163
+ seq = torch.arange(max_len, device = device)
164
+
165
+ return einx.less('j, i -> i j', seq, t)
166
+
158
167
  def log(t, eps = 1e-20):
159
168
  return t.clamp(min = eps).log()
160
169
 
@@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module):
2581
2590
  *,
2582
2591
  video = None, # (b v? c t vh vw)
2583
2592
  latents = None, # (b t v? n d) | (b t v? d)
2593
+ lens = None, # (b)
2584
2594
  signal_levels = None, # () | (b) | (b t)
2585
2595
  step_sizes = None, # () | (b)
2586
2596
  step_sizes_log2 = None, # () | (b)
@@ -3014,7 +3024,19 @@ class DynamicsWorldModel(Module):
3014
3024
 
3015
3025
  flow_losses = flow_losses * loss_weight
3016
3026
 
3017
- flow_loss = flow_losses.mean()
3027
+ # handle variable lengths if needed
3028
+
3029
+ is_var_len = exists(lens)
3030
+
3031
+ if is_var_len:
3032
+
3033
+ loss_mask = lens_to_mask(lens, time)
3034
+ loss_mask_without_last = loss_mask[:, :-1]
3035
+
3036
+ flow_loss = flow_losses[loss_mask].mean()
3037
+
3038
+ else:
3039
+ flow_loss = flow_losses.mean()
3018
3040
 
3019
3041
  # now take care of the agent token losses
3020
3042
 
@@ -3037,7 +3059,10 @@ class DynamicsWorldModel(Module):
3037
3059
 
3038
3060
  reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
3039
3061
 
3040
- reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3062
+ if is_var_len:
3063
+ reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
3064
+ else:
3065
+ reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
3041
3066
 
3042
3067
  # maybe autoregressive action loss
3043
3068
 
@@ -3080,12 +3105,20 @@ class DynamicsWorldModel(Module):
3080
3105
  if exists(discrete_log_probs):
3081
3106
  discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
3082
3107
 
3083
- discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3108
+ if is_var_len:
3109
+ discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
3110
+ discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3111
+ else:
3112
+ discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
3084
3113
 
3085
3114
  if exists(continuous_log_probs):
3086
3115
  continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
3087
3116
 
3088
- continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3117
+ if is_var_len:
3118
+ continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
3119
+ continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
3120
+ else:
3121
+ continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
3089
3122
 
3090
3123
  # handle loss normalization
3091
3124
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.71
3
+ Version: 0.0.72
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=o9f0Y-BIKAm_qpqxJOkAnkP_WmzNqPq92Nu5tep2tm0,104640
3
+ dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
+ dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
+ dreamer4-0.0.72.dist-info/METADATA,sha256=WCVEjMXAe8awAYkuBhdvyHfWugqbdsRR3WCSBQXFs18,3065
6
+ dreamer4-0.0.72.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.72.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.72.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=ywbJM384Z1VXQOHyv5RwUaCmsKAJyo2CmaDpVknml2c,103385
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
- dreamer4-0.0.71.dist-info/METADATA,sha256=wUWSPxR5xZfSsICp3GLE-p25hOSdqcdV-7Zh_vAYV64,3065
6
- dreamer4-0.0.71.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.71.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.71.dist-info/RECORD,,