evolutionary-policy-optimization 0.2.3__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/.gitignore +2 -0
  2. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/PKG-INFO +5 -2
  3. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/README.md +1 -1
  4. evolutionary_policy_optimization-0.2.6/evolutionary_policy_optimization/experimental.py +217 -0
  5. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/pyproject.toml +6 -1
  6. evolutionary_policy_optimization-0.2.6/train_crossover_weight_space.py +146 -0
  7. evolutionary_policy_optimization-0.2.3/evolutionary_policy_optimization/experimental.py +0 -80
  8. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/.github/workflows/lint.yml +0 -0
  9. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/.github/workflows/python-publish.yml +0 -0
  10. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/.github/workflows/test.yml +0 -0
  11. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/LICENSE +0 -0
  12. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/__init__.py +0 -0
  13. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/distributed.py +0 -0
  14. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  15. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/epo.py +0 -0
  16. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/mock_env.py +0 -0
  17. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/requirements.txt +0 -0
  18. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/tests/test_epo.py +0 -0
  19. {evolutionary_policy_optimization-0.2.3 → evolutionary_policy_optimization-0.2.6}/train_gym.py +0 -0
@@ -1,3 +1,5 @@
1
+ data/
2
+
1
3
  # Byte-compiled / optimized / DLL files
2
4
  __pycache__/
3
5
  *.py[cod]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.3
3
+ Version: 0.2.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -47,6 +47,9 @@ Provides-Extra: examples
47
47
  Requires-Dist: numpy; extra == 'examples'
48
48
  Requires-Dist: pufferlib>=2.0.6; extra == 'examples'
49
49
  Requires-Dist: tqdm; extra == 'examples'
50
+ Provides-Extra: experimental
51
+ Requires-Dist: tensordict; extra == 'experimental'
52
+ Requires-Dist: torchvision; extra == 'experimental'
50
53
  Provides-Extra: test
51
54
  Requires-Dist: pytest; extra == 'test'
52
55
  Requires-Dist: ruff>=0.4.2; extra == 'test'
@@ -56,7 +59,7 @@ Description-Content-Type: text/markdown
56
59
 
57
60
  ## Evolutionary Policy Optimization
58
61
 
59
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
62
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
60
63
 
61
64
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
62
65
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  ## Evolutionary Policy Optimization
4
4
 
