dreamer4 0.0.76__tar.gz → 0.0.77__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.76 → dreamer4-0.0.77}/PKG-INFO +1 -1
- {dreamer4-0.0.76 → dreamer4-0.0.77}/dreamer4/dreamer4.py +28 -3
- {dreamer4-0.0.76 → dreamer4-0.0.77}/pyproject.toml +1 -1
- {dreamer4-0.0.76 → dreamer4-0.0.77}/tests/test_dreamer.py +4 -4
- {dreamer4-0.0.76 → dreamer4-0.0.77}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/.gitignore +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/LICENSE +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/README.md +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.77}/dreamer4-fig2.png +0 -0
|
@@ -82,7 +82,7 @@ class Experience:
|
|
|
82
82
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
|
-
lens: Tensor | None = None
|
|
85
|
+
lens: Tensor | None = None
|
|
86
86
|
agent_index: int = 0
|
|
87
87
|
is_from_world_model: bool = True
|
|
88
88
|
|
|
@@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
|
|
|
2226
2226
|
|
|
2227
2227
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
2228
2228
|
|
|
2229
|
+
# handle variable lengths
|
|
2230
|
+
|
|
2231
|
+
max_time = latents.shape[1]
|
|
2232
|
+
is_var_len = exists(experience.lens)
|
|
2233
|
+
|
|
2234
|
+
if is_var_len:
|
|
2235
|
+
lens = experience.lens
|
|
2236
|
+
mask = lens_to_mask(lens, max_time)
|
|
2237
|
+
|
|
2229
2238
|
# determine whether to finetune entire transformer or just learn the heads
|
|
2230
2239
|
|
|
2231
2240
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
@@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
|
|
|
2291
2300
|
|
|
2292
2301
|
# handle entropy loss for naive exploration bonus
|
|
2293
2302
|
|
|
2294
|
-
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2303
|
+
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2295
2304
|
|
|
2296
2305
|
total_policy_loss = (
|
|
2297
2306
|
policy_loss +
|
|
2298
2307
|
entropy_loss * self.policy_entropy_weight
|
|
2299
2308
|
)
|
|
2300
2309
|
|
|
2310
|
+
# maybe handle variable lengths
|
|
2311
|
+
|
|
2312
|
+
if is_var_len:
|
|
2313
|
+
total_policy_loss = total_policy_loss[mask].mean()
|
|
2314
|
+
else:
|
|
2315
|
+
total_policy_loss = total_policy_loss.mean()
|
|
2316
|
+
|
|
2301
2317
|
# maybe take policy optimizer step
|
|
2302
2318
|
|
|
2303
2319
|
if exists(policy_optim):
|
|
@@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
|
|
|
2316
2332
|
|
|
2317
2333
|
return_bins = self.reward_encoder(returns)
|
|
2318
2334
|
|
|
2335
|
+
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
2336
|
+
|
|
2319
2337
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
2320
2338
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
2321
2339
|
|
|
2322
|
-
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2340
|
+
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2341
|
+
|
|
2342
|
+
# maybe variable length
|
|
2343
|
+
|
|
2344
|
+
if is_var_len:
|
|
2345
|
+
value_loss = value_loss[mask].mean()
|
|
2346
|
+
else:
|
|
2347
|
+
value_loss = value_loss.mean()
|
|
2323
2348
|
|
|
2324
2349
|
# maybe take value optimizer step
|
|
2325
2350
|
|
|
@@ -407,14 +407,14 @@ def test_mtp():
|
|
|
407
407
|
|
|
408
408
|
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
|
409
409
|
|
|
410
|
-
assert reward_targets.shape == (3,
|
|
411
|
-
assert mask.shape == (3,
|
|
410
|
+
assert reward_targets.shape == (3, 16, 3)
|
|
411
|
+
assert mask.shape == (3, 16, 3)
|
|
412
412
|
|
|
413
413
|
actions = torch.randint(0, 10, (3, 16, 2))
|
|
414
414
|
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
|
415
415
|
|
|
416
|
-
assert action_targets.shape == (3,
|
|
417
|
-
assert mask.shape == (3,
|
|
416
|
+
assert action_targets.shape == (3, 16, 3, 2)
|
|
417
|
+
assert mask.shape == (3, 16, 3)
|
|
418
418
|
|
|
419
419
|
from dreamer4.dreamer4 import ActionEmbedder
|
|
420
420
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|