dreamer4 0.0.75__py3-none-any.whl → 0.0.77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +39 -5
- {dreamer4-0.0.75.dist-info → dreamer4-0.0.77.dist-info}/METADATA +1 -1
- dreamer4-0.0.77.dist-info/RECORD +8 -0
- dreamer4-0.0.75.dist-info/RECORD +0 -8
- {dreamer4-0.0.75.dist-info → dreamer4-0.0.77.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.75.dist-info → dreamer4-0.0.77.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -82,7 +82,7 @@ class Experience:
|
|
|
82
82
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
|
-
lens: Tensor | None = None
|
|
85
|
+
lens: Tensor | None = None
|
|
86
86
|
agent_index: int = 0
|
|
87
87
|
is_from_world_model: bool = True
|
|
88
88
|
|
|
@@ -284,7 +284,7 @@ def create_multi_token_prediction_targets(
|
|
|
284
284
|
batch, seq_len, device = *t.shape[:2], t.device
|
|
285
285
|
|
|
286
286
|
batch_arange = arange(batch, device = device)
|
|
287
|
-
seq_arange = arange(seq_len, device = device)
|
|
287
|
+
seq_arange = arange(seq_len, device = device)
|
|
288
288
|
steps_arange = arange(steps_future, device = device)
|
|
289
289
|
|
|
290
290
|
indices = add('t, steps -> t steps', seq_arange, steps_arange)
|
|
@@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
|
|
|
2226
2226
|
|
|
2227
2227
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
2228
2228
|
|
|
2229
|
+
# handle variable lengths
|
|
2230
|
+
|
|
2231
|
+
max_time = latents.shape[1]
|
|
2232
|
+
is_var_len = exists(experience.lens)
|
|
2233
|
+
|
|
2234
|
+
if is_var_len:
|
|
2235
|
+
lens = experience.lens
|
|
2236
|
+
mask = lens_to_mask(lens, max_time)
|
|
2237
|
+
|
|
2229
2238
|
# determine whether to finetune entire transformer or just learn the heads
|
|
2230
2239
|
|
|
2231
2240
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
@@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
|
|
|
2291
2300
|
|
|
2292
2301
|
# handle entropy loss for naive exploration bonus
|
|
2293
2302
|
|
|
2294
|
-
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2303
|
+
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2295
2304
|
|
|
2296
2305
|
total_policy_loss = (
|
|
2297
2306
|
policy_loss +
|
|
2298
2307
|
entropy_loss * self.policy_entropy_weight
|
|
2299
2308
|
)
|
|
2300
2309
|
|
|
2310
|
+
# maybe handle variable lengths
|
|
2311
|
+
|
|
2312
|
+
if is_var_len:
|
|
2313
|
+
total_policy_loss = total_policy_loss[mask].mean()
|
|
2314
|
+
else:
|
|
2315
|
+
total_policy_loss = total_policy_loss.mean()
|
|
2316
|
+
|
|
2301
2317
|
# maybe take policy optimizer step
|
|
2302
2318
|
|
|
2303
2319
|
if exists(policy_optim):
|
|
@@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
|
|
|
2316
2332
|
|
|
2317
2333
|
return_bins = self.reward_encoder(returns)
|
|
2318
2334
|
|
|
2335
|
+
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
2336
|
+
|
|
2319
2337
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
2320
2338
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
2321
2339
|
|
|
2322
|
-
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2340
|
+
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2341
|
+
|
|
2342
|
+
# maybe variable length
|
|
2343
|
+
|
|
2344
|
+
if is_var_len:
|
|
2345
|
+
value_loss = value_loss[mask].mean()
|
|
2346
|
+
else:
|
|
2347
|
+
value_loss = value_loss.mean()
|
|
2323
2348
|
|
|
2324
2349
|
# maybe take value optimizer step
|
|
2325
2350
|
|
|
@@ -3100,7 +3125,7 @@ class DynamicsWorldModel(Module):
|
|
|
3100
3125
|
|
|
3101
3126
|
reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
|
|
3102
3127
|
|
|
3103
|
-
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
|
|
3128
|
+
reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
|
|
3104
3129
|
|
|
3105
3130
|
reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
|
|
3106
3131
|
|
|
@@ -3126,6 +3151,15 @@ class DynamicsWorldModel(Module):
|
|
|
3126
3151
|
):
|
|
3127
3152
|
assert self.action_embedder.has_actions
|
|
3128
3153
|
|
|
3154
|
+
# handle actions having time vs time - 1 length
|
|
3155
|
+
# remove the first action if it is equal to time (as it would come from some agent token in the past)
|
|
3156
|
+
|
|
3157
|
+
if exists(discrete_actions) and discrete_actions.shape[1] == time:
|
|
3158
|
+
discrete_actions = discrete_actions[:, 1:]
|
|
3159
|
+
|
|
3160
|
+
if exists(continuous_actions) and continuous_actions.shape[1] == time:
|
|
3161
|
+
continuous_actions = continuous_actions[:, 1:]
|
|
3162
|
+
|
|
3129
3163
|
# only for 1 agent
|
|
3130
3164
|
|
|
3131
3165
|
agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=v6W9psBXMCGo9K3F6xCxm49TEu3jFUqDvo3pf_bsBQo,107099
|
|
3
|
+
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
+
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
+
dreamer4-0.0.77.dist-info/METADATA,sha256=8XGi2Q8meZfrz2K1cpbVH1P34wODclA0JD7OG0xldaQ,3065
|
|
6
|
+
dreamer4-0.0.77.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.77.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.77.dist-info/RECORD,,
|
dreamer4-0.0.75.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=6ngJq_cpil97RsITPmcExlxoZTcZ6XNdFfOBDcdACQg,105915
|
|
3
|
-
dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
|
|
4
|
-
dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
|
|
5
|
-
dreamer4-0.0.75.dist-info/METADATA,sha256=VVaIj0vNfpT2JBm9AaSr8D-SP5wfWNgHcJjszdmkwU4,3065
|
|
6
|
-
dreamer4-0.0.75.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.75.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.75.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|