evolutionary-policy-optimization 0.0.69__py3-none-any.whl → 0.0.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- evolutionary_policy_optimization/epo.py +47 -10
- evolutionary_policy_optimization/mock_env.py +3 -2
- {evolutionary_policy_optimization-0.0.69.dist-info → evolutionary_policy_optimization-0.0.71.dist-info}/METADATA +7 -6
- evolutionary_policy_optimization-0.0.71.dist-info/RECORD +9 -0
- evolutionary_policy_optimization-0.0.69.dist-info/RECORD +0 -9
- {evolutionary_policy_optimization-0.0.69.dist-info → evolutionary_policy_optimization-0.0.71.dist-info}/WHEEL +0 -0
- {evolutionary_policy_optimization-0.0.69.dist-info → evolutionary_policy_optimization-0.0.71.dist-info}/licenses/LICENSE +0 -0
@@ -8,8 +8,10 @@ from functools import partial, wraps
|
|
8
8
|
from collections import namedtuple
|
9
9
|
from random import randrange
|
10
10
|
|
11
|
+
import numpy as np
|
12
|
+
|
11
13
|
import torch
|
12
|
-
from torch import nn, cat, stack, is_tensor, tensor, Tensor
|
14
|
+
from torch import nn, cat, stack, is_tensor, tensor, from_numpy, Tensor
|
13
15
|
import torch.nn.functional as F
|
14
16
|
import torch.distributed as dist
|
15
17
|
from torch.nn import Linear, Module, ModuleList
|
@@ -58,6 +60,21 @@ def divisible_by(num, den):
|
|
58
60
|
def to_device(inp, device):
|
59
61
|
return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
|
60
62
|
|
63
|
+
def interface_torch_numpy(fn, device):
|
64
|
+
# for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
|
65
|
+
|
66
|
+
@wraps(fn)
|
67
|
+
def decorated_fn(*args, **kwargs):
|
68
|
+
|
69
|
+
args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
|
70
|
+
|
71
|
+
out = fn(*args, **kwargs)
|
72
|
+
|
73
|
+
out = tree_map(lambda t: from_numpy(t).to(device) if isinstance(t, np.ndarray) else t, out)
|
74
|
+
return out
|
75
|
+
|
76
|
+
return decorated_fn
|
77
|
+
|
61
78
|
# tensor helpers
|
62
79
|
|
63
80
|
def l2norm(t):
|
@@ -416,6 +433,8 @@ class LatentGenePool(Module):
|
|
416
433
|
self.num_natural_selected = int(frac_natural_selected * latents_per_island)
|
417
434
|
self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
|
418
435
|
|
436
|
+
assert self.num_tournament_participants >= 2
|
437
|
+
|
419
438
|
self.crossover_random = crossover_random
|
420
439
|
|
421
440
|
self.mutation_strength = mutation_strength
|
@@ -845,7 +864,7 @@ class Agent(Module):
|
|
845
864
|
|
846
865
|
return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
|
847
866
|
|
848
|
-
def
|
867
|
+
def learn_from(
|
849
868
|
self,
|
850
869
|
memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
|
851
870
|
epochs = 2
|
@@ -853,11 +872,13 @@ class Agent(Module):
|
|
853
872
|
):
|
854
873
|
memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
|
855
874
|
|
856
|
-
|
875
|
+
memories_list, rewards_per_latent_episode = memories_and_cumulative_rewards
|
857
876
|
|
858
877
|
# stack memories
|
859
878
|
|
860
|
-
memories = map(stack, zip(*
|
879
|
+
memories = map(stack, zip(*memories_list))
|
880
|
+
|
881
|
+
memories_list.clear()
|
861
882
|
|
862
883
|
maybe_barrier()
|
863
884
|
|
@@ -979,7 +1000,6 @@ class Agent(Module):
|
|
979
1000
|
# apply evolution
|
980
1001
|
|
981
1002
|
if self.has_latent_genes:
|
982
|
-
|
983
1003
|
self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
|
984
1004
|
|
985
1005
|
# reinforcement learning related - ppo
|
@@ -1159,9 +1179,10 @@ class EPO(Module):
|
|
1159
1179
|
yield latent_id, episode_id, maybe_seed
|
1160
1180
|
|
1161
1181
|
@torch.no_grad()
|
1162
|
-
def
|
1182
|
+
def gather_experience_from(
|
1163
1183
|
self,
|
1164
1184
|
env,
|
1185
|
+
memories: list[Memory] | None = None,
|
1165
1186
|
fix_environ_across_latents = None
|
1166
1187
|
) -> MemoriesAndCumulativeRewards:
|
1167
1188
|
|
@@ -1171,9 +1192,10 @@ class EPO(Module):
|
|
1171
1192
|
|
1172
1193
|
invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
|
1173
1194
|
|
1174
|
-
memories:
|
1195
|
+
if not exists(memories):
|
1196
|
+
memories = []
|
1175
1197
|
|
1176
|
-
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
|
1198
|
+
rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent), device = self.device)
|
1177
1199
|
|
1178
1200
|
rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
|
1179
1201
|
|
@@ -1188,7 +1210,7 @@ class EPO(Module):
|
|
1188
1210
|
if fix_environ_across_latents:
|
1189
1211
|
reset_kwargs.update(seed = maybe_seed)
|
1190
1212
|
|
1191
|
-
state = env.reset(**reset_kwargs)
|
1213
|
+
state = interface_torch_numpy(env.reset, device = self.device)(**reset_kwargs)
|
1192
1214
|
|
1193
1215
|
# get latent from pool
|
1194
1216
|
|
@@ -1210,7 +1232,7 @@ class EPO(Module):
|
|
1210
1232
|
|
1211
1233
|
# get the next state, action, and reward
|
1212
1234
|
|
1213
|
-
state, reward, truncated, terminated = env(action)
|
1235
|
+
state, reward, truncated, terminated = interface_torch_numpy(env.forward, device = self.device)(action)
|
1214
1236
|
|
1215
1237
|
done = truncated or terminated
|
1216
1238
|
|
@@ -1254,3 +1276,18 @@ class EPO(Module):
|
|
1254
1276
|
memories = memories,
|
1255
1277
|
cumulative_rewards = rewards_per_latent_episode
|
1256
1278
|
)
|
1279
|
+
|
1280
|
+
def forward(
|
1281
|
+
self,
|
1282
|
+
agent: Agent,
|
1283
|
+
env,
|
1284
|
+
num_learning_cycles
|
1285
|
+
):
|
1286
|
+
|
1287
|
+
for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
|
1288
|
+
|
1289
|
+
memories = self.gather_experience_from(env)
|
1290
|
+
|
1291
|
+
agent.learn_from(memories)
|
1292
|
+
|
1293
|
+
print(f'training complete')
|
@@ -34,7 +34,7 @@ class Env(Module):
|
|
34
34
|
):
|
35
35
|
state = randn(self.state_shape, device = self.device)
|
36
36
|
self.step.zero_()
|
37
|
-
return state
|
37
|
+
return state.numpy()
|
38
38
|
|
39
39
|
def forward(
|
40
40
|
self,
|
@@ -51,4 +51,5 @@ class Env(Module):
|
|
51
51
|
|
52
52
|
self.step.add_(1)
|
53
53
|
|
54
|
-
|
54
|
+
out = state, reward, truncated, terminated
|
55
|
+
return tuple(t.numpy() for t in out)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: evolutionary-policy-optimization
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.71
|
4
4
|
Summary: EPO - Pytorch
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
|
@@ -114,11 +114,14 @@ from evolutionary_policy_optimization import (
|
|
114
114
|
|
115
115
|
agent = create_agent(
|
116
116
|
dim_state = 512,
|
117
|
-
num_latents =
|
117
|
+
num_latents = 16,
|
118
118
|
dim_latent = 32,
|
119
119
|
actor_num_actions = 5,
|
120
120
|
actor_dim_hiddens = (256, 128),
|
121
|
-
critic_dim_hiddens = (256, 128, 64)
|
121
|
+
critic_dim_hiddens = (256, 128, 64),
|
122
|
+
latent_gene_pool_kwargs = dict(
|
123
|
+
frac_natural_selected = 0.5
|
124
|
+
)
|
122
125
|
)
|
123
126
|
|
124
127
|
epo = EPO(
|
@@ -130,9 +133,7 @@ epo = EPO(
|
|
130
133
|
|
131
134
|
env = Env((512,))
|
132
135
|
|
133
|
-
|
134
|
-
|
135
|
-
agent(memories)
|
136
|
+
epo(agent, env, num_learning_cycles = 5)
|
136
137
|
|
137
138
|
# saving and loading
|
138
139
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
+
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
+
evolutionary_policy_optimization/epo.py,sha256=7_SFYcOcQA2TQcquIYswIJc-42HgwHj4hKYU9wGWZB0,39333
|
4
|
+
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
+
evolutionary_policy_optimization/mock_env.py,sha256=QhPYIazU3uFPKTMZyFw70KhXiDUBr5aU3v1idotfVFI,1391
|
6
|
+
evolutionary_policy_optimization-0.0.71.dist-info/METADATA,sha256=veUHctceDr5Xvym2fCDND3fIjYb6ivsCsIfNnOtkdCg,6304
|
7
|
+
evolutionary_policy_optimization-0.0.71.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
evolutionary_policy_optimization-0.0.71.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
evolutionary_policy_optimization-0.0.71.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
|
2
|
-
evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
|
3
|
-
evolutionary_policy_optimization/epo.py,sha256=e83tghTNXfCW0zhhb4nIjvfbzDvzWRxgTlm3vKJd4rM,38189
|
4
|
-
evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
|
5
|
-
evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
|
6
|
-
evolutionary_policy_optimization-0.0.69.dist-info/METADATA,sha256=UZEaCY5lfTRMkuyEQs5PLA1AZzSOcsRzXey9kgdd9i0,6220
|
7
|
-
evolutionary_policy_optimization-0.0.69.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
evolutionary_policy_optimization-0.0.69.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
evolutionary_policy_optimization-0.0.69.dist-info/RECORD,,
|
File without changes
|