evolutionary-policy-optimization 0.1.19__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. evolutionary_policy_optimization-0.2.1/.github/workflows/lint.yml +21 -0
  2. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/PKG-INFO +29 -1
  3. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/README.md +27 -0
  4. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/distributed.py +0 -2
  5. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/epo.py +106 -16
  6. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/experimental.py +9 -1
  7. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/pyproject.toml +16 -2
  8. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/tests/test_epo.py +5 -5
  9. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/train_gym.py +3 -3
  10. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/.github/workflows/python-publish.yml +0 -0
  11. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/.github/workflows/test.yml +0 -0
  12. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/.gitignore +0 -0
  13. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/LICENSE +0 -0
  14. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/__init__.py +0 -0
  15. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  16. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/evolutionary_policy_optimization/mock_env.py +0 -0
  17. {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.1}/requirements.txt +0 -0
@@ -0,0 +1,21 @@
1
+ name: Ruff
2
+ on: [push, pull_request]
3
+
4
+ jobs:
5
+ build:
6
+
7
+ runs-on: ubuntu-latest
8
+
9
+ steps:
10
+ - uses: actions/checkout@v4
11
+ - name: Set up Python 3.10
12
+ uses: actions/setup-python@v5
13
+ with:
14
+ python-version: "3.10"
15
+ - name: Install dependencies
16
+ run: |
17
+ python -m pip install uv
18
+ python -m uv pip install ruff
19
+ - name: Lint with Ruff
20
+ run: |
21
+ ruff check evolutionary_policy_optimization/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.19
3
+ Version: 0.2.1
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -49,6 +49,7 @@ Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
50
  Provides-Extra: test
51
51
  Requires-Dist: pytest; extra == 'test'
52
+ Requires-Dist: ruff>=0.4.2; extra == 'test'
52
53
  Description-Content-Type: text/markdown
53
54
 
54
55
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -144,6 +145,22 @@ agent.save('./agent.pt', overwrite = True)
144
145
  agent.load('./agent.pt')
145
146
  ```
146
147
 
148
+ ## Contributing
149
+
150
+ At the project root, run
151
+
152
+ ```bash
153
+ $ pip install '.[test]' # or `uv pip install '.[test]'`
154
+ ```
155
+
156
+ Then add your tests to `tests/test_epo.py` and run
157
+
158
+ ```bash
159
+ $ pytest tests/
160
+ ```
161
+
162
+ That's it
163
+
147
164
  ## Citations
148
165
 
149
166
  ```bibtex
@@ -237,4 +254,15 @@ agent.load('./agent.pt')
237
254
  }
