dreamer4 0.0.77__tar.gz → 0.0.78__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.77
3
+ Version: 0.0.78
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -13,7 +13,7 @@ import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
- from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
16
+ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
@@ -83,6 +83,7 @@ class Experience:
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
85
  lens: Tensor | None = None
86
+ is_truncated: Tensor | None = None
86
87
  agent_index: int = 0
87
88
  is_from_world_model: bool = True
88
89
 
@@ -99,7 +100,10 @@ def combine_experiences(
99
100
  batch, time, device = *latents.shape[:2], latents.device
100
101
 
101
102
  if not exists(exp.lens):
102
- exp.lens = torch.full((batch,), time, device = device)
103
+ exp.lens = full((batch,), time, device = device)
104
+
105
+ if not exists(exp.is_truncated):
106
+ exp.is_truncated = full((batch,), True, device = device)
103
107
 
104
108
  # convert to dictionary
105
109
 
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
2115
2119
 
2116
2120
  time_kv_cache = None
2117
2121
 
2118
- for _ in range(max_timesteps):
2122
+ for i in range(max_timesteps + 1):
2119
2123
 
2120
2124
  latents = self.video_tokenizer(video, return_latents = True)
2121
2125
 
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
2141
2145
 
2142
2146
  one_agent_embed = agent_embed[..., -1:, agent_index, :]
2143
2147
 
2148
+ # values
2149
+
2150
+ value_bins = self.value_head(one_agent_embed)
2151
+ value = self.reward_encoder.bins_to_scalar_value(value_bins)
2152
+
2153
+ values = safe_cat((values, value), dim = 1)
2154
+
2155
+ # policy embed
2156
+
2144
2157
  policy_embed = self.policy_head(one_agent_embed)
2145
2158
 
2146
2159
  # sample actions
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
2162
2175
  discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
2163
2176
  continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
2164
2177
 
2165
- value_bins = self.value_head(one_agent_embed)
2166
- value = self.reward_encoder.bins_to_scalar_value(value_bins)
2167
-
2168
- values = safe_cat((values, value), dim = 1)
2169
-
2170
2178
  # pass the sampled action to the environment and get back next state and reward
2171
2179
 
2172
2180
  next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
2187
2195
 
2188
2196
  # package up one experience for learning
2189
2197
 
2198
+ batch, device = latents.shape[0], latents.device
2199
+
2190
2200
  one_experience = Experience(
2191
2201
  latents = latents,
2192
2202
  video = video[:, :, :-1],
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
2196
2206
  values = values,
2197
2207
  step_size = step_size,
2198
2208
  agent_index = agent_index,
2209
+ lens = full((batch,), max_timesteps, device = device),
2199
2210
  is_from_world_model = False
2200
2211
  )
2201
2212
 
@@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
2224
2235
 
2225
2236
  assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2226
2237
 
2238
+ batch, time = latents.shape[0], latents.shape[1]
2239
+
2240
+ # calculate returns
2241
+
2242
+ # for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
2243
+ # for terminated, will just mask out any after lens
2244
+
2245
+ # if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
2246
+
2247
+ if not exists(experience.is_truncated):
2248
+ experience.is_truncated = full((batch,), True, device = latents.device)
2249
+
2250
+ lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
2251
+ mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
2252
+
2253
+ rewards = rewards.masked_fill(mask_for_gae, 0.)
2254
+ old_values = old_values.masked_fill(mask_for_gae, 0.)
2255
+
2256
+ # calculate returns
2257
+
2227
2258
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
2228
2259
 
2229
2260
  # handle variable lengths
@@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
2387
2418
  assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
2388
2419
 
2389
2420
  if isinstance(tasks, int):
2390
- tasks = torch.full((batch_size,), tasks, device = self.device)
2421
+ tasks = full((batch_size,), tasks, device = self.device)
2391
2422
 
2392
2423
  assert not exists(tasks) or tasks.shape[0] == batch_size
2393
2424
 
@@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
2624
2655
  # returning agent actions, rewards, and log probs + values for policy optimization
2625
2656
 
2626
2657
  batch, device = latents.shape[0], latents.device
2627
- experience_lens = torch.full((batch,), time_steps, device = device)
2658
+ experience_lens = full((batch,), time_steps, device = device)
2628
2659
 
2629
2660
  gen = Experience(
2630
2661
  latents = latents,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dreamer4"
3
- version = "0.0.77"
3
+ version = "0.0.78"
4
4
  description = "Dreamer 4"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
File without changes
File without changes
File without changes
File without changes
File without changes