evolutionary-policy-optimization 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -76,6 +76,14 @@ def maybe(fn):
76
76
  def interface_torch_numpy(fn, device):
77
77
  # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
78
78
 
79
+ def to_torch_tensor(t):
80
+ if isinstance(t, (np.ndarray, np.float64)):
81
+ t = from_numpy(np.array(t))
82
+ elif isinstance(t, (float, int, bool)):
83
+ t = tensor(t)
84
+
85
+ return t.to(device)
86
+
79
87
  @wraps(fn)
80
88
  def decorated_fn(*args, **kwargs):
81
89
 
@@ -83,7 +91,7 @@ def interface_torch_numpy(fn, device):
83
91
 
84
92
  out = fn(*args, **kwargs)
85
93
 
86
- out = tree_map(lambda t: from_numpy(np.array(t)).to(device) if isinstance(t, (np.ndarray, np.float64)) else t, out)
94
+ out = tree_map(to_torch_tensor, out)
87
95
  return out
88
96
 
89
97
  return decorated_fn
@@ -285,37 +293,42 @@ class PowerLawDist(Module):
285
293
  class MLP(Module):
286
294
  def __init__(
287
295
  self,
288
- dims: tuple[int, ...],
296
+ dim,
297
+ depth,
289
298
  dim_latent = 0,
299
+ expansion_factor = 2.
290
300
  ):
291
301
  super().__init__()
292
302
  dim_latent = default(dim_latent, 0)
293
303
 
294
- assert len(dims) >= 2, 'must have at least two dimensions'
295
-
296
- # add the latent to the first dim
297
-
298
- first_dim, *rest_dims = dims
299
- dims = (first_dim + dim_latent, *rest_dims)
300
-
301
304
  self.dim_latent = dim_latent
302
305
 
303
306
  self.needs_latent = dim_latent > 0
304
307
 
305
308
  self.encode_latent = nn.Sequential(
306
- Linear(dim_latent, dim_latent),
309
+ Linear(dim_latent, dim),
307
310
  nn.SiLU()
308
311
  ) if self.needs_latent else None
309
312
 
310
- # pairs of dimension
313
+ dim_hidden = int(dim * expansion_factor)
311
314
 
312
- dim_pairs = tuple(zip(dims[:-1], dims[1:]))
315
+ # layers
313
316
 
314
- # modules across layers
317
+ layers = []
315
318
 
316
- layers = ModuleList([Linear(dim_in, dim_out) for dim_in, dim_out in dim_pairs])
319
+ for _ in range(depth):
320
+ layer = nn.Sequential(
321
+ nn.LayerNorm(dim, bias = False),
322
+ nn.Linear(dim, dim_hidden),
323
+ nn.SiLU(),
324
+ nn.Linear(dim_hidden, dim),
325
+ )
326
+
327
+ layers.append(layer)
328
+
329
+ # modules across layers
317
330
 
318
- self.layers = layers
331
+ self.layers = ModuleList(layers)
319
332
 
320
333
  def forward(
321
334
  self,
@@ -337,17 +350,14 @@ class MLP(Module):
337
350
 
338
351
  assert latent.shape[0] == x.shape[0], f'received state with batch size {x.shape[0]} but latent ids received had batch size {latent_id.shape[0]}'
339
352
 
340
- x = cat((x, latent), dim = -1)
353
+ x = x * latent
341
354
 
342
355
  # layers
343
356
 
344
357
  for ind, layer in enumerate(self.layers, start = 1):
345
358
  is_last = ind == len(self.layers)
346
359
 
347
- x = layer(x)
348
-
349
- if not is_last:
350
- x = F.silu(x)
360
+ x = layer(x) + x
351
361
 
352
362
  return x
353
363
 
@@ -359,26 +369,24 @@ class Actor(Module):
359
369
  self,
360
370
  dim_state,
361
371
  num_actions,
362
- dim_hiddens: tuple[int, ...],
372
+ dim,
373
+ mlp_depth,
363
374
  dim_latent = 0,
364
375
  ):
365
376
  super().__init__()
366
377
 
367
- assert len(dim_hiddens) >= 2
368
- dim_first, *_, dim_last = dim_hiddens
369
-
370
378
  self.dim_latent = dim_latent
371
379
 
372
380
  self.init_layer = nn.Sequential(
373
- nn.Linear(dim_state, dim_first),
381
+ nn.Linear(dim_state, dim),
374
382
  nn.SiLU()
375
383
  )
376
384
 
377
- self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
385
+ self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
378
386
 
379
387
  self.to_out = nn.Sequential(
380
- nn.SiLU(),
381
- nn.Linear(dim_last, num_actions),
388
+ nn.LayerNorm(dim, bias = False),
389
+ nn.Linear(dim, num_actions, bias = False),
382
390
  )
383
391
 
