evolutionary-policy-optimization 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from collections import namedtuple
4
+
3
5
  import torch
4
6
  from torch import nn, cat
5
7
  import torch.nn.functional as F
@@ -7,7 +9,7 @@ import torch.nn.functional as F
7
9
  import torch.nn.functional as F
8
10
  from torch.nn import Linear, Module, ModuleList
9
11
 
10
- from einops import rearrange, repeat
12
+ from einops import rearrange, repeat, einsum
11
13
 
12
14
  from assoc_scan import AssocScan
13
15
 
@@ -85,9 +87,9 @@ def critic_loss(
85
87
  # generalized advantage estimate
86
88
 
87
89
  def calc_generalized_advantage_estimate(
88
- rewards: Float['g n'],
89
- values: Float['g n+1'],
90
- masks: Bool['n'],
90
+ rewards, # Float[g n]
91
+ values, # Float[g n+1]
92
+ masks, # Bool[n]
91
93
  gamma = 0.99,
92
94
  lam = 0.95,
93
95
  use_accelerated = None
@@ -160,6 +162,7 @@ class MLP(Module):
160
162
  self,
161
163
  dims: tuple[int, ...],
162
164
  dim_latent = 0,
165
+ num_latent_sets = 1
163
166
  ):
164
167
  super().__init__()
165
168
  assert len(dims) >= 2, 'must have at least two dimensions'
@@ -167,17 +170,26 @@ class MLP(Module):
167
170
  # add the latent to the first dim
168
171
 
169
172
  first_dim, *rest_dims = dims
170
- first_dim += dim_latent
171
- dims = (first_dim, *rest_dims)
173
+ dims = (first_dim + dim_latent, *rest_dims)
174
+
175
+ assert num_latent_sets >= 1
172
176
 
173
177
  self.dim_latent = dim_latent
178
+ self.num_latent_sets = num_latent_sets
179
+
174
180
  self.needs_latent = dim_latent > 0
181
+ self.needs_latent_gate = num_latent_sets > 1
175
182
 
176
183
  self.encode_latent = nn.Sequential(
177
184
  Linear(dim_latent, dim_latent),
178
185
  nn.SiLU()
179
186
  ) if self.needs_latent else None
180
187
 
188
+ self.to_latent_gate = nn.Sequential(
189
+ Linear(first_dim, num_latent_sets),
190
+ nn.Softmax(dim = -1)
191
+ ) if self.needs_latent_gate else None
192
+
181
193
  # pairs of dimension
182
194
 
183
195
  dim_pairs = tuple(zip(dims[:-1], dims[1:]))
@@ -193,16 +205,27 @@ class MLP(Module):
193
205
  x,
194
206
  latent = None
195
207
  ):
208
+ batch = x.shape[0]
209
+
196
210
  assert xnor(self.needs_latent, exists(latent))
197
211
 
212
+ if exists(latent) and self.needs_latent_gate:
213
+ # an improvisation where set of genes with controlled expression by environment
214
+
215
+ gates = self.to_latent_gate(x)
216
+ latent = einsum(latent, gates, 'n g, b n -> b g')
217
+ else:
218
+ assert latent.shape[0] == 1
219
+ latent = latent[0]
220
+
198
221
  if exists(latent):
199
222
  # start with naive concatenative conditioning
200
223
  # but will also offer some alternatives once a spark is seen (film, adaptive linear from stylegan, etc)
201
224
 
202
- batch = x.shape[0]
203
-
204
225
  latent = self.encode_latent(latent)
205
- latent = repeat(latent, 'd -> b d', b = batch)
226
+
227
+ if latent.ndim == 1:
228
+ latent = repeat(latent, 'd -> b d', b = batch)
206
229
 
207
230
  x = cat((x, latent), dim = -1)
208
231
 
@@ -218,6 +241,100 @@ class MLP(Module):
218
241
 
219
242
  return x
220
243
 
