dreamer4 0.0.76__tar.gz → 0.0.77__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.76
3
+ Version: 0.0.77
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -82,7 +82,7 @@ class Experience:
82
82
  log_probs: tuple[Tensor, Tensor] | None = None
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
- lens: Tensor | None = None,
85
+ lens: Tensor | None = None
86
86
  agent_index: int = 0
87
87
  is_from_world_model: bool = True
88
88
 
@@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
2226
2226
 
2227
2227
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
2228
2228
 
2229
+ # handle variable lengths
2230
+
2231
+ max_time = latents.shape[1]
2232
+ is_var_len = exists(experience.lens)
2233
+
2234
+ if is_var_len:
2235
+ lens = experience.lens
2236
+ mask = lens_to_mask(lens, max_time)
2237
+
2229
2238
  # determine whether to finetune entire transformer or just learn the heads
2230
2239
 
2231
2240
  world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
@@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
2291
2300
 
2292
2301
  # handle entropy loss for naive exploration bonus
2293
2302
 
2294
- entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
2303
+ entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2295
2304
 
2296
2305
  total_policy_loss = (
2297
2306
  policy_loss +
2298
2307
  entropy_loss * self.policy_entropy_weight
2299
2308
  )
2300
2309
 
2310
+ # maybe handle variable lengths
2311
+
2312
+ if is_var_len:
2313
+ total_policy_loss = total_policy_loss[mask].mean()
2314
+ else:
2315
+ total_policy_loss = total_policy_loss.mean()
2316
+
2301
2317
  # maybe take policy optimizer step
2302
2318
 
2303
2319
  if exists(policy_optim):
@@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
2316
2332
 
2317
2333
  return_bins = self.reward_encoder(returns)
2318
2334
 
2335
+ value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
2336
+
2319
2337
  value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
2320
2338
  value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
2321
2339
 
2322
- value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
2340
+ value_loss = torch.maximum(value_loss_1, value_loss_2)
2341
+
2342
+ # maybe variable length
2343
+
2344
+ if is_var_len:
2345
+ value_loss = value_loss[mask].mean()
2346
+ else:
2347
+ value_loss = value_loss.mean()
2323
2348
 
2324
2349
  # maybe take value optimizer step
2325
2350
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.76"
3
+ version = "0.0.77"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -407,14 +407,14 @@ def test_mtp():
407
407
 
408
408
  reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
409
409
 
410
- assert reward_targets.shape == (3, 15, 3)
411
- assert mask.shape == (3, 15, 3)
410
+ assert reward_targets.shape == (3, 16, 3)
411
+ assert mask.shape == (3, 16, 3)
412
412
 
413
413
  actions = torch.randint(0, 10, (3, 16, 2))
414
414
  action_targets, mask = create_multi_token_prediction_targets(actions, 3)
415
415
 
416
- assert action_targets.shape == (3, 15, 3, 2)
417
- assert mask.shape == (3, 15, 3)
416
+ assert action_targets.shape == (3, 16, 3, 2)
417
+ assert mask.shape == (3, 16, 3)
418
418
 
419
419
  from dreamer4.dreamer4 import ActionEmbedder
420
420
 
File without changes
File without changes
File without changes
File without changes
File without changes