dreamer4 0.0.87__tar.gz → 0.0.89__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.87 → dreamer4-0.0.89}/PKG-INFO +1 -1
- {dreamer4-0.0.87 → dreamer4-0.0.89}/dreamer4/dreamer4.py +53 -3
- {dreamer4-0.0.87 → dreamer4-0.0.89}/pyproject.toml +1 -1
- {dreamer4-0.0.87 → dreamer4-0.0.89}/tests/test_dreamer.py +33 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/.gitignore +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/LICENSE +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/README.md +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.87 → dreamer4-0.0.89}/dreamer4-fig2.png +0 -0
|
@@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
|
|
|
1942
1942
|
# learned set of latent genes
|
|
1943
1943
|
|
|
1944
1944
|
self.agent_has_genes = num_latent_genes > 0
|
|
1945
|
+
self.num_latent_genes = num_latent_genes
|
|
1945
1946
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
1946
1947
|
|
|
1947
1948
|
# policy head
|
|
@@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
|
|
|
2095
2096
|
|
|
2096
2097
|
return align_dims_left(times, align_dims_left_to)
|
|
2097
2098
|
|
|
2099
|
+
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
|
2100
|
+
|
|
2101
|
+
@torch.no_grad()
|
|
2102
|
+
def evolve_(
|
|
2103
|
+
self,
|
|
2104
|
+
fitness,
|
|
2105
|
+
select_frac = 0.5,
|
|
2106
|
+
tournament_frac = 0.5
|
|
2107
|
+
):
|
|
2108
|
+
assert fitness.numel() == self.num_latent_genes
|
|
2109
|
+
|
|
2110
|
+
pop = self.latent_genes
|
|
2111
|
+
|
|
2112
|
+
pop_size = self.num_latent_genes
|
|
2113
|
+
num_selected = ceil(pop_size * select_frac)
|
|
2114
|
+
num_children = pop_size - num_selected
|
|
2115
|
+
|
|
2116
|
+
dim_gene = pop.shape[-1]
|
|
2117
|
+
|
|
2118
|
+
# natural selection just a sort and slice
|
|
2119
|
+
|
|
2120
|
+
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
|
2121
|
+
selected = pop[selected_indices]
|
|
2122
|
+
|
|
2123
|
+
# use tournament - one tournament per child
|
|
2124
|
+
|
|
2125
|
+
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
|
2126
|
+
|
|
2127
|
+
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
|
2128
|
+
|
|
2129
|
+
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
|
2130
|
+
|
|
2131
|
+
parents = selected[parent_ids]
|
|
2132
|
+
|
|
2133
|
+
# crossover by random interpolation from parent1 to parent2
|
|
2134
|
+
|
|
2135
|
+
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
|
2136
|
+
|
|
2137
|
+
parent1, parent2 = parents.unbind(dim = 1)
|
|
2138
|
+
children = parent1.lerp(parent2, random_uniform_mix)
|
|
2139
|
+
|
|
2140
|
+
# store next population
|
|
2141
|
+
|
|
2142
|
+
next_pop = cat((selected, children))
|
|
2143
|
+
|
|
2144
|
+
self.latent_genes.copy_(next_pop)
|
|
2145
|
+
|
|
2098
2146
|
# interacting with env for experience
|
|
2099
2147
|
|
|
2100
2148
|
@torch.no_grad()
|
|
@@ -2255,7 +2303,7 @@ class DynamicsWorldModel(Module):
|
|
|
2255
2303
|
video = cat((video, next_frame), dim = 2)
|
|
2256
2304
|
rewards = safe_cat((rewards, reward), dim = 1)
|
|
2257
2305
|
|
|
2258
|
-
acc_agent_embed = safe_cat((acc_agent_embed,
|
|
2306
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2259
2307
|
|
|
2260
2308
|
# package up one experience for learning
|
|
2261
2309
|
|
|
@@ -2397,7 +2445,7 @@ class DynamicsWorldModel(Module):
|
|
|
2397
2445
|
return_intermediates = True
|
|
2398
2446
|
)
|
|
2399
2447
|
|
|
2400
|
-
|
|
2448
|
+
agent_embeds = agent_embeds[..., agent_index, :]
|
|
2401
2449
|
|
|
2402
2450
|
# maybe detach agent embed
|
|
2403
2451
|
|
|
@@ -2672,7 +2720,9 @@ class DynamicsWorldModel(Module):
|
|
|
2672
2720
|
|
|
2673
2721
|
# maybe store agent embed
|
|
2674
2722
|
|
|
2675
|
-
|
|
2723
|
+
if store_agent_embed:
|
|
2724
|
+
one_agent_embed = agent_embed[:, -1:, agent_index]
|
|
2725
|
+
acc_agent_embed = safe_cat((acc_agent_embed, one_agent_embed), dim = 1)
|
|
2676
2726
|
|
|
2677
2727
|
# decode the agent actions if needed
|
|
2678
2728
|
|
|
@@ -753,3 +753,36 @@ def test_proprioception(
|
|
|
753
753
|
|
|
754
754
|
assert exists(generations.proprio)
|
|
755
755
|
assert generations.video.shape == video_shape
|
|
756
|
+
|
|
757
|
+
def test_epo():
|
|
758
|
+
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
759
|
+
|
|
760
|
+
tokenizer = VideoTokenizer(
|
|
761
|
+
512,
|
|
762
|
+
dim_latent = 32,
|
|
763
|
+
patch_size = 32,
|
|
764
|
+
encoder_depth = 2,
|
|
765
|
+
decoder_depth = 2,
|
|
766
|
+
time_block_every = 2,
|
|
767
|
+
attn_heads = 8,
|
|
768
|
+
image_height = 256,
|
|
769
|
+
image_width = 256,
|
|
770
|
+
attn_kwargs = dict(
|
|
771
|
+
query_heads = 16
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
dynamics = DynamicsWorldModel(
|
|
776
|
+
512,
|
|
777
|
+
num_agents = 1,
|
|
778
|
+
video_tokenizer = tokenizer,
|
|
779
|
+
dim_latent = 32,
|
|
780
|
+
dim_proprio = 21,
|
|
781
|
+
num_tasks = 4,
|
|
782
|
+
num_latent_genes = 16,
|
|
783
|
+
num_discrete_actions = 4,
|
|
784
|
+
num_residual_streams = 1
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
fitness = torch.randn(16,)
|
|
788
|
+
dynamics.evolve_(fitness)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|