dreamer4 0.0.88__tar.gz → 0.0.89__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- {dreamer4-0.0.88 → dreamer4-0.0.89}/PKG-INFO +1 -1
- {dreamer4-0.0.88 → dreamer4-0.0.89}/dreamer4/dreamer4.py +48 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/pyproject.toml +1 -1
- {dreamer4-0.0.88 → dreamer4-0.0.89}/tests/test_dreamer.py +33 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/.github/workflows/python-publish.yml +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/.github/workflows/test.yml +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/.gitignore +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/LICENSE +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/README.md +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/dreamer4/__init__.py +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/dreamer4/mocks.py +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/dreamer4/trainers.py +0 -0
- {dreamer4-0.0.88 → dreamer4-0.0.89}/dreamer4-fig2.png +0 -0
|
@@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
|
|
|
1942
1942
|
# learned set of latent genes
|
|
1943
1943
|
|
|
1944
1944
|
self.agent_has_genes = num_latent_genes > 0
|
|
1945
|
+
self.num_latent_genes = num_latent_genes
|
|
1945
1946
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
1946
1947
|
|
|
1947
1948
|
# policy head
|
|
@@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
|
|
|
2095
2096
|
|
|
2096
2097
|
return align_dims_left(times, align_dims_left_to)
|
|
2097
2098
|
|
|
2099
|
+
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
|
2100
|
+
|
|
2101
|
+
@torch.no_grad()
|
|
2102
|
+
def evolve_(
|
|
2103
|
+
self,
|
|
2104
|
+
fitness,
|
|
2105
|
+
select_frac = 0.5,
|
|
2106
|
+
tournament_frac = 0.5
|
|
2107
|
+
):
|
|
2108
|
+
assert fitness.numel() == self.num_latent_genes
|
|
2109
|
+
|
|
2110
|
+
pop = self.latent_genes
|
|
2111
|
+
|
|
2112
|
+
pop_size = self.num_latent_genes
|
|
2113
|
+
num_selected = ceil(pop_size * select_frac)
|
|
2114
|
+
num_children = pop_size - num_selected
|
|
2115
|
+
|
|
2116
|
+
dim_gene = pop.shape[-1]
|
|
2117
|
+
|
|
2118
|
+
# natural selection just a sort and slice
|
|
2119
|
+
|
|
2120
|
+
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
|
2121
|
+
selected = pop[selected_indices]
|
|
2122
|
+
|
|
2123
|
+
# use tournament - one tournament per child
|
|
2124
|
+
|
|
2125
|
+
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
|
2126
|
+
|
|
2127
|
+
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
|
2128
|
+
|
|
2129
|
+
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
|
2130
|
+
|
|
2131
|
+
parents = selected[parent_ids]
|
|
2132
|
+
|
|
2133
|
+
# crossover by random interpolation from parent1 to parent2
|
|
2134
|
+
|
|
2135
|
+
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
|
2136
|
+
|
|
2137
|
+
parent1, parent2 = parents.unbind(dim = 1)
|
|
2138
|
+
children = parent1.lerp(parent2, random_uniform_mix)
|
|
2139
|
+
|
|
2140
|
+
# store next population
|
|
2141
|
+
|
|
2142
|
+
next_pop = cat((selected, children))
|
|
2143
|
+
|
|
2144
|
+
self.latent_genes.copy_(next_pop)
|
|
2145
|
+
|
|
2098
2146
|
# interacting with env for experience
|
|
2099
2147
|
|
|
2100
2148
|
@torch.no_grad()
|
|
@@ -753,3 +753,36 @@ def test_proprioception(
|
|
|
753
753
|
|
|
754
754
|
assert exists(generations.proprio)
|
|
755
755
|
assert generations.video.shape == video_shape
|
|
756
|
+
|
|
757
|
+
def test_epo():
|
|
758
|
+
from dreamer4.dreamer4 import VideoTokenizer, DynamicsWorldModel
|
|
759
|
+
|
|
760
|
+
tokenizer = VideoTokenizer(
|
|
761
|
+
512,
|
|
762
|
+
dim_latent = 32,
|
|
763
|
+
patch_size = 32,
|
|
764
|
+
encoder_depth = 2,
|
|
765
|
+
decoder_depth = 2,
|
|
766
|
+
time_block_every = 2,
|
|
767
|
+
attn_heads = 8,
|
|
768
|
+
image_height = 256,
|
|
769
|
+
image_width = 256,
|
|
770
|
+
attn_kwargs = dict(
|
|
771
|
+
query_heads = 16
|
|
772
|
+
)
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
dynamics = DynamicsWorldModel(
|
|
776
|
+
512,
|
|
777
|
+
num_agents = 1,
|
|
778
|
+
video_tokenizer = tokenizer,
|
|
779
|
+
dim_latent = 32,
|
|
780
|
+
dim_proprio = 21,
|
|
781
|
+
num_tasks = 4,
|
|
782
|
+
num_latent_genes = 16,
|
|
783
|
+
num_discrete_actions = 4,
|
|
784
|
+
num_residual_streams = 1
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
fitness = torch.randn(16,)
|
|
788
|
+
dynamics.evolve_(fitness)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|