evolutionary-policy-optimization 0.1.18__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -61,8 +61,6 @@ def has_only_one_value(t):
61
61
  return (t == t[0]).all()
62
62
 
63
63
  def all_gather_variable_dim(t, dim = 0, sizes = None):
64
- device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
65
-
66
64
  if not exists(sizes):
67
65
  sizes = gather_sizes(t, dim = dim)
68
66
 
@@ -19,7 +19,7 @@ from torch.utils.data import TensorDataset, DataLoader
19
19
  from torch.utils._pytree import tree_map
20
20
 
21
21
  import einx
22
- from einops import rearrange, repeat, einsum, pack
22
+ from einops import rearrange, repeat, reduce, einsum, pack
23
23
  from einops.layers.torch import Rearrange
24
24
 
25
25
  from evolutionary_policy_optimization.distributed import (
@@ -192,7 +192,6 @@ def calc_generalized_advantage_estimate(
192
192
  use_accelerated = None
193
193
  ):
194
194
  use_accelerated = default(use_accelerated, rewards.is_cuda)
195
- device = rewards.device
196
195
 
197
196
  values = F.pad(values, (0, 1), value = 0.)
198
197
  values, values_next = values[:-1], values[1:]
@@ -202,7 +201,7 @@ def calc_generalized_advantage_estimate(
202
201
 
203
202
  scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
204
203
 
205
- return scan(gates, delta)
204
+ return scan(gates, delta)
206
205
 
207
206
  # evolution related functions
208
207
 
@@ -336,6 +335,53 @@ class DynamicLIMe(Module):
336
335
 
337
336
  return einsum(hiddens, weights, 'l b d, b l -> b d')
338
337
 
338
+ # state normalization
339
+
340
+ class StateNorm(Module):
341
+ def __init__(
342
+ self,
343
+ dim,
344
+ eps = 1e-5
345
+ ):
346
+ # equation (3) in https://arxiv.org/abs/2410.09754 - 'RSMNorm'
347
+
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.eps = eps
351
+
352
+ self.register_buffer('step', tensor(1))
353
+ self.register_buffer('running_mean', torch.zeros(dim))
354
+ self.register_buffer('running_variance', torch.ones(dim))
355
+
356
+ def forward(
357
+ self,
358
+ state
359
+ ):
360
+ assert state.shape[-1] == self.dim, f'expected feature dimension of {self.dim} but received {x.shape[-1]}'
361
+
362
+ time = self.step.item()
363
+ mean = self.running_mean
364
+ variance = self.running_variance
365
+
366
+ normed = (state - mean) / variance.sqrt().clamp(min = self.eps)
367
+
368
+ if not self.training:
369
+ return normed
370
+
371
+ # update running mean and variance
372
+
373
+ new_obs_mean = reduce(state, '... d -> d', 'mean')
374
+ delta = new_obs_mean - mean
375
+
376
+ new_mean = mean + delta / time
377
+ new_variance = (time - 1) / time * (variance + (delta ** 2) / time)
378
+
379
+ self.step.add_(1)
380
+ self.running_mean.copy_(new_mean)
381
+ self.running_variance.copy_(new_variance)
382
+
383
+ return normed
384
+
339
385
  # simple MLP networks, but with latent variables
340
386
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
341
387
 
@@ -444,10 +490,13 @@ class Actor(Module):
444
490
  num_actions,
445
491
  dim,
446
492
  mlp_depth,
493
+ state_norm: StateNorm | None = None,
447
494
  dim_latent = 0,
448
495
  ):
449
496
  super().__init__()
450
497
 
498
+ self.state_norm = state_norm
499
+
451
500
  self.dim_latent = dim_latent
452
501
 
453
502
  self.init_layer = nn.Sequential(
@@ -467,6 +516,10 @@ class Actor(Module):
467
516
  state,
468
517
  latent
469
518
  ):
519
+ if exists(self.state_norm):
520
+ with torch.no_grad():
521
+ self.state_norm.eval()
522
+ state = self.state_norm(state)
470
523
 
471
524
  hidden = self.init_layer(state)
472
525
 
@@ -482,6 +535,7 @@ class Critic(Module):
482
535
  mlp_depth,
483
536
  dim_latent = 0,
484
537
  use_regression = False,
538
+ state_norm: StateNorm | None = None,
485
539
  hl_gauss_loss_kwargs: dict = dict(
486
540
  min_value = -100.,
487
541
  max_value = 100.,
@@ -490,6 +544,8 @@ class Critic(Module):
490
544
  ):
491
545
  super().__init__()
492
546
 
547
+ self.state_norm = state_norm
548
+
493
549
  self.dim_latent = dim_latent
494
550
 
495
551
  self.init_layer = nn.Sequential(
@@ -523,6 +579,12 @@ class Critic(Module):
523
579
  eps_clip = 0.4,
524
580
  use_improved = True
525
581
  ):
582
+
583
+ if exists(self.state_norm):
584
+ with torch.no_grad():
585
+ self.state_norm.eval()
586
+ state = self.state_norm(state)
587
+
526
588
  logits = self.forward(state, latent, return_logits = True)
527
589
 
528
590
  value = self.maybe_bins_to_value(logits)
@@ -535,7 +597,8 @@ class Critic(Module):
535
597
  old_values_lo = old_values - eps_clip
536
598
  old_values_hi = old_values + eps_clip
537
599
 
538
- is_between = lambda lo, hi: (lo < value) & (value < hi)
600
+ def is_between(lo, hi):
601
+ return (lo < value) & (value < hi)
539
602
 
540
603
  clipped_loss = loss_fn(logits, clipped_target)
541
604
  loss = loss_fn(logits, target)
@@ -921,6 +984,7 @@ class Agent(Module):
921
984
  critic: Critic,
922
985
  latent_gene_pool: LatentGenePool | None,
923
986
  optim_klass = AdoptAtan2,
987
+ state_norm: StateNorm | None = None,
924
988
  actor_lr = 8e-4,
925
989
  critic_lr = 8e-4,
926
990
  latent_lr = 1e-5,
@@ -965,12 +1029,20 @@ class Agent(Module):
965
1029
  accelerate = Accelerator(**accelerate_kwargs)
966
1030
  self.accelerate = accelerate
967
1031
 
1032
+ # state norm
1033
+
1034
+ self.state_norm = state_norm
1035
+
968
1036
  # actor, critic, and their shared latent gene pool
969
1037
 
970
1038
  self.actor = actor
971
1039
 
972
1040
  self.critic = critic
973
1041
 
1042
+ if exists(state_norm):
1043
+ # insurance
1044
+ actor.state_norm = critic.state_norm = state_norm
1045
+
974
1046
  self.use_critic_ema = use_critic_ema
975
1047
  self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
976
1048
 
@@ -1034,6 +1106,7 @@ class Agent(Module):
1034
1106
  self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
1035
1107
 
1036
1108
  (
1109
+ self.state_norm,
1037
1110
  self.actor,
1038
1111
  self.critic,
1039
1112
  self.latent_gene_pool,
@@ -1042,6 +1115,7 @@ class Agent(Module):
1042
1115
  self.latent_optim,
1043
1116
  ) = tuple(
1044
1117
  maybe(self.accelerate.prepare)(m) for m in (
1118
+ self.state_norm,
1045
1119
  self.actor,
1046
1120
  self.critic,
1047
1121
  self.latent_gene_pool,
@@ -1068,33 +1142,44 @@ class Agent(Module):
1068
1142
  def unwrapped_latent_gene_pool(self):
1069
1143
  return self.unwrap_model(self.latent_gene_pool)
1070
1144
 
1145
+ def log(self, **data_kwargs):
1146
+ if not self.wrap_with_accelerate:
1147
+ return
1148
+
1149
+ self.accelerate.log(data_kwargs, step = self.step)
1150
+
1071
1151
  def save(self, path, overwrite = False):
1072
1152
  path = Path(path)
1153
+ unwrap = self.unwrap_model
1073
1154
 
1074
1155
  assert not path.exists() or overwrite
1075
1156
 
1076
1157
  pkg = dict(
1077
- actor = self.actor.state_dict(),
1078
- critic = self.critic.state_dict(),
1158
+ state_norm = unwrap(self.state_norm).state_dict() if self.state_norm else None,
1159
+ actor = unwrap(self.actor).state_dict(),
1160
+ critic = unwrap(self.critic).state_dict(),
1079
1161
  critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
1080
- latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
1081
- actor_optim = self.actor_optim.state_dict(),
1082
- critic_optim = self.critic_optim.state_dict(),
1083
- latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
1162
+ latents = unwrap(self.latent_gene_pool).state_dict() if self.has_latent_genes else None,
1163
+ actor_optim = unwrap(self.actor_optim).state_dict(),
1164
+ critic_optim = unwrap(self.critic_optim).state_dict(),
1165
+ latent_optim = unwrap(self.latent_optim).state_dict() if exists(self.latent_optim) else None
1084
1166
  )
1085
1167
 
1086
1168
  torch.save(pkg, str(path))
1087
1169
 
1088
1170
  def load(self, path):
1171
+ unwrap = self.unwrap_model
1089
1172
  path = Path(path)
1090
1173
 
1091
1174
  assert path.exists()
1092
1175
 
1093
1176
  pkg = torch.load(str(path), weights_only = True)
1094
1177
 
1095
- self.actor.load_state_dict(pkg['actor'])
1178
+ unwrap(self.actor).load_state_dict(pkg['actor'])
1096
1179
 
1097
- self.critic.load_state_dict(pkg['critic'])
1180
+ unwrap(self.critic).load_state_dict(pkg['critic'])
1181
+
1182
+ unwrap(self.latent_gene_pool).load_state_dict(pkg['latents'])
1098
1183
 
1099
1184
  if self.use_critic_ema:
1100
1185
  self.critic_ema.load_state_dict(pkg['critic_ema'])
@@ -1102,11 +1187,11 @@ class Agent(Module):
1102
1187
  if exists(pkg.get('latents', None)):
1103
1188
  self.latent_gene_pool.load_state_dict(pkg['latents'])
1104
1189
 
1105
- self.actor_optim.load_state_dict(pkg['actor_optim'])
1106
- self.critic_optim.load_state_dict(pkg['critic_optim'])
1190
+ unwrap(self.actor_optim).load_state_dict(pkg['actor_optim'])
1191
+ unwrap(self.critic_optim).load_state_dict(pkg['critic_optim'])
1107
1192
 
1108
1193
  if exists(pkg.get('latent_optim', None)):
1109
- self.latent_optim.load_state_dict(pkg['latent_optim'])
1194
+ unwrap(self.latent_optim).load_state_dict(pkg['latent_optim'])
1110
1195
 
1111
1196
  @move_input_tensors_to_device
1112
1197
  def get_actor_actions(
@@ -1283,6 +1368,14 @@ class Agent(Module):
1283
1368
  self.critic_optim.step()
1284
1369
  self.critic_optim.zero_grad()
1285
1370
 
1371
+ # log actor critic loss
1372
+
1373
+ self.log(
1374
+ actor_loss = actor_loss.item(),
1375
+ critic_loss = critic_loss.item(),
1376
+ fitness_scores = fitness_scores
1377
+ )
1378
+
1286
1379
  # maybe ema update critic
1287
1380
 
1288
1381
  if self.use_critic_ema:
@@ -1307,6 +1400,19 @@ class Agent(Module):
1307
1400
  self.latent_optim.step()
1308
1401
  self.latent_optim.zero_grad()
1309
1402
 
1403
+ if self.has_diversity_loss:
1404
+ self.log(
1405
+ diversity_loss = diversity_loss.item()
1406
+ )
1407
+
1408
+ # update state norm if needed
1409
+
1410
+ if exists(self.state_norm):
1411
+ self.state_norm.train()
1412
+
1413
+ for _, states, *_ in tqdm(dataloader, desc = 'state norm learning'):
1414
+ self.state_norm(states)
1415
+
1310
1416
  # apply evolution
1311
1417
 
1312
1418
  if self.has_latent_genes:
@@ -1387,12 +1493,15 @@ def create_agent(
1387
1493
  **latent_gene_pool_kwargs
1388
1494
  ) if has_latent_genes else None
1389
1495
 
1496
+ state_norm = StateNorm(dim = dim_state)
1497
+
1390
1498
  actor = Actor(
1391
1499
  num_actions = actor_num_actions,
1392
1500
  dim_state = dim_state,
1393
1501
  dim_latent = dim_latent,
1394
1502
  dim = actor_dim,
1395
1503
  mlp_depth = actor_mlp_depth,
1504
+ state_norm = state_norm,
1396
1505
  **actor_kwargs
1397
1506
  )
1398
1507
 
@@ -1401,12 +1510,14 @@ def create_agent(
1401
1510
  dim_latent = dim_latent,
1402
1511
  dim = critic_dim,
1403
1512
  mlp_depth = critic_mlp_depth,
1513
+ state_norm = state_norm,
1404
1514
  **critic_kwargs
1405
1515
  )
1406
1516
 
1407
1517
  agent = Agent(
1408
1518
  actor = actor,
1409
1519
  critic = critic,
1520
+ state_norm = state_norm,
1410
1521
  latent_gene_pool = latent_gene_pool,
1411
1522
  use_critic_ema = use_critic_ema,
1412
1523
  **kwargs
@@ -1620,4 +1731,4 @@ class EPO(Module):
1620
1731
 
1621
1732
  agent.learn_from(memories)
1622
1733
 
1623
- print(f'training complete')
1734
+ print('training complete')
@@ -1,6 +1,10 @@
1
1
  import torch
2
+ import torch.nn.functional as F
2
3
  from einops import rearrange
3
4
 
5
+ def l2norm(t, dim = -1):
6
+ return F.normalize(t, dim = dim)
7
+
4
8
  def crossover_weights(w1, w2, transpose = False):
5
9
  assert w2.shape == w2.shape
6
10
 
@@ -27,7 +31,7 @@ def crossover_weights(w1, w2, transpose = False):
27
31
 
28
32
  u = torch.where(mask[:, None, :], u1, u2)
29
33
  s = torch.where(mask, s1, s2)
30
- v = torch.where(mask[:, None, :], v1, v2)
34
+ v = torch.where(mask[:, :, None], v1, v2)
31
35
 
32
36
  out = u @ torch.diag_embed(s) @ v.mT
33
37
 
@@ -52,9 +56,13 @@ def mutate_weight(
52
56
  assert rank >= 2
53
57
 
54
58
  u, s, v = torch.svd(w)
59
+
55
60
  u = u + torch.randn_like(u) * mutation_strength
56
61
  v = v + torch.randn_like(v) * mutation_strength
57
62
 
63
+ u = l2norm(u, dim = -2)
64
+ v = l2norm(v, dim = -1)
65
+
58
66
  out = u @ torch.diag_embed(s) @ v.mT
59
67
 
60
68
  if transpose:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.18
3
+ Version: 0.2.0
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -49,6 +49,7 @@ Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
50
  Provides-Extra: test
51
51
  Requires-Dist: pytest; extra == 'test'
52
+ Requires-Dist: ruff>=0.4.2; extra == 'test'
52
53
  Description-Content-Type: text/markdown
53
54
 
54
55
  <img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
@@ -144,6 +145,22 @@ agent.save('./agent.pt', overwrite = True)
144
145
  agent.load('./agent.pt')
145
146
  ```
146
147
 
148
+ ## Contributing
149
+
150
+ At the project root, run
151
+
152
+ ```bash
153
+ $ pip install '.[test]' # or `uv pip install '.[test]'`
154
+ ```
155
+
156
+ Then add your tests to `tests/test_epo.py` and run
157
+
158
+ ```bash
159
+ $ pytest tests/
160
+ ```
161
+
162
+ That's it
163
+
147
164
  ## Citations
148
165
 
149
166
  ```bibtex
@@ -237,4 +254,15 @@ agent.load('./agent.pt')
237
254
  }
238
255
  ```
239
256
 
257
+ ```bibtex
258
+ @article{Lee2024SimBaSB,
259
+ title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
260
+ author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
261
+ journal = {ArXiv},
262
+ year = {2024},
263
+ volume = {abs/2410.09754},
264
+ url = {https://api.semanticscholar.org/CorpusID:273346233}
265
+ }
266
+ ```
267
+
240
268
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -0,0 +1,10 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
+ evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
3
+ evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
+ evolutionary_policy_optimization/epo.py,sha256=sMjCvwT6upNZ48DHegtZ40K9M8PloeW8KBCedx2fc-4,51796
5
+ evolutionary_policy_optimization/experimental.py,sha256=ZyOGHbE4dXmt4zCljSzcUklua4vlOwQtslhFEm0JN94,1716
6
+ evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
+ evolutionary_policy_optimization-0.2.0.dist-info/METADATA,sha256=mqgk3bcYYLwGYYoOqWclUV2QuitxC_shw8nc6hrP8K0,8697
8
+ evolutionary_policy_optimization-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.2.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.2.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=xnbp_dDUASvB_bD7qxuO2GSOJ4THD6L5DmL-FoRYHA0,48524
5
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
- evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.18.dist-info/METADATA,sha256=zVEsduf0ym0Blad2x-5XxmtW_MjYkyzN7iXG_YlcnaQ,7979
8
- evolutionary_policy_optimization-0.1.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.18.dist-info/RECORD,,