dreamer4 0.0.77__py3-none-any.whl → 0.0.79__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -13,7 +13,7 @@ import torch.nn.functional as F
13
13
  from torch.nested import nested_tensor
14
14
  from torch.distributions import Normal
15
15
  from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
16
- from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
16
+ from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
17
17
  from torch.utils._pytree import tree_flatten, tree_unflatten
18
18
 
19
19
  import torchvision
@@ -83,6 +83,7 @@ class Experience:
83
83
  values: Tensor | None = None
84
84
  step_size: int | None = None
85
85
  lens: Tensor | None = None
86
+ is_truncated: Tensor | None = None
86
87
  agent_index: int = 0
87
88
  is_from_world_model: bool = True
88
89
 
@@ -99,7 +100,10 @@ def combine_experiences(
99
100
  batch, time, device = *latents.shape[:2], latents.device
100
101
 
101
102
  if not exists(exp.lens):
102
- exp.lens = torch.full((batch,), time, device = device)
103
+ exp.lens = full((batch,), time, device = device)
104
+
105
+ if not exists(exp.is_truncated):
106
+ exp.is_truncated = full((batch,), True, device = device)
103
107
 
104
108
  # convert to dictionary
105
109
 
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
2115
2119
 
2116
2120
  time_kv_cache = None
2117
2121
 
2118
- for _ in range(max_timesteps):
2122
+ for i in range(max_timesteps + 1):
2119
2123
 
2120
2124
  latents = self.video_tokenizer(video, return_latents = True)
2121
2125
 
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
2141
2145
 
2142
2146
  one_agent_embed = agent_embed[..., -1:, agent_index, :]
2143
2147
 
2148
+ # values
2149
+
2150
+ value_bins = self.value_head(one_agent_embed)
2151
+ value = self.reward_encoder.bins_to_scalar_value(value_bins)
2152
+
2153
+ values = safe_cat((values, value), dim = 1)
2154
+
2155
+ # policy embed
2156
+
2144
2157
  policy_embed = self.policy_head(one_agent_embed)
2145
2158
 
2146
2159
  # sample actions
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
2162
2175
  discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
2163
2176
  continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
2164
2177
 
2165
- value_bins = self.value_head(one_agent_embed)
2166
- value = self.reward_encoder.bins_to_scalar_value(value_bins)
2167
-
2168
- values = safe_cat((values, value), dim = 1)
2169
-
2170
2178
  # pass the sampled action to the environment and get back next state and reward
2171
2179
 
2172
2180
  next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
2187
2195
 
2188
2196
  # package up one experience for learning
2189
2197
 
2198
+ batch, device = latents.shape[0], latents.device
2199
+
2190
2200
  one_experience = Experience(
2191
2201
  latents = latents,
2192
2202
  video = video[:, :, :-1],
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
2196
2206
  values = values,
2197
2207
  step_size = step_size,
2198
2208
  agent_index = agent_index,
2209
+ lens = full((batch,), max_timesteps + 1, device = device),
2199
2210
  is_from_world_model = False
2200
2211
  )
2201
2212
 
@@ -2224,6 +2235,22 @@ class DynamicsWorldModel(Module):
2224
2235
 
2225
2236
  assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
2226
2237
 
2238
+ batch, time = latents.shape[0], latents.shape[1]
2239
+
2240
+ # calculate returns
2241
+
2242
+ # mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
2243
+
2244
+ if not exists(experience.is_truncated):
2245
+ experience.is_truncated = full((batch,), True, device = latents.device)
2246
+
2247
+ mask_for_gae = lens_to_mask(experience.lens, time)
2248
+
2249
+ rewards = rewards.masked_fill(mask_for_gae, 0.)
2250
+ old_values = old_values.masked_fill(mask_for_gae, 0.)
2251
+
2252
+ # calculate returns
2253
+
2227
2254
  returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
2228
2255
 
2229
2256
  # handle variable lengths
@@ -2232,8 +2259,8 @@ class DynamicsWorldModel(Module):
2232
2259
  is_var_len = exists(experience.lens)
2233
2260
 
2234
2261
  if is_var_len:
2235
- lens = experience.lens
2236
- mask = lens_to_mask(lens, max_time)
2262
+ learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
2263
+ mask = lens_to_mask(learnable_lens, max_time)
2237
2264
 
2238
2265
  # determine whether to finetune entire transformer or just learn the heads
2239
2266
 
@@ -2387,7 +2414,7 @@ class DynamicsWorldModel(Module):
2387
2414
  assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
2388
2415
 
2389
2416
  if isinstance(tasks, int):
2390
- tasks = torch.full((batch_size,), tasks, device = self.device)
2417
+ tasks = full((batch_size,), tasks, device = self.device)
2391
2418
 
2392
2419
  assert not exists(tasks) or tasks.shape[0] == batch_size
2393
2420
 
@@ -2624,7 +2651,7 @@ class DynamicsWorldModel(Module):
2624
2651
  # returning agent actions, rewards, and log probs + values for policy optimization
2625
2652
 
2626
2653
  batch, device = latents.shape[0], latents.device
2627
- experience_lens = torch.full((batch,), time_steps, device = device)
2654
+ experience_lens = full((batch,), time_steps, device = device)
2628
2655
 
2629
2656
  gen = Experience(
2630
2657
  latents = latents,
dreamer4/trainers.py CHANGED
@@ -287,7 +287,7 @@ class DreamTrainer(Module):
287
287
  for _ in range(self.num_train_steps):
288
288
 
289
289
  dreams = self.unwrapped_model.generate(
290
- self.generate_timesteps,
290
+ self.generate_timesteps + 1, # plus one for bootstrap value
291
291
  batch_size = self.batch_size,
292
292
  return_rewards_per_frame = True,
293
293
  return_agent_actions = True,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.77
3
+ Version: 0.0.79
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=B-mNoFEVC_fk4ZOAr9CKSniQfFrvtV5aXG9jmlRxoBA,108095
3
+ dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.79.dist-info/METADATA,sha256=6QM1bmGCCQeYkYW9HD3OQgNoKYlo-ILQIJ2e03aeZYk,3065
6
+ dreamer4-0.0.79.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.79.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.79.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=v6W9psBXMCGo9K3F6xCxm49TEu3jFUqDvo3pf_bsBQo,107099
3
- dreamer4/mocks.py,sha256=Oi91Yv1oK0E-Wz-KDkf79xoyWzIXCvMLCr0WYCpJDLA,1482
4
- dreamer4/trainers.py,sha256=898ye9Y1mqxGZnU_gfQS6pECibZwwyA43sL7wK_JHAU,13993
5
- dreamer4-0.0.77.dist-info/METADATA,sha256=8XGi2Q8meZfrz2K1cpbVH1P34wODclA0JD7OG0xldaQ,3065
6
- dreamer4-0.0.77.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.77.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.77.dist-info/RECORD,,