dreamer4 0.0.76__tar.gz → 0.0.78__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.76
3
+ Version: 0.0.78
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -13,7 +13,7 @@ import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
- from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
16
+ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
@@ -82,7 +82,8 @@ class Experience:
82
82
  log_probs: tuple[Tensor, Tensor] | None = None
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
- lens: Tensor | None = None,
85
+ lens: Tensor | None = None
86
+ is_truncated: Tensor | None = None
86
87
  agent_index: int = 0
87
88
  is_from_world_model: bool = True
88
89
 
@@ -99,7 +100,10 @@ def combine_experiences(
99
100
  batch, time, device = *latents.shape[:2], latents.device
100
101
 
101
102
  if not exists(exp.lens):
102
- exp.lens = torch.full((batch,), time, device = device)
103
+ exp.lens = full((batch,), time, device = device)
104
+
105
+ if not exists(exp.is_truncated):
106
+ exp.is_truncated = full((batch,), True, device = device)
103
107
 
104
108
  # convert to dictionary
105
109
 
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
2115
2119
 
2116
2120
  time_kv_cache = None
2117
2121
 
2118
- for _ in range(max_timesteps):
2122
+ for i in range(max_timesteps + 1):
2119
2123
 
2120
2124
  latents = self.video_tokenizer(video, return_latents = True)
2121
2125
 
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
2141
2145
 
2142
2146
  one_agent_embed = agent_embed[..., -1:, agent_index, :]
2143
2147
 
2148
+ # values
2149
+
2150
+ value_bins = self.value_head(one_agent_embed)
2151
+ value = self.reward_encoder.bins_to_scalar_value(value_bins)
2152
+
2153
+ values = safe_cat((values, value), dim = 1)
2154
+
2155
+ # policy embed
2156
+
2144
2157
  policy_embed = self.policy_head(one_agent_embed)
2145
2158
 
2146
2159
  # sample actions
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
2162
2175
  discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
2163
2176
  continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
2164
2177
 
2165
- value_bins = self.value_head(one_agent_embed)
2166
- value = self.reward_encoder.bins_to_scalar_value(value_bins)
2167
-
2168
- values = safe_cat((values, value), dim = 1)
2169
-
2170
2178
  # pass the sampled action to the environment and get back next state and reward
2171
2179
 
2172
2180
  next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
2187
2195
 
2188
2196
  # package up one experience for learning
2189
2197
 
2198
+ batch, device = latents.shape[0], latents.device
2199
+
2190
2200
  one_experience = Experience(
2191
2201
  latents = latents,
2192
2202
  video = video[:, :, :-1],
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
2196
2206
  values = values,
2197
2207
  step_size = step_size,
2198
2208
  agent_index = agent_index,
2209
+ lens = full((batch,), max_timesteps, device = device),
2199
2210
  is_from_world_model = False
2200
2211
  )
2201
2212
 
@@ -2224,8 +2235,37 @@ class DynamicsWorldModel(Module):
2224
2235
 
2225
2236
  assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2226
2237
 
2238
+ batch, time = latents.shape[0], latents.shape[1]
2239
+
2240
+ # calculate returns
2241
+
2242
+ # for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
2243
+ # for terminated, will just mask out any after lens
2244
+
2245
+ # if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
2246
+
2247
+ if not exists(experience.is_truncated):
2248
+ experience.is_truncated = full((batch,), True, device = latents.device)
2249
+
2250
+ lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
2251
+ mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
2252
+
2253
+ rewards = rewards.masked_fill(mask_for_gae, 0.)
2254
+ old_values = old_values.masked_fill(mask_for_gae, 0.)
2255
+
2256
+ # calculate returns
2257
+
2227
2258
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
2228
2259
 
2260
+ # handle variable lengths
2261
+
2262
+ max_time = latents.shape[1]
2263
+ is_var_len = exists(experience.lens)
2264
+
2265
+ if is_var_len:
2266
+ lens = experience.lens
2267
+ mask = lens_to_mask(lens, max_time)
2268
+
2229
2269
  # determine whether to finetune entire transformer or just learn the heads
2230
2270
 
2231
2271
  world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
@@ -2291,13 +2331,20 @@ class DynamicsWorldModel(Module):
2291
2331
 
2292
2332
  # handle entropy loss for naive exploration bonus
2293
2333
 
2294
- entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum').mean()
2334
+ entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
2295
2335
 
2296
2336
  total_policy_loss = (
2297
2337
  policy_loss +
2298
2338
  entropy_loss * self.policy_entropy_weight
2299
2339
  )
2300
2340
 
2341
+ # maybe handle variable lengths
2342
+
2343
+ if is_var_len:
2344
+ total_policy_loss = total_policy_loss[mask].mean()
2345
+ else:
2346
+ total_policy_loss = total_policy_loss.mean()
2347
+
2301
2348
  # maybe take policy optimizer step
2302
2349
 
2303
2350
  if exists(policy_optim):
@@ -2316,10 +2363,19 @@ class DynamicsWorldModel(Module):
2316
2363
 
2317
2364
  return_bins = self.reward_encoder(returns)
2318
2365
 
2366
+ value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
2367
+
2319
2368
  value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
2320
2369
  value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
2321
2370
 
2322
- value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
2371
+ value_loss = torch.maximum(value_loss_1, value_loss_2)
2372
+
2373
+ # maybe variable length
2374
+
2375
+ if is_var_len:
2376
+ value_loss = value_loss[mask].mean()
2377
+ else:
2378
+ value_loss = value_loss.mean()
2323
2379
 
2324
2380
  # maybe take value optimizer step
2325
2381
 
@@ -2362,7 +2418,7 @@ class DynamicsWorldModel(Module):
2362
2418
  assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
2363
2419
 
2364
2420
  if isinstance(tasks, int):
2365
- tasks = torch.full((batch_size,), tasks, device = self.device)
2421
+ tasks = full((batch_size,), tasks, device = self.device)
2366
2422
 
2367
2423
  assert not exists(tasks) or tasks.shape[0] == batch_size
2368
2424
 
@@ -2599,7 +2655,7 @@ class DynamicsWorldModel(Module):
2599
2655
  # returning agent actions, rewards, and log probs + values for policy optimization
2600
2656
 
2601
2657
  batch, device = latents.shape[0], latents.device
2602
- experience_lens = torch.full((batch,), time_steps, device = device)
2658
+ experience_lens = full((batch,), time_steps, device = device)
2603
2659
 
2604
2660
  gen = Experience(
2605
2661
  latents = latents,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.76"
3
+ version = "0.0.78"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -407,14 +407,14 @@ def test_mtp():
407
407
 
408
408
  reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
409
409
 
410
- assert reward_targets.shape == (3, 15, 3)
411
- assert mask.shape == (3, 15, 3)
410
+ assert reward_targets.shape == (3, 16, 3)
411
+ assert mask.shape == (3, 16, 3)
412
412
 
413
413
  actions = torch.randint(0, 10, (3, 16, 2))
414
414
  action_targets, mask = create_multi_token_prediction_targets(actions, 3)
415
415
 
416
- assert action_targets.shape == (3, 15, 3, 2)
417
- assert mask.shape == (3, 15, 3)
416
+ assert action_targets.shape == (3, 16, 3, 2)
417
+ assert mask.shape == (3, 16, 3)
418
418
 
419
419
  from dreamer4.dreamer4 import ActionEmbedder
420
420
 
File without changes
File without changes
File without changes
File without changes
File without changes