384
392
  def forward(
@@ -397,34 +405,31 @@ class Critic(Module):
397
405
  def __init__(
398
406
  self,
399
407
  dim_state,
400
- dim_hiddens: tuple[int, ...],
408
+ dim,
409
+ mlp_depth,
401
410
  dim_latent = 0,
402
411
  use_regression = False,
403
412
  hl_gauss_loss_kwargs: dict = dict(
404
- min_value = -10.,
405
- max_value = 10.,
406
- num_bins = 25,
407
- sigma = 0.5
413
+ min_value = -100.,
414
+ max_value = 100.,
415
+ num_bins = 200
408
416
  )
409
417
  ):
410
418
  super().__init__()
411
419
 
412
- assert len(dim_hiddens) >= 2
413
- dim_first, *_, dim_last = dim_hiddens
414
-
415
420
  self.dim_latent = dim_latent
416
421
 
417
422
  self.init_layer = nn.Sequential(
418
- nn.Linear(dim_state, dim_first),
423
+ nn.Linear(dim_state, dim),
419
424
  nn.SiLU()
420
425
  )
421
426
 
422
- self.mlp = MLP(dims = dim_hiddens, dim_latent = dim_latent)
427
+ self.mlp = MLP(dim = dim, depth = mlp_depth, dim_latent = dim_latent)
423
428
 
424
- self.final_act = nn.SiLU()
429
+ self.final_norm = nn.LayerNorm(dim, bias = False)
425
430
 
426
431
  self.to_pred = HLGaussLayer(
427
- dim = dim_last,
432
+ dim = dim,
428
433
  use_regression = use_regression,
429
434
  hl_gauss_loss = hl_gauss_loss_kwargs
430
435
  )
@@ -488,7 +493,7 @@ class Critic(Module):
488
493
 
489
494
  hidden = self.mlp(hidden, latent)
490
495
 
491
- hidden = self.final_act(hidden)
496
+ hidden = self.final_norm(hidden)
492
497
 
493
498
  pred_kwargs = dict(return_logits = return_logits) if not self.use_regression else dict()
494
499
  return self.to_pred(hidden, **pred_kwargs)
@@ -843,16 +848,16 @@ class Agent(Module):
843
848
  critic: Critic,
844
849
  latent_gene_pool: LatentGenePool | None,
845
850
  optim_klass = AdoptAtan2,
846
- actor_lr = 1e-4,
847
- critic_lr = 1e-4,
851
+ actor_lr = 8e-4,
852
+ critic_lr = 8e-4,
848
853
  latent_lr = 1e-5,
849
- actor_weight_decay = 1e-3,
850
- critic_weight_decay = 1e-3,
854
+ actor_weight_decay = 5e-4,
855
+ critic_weight_decay = 5e-4,
851
856
  diversity_aux_loss_weight = 0.,
852
857
  use_critic_ema = True,
853
- critic_ema_beta = 0.99,
854
- max_grad_norm = 0.5,
855
- batch_size = 16,
858
+ critic_ema_beta = 0.95,
859
+ max_grad_norm = 1.0,
860
+ batch_size = 32,
856
861
  calc_gae_kwargs: dict = dict(
857
862
  use_accelerated = False,
858
863
  gamma = 0.99,
@@ -1269,8 +1274,10 @@ def create_agent(
1269
1274
  num_latents,
1270
1275
  dim_latent,
1271
1276
  actor_num_actions,
1272
- actor_dim_hiddens: int | tuple[int, ...],
1273
- critic_dim_hiddens: int | tuple[int, ...],
1277
+ actor_dim,
1278
+ actor_mlp_depth,
1279
+ critic_dim,
1280
+ critic_mlp_depth,
1274
1281
  use_critic_ema = True,
1275
1282
  latent_gene_pool_kwargs: dict = dict(),
1276
1283
  actor_kwargs: dict = dict(),
@@ -1293,14 +1300,16 @@ def create_agent(
1293
1300
  num_actions = actor_num_actions,
1294
1301
  dim_state = dim_state,
1295
1302
  dim_latent = dim_latent,
1296
- dim_hiddens = actor_dim_hiddens,
1303
+ dim = actor_dim,
1304
+ mlp_depth = actor_mlp_depth,
1297
1305
  **actor_kwargs
1298
1306
  )
1299
1307
 
1300
1308
  critic = Critic(
1301
1309
  dim_state = dim_state,
1302
1310
  dim_latent = dim_latent,
1303
- dim_hiddens = critic_dim_hiddens,
1311
+ dim = critic_dim,
1312
+ mlp_depth = critic_mlp_depth,
1304
1313
  **critic_kwargs
1305
1314
  )
1306
1315
 
@@ -1475,7 +1484,7 @@ class EPO(Module):
1475
1484
  log_prob,
1476
1485
  reward,
1477
1486
  value,
1478
- tensor(terminated)
1487
+ terminated
1479
1488
  )
1480
1489
 
1481
1490
  memory = Memory(*tuple(t.cpu() for t in memory))
@@ -1487,7 +1496,7 @@ class EPO(Module):
1487
1496
  if not terminated:
1488
1497
  # add bootstrap value if truncated
1489
1498
 
1490
- next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent)
1499
+ next_value = temp_batch_dim(self.agent.get_critic_values)(state, latent = latent, use_ema_if_available = True, use_unwrapped_model = True)
1491
1500
 
1492
1501
  memory_for_gae = memory._replace(
1493
1502
  episode_id = invalid_episode,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=aOltJBkZVi2FxXao51zdfYaLynIi8T8v3qO1ex2HElg,46058
4
+ evolutionary_policy_optimization/epo.py,sha256=RTMVCo1joMEKIkqTQLsLgTeOuJVbvkNbX9hcOOL0oCw,46088
5
5
  evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.1.12.dist-info/METADATA,sha256=B_eK4c8-dp4rr4K0HxRiNJqY6fK10XmwBbsPm-PE0_k,7625
8
- evolutionary_policy_optimization-0.1.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.1.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.1.12.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.1.14.dist-info/METADATA,sha256=TCj1gzhViuNYP-TnVHCcvWHkTIyqiz-AAO-xhoVahBo,7625
8
+ evolutionary_policy_optimization-0.1.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.1.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.1.14.dist-info/RECORD,,