244
+ # actor, critic, and agent (actor + critic)
245
+ # eventually, should just create a separate repo and aggregate all the MLP related architectures
246
+
247
+ class Actor(Module):
248
+ def __init__(
249
+ self,
250
+ dim_in,
251
+ num_actions,
252
+ dim_hiddens: tuple[int, ...],
253
+ dim_latent = 0,
254
+ ):
255
+ super().__init__()
256
+
257
+ assert len(dim_hiddens) >= 2
258
+ dim_first, *_, dim_last = dim_hiddens
259
+
260
+ self.init_layer = nn.Sequential(
261
+ nn.Linear(dim_in, dim_first),
262
+ nn.SiLU()
263
+ )
264
+
265
+ self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
266
+
267
+ self.to_out = nn.Sequential(
268
+ nn.SiLU(),
269
+ nn.Linear(dim_last, num_actions),
270
+ )
271
+
272
+ def forward(
273
+ self,
274
+ state,
275
+ latent
276
+ ):
277
+
278
+ hidden = self.init_layer(state)
279
+
280
+ hidden = self.mlp(state, latent)
281
+
282
+ return self.to_out(hidden)
283
+
284
+ class Critic(Module):
285
+ def __init__(
286
+ self,
287
+ dim_in,
288
+ dim_hiddens: tuple[int, ...],
289
+ dim_latent = 0,
290
+ ):
291
+ super().__init__()
292
+
293
+ assert len(dim_hiddens) >= 2
294
+ dim_first, *_, dim_last = dim_hiddens
295
+
296
+ self.init_layer = nn.Sequential(
297
+ nn.Linear(dim_in, dim_first),
298
+ nn.SiLU()
299
+ )
300
+
301
+ self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
302
+
303
+ self.to_out = nn.Sequential(
304
+ nn.SiLU(),
305
+ nn.Linear(dim_last, 1),
306
+ Rearrange('... 1 -> ...')
307
+ )
308
+
309
+ def forward(
310
+ self,
311
+ state,
312
+ latent
313
+ ):
314
+
315
+ hidden = self.init_layer(state)
316
+
317
+ hidden = self.mlp(state, latent)
318
+
319
+ return self.to_out(hidden)
320
+
321
+ class Agent(Module):
322
+ def __init__(
323
+ self,
324
+ actor: Actor,
325
+ critic: Critic,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.actor = actor
330
+ self.critic = critic
331
+
332
+ def forward(
333
+ self,
334
+ memories: list[Memory]
335
+ ):
336
+ raise NotImplementedError
337
+
221
338
  # classes
222
339
 
223
340
  class LatentGenePool(Module):
@@ -225,6 +342,7 @@ class LatentGenePool(Module):
225
342
  self,
226
343
  num_latents, # same as gene pool size
227
344
  dim_latent, # gene dimension
345
+ num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
228
346
  crossover_random = True, # random interp from parent1 to parent2 for crossover, set to `False` for averaging (0.5 constant value)
229
347
  l2norm_latent = False, # whether to enforce latents on hypersphere,
230
348
  frac_tournaments = 0.25, # fraction of genes to participate in tournament - the lower the value, the more chance a less fit gene could be selected
@@ -237,12 +355,13 @@ class LatentGenePool(Module):
237
355
 
238
356
  maybe_l2norm = l2norm if l2norm_latent else identity
239
357
 
240
- latents = torch.randn(num_latents, dim_latent)
358
+ latents = torch.randn(num_latents, num_latent_sets, dim_latent)
241
359
 
242
360
  if l2norm_latent:
243
361
  latents = maybe_l2norm(latents, dim = -1)
244
362
 
245
363
  self.num_latents = num_latents
364
+ self.num_latent_sets = num_latent_sets
246
365
  self.latents = nn.Parameter(latents, requires_grad = False)
247
366
 
248
367
  self.maybe_l2norm = maybe_l2norm
@@ -268,9 +387,17 @@ class LatentGenePool(Module):
268
387
  # network for the latent / gene
269
388
 
270
389
  if isinstance(net, dict):
390
+ assert 'dim_latent' not in net
391
+ assert 'num_latent_sets' not in net
392
+
393
+ net.update(dim_latent = dim_latent)
394
+ net.update(num_latent_sets = num_latent_sets)
395
+
271
396
  net = MLP(**net)
272
397
 
273
398
  assert net.dim_latent == dim_latent, f'the latent dimension set on the MLP {net.dim_latent} must be what was passed into the latent gene pool module ({dim_latent})'
399
+ assert net.num_latent_sets == num_latent_sets, 'number of latent sets must be equal between MLP and and latent gene pool container'
400
+
274
401
  self.net = net
275
402
 
276
403
  @torch.no_grad()
@@ -283,6 +410,7 @@ class LatentGenePool(Module):
283
410
  """
284
411
  p - population
285
412
  g - gene dimension
413
+ n - number of genes per individual
286
414
  """
287
415
  assert self.num_latents > 1
288
416
 
@@ -307,7 +435,7 @@ class LatentGenePool(Module):
307
435
 
308
436
  tournament_winner_indices = participant_fitness.topk(2, dim = -1).indices
309
437
 
310
- tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... g', g = self.dim_latent)
438
+ tournament_winner_indices = repeat(tournament_winner_indices, '... -> ... n g', g = self.dim_latent, n = self.num_latent_sets)
311
439
 
312
440
  parents = participants.gather(-2, tournament_winner_indices)
313
441
 
@@ -368,3 +496,35 @@ class LatentGenePool(Module):
368
496
  latent = latent,
369
497
  **kwargs
370
498
  )
