dreamer4 0.0.88__py3-none-any.whl → 0.0.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dreamer4 might be problematic. Click here for more details.

dreamer4/dreamer4.py CHANGED
@@ -1942,6 +1942,7 @@ class DynamicsWorldModel(Module):
1942
1942
  # learned set of latent genes
1943
1943
 
1944
1944
  self.agent_has_genes = num_latent_genes > 0
1945
+ self.num_latent_genes = num_latent_genes
1945
1946
  self.latent_genes = Parameter(randn(num_latent_genes, dim) * 1e-2)
1946
1947
 
1947
1948
  # policy head
@@ -2095,6 +2096,53 @@ class DynamicsWorldModel(Module):
2095
2096
 
2096
2097
  return align_dims_left(times, align_dims_left_to)
2097
2098
 
2099
+ # evolutionary policy optimization - https://web3.arxiv.org/abs/2503.19037
2100
+
2101
+ @torch.no_grad()
2102
+ def evolve_(
2103
+ self,
2104
+ fitness,
2105
+ select_frac = 0.5,
2106
+ tournament_frac = 0.5
2107
+ ):
2108
+ assert fitness.numel() == self.num_latent_genes
2109
+
2110
+ pop = self.latent_genes
2111
+
2112
+ pop_size = self.num_latent_genes
2113
+ num_selected = ceil(pop_size * select_frac)
2114
+ num_children = pop_size - num_selected
2115
+
2116
+ dim_gene = pop.shape[-1]
2117
+
2118
+ # natural selection just a sort and slice
2119
+
2120
+ selected_fitness, selected_indices = fitness.topk(num_selected, dim = -1)
2121
+ selected = pop[selected_indices]
2122
+
2123
+ # use tournament - one tournament per child
2124
+
2125
+ tournament_size = max(2, ceil(num_selected * tournament_frac))
2126
+
2127
+ tournaments = torch.randn((num_children, num_selected), device = self.device).argsort(dim = -1)[:, :tournament_size]
2128
+
2129
+ parent_ids = selected_fitness[tournaments].topk(2, dim = -1).indices # get top 2 winners as parents
2130
+
2131
+ parents = selected[parent_ids]
2132
+
2133
+ # crossover by random interpolation from parent1 to parent2
2134
+
2135
+ random_uniform_mix = torch.randn((num_children, dim_gene), device = self.device).sigmoid()
2136
+
2137
+ parent1, parent2 = parents.unbind(dim = 1)
2138
+ children = parent1.lerp(parent2, random_uniform_mix)
2139
+
2140
+ # store next population
2141
+
2142
+ next_pop = cat((selected, children))
2143
+
2144
+ self.latent_genes.copy_(next_pop)
2145
+
2098
2146
  # interacting with env for experience
2099
2147
 
2100
2148
  @torch.no_grad()
@@ -2495,6 +2543,7 @@ class DynamicsWorldModel(Module):
2495
2543
  batch_size = 1,
2496
2544
  agent_index = 0,
2497
2545
  tasks: int | Tensor | None = None,
2546
+ latent_gene_ids = None,
2498
2547
  image_height = None,
2499
2548
  image_width = None,
2500
2549
  return_decoded_video = None,
@@ -2610,6 +2659,7 @@ class DynamicsWorldModel(Module):
2610
2659
  step_sizes = step_size,
2611
2660
  rewards = decoded_rewards,
2612
2661
  tasks = tasks,
2662
+ latent_gene_ids = latent_gene_ids,
2613
2663
  discrete_actions = decoded_discrete_actions,
2614
2664
  continuous_actions = decoded_continuous_actions,
2615
2665
  proprio = noised_proprio_with_context,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dreamer4
3
- Version: 0.0.88
3
+ Version: 0.0.90
4
4
  Summary: Dreamer 4
5
5
  Project-URL: Homepage, https://pypi.org/project/dreamer4/
6
6
  Project-URL: Repository, https://github.com/lucidrains/dreamer4
@@ -0,0 +1,8 @@
1
+ dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
+ dreamer4/dreamer4.py,sha256=Ig-t_A8BJWY2eKhsees4_zGXzvtS2JTQTlRuS33ufT8,113812
3
+ dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
+ dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
+ dreamer4-0.0.90.dist-info/METADATA,sha256=94VfjlhIE6dDY5AbipuRF-Ip7pyhvgQOC4EBKc8ZKRg,3065
6
+ dreamer4-0.0.90.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ dreamer4-0.0.90.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ dreamer4-0.0.90.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- dreamer4/__init__.py,sha256=Jssh1obzDRtTfBLZl36kXge1cIQlMjf_8DyjPulvKSk,183
2
- dreamer4/dreamer4.py,sha256=qAVLInnB5kgazzyp1KFIgWEIjJyLhlUQb3RNmybi23g,112219
3
- dreamer4/mocks.py,sha256=TfqOB_Gq6N_GggBYwa6ZAJQx38ntlYbXZe23Ne4jshw,2502
4
- dreamer4/trainers.py,sha256=D2b7WTgTHElLhIWLFgl2Ct2knGJLTk91HHpC5UkNvG0,14028
5
- dreamer4-0.0.88.dist-info/METADATA,sha256=cI2FsEuCzaLcqobMAtLkN_6nBv6iDZQNYr1hlV7YmlY,3065
6
- dreamer4-0.0.88.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- dreamer4-0.0.88.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- dreamer4-0.0.88.dist-info/RECORD,,