dreamer4 0.0.77__tar.gz → 0.0.78__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.77 → dreamer4-0.0.78}/PKG-INFO +1 -1
- {dreamer4-0.0.77 → dreamer4-0.0.78}/dreamer4/dreamer4.py +41 -10
- {dreamer4-0.0.77 → dreamer4-0.0.78}/pyproject.toml +1 -1
- {dreamer4-0.0.77 → dreamer4-0.0.78}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/.gitignore +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/LICENSE +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/README.md +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.77 → dreamer4-0.0.78}/tests/test_dreamer.py +0 -0
|
@@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|
|
13
13
|
from torch.nested import nested_tensor
|
|
14
14
|
from torch.distributions import Normal
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
|
-
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
16
|
+
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
17
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
@@ -83,6 +83,7 @@ class Experience:
|
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
85
|
lens: Tensor | None = None
|
|
86
|
+
is_truncated: Tensor | None = None
|
|
86
87
|
agent_index: int = 0
|
|
87
88
|
is_from_world_model: bool = True
|
|
88
89
|
|
|
@@ -99,7 +100,10 @@ def combine_experiences(
|
|
|
99
100
|
batch, time, device = *latents.shape[:2], latents.device
|
|
100
101
|
|
|
101
102
|
if not exists(exp.lens):
|
|
102
|
-
exp.lens =
|
|
103
|
+
exp.lens = full((batch,), time, device = device)
|
|
104
|
+
|
|
105
|
+
if not exists(exp.is_truncated):
|
|
106
|
+
exp.is_truncated = full((batch,), True, device = device)
|
|
103
107
|
|
|
104
108
|
# convert to dictionary
|
|
105
109
|
|
|
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2119
|
|
|
2116
2120
|
time_kv_cache = None
|
|
2117
2121
|
|
|
2118
|
-
for
|
|
2122
|
+
for i in range(max_timesteps + 1):
|
|
2119
2123
|
|
|
2120
2124
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2121
2125
|
|
|
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
|
|
|
2141
2145
|
|
|
2142
2146
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
2143
2147
|
|
|
2148
|
+
# values
|
|
2149
|
+
|
|
2150
|
+
value_bins = self.value_head(one_agent_embed)
|
|
2151
|
+
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2152
|
+
|
|
2153
|
+
values = safe_cat((values, value), dim = 1)
|
|
2154
|
+
|
|
2155
|
+
# policy embed
|
|
2156
|
+
|
|
2144
2157
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2145
2158
|
|
|
2146
2159
|
# sample actions
|
|
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
|
|
|
2162
2175
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
2163
2176
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
2164
2177
|
|
|
2165
|
-
value_bins = self.value_head(one_agent_embed)
|
|
2166
|
-
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2167
|
-
|
|
2168
|
-
values = safe_cat((values, value), dim = 1)
|
|
2169
|
-
|
|
2170
2178
|
# pass the sampled action to the environment and get back next state and reward
|
|
2171
2179
|
|
|
2172
2180
|
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2195
|
|
|
2188
2196
|
# package up one experience for learning
|
|
2189
2197
|
|
|
2198
|
+
batch, device = latents.shape[0], latents.device
|
|
2199
|
+
|
|
2190
2200
|
one_experience = Experience(
|
|
2191
2201
|
latents = latents,
|
|
2192
2202
|
video = video[:, :, :-1],
|
|
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
|
|
|
2196
2206
|
values = values,
|
|
2197
2207
|
step_size = step_size,
|
|
2198
2208
|
agent_index = agent_index,
|
|
2209
|
+
lens = full((batch,), max_timesteps, device = device),
|
|
2199
2210
|
is_from_world_model = False
|
|
2200
2211
|
)
|
|
2201
2212
|
|
|
@@ -2224,6 +2235,26 @@ class DynamicsWorldModel(Module):
|
|
|
2224
2235
|
|
|
2225
2236
|
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2226
2237
|
|
|
2238
|
+
batch, time = latents.shape[0], latents.shape[1]
|
|
2239
|
+
|
|
2240
|
+
# calculate returns
|
|
2241
|
+
|
|
2242
|
+
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
|
|
2243
|
+
# for terminated, will just mask out any after lens
|
|
2244
|
+
|
|
2245
|
+
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
|
2246
|
+
|
|
2247
|
+
if not exists(experience.is_truncated):
|
|
2248
|
+
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2249
|
+
|
|
2250
|
+
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
|
|
2251
|
+
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
|
2252
|
+
|
|
2253
|
+
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2254
|
+
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2255
|
+
|
|
2256
|
+
# calculate returns
|
|
2257
|
+
|
|
2227
2258
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
2228
2259
|
|
|
2229
2260
|
# handle variable lengths
|
|
@@ -2387,7 +2418,7 @@ class DynamicsWorldModel(Module):
|
|
|
2387
2418
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
2388
2419
|
|
|
2389
2420
|
if isinstance(tasks, int):
|
|
2390
|
-
tasks =
|
|
2421
|
+
tasks = full((batch_size,), tasks, device = self.device)
|
|
2391
2422
|
|
|
2392
2423
|
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
2393
2424
|
|
|
@@ -2624,7 +2655,7 @@ class DynamicsWorldModel(Module):
|
|
|
2624
2655
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
2625
2656
|
|
|
2626
2657
|
batch, device = latents.shape[0], latents.device
|
|
2627
|
-
experience_lens =
|
|
2658
|
+
experience_lens = full((batch,), time_steps, device = device)
|
|
2628
2659
|
|
|
2629
2660
|
gen = Experience(
|
|
2630
2661
|
latents = latents,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|