dreamer4 0.0.76__tar.gz → 0.0.78__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.76 → dreamer4-0.0.78}/PKG-INFO +1 -1
- {dreamer4-0.0.76 → dreamer4-0.0.78}/dreamer4/dreamer4.py +69 -13
- {dreamer4-0.0.76 → dreamer4-0.0.78}/pyproject.toml +1 -1
- {dreamer4-0.0.76 → dreamer4-0.0.78}/tests/test_dreamer.py +4 -4
- {dreamer4-0.0.76 → dreamer4-0.0.78}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/.gitignore +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/LICENSE +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/README.md +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.76 → dreamer4-0.0.78}/dreamer4-fig2.png +0 -0
|
@@ -13,7 +13,7 @@ import torch.nn.functional as F
|
|
|
13
13
|
from torch.nested import nested_tensor
|
|
14
14
|
from torch.distributions import Normal
|
|
15
15
|
from torch.nn import Module, ModuleList, Embedding, Parameter, Sequential, Linear, RMSNorm, Identity
|
|
16
|
-
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
16
|
+
from torch import nn, cat, stack, arange, tensor, Tensor, is_tensor, full, zeros, ones, randint, rand, randn, randn_like, empty, full, linspace, arange
|
|
17
17
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
|
18
18
|
|
|
19
19
|
import torchvision
|
|
@@ -82,7 +82,8 @@ class Experience:
|
|
|
82
82
|
log_probs: tuple[Tensor, Tensor] | None = None
|
|
83
83
|
values: Tensor | None = None
|
|
84
84
|
step_size: int | None = None
|
|
85
|
-
lens: Tensor | None = None
|
|
85
|
+
lens: Tensor | None = None
|
|
86
|
+
is_truncated: Tensor | None = None
|
|
86
87
|
agent_index: int = 0
|
|
87
88
|
is_from_world_model: bool = True
|
|
88
89
|
|
|
@@ -99,7 +100,10 @@ def combine_experiences(
|
|
|
99
100
|
batch, time, device = *latents.shape[:2], latents.device
|
|
100
101
|
|
|
101
102
|
if not exists(exp.lens):
|
|
102
|
-
exp.lens =
|
|
103
|
+
exp.lens = full((batch,), time, device = device)
|
|
104
|
+
|
|
105
|
+
if not exists(exp.is_truncated):
|
|
106
|
+
exp.is_truncated = full((batch,), True, device = device)
|
|
103
107
|
|
|
104
108
|
# convert to dictionary
|
|
105
109
|
|
|
@@ -2115,7 +2119,7 @@ class DynamicsWorldModel(Module):
|
|
|
2115
2119
|
|
|
2116
2120
|
time_kv_cache = None
|
|
2117
2121
|
|
|
2118
|
-
for
|
|
2122
|
+
for i in range(max_timesteps + 1):
|
|
2119
2123
|
|
|
2120
2124
|
latents = self.video_tokenizer(video, return_latents = True)
|
|
2121
2125
|
|
|
@@ -2141,6 +2145,15 @@ class DynamicsWorldModel(Module):
|
|
|
2141
2145
|
|
|
2142
2146
|
one_agent_embed = agent_embed[..., -1:, agent_index, :]
|
|
2143
2147
|
|
|
2148
|
+
# values
|
|
2149
|
+
|
|
2150
|
+
value_bins = self.value_head(one_agent_embed)
|
|
2151
|
+
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2152
|
+
|
|
2153
|
+
values = safe_cat((values, value), dim = 1)
|
|
2154
|
+
|
|
2155
|
+
# policy embed
|
|
2156
|
+
|
|
2144
2157
|
policy_embed = self.policy_head(one_agent_embed)
|
|
2145
2158
|
|
|
2146
2159
|
# sample actions
|
|
@@ -2162,11 +2175,6 @@ class DynamicsWorldModel(Module):
|
|
|
2162
2175
|
discrete_log_probs = safe_cat((discrete_log_probs, one_discrete_log_probs), dim = 1)
|
|
2163
2176
|
continuous_log_probs = safe_cat((continuous_log_probs, one_continuous_log_probs), dim = 1)
|
|
2164
2177
|
|
|
2165
|
-
value_bins = self.value_head(one_agent_embed)
|
|
2166
|
-
value = self.reward_encoder.bins_to_scalar_value(value_bins)
|
|
2167
|
-
|
|
2168
|
-
values = safe_cat((values, value), dim = 1)
|
|
2169
|
-
|
|
2170
2178
|
# pass the sampled action to the environment and get back next state and reward
|
|
2171
2179
|
|
|
2172
2180
|
next_frame, reward = env.step((sampled_discrete_actions, sampled_continuous_actions))
|
|
@@ -2187,6 +2195,8 @@ class DynamicsWorldModel(Module):
|
|
|
2187
2195
|
|
|
2188
2196
|
# package up one experience for learning
|
|
2189
2197
|
|
|
2198
|
+
batch, device = latents.shape[0], latents.device
|
|
2199
|
+
|
|
2190
2200
|
one_experience = Experience(
|
|
2191
2201
|
latents = latents,
|
|
2192
2202
|
video = video[:, :, :-1],
|
|
@@ -2196,6 +2206,7 @@ class DynamicsWorldModel(Module):
|
|
|
2196
2206
|
values = values,
|
|
2197
2207
|
step_size = step_size,
|
|
2198
2208
|
agent_index = agent_index,
|
|
2209
|
+
lens = full((batch,), max_timesteps, device = device),
|
|
2199
2210
|
is_from_world_model = False
|
|
2200
2211
|
)
|
|
2201
2212
|
|
|
@@ -2224,8 +2235,37 @@ class DynamicsWorldModel(Module):
|
|
|
2224
2235
|
|
|
2225
2236
|
assert all([*map(exists, (old_log_probs, actions, old_values, rewards, step_size))]), 'the generations need to contain the log probs, values, and rewards for policy optimization'
|
|
2226
2237
|
|
|
2238
|
+
batch, time = latents.shape[0], latents.shape[1]
|
|
2239
|
+
|
|
2240
|
+
# calculate returns
|
|
2241
|
+
|
|
2242
|
+
# for truncated (true by default), we will calculate experience lens + 1 and mask out anything after that
|
|
2243
|
+
# for terminated, will just mask out any after lens
|
|
2244
|
+
|
|
2245
|
+
# if not supplied, assume truncated (which is the case for games like minecraft or capped timesteps)
|
|
2246
|
+
|
|
2247
|
+
if not exists(experience.is_truncated):
|
|
2248
|
+
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2249
|
+
|
|
2250
|
+
lens_for_gae_calc = torch.where(experience.is_truncated, experience.lens, experience.lens + 1)
|
|
2251
|
+
mask_for_gae = lens_to_mask(lens_for_gae_calc, time)
|
|
2252
|
+
|
|
2253
|
+
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2254
|
+
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2255
|
+
|
|
2256
|
+
# calculate returns
|
|
2257
|
+
|
|
2227
2258
|
returns = calc_gae(rewards, old_values, gamma = self.gae_discount_factor, lam = self.gae_lambda, use_accelerated = self.gae_use_accelerated)
|
|
2228
2259
|
|
|
2260
|
+
# handle variable lengths
|
|
2261
|
+
|
|
2262
|
+
max_time = latents.shape[1]
|
|
2263
|
+
is_var_len = exists(experience.lens)
|
|
2264
|
+
|
|
2265
|
+
if is_var_len:
|
|
2266
|
+
lens = experience.lens
|
|
2267
|
+
mask = lens_to_mask(lens, max_time)
|
|
2268
|
+
|
|
2229
2269
|
# determine whether to finetune entire transformer or just learn the heads
|
|
2230
2270
|
|
|
2231
2271
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
@@ -2291,13 +2331,20 @@ class DynamicsWorldModel(Module):
|
|
|
2291
2331
|
|
|
2292
2332
|
# handle entropy loss for naive exploration bonus
|
|
2293
2333
|
|
|
2294
|
-
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2334
|
+
entropy_loss = - reduce(entropies, 'b t na -> b t', 'sum')
|
|
2295
2335
|
|
|
2296
2336
|
total_policy_loss = (
|
|
2297
2337
|
policy_loss +
|
|
2298
2338
|
entropy_loss * self.policy_entropy_weight
|
|
2299
2339
|
)
|
|
2300
2340
|
|
|
2341
|
+
# maybe handle variable lengths
|
|
2342
|
+
|
|
2343
|
+
if is_var_len:
|
|
2344
|
+
total_policy_loss = total_policy_loss[mask].mean()
|
|
2345
|
+
else:
|
|
2346
|
+
total_policy_loss = total_policy_loss.mean()
|
|
2347
|
+
|
|
2301
2348
|
# maybe take policy optimizer step
|
|
2302
2349
|
|
|
2303
2350
|
if exists(policy_optim):
|
|
@@ -2316,10 +2363,19 @@ class DynamicsWorldModel(Module):
|
|
|
2316
2363
|
|
|
2317
2364
|
return_bins = self.reward_encoder(returns)
|
|
2318
2365
|
|
|
2366
|
+
value_bins, return_bins, clipped_value_bins = tuple(rearrange(t, 'b t l -> b l t') for t in (value_bins, return_bins, clipped_value_bins))
|
|
2367
|
+
|
|
2319
2368
|
value_loss_1 = F.cross_entropy(value_bins, return_bins, reduction = 'none')
|
|
2320
2369
|
value_loss_2 = F.cross_entropy(clipped_value_bins, return_bins, reduction = 'none')
|
|
2321
2370
|
|
|
2322
|
-
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2371
|
+
value_loss = torch.maximum(value_loss_1, value_loss_2)
|
|
2372
|
+
|
|
2373
|
+
# maybe variable length
|
|
2374
|
+
|
|
2375
|
+
if is_var_len:
|
|
2376
|
+
value_loss = value_loss[mask].mean()
|
|
2377
|
+
else:
|
|
2378
|
+
value_loss = value_loss.mean()
|
|
2323
2379
|
|
|
2324
2380
|
# maybe take value optimizer step
|
|
2325
2381
|
|
|
@@ -2362,7 +2418,7 @@ class DynamicsWorldModel(Module):
|
|
|
2362
2418
|
assert 0 < num_steps <= self.max_steps, f'number of steps {num_steps} must be between 0 and {self.max_steps}'
|
|
2363
2419
|
|
|
2364
2420
|
if isinstance(tasks, int):
|
|
2365
|
-
tasks =
|
|
2421
|
+
tasks = full((batch_size,), tasks, device = self.device)
|
|
2366
2422
|
|
|
2367
2423
|
assert not exists(tasks) or tasks.shape[0] == batch_size
|
|
2368
2424
|
|
|
@@ -2599,7 +2655,7 @@ class DynamicsWorldModel(Module):
|
|
|
2599
2655
|
# returning agent actions, rewards, and log probs + values for policy optimization
|
|
2600
2656
|
|
|
2601
2657
|
batch, device = latents.shape[0], latents.device
|
|
2602
|
-
experience_lens =
|
|
2658
|
+
experience_lens = full((batch,), time_steps, device = device)
|
|
2603
2659
|
|
|
2604
2660
|
gen = Experience(
|
|
2605
2661
|
latents = latents,
|
|
@@ -407,14 +407,14 @@ def test_mtp():
|
|
|
407
407
|
|
|
408
408
|
reward_targets, mask = create_multi_token_prediction_targets(rewards, 3) # say three token lookahead
|
|
409
409
|
|
|
410
|
-
assert reward_targets.shape == (3,
|
|
411
|
-
assert mask.shape == (3,
|
|
410
|
+
assert reward_targets.shape == (3, 16, 3)
|
|
411
|
+
assert mask.shape == (3, 16, 3)
|
|
412
412
|
|
|
413
413
|
actions = torch.randint(0, 10, (3, 16, 2))
|
|
414
414
|
action_targets, mask = create_multi_token_prediction_targets(actions, 3)
|
|
415
415
|
|
|
416
|
-
assert action_targets.shape == (3,
|
|
417
|
-
assert mask.shape == (3,
|
|
416
|
+
assert action_targets.shape == (3, 16, 3, 2)
|
|
417
|
+
assert mask.shape == (3, 16, 3)
|
|
418
418
|
|
|
419
419
|
from dreamer4.dreamer4 import ActionEmbedder
|
|
420
420
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|