evolutionary-policy-optimization 0.0.69__py3-none-any.whl → 0.0.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,8 +8,10 @@ from functools import partial, wraps
8
8
  from collections import namedtuple
9
9
  from random import randrange
10
10
 
11
+ import numpy as np
12
+
11
13
  import torch
12
- from torch import nn, cat, stack, is_tensor, tensor, Tensor
14
+ from torch import nn, cat, stack, is_tensor, tensor, from_numpy, Tensor
13
15
  import torch.nn.functional as F
14
16
  import torch.distributed as dist
15
17
  from torch.nn import Linear, Module, ModuleList
@@ -58,6 +60,21 @@ def divisible_by(num, den):
58
60
  def to_device(inp, device):
59
61
  return tree_map(lambda t: t.to(device) if is_tensor(t) else t, inp)
60
62
 
63
+ def interface_torch_numpy(fn, device):
64
+ # for a given function, move all inputs from torch tensor to numpy, and all outputs from numpy to torch tensor
65
+
66
+ @wraps(fn)
67
+ def decorated_fn(*args, **kwargs):
68
+
69
+ args, kwargs = tree_map(lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t, (args, kwargs))
70
+
71
+ out = fn(*args, **kwargs)
72
+
73
+ out = tree_map(lambda t: from_numpy(t).to(device) if isinstance(t, np.ndarray) else t, out)
74
+ return out
75
+
76
+ return decorated_fn
77
+
61
78
  # tensor helpers
62
79
 
63
80
  def l2norm(t):
@@ -416,6 +433,8 @@ class LatentGenePool(Module):
416
433
  self.num_natural_selected = int(frac_natural_selected * latents_per_island)
417
434
  self.num_tournament_participants = int(frac_tournaments * self.num_natural_selected)
418
435
 
436
+ assert self.num_tournament_participants >= 2
437
+
419
438
  self.crossover_random = crossover_random
420
439
 
421
440
  self.mutation_strength = mutation_strength
@@ -845,7 +864,7 @@ class Agent(Module):
845
864
 
846
865
  return self.latent_gene_pool.genetic_algorithm_step(fitnesses)
847
866
 
848
- def forward(
867
+ def learn_from(
849
868
  self,
850
869
  memories_and_cumulative_rewards: MemoriesAndCumulativeRewards,
851
870
  epochs = 2
@@ -853,11 +872,13 @@ class Agent(Module):
853
872
  ):
854
873
  memories_and_cumulative_rewards = to_device(memories_and_cumulative_rewards, self.device)
855
874
 
856
- memories, rewards_per_latent_episode = memories_and_cumulative_rewards
875
+ memories_list, rewards_per_latent_episode = memories_and_cumulative_rewards
857
876
 
858
877
  # stack memories
859
878
 
860
- memories = map(stack, zip(*memories))
879
+ memories = map(stack, zip(*memories_list))
880
+
881
+ memories_list.clear()
861
882
 
862
883
  maybe_barrier()
863
884
 
@@ -979,7 +1000,6 @@ class Agent(Module):
979
1000
  # apply evolution
980
1001
 
981
1002
  if self.has_latent_genes:
982
-
983
1003
  self.latent_gene_pool.genetic_algorithm_step(fitness_scores)
984
1004
 
985
1005
  # reinforcement learning related - ppo
@@ -1159,9 +1179,10 @@ class EPO(Module):
1159
1179
  yield latent_id, episode_id, maybe_seed
1160
1180
 
1161
1181
  @torch.no_grad()
1162
- def forward(
1182
+ def gather_experience_from(
1163
1183
  self,
1164
1184
  env,
1185
+ memories: list[Memory] | None = None,
1165
1186
  fix_environ_across_latents = None
1166
1187
  ) -> MemoriesAndCumulativeRewards:
1167
1188
 
@@ -1171,9 +1192,10 @@ class EPO(Module):
1171
1192
 
1172
1193
  invalid_episode = tensor(-1) # will use `episode_id` value of `-1` for the `next_value`, needed for not discarding last reward for generalized advantage estimate
1173
1194
 
1174
- memories: list[Memory] = []
1195
+ if not exists(memories):
1196
+ memories = []
1175
1197
 
1176
- rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent))
1198
+ rewards_per_latent_episode = torch.zeros((self.num_latents, self.episodes_per_latent), device = self.device)
1177
1199
 
1178
1200
  rollout_gen = self.rollouts_for_machine(fix_environ_across_latents)
1179
1201
 
@@ -1188,7 +1210,7 @@ class EPO(Module):
1188
1210
  if fix_environ_across_latents:
1189
1211
  reset_kwargs.update(seed = maybe_seed)
1190
1212
 
1191
- state = env.reset(**reset_kwargs)
1213
+ state = interface_torch_numpy(env.reset, device = self.device)(**reset_kwargs)
1192
1214
 
