evolutionary-policy-optimization 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,6 +76,7 @@ def maybe(fn):
76
76
  def interface_torch_numpy(fn, device):
77
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
78
78
 
79
+ @maybe
79
80
  def to_torch_tensor(t):
80
81
  if isinstance(t, (np.ndarray, np.float64)):
81
82
  t = from_numpy(np.array(t))
@@ -287,6 +288,22 @@ class PowerLawDist(Module):
287
288
 
288
289
  return self.values[sampled]
289
290
 
291
+ # FiLM for latent to mlp conditioning
292
+
293
+ class FiLM(Module):
294
+ def __init__(self, dim, dim_out):
295
+ super().__init__()
296
+ self.to_gamma = nn.Linear(dim, dim_out, bias = False)
297
+ self.to_beta = nn.Linear(dim, dim_out, bias = False)
298
+
299
+ nn.init.zeros_(self.to_gamma.weight)
300
+ nn.init.zeros_(self.to_beta.weight)
301
+
302
+ def forward(self, x, cond):
303
+ gamma, beta = self.to_gamma(cond), self.to_beta(cond)
304
+
305
+ return x * (gamma + 1.) + beta
306
+
290
307
  # layer integrated memory
291
308
 
292
309
  class DynamicLIMe(Module):
@@ -301,7 +318,7 @@ class DynamicLIMe(Module):
301
318
  self.to_weights = nn.Sequential(
302
319
  nn.RMSNorm(dim),
303
320
  nn.Linear(dim, num_layers),
304
- nn.ReLU()
321
+ nn.Softmax(dim = -1)
305
322
  )
306
323
 
307
324
  def forward(
@@ -338,7 +355,7 @@ class MLP(Module):
338
355
  self.needs_latent = dim_latent > 0
339
356
 
340
357
  self.encode_latent = nn.Sequential(
341
- Linear(dim_latent, dim),
358
+ Linear(dim_latent, dim * 2),
342
359
  nn.SiLU()
343
360
  ) if self.needs_latent else None
344
361
 
@@ -351,6 +368,11 @@ class MLP(Module):
351
368
  for ind in range(depth):
352
369
  is_first = ind == 0
353
370
 
371
+ film = None
372
+
373
+ if self.needs_latent:
374
+ film = FiLM(dim * 2, dim)
375
+
354
376
  lime = DynamicLIMe(dim, num_layers = ind + 1) if not is_first else None
355
377
 
356
378
  layer = nn.Sequential(
@@ -362,6 +384,7 @@ class MLP(Module):
362
384
 
363
385
  layers.append(ModuleList([
364
386
  lime,
387
+ film,
365
388
  layer
366
389
  ]))
367
390
 
@@ -389,19 +412,20 @@ class MLP(Module):
389
412
 
390
413
  assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
391
414
 
392
- x = x * latent
393
-
394
415
  # layers
395
416
 
396
417
  prev_layer_inputs = [x]
397
418
 
398
- for lime, layer in self.layers:
419
+ for lime, film, layer in self.layers:
399
420
 
400
421
  layer_inp = x
401
422
 
402
423
  if exists(lime):
403
424
  layer_inp = lime(x, prev_layer_inputs)
404
425
 
426
+ if exists(film):
427
+ layer_inp = film(layer_inp, latent)
428
+
405
429
  x = layer(layer_inp) + x
406
430
 
407
431
  prev_layer_inputs.append(x)
@@ -919,6 +943,8 @@ class Agent(Module):
919
943
  eps_clip = 0.4
920
944
  ),
921
945
  use_improved_critic_loss = True,
946
+ shrink_and_perturb_every = None,
947
+ shrink_and_perturb_kwargs: dict = dict(),
922
948
  ema_kwargs: dict = dict(),
923
949
  actor_optim_kwargs: dict = dict(),
924
950
  critic_optim_kwargs: dict = dict(),
@@ -983,6 +1009,12 @@ class Agent(Module):
983
1009
 
984
1010
  self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if exists(latent_gene_pool) and not latent_gene_pool.frozen_latents else None
985
1011
 
1012
+ # shrink and perturb every
1013
+
1014
+ self.should_noise_weights = exists(shrink_and_perturb_every)
1015
+ self.shrink_and_perturb_every = shrink_and_perturb_every
1016
+ self.shrink_and_perturb_ = partial(shrink_and_perturb_, **shrink_and_perturb_kwargs)
1017
+
986
1018
  # promotes latents to be farther apart for diversity maintenance
987
1019
 
988
1020
  self.has_diversity_loss = diversity_aux_loss_weight > 0.
@@ -992,7 +1024,7 @@ class Agent(Module):
992
1024
 
993
1025
  self.unwrap_model = identity if not wrap_with_accelerate else self.accelerate.unwrap_model
994
1026
 
995
- dummy = tensor(0)
1027
+ step = tensor(0)
996
1028
 
997
1029
  self.clip_grad_norm_ = nn.utils.clip_grad_norm_
998
1030
 
@@ -1020,15 +1052,15 @@ class Agent(Module):
1020
1052
  if exists(self.critic_ema):
1021
1053
  self.critic_ema.to(self.accelerate.device)
1022
1054
 
1023
- dummy = dummy.to(self.accelerate.device)
1055
+ step = step.to(self.accelerate.device)
1024
1056
 
1025
1057
  # device tracking
1026
1058
 
1027
- self.register_buffer('dummy', dummy)
1059
+ self.register_buffer('step', step)
1028
1060
 
1029
1061
  @property
1030
1062
  def device(self):
1031
- return self.dummy.device
1063
+ return self.step.device
1032
1064
 
1033
1065
  @property
1034
1066
  def unwrapped_latent_gene_pool(self):
@@ -1278,6 +1310,16 @@ class Agent(Module):
1278
1310
  if self.has_latent_genes:
1279
1311
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
1280
1312
 
1313
+ # maybe shrink and perturb
1314
+
1315
+ if self.should_noise_weights and divisible_by(self.step.item(), self.shrink_and_perturb_every):
1316
+ self.shrink_and_perturb_(self.actor)
1317
+ self.shrink_and_perturb_(self.critic)
1318
+
1319
+ # increment step
1320
+
1321
+ self.step.add_(1)
1322
+
1281
1323
  # reinforcement learning related - ppo
1282
1324
 
1283
1325
  def actor_loss(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -118,8 +118,10 @@ agent = create_agent(
118
118
  num_latents = 16,
119
119
  dim_latent = 32,
120
120
  actor_num_actions = 5,
121
- actor_dim_hiddens = (256, 128),
122
- critic_dim_hiddens = (256, 128, 64),
121
+ actor_dim = 256,
122
+ actor_mlp_depth = 2,
123
+ critic_dim = 256,
124
+ critic_mlp_depth = 3,
123
125
  latent_gene_pool_kwargs = dict(
124
126
  frac_natural_selected = 0.5
125
127
  )
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=iUxd7gbT1GPGwso4utTaxgtjcxvvGNA8AGGUpSOImTM,47108
4
+ evolutionary_policy_optimization/epo.py,sha256=TW-9l9oRN8XQZxSeG5Glkk4rWuxO9JOjjRJO7hJgHZs,48433
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.15.dist-info/METADATA,sha256=e8ofJe5rpGIyEiMd3mJBU-2VjOfFJ8TpGGv7adSKjRM,7962
8
- evolutionary_policy_optimization-0.1.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.15.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.17.dist-info/METADATA,sha256=-2Lt54pfQkw2_LAwB8pjcvLsc4vSMVQjAt-ZiObCRdw,7979
8
+ evolutionary_policy_optimization-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.17.dist-info/RECORD,,