dreamer4 0.0.88__py3-none-any.whl → 0.0.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dreamer4 might be problematic. Click here for more details.
- dreamer4/dreamer4.py +50 -0
- {dreamer4-0.0.88.dist-info → dreamer4-0.0.90.dist-info}/METADATA +1 -1
- dreamer4-0.0.90.dist-info/RECORD +8 -0
- dreamer4-0.0.88.dist-info/RECORD +0 -8
- {dreamer4-0.0.88.dist-info → dreamer4-0.0.90.dist-info}/WHEEL +0 -0
- {dreamer4-0.0.88.dist-info → dreamer4-0.0.90.dist-info}/licenses/LICENSE +0 -0
dreamer4/dreamer4.py
CHANGED
|
@@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
|
|
|
1942
1942
|
# learned set of latent genes
|
|
1943
1943
|
|
|
1944
1944
|
self.agent_has_genes = num_latent_genes > 0
|
|
1945
|
+
self.num_latent_genes = num_latent_genes
|
|
1945
1946
|
self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
|
|
1946
1947
|
|
|
1947
1948
|
# policy head
|
|
@@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
|
|
|
2095
2096
|
|
|
2096
2097
|
return align_dims_left(times, align_dims_left_to)
|
|
2097
2098
|
|
|
2099
|
+
# evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
|
|
2100
|
+
|
|
2101
|
+
@torch.no_grad()
|
|
2102
|
+
def evolve_(
|
|
2103
|
+
self,
|
|
2104
|
+
fitness,
|
|
2105
|
+
select_frac = 0.5,
|
|
2106
|
+
tournament_frac = 0.5
|
|
2107
|
+
):
|
|
2108
|
+
assert fitness.numel() == self.num_latent_genes
|
|
2109
|
+
|
|
2110
|
+
pop = self.latent_genes
|
|
2111
|
+
|
|
2112
|
+
pop_size = self.num_latent_genes
|
|
2113
|
+
num_selected = ceil(pop_size * select_frac)
|
|
2114
|
+
num_children = pop_size - num_selected
|
|
2115
|
+
|
|
2116
|
+
dim_gene = pop.shape[-1]
|
|
2117
|
+
|
|
2118
|
+
# natural selection just a sort and slice
|
|
2119
|
+
|
|
2120
|
+
selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
|
|
2121
|
+
selected = pop[selected_indices]
|
|
2122
|
+
|
|
2123
|
+
# use tournament - one tournament per child
|
|
2124
|
+
|
|
2125
|
+
tournament_size = max(2, ceil(num_selected * tournament_frac))
|
|
2126
|
+
|
|
2127
|
+
tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
|
|
2128
|
+
|
|
2129
|
+
parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
|
|
2130
|
+
|
|
2131
|
+
parents = selected[parent_ids]
|
|
2132
|
+
|
|
2133
|
+
# crossover by random interpolation from parent1 to parent2
|
|
2134
|
+
|
|
2135
|
+
random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
|
|
2136
|
+
|
|
2137
|
+
parent1, parent2 = parents.unbind(dim = 1)
|
|
2138
|
+
children = parent1.lerp(parent2, random_uniform_mix)
|
|
2139
|
+
|
|
2140
|
+
# store next population
|
|
2141
|
+
|
|
2142
|
+
next_pop = cat((selected, children))
|
|
2143
|
+
|
|
2144
|
+
self.latent_genes.copy_(next_pop)
|
|
2145
|
+
|
|
2098
2146
|
# interacting with env for experience
|
|
2099
2147
|
|
|
2100
2148
|
@torch.no_grad()
|
|
@@ -2495,6 +2543,7 @@ class DynamicsWorldModel(Module):
|
|
|
2495
2543
|
batch_size = 1,
|
|
2496
2544
|
agent_index = 0,
|
|
2497
2545
|
tasks: int | Tensor | None = None,
|
|
2546
|
+
latent_gene_ids = None,
|
|
2498
2547
|
image_height = None,
|
|
2499
2548
|
image_width = None,
|
|
2500
2549
|
return_decoded_video = None,
|
|
@@ -2610,6 +2659,7 @@ class DynamicsWorldModel(Module):
|
|
|
2610
2659
|
step_sizes = step_size,
|
|
2611
2660
|
rewards = decoded_rewards,
|
|
2612
2661
|
tasks = tasks,
|
|
2662
|
+
latent_gene_ids = latent_gene_ids,
|
|
2613
2663
|
discrete_actions = decoded_discrete_actions,
|
|
2614
2664
|
continuous_actions = decoded_continuous_actions,
|
|
2615
2665
|
proprio = noised_proprio_with_context,
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
+
dreamer4/dreamer4.py,sha256=Ig-t_A8BJWY2eKhsees4_zGXzvtS2JTQTlRuS33ufT8,113812
|
|
3
|
+
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
+
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
+
dreamer4-0.0.90.dist-info/METADATA,sha256=94VfjlhIE6dDY5AbipuRF-Ip7pyhvgQOC4EBKc8ZKRg,3065
|
|
6
|
+
dreamer4-0.0.90.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
dreamer4-0.0.90.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
dreamer4-0.0.90.dist-info/RECORD,,
|
dreamer4-0.0.88.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
|
|
2
|
-
dreamer4/dreamer4.py,sha256=qAVLInnB5kgazzyp1KFIgWEIjJyLhlUQb3RNmybi23g,112219
|
|
3
|
-
dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
|
|
4
|
-
dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
|
|
5
|
-
dreamer4-0.0.88.dist-info/METADATA,sha256=cI2FsEuCzaLcqobMAtLkN_6nBv6iDZQNYr1hlV7YmlY,3065
|
|
6
|
-
dreamer4-0.0.88.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
dreamer4-0.0.88.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
dreamer4-0.0.88.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|