evolutionary-policy-optimization 0.0.26__py3-none-any.whl → 0.0.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +6 -37
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/METADATA +1 -1
- evolutionary_policy_optimization-0.0.27.dist-info/RECORD +7 -0
- evolutionary_policy_optimization-0.0.26.dist-info/RECORD +0 -7
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.26.dist-info → evolutionary_policy_optimization-0.0.27.dist-info}/licenses/LICENSE +0 -0
@@ -303,7 +303,6 @@ class LatentGenePool(Module):
|
|
303
303
|
self,
|
304
304
|
num_latents, # same as gene pool size
|
305
305
|
dim_latent, # gene dimension
|
306
|
-
num_latent_sets = 1, # allow for sets of latents / gene per individual, expression of a set controlled by the environment
|
307
306
|
num_islands = 1, # add the island strategy, which has been effectively used in a few recent works
|
308
307
|
dim_state = None,
|
309
308
|
frozen_latents = True,
|
@@ -320,29 +319,17 @@ class LatentGenePool(Module):
|
|
320
319
|
|
321
320
|
maybe_l2norm = l2norm if l2norm_latent else identity
|
322
321
|
|
323
|
-
latents = torch.randn(num_latents,
|
322
|
+
latents = torch.randn(num_latents, dim_latent)
|
324
323
|
|
325
324
|
if l2norm_latent:
|
326
325
|
latents = maybe_l2norm(latents, dim = -1)
|
327
326
|
|
328
327
|
self.num_latents = num_latents
|
329
|
-
self.
|
328
|
+
self.frozen_latents = frozen_latents
|
330
329
|
self.latents = nn.Parameter(latents, requires_grad = not frozen_latents)
|
331
330
|
|
332
331
|
self.maybe_l2norm = maybe_l2norm
|
333
332
|
|
334
|
-
# gene expression as a function of environment
|
335
|
-
|
336
|
-
self.num_latent_sets = num_latent_sets
|
337
|
-
|
338
|
-
if self.needs_latent_gate:
|
339
|
-
assert exists(dim_state), '`dim_state` must be passed in if using gated gene expression'
|
340
|
-
|
341
|
-
self.to_latent_gate = nn.Sequential(
|
342
|
-
Linear(dim_state, num_latent_sets),
|
343
|
-
nn.Softmax(dim = -1)
|
344
|
-
) if self.needs_latent_gate else None
|
345
|
-
|
346
333
|
# some derived values
|
347
334
|
|
348
335
|
assert num_islands >= 1
|
@@ -460,7 +447,7 @@ class LatentGenePool(Module):
|
|
460
447
|
|
461
448
|
return genes
|
462
449
|
|
463
|
-
genes = rearrange(genes, '(i p)
|
450
|
+
genes = rearrange(genes, '(i p) ... -> i p ...', i = islands)
|
464
451
|
|
465
452
|
orig_genes = genes
|
466
453
|
|
@@ -469,7 +456,7 @@ class LatentGenePool(Module):
|
|
469
456
|
|
470
457
|
sorted_indices = fitness.sort(dim = -1).indices
|
471
458
|
natural_selected_indices = sorted_indices[..., -self.num_natural_selected:]
|
472
|
-
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ...
|
459
|
+
natural_select_gene_indices = repeat(natural_selected_indices, '... -> ... g', g = genes.shape[-1])
|
473
460
|
|
474
461
|
genes, fitness = genes.gather(1, natural_select_gene_indices), fitness.gather(1, natural_selected_indices)
|
475
462
|
|
@@ -484,7 +471,7 @@ class LatentGenePool(Module):
|
|
484
471
|
parent_indices_at_tournament = participant_fitness.topk(2, dim = -1).indices
|
485
472
|
parent_gene_ids = rand_tournament_gene_ids.gather(-1, parent_indices_at_tournament)
|
486
473
|
|
487
|
-
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents)
|
474
|
+
parent_gene_ids_for_gather = repeat(parent_gene_ids, 'i p parents -> i (p parents) g', g = genes.shape[-1])
|
488
475
|
|
489
476
|
parents = genes.gather(1, parent_gene_ids_for_gather)
|
490
477
|
parents = rearrange(parents, 'i (p parents) ... -> i p parents ...', parents = 2)
|
@@ -555,22 +542,6 @@ class LatentGenePool(Module):
|
|
555
542
|
|
556
543
|
latent = self.latents[latent_id]
|
557
544
|
|
558
|
-
if self.needs_latent_gate:
|
559
|
-
assert exists(state), 'state must be passed in if greater than number of 1 latent set'
|
560
|
-
|
561
|
-
if not fetching_multiple_latents:
|
562
|
-
latent = repeat(latent, '... -> b ...', b = state.shape[0])
|
563
|
-
|
564
|
-
assert latent.shape[0] == state.shape[0]
|
565
|
-
|
566
|
-
gates = self.to_latent_gate(state)
|
567
|
-
latent = einsum(latent, gates, 'b n g, b n -> b g')
|
568
|
-
|
569
|
-
elif fetching_multiple_latents:
|
570
|
-
latent = latent[:, 0]
|
571
|
-
else:
|
572
|
-
latent = latent[0]
|
573
|
-
|
574
545
|
latent = self.maybe_l2norm(latent)
|
575
546
|
|
576
547
|
if not exists(net):
|
@@ -612,7 +583,7 @@ class Agent(Module):
|
|
612
583
|
self.actor_optim = optim_klass(actor.parameters(), lr = actor_lr, **actor_optim_kwargs)
|
613
584
|
self.critic_optim = optim_klass(critic.parameters(), lr = critic_lr, **critic_optim_kwargs)
|
614
585
|
|
615
|
-
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if latent_gene_pool.
|
586
|
+
self.latent_optim = optim_klass(latent_gene_pool.parameters(), lr = latent_lr, **latent_optim_kwargs) if not latent_gene_pool.frozen_latents else None
|
616
587
|
|
617
588
|
def get_actor_actions(
|
618
589
|
self,
|
@@ -687,7 +658,6 @@ def create_agent(
|
|
687
658
|
actor_num_actions,
|
688
659
|
actor_dim_hiddens: int | tuple[int, ...],
|
689
660
|
critic_dim_hiddens: int | tuple[int, ...],
|
690
|
-
num_latent_sets = 1
|
691
661
|
) -> Agent:
|
692
662
|
|
693
663
|
actor = Actor(
|
@@ -707,7 +677,6 @@ def create_agent(
|
|
707
677
|
dim_state = dim_state,
|
708
678
|
num_latents = num_latents,
|
709
679
|
dim_latent = dim_latent,
|
710
|
-
num_latent_sets = num_latent_sets
|
711
680
|
)
|
712
681
|
|
713
682
|
return Agent(actor = actor, critic = critic, latent_gene_pool = latent_gene_pool)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.27
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -0,0 +1,7 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
+
evolutionary_policy_optimization/epo.py,sha256=UCCwYK-b20X-5Cq-pah1NTeHFc_35b4xZ3y0aSR8aaI,20783
|
3
|
+
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
+
evolutionary_policy_optimization-0.0.27.dist-info/METADATA,sha256=pJ2kQD5YtKDSUp1TCO_hsrRMh6FCMm8dyu6WrpVHiQk,4958
|
5
|
+
evolutionary_policy_optimization-0.0.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
evolutionary_policy_optimization-0.0.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
evolutionary_policy_optimization-0.0.27.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=Qavcia0n13jjaWIS_LPW7QrxSLT_BBeKujCjF9kQjbA,133
|
2
|
-
evolutionary_policy_optimization/epo.py,sha256=zYKRKUkvFdxgHkc2yduN76Hph3asWX33mnpDF3isDfo,22019
|
3
|
-
evolutionary_policy_optimization/experimental.py,sha256=ktBKxRF27Qsj7WIgBpYlWXqMVxO9zOx2oD1JuDYRAwM,548
|
4
|
-
evolutionary_policy_optimization-0.0.26.dist-info/METADATA,sha256=l24aFXZu4kp1oxZeIdFTUw1mwkyzln9C64S3HNqebF4,4958
|
5
|
-
evolutionary_policy_optimization-0.0.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
evolutionary_policy_optimization-0.0.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
evolutionary_policy_optimization-0.0.26.dist-info/RECORD,,
|
File without changes
|