evolutionary-policy-optimization 0.2.5__tar.gz → 0.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (18) hide show
  1. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/experimental.py +19 -0
  3. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/train_crossover_weight_space.py +2 -2
  5. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/.github/workflows/lint.yml +0 -0
  6. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/.github/workflows/python-publish.yml +0 -0
  7. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/.github/workflows/test.yml +0 -0
  8. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/.gitignore +0 -0
  9. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/LICENSE +0 -0
  10. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/README.md +0 -0
  11. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/__init__.py +0 -0
  12. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/distributed.py +0 -0
  13. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  14. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/epo.py +0 -0
  15. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/evolutionary_policy_optimization/mock_env.py +0 -0
  16. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/requirements.txt +0 -0
  17. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/tests/test_epo.py +0 -0
  18. {evolutionary_policy_optimization-0.2.5 → evolutionary_policy_optimization-0.2.6}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -1,4 +1,5 @@
1
1
  from random import uniform
2
+ from copy import deepcopy
2
3
 
3
4
  import torch
4
5
  import torch.nn.functional as F
@@ -100,6 +101,7 @@ class PopulationWrapper(Module):
100
101
  assert num_selected < pop_size
101
102
  assert tournament_size < num_selected
102
103
 
104
+ self.pop_size = pop_size
103
105
  self.num_selected = num_selected
104
106
  self.tournament_size = tournament_size
105
107
  self.num_offsprings = pop_size - num_selected
@@ -123,6 +125,14 @@ class PopulationWrapper(Module):
123
125
  def pop_params(self):
124
126
  return dict(zip(self.param_names, self.param_values))
125
127
 
128
+ def individual(self, id) -> Module:
129
+ assert 0 <= id < self.pop_size
130
+ state_dict = {key: param[id] for key, param in self.pop_params.items()}
131
+
132
+ net = deepcopy(self.net)
133
+ net.load_state_dict(state_dict)
134
+ return net
135
+
126
136
  def parameters(self):
127
137
  return self.pop_params.values()
128
138
 
@@ -162,9 +172,18 @@ class PopulationWrapper(Module):
162
172
  def forward(
163
173
  self,
164
174
  data,
175
+ *,
176
+ individual_id = None,
165
177
  labels = None,
166
178
  return_logits_with_loss = False
167
179
  ):
180
+ # if `individual_id` passed in, will forward for only that one network
181
+
182
+ if exists(individual_id):
183
+ assert 0 <= individual_id < self.pop_size
184
+ params = {key: param[individual_id] for key, param in self.pop_params.items()}
185
+ return functional_call(self.net, params, data)
186
+
168
187
  out = self.forward_pop_nets(dict(self.pop_params), data)
169
188
 
170
189
  if not exists(labels):
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.2.5"
3
+ version = "0.2.6"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -109,7 +109,7 @@ for i in range(1000):
109
109
 
110
110
  data, labels = next(iter_train_dl)
111
111
 
112
- losses = pop_net(data, labels)
112
+ losses = pop_net(data, labels = labels)
113
113
 
114
114
  losses.sum(dim = 0).mean().backward()
115
115
 
@@ -127,7 +127,7 @@ for i in range(1000):
127
127
  pop_net.eval()
128
128
 
129
129
  eval_data, labels = next(iter_eval_dl)
130
- eval_loss, logits = pop_net(eval_data, labels, return_logits_with_loss = True)
130
+ eval_loss, logits = pop_net(eval_data, labels = labels, return_logits_with_loss = True)
131
131
 
132
132
  total = labels.shape[0] * pop_size
133
133
  correct = (logits.argmax(dim = -1) == labels).long().sum().item()