dreamer4 0.0.71__py3-none-any.whl → 0.0.72__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dreamer4/dreamer4.py +37 -4
- {dreamer4-0.0.71.dist-info → dreamer4-0.0.72.dist-info}/METADATA +1 -1
- dreamer4-0.0.72.dist-info/RECORD +8 -0
- dreamer4-0.0.71.dist-info/RECORD +0 -8
- {dreamer4-0.0.71.dist-info → dreamer4-0.0.72.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.71.dist-info → dreamer4-0.0.72.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -155,6 +155,15 @@ def is_power_two(num):
|
|
|
155
155
|
def is_empty(t):
|
|
156
156
|
return t.numel() == 0
|
|
157
157
|
|
|
158
|
+
def lens_to_mask(t, max_len = None):
|
|
159
|
+
if not exists(max_len):
|
|
160
|
+
max_len = t.amax().item()
|
|
161
|
+
|
|
162
|
+
device = t.device
|
|
163
|
+
seq = torch.arange(max_len, device = device)
|
|
164
|
+
|
|
165
|
+
return einx.less('j, i -> i j', seq, t)
|
|
166
|
+
|
|
158
167
|
def log(t, eps = 1e-20):
|
|
159
168
|
return t.clamp(min = eps).log()
|
|
160
169
|
|
|
@@ -2581,6 +2590,7 @@ class DynamicsWorldModel(Module):
|
|
|
2581
2590
|
*,
|
|
2582
2591
|
video = None, # (b v? c t vh vw)
|
|
2583
2592
|
latents = None, # (b t v? n d) | (b t v? d)
|
|
2593
|
+
lens = None, # (b)
|
|
2584
2594
|
signal_levels = None, # () | (b) | (b t)
|
|
2585
2595
|
step_sizes = None, # () | (b)
|
|
2586
2596
|
step_sizes_log2 = None, # () | (b)
|
|
@@ -3014,7 +3024,19 @@ class DynamicsWorldModel(Module):
|
|
|
3014
3024
|
|
|
3015
3025
|
flow_losses = flow_losses * loss_weight
|
|
3016
3026
|
|
|
3017
|
-
|
|
3027
|
+
# handle variable lengths if needed
|
|
3028
|
+
|
|
3029
|
+
is_var_len = exists(lens)
|
|
3030
|
+
|
|
3031
|
+
if is_var_len:
|
|
3032
|
+
|
|
3033
|
+
loss_mask = lens_to_mask(lens, time)
|
|
3034
|
+
loss_mask_without_last = loss_mask[:, :-1]
|
|
3035
|
+
|
|
3036
|
+
flow_loss = flow_losses[loss_mask].mean()
|
|
3037
|
+
|
|
3038
|
+
else:
|
|
3039
|
+
flow_loss = flow_losses.mean()
|
|
3018
3040
|
|
|
3019
3041
|
# now take care of the agent token losses
|
|
3020
3042
|
|
|
@@ -3037,7 +3059,10 @@ class DynamicsWorldModel(Module):
|
|
|
3037
3059
|
|
|
3038
3060
|
reward_losses = reward_losses.masked_fill(reward_loss_mask, 0.)
|
|
3039
3061
|
|
|
3040
|
-
|
|
3062
|
+
if is_var_len:
|
|
3063
|
+
reward_loss = reward_losses[loss_mask_without_last].mean(dim = 0)
|
|
3064
|
+
else:
|
|
3065
|
+
reward_loss = reduce(reward_losses, '... mtp -> mtp', 'mean') # they sum across the prediction steps (mtp dimension) - eq(9)
|
|
3041
3066
|
|
|
3042
3067
|
# maybe autoregressive action loss
|
|
3043
3068
|
|
|
@@ -3080,12 +3105,20 @@ class DynamicsWorldModel(Module):
|
|
|
3080
3105
|
if exists(discrete_log_probs):
|
|
3081
3106
|
discrete_log_probs = discrete_log_probs.masked_fill(~discrete_mask[..., None], 0.)
|
|
3082
3107
|
|
|
3083
|
-
|
|
3108
|
+
if is_var_len:
|
|
3109
|
+
discrete_action_losses = rearrange(-discrete_log_probs, 'mtp b t na -> b t na mtp')
|
|
3110
|
+
discrete_action_loss = reduce(discrete_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3111
|
+
else:
|
|
3112
|
+
discrete_action_loss = reduce(-discrete_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3084
3113
|
|
|
3085
3114
|
if exists(continuous_log_probs):
|
|
3086
3115
|
continuous_log_probs = continuous_log_probs.masked_fill(~continuous_mask[..., None], 0.)
|
|
3087
3116
|
|
|
3088
|
-
|
|
3117
|
+
if is_var_len:
|
|
3118
|
+
continuous_action_losses = rearrange(-continuous_log_probs, 'mtp b t na -> b t na mtp')
|
|
3119
|
+
continuous_action_loss = reduce(continuous_action_losses[loss_mask_without_last], '... mtp -> mtp', 'mean')
|
|
3120
|
+
else:
|
|
3121
|
+
continuous_action_loss = reduce(-continuous_log_probs, 'mtp b t na -> mtp', 'mean')
|
|
3089
3122
|
|
|
3090
3123
|
# handle loss normalization
|
|
3091
3124
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=o9f0Y-BIKAm_qpqxJOkAnkP_WmzNqPq92Nu5tep2tm0,104640
|
|
3
|
+
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
+
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
+
dreamer4-0.0.72.dist-info/METADATA,sha256=WCVEjMXAe8awAYkuBhdvyHfWugqbdsRR3WCSBQXFs18,3065
|
|
6
|
+
dreamer4-0.0.72.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.72.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.72.dist-info/RECORD,,
|
dreamer4-0.0.71.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=ywbJM384Z1VXQOHyv5RwUaCmsKAJyo2CmaDpVknml2c,103385
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
-
dreamer4-0.0.71.dist-info/METADATA,sha256=wUWSPxR5xZfSsICp3GLE-p25hOSdqcdV-7Zh_vAYV64,3065
|
|
6
|
-
dreamer4-0.0.71.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.71.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.71.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|