evolutionary-policy-optimization 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +10 -37
- {evolutionary_policy_optimization-0.0.25.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.27.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.25.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.25.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.25.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/licenses/LICENSE +0 -0
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
|
|
303
303
|
self,
|
304
304
|
num_latents, # same as gene pool size
|
305
305
|
dim_latent, # gene dimension
|
306
|
-
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
307
306
|
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
308
307
|
dim_state = None,
|
309
308
|
frozen_latents = True,
|
@@ -320,29 +319,17 @@ class LatentGenePool(Module):
|
|
320
319
|
|
321
320
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
322
321
|
|
323
|
-
latents = torch.randn(num_latents,
|
322
|
+
latents = torch.randn(num_latents, dim_latent)
|
324
323
|
|
325
324
|
if l2norm_latent:
|
326
325
|
latents = maybe_l2norm(latents, dim = -1)
|
327
326
|
|
328
327
|
self.num_latents = num_latents
|
329
|
-
self.
|
328
|
+
self.frozen_latents = frozen_latents
|
330
329
|
self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
|
331
330
|
|
332
331
|
self.maybe_l2norm = maybe_l2norm
|
333
332
|
|
334
|
-
# gene expression as a function of environment
|
335
|
-
|
336
|
-
self.num_latent_sets = num_latent_sets
|
337
|
-
|
338
|
-
if self.needs_latent_gate:
|
339
|
-
assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
|
340
|
-
|
341
|
-
self.to_latent_gate = nn.Sequential(
|
342
|
-
Linear(dim_state, num_latent_sets),
|
343
|
-
nn.Softmax(dim = -1)
|
344
|
-
) if self.needs_latent_gate else None
|
345
|
-
|
346
333
|
# some derived values
|
347
334
|
|
348
335
|
assert num_islands >= 1
|
@@ -410,6 +397,10 @@ class LatentGenePool(Module):
|
|
410
397
|
|
411
398
|
fireflies = rearrange(fireflies, 'i p ... -> (i p) ...')
|
412
399
|
|
400
|
+
# maybe fireflies on hypersphere
|
401
|
+
|
402
|
+
fireflies = self.maybe_l2norm(fireflies)
|
403
|
+
|
413
404
|
if not inplace:
|
414
405
|
return fireflies
|
415
406
|
|
@@ -456,7 +447,7 @@ class LatentGenePool(Module):
|
|
456
447
|
|
457
448
|
return genes
|
458
449
|
|
459
|
-
genes = rearrange(genes, '(i p)
|
450
|
+
genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
|
460
451
|
|
461
452
|
orig_genes = genes
|
462
453
|
|
@@ -465,7 +456,7 @@ class LatentGenePool(Module):
|
|
465
456
|
|
466
457
|
sorted_indices = fitness.sort(dim = -1).indices
|
467
458
|
natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
|
468
|
-
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ...
|
459
|
+
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
|
469
460
|
|
470
461
|
genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
|
471
462
|
|
@@ -480,7 +471,7 @@ class LatentGenePool(Module):
|
|
480
471
|
parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
|
481
472
|
parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
|
482
473
|
|
483
|
-
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents)
|
474
|
+
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
|
484
475
|
|
485
476
|
parents = genes.gather(1, parent_gene_ids_for_gather)
|
486
477
|
parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
|
@@ -551,22 +542,6 @@ class LatentGenePool(Module):
|
|
551
542
|
|
552
543
|
latent = self.latents[latent_id]
|
553
544
|
|
554
|
-
if self.needs_latent_gate:
|
555
|
-
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
556
|
-
|
557
|
-
if not fetching_multiple_latents:
|
558
|
-
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
559
|
-
|
560
|
-
assert latent.shape[0] == state.shape[0]
|
561
|
-
|
562
|
-
gates = self.to_latent_gate(state)
|
563
|
-
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
564
|
-
|
565
|
-
elif fetching_multiple_latents:
|
566
|
-
latent = latent[:, 0]
|
567
|
-
else:
|
568
|
-
latent = latent[0]
|
569
|
-
|
570
545
|
latent = self.maybe_l2norm(latent)
|
571
546
|
|
572
547
|
if not exists(net):
|
@@ -608,7 +583,7 @@ class Agent(Module):
|
|
608
583
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
609
584
|
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
610
585
|
|
611
|
-
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.
|
586
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
612
587
|
|
613
588
|
def get_actor_actions(
|
614
589
|
self,
|
@@ -683,7 +658,6 @@ def create_agent(
|
|
683
658
|
actor_num_actions,
|
684
659
|
actor_dim_hiddens: int | tuple[int, ...],
|
685
660
|
critic_dim_hiddens: int | tuple[int, ...],
|
686
|
-
num_latent_sets = 1
|
687
661
|
) -> Agent:
|
688
662
|
|
689
663
|
actor = Actor(
|
@@ -703,7 +677,6 @@ def create_agent(
|
|
703
677
|
dim_state = dim_state,
|
704
678
|
num_latents = num_latents,
|
705
679
|
dim_latent = dim_latent,
|
706
|
-
num_latent_sets = num_latent_sets
|
707
680
|
)
|
708
681
|
|
709
682
|
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.27
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=UCCwYK-b20X-5Cq-pah1NTeHFc_35b4xZ3y0aSR8aaI,20783
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.27.dist-info/METADATA,sha256=pJ2kQD5YtKDSUp1TCO_hsrRMh6FCMm8dyu6WrpVHiQk,4958
|
5
|
+
evolutionary_policy_optimization-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.27.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=BLwy7PBZOjw6H7MFvMq9CC7Mdm3K8fpzBNH6HbNu6LY,21927
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.25.dist-info/METADATA,sha256=p3-_SuLvKs8E0z1l567qA0Pbsv2dOLlrJPX4WYoZaB4,4958
|
5
|
-
evolutionary_policy_optimization-0.0.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.25.dist-info/RECORD,,
|
File without changes
|