dreamer4 0.0.78__tar.gz → 0.0.79__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.78
3
+ Version: 0.0.79
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -2206,7 +2206,7 @@ class DynamicsWorldModel(Module):
2206
2206
  values = values,
2207
2207
  step_size = step_size,
2208
2208
  agent_index = agent_index,
2209
- lens = full((batch,), max_timesteps, device = device),
2209
+ lens = full((batch,), max_timesteps + 1, device = device),
2210
2210
  is_from_world_model = False
2211
2211
  )
2212
2212
 
@@ -2239,16 +2239,12 @@ class DynamicsWorldModel(Module):
2239
2239
 
2240
2240
  # calculate returns
2241
2241
 
2242
- # for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
2243
- # for terminated, will just mask out any after lens
2244
-
2245
- # if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
2242
+ # mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
2246
2243
 
2247
2244
  if not exists(experience.is_truncated):
2248
2245
  experience.is_truncated = full((batch,), True, device = latents.device)
2249
2246
 
2250
- lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
2251
- mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
2247
+ mask_for_gae = lens_to_mask(experience.lens, time)
2252
2248
 
2253
2249
  rewards = rewards.masked_fill(mask_for_gae, 0.)
2254
2250
  old_values = old_values.masked_fill(mask_for_gae, 0.)
@@ -2263,8 +2259,8 @@ class DynamicsWorldModel(Module):
2263
2259
  is_var_len = exists(experience.lens)
2264
2260
 
2265
2261
  if is_var_len:
2266
- lens = experience.lens
2267
- mask = lens_to_mask(lens, max_time)
2262
+ learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2263
+ mask = lens_to_mask(learnable_lens, max_time)
2268
2264
 
2269
2265
  # determine whether to finetune entire transformer or just learn the heads
2270
2266
 
@@ -287,7 +287,7 @@ class DreamTrainer(Module):
287
287
  for _ in range(self.num_train_steps):
288
288
 
289
289
  dreams = self.unwrapped_model.generate(
290
- self.generate_timesteps,
290
+ self.generate_timesteps + 1, # plus one for bootstrap value
291
291
  batch_size = self.batch_size,
292
292
  return_rewards_per_frame = True,
293
293
  return_agent_actions = True,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.78"
3
+ version = "0.0.79"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes