evolutionary-policy-optimization 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ import torch.nn.functional as F
9
9
  import torch.nn.functional as F
10
10
  from torch.nn import Linear, Module, ModuleList
11
11
 
12
- from einops import rearrange, repeat
12
+ from einops import rearrange, repeat, einsum
13
13
 
14
14
  from assoc_scan import AssocScan
15
15
 
@@ -162,6 +162,7 @@ class MLP(Module):
162
162
  self,
163
163
  dims: tuple[int, ...],
164
164
  dim_latent = 0,
165
+ num_latent_sets = 1
165
166
  ):
166
167
  super().__init__()
167
168
  assert len(dims) >= 2, 'must have at least two dimensions'
@@ -169,17 +170,26 @@ class MLP(Module):
169
170
  # add the latent to the first dim
170
171
 
171
172
  first_dim, *rest_dims = dims
172
- first_dim += dim_latent
173
- dims = (first_dim, *rest_dims)
173
+ dims = (first_dim + dim_latent, *rest_dims)
174
+
175
+ assert num_latent_sets >= 1
174
176
 
175
177
  self.dim_latent = dim_latent
178
+ self.num_latent_sets = num_latent_sets
179
+
176
180
  self.needs_latent = dim_latent > 0
181
+ self.needs_latent_gate = num_latent_sets > 1
177
182
 
178
183
  self.encode_latent = nn.Sequential(
179
184
  Linear(dim_latent, dim_latent),
180
185
  nn.SiLU()
181
186
  ) if self.needs_latent else None
182
187
 
188
+ self.to_latent_gate = nn.Sequential(
189
+ Linear(first_dim, num_latent_sets),
190
+ nn.Softmax(dim = -1)
191
+ ) if self.needs_latent_gate else None
192
+
183
193
  # pairs of dimension
184
194
 
185
195
  dim_pairs = tuple(zip(dims[:-1], dims[1:]))
@@ -195,16 +205,27 @@ class MLP(Module):
195
205
  x,
196
206
  latent = None
197
207
  ):
208
+ batch = x.shape[0]
209
+
198
210
  assert xnor(self.needs_latent, exists(latent))
199
211
 
212
+ if exists(latent) and self.needs_latent_gate:
213
+ # an improvisation where set of genes with controlled expression by environment
214
+
215
+ gates = self.to_latent_gate(x)
216
+ latent = einsum(latent, gates, 'n g, b n -> b g')
217
+ else:
218
+ assert latent.shape[0] == 1
219
+ latent = latent[0]
220
+
200
221
  if exists(latent):
201
222
  # start with naive concatenative conditioning
202
223
  # but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
203
224
 
204
- batch = x.shape[0]
205
-
206
225
  latent = self.encode_latent(latent)
207
- latent = repeat(latent, 'd -> b d', b = batch)
226
+
227
+ if latent.ndim == 1:
228
+ latent = repeat(latent, 'd -> b d', b = batch)
208
229
 
209
230
  x = cat((x, latent), dim = -1)
210
231
 
@@ -321,6 +342,7 @@ class LatentGenePool(Module):
321
342
  self,
322
343
  num_latents, # same as gene pool size
323
344
  dim_latent, # gene dimension
345
+ num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
324
346
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
325
347
  l2norm_latent = False, # whether to enforce latents on hypersphere,
326
348
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
@@ -333,12 +355,13 @@ class LatentGenePool(Module):
333
355
 
334
356
  maybe_l2norm = l2norm if l2norm_latent else identity
335
357
 
336
- latents = torch.randn(num_latents, dim_latent)
358
+ latents = torch.randn(num_latents, num_latent_sets, dim_latent)
337
359
 
338
360
  if l2norm_latent:
339
361
  latents = maybe_l2norm(latents, dim = -1)
340
362
 
341
363
  self.num_latents = num_latents
364
+ self.num_latent_sets = num_latent_sets
342
365
  self.latents = nn.Parameter(latents, requires_grad = False)
343
366
 
344
367
  self.maybe_l2norm = maybe_l2norm
@@ -364,9 +387,17 @@ class LatentGenePool(Module):
364
387
  # network for the latent / gene
365
388
 
366
389
  if isinstance(net, dict):
390
+ assert 'dim_latent' not in net
391
+ assert 'num_latent_sets' not in net
392
+
393
+ net.update(dim_latent = dim_latent)
394
+ net.update(num_latent_sets = num_latent_sets)
395
+
367
396
  net = MLP(**net)
368
397
 
369
398
  assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
399
+ assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
400
+
370
401
  self.net = net
371
402
 
372
403
  @torch.no_grad()
@@ -379,6 +410,7 @@ class LatentGenePool(Module):
379
410
  """
380
411
  p - population
381
412
  g - gene dimension
413
+ n - number of genes per individual
382
414
  """
383
415
  assert self.num_latents > 1
384
416
 
@@ -403,7 +435,7 @@ class LatentGenePool(Module):
403
435
 
404
436
  tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
405
437
 
406
- tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
438
+ tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
407
439
 
408
440
  parents = participants.gather(-2, tournament_winner_indices)
409
441
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
46
46
  Requires-Dist: box2d-py; extra == 'examples-gym'
47
47
  Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
48
48
  Requires-Dist: tqdm; extra == 'examples-gym'
49
+ Provides-Extra: test
50
+ Requires-Dist: pytest; extra == 'test'
49
51
  Description-Content-Type: text/markdown
50
52
 
51
53
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
56
58
 
57
59
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
58
60
 
59
- Besides their latent variable method, I'll also throw in some attempts with crossover in weight space
61
+ Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
+
63
+ Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
60
64
 
61
65
  ## Usage
62
66
 
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
+ evolutionary_policy_optimization/epo.py,sha256=vXkwsQE0CNEUPpguZP-XXsuDyIBN-bS3xDJDXpYlTHM,14772
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.6.dist-info/METADATA,sha256=M_0SbTqdifHQ_R9LWIe7ZfHMXgCiFDJ0sDpD29ctiNk,4460
5
+ evolutionary_policy_optimization-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.6.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=lDhMV535MhUw1di7D7RM-Rr_J6aiuLqV-puh4EaNCd8,13455
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.5.dist-info/METADATA,sha256=uzkB4DrpzLLxbMEeiTID4CDxDxmEX1pO9fabwryDQcY,4098
5
- evolutionary_policy_optimization-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.5.dist-info/RECORD,,