evolutionary-policy-optimization 0.1.19__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization-0.2.0/.github/workflows/lint.yml +21 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/PKG-INFO +29 -1
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/README.md +27 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/distributed.py +0 -2
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/epo.py +108 -16
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/experimental.py +9 -1
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/pyproject.toml +16 -2
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/tests/test_epo.py +5 -5
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/train_gym.py +3 -3
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/.github/workflows/python-publish.yml +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/.github/workflows/test.yml +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/.gitignore +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/LICENSE +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/__init__.py +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/env_wrappers.py +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/evolutionary_policy_optimization/mock_env.py +0 -0
- {evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/requirements.txt +0 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
name: Ruff
|
2
|
+
on: [push, pull_request]
|
3
|
+
|
4
|
+
jobs:
|
5
|
+
build:
|
6
|
+
|
7
|
+
runs-on: ubuntu-latest
|
8
|
+
|
9
|
+
steps:
|
10
|
+
- uses: actions/checkout@v4
|
11
|
+
- name: Set up Python 3.10
|
12
|
+
uses: actions/setup-python@v5
|
13
|
+
with:
|
14
|
+
python-version: "3.10"
|
15
|
+
- name: Install dependencies
|
16
|
+
run: |
|
17
|
+
python -m pip install uv
|
18
|
+
python -m uv pip install ruff
|
19
|
+
- name: Lint with Ruff
|
20
|
+
run: |
|
21
|
+
ruff check evolutionary_policy_optimization/
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -49,6 +49,7 @@ Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
|
|
49
49
|
Requires-Dist: tqdm; extra == 'examples'
|
50
50
|
Provides-Extra: test
|
51
51
|
Requires-Dist: pytest; extra == 'test'
|
52
|
+
Requires-Dist: ruff>=0.4.2; extra == 'test'
|
52
53
|
Description-Content-Type: text/markdown
|
53
54
|
|
54
55
|
<img width="450px" alt="fig1" src="https://github.com/user-attachments/assets/33bef569-e786-4f09-bdee-56bad7ea9e6d" />
|
@@ -144,6 +145,22 @@ agent.save('./agent.pt', overwrite = True)
|
|
144
145
|
agent.load('./agent.pt')
|
145
146
|
```
|
146
147
|
|
148
|
+
## Contributing
|
149
|
+
|
150
|
+
At the project root, run
|
151
|
+
|
152
|
+
```bash
|
153
|
+
$ pip install '.[test]' # or `uv pip install '.[test]'`
|
154
|
+
```
|
155
|
+
|
156
|
+
Then add your tests to `tests/test_epo.py` and run
|
157
|
+
|
158
|
+
```bash
|
159
|
+
$ pytest tests/
|
160
|
+
```
|
161
|
+
|
162
|
+
That's it
|
163
|
+
|
147
164
|
## Citations
|
148
165
|
|
149
166
|
```bibtex
|
@@ -237,4 +254,15 @@ agent.load('./agent.pt')
|
|
237
254
|
}
|
238
255
|
```
|
239
256
|
|
257
|
+
```bibtex
|
258
|
+
@article{Lee2024SimBaSB,
|
259
|
+
title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
|
260
|
+
author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
|
261
|
+
journal = {ArXiv},
|
262
|
+
year = {2024},
|
263
|
+
volume = {abs/2410.09754},
|
264
|
+
url = {https://api.semanticscholar.org/CorpusID:273346233}
|
265
|
+
}
|
266
|
+
```
|
267
|
+
|
240
268
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/README.md
RENAMED
@@ -91,6 +91,22 @@ agent.save('./agent.pt', overwrite = True)
|
|
91
91
|
agent.load('./agent.pt')
|
92
92
|
```
|
93
93
|
|
94
|
+
## Contributing
|
95
|
+
|
96
|
+
At the project root, run
|
97
|
+
|
98
|
+
```bash
|
99
|
+
$ pip install '.[test]' # or `uv pip install '.[test]'`
|
100
|
+
```
|
101
|
+
|
102
|
+
Then add your tests to `tests/test_epo.py` and run
|
103
|
+
|
104
|
+
```bash
|
105
|
+
$ pytest tests/
|
106
|
+
```
|
107
|
+
|
108
|
+
That's it
|
109
|
+
|
94
110
|
## Citations
|
95
111
|
|
96
112
|
```bibtex
|
@@ -184,4 +200,15 @@ agent.load('./agent.pt')
|
|
184
200
|
}
|
185
201
|
```
|
186
202
|
|
203
|
+
```bibtex
|
204
|
+
@article{Lee2024SimBaSB,
|
205
|
+
title = {SimBa: Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning},
|
206
|
+
author = {Hojoon Lee and Dongyoon Hwang and Donghu Kim and Hyunseung Kim and Jun Jet Tai and Kaushik Subramanian and Peter R. Wurman and Jaegul Choo and Peter Stone and Takuma Seno},
|
207
|
+
journal = {ArXiv},
|
208
|
+
year = {2024},
|
209
|
+
volume = {abs/2410.09754},
|
210
|
+
url = {https://api.semanticscholar.org/CorpusID:273346233}
|
211
|
+
}
|
212
|
+
```
|
213
|
+
|
187
214
|
*Evolution is cleverer than you are.* - Leslie Orgel
|
@@ -61,8 +61,6 @@ def has_only_one_value(t):
|
|
61
61
|
return (t == t[0]).all()
|
62
62
|
|
63
63
|
def all_gather_variable_dim(t, dim = 0, sizes = None):
|
64
|
-
device, rank, world_size = t.device, dist.get_rank(), dist.get_world_size()
|
65
|
-
|
66
64
|
if not exists(sizes):
|
67
65
|
sizes = gather_sizes(t, dim = dim)
|
68
66
|
|
@@ -19,7 +19,7 @@ from torch.utils.data import TensorDataset, DataLoader
|
|
19
19
|
from torch.utils._pytree import tree_map
|
20
20
|
|
21
21
|
import einx
|
22
|
-
from einops import rearrange, repeat, einsum, pack
|
22
|
+
from einops import rearrange, repeat, reduce, einsum, pack
|
23
23
|
from einops.layers.torch import Rearrange
|
24
24
|
|
25
25
|
from evolutionary_policy_optimization.distributed import (
|
@@ -192,7 +192,6 @@ def calc_generalized_advantage_estimate(
|
|
192
192
|
use_accelerated = None
|
193
193
|
):
|
194
194
|
use_accelerated = default(use_accelerated, rewards.is_cuda)
|
195
|
-
device = rewards.device
|
196
195
|
|
197
196
|
values = F.pad(values, (0, 1), value = 0.)
|
198
197
|
values, values_next = values[:-1], values[1:]
|
@@ -202,7 +201,7 @@ def calc_generalized_advantage_estimate(
|
|
202
201
|
|
203
202
|
scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
|
204
203
|
|
205
|
-
return scan(gates, delta)
|
204
|
+
return scan(gates, delta)
|
206
205
|
|
207
206
|
# evolution related functions
|
208
207
|
|
@@ -336,6 +335,53 @@ class DynamicLIMe(Module):
|
|
336
335
|
|
337
336
|
return einsum(hiddens, weights, 'l b d, b l -> b d')
|
338
337
|
|
338
|
+
# state normalization
|
339
|
+
|
340
|
+
class StateNorm(Module):
|
341
|
+
def __init__(
|
342
|
+
self,
|
343
|
+
dim,
|
344
|
+
eps = 1e-5
|
345
|
+
):
|
346
|
+
# equation (3) in https://arxiv.org/abs/2410.09754 - 'RSMNorm'
|
347
|
+
|
348
|
+
super().__init__()
|
349
|
+
self.dim = dim
|
350
|
+
self.eps = eps
|
351
|
+
|
352
|
+
self.register_buffer('step', tensor(1))
|
353
|
+
self.register_buffer('running_mean', torch.zeros(dim))
|
354
|
+
self.register_buffer('running_variance', torch.ones(dim))
|
355
|
+
|
356
|
+
def forward(
|
357
|
+
self,
|
358
|
+
state
|
359
|
+
):
|
360
|
+
assert state.shape[-1] == self.dim, f'expected feature dimension of {self.dim} but received {x.shape[-1]}'
|
361
|
+
|
362
|
+
time = self.step.item()
|
363
|
+
mean = self.running_mean
|
364
|
+
variance = self.running_variance
|
365
|
+
|
366
|
+
normed = (state - mean) / variance.sqrt().clamp(min = self.eps)
|
367
|
+
|
368
|
+
if not self.training:
|
369
|
+
return normed
|
370
|
+
|
371
|
+
# update running mean and variance
|
372
|
+
|
373
|
+
new_obs_mean = reduce(state, '... d -> d', 'mean')
|
374
|
+
delta = new_obs_mean - mean
|
375
|
+
|
376
|
+
new_mean = mean + delta / time
|
377
|
+
new_variance = (time - 1) / time * (variance + (delta ** 2) / time)
|
378
|
+
|
379
|
+
self.step.add_(1)
|
380
|
+
self.running_mean.copy_(new_mean)
|
381
|
+
self.running_variance.copy_(new_variance)
|
382
|
+
|
383
|
+
return normed
|
384
|
+
|
339
385
|
# simple MLP networks, but with latent variables
|
340
386
|
# the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
|
341
387
|
|
@@ -444,10 +490,13 @@ class Actor(Module):
|
|
444
490
|
num_actions,
|
445
491
|
dim,
|
446
492
|
mlp_depth,
|
493
|
+
state_norm: StateNorm | None = None,
|
447
494
|
dim_latent = 0,
|
448
495
|
):
|
449
496
|
super().__init__()
|
450
497
|
|
498
|
+
self.state_norm = state_norm
|
499
|
+
|
451
500
|
self.dim_latent = dim_latent
|
452
501
|
|
453
502
|
self.init_layer = nn.Sequential(
|
@@ -467,6 +516,10 @@ class Actor(Module):
|
|
467
516
|
state,
|
468
517
|
latent
|
469
518
|
):
|
519
|
+
if exists(self.state_norm):
|
520
|
+
with torch.no_grad():
|
521
|
+
self.state_norm.eval()
|
522
|
+
state = self.state_norm(state)
|
470
523
|
|
471
524
|
hidden = self.init_layer(state)
|
472
525
|
|
@@ -482,6 +535,7 @@ class Critic(Module):
|
|
482
535
|
mlp_depth,
|
483
536
|
dim_latent = 0,
|
484
537
|
use_regression = False,
|
538
|
+
state_norm: StateNorm | None = None,
|
485
539
|
hl_gauss_loss_kwargs: dict = dict(
|
486
540
|
min_value = -100.,
|
487
541
|
max_value = 100.,
|
@@ -490,6 +544,8 @@ class Critic(Module):
|
|
490
544
|
):
|
491
545
|
super().__init__()
|
492
546
|
|
547
|
+
self.state_norm = state_norm
|
548
|
+
|
493
549
|
self.dim_latent = dim_latent
|
494
550
|
|
495
551
|
self.init_layer = nn.Sequential(
|
@@ -523,6 +579,12 @@ class Critic(Module):
|
|
523
579
|
eps_clip = 0.4,
|
524
580
|
use_improved = True
|
525
581
|
):
|
582
|
+
|
583
|
+
if exists(self.state_norm):
|
584
|
+
with torch.no_grad():
|
585
|
+
self.state_norm.eval()
|
586
|
+
state = self.state_norm(state)
|
587
|
+
|
526
588
|
logits = self.forward(state, latent, return_logits = True)
|
527
589
|
|
528
590
|
value = self.maybe_bins_to_value(logits)
|
@@ -535,7 +597,8 @@ class Critic(Module):
|
|
535
597
|
old_values_lo = old_values - eps_clip
|
536
598
|
old_values_hi = old_values + eps_clip
|
537
599
|
|
538
|
-
is_between
|
600
|
+
def is_between(lo, hi):
|
601
|
+
return (lo < value) & (value < hi)
|
539
602
|
|
540
603
|
clipped_loss = loss_fn(logits, clipped_target)
|
541
604
|
loss = loss_fn(logits, target)
|
@@ -921,6 +984,7 @@ class Agent(Module):
|
|
921
984
|
critic: Critic,
|
922
985
|
latent_gene_pool: LatentGenePool | None,
|
923
986
|
optim_klass = AdoptAtan2,
|
987
|
+
state_norm: StateNorm | None = None,
|
924
988
|
actor_lr = 8e-4,
|
925
989
|
critic_lr = 8e-4,
|
926
990
|
latent_lr = 1e-5,
|
@@ -965,12 +1029,20 @@ class Agent(Module):
|
|
965
1029
|
accelerate = Accelerator(**accelerate_kwargs)
|
966
1030
|
self.accelerate = accelerate
|
967
1031
|
|
1032
|
+
# state norm
|
1033
|
+
|
1034
|
+
self.state_norm = state_norm
|
1035
|
+
|
968
1036
|
# actor, critic, and their shared latent gene pool
|
969
1037
|
|
970
1038
|
self.actor = actor
|
971
1039
|
|
972
1040
|
self.critic = critic
|
973
1041
|
|
1042
|
+
if exists(state_norm):
|
1043
|
+
# insurance
|
1044
|
+
actor.state_norm = critic.state_norm = state_norm
|
1045
|
+
|
974
1046
|
self.use_critic_ema = use_critic_ema
|
975
1047
|
self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
|
976
1048
|
|
@@ -1034,6 +1106,7 @@ class Agent(Module):
|
|
1034
1106
|
self.clip_grad_norm_ = self.accelerate.clip_grad_norm_
|
1035
1107
|
|
1036
1108
|
(
|
1109
|
+
self.state_norm,
|
1037
1110
|
self.actor,
|
1038
1111
|
self.critic,
|
1039
1112
|
self.latent_gene_pool,
|
@@ -1042,6 +1115,7 @@ class Agent(Module):
|
|
1042
1115
|
self.latent_optim,
|
1043
1116
|
) = tuple(
|
1044
1117
|
maybe(self.accelerate.prepare)(m) for m in (
|
1118
|
+
self.state_norm,
|
1045
1119
|
self.actor,
|
1046
1120
|
self.critic,
|
1047
1121
|
self.latent_gene_pool,
|
@@ -1076,31 +1150,36 @@ class Agent(Module):
|
|
1076
1150
|
|
1077
1151
|
def save(self, path, overwrite = False):
|
1078
1152
|
path = Path(path)
|
1153
|
+
unwrap = self.unwrap_model
|
1079
1154
|
|
1080
1155
|
assert not path.exists() or overwrite
|
1081
1156
|
|
1082
1157
|
pkg = dict(
|
1083
|
-
|
1084
|
-
|
1158
|
+
state_norm = unwrap(self.state_norm).state_dict() if self.state_norm else None,
|
1159
|
+
actor = unwrap(self.actor).state_dict(),
|
1160
|
+
critic = unwrap(self.critic).state_dict(),
|
1085
1161
|
critic_ema = self.critic_ema.state_dict() if self.use_critic_ema else None,
|
1086
|
-
latents = self.latent_gene_pool.state_dict() if self.has_latent_genes else None,
|
1087
|
-
actor_optim = self.actor_optim.state_dict(),
|
1088
|
-
critic_optim = self.critic_optim.state_dict(),
|
1089
|
-
latent_optim = self.latent_optim.state_dict() if exists(self.latent_optim) else None
|
1162
|
+
latents = unwrap(self.latent_gene_pool).state_dict() if self.has_latent_genes else None,
|
1163
|
+
actor_optim = unwrap(self.actor_optim).state_dict(),
|
1164
|
+
critic_optim = unwrap(self.critic_optim).state_dict(),
|
1165
|
+
latent_optim = unwrap(self.latent_optim).state_dict() if exists(self.latent_optim) else None
|
1090
1166
|
)
|
1091
1167
|
|
1092
1168
|
torch.save(pkg, str(path))
|
1093
1169
|
|
1094
1170
|
def load(self, path):
|
1171
|
+
unwrap = self.unwrap_model
|
1095
1172
|
path = Path(path)
|
1096
1173
|
|
1097
1174
|
assert path.exists()
|
1098
1175
|
|
1099
1176
|
pkg = torch.load(str(path), weights_only = True)
|
1100
1177
|
|
1101
|
-
self.actor.load_state_dict(pkg['actor'])
|
1178
|
+
unwrap(self.actor).load_state_dict(pkg['actor'])
|
1179
|
+
|
1180
|
+
unwrap(self.critic).load_state_dict(pkg['critic'])
|
1102
1181
|
|
1103
|
-
self.
|
1182
|
+
unwrap(self.latent_gene_pool).load_state_dict(pkg['latents'])
|
1104
1183
|
|
1105
1184
|
if self.use_critic_ema:
|
1106
1185
|
self.critic_ema.load_state_dict(pkg['critic_ema'])
|
@@ -1108,11 +1187,11 @@ class Agent(Module):
|
|
1108
1187
|
if exists(pkg.get('latents', None)):
|
1109
1188
|
self.latent_gene_pool.load_state_dict(pkg['latents'])
|
1110
1189
|
|
1111
|
-
self.actor_optim.load_state_dict(pkg['actor_optim'])
|
1112
|
-
self.critic_optim.load_state_dict(pkg['critic_optim'])
|
1190
|
+
unwrap(self.actor_optim).load_state_dict(pkg['actor_optim'])
|
1191
|
+
unwrap(self.critic_optim).load_state_dict(pkg['critic_optim'])
|
1113
1192
|
|
1114
1193
|
if exists(pkg.get('latent_optim', None)):
|
1115
|
-
self.latent_optim.load_state_dict(pkg['latent_optim'])
|
1194
|
+
unwrap(self.latent_optim).load_state_dict(pkg['latent_optim'])
|
1116
1195
|
|
1117
1196
|
@move_input_tensors_to_device
|
1118
1197
|
def get_actor_actions(
|
@@ -1326,6 +1405,14 @@ class Agent(Module):
|
|
1326
1405
|
diversity_loss = diversity_loss.item()
|
1327
1406
|
)
|
1328
1407
|
|
1408
|
+
# update state norm if needed
|
1409
|
+
|
1410
|
+
if exists(self.state_norm):
|
1411
|
+
self.state_norm.train()
|
1412
|
+
|
1413
|
+
for _, states, *_ in tqdm(dataloader, desc = 'state norm learning'):
|
1414
|
+
self.state_norm(states)
|
1415
|
+
|
1329
1416
|
# apply evolution
|
1330
1417
|
|
1331
1418
|
if self.has_latent_genes:
|
@@ -1406,12 +1493,15 @@ def create_agent(
|
|
1406
1493
|
**latent_gene_pool_kwargs
|
1407
1494
|
) if has_latent_genes else None
|
1408
1495
|
|
1496
|
+
state_norm = StateNorm(dim = dim_state)
|
1497
|
+
|
1409
1498
|
actor = Actor(
|
1410
1499
|
num_actions = actor_num_actions,
|
1411
1500
|
dim_state = dim_state,
|
1412
1501
|
dim_latent = dim_latent,
|
1413
1502
|
dim = actor_dim,
|
1414
1503
|
mlp_depth = actor_mlp_depth,
|
1504
|
+
state_norm = state_norm,
|
1415
1505
|
**actor_kwargs
|
1416
1506
|
)
|
1417
1507
|
|
@@ -1420,12 +1510,14 @@ def create_agent(
|
|
1420
1510
|
dim_latent = dim_latent,
|
1421
1511
|
dim = critic_dim,
|
1422
1512
|
mlp_depth = critic_mlp_depth,
|
1513
|
+
state_norm = state_norm,
|
1423
1514
|
**critic_kwargs
|
1424
1515
|
)
|
1425
1516
|
|
1426
1517
|
agent = Agent(
|
1427
1518
|
actor = actor,
|
1428
1519
|
critic = critic,
|
1520
|
+
state_norm = state_norm,
|
1429
1521
|
latent_gene_pool = latent_gene_pool,
|
1430
1522
|
use_critic_ema = use_critic_ema,
|
1431
1523
|
**kwargs
|
@@ -1639,4 +1731,4 @@ class EPO(Module):
|
|
1639
1731
|
|
1640
1732
|
agent.learn_from(memories)
|
1641
1733
|
|
1642
|
-
print(
|
1734
|
+
print('training complete')
|
@@ -1,6 +1,10 @@
|
|
1
1
|
import torch
|
2
|
+
import torch.nn.functional as F
|
2
3
|
from einops import rearrange
|
3
4
|
|
5
|
+
def l2norm(t, dim = -1):
|
6
|
+
return F.normalize(t, dim = dim)
|
7
|
+
|
4
8
|
def crossover_weights(w1, w2, transpose = False):
|
5
9
|
assert w2.shape == w2.shape
|
6
10
|
|
@@ -27,7 +31,7 @@ def crossover_weights(w1, w2, transpose = False):
|
|
27
31
|
|
28
32
|
u = torch.where(mask[:, None, :], u1, u2)
|
29
33
|
s = torch.where(mask, s1, s2)
|
30
|
-
v = torch.where(mask[:, None
|
34
|
+
v = torch.where(mask[:, :, None], v1, v2)
|
31
35
|
|
32
36
|
out = u @ torch.diag_embed(s) @ v.mT
|
33
37
|
|
@@ -52,9 +56,13 @@ def mutate_weight(
|
|
52
56
|
assert rank >= 2
|
53
57
|
|
54
58
|
u, s, v = torch.svd(w)
|
59
|
+
|
55
60
|
u = u + torch.randn_like(u) * mutation_strength
|
56
61
|
v = v + torch.randn_like(v) * mutation_strength
|
57
62
|
|
63
|
+
u = l2norm(u, dim = -2)
|
64
|
+
v = l2norm(v, dim = -1)
|
65
|
+
|
58
66
|
out = u @ torch.diag_embed(s) @ v.mT
|
59
67
|
|
60
68
|
if transpose:
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/pyproject.toml
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "evolutionary-policy-optimization"
|
3
|
-
version = "0.
|
3
|
+
version = "0.2.0"
|
4
4
|
description = "EPO - Pytorch"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -47,7 +47,8 @@ examples = [
|
|
47
47
|
]
|
48
48
|
|
49
49
|
test = [
|
50
|
-
"pytest"
|
50
|
+
"pytest",
|
51
|
+
"ruff>=0.4.2",
|
51
52
|
]
|
52
53
|
|
53
54
|
[tool.pytest.ini_options]
|
@@ -55,6 +56,19 @@ pythonpath = [
|
|
55
56
|
"."
|
56
57
|
]
|
57
58
|
|
59
|
+
[tool.ruff]
|
60
|
+
line-length = 1000
|
61
|
+
|
62
|
+
lint.ignore = [
|
63
|
+
"F722", # for jaxtyping shape annotation
|
64
|
+
"F401",
|
65
|
+
"F821"
|
66
|
+
]
|
67
|
+
|
68
|
+
lint.extend-select = [
|
69
|
+
"W291"
|
70
|
+
]
|
71
|
+
|
58
72
|
[build-system]
|
59
73
|
requires = ["hatchling"]
|
60
74
|
build-backend = "hatchling.build"
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/tests/test_epo.py
RENAMED
@@ -19,7 +19,7 @@ def test_readme(
|
|
19
19
|
|
20
20
|
latent_pool = LatentGenePool(
|
21
21
|
num_latents = 128,
|
22
|
-
dim_latent = 32,
|
22
|
+
dim_latent = 32,
|
23
23
|
num_islands = num_islands,
|
24
24
|
fast_genetic_algorithm = sampled_mutation_strengths
|
25
25
|
)
|
@@ -31,8 +31,8 @@ def test_readme(
|
|
31
31
|
|
32
32
|
latent = latent_pool(latent_id = latent_ids, state = state)
|
33
33
|
|
34
|
-
actions = actor(state, latent)
|
35
|
-
value = critic(state, latent)
|
34
|
+
actions = actor(state, latent) # noqa: F841
|
35
|
+
value = critic(state, latent) # noqa: F841
|
36
36
|
|
37
37
|
# interact with environment and receive rewards, termination etc
|
38
38
|
|
@@ -63,8 +63,8 @@ def test_create_agent(
|
|
63
63
|
|
64
64
|
state = torch.randn(2, 512)
|
65
65
|
|
66
|
-
actions = agent.get_actor_actions(state, latent_id = latent_ids)
|
67
|
-
value = agent.get_critic_values(state, latent_id = latent_ids)
|
66
|
+
actions = agent.get_actor_actions(state, latent_id = latent_ids) # noqa: F841
|
67
|
+
value = agent.get_critic_values(state, latent_id = latent_ids) # noqa: F841
|
68
68
|
|
69
69
|
# interact with environment and receive rewards, termination etc
|
70
70
|
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/train_gym.py
RENAMED
@@ -21,7 +21,7 @@ env = gym.wrappers.RecordVideo(
|
|
21
21
|
env = env,
|
22
22
|
video_folder = './recordings',
|
23
23
|
name_prefix = 'lunar-video',
|
24
|
-
episode_trigger = lambda eps_num: (eps_num %
|
24
|
+
episode_trigger = lambda eps_num: (eps_num % 250) == 0,
|
25
25
|
disable_logger = True
|
26
26
|
)
|
27
27
|
|
@@ -53,8 +53,8 @@ agent = env.to_epo_agent(
|
|
53
53
|
|
54
54
|
epo = EPO(
|
55
55
|
agent,
|
56
|
-
episodes_per_latent =
|
57
|
-
max_episode_length =
|
56
|
+
episodes_per_latent = 10,
|
57
|
+
max_episode_length = 250,
|
58
58
|
action_sample_temperature = 1.,
|
59
59
|
)
|
60
60
|
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/.gitignore
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{evolutionary_policy_optimization-0.1.19 → evolutionary_policy_optimization-0.2.0}/requirements.txt
RENAMED
File without changes
|