evolutionary-policy-optimization 0.1.12__tar.gz → 0.1.15__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/PKG-INFO +10 -1
  2. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/README.md +9 -0
  3. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/epo.py +111 -55
  4. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/pyproject.toml +1 -1
  5. evolutionary_policy_optimization-0.1.15/train_gym.py +63 -0
  6. evolutionary_policy_optimization-0.1.12/train_gym.py +0 -44
  7. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/.github/workflows/python-publish.yml +0 -0
  8. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/.github/workflows/test.yml +0 -0
  9. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/.gitignore +0 -0
  10. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/LICENSE +0 -0
  11. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/__init__.py +0 -0
  12. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/distributed.py +0 -0
  13. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  14. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/experimental.py +0 -0
  15. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/evolutionary_policy_optimization/mock_env.py +0 -0
  16. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/requirements.txt +0 -0
  17. {evolutionary_policy_optimization-0.1.12 → evolutionary_policy_optimization-0.1.15}/tests/test_epo.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.12
3
+ Version: 0.1.15
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -226,4 +226,13 @@ agent.load('./agent.pt')
226
226
  }
227
227
  ```
228
228
 
229
+ ```bibtex
230
+ @inproceedings{Gerasimov2025YouDN,
231
+ title = {You Do Not Fully Utilize Transformer's Representation Capacity},
232
+ author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
233
+ year = {2025},
234
+ url = {https://api.semanticscholar.org/CorpusID:276317819}
235
+ }
236
+ ```
237
+
229
238
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -173,4 +173,13 @@ agent.load('./agent.pt')
173
173
  }
174
174
  ```
175
175
 
176
+ ```bibtex
177
+ @inproceedings{Gerasimov2025YouDN,
178
+ title = {You Do Not Fully Utilize Transformer's Representation Capacity},
179
+ author = {Gleb Gerasimov and Yaroslav Aksenov and Nikita Balagansky and Viacheslav Sinii and Daniil Gavrilov},
180
+ year = {2025},
181
+ url = {https://api.semanticscholar.org/CorpusID:276317819}
182
+ }
183
+ ```
184
+
176
185
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -76,6 +76,14 @@ def maybe(fn):
76
76
  def interface_torch_numpy(fn, device):
77
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
78
78
 
79
+ def to_torch_tensor(t):
80
+ if isinstance(t, (np.ndarray, np.float64)):
81
+ t = from_numpy(np.array(t))
82
+ elif isinstance(t, (float, int, bool)):
83
+ t = tensor(t)
84
+
85
+ return t.to(device)
86
+
79
87
  @wraps(fn)
80
88
  def decorated_fn(*args, **kwargs):
81
89
 
@@ -83,7 +91,7 @@ def interface_torch_numpy(fn, device):
83
91
 
84
92
  out = fn(*args, **kwargs)
85
93
 
86
- out = tree_map(lambda t: from_numpy(np.array(t)).to(device) if isinstance(t, (np.ndarray, np.float64)) else t, out)
94
+ out = tree_map(to_torch_tensor, out)
87
95
  return out
88
96
 
89
97
  return decorated_fn
@@ -279,43 +287,87 @@ class PowerLawDist(Module):
279
287
 
280
288
  return self.values[sampled]
281
289
 
290
+ # layer integrated memory
291
+
292
+ class DynamicLIMe(Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ num_layers
297
+ ):
298
+ super().__init__()
299
+ self.num_layers = num_layers
300
+
301
+ self.to_weights = nn.Sequential(
302
+ nn.RMSNorm(dim),
303
+ nn.Linear(dim, num_layers),
304
+ nn.ReLU()
305
+ )
306
+
307
+ def forward(
308
+ self,
309
+ x,
310
+ hiddens
311
+ ):
312
+
313
+ if not is_tensor(hiddens):
314
+ hiddens = stack(hiddens)
315
+
316
+ assert hiddens.shape[0] == self.num_layers, f'expected hiddens to have {self.num_layers} layers but received {tuple(hiddens.shape)} instead (first dimension must be layers)'
317
+
318
+ weights = self.to_weights(x)
319
+
320
+ return einsum(hiddens, weights, 'l b d, b l -> b d')
321
+
282
322
  # simple MLP networks, but with latent variables
283
323
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
284
324
 
285
325
  class MLP(Module):
286
326
  def __init__(
287
327
  self,
288
- dims: tuple[int, ...],
328
+ dim,
329
+ depth,
289
330
  dim_latent = 0,
331
+ expansion_factor = 2.
290
332
  ):
291
333
  super().__init__()
292
334
  dim_latent = default(dim_latent, 0)
293
335
 
294
- assert len(dims) >= 2, 'must have at least two dimensions'
295
-
296
- # add the latent to the first dim
297
-
298
- first_dim, *rest_dims = dims
299
- dims = (first_dim + dim_latent, *rest_dims)
300
-
301
336
  self.dim_latent = dim_latent
302
337
 
303
338
  self.needs_latent = dim_latent > 0
304
339
 
305
340
  self.encode_latent = nn.Sequential(
306
- Linear(dim_latent, dim_latent),
341
+ Linear(dim_latent, dim),
307
342
  nn.SiLU()
308
343
  ) if self.needs_latent else None
309
344
 
310
- # pairs of dimension
345
+ dim_hidden = int(dim * expansion_factor)
311
346
 
312
- dim_pairs = tuple(zip(dims[:-1], dims[1:]))
347
+ # layers
313
348
 
314
- # modules across layers
349
+ layers = []
350
+
351
+ for ind in range(depth):
352
+ is_first = ind == 0
315
353
 
316
- layers = ModuleList([Linear(dim_in, dim_out) for dim_in, dim_out in dim_pairs])
354
+ lime = DynamicLIMe(dim, num_layers = ind + 1) if not is_first else None
355
+
356
+ layer = nn.Sequential(
357
+ nn.RMSNorm(dim),
358
+ nn.Linear(dim, dim_hidden),
359
+ nn.SiLU(),
360
+ nn.Linear(dim_hidden, dim),
361
+ )
317
362
 
318
- self.layers = layers
363
+ layers.append(ModuleList([
364
+ lime,
365
+ layer
366
+ ]))
367
+
368
+ # modules across layers
369
+
370
+ self.layers = ModuleList(layers)
319
371
 
320
372
  def forward(
321
373
  self,
@@ -337,17 +389,22 @@ class MLP(Module):
337
389
 
338
390
  assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
339
391
 
340
- x = cat((x, latent), dim = -1)
392
+ x = x * latent
341
393
 
342
394
  # layers
343
395
 
344
- for ind, layer in enumerate(self.layers, start = 1):
345
- is_last = ind == len(self.layers)
396
+ prev_layer_inputs = [x]
397
+
398
+ for lime, layer in self.layers:
346
399
 
347
- x = layer(x)
400
+ layer_inp = x
348
401
 
349
- if not is_last:
350
- x = F.silu(x)
402
+ if exists(lime):
403
+ layer_inp = lime(x, prev_layer_inputs)
404
+
405
+ x = layer(layer_inp) + x
406
+
407
+ prev_layer_inputs.append(x)
351
408
 
352
409
  return x
353
410
 
@@ -359,26 +416,24 @@ class Actor(Module):
359
416
  self,
360
417
  dim_state,
361
418
  num_actions,
362
- dim_hiddens: tuple[int, ...],
419
+ dim,
420
+ mlp_depth,
363
421
  dim_latent = 0,
364
422
  ):
365
423
  super().__init__()
366
424
 
367
- assert len(dim_hiddens) >= 2
368
- dim_first, *_, dim_last = dim_hiddens
369
-
370
425
  self.dim_latent = dim_latent
371
426
 
372
427
  self.init_layer = nn.Sequential(
373
- nn.Linear(dim_state, dim_first),
428
+ nn.Linear(dim_state, dim),
374
429
  nn.SiLU()
375
430
  )
376
431
 
377
- self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
432
+ self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
378
433
 
379
434
  self.to_out = nn.Sequential(
380
- nn.SiLU(),
381
- nn.Linear(dim_last, num_actions),
435
+ nn.RMSNorm(dim),
436
+ nn.Linear(dim, num_actions, bias = False),
382
437
  )
383
438
 
384
439
  def forward(
@@ -397,34 +452,31 @@ class Critic(Module):
397
452
  def __init__(
398
453
  self,
399
454
  dim_state,
400
- dim_hiddens: tuple[int, ...],
455
+ dim,
456
+ mlp_depth,
401
457
  dim_latent = 0,
402
458
  use_regression = False,
403
459
  hl_gauss_loss_kwargs: dict = dict(
404
- min_value = -10.,
405
- max_value = 10.,
406
- num_bins = 25,
407
- sigma = 0.5
460
+ min_value = -100.,
461
+ max_value = 100.,
462
+ num_bins = 200
408
463
  )
409
464
  ):
410
465
  super().__init__()
411
466
 
412
- assert len(dim_hiddens) >= 2
413
- dim_first, *_, dim_last = dim_hiddens
414
-
415
467
  self.dim_latent = dim_latent
416
468
 
417
469
  self.init_layer = nn.Sequential(
418
- nn.Linear(dim_state, dim_first),
470
+ nn.Linear(dim_state, dim),
419
471
  nn.SiLU()
420
472
  )
421
473
 
422
- self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
474
+ self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
423
475
 
424
- self.final_act = nn.SiLU()
476
+ self.final_norm = nn.RMSNorm(dim)
425
477
 
426
478
  self.to_pred = HLGaussLayer(
427
- dim = dim_last,
479
+ dim = dim,
428
480
  use_regression = use_regression,
429
481
  hl_gauss_loss = hl_gauss_loss_kwargs
430
482
  )
@@ -488,7 +540,7 @@ class Critic(Module):
488
540
 
489
541
  hidden = self.mlp(hidden, latent)
490
542
 
491
- hidden = self.final_act(hidden)
543
+ hidden = self.final_norm(hidden)
492
544
 
493
545
  pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
494
546
  return self.to_pred(hidden, **pred_kwargs)
@@ -843,16 +895,16 @@ class Agent(Module):
843
895
  critic: Critic,
844
896
  latent_gene_pool: LatentGenePool | None,
845
897
  optim_klass = AdoptAtan2,
846
- actor_lr = 1e-4,
847
- critic_lr = 1e-4,
898
+ actor_lr = 8e-4,
899
+ critic_lr = 8e-4,
848
900
  latent_lr = 1e-5,
849
- actor_weight_decay = 1e-3,
850
- critic_weight_decay = 1e-3,
901
+ actor_weight_decay = 5e-4,
902
+ critic_weight_decay = 5e-4,
851
903
  diversity_aux_loss_weight = 0.,
852
904
  use_critic_ema = True,
853
- critic_ema_beta = 0.99,
854
- max_grad_norm = 0.5,
855
- batch_size = 16,
905
+ critic_ema_beta = 0.95,
906
+ max_grad_norm = 1.0,
907
+ batch_size = 32,
856
908
  calc_gae_kwargs: dict = dict(
857
909
  use_accelerated = False,
858
910
  gamma = 0.99,
@@ -1269,8 +1321,10 @@ def create_agent(
1269
1321
  num_latents,
1270
1322
  dim_latent,
1271
1323
  actor_num_actions,
1272
- actor_dim_hiddens: int | tuple[int, ...],
1273
- critic_dim_hiddens: int | tuple[int, ...],
1324
+ actor_dim,
1325
+ actor_mlp_depth,
1326
+ critic_dim,
1327
+ critic_mlp_depth,
1274
1328
  use_critic_ema = True,
1275
1329
  latent_gene_pool_kwargs: dict = dict(),
1276
1330
  actor_kwargs: dict = dict(),
@@ -1293,14 +1347,16 @@ def create_agent(
1293
1347
  num_actions = actor_num_actions,
1294
1348
  dim_state = dim_state,
1295
1349
  dim_latent = dim_latent,
1296
- dim_hiddens = actor_dim_hiddens,
1350
+ dim = actor_dim,
1351
+ mlp_depth = actor_mlp_depth,
1297
1352
  **actor_kwargs
1298
1353
  )
1299
1354
 
1300
1355
  critic = Critic(
1301
1356
  dim_state = dim_state,
1302
1357
  dim_latent = dim_latent,
1303
- dim_hiddens = critic_dim_hiddens,
1358
+ dim = critic_dim,
1359
+ mlp_depth = critic_mlp_depth,
1304
1360
  **critic_kwargs
1305
1361
  )
1306
1362
 
@@ -1475,7 +1531,7 @@ class EPO(Module):
1475
1531
  log_prob,
1476
1532
  reward,
1477
1533
  value,
1478
- tensor(terminated)
1534
+ terminated
1479
1535
  )
1480
1536
 
1481
1537
  memory = Memory(*tuple(t.cpu() for t in memory))
@@ -1487,7 +1543,7 @@ class EPO(Module):
1487
1543
  if not terminated:
1488
1544
  # add bootstrap value if truncated
1489
1545
 
1490
- next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1546
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
1491
1547
 
1492
1548
  memory_for_gae = memory._replace(
1493
1549
  episode_id = invalid_episode,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.12"
3
+ version = "0.1.15"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,63 @@
1
+ import torch
2
+
3
+ from evolutionary_policy_optimization import (
4
+ EPO,
5
+ GymnasiumEnvWrapper
6
+ )
7
+
8
+ # gymnasium
9
+
10
+ from shutil import rmtree
11
+ import gymnasium as gym
12
+
13
+ env = gym.make(
14
+ 'LunarLander-v3',
15
+ render_mode = 'rgb_array'
16
+ )
17
+
18
+ rmtree('./recordings', ignore_errors = True)
19
+
20
+ env = gym.wrappers.RecordVideo(
21
+ env = env,
22
+ video_folder = './recordings',
23
+ name_prefix = 'lunar-video',
24
+ episode_trigger = lambda eps_num: (eps_num % 250) == 0,
25
+ disable_logger = True
26
+ )
27
+
28
+ env = GymnasiumEnvWrapper(env)
29
+
30
+ # epo
31
+
32
+ agent = env.to_epo_agent(
33
+ num_latents = 1,
34
+ dim_latent = 32,
35
+ actor_dim = 128,
36
+ actor_mlp_depth = 2,
37
+ critic_dim = 256,
38
+ critic_mlp_depth = 4,
39
+ latent_gene_pool_kwargs = dict(
40
+ frac_natural_selected = 0.5,
41
+ frac_tournaments = 0.5
42
+ ),
43
+ accelerate_kwargs = dict(
44
+ cpu = False
45
+ ),
46
+ actor_optim_kwargs = dict(
47
+ cautious_factor = 0.1,
48
+ ),
49
+ critic_optim_kwargs = dict(
50
+ cautious_factor = 0.1,
51
+ ),
52
+ )
53
+
54
+ epo = EPO(
55
+ agent,
56
+ episodes_per_latent = 50,
57
+ max_episode_length = 500,
58
+ action_sample_temperature = 1.,
59
+ )
60
+
61
+ epo(agent, env, num_learning_cycles = 100)
62
+
63
+ agent.save('./agent.pt', overwrite = True)
@@ -1,44 +0,0 @@
1
- import torch
2
-
3
- from evolutionary_policy_optimization import (
4
- EPO,
5
- GymnasiumEnvWrapper
6
- )
7
-
8
- # gymnasium
9
-
10
- import gymnasium as gym
11
-
12
- env = gym.make(
13
- 'LunarLander-v3',
14
- render_mode = 'rgb_array'
15
- )
16
-
17
- env = GymnasiumEnvWrapper(env)
18
-
19
- # epo
20
-
21
- agent = env.to_epo_agent(
22
- num_latents = 8,
23
- dim_latent = 32,
24
- actor_dim_hiddens = (256, 128),
25
- critic_dim_hiddens = (256, 128, 64),
26
- latent_gene_pool_kwargs = dict(
27
- frac_natural_selected = 0.5,
28
- frac_tournaments = 0.5
29
- ),
30
- accelerate_kwargs = dict(
31
- cpu = False
32
- )
33
- )
34
-
35
- epo = EPO(
36
- agent,
37
- episodes_per_latent = 5,
38
- max_episode_length = 10,
39
- action_sample_temperature = 1.,
40
- )
41
-
42
- epo(agent, env, num_learning_cycles = 5)
43
-
44
- agent.save('./agent.pt', overwrite = True)