5
- Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from Wang et al. of the Robotics Institute at Carnegie Mellon University
5
+ Pytorch implementation of [Evolutionary Policy Optimization](https://web3.arxiv.org/abs/2503.19037), from [Wang](https://www.jianrenw.com/) et al. of the Robotics Institute at Carnegie Mellon University
6
6
 
7
7
  This paper stands out, as I have witnessed the positive effects first hand in an [exploratory project](https://github.com/lucidrains/firefly-torch) (mixing evolution with gradient based methods). Perhaps the Alexnet moment for genetic algorithms has not come to pass yet.
8
8
 
@@ -0,0 +1,217 @@
1
+ from random import uniform
2
+ from copy import deepcopy
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.func import vmap, functional_call
7
+ from torch.nn import Module, ParameterList
8
+
9
+ from einops import rearrange, reduce, repeat
10
+
11
+ def exists(v):
12
+ return v is not None
13
+
14
+ def l2norm(t, dim = -1):
15
+ return F.normalize(t, dim = dim)
16
+
17
+ def crossover_weights(w1, w2):
18
+ assert w2.shape == w2.shape
19
+
20
+ no_batch = w1.ndim == 2
21
+
22
+ if no_batch:
23
+ w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
24
+
25
+ assert w1.ndim == 3
26
+
27
+ i, j = w1.shape[-2:]
28
+ transpose = i < j
29
+
30
+ if transpose:
31
+ w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
32
+
33
+ rank = min(w2.shape[1:])
34
+ assert rank >= 2
35
+
36
+ batch = w1.shape[0]
37
+
38
+ u1, s1, v1 = torch.svd(w1)
39
+ u2, s2, v2 = torch.svd(w2)
40
+
41
+ batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
42
+ mask = batch_randperm < (rank // 2)
43
+
44
+ u = torch.where(mask[:, None, :], u1, u2)
45
+ s = torch.where(mask, s1, s2)
46
+ v = torch.where(mask[:, :, None], v1, v2)
47
+
48
+ out = u @ torch.diag_embed(s) @ v.mT
49
+
50
+ if transpose:
51
+ out = rearrange(out, 'b j i -> b i j')
52
+
53
+ if no_batch:
54
+ out = rearrange(out, '1 ... -> ...')
55
+
56
+ return out
57
+
58
+ def mutate_weight(
59
+ w,
60
+ mutation_strength = 1.
61
+ ):
62
+
63
+ i, j = w.shape[-2:]
64
+ transpose = i < j
65
+
66
+ if transpose:
67
+ w = w.transpose(-1, -2)
68
+
69
+ rank = min(w.shape[1:])
70
+ assert rank >= 2
71
+
72
+ u, s, v = torch.svd(w)
73
+
74
+ u = u + torch.randn_like(u) * mutation_strength
75
+ v = v + torch.randn_like(v) * mutation_strength
76
+
77
+ u = l2norm(u, dim = -2)
78
+ v = l2norm(v, dim = -1)
79
+
80
+ out = u @ torch.diag_embed(s) @ v.mT
81
+
82
+ if transpose:
83
+ out = out.transpose(-1, -2)
84
+
85
+ return out
86
+
87
+ # wrapper that manages network to population
88
+ # able to receive fitness and employ selection + crossover
89
+
90
+ class PopulationWrapper(Module):
91
+ def __init__(
92
+ self,
93
+ net: Module,
94
+ pop_size,
95
+ num_selected,
96
+ tournament_size,
97
+ learning_rate = 1e-3,
98
+ init_std_dev = 1e-1
99
+ ):
100
+ super().__init__()
101
+ assert num_selected < pop_size
102
+ assert tournament_size < num_selected
103
+
104
+ self.pop_size = pop_size
105
+ self.num_selected = num_selected
106
+ self.tournament_size = tournament_size
107
+ self.num_offsprings = pop_size - num_selected
108
+
109
+ self.net = net
110
+
111
+ params = dict(net.named_parameters())
112
+ device = next(iter(params.values())).device
113
+
114
+ pop_params = {name: (torch.randn((pop_size, *param.shape), device = device) * init_std_dev).requires_grad_() for name, param in params.items()}
115
+
116
+ self.param_names = pop_params.keys()
117
+ self.param_values = ParameterList(list(pop_params.values()))
118
+
119
+ def _forward(params, data):
120
+ return functional_call(net, params, data)
121
+
122
+ self.forward_pop_nets = vmap(_forward, in_dims = (0, None))
123
+
124
+ @property
125
+ def pop_params(self):
126
+ return dict(zip(self.param_names, self.param_values))
127
+
128
+ def individual(self, id) -> Module:
129
+ assert 0 <= id < self.pop_size
130
+ state_dict = {key: param[id] for key, param in self.pop_params.items()}
131
+
132
+ net = deepcopy(self.net)
133
+ net.load_state_dict(state_dict)
134
+ return net
135
+
136
+ def parameters(self):
137
+ return self.pop_params.values()
138
+
139
+ def genetic_algorithm_step_(
140
+ self,
141
+ fitnesses
142
+ ):
143
+ fitnesses = reduce(fitnesses, 'b p -> p', 'mean') # average across samples
144
+
145
+ num_selected = self.num_selected
146
+
147
+ # selection
148
+
149
+ sel_fitnesses, sel_indices = fitnesses.topk(num_selected, dim = -1)
150
+
151
+ # tournaments
152
+
153
+ tourn_ids = torch.randn((self.num_offsprings, self.tournament_size)).argsort(dim = -1)
154
+ tourn_scores = sel_fitnesses[tourn_ids]
155
+
156
+ winner_ids = tourn_scores.topk(2, dim = -1).indices
157
+ winner_ids = rearrange(winner_ids, 'offsprings couple -> couple offsprings')
158
+ parent_ids = sel_indices[winner_ids]
159
+
160
+ # crossover
161
+
162
+ for param in self.param_values:
163
+ parents = param[sel_indices]
164
+ parent1, parent2 = param[parent_ids]
165
+
166
+ children = parent1.lerp_(parent2, uniform(0.25, 0.75))
167
+
168
+ pop = torch.cat((parents, children))
169
+
170
+ param.data.copy_(pop)
171
+
172
+ def forward(
173
+ self,
174
+ data,
175
+ *,
176
+ individual_id = None,
177
+ labels = None,
178
+ return_logits_with_loss = False
179
+ ):
180
+ # if `individual_id` passed in, will forward for only that one network
181
+
182
+ if exists(individual_id):
183
+ assert 0 <= individual_id < self.pop_size
184
+ params = {key: param[individual_id] for key, param in self.pop_params.items()}
185
+ return functional_call(self.net, params, data)
186
+
187
+ out = self.forward_pop_nets(dict(self.pop_params), data)
188
+
189
+ if not exists(labels):
190
+ return out
191
+
192
+ logits = out
193
+ pop_size = logits.shape[0]
194
+
195
+ losses = F.cross_entropy(
196
+ rearrange(logits, 'p b ... l -> (p b) l ...'),
197
+ repeat(labels, 'b ... -> (p b) ...', p = pop_size),
198
+ reduction = 'none'
199
+ )
200
+
201
+ losses = rearrange(losses, '(p b) ... -> p b ...', p = pop_size)
202
+
203
+ if not return_logits_with_loss:
204
+ return losses
205
+
206
+ return losses, logits
207
+
208
+ # test
209
+
210
+ if __name__ == '__main__':
211
+ w1 = torch.randn(2, 32, 16)
212
+ w2 = torch.randn(2, 32, 16)
213
+
214
+ child = crossover_weights(w1, w2)
215
+ mutated_w1 = mutate_weight(w1)
216
+
217
+ assert child.shape == w2.shape
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.2.3"
3
+ version = "0.2.6"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -46,6 +46,11 @@ examples = [
46
46
  "tqdm",
47
47
  ]
48
48
 
49
+ experimental = [
50
+ "tensordict",
51
+ "torchvision"
52
+ ]
53
+
49
54
  test = [
50
55
  "pytest",
51
56
  "ruff>=0.4.2",
@@ -0,0 +1,146 @@
1
+ from random import uniform
2
+
3
+ import torch
4
+ from torch import nn, tensor, randn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torch.optim import Adam
8
+
9
+ import torchvision
10
+ import torchvision.transforms as T
11
+
12
+ from einops.layers.torch import Rearrange
13
+ from einops import repeat, rearrange
14
+
15
+ from evolutionary_policy_optimization.experimental import PopulationWrapper
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+
19
+ def divisible_by(num, den):
20
+ return (num % den) == 0
21
+
22
+ #data
23
+
24
+ class MnistDataset(Dataset):
25
+ def __init__(self, train):
26
+ self.mnist = torchvision.datasets.MNIST('./data/mnist', train = train, download = True)
27
+
28
+ def __len__(self):
29
+ return len(self.mnist)
30
+
31
+ def __getitem__(self, idx):
32
+ pil, labels = self.mnist[idx]
33
+ digit_tensor = T.PILToTensor()(pil)
34
+ return (digit_tensor / 255.).float().to(device), tensor(labels, device = device)
35
+
36
+ batch = 32
37
+
38
+ train_dataset = MnistDataset(train = True)
39
+ dl = DataLoader(train_dataset, batch_size = batch, shuffle = True, drop_last = True)
40
+
41
+ eval_dataset = MnistDataset(train = False)
42
+ eval_dl = DataLoader(eval_dataset, batch_size = batch, shuffle = True, drop_last = True)
43
+
44
+ def cycle(dl):
45
+ while True:
46
+ for batch in dl:
47
+ yield batch
48
+
49
+ # network
50
+
51
+ net = nn.Sequential(
52
+ Rearrange('... c h w -> ... (c h w)'),
53
+ nn.Linear(784, 64, bias = False),
54
+ nn.ReLU(),
55
+ nn.Linear(64, 10, bias = False),
56
+ ).to(device)
57
+
58
+ # regular gradient descent
59
+
60
+ optim = Adam(net.parameters(), lr = 1e-3)
61
+
62
+ iter_train_dl = cycle(dl)
63
+ iter_eval_dl = cycle(eval_dl)
64
+
65
+ for i in range(1000):
66
+
67
+ data, labels = next(iter_train_dl)
68
+
69
+ logits = net(data)
70
+
71
+ loss = F.cross_entropy(logits, labels)
72
+ loss.backward()
73
+
74
+ print(f'{i}: {loss.item():.3f}')
75
+
76
+ optim.step()
77
+ optim.zero_grad()
78
+
79
+ if divisible_by(i + 1, 100):
80
+ with torch.no_grad():
81
+ eval_data, labels = next(iter_eval_dl)
82
+ logits = net(eval_data)
83
+ eval_loss = F.cross_entropy(logits, labels)
84
+
85
+ total = labels.shape[0]
86
+ correct = (logits.argmax(dim = -1) == labels).long().sum().item()
87
+
88
+ print(f'{i}: eval loss: {eval_loss.item():.3f}')
89
+ print(f'{i}: accuracy: {correct} / {total}')
90
+
91
+ # periodic crossover from genetic algorithm on population of networks
92
+ # pop stands for population
93
+
94
+ pop_size = 100
95
+ learning_rate = 3e-4
96
+
97
+ pop_net = PopulationWrapper(
98
+ net,
99
+ pop_size = pop_size,
100
+ num_selected = 25,
101
+ tournament_size = 5,
102
+ learning_rate = 1e-3
103
+ )
104
+
105
+ optim = Adam(pop_net.parameters(), lr = learning_rate)
106
+
107
+ for i in range(1000):
108
+ pop_net.train()
109
+
110
+ data, labels = next(iter_train_dl)
111
+
112
+ losses = pop_net(data, labels = labels)
113
+
114
+ losses.sum(dim = 0).mean().backward()
115
+
116
+ print(f'{i}: loss: {losses.mean().item():.3f}')
117
+
118
+ optim.step()
119
+ optim.zero_grad()
120
+
121
+ # evaluate
122
+
123
+ if divisible_by(i + 1, 100):
124
+
125
+ with torch.no_grad():
126
+
127
+ pop_net.eval()
128
+
129
+ eval_data, labels = next(iter_eval_dl)
130
+ eval_loss, logits = pop_net(eval_data, labels = labels, return_logits_with_loss = True)
131
+
132
+ total = labels.shape[0] * pop_size
133
+ correct = (logits.argmax(dim = -1) == labels).long().sum().item()
134
+
135
+ print(f'{i}: eval loss: {eval_loss.mean().item():.3f}')
136
+ print(f'{i}: accuracy: {correct} / {total}')
137
+
138
+ # genetic algorithm on population
139
+
140
+ fitnesses = 1. / eval_loss
141
+
142
+ pop_net.genetic_algorithm_step_(fitnesses)
143
+
144
+ # new optim
145
+
146
+ optim = Adam(pop_net.parameters(), lr = learning_rate)
@@ -1,80 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from einops import rearrange
4
-
5
- def l2norm(t, dim = -1):
6
- return F.normalize(t, dim = dim)
7
-
8
- def crossover_weights(w1, w2, transpose = False):
9
- assert w2.shape == w2.shape
10
-
11
- no_batch = w1.ndim == 2
12
-
13
- if no_batch:
14
- w1, w2 = tuple(rearrange(t, '... -> 1 ...') for t in (w1, w2))
15
-
16
- assert w1.ndim == 3
17
-
18
- if transpose:
19
- w1, w2 = tuple(rearrange(t, 'b i j -> b j i') for t in (w1, w2))
20
-
21
- rank = min(w2.shape[1:])
22
- assert rank >= 2
23
-
24
- batch = w1.shape[0]
25
-
26
- u1, s1, v1 = torch.svd(w1)
27
- u2, s2, v2 = torch.svd(w2)
28
-
29
- batch_randperm = torch.randn((batch, rank), device = w1.device).argsort(dim = -1)
30
- mask = batch_randperm < (rank // 2)
31
-
32
- u = torch.where(mask[:, None, :], u1, u2)
33
- s = torch.where(mask, s1, s2)
34
- v = torch.where(mask[:, :, None], v1, v2)
35
-
36
- out = u @ torch.diag_embed(s) @ v.mT
37
-
38
- if transpose:
39
- out = rearrange(out, 'b j i -> b i j')
40
-
41
- if no_batch:
42
- out = rearrange(out, '1 ... -> ...')
43
-
44
- return out
45
-
46
- def mutate_weight(
47
- w,
48
- transpose = False,
49
- mutation_strength = 1.
50
- ):
51
-
52
- if transpose:
53
- w = w.transpose(-1, -2)
54
-
55
- rank = min(w2.shape[1:])
56
- assert rank >= 2
57
-
58
- u, s, v = torch.svd(w)
59
-
60
- u = u + torch.randn_like(u) * mutation_strength
61
- v = v + torch.randn_like(v) * mutation_strength
62
-
63
- u = l2norm(u, dim = -2)
64
- v = l2norm(v, dim = -1)
65
-
66
- out = u @ torch.diag_embed(s) @ v.mT
67
-
68
- if transpose:
69
- out = out.transpose(-1, -2)
70
-
71
- return out
72
-
73
- if __name__ == '__main__':
74
- w1 = torch.randn(32, 16)
75
- w2 = torch.randn(32, 16)
76
-
77
- child = crossover_weights(w1, w2)
78
- mutated_w1 = mutate_weight(w1)
79
-
80
- assert child.shape == w2.shape