evolutionary-policy-optimization 0.2.1__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/.gitignore +2 -0
  2. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/PKG-INFO +5 -2
  3. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/README.md +1 -1
  4. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/evolutionary_policy_optimization/epo.py +8 -1
  5. evolutionary_policy_optimization-0.2.5/evolutionary_policy_optimization/experimental.py +198 -0
  6. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/pyproject.toml +6 -1
  7. evolutionary_policy_optimization-0.2.5/train_crossover_weight_space.py +146 -0
  8. evolutionary_policy_optimization-0.2.1/evolutionary_policy_optimization/experimental.py +0 -80
  9. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/.github/workflows/lint.yml +0 -0
  10. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/.github/workflows/python-publish.yml +0 -0
  11. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/.github/workflows/test.yml +0 -0
  12. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/LICENSE +0 -0
  13. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/evolutionary_policy_optimization/__init__.py +0 -0
  14. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/evolutionary_policy_optimization/distributed.py +0 -0
  15. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  16. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/evolutionary_policy_optimization/mock_env.py +0 -0
  17. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/requirements.txt +0 -0
  18. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/tests/test_epo.py +0 -0
  19. {evolutionary_policy_optimization-0.2.1 → evolutionary_policy_optimization-0.2.5}/train_gym.py +0 -0
@@ -1,3 +1,5 @@
1
+ data/
2
+
1
3
  # Byte-compiled / optimized / DLL files
2
4
  __pycache__/
3
5
  *.py[cod]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -47,6 +47,9 @@ Provides-Extra: examples
47
47
  Requires-Dist: numpy; extra == 'examples'
48
48
  Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
+ Provides-Extra: experimental
51
+ Requires-Dist: tensordict; extra == 'experimental'
52
+ Requires-Dist: torchvision; extra == 'experimental'
50
53
  Provides-Extra: test
51
54
  Requires-Dist: pytest; extra == 'test'
52
55
  Requires-Dist: ruff>=0.4.2; extra == 'test'
@@ -56,7 +59,7 @@ Description-Content-Type: text/markdown
56
59
 
57
60
  ## Evolutionary Policy Optimization
58
61
 
59
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
62
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
60
63
 
61
64
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
62
65
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  ## Evolutionary Policy Optimization
4
4
 
5
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
5
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
6
6
 
7
7
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
8
8
 
@@ -1044,7 +1044,14 @@ class Agent(Module):
1044
1044
  actor.state_norm = critic.state_norm = state_norm
1045
1045
 
1046
1046
  self.use_critic_ema = use_critic_ema
1047
- self.critic_ema = EMA(critic, beta = critic_ema_beta, include_online_model = False, **ema_kwargs) if use_critic_ema else None
1047
+
1048
+ self.critic_ema = EMA(
1049
+ critic,
1050
+ beta = critic_ema_beta,
1051
+ include_online_model = False,
1052
+ ignore_startswith_names = {'state_norm'},
1053
+ **ema_kwargs
1054
+ ) if use_critic_ema else None
1048
1055
 
1049
1056
  self.latent_gene_pool = latent_gene_pool
1050
1057
  self.num_latents = latent_gene_pool.num_latents if exists(latent_gene_pool) else 1