1193
1215
  # get latent from pool
1194
1216
 
@@ -1210,7 +1232,7 @@ class EPO(Module):
1210
1232
 
1211
1233
  # get the next state, action, and reward
1212
1234
 
1213
- state, reward, truncated, terminated = env(action)
1235
+ state, reward, truncated, terminated = interface_torch_numpy(env.forward, device = self.device)(action)
1214
1236
 
1215
1237
  done = truncated or terminated
1216
1238
 
@@ -1254,3 +1276,18 @@ class EPO(Module):
1254
1276
  memories = memories,
1255
1277
  cumulative_rewards = rewards_per_latent_episode
1256
1278
  )
1279
+
1280
+ def forward(
1281
+ self,
1282
+ agent: Agent,
1283
+ env,
1284
+ num_learning_cycles
1285
+ ):
1286
+
1287
+ for _ in tqdm(range(num_learning_cycles), desc = 'learning cycle'):
1288
+
1289
+ memories = self.gather_experience_from(env)
1290
+
1291
+ agent.learn_from(memories)
1292
+
1293
+ print(f'training complete')
@@ -34,7 +34,7 @@ class Env(Module):
34
34
  ):
35
35
  state = randn(self.state_shape, device = self.device)
36
36
  self.step.zero_()
37
- return state
37
+ return state.numpy()
38
38
 
39
39
  def forward(
40
40
  self,
@@ -51,4 +51,5 @@ class Env(Module):
51
51
 
52
52
  self.step.add_(1)
53
53
 
54
- return state, reward, truncated, terminated
54
+ out = state, reward, truncated, terminated
55
+ return tuple(t.numpy() for t in out)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.0.69
3
+ Version: 0.0.71
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -114,11 +114,14 @@ from evolutionary_policy_optimization import (
114
114
 
115
115
  agent = create_agent(
116
116
  dim_state = 512,
117
- num_latents = 8,
117
+ num_latents = 16,
118
118
  dim_latent = 32,
119
119
  actor_num_actions = 5,
120
120
  actor_dim_hiddens = (256, 128),
121
- critic_dim_hiddens = (256, 128, 64)
121
+ critic_dim_hiddens = (256, 128, 64),
122
+ latent_gene_pool_kwargs = dict(
123
+ frac_natural_selected = 0.5
124
+ )
122
125
  )
123
126
 
124
127
  epo = EPO(
@@ -130,9 +133,7 @@ epo = EPO(
130
133
 
131
134
  env = Env((512,))
132
135
 
133
- memories = epo(env)
134
-
135
- agent(memories)
136
+ epo(agent, env, num_learning_cycles = 5)
136
137
 
137
138
  # saving and loading
138
139
 
@@ -0,0 +1,9 @@
1
+ evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
+ evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
+ evolutionary_policy_optimization/epo.py,sha256=7_SFYcOcQA2TQcquIYswIJc-42HgwHj4hKYU9wGWZB0,39333
4
+ evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
+ evolutionary_policy_optimization/mock_env.py,sha256=QhPYIazU3uFPKTMZyFw70KhXiDUBr5aU3v1idotfVFI,1391
6
+ evolutionary_policy_optimization-0.0.71.dist-info/METADATA,sha256=veUHctceDr5Xvym2fCDND3fIjYb6ivsCsIfNnOtkdCg,6304
7
+ evolutionary_policy_optimization-0.0.71.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ evolutionary_policy_optimization-0.0.71.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ evolutionary_policy_optimization-0.0.71.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- evolutionary_policy_optimization/__init__.py,sha256=0q0aBuFgWi06MLMD8FiHzBYQ3_W4LYWrwmCtF3u5H2A,201
2
- evolutionary_policy_optimization/distributed.py,sha256=7KgZdeS_wxBHo_du9XZFB1Cu318J-Bp66Xdr6Log_20,2423
3
- evolutionary_policy_optimization/epo.py,sha256=e83tghTNXfCW0zhhb4nIjvfbzDvzWRxgTlm3vKJd4rM,38189
4
- evolutionary_policy_optimization/experimental.py,sha256=-IgqjJ_Wk_CMB1y9YYWpoYqTG9GZHAS6kbRdTluVevg,1563
5
- evolutionary_policy_optimization/mock_env.py,sha256=Bv9ONFRbma8wpjUurc9aCk19A6ceiWitRnS3nwrIR64,1339
6
- evolutionary_policy_optimization-0.0.69.dist-info/METADATA,sha256=UZEaCY5lfTRMkuyEQs5PLA1AZzSOcsRzXey9kgdd9i0,6220
7
- evolutionary_policy_optimization-0.0.69.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- evolutionary_policy_optimization-0.0.69.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- evolutionary_policy_optimization-0.0.69.dist-info/RECORD,,