dreamer4 0.0.79__tar.gz → 0.0.81__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.79 → dreamer4-0.0.81}/PKG-INFO +1 -1
- {dreamer4-0.0.79 → dreamer4-0.0.81}/dreamer4/dreamer4.py +37 -5
- {dreamer4-0.0.79 → dreamer4-0.0.81}/pyproject.toml +1 -1
- {dreamer4-0.0.79 → dreamer4-0.0.81}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/.gitignore +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/LICENSE +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/README.md +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/dreamer4-fig2.png +0 -0
- {dreamer4-0.0.79 → dreamer4-0.0.81}/tests/test_dreamer.py +0 -0
|
@@ -1817,6 +1817,8 @@ class DynamicsWorldModel(Module):
|
|
|
1817
1817
|
continuous_action_loss_weight: float | list[float] = 1.,
|
|
1818
1818
|
num_latent_genes = 0, # for carrying out evolution within the dreams https://web3.arxiv.org/abs/2503.19037
|
|
1819
1819
|
num_residual_streams = 1,
|
|
1820
|
+
keep_reward_ema_stats = False,
|
|
1821
|
+
reward_ema_decay = 0.998,
|
|
1820
1822
|
gae_discount_factor = 0.997,
|
|
1821
1823
|
gae_lambda = 0.95,
|
|
1822
1824
|
ppo_eps_clip = 0.2,
|
|
@@ -2022,6 +2024,14 @@ class DynamicsWorldModel(Module):
|
|
|
2022
2024
|
self.value_clip = value_clip
|
|
2023
2025
|
self.policy_entropy_weight = value_clip
|
|
2024
2026
|
|
|
2027
|
+
# rewards related
|
|
2028
|
+
|
|
2029
|
+
self.keep_reward_ema_stats = keep_reward_ema_stats
|
|
2030
|
+
self.reward_ema_decay = reward_ema_decay
|
|
2031
|
+
|
|
2032
|
+
self.register_buffer('ema_returns_mean', tensor(0.))
|
|
2033
|
+
self.register_buffer('ema_returns_var', tensor(1.))
|
|
2034
|
+
|
|
2025
2035
|
# loss related
|
|
2026
2036
|
|
|
2027
2037
|
self.flow_loss_normalizer = LossNormalizer(1)
|
|
@@ -2244,10 +2254,11 @@ class DynamicsWorldModel(Module):
|
|
|
2244
2254
|
if not exists(experience.is_truncated):
|
|
2245
2255
|
experience.is_truncated = full((batch,), True, device = latents.device)
|
|
2246
2256
|
|
|
2247
|
-
|
|
2257
|
+
if exists(experience.lens):
|
|
2258
|
+
mask_for_gae = lens_to_mask(experience.lens, time)
|
|
2248
2259
|
|
|
2249
|
-
|
|
2250
|
-
|
|
2260
|
+
rewards = rewards.masked_fill(mask_for_gae, 0.)
|
|
2261
|
+
old_values = old_values.masked_fill(mask_for_gae, 0.)
|
|
2251
2262
|
|
|
2252
2263
|
# calculate returns
|
|
2253
2264
|
|
|
@@ -2266,11 +2277,32 @@ class DynamicsWorldModel(Module):
|
|
|
2266
2277
|
|
|
2267
2278
|
world_model_forward_context = torch.no_grad if only_learn_policy_value_heads else nullcontext
|
|
2268
2279
|
|
|
2280
|
+
# maybe keep track returns statistics and normalize returns and values before calculating advantage, as done in dreamer v3
|
|
2281
|
+
|
|
2282
|
+
if self.keep_reward_ema_stats:
|
|
2283
|
+
ema_returns_mean, ema_returns_var = self.ema_returns_mean, self.ema_returns_var
|
|
2284
|
+
|
|
2285
|
+
decay = 1. - self.reward_ema_decay
|
|
2286
|
+
|
|
2287
|
+
# todo - handle distributed
|
|
2288
|
+
|
|
2289
|
+
returns_mean, returns_var = returns.mean(), returns.var()
|
|
2290
|
+
|
|
2291
|
+
ema_returns_mean.lerp_(returns_mean, decay)
|
|
2292
|
+
ema_returns_var.lerp_(returns_var, decay)
|
|
2293
|
+
|
|
2294
|
+
ema_returns_std = ema_returns_var.clamp(min = 1e-5).sqrt()
|
|
2295
|
+
|
|
2296
|
+
normed_returns = (returns - ema_returns_mean) / ema_returns_std
|
|
2297
|
+
normed_old_values = (old_values - ema_returns_mean) / ema_returns_std
|
|
2298
|
+
|
|
2299
|
+
advantage = normed_returns - normed_old_values
|
|
2300
|
+
else:
|
|
2301
|
+
advantage = returns - old_values
|
|
2302
|
+
|
|
2269
2303
|
# apparently they just use the sign of the advantage
|
|
2270
2304
|
# https://arxiv.org/abs/2410.04166v1
|
|
2271
2305
|
|
|
2272
|
-
advantage = returns - old_values
|
|
2273
|
-
|
|
2274
2306
|
if use_signed_advantage:
|
|
2275
2307
|
advantage = advantage.sign()
|
|
2276
2308
|
else:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|