@@ -0,0 +1,198 @@
1
+ from random import uniform
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.func import vmap, functional_call
6
+ from torch.nn import Module, ParameterList
7
+
8
+ from einops import rearrange, reduce, repeat
9
+
10
+ def exists(v):
11
+ return v is not None
12
+
13
+ def l2norm(t, dim = -1):
14
+ return F.normalize(t, dim = dim)
15
+
16
+ def crossover_weights(w1, w2):
17
+ assert w2.shape == w2.shape
18
+
19
+ no_batch = w1.ndim == 2
20
+
21
+ if no_batch:
22
+ w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
23
+
24
+ assert w1.ndim == 3
25
+
26
+ i, j = w1.shape[-2:]
27
+ transpose = i < j
28
+
29
+ if transpose:
30
+ w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
31
+
32
+ rank = min(w2.shape[1:])
33
+ assert rank >= 2
34
+
35
+ batch = w1.shape[0]
36
+
37
+ u1, s1, v1 = torch.svd(w1)
38
+ u2, s2, v2 = torch.svd(w2)
39
+
40
+ batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
41
+ mask = batch_randperm < (rank // 2)
42
+
43
+ u = torch.where(mask[:, None, :], u1, u2)
44
+ s = torch.where(mask, s1, s2)
45
+ v = torch.where(mask[:, :, None], v1, v2)
46
+
47
+ out = u @ torch.diag_embed(s) @ v.mT
48
+
49
+ if transpose:
50
+ out = rearrange(out, 'b j i -> b i j')
51
+
52
+ if no_batch:
53
+ out = rearrange(out, '1 ... -> ...')
54
+
55
+ return out
56
+
57
+ def mutate_weight(
58
+ w,
59
+ mutation_strength = 1.
60
+ ):
61
+
62
+ i, j = w.shape[-2:]
63
+ transpose = i < j
64
+
65
+ if transpose:
66
+ w = w.transpose(-1, -2)
67
+
68
+ rank = min(w.shape[1:])
69
+ assert rank >= 2
70
+
71
+ u, s, v = torch.svd(w)
72
+
73
+ u = u + torch.randn_like(u) * mutation_strength
74
+ v = v + torch.randn_like(v) * mutation_strength
75
+
76
+ u = l2norm(u, dim = -2)
77
+ v = l2norm(v, dim = -1)
78
+
79
+ out = u @ torch.diag_embed(s) @ v.mT
80
+
81
+ if transpose:
82
+ out = out.transpose(-1, -2)
83
+
84
+ return out
85
+
86
+ # wrapper that manages network to population
87
+ # able to receive fitness and employ selection + crossover
88
+
89
+ class PopulationWrapper(Module):
90
+ def __init__(
91
+ self,
92
+ net: Module,
93
+ pop_size,
94
+ num_selected,
95
+ tournament_size,
96
+ learning_rate = 1e-3,
97
+ init_std_dev = 1e-1
98
+ ):
99
+ super().__init__()
100
+ assert num_selected < pop_size
101
+ assert tournament_size < num_selected
102
+
103
+ self.num_selected = num_selected
104
+ self.tournament_size = tournament_size
105
+ self.num_offsprings = pop_size - num_selected
106
+
107
+ self.net = net
108
+
109
+ params = dict(net.named_parameters())
110
+ device = next(iter(params.values())).device
111
+
112
+ pop_params = {name: (torch.randn((pop_size, *param.shape), device = device) * init_std_dev).requires_grad_() for name, param in params.items()}
113
+
114
+ self.param_names = pop_params.keys()
115
+ self.param_values = ParameterList(list(pop_params.values()))
116
+
117
+ def _forward(params, data):
118
+ return functional_call(net, params, data)
119
+
120
+ self.forward_pop_nets = vmap(_forward, in_dims = (0, None))
121
+
122
+ @property
123
+ def pop_params(self):
124
+ return dict(zip(self.param_names, self.param_values))
125
+
126
+ def parameters(self):
127
+ return self.pop_params.values()
128
+
129
+ def genetic_algorithm_step_(
130
+ self,
131
+ fitnesses
132
+ ):
133
+ fitnesses = reduce(fitnesses, 'b p -> p', 'mean') # average across samples
134
+
135
+ num_selected = self.num_selected
136
+
137
+ # selection
138
+
139
+ sel_fitnesses, sel_indices = fitnesses.topk(num_selected, dim = -1)
140
+
141
+ # tournaments
142
+
143
+ tourn_ids = torch.randn((self.num_offsprings, self.tournament_size)).argsort(dim = -1)
144
+ tourn_scores = sel_fitnesses[tourn_ids]
145
+
146
+ winner_ids = tourn_scores.topk(2, dim = -1).indices
147
+ winner_ids = rearrange(winner_ids, 'offsprings couple -> couple offsprings')
148
+ parent_ids = sel_indices[winner_ids]
149
+
150
+ # crossover
151
+
152
+ for param in self.param_values:
153
+ parents = param[sel_indices]
154
+ parent1, parent2 = param[parent_ids]
155
+
156
+ children = parent1.lerp_(parent2, uniform(0.25, 0.75))
157
+
158
+ pop = torch.cat((parents, children))
159
+
160
+ param.data.copy_(pop)
161
+
162
+ def forward(
163
+ self,
164
+ data,
165
+ labels = None,
166
+ return_logits_with_loss = False
167
+ ):
168
+ out = self.forward_pop_nets(dict(self.pop_params), data)
169
+
170
+ if not exists(labels):
171
+ return out
172
+
173
+ logits = out
174
+ pop_size = logits.shape[0]
175
+
176
+ losses = F.cross_entropy(
177
+ rearrange(logits, 'p b ... l -> (p b) l ...'),
178
+ repeat(labels, 'b ... -> (p b) ...', p = pop_size),
179
+ reduction = 'none'
180
+ )
181
+
182
+ losses = rearrange(losses, '(p b) ... -> p b ...', p = pop_size)
183
+
184
+ if not return_logits_with_loss:
185
+ return losses
186
+
187
+ return losses, logits
188
+
189
+ # test
190
+
191
+ if __name__ == '__main__':
192
+ w1 = torch.randn(2, 32, 16)
193
+ w2 = torch.randn(2, 32, 16)
194
+
195
+ child = crossover_weights(w1, w2)
196
+ mutated_w1 = mutate_weight(w1)
197
+
198
+ assert child.shape == w2.shape
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.2.1"
3
+ version = "0.2.5"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -46,6 +46,11 @@ examples = [
46
46
  "tqdm",
47
47
  ]
48
48
 
49
+ experimental = [
50
+ "tensordict",
51
+ "torchvision"
52
+ ]
53
+
49
54
  test = [
50
55
  "pytest",
51
56
  "ruff>=0.4.2",
@@ -0,0 +1,146 @@
1
+ from random import uniform
2
+
3
+ import torch
4
+ from torch import nn, tensor, randn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torch.optim import Adam
8
+
9
+ import torchvision
10
+ import torchvision.transforms as T
11
+
12
+ from einops.layers.torch import Rearrange
13
+ from einops import repeat, rearrange
14
+
15
+ from evolutionary_policy_optimization.experimental import PopulationWrapper
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ def divisible_by(num, den):
20
+ return (num % den) == 0
21
+
22
+ #data
23
+
24
+ class MnistDataset(Dataset):
25
+ def __init__(self, train):
26
+ self.mnist = torchvision.datasets.MNIST('./data/mnist', train = train, download = True)
27
+
28
+ def __len__(self):
29
+ return len(self.mnist)
30
+
31
+ def __getitem__(self, idx):
32
+ pil, labels = self.mnist[idx]
33
+ digit_tensor = T.PILToTensor()(pil)
34
+ return (digit_tensor / 255.).float().to(device), tensor(labels, device = device)
35
+
36
+ batch = 32
37
+
38
+ train_dataset = MnistDataset(train = True)
39
+ dl = DataLoader(train_dataset, batch_size = batch, shuffle = True, drop_last = True)
40
+
41
+ eval_dataset = MnistDataset(train = False)
42
+ eval_dl = DataLoader(eval_dataset, batch_size = batch, shuffle = True, drop_last = True)
43
+
44
+ def cycle(dl):
45
+ while True:
46
+ for batch in dl:
47
+ yield batch
48
+
49
+ # network
50
+
51
+ net = nn.Sequential(
52
+ Rearrange('... c h w -> ... (c h w)'),
53
+ nn.Linear(784, 64, bias = False),
54
+ nn.ReLU(),
55
+ nn.Linear(64, 10, bias = False),
56
+ ).to(device)
57
+
58
+ # regular gradient descent
59
+
60
+ optim = Adam(net.parameters(), lr = 1e-3)
61
+
62
+ iter_train_dl = cycle(dl)
63
+ iter_eval_dl = cycle(eval_dl)
64
+
65
+ for i in range(1000):
66
+
67
+ data, labels = next(iter_train_dl)
68
+
69
+ logits = net(data)
70
+
71
+ loss = F.cross_entropy(logits, labels)
72
+ loss.backward()
73
+
74
+ print(f'{i}: {loss.item():.3f}')
75
+
76
+ optim.step()
77
+ optim.zero_grad()
78
+
79
+ if divisible_by(i + 1, 100):
80
+ with torch.no_grad():
81
+ eval_data, labels = next(iter_eval_dl)
82
+ logits = net(eval_data)
83
+ eval_loss = F.cross_entropy(logits, labels)
84
+
85
+ total = labels.shape[0]
86
+ correct = (logits.argmax(dim = -1) == labels).long().sum().item()
87
+
88
+ print(f'{i}: eval loss: {eval_loss.item():.3f}')
89
+ print(f'{i}: accuracy: {correct} / {total}')
90
+
91
+ # periodic crossover from genetic algorithm on population of networks
92
+ # pop stands for population
93
+
94
+ pop_size = 100
95
+ learning_rate = 3e-4
96
+
97
+ pop_net = PopulationWrapper(
98
+ net,
99
+ pop_size = pop_size,
100
+ num_selected = 25,
101
+ tournament_size = 5,
102
+ learning_rate = 1e-3
103
+ )
104
+
105
+ optim = Adam(pop_net.parameters(), lr = learning_rate)
106
+
107
+ for i in range(1000):
108
+ pop_net.train()
109
+
110
+ data, labels = next(iter_train_dl)
111
+
112
+ losses = pop_net(data, labels)
113
+
114
+ losses.sum(dim = 0).mean().backward()
115
+
116
+ print(f'{i}: loss: {losses.mean().item():.3f}')
117
+
118
+ optim.step()
119
+ optim.zero_grad()
120
+
121
+ # evaluate
122
+
123
+ if divisible_by(i + 1, 100):
124
+
125
+ with torch.no_grad():
126
+
127
+ pop_net.eval()
128
+
129
+ eval_data, labels = next(iter_eval_dl)
130
+ eval_loss, logits = pop_net(eval_data, labels, return_logits_with_loss = True)
131
+
132
+ total = labels.shape[0] * pop_size
133
+ correct = (logits.argmax(dim = -1) == labels).long().sum().item()
134
+
135
+ print(f'{i}: eval loss: {eval_loss.mean().item():.3f}')
136
+ print(f'{i}: accuracy: {correct} / {total}')
137
+
138
+ # genetic algorithm on population
139
+
140
+ fitnesses = 1. / eval_loss
141
+
142
+ pop_net.genetic_algorithm_step_(fitnesses)
143
+
144
+ # new optim
145
+
146
+ optim = Adam(pop_net.parameters(), lr = learning_rate)
@@ -1,80 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from einops import rearrange
4
-
5
- def l2norm(t, dim = -1):
6
- return F.normalize(t, dim = dim)
7
-
8
- def crossover_weights(w1, w2, transpose = False):
9
- assert w2.shape == w2.shape
10
-
11
- no_batch = w1.ndim == 2
12
-
13
- if no_batch:
14
- w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
15
-
16
- assert w1.ndim == 3
17
-
18
- if transpose:
19
- w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
20
-
21
- rank = min(w2.shape[1:])
22
- assert rank >= 2
23
-
24
- batch = w1.shape[0]
25
-
26
- u1, s1, v1 = torch.svd(w1)
27
- u2, s2, v2 = torch.svd(w2)
28
-
29
- batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
30
- mask = batch_randperm < (rank // 2)
31
-
32
- u = torch.where(mask[:, None, :], u1, u2)
33
- s = torch.where(mask, s1, s2)
34
- v = torch.where(mask[:, :, None], v1, v2)
35
-
36
- out = u @ torch.diag_embed(s) @ v.mT
37
-
38
- if transpose:
39
- out = rearrange(out, 'b j i -> b i j')
40
-
41
- if no_batch:
42
- out = rearrange(out, '1 ... -> ...')
43
-
44
- return out
45
-
46
- def mutate_weight(
47
- w,
48
- transpose = False,
49
- mutation_strength = 1.
50
- ):
51
-
52
- if transpose:
53
- w = w.transpose(-1, -2)
54
-
55
- rank = min(w2.shape[1:])
56
- assert rank >= 2
57
-
58
- u, s, v = torch.svd(w)
59
-
60
- u = u + torch.randn_like(u) * mutation_strength
61
- v = v + torch.randn_like(v) * mutation_strength
62
-
63
- u = l2norm(u, dim = -2)
64
- v = l2norm(v, dim = -1)
65
-
66
- out = u @ torch.diag_embed(s) @ v.mT
67
-
68
- if transpose:
69
- out = out.transpose(-1, -2)
70
-
71
- return out
72
-
73
- if __name__ == '__main__':
74
- w1 = torch.randn(32, 16)
75
- w2 = torch.randn(32, 16)
76
-
77
- child = crossover_weights(w1, w2)
78
- mutated_w1 = mutate_weight(w1)
79
-
80
- assert child.shape == w2.shape