dreamer4 0.0.77__tar.gz → 0.0.79__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.77 → dreamer4-0.0.79}/PKG-INFO +1 -1
- {dreamer4-0.0.77 → dreamer4-0.0.79}/dreamer4/dreamer4.py +39 -12
- {dreamer4-0.0.77 → dreamer4-0.0.79}/dreamer4/trainers.py +1 -1
- {dreamer4-0.0.77 → dreamer4-0.0.79}/pyproject.toml +1 -1
- {dreamer4-0.0.77 → dreamer4-0.0.79}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/.gitignore +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/LICENSE +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/README.md +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.79}/tests/test_dreamer.py +0 -0
|
@@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|
|
13
13
|
from torch.nested import nested_tensor
|
|
14
14
|
from torch.distributions import Normal
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
|
-
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
16
|
+
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
17
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
@@ -83,6 +83,7 @@ class Experience:
|
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
85
|
lens: Tensor | None = None
|
|
86
|
+
is_truncated: Tensor | None = None
|
|
86
87
|
agent_index: int = 0
|
|
87
88
|
is_from_world_model: bool = True
|
|
88
89
|
|
|
@@ -99,7 +100,10 @@ def combine_experiences(
|
|
|
99
100
|
batch, time, device = *latents.shape[:2], latents.device
|
|
100
101
|
|
|
101
102
|
if not exists(exp.lens):
|
|
102
|
-
exp.lens =
|
|
103
|
+
exp.lens = full((batch,), time, device = device)
|
|
104
|
+
|
|
105
|
+
if not exists(exp.is_truncated):
|
|
106
|
+
exp.is_truncated = full((batch,), True, device = device)
|
|
103
107
|
|
|
104
108
|
# convert to dictionary
|
|
105
109
|
|
|
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2119
|
|
|
2116
2120
|
time_kv_cache = None
|
|
2117
2121
|
|
|
2118
|
-
for
|
|
2122
|
+
for i in range(max_timesteps + 1):
|
|
2119
2123
|
|
|
2120
2124
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2121
2125
|
|
|
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
|
|
|
2141
2145
|
|
|
2142
2146
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
2143
2147
|
|
|
2148
|
+
# values
|
|
2149
|
+
|
|
2150
|
+
value_bins = self.value_head(one_agent_embed)
|
|
2151
|
+
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2152
|
+
|
|
2153
|
+
values = safe_cat((values, value), dim = 1)
|
|
2154
|
+
|
|
2155
|
+
# policy embed
|
|
2156
|
+
|
|
2144
2157
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2145
2158
|
|
|
2146
2159
|
# sample actions
|
|
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
|
|
|
2162
2175
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
2163
2176
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
2164
2177
|
|
|
2165
|
-
value_bins = self.value_head(one_agent_embed)
|
|
2166
|
-
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2167
|
-
|
|
2168
|
-
values = safe_cat((values, value), dim = 1)
|
|
2169
|
-
|
|
2170
2178
|
# pass the sampled action to the environment and get back next state and reward
|
|
2171
2179
|
|
|
2172
2180
|
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2195
|
|
|
2188
2196
|
# package up one experience for learning
|
|
2189
2197
|
|
|
2198
|
+
batch, device = latents.shape[0], latents.device
|
|
2199
|
+
|
|
2190
2200
|
one_experience = Experience(
|
|
2191
2201
|
latents = latents,
|
|
2192
2202
|
video = video[:, :, :-1],
|
|
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
|
|
|
2196
2206
|
values = values,
|
|
2197
2207
|
step_size = step_size,
|
|
2198
2208
|
agent_index = agent_index,
|
|
2209
|
+
lens = full((batch,), max_timesteps + 1, device = device),
|
|
2199
2210
|
is_from_world_model = False
|
|
2200
2211
|
)
|
|
2201
2212
|
|
|
@@ -2224,6 +2235,22 @@ class DynamicsWorldModel(Module):
|
|
|
2224
2235
|
|
|
2225
2236
|
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2226
2237
|
|
|
2238
|
+
batch, time = latents.shape[0], latents.shape[1]
|
|
2239
|
+
|
|
2240
|
+
# calculate returns
|
|
2241
|
+
|
|
2242
|
+
# mask out anything after the `lens`, which may include a bootstrapped node at the very end if `is_truncated = True`
|
|
2243
|
+
|
|
2244
|
+
if not exists(experience.is_truncated):
|
|
2245
|
+
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2246
|
+
|
|
2247
|
+
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2248
|
+
|
|
2249
|
+
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2250
|
+
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2251
|
+
|
|
2252
|
+
# calculate returns
|
|
2253
|
+
|
|
2227
2254
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
2228
2255
|
|
|
2229
2256
|
# handle variable lengths
|
|
@@ -2232,8 +2259,8 @@ class DynamicsWorldModel(Module):
|
|
|
2232
2259
|
is_var_len = exists(experience.lens)
|
|
2233
2260
|
|
|
2234
2261
|
if is_var_len:
|
|
2235
|
-
|
|
2236
|
-
mask = lens_to_mask(
|
|
2262
|
+
learnable_lens = experience.lens - experience.is_truncated.long() # if is truncated, remove the last one, as it is bootstrapped value
|
|
2263
|
+
mask = lens_to_mask(learnable_lens, max_time)
|
|
2237
2264
|
|
|
2238
2265
|
# determine whether to finetune entire transformer or just learn the heads
|
|
2239
2266
|
|
|
@@ -2387,7 +2414,7 @@ class DynamicsWorldModel(Module):
|
|
|
2387
2414
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
2388
2415
|
|
|
2389
2416
|
if isinstance(tasks, int):
|
|
2390
|
-
tasks =
|
|
2417
|
+
tasks = full((batch_size,), tasks, device = self.device)
|
|
2391
2418
|
|
|
2392
2419
|
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
2393
2420
|
|
|
@@ -2624,7 +2651,7 @@ class DynamicsWorldModel(Module):
|
|
|
2624
2651
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
2625
2652
|
|
|
2626
2653
|
batch, device = latents.shape[0], latents.device
|
|
2627
|
-
experience_lens =
|
|
2654
|
+
experience_lens = full((batch,), time_steps, device = device)
|
|
2628
2655
|
|
|
2629
2656
|
gen = Experience(
|
|
2630
2657
|
latents = latents,
|
|
@@ -287,7 +287,7 @@ class DreamTrainer(Module):
|
|
|
287
287
|
for _ in range(self.num_train_steps):
|
|
288
288
|
|
|
289
289
|
dreams = self.unwrapped_model.generate(
|
|
290
|
-
self.generate_timesteps,
|
|
290
|
+
self.generate_timesteps + 1, # plus one for bootstrap value
|
|
291
291
|
batch_size = self.batch_size,
|
|
292
292
|
return_rewards_per_frame = True,
|
|
293
293
|
return_agent_actions = True,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|