238
255
  ```
239
256
 
257
+ ```bibtex
258
+ @article{Lee2024SimBaSB,
259
+ title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
260
+ author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
261
+ journal = {ArXiv},
262
+ year = {2024},
263
+ volume = {abs/2410.09754},
264
+ url = {https://api.semanticscholar.org/CorpusID:273346233}
265
+ }
266
+ ```
267
+
240
268
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -91,6 +91,22 @@ agent.save('./agent.pt', overwrite = True)
91
91
  agent.load('./agent.pt')
92
92
  ```
93
93
 
94
+ ## Contributing
95
+
96
+ At the project root, run
97
+
98
+ ```bash
99
+ $ pip install '.[test]' # or `uv pip install '.[test]'`
100
+ ```
101
+
102
+ Then add your tests to `tests/test_epo.py` and run
103
+
104
+ ```bash
105
+ $ pytest tests/
106
+ ```
107
+
108
+ That's it
109
+
94
110
  ## Citations
95
111
 
96
112
  ```bibtex
@@ -184,4 +200,15 @@ agent.load('./agent.pt')
184
200
  }
185
201
  ```
186
202
 
203
+ ```bibtex
204
+ @article{Lee2024SimBaSB,
205
+ title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
206
+ author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
207
+ journal = {ArXiv},
208
+ year = {2024},
209
+ volume = {abs/2410.09754},
210
+ url = {https://api.semanticscholar.org/CorpusID:273346233}
211
+ }
212
+ ```
213
+
187
214
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -61,8 +61,6 @@ def has_only_one_value(t):
61
61
  return (t == t[0]).all()
62
62
 
63
63
  def all_gather_variable_dim(t, dim = 0, sizes = None):
64
- device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
65
-
66
64
  if not exists(sizes):
67
65
  sizes = gather_sizes(t, dim = dim)
68
66
 
@@ -19,7 +19,7 @@ from torch.utils.data import TensorDataset, DataLoader
19
19
  from torch.utils._pytree import tree_map
20
20
 
21
21
  import einx
22
- from einops import rearrange, repeat, einsum, pack
22
+ from einops import rearrange, repeat, reduce, einsum, pack
23
23
  from einops.layers.torch import Rearrange
24
24
 
25
25
  from evolutionary_policy_optimization.distributed import (
@@ -192,7 +192,6 @@ def calc_generalized_advantage_estimate(
192
192
  use_accelerated = None
193
193
  ):
194
194
  use_accelerated = default(use_accelerated, rewards.is_cuda)
195
- device = rewards.device
196
195
 
197
196
  values = F.pad(values, (0, 1), value = 0.)
198
197
  values, values_next = values[:-1], values[1:]
@@ -202,7 +201,7 @@ def calc_generalized_advantage_estimate(
202
201
 
203
202
  scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
204
203
 
205
- return scan(gates, delta)
204
+ return scan(gates, delta)
206
205
 
207
206
  # evolution related functions
208
207
 
@@ -336,6 +335,53 @@ class DynamicLIMe(Module):
336
335
 
337
336
  return einsum(hiddens, weights, 'l b d, b l -> b d')
338
337
 
338
+ # state normalization
339
+
340
+ class StateNorm(Module):
341
+ def __init__(
342
+ self,
343
+ dim,
344
+ eps = 1e-5
345
+ ):
346
+ # equation (3) in https://arxiv.org/abs/2410.09754 - 'RSMNorm'
347
+
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.eps = eps
351
+
352
+ self.register_buffer('step', tensor(1))
353
+ self.register_buffer('running_mean', torch.zeros(dim))
354
+ self.register_buffer('running_variance', torch.ones(dim))
355
+
356
+ def forward(
357
+ self,
358
+ state
359
+ ):
360
+ assert state.shape[-1] == self.dim, f'expected feature dimension of {self.dim} but received {x.shape[-1]}'
361
+
362
+ time = self.step.item()
363
+ mean = self.running_mean
364
+ variance = self.running_variance
365
+
366
+ normed = (state - mean) / variance.sqrt().clamp(min = self.eps)
367
+
368
+ if not self.training:
369
+ return normed
370
+
371
+ # update running mean and variance
372
+
373
+ new_obs_mean = reduce(state, '... d -> d', 'mean')
374
+ delta = new_obs_mean - mean
375
+
376
+ new_mean = mean + delta / time
377
+ new_variance = (time - 1) / time * (variance + (delta ** 2) / time)
378
+
379
+ self.step.add_(1)
380
+ self.running_mean.copy_(new_mean)
381
+ self.running_variance.copy_(new_variance)
382
+
383
+ return normed
384
+
339
385
  # simple MLP networks, but with latent variables
340
386
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
341
387
 
@@ -444,10 +490,13 @@ class Actor(Module):
444
490
  num_actions,
445
491
  dim,
446
492
  mlp_depth,
493
+ state_norm: StateNorm | None = None,
447
494
  dim_latent = 0,
448
495
  ):
449
496
  super().__init__()
450
497
 
498
+ self.state_norm = state_norm
499
+
451
500
  self.dim_latent = dim_latent
452
501
 
453
502
  self.init_layer = nn.Sequential(
@@ -467,6 +516,10 @@ class Actor(Module):
467
516
  state,
468
517
  latent
469
518
  ):
519
+ if exists(self.state_norm):
520
+ with torch.no_grad():
521
+ self.state_norm.eval()
522
+ state = self.state_norm(state)
470
523
 
471
524
  hidden = self.init_layer(state)
472
525
 
@@ -482,6 +535,7 @@ class Critic(Module):
482
535
  mlp_depth,
483
536
  dim_latent = 0,
484
537
  use_regression = False,
538
+ state_norm: StateNorm | None = None,
485
539
  hl_gauss_loss_kwargs: dict = dict(
486
540
  min_value = -100.,
487
541
  max_value = 100.,
@@ -490,6 +544,8 @@ class Critic(Module):
490
544
  ):
491
545
  super().__init__()
492
546
 
547
+ self.state_norm = state_norm
548
+
493
549
  self.dim_latent = dim_latent
494
550
 
495
551
  self.init_layer = nn.Sequential(
@@ -523,6 +579,12 @@ class Critic(Module):
523
579
  eps_clip = 0.4,
524
580
  use_improved = True
525
581
  ):
582
+
583
+ if exists(self.state_norm):
584
+ with torch.no_grad():
585
+ self.state_norm.eval()
586
+ state = self.state_norm(state)
587
+
526
588
  logits = self.forward(state, latent, return_logits = True)
527
589
 
528
590
  value = self.maybe_bins_to_value(logits)
@@ -535,7 +597,8 @@ class Critic(Module):
535
597
  old_values_lo = old_values - eps_clip
536
598
  old_values_hi = old_values + eps_clip
537
599
 
538
- is_between = lambda lo, hi: (lo < value) & (value < hi)
600
+ def is_between(lo, hi):
601
+ return (lo < value) & (value < hi)
539
602
 
540
603
  clipped_loss = loss_fn(logits, clipped_target)
541
604
  loss = loss_fn(logits, target)
@@ -921,6 +984,7 @@ class Agent(Module):
921
984
  critic: Critic,
922
985
  latent_gene_pool: LatentGenePool | None,
923
986
  optim_klass = AdoptAtan2,
987
+ state_norm: StateNorm | None = None,
924
988
  actor_lr = 8e-4,
925
989
  critic_lr = 8e-4,
926
990
  latent_lr = 1e-5,
@@ -965,12 +1029,20 @@ class Agent(Module):
965
1029
  accelerate = Accelerator(**accelerate_kwargs)
966
1030
  self.accelerate = accelerate
967
1031
 
1032
+ # state norm
1033
+
1034
+ self.state_norm = state_norm
1035
+
968
1036
  # actor, critic, and their shared latent gene pool
969
1037
 
970
1038
  self.actor = actor
971
1039
 
972
1040
  self.critic = critic
973
1041
 
1042
+ if exists(state_norm):
1043
+ # insurance
1044
+ actor.state_norm = critic.state_norm = state_norm
1045
+
974
1046
  self.use_critic_ema = use_critic_ema
975
1047
  self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
976
1048
 
@@ -1034,6 +1106,7 @@ class Agent(Module):
1034
1106
  self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
1035
1107
 
1036
1108
  (
1109
+ self.state_norm,
1037
1110
  self.actor,
1038
1111
  self.critic,
1039
1112
  self.latent_gene_pool,
@@ -1042,6 +1115,7 @@ class Agent(Module):
1042
1115
  self.latent_optim,
1043
1116
  ) = tuple(
1044
1117
  maybe(self.accelerate.prepare)(m) for m in (
1118
+ self.state_norm,
1045
1119
  self.actor,
1046
1120
  self.critic,
1047
1121
  self.latent_gene_pool,
@@ -1076,31 +1150,34 @@ class Agent(Module):
1076
1150
 
1077
1151
  def save(self, path, overwrite = False):
1078
1152
  path = Path(path)
1153
+ unwrap = self.unwrap_model
1079
1154
 
1080
1155
  assert not path.exists() or overwrite
1081
1156
 
1082
1157
  pkg = dict(
1083
- actor = self.actor.state_dict(),
1084
- critic = self.critic.state_dict(),
1158
+ state_norm = unwrap(self.state_norm).state_dict() if self.state_norm else None,
1159
+ actor = unwrap(self.actor).state_dict(),
1160
+ critic = unwrap(self.critic).state_dict(),
1085
1161
  critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
1086
- latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
1087
- actor_optim = self.actor_optim.state_dict(),
1088
- critic_optim = self.critic_optim.state_dict(),
1089
- latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
1162
+ latents = unwrap(self.latent_gene_pool).state_dict() if self.has_latent_genes else None,
1163
+ actor_optim = unwrap(self.actor_optim).state_dict(),
1164
+ critic_optim = unwrap(self.critic_optim).state_dict(),
1165
+ latent_optim = unwrap(self.latent_optim).state_dict() if exists(self.latent_optim) else None
1090
1166
  )
1091
1167
 
1092
1168
  torch.save(pkg, str(path))
1093
1169
 
1094
1170
  def load(self, path):
1171
+ unwrap = self.unwrap_model
1095
1172
  path = Path(path)
1096
1173
 
1097
1174
  assert path.exists()
1098
1175
 
1099
1176
  pkg = torch.load(str(path), weights_only = True)
1100
1177
 
1101
- self.actor.load_state_dict(pkg['actor'])
1178
+ unwrap(self.actor).load_state_dict(pkg['actor'])
1102
1179
 
1103
- self.critic.load_state_dict(pkg['critic'])
1180
+ unwrap(self.critic).load_state_dict(pkg['critic'])
1104
1181
 
1105
1182
  if self.use_critic_ema:
1106
1183
  self.critic_ema.load_state_dict(pkg['critic_ema'])
@@ -1108,11 +1185,11 @@ class Agent(Module):
1108
1185
  if exists(pkg.get('latents', None)):
1109
1186
  self.latent_gene_pool.load_state_dict(pkg['latents'])
1110
1187
 
1111
- self.actor_optim.load_state_dict(pkg['actor_optim'])
1112
- self.critic_optim.load_state_dict(pkg['critic_optim'])
1188
+ unwrap(self.actor_optim).load_state_dict(pkg['actor_optim'])
1189
+ unwrap(self.critic_optim).load_state_dict(pkg['critic_optim'])
1113
1190
 
1114
1191
  if exists(pkg.get('latent_optim', None)):
1115
- self.latent_optim.load_state_dict(pkg['latent_optim'])
1192
+ unwrap(self.latent_optim).load_state_dict(pkg['latent_optim'])
1116
1193
 
1117
1194
  @move_input_tensors_to_device
1118
1195
  def get_actor_actions(
@@ -1326,6 +1403,14 @@ class Agent(Module):
1326
1403
  diversity_loss = diversity_loss.item()
1327
1404
  )
1328
1405
 
1406
+ # update state norm if needed
1407
+
1408
+ if exists(self.state_norm):
1409
+ self.state_norm.train()
1410
+
1411
+ for _, states, *_ in tqdm(dataloader, desc = 'state norm learning'):
1412
+ self.state_norm(states)
1413
+
1329
1414
  # apply evolution
1330
1415
 
1331
1416
  if self.has_latent_genes:
@@ -1406,12 +1491,15 @@ def create_agent(
1406
1491
  **latent_gene_pool_kwargs
1407
1492
  ) if has_latent_genes else None
1408
1493
 
1494
+ state_norm = StateNorm(dim = dim_state)
1495
+
1409
1496
  actor = Actor(
1410
1497
  num_actions = actor_num_actions,
1411
1498
  dim_state = dim_state,
1412
1499
  dim_latent = dim_latent,
1413
1500
  dim = actor_dim,
1414
1501
  mlp_depth = actor_mlp_depth,
1502
+ state_norm = state_norm,
1415
1503
  **actor_kwargs
1416
1504
  )
1417
1505
 
@@ -1420,12 +1508,14 @@ def create_agent(
1420
1508
  dim_latent = dim_latent,
1421
1509
  dim = critic_dim,
1422
1510
  mlp_depth = critic_mlp_depth,
1511
+ state_norm = state_norm,
1423
1512
  **critic_kwargs
1424
1513
  )
1425
1514
 
1426
1515
  agent = Agent(
1427
1516
  actor = actor,
1428
1517
  critic = critic,
1518
+ state_norm = state_norm,
1429
1519
  latent_gene_pool = latent_gene_pool,
1430
1520
  use_critic_ema = use_critic_ema,
1431
1521
  **kwargs
@@ -1639,4 +1729,4 @@ class EPO(Module):
1639
1729
 
1640
1730
  agent.learn_from(memories)
1641
1731
 
1642
- print(f'training complete')
1732
+ print('training complete')
@@ -1,6 +1,10 @@
1
1
  import torch
2
+ import torch.nn.functional as F
2
3
  from einops import rearrange
3
4
 
5
+ def l2norm(t, dim = -1):
6
+ return F.normalize(t, dim = dim)
7
+
4
8
  def crossover_weights(w1, w2, transpose = False):
5
9
  assert w2.shape == w2.shape
6
10
 
@@ -27,7 +31,7 @@ def crossover_weights(w1, w2, transpose = False):
27
31
 
28
32
  u = torch.where(mask[:, None, :], u1, u2)
29
33
  s = torch.where(mask, s1, s2)
30
- v = torch.where(mask[:, None, :], v1, v2)
34
+ v = torch.where(mask[:, :, None], v1, v2)
31
35
 
32
36
  out = u @ torch.diag_embed(s) @ v.mT
33
37
 
@@ -52,9 +56,13 @@ def mutate_weight(
52
56
  assert rank >= 2
53
57
 
54
58
  u, s, v = torch.svd(w)
59
+
55
60
  u = u + torch.randn_like(u) * mutation_strength
56
61
  v = v + torch.randn_like(v) * mutation_strength
57
62
 
63
+ u = l2norm(u, dim = -2)
64
+ v = l2norm(v, dim = -1)
65
+
58
66
  out = u @ torch.diag_embed(s) @ v.mT
59
67
 
60
68
  if transpose:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.19"
3
+ version = "0.2.1"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -47,7 +47,8 @@ examples = [
47
47
  ]
48
48
 
49
49
  test = [
50
- "pytest"
50
+ "pytest",
51
+ "ruff>=0.4.2",
51
52
  ]
52
53
 
53
54
  [tool.pytest.ini_options]
@@ -55,6 +56,19 @@ pythonpath = [
55
56
  "."
56
57
  ]
57
58
 
59
+ [tool.ruff]
60
+ line-length = 1000
61
+
62
+ lint.ignore = [
63
+ "F722", # for jaxtyping shape annotation
64
+ "F401",
65
+ "F821"
66
+ ]
67
+
68
+ lint.extend-select = [
69
+ "W291"
70
+ ]
71
+
58
72
  [build-system]
59
73
  requires = ["hatchling"]
60
74
  build-backend = "hatchling.build"
@@ -19,7 +19,7 @@ def test_readme(
19
19
 
20
20
  latent_pool = LatentGenePool(
21
21
  num_latents = 128,
22
- dim_latent = 32,
22
+ dim_latent = 32,
23
23
  num_islands = num_islands,
24
24
  fast_genetic_algorithm = sampled_mutation_strengths
25
25
  )
@@ -31,8 +31,8 @@ def test_readme(
31
31
 
32
32
  latent = latent_pool(latent_id = latent_ids, state = state)
33
33
 
34
- actions = actor(state, latent)
35
- value = critic(state, latent)
34
+ actions = actor(state, latent) # noqa: F841
35
+ value = critic(state, latent) # noqa: F841
36
36
 
37
37
  # interact with environment and receive rewards, termination etc
38
38
 
@@ -63,8 +63,8 @@ def test_create_agent(
63
63
 
64
64
  state = torch.randn(2, 512)
65
65
 
66
- actions = agent.get_actor_actions(state, latent_id = latent_ids)
67
- value = agent.get_critic_values(state, latent_id = latent_ids)
66
+ actions = agent.get_actor_actions(state, latent_id = latent_ids) # noqa: F841
67
+ value = agent.get_critic_values(state, latent_id = latent_ids) # noqa: F841
68
68
 
69
69
  # interact with environment and receive rewards, termination etc
70
70
 
@@ -21,7 +21,7 @@ env = gym.wrappers.RecordVideo(
21
21
  env = env,
22
22
  video_folder = './recordings',
23
23
  name_prefix = 'lunar-video',
24
- episode_trigger = lambda eps_num: (eps_num % (250 * 4)) == 0,
24
+ episode_trigger = lambda eps_num: (eps_num % 250) == 0,
25
25
  disable_logger = True
26
26
  )
27
27
 
@@ -53,8 +53,8 @@ agent = env.to_epo_agent(
53
53
 
54
54
  epo = EPO(
55
55
  agent,
56
- episodes_per_latent = 5,
57
- max_episode_length = 500,
56
+ episodes_per_latent = 10,
57
+ max_episode_length = 250,
58
58
  action_sample_temperature = 1.,
59
59
  )
60
60