dreamer4 0.0.75__py3-none-any.whl → 0.0.77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -82,7 +82,7 @@ class Experience:
82
82
  log_probs: tuple[Tensor, Tensor] | None = None
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
- lens: Tensor | None = None,
85
+ lens: Tensor | None = None
86
86
  agent_index: int = 0
87
87
  is_from_world_model: bool = True
88
88
 
@@ -284,7 +284,7 @@ def create_multi_token_prediction_targets(
284
284
  batch, seq_len, device = *t.shape[:2], t.device
285
285
 
286
286
  batch_arange = arange(batch, device = device)
287
- seq_arange = arange(seq_len, device = device)[1:]
287
+ seq_arange = arange(seq_len, device = device)
288
288
  steps_arange = arange(steps_future, device = device)
289
289
 
290
290
  indices = add('t, steps -> t steps', seq_arange, steps_arange)
@@ -2226,6 +2226,15 @@ class DynamicsWorldModel(Module):
2226
2226
 
2227
2227
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
2228
2228
 
2229
+ # handle variable lengths
2230
+
2231
+ max_time = latents.shape[1]
2232
+ is_var_len = exists(experience.lens)
2233
+
2234
+ if is_var_len:
2235
+ lens = experience.lens
2236
+ mask = lens_to_mask(lens, max_time)
2237
+
2229
2238
  # determine whether to finetune entire transformer or just learn the heads
2230
2239
 
2231
2240
  world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
@@ -2291,13 +2300,20 @@ class DynamicsWorldModel(Module):
2291
2300
 
2292
2301
  # handle entropy loss for naive exploration bonus
2293
2302
 
2294
- entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
2303
+ entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2295
2304
 
2296
2305
  total_policy_loss = (
2297
2306
  policy_loss +
2298
2307
  entropy_loss * self.policy_entropy_weight
2299
2308
  )
2300
2309
 
2310
+ # maybe handle variable lengths
2311
+
2312
+ if is_var_len:
2313
+ total_policy_loss = total_policy_loss[mask].mean()
2314
+ else:
2315
+ total_policy_loss = total_policy_loss.mean()
2316
+
2301
2317
  # maybe take policy optimizer step
2302
2318
 
2303
2319
  if exists(policy_optim):
@@ -2316,10 +2332,19 @@ class DynamicsWorldModel(Module):
2316
2332
 
2317
2333
  return_bins = self.reward_encoder(returns)
2318
2334
 
2335
+ value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
2336
+
2319
2337
  value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
2320
2338
  value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
2321
2339
 
2322
- value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
2340
+ value_loss = torch.maximum(value_loss_1, value_loss_2)
2341
+
2342
+ # maybe variable length
2343
+
2344
+ if is_var_len:
2345
+ value_loss = value_loss[mask].mean()
2346
+ else:
2347
+ value_loss = value_loss.mean()
2323
2348
 
2324
2349
  # maybe take value optimizer step
2325
2350
 
@@ -3100,7 +3125,7 @@ class DynamicsWorldModel(Module):
3100
3125
 
3101
3126
  reward_pred = rearrange(reward_pred, 'mtp b t l -> b l t mtp')
3102
3127
 
3103
- reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding, self.multi_token_pred_len)
3128
+ reward_targets, reward_loss_mask = create_multi_token_prediction_targets(two_hot_encoding[:, :-1], self.multi_token_pred_len)
3104
3129
 
3105
3130
  reward_targets = rearrange(reward_targets, 'b t mtp l -> b l t mtp')
3106
3131
 
@@ -3126,6 +3151,15 @@ class DynamicsWorldModel(Module):
3126
3151
  ):
3127
3152
  assert self.action_embedder.has_actions
3128
3153
 
3154
+ # handle actions having time vs time - 1 length
3155
+ # remove the first action if it is equal to time (as it would come from some agent token in the past)
3156
+
3157
+ if exists(discrete_actions) and discrete_actions.shape[1] == time:
3158
+ discrete_actions = discrete_actions[:, 1:]
3159
+
3160
+ if exists(continuous_actions) and continuous_actions.shape[1] == time:
3161
+ continuous_actions = continuous_actions[:, 1:]
3162
+
3129
3163
  # only for 1 agent
3130
3164
 
3131
3165
  agent_tokens = rearrange(agent_tokens, 'b t 1 d -> b t d')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.75
3
+ Version: 0.0.77
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=v6W9psBXMCGo9K3F6xCxm49TEu3jFUqDvo3pf_bsBQo,107099
3
+ dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
+ dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
+ dreamer4-0.0.77.dist-info/METADATA,sha256=8XGi2Q8meZfrz2K1cpbVH1P34wODclA0JD7OG0xldaQ,3065
6
+ dreamer4-0.0.77.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.77.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.77.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=6ngJq_cpil97RsITPmcExlxoZTcZ6XNdFfOBDcdACQg,105915
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
- dreamer4-0.0.75.dist-info/METADATA,sha256=VVaIj0vNfpT2JBm9AaSr8D-SP5wfWNgHcJjszdmkwU4,3065
6
- dreamer4-0.0.75.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.75.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.75.dist-info/RECORD,,