evolutionary-policy-optimization 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +51 -9
- {evolutionary_policy_optimization-0.1.15.dist-info → evolutionary_policy_optimization-0.1.17.dist-info}/METADATA +5 -3
- {evolutionary_policy_optimization-0.1.15.dist-info → evolutionary_policy_optimization-0.1.17.dist-info}/RECORD +5 -5
- {evolutionary_policy_optimization-0.1.15.dist-info → evolutionary_policy_optimization-0.1.17.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.1.15.dist-info → evolutionary_policy_optimization-0.1.17.dist-info}/licenses/LICENSE +0 -0
@@ -76,6 +76,7 @@ def maybe(fn):
|
|
76
76
|
def interface_torch_numpy(fn, device):
|
77
77
|
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
78
78
|
|
79
|
+
@maybe
|
79
80
|
def to_torch_tensor(t):
|
80
81
|
if isinstance(t, (np.ndarray, np.float64)):
|
81
82
|
t = from_numpy(np.array(t))
|
@@ -287,6 +288,22 @@ class PowerLawDist(Module):
|
|
287
288
|
|
288
289
|
return self.values[sampled]
|
289
290
|
|
291
|
+
# FiLM for latent to mlp conditioning
|
292
|
+
|
293
|
+
class FiLM(Module):
|
294
|
+
def __init__(self, dim, dim_out):
|
295
|
+
super().__init__()
|
296
|
+
self.to_gamma = nn.Linear(dim, dim_out, bias = False)
|
297
|
+
self.to_beta = nn.Linear(dim, dim_out, bias = False)
|
298
|
+
|
299
|
+
nn.init.zeros_(self.to_gamma.weight)
|
300
|
+
nn.init.zeros_(self.to_beta.weight)
|
301
|
+
|
302
|
+
def forward(self, x, cond):
|
303
|
+
gamma, beta = self.to_gamma(cond), self.to_beta(cond)
|
304
|
+
|
305
|
+
return x * (gamma + 1.) + beta
|
306
|
+
|
290
307
|
# layer integrated memory
|
291
308
|
|
292
309
|
class DynamicLIMe(Module):
|
@@ -301,7 +318,7 @@ class DynamicLIMe(Module):
|
|
301
318
|
self.to_weights = nn.Sequential(
|
302
319
|
nn.RMSNorm(dim),
|
303
320
|
nn.Linear(dim, num_layers),
|
304
|
-
nn.
|
321
|
+
nn.Softmax(dim = -1)
|
305
322
|
)
|
306
323
|
|
307
324
|
def forward(
|
@@ -338,7 +355,7 @@ class MLP(Module):
|
|
338
355
|
self.needs_latent = dim_latent > 0
|
339
356
|
|
340
357
|
self.encode_latent = nn.Sequential(
|
341
|
-
Linear(dim_latent, dim),
|
358
|
+
Linear(dim_latent, dim * 2),
|
342
359
|
nn.SiLU()
|
343
360
|
) if self.needs_latent else None
|
344
361
|
|
@@ -351,6 +368,11 @@ class MLP(Module):
|
|
351
368
|
for ind in range(depth):
|
352
369
|
is_first = ind == 0
|
353
370
|
|
371
|
+
film = None
|
372
|
+
|
373
|
+
if self.needs_latent:
|
374
|
+
film = FiLM(dim * 2, dim)
|
375
|
+
|
354
376
|
lime = DynamicLIMe(dim, num_layers = ind + 1) if not is_first else None
|
355
377
|
|
356
378
|
layer = nn.Sequential(
|
@@ -362,6 +384,7 @@ class MLP(Module):
|
|
362
384
|
|
363
385
|
layers.append(ModuleList([
|
364
386
|
lime,
|
387
|
+
film,
|
365
388
|
layer
|
366
389
|
]))
|
367
390
|
|
@@ -389,19 +412,20 @@ class MLP(Module):
|
|
389
412
|
|
390
413
|
assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
|
391
414
|
|
392
|
-
x = x * latent
|
393
|
-
|
394
415
|
# layers
|
395
416
|
|
396
417
|
prev_layer_inputs = [x]
|
397
418
|
|
398
|
-
for lime, layer in self.layers:
|
419
|
+
for lime, film, layer in self.layers:
|
399
420
|
|
400
421
|
layer_inp = x
|
401
422
|
|
402
423
|
if exists(lime):
|
403
424
|
layer_inp = lime(x, prev_layer_inputs)
|
404
425
|
|
426
|
+
if exists(film):
|
427
|
+
layer_inp = film(layer_inp, latent)
|
428
|
+
|
405
429
|
x = layer(layer_inp) + x
|
406
430
|
|
407
431
|
prev_layer_inputs.append(x)
|
@@ -919,6 +943,8 @@ class Agent(Module):
|
|
919
943
|
eps_clip = 0.4
|
920
944
|
),
|
921
945
|
use_improved_critic_loss = True,
|
946
|
+
shrink_and_perturb_every = None,
|
947
|
+
shrink_and_perturb_kwargs: dict = dict(),
|
922
948
|
ema_kwargs: dict = dict(),
|
923
949
|
actor_optim_kwargs: dict = dict(),
|
924
950
|
critic_optim_kwargs: dict = dict(),
|
@@ -983,6 +1009,12 @@ class Agent(Module):
|
|
983
1009
|
|
984
1010
|
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
|
985
1011
|
|
1012
|
+
# shrink and perturb every
|
1013
|
+
|
1014
|
+
self.should_noise_weights = exists(shrink_and_perturb_every)
|
1015
|
+
self.shrink_and_perturb_every = shrink_and_perturb_every
|
1016
|
+
self.shrink_and_perturb_ = partial(shrink_and_perturb_, **shrink_and_perturb_kwargs)
|
1017
|
+
|
986
1018
|
# promotes latents to be farther apart for diversity maintenance
|
987
1019
|
|
988
1020
|
self.has_diversity_loss = diversity_aux_loss_weight > 0.
|
@@ -992,7 +1024,7 @@ class Agent(Module):
|
|
992
1024
|
|
993
1025
|
self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
|
994
1026
|
|
995
|
-
|
1027
|
+
step = tensor(0)
|
996
1028
|
|
997
1029
|
self.clip_grad_norm_ = nn.utils.clip_grad_norm_
|
998
1030
|
|
@@ -1020,15 +1052,15 @@ class Agent(Module):
|
|
1020
1052
|
if exists(self.critic_ema):
|
1021
1053
|
self.critic_ema.to(self.accelerate.device)
|
1022
1054
|
|
1023
|
-
|
1055
|
+
step = step.to(self.accelerate.device)
|
1024
1056
|
|
1025
1057
|
# device tracking
|
1026
1058
|
|
1027
|
-
self.register_buffer('
|
1059
|
+
self.register_buffer('step', step)
|
1028
1060
|
|
1029
1061
|
@property
|
1030
1062
|
def device(self):
|
1031
|
-
return self.
|
1063
|
+
return self.step.device
|
1032
1064
|
|
1033
1065
|
@property
|
1034
1066
|
def unwrapped_latent_gene_pool(self):
|
@@ -1278,6 +1310,16 @@ class Agent(Module):
|
|
1278
1310
|
if self.has_latent_genes:
|
1279
1311
|
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
1280
1312
|
|
1313
|
+
# maybe shrink and perturb
|
1314
|
+
|
1315
|
+
if self.should_noise_weights and divisible_by(self.step.item(), self.shrink_and_perturb_every):
|
1316
|
+
self.shrink_and_perturb_(self.actor)
|
1317
|
+
self.shrink_and_perturb_(self.critic)
|
1318
|
+
|
1319
|
+
# increment step
|
1320
|
+
|
1321
|
+
self.step.add_(1)
|
1322
|
+
|
1281
1323
|
# reinforcement learning related - ppo
|
1282
1324
|
|
1283
1325
|
def actor_loss(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.17
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -118,8 +118,10 @@ agent = create_agent(
|
|
118
118
|
num_latents = 16,
|
119
119
|
dim_latent = 32,
|
120
120
|
actor_num_actions = 5,
|
121
|
-
|
122
|
-
|
121
|
+
actor_dim = 256,
|
122
|
+
actor_mlp_depth = 2,
|
123
|
+
critic_dim = 256,
|
124
|
+
critic_mlp_depth = 3,
|
123
125
|
latent_gene_pool_kwargs = dict(
|
124
126
|
frac_natural_selected = 0.5
|
125
127
|
)
|
@@ -1,10 +1,10 @@
|
|
1
1
|
evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
|
2
2
|
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
3
|
evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
|
4
|
-
evolutionary_policy_optimization/epo.py,sha256=
|
4
|
+
evolutionary_policy_optimization/epo.py,sha256=TW-9l9oRN8XQZxSeG5Glkk4rWuxO9JOjjRJO7hJgHZs,48433
|
5
5
|
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
6
6
|
evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
|
7
|
-
evolutionary_policy_optimization-0.1.
|
8
|
-
evolutionary_policy_optimization-0.1.
|
9
|
-
evolutionary_policy_optimization-0.1.
|
10
|
-
evolutionary_policy_optimization-0.1.
|
7
|
+
evolutionary_policy_optimization-0.1.17.dist-info/METADATA,sha256=-2Lt54pfQkw2_LAwB8pjcvLsc4vSMVQjAt-ZiObCRdw,7979
|
8
|
+
evolutionary_policy_optimization-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
9
|
+
evolutionary_policy_optimization-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
10
|
+
evolutionary_policy_optimization-0.1.17.dist-info/RECORD,,
|
File without changes
|