499
+
500
+ # EPO - which is just PPO with natural selection of a population of latent variables conditioning the agent
501
+ # the tricky part is that the latent ids for each episode / trajectory needs to be tracked
502
+
503
+ Memory = namedtuple('Memory', [
504
+ 'state',
505
+ 'latent_gene_id',
506
+ 'action',
507
+ 'log_prob',
508
+ 'reward',
509
+ 'values',
510
+ 'done'
511
+ ])
512
+
513
+ class EPO(Module):
514
+
515
+ def __init__(
516
+ self,
517
+ agent: Agent,
518
+ latent_gene_pool: LatentGenePool
519
+ ):
520
+ super().__init__()
521
+
522
+ self.agent = agent
523
+ self.latent_gene_pool = latent_gene_pool
524
+
525
+ def forward(
526
+ self,
527
+ env
528
+ ) -> list[Memory]:
529
+
530
+ raise NotImplementedError
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -46,6 +46,8 @@ Provides-Extra: examples-gym
46
46
  Requires-Dist: box2d-py; extra == 'examples-gym'
47
47
  Requires-Dist: gymnasium[box2d]>=1.0.0; extra == 'examples-gym'
48
48
  Requires-Dist: tqdm; extra == 'examples-gym'
49
+ Provides-Extra: test
50
+ Requires-Dist: pytest; extra == 'test'
49
51
  Description-Content-Type: text/markdown
50
52
 
51
53
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -56,7 +58,9 @@ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.
56
58
 
57
59
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
58
60
 
59
- Besides their latent variable method, I'll also throw in some attempts with crossover in weight space
61
+ Besides their latent variable strategy, I'll also throw in some attempts with crossover in weight space
62
+
63
+ Update: I see, mixing genetic algorithms with gradient based method is already a research field, under [Memetic algorithms](https://en.wikipedia.org/wiki/Memetic_algorithm). This is also incidentally what I have concluded what Science is. I am in direct exposure to this phenomenon on a daily basis
60
64
 
61
65
  ## Usage
62
66
 
@@ -0,0 +1,7 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
+ evolutionary_policy_optimization/epo.py,sha256=vXkwsQE0CNEUPpguZP-XXsuDyIBN-bS3xDJDXpYlTHM,14772
3
+ evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
+ evolutionary_policy_optimization-0.0.6.dist-info/METADATA,sha256=M_0SbTqdifHQ_R9LWIe7ZfHMXgCiFDJ0sDpD29ctiNk,4460
5
+ evolutionary_policy_optimization-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ evolutionary_policy_optimization-0.0.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ evolutionary_policy_optimization-0.0.6.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=Nu-_OMA8abe4AiW9Gw6MvbZH0JZpMHMqjeXmkC9-7UU,81
2
- evolutionary_policy_optimization/epo.py,sha256=jW6wZ_IbTdO05agc9AghDHawLb0rStfOzHKpSh-vEe0,10783
3
- evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
4
- evolutionary_policy_optimization-0.0.4.dist-info/METADATA,sha256=ZmVUGRQkqOYs1fAyPXjyvIeyc_mShKVTfRVZsIE_Z1Q,4098
5
- evolutionary_policy_optimization-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- evolutionary_policy_optimization-0.0.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- evolutionary_policy_optimization-0.0.4.dist-info/RECORD,,