dreamer4 0.0.78__tar.gz → 0.0.79__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.78 → dreamer4-0.0.79}/PKG-INFO +1 -1
- {dreamer4-0.0.78 → dreamer4-0.0.79}/dreamer4/dreamer4.py +5 -9
- {dreamer4-0.0.78 → dreamer4-0.0.79}/dreamer4/trainers.py +1 -1
- {dreamer4-0.0.78 → dreamer4-0.0.79}/pyproject.toml +1 -1
- {dreamer4-0.0.78 → dreamer4-0.0.79}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/.gitignore +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/LICENSE +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/README.md +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.78 → dreamer4-0.0.79}/tests/test_dreamer.py +0 -0
|
@@ -2206,7 +2206,7 @@ class DynamicsWorldModel(Module):
|
|
|
2206
2206
|
values = values,
|
|
2207
2207
|
step_size = step_size,
|
|
2208
2208
|
agent_index = agent_index,
|
|
2209
|
-
lens = full((batch,), max_timesteps, device = device),
|
|
2209
|
+
lens = full((batch,), max_timesteps + 1, device = device),
|
|
2210
2210
|
is_from_world_model = False
|
|
2211
2211
|
)
|
|
2212
2212
|
|
|
@@ -2239,16 +2239,12 @@ class DynamicsWorldModel(Module):
|
|
|
2239
2239
|
|
|
2240
2240
|
# calculate returns
|
|
2241
2241
|
|
|
2242
|
-
#
|
|
2243
|
-
# for terminated, will just mask out any after lens
|
|
2244
|
-
|
|
2245
|
-
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
|
2242
|
+
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
|
2246
2243
|
|
|
2247
2244
|
if not exists(experience.is_truncated):
|
|
2248
2245
|
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2249
2246
|
|
|
2250
|
-
|
|
2251
|
-
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
|
2247
|
+
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2252
2248
|
|
|
2253
2249
|
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2254
2250
|
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
@@ -2263,8 +2259,8 @@ class DynamicsWorldModel(Module):
|
|
|
2263
2259
|
is_var_len = exists(experience.lens)
|
|
2264
2260
|
|
|
2265
2261
|
if is_var_len:
|
|
2266
|
-
|
|
2267
|
-
mask = lens_to_mask(
|
|
2262
|
+
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2263
|
+
mask = lens_to_mask(learnable_lens, max_time)
|
|
2268
2264
|
|
|
2269
2265
|
# determine whether to finetune entire transformer or just learn the heads
|
|
2270
2266
|
|
|
@@ -287,7 +287,7 @@ class DreamTrainer(Module):
|
|
|
287
287
|
for _ in range(self.num_train_steps):
|
|
288
288
|
|
|
289
289
|
dreams = self.unwrapped_model.generate(
|
|
290
|
-
self.generate_timesteps,
|
|
290
|
+
self.generate_timesteps + 1, # plus one for bootstrap value
|
|
291
291
|
batch_size = self.batch_size,
|
|
292
292
|
return_rewards_per_frame = True,
|
|
293
293
|
return_agent_actions = True,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|