evolutionary-policy-optimization 0.1.19__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,8 +61,6 @@ def has_only_one_value(t):
61
61
  return (t == t[0]).all()
62
62
 
63
63
  def all_gather_variable_dim(t, dim = 0, sizes = None):
64
- device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
65
-
66
64
  if not exists(sizes):
67
65
  sizes = gather_sizes(t, dim = dim)
68
66
 
@@ -19,7 +19,7 @@ from torch.utils.data import TensorDataset, DataLoader
19
19
  from torch.utils._pytree import tree_map
20
20
 
21
21
  import einx
22
- from einops import rearrange, repeat, einsum, pack
22
+ from einops import rearrange, repeat, reduce, einsum, pack
23
23
  from einops.layers.torch import Rearrange
24
24
 
25
25
  from evolutionary_policy_optimization.distributed import (
@@ -192,7 +192,6 @@ def calc_generalized_advantage_estimate(
192
192
  use_accelerated = None
193
193
  ):
194
194
  use_accelerated = default(use_accelerated, rewards.is_cuda)
195
- device = rewards.device
196
195
 
197
196
  values = F.pad(values, (0, 1), value = 0.)
198
197
  values, values_next = values[:-1], values[1:]
@@ -202,7 +201,7 @@ def calc_generalized_advantage_estimate(
202
201
 
203
202
  scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
204
203
 
205
- return scan(gates, delta)
204
+ return scan(gates, delta)
206
205
 
207
206
  # evolution related functions
208
207
 
@@ -336,6 +335,53 @@ class DynamicLIMe(Module):
336
335
 
337
336
  return einsum(hiddens, weights, 'l b d, b l -> b d')
338
337
 
338
+ # state normalization
339
+
340
+ class StateNorm(Module):
341
+ def __init__(
342
+ self,
343
+ dim,
344
+ eps = 1e-5
345
+ ):
346
+ # equation (3) in https://arxiv.org/abs/2410.09754 - 'RSMNorm'
347
+
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.eps = eps
351
+
352
+ self.register_buffer('step', tensor(1))
353
+ self.register_buffer('running_mean', torch.zeros(dim))
354
+ self.register_buffer('running_variance', torch.ones(dim))
355
+
356
+ def forward(
357
+ self,
358
+ state
359
+ ):
360
+ assert state.shape[-1] == self.dim, f'expected feature dimension of {self.dim} but received {x.shape[-1]}'
361
+
362
+ time = self.step.item()
363
+ mean = self.running_mean
364
+ variance = self.running_variance
365
+
366
+ normed = (state - mean) / variance.sqrt().clamp(min = self.eps)
367
+
368
+ if not self.training:
369
+ return normed
370
+
371
+ # update running mean and variance
372
+
373
+ new_obs_mean = reduce(state, '... d -> d', 'mean')
374
+ delta = new_obs_mean - mean
375
+
376
+ new_mean = mean + delta / time
377
+ new_variance = (time - 1) / time * (variance + (delta ** 2) / time)
378
+
379
+ self.step.add_(1)
380
+ self.running_mean.copy_(new_mean)
381
+ self.running_variance.copy_(new_variance)
382
+
383
+ return normed
384
+
339
385
  # simple MLP networks, but with latent variables
340
386
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
341
387
 
@@ -444,10 +490,13 @@ class Actor(Module):
444
490
  num_actions,
445
491
  dim,
446
492
  mlp_depth,
493
+ state_norm: StateNorm | None = None,
447
494
  dim_latent = 0,
448
495
  ):
449
496
  super().__init__()
450
497
 
498
+ self.state_norm = state_norm
499
+
451
500
  self.dim_latent = dim_latent
452
501
 
453
502
  self.init_layer = nn.Sequential(
@@ -467,6 +516,10 @@ class Actor(Module):
467
516
  state,
468
517
  latent
469
518
  ):
519
+ if exists(self.state_norm):
520
+ with torch.no_grad():
521
+ self.state_norm.eval()
522
+ state = self.state_norm(state)
470
523
 
471
524
  hidden = self.init_layer(state)
472
525
 
@@ -482,6 +535,7 @@ class Critic(Module):
482
535
  mlp_depth,
483
536
  dim_latent = 0,
484
537
  use_regression = False,
538
+ state_norm: StateNorm | None = None,
485
539
  hl_gauss_loss_kwargs: dict = dict(
486
540
  min_value = -100.,
487
541
  max_value = 100.,
@@ -490,6 +544,8 @@ class Critic(Module):
490
544
  ):
491
545
  super().__init__()
492
546
 
547
+ self.state_norm = state_norm
548
+
493
549
  self.dim_latent = dim_latent
494
550
 
495
551
  self.init_layer = nn.Sequential(
@@ -523,6 +579,12 @@ class Critic(Module):
523
579
  eps_clip = 0.4,
524
580
  use_improved = True
525
581
  ):
582
+
583
+ if exists(self.state_norm):
584
+ with torch.no_grad():
585
+ self.state_norm.eval()
586
+ state = self.state_norm(state)
587
+
526
588
  logits = self.forward(state, latent, return_logits = True)
527
589
 
528
590
  value = self.maybe_bins_to_value(logits)
@@ -535,7 +597,8 @@ class Critic(Module):
535
597
  old_values_lo = old_values - eps_clip
536
598
  old_values_hi = old_values + eps_clip
537
599
 
538
- is_between = lambda lo, hi: (lo < value) & (value < hi)
600
+ def is_between(lo, hi):
601
+ return (lo < value) & (value < hi)
539
602
 
540
603
  clipped_loss = loss_fn(logits, clipped_target)
541
604
  loss = loss_fn(logits, target)
@@ -921,6 +984,7 @@ class Agent(Module):
921
984
  critic: Critic,
922
985
  latent_gene_pool: LatentGenePool | None,
923
986
  optim_klass = AdoptAtan2,
987
+ state_norm: StateNorm | None = None,
924
988
  actor_lr = 8e-4,
925
989
  critic_lr = 8e-4,
926
990
  latent_lr = 1e-5,
@@ -965,12 +1029,20 @@ class Agent(Module):
965
1029
  accelerate = Accelerator(**accelerate_kwargs)
966
1030
  self.accelerate = accelerate
967
1031
 
1032
+ # state norm
1033
+
1034
+ self.state_norm = state_norm
1035
+
968
1036
  # actor, critic, and their shared latent gene pool
969
1037
 
970
1038
  self.actor = actor
971
1039
 
972
1040
  self.critic = critic
973
1041
 
1042
+ if exists(state_norm):
1043
+ # insurance
1044
+ actor.state_norm = critic.state_norm = state_norm
1045
+
974
1046
  self.use_critic_ema = use_critic_ema
975
1047
  self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
976
1048
 
@@ -1034,6 +1106,7 @@ class Agent(Module):
1034
1106
  self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
1035
1107
 
1036
1108
  (
1109
+ self.state_norm,
1037
1110
  self.actor,
1038
1111
  self.critic,
1039
1112
  self.latent_gene_pool,
@@ -1042,6 +1115,7 @@ class Agent(Module):
1042
1115
  self.latent_optim,
1043
1116
  ) = tuple(
1044
1117
  maybe(self.accelerate.prepare)(m) for m in (
1118
+ self.state_norm,
1045
1119
  self.actor,
1046
1120
  self.critic,
1047
1121
  self.latent_gene_pool,
@@ -1076,31 +1150,36 @@ class Agent(Module):
1076
1150
 
1077
1151
  def save(self, path, overwrite = False):
1078
1152
  path = Path(path)
1153
+ unwrap = self.unwrap_model
1079
1154
 
1080
1155
  assert not path.exists() or overwrite
1081
1156
 
1082
1157
  pkg = dict(
1083
- actor = self.actor.state_dict(),
1084
- critic = self.critic.state_dict(),
1158
+ state_norm = unwrap(self.state_norm).state_dict() if self.state_norm else None,
1159
+ actor = unwrap(self.actor).state_dict(),
1160
+ critic = unwrap(self.critic).state_dict(),
1085
1161
  critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
1086
- latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
1087
- actor_optim = self.actor_optim.state_dict(),
1088
- critic_optim = self.critic_optim.state_dict(),
1089
- latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
1162
+ latents = unwrap(self.latent_gene_pool).state_dict() if self.has_latent_genes else None,
1163
+ actor_optim = unwrap(self.actor_optim).state_dict(),
1164
+ critic_optim = unwrap(self.critic_optim).state_dict(),
1165
+ latent_optim = unwrap(self.latent_optim).state_dict() if exists(self.latent_optim) else None
1090
1166
  )
1091
1167
 
1092
1168
  torch.save(pkg, str(path))
1093
1169
 
1094
1170
  def load(self, path):
1171
+ unwrap = self.unwrap_model
1095
1172
  path = Path(path)
1096
1173
 
1097
1174
  assert path.exists()
1098
1175
 
1099
1176
  pkg = torch.load(str(path), weights_only = True)
1100
1177
 
1101
- self.actor.load_state_dict(pkg['actor'])
1178
+ unwrap(self.actor).load_state_dict(pkg['actor'])
1179
+
1180
+ unwrap(self.critic).load_state_dict(pkg['critic'])
1102
1181
 
1103
- self.critic.load_state_dict(pkg['critic'])
1182
+ unwrap(self.latent_gene_pool).load_state_dict(pkg['latents'])
1104
1183
 
1105
1184
  if self.use_critic_ema:
1106
1185
  self.critic_ema.load_state_dict(pkg['critic_ema'])
@@ -1108,11 +1187,11 @@ class Agent(Module):
1108
1187
  if exists(pkg.get('latents', None)):
1109
1188
  self.latent_gene_pool.load_state_dict(pkg['latents'])
1110
1189
 
1111
- self.actor_optim.load_state_dict(pkg['actor_optim'])
1112
- self.critic_optim.load_state_dict(pkg['critic_optim'])
1190
+ unwrap(self.actor_optim).load_state_dict(pkg['actor_optim'])
1191
+ unwrap(self.critic_optim).load_state_dict(pkg['critic_optim'])
1113
1192
 
1114
1193
  if exists(pkg.get('latent_optim', None)):
1115
- self.latent_optim.load_state_dict(pkg['latent_optim'])
1194
+ unwrap(self.latent_optim).load_state_dict(pkg['latent_optim'])
1116
1195
 
1117
1196
  @move_input_tensors_to_device
1118
1197
  def get_actor_actions(
@@ -1326,6 +1405,14 @@ class Agent(Module):
1326
1405
  diversity_loss = diversity_loss.item()
1327
1406
  )
1328
1407
 
1408
+ # update state norm if needed
1409
+
1410
+ if exists(self.state_norm):
1411
+ self.state_norm.train()
1412
+
1413
+ for _, states, *_ in tqdm(dataloader, desc = 'state norm learning'):
1414
+ self.state_norm(states)
1415
+
1329
1416
  # apply evolution
1330
1417
 
1331
1418
  if self.has_latent_genes:
@@ -1406,12 +1493,15 @@ def create_agent(
1406
1493
  **latent_gene_pool_kwargs
1407
1494
  ) if has_latent_genes else None
1408
1495
 
1496
+ state_norm = StateNorm(dim = dim_state)
1497
+
1409
1498
  actor = Actor(
1410
1499
  num_actions = actor_num_actions,
1411
1500
  dim_state = dim_state,
1412
1501
  dim_latent = dim_latent,
1413
1502
  dim = actor_dim,
1414
1503
  mlp_depth = actor_mlp_depth,
1504
+ state_norm = state_norm,
1415
1505
  **actor_kwargs
1416
1506
  )
1417
1507
 
@@ -1420,12 +1510,14 @@ def create_agent(
1420
1510
  dim_latent = dim_latent,
1421
1511
  dim = critic_dim,
1422
1512
  mlp_depth = critic_mlp_depth,
1513
+ state_norm = state_norm,
1423
1514
  **critic_kwargs
1424
1515
  )
1425
1516
 
1426
1517
  agent = Agent(
1427
1518
  actor = actor,
1428
1519
  critic = critic,
1520
+ state_norm = state_norm,
1429
1521
  latent_gene_pool = latent_gene_pool,
1430
1522
  use_critic_ema = use_critic_ema,
1431
1523
  **kwargs
@@ -1639,4 +1731,4 @@ class EPO(Module):
1639
1731
 
1640
1732
  agent.learn_from(memories)
1641
1733
 
1642
- print(f'training complete')
1734
+ print('training complete')
@@ -1,6 +1,10 @@
1
1
  import torch
2
+ import torch.nn.functional as F
2
3
  from einops import rearrange
3
4
 
5
+ def l2norm(t, dim = -1):
6
+ return F.normalize(t, dim = dim)
7
+
4
8
  def crossover_weights(w1, w2, transpose = False):
5
9
  assert w2.shape == w2.shape
6
10
 
@@ -27,7 +31,7 @@ def crossover_weights(w1, w2, transpose = False):
27
31
 
28
32
  u = torch.where(mask[:, None, :], u1, u2)
29
33
  s = torch.where(mask, s1, s2)
30
- v = torch.where(mask[:, None, :], v1, v2)
34
+ v = torch.where(mask[:, :, None], v1, v2)
31
35
 
32
36
  out = u @ torch.diag_embed(s) @ v.mT
33
37
 
@@ -52,9 +56,13 @@ def mutate_weight(
52
56
  assert rank >= 2
53
57
 
54
58
  u, s, v = torch.svd(w)
59
+
55
60
  u = u + torch.randn_like(u) * mutation_strength
56
61
  v = v + torch.randn_like(v) * mutation_strength
57
62
 
63
+ u = l2norm(u, dim = -2)
64
+ v = l2norm(v, dim = -1)
65
+
58
66
  out = u @ torch.diag_embed(s) @ v.mT
59
67
 
60
68
  if transpose:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.19
3
+ Version: 0.2.0
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -49,6 +49,7 @@ Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
50
  Provides-Extra: test
51
51
  Requires-Dist: pytest; extra == 'test'
52
+ Requires-Dist: ruff>=0.4.2; extra == 'test'
52
53
  Description-Content-Type: text/markdown
53
54
 
54
55
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -144,6 +145,22 @@ agent.save('./agent.pt', overwrite = True)
144
145
  agent.load('./agent.pt')
145
146
  ```
146
147
 
148
+ ## Contributing
149
+
150
+ At the project root, run
151
+
152
+ ```bash
153
+ $ pip install '.[test]' # or `uv pip install '.[test]'`
154
+ ```
155
+
156
+ Then add your tests to `tests/test_epo.py` and run
157
+
158
+ ```bash
159
+ $ pytest tests/
160
+ ```
161
+
162
+ That's it
163
+
147
164
  ## Citations
148
165
 
149
166
  ```bibtex
@@ -237,4 +254,15 @@ agent.load('./agent.pt')
237
254
  }
238
255
  ```
239
256
 
257
+ ```bibtex
258
+ @article{Lee2024SimBaSB,
259
+ title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
260
+ author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
261
+ journal = {ArXiv},
262
+ year = {2024},
263
+ volume = {abs/2410.09754},
264
+ url = {https://api.semanticscholar.org/CorpusID:273346233}
265
+ }
266
+ ```
267
+
240
268
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -0,0 +1,10 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
+ evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
3
+ evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
+ evolutionary_policy_optimization/epo.py,sha256=sMjCvwT6upNZ48DHegtZ40K9M8PloeW8KBCedx2fc-4,51796
5
+ evolutionary_policy_optimization/experimental.py,sha256=ZyOGHbE4dXmt4zCljSzcUklua4vlOwQtslhFEm0JN94,1716
6
+ evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
+ evolutionary_policy_optimization-0.2.0.dist-info/METADATA,sha256=mqgk3bcYYLwGYYoOqWclUV2QuitxC_shw8nc6hrP8K0,8697
8
+ evolutionary_policy_optimization-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.2.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.2.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=JzZdNbzerIMgPg6dlL4eLNJ9_LbmW0xNkgQrgSNoSKA,49084
5
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
- evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.19.dist-info/METADATA,sha256=Jmbgx_z8dJv1W-FQVRJCI14MW2Tv7wbc4VhP_YN3WNw,7979
8
- evolutionary_policy_optimization-0.1.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.19.dist-info/RECORD,,