evolutionary-policy-optimization 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -382,6 +382,52 @@ class StateNorm(Module):
382
382
 
383
383
  return normed
384
384
 
385
+ # style mapping network from StyleGAN2
386
+
387
+ class EqualLinear(Module):
388
+ def __init__(
389
+ self,
390
+ dim_in,
391
+ dim_out,
392
+ lr_mul = 1,
393
+ bias = True
394
+ ):
395
+ super().__init__()
396
+ self.lr_mul = lr_mul
397
+
398
+ self.weight = nn.Parameter(torch.randn(dim_out, dim_in))
399
+ self.bias = nn.Parameter(torch.zeros(dim_out))
400
+
401
+ def forward(
402
+ self,
403
+ input
404
+ ):
405
+ weight, bias = tuple(t * self.lr_mul for t in (self.weight, self.bias))
406
+ return F.linear(input, weight, bias = bias)
407
+
408
+ class LatentMappingNetwork(Module):
409
+ def __init__(
410
+ self,
411
+ dim_latent,
412
+ depth,
413
+ lr_mul = 0.1,
414
+ leaky_relu_p = 2e-2
415
+ ):
416
+ super().__init__()
417
+
418
+ layers = []
419
+
420
+ for i in range(depth):
421
+ layers.extend([
422
+ EqualLinear(dim_latent, dim_latent, lr_mul),
423
+ nn.LeakyReLU(leaky_relu_p)
424
+ ])
425
+
426
+ self.net = nn.Sequential(*layers)
427
+
428
+ def forward(self, x):
429
+ return self.net(x)
430
+
385
431
  # simple MLP networks, but with latent variables
386
432
  # the latent variables are the "genes" with the rest of the network as the scaffold for "gene expression" - as suggested in the paper
387
433
 
@@ -391,6 +437,7 @@ class MLP(Module):
391
437
  dim,
392
438
  depth,
393
439
  dim_latent = 0,
440
+ latent_mapping_network_depth = 2,
394
441
  expansion_factor = 2.
395
442
  ):
396
443
  super().__init__()
@@ -401,6 +448,7 @@ class MLP(Module):
401
448
  self.needs_latent = dim_latent > 0
402
449
 
403
450
  self.encode_latent = nn.Sequential(
451
+ LatentMappingNetwork(dim_latent, depth = latent_mapping_network_depth),
404
452
  Linear(dim_latent, dim * 2),
405
453
  nn.SiLU()
406
454
  ) if self.needs_latent else None
@@ -1,4 +1,5 @@
1
1
  from random import uniform
2
+ from copy import deepcopy
2
3
 
3
4
  import torch
4
5
  import torch.nn.functional as F
@@ -100,6 +101,7 @@ class PopulationWrapper(Module):
100
101
  assert num_selected < pop_size
101
102
  assert tournament_size < num_selected
102
103
 
104
+ self.pop_size = pop_size
103
105
  self.num_selected = num_selected
104
106
  self.tournament_size = tournament_size
105
107
  self.num_offsprings = pop_size - num_selected
@@ -123,6 +125,14 @@ class PopulationWrapper(Module):
123
125
  def pop_params(self):
124
126
  return dict(zip(self.param_names, self.param_values))
125
127
 
128
+ def individual(self, id) -> Module:
129
+ assert 0 <= id < self.pop_size
130
+ state_dict = {key: param[id] for key, param in self.pop_params.items()}
131
+
132
+ net = deepcopy(self.net)
133
+ net.load_state_dict(state_dict)
134
+ return net
135
+
126
136
  def parameters(self):
127
137
  return self.pop_params.values()
128
138
 
@@ -162,9 +172,18 @@ class PopulationWrapper(Module):
162
172
  def forward(
163
173
  self,
164
174
  data,
175
+ *,
176
+ individual_id = None,
165
177
  labels = None,
166
178
  return_logits_with_loss = False
167
179
  ):
180
+ # if `individual_id` passed in, will forward for only that one network
181
+
182
+ if exists(individual_id):
183
+ assert 0 <= individual_id < self.pop_size
184
+ params = {key: param[individual_id] for key, param in self.pop_params.items()}
185
+ return functional_call(self.net, params, data)
186
+
168
187
  out = self.forward_pop_nets(dict(self.pop_params), data)
169
188
 
170
189
  if not exists(labels):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -268,4 +268,14 @@ That's it
268
268
  }
269
269
  ```
270
270
 
271
+ ```bibtex
272
+ @article{Karras2019stylegan2,
273
+ title = {Analyzing and Improving the Image Quality of {StyleGAN}},
274
+ author = {Tero Karras and Samuli Laine and Miika Aittala and Janne Hellsten and Jaakko Lehtinen and Timo Aila},
275
+ journal = {CoRR},
276
+ volume = {abs/1912.04958},
277
+ year = {2019},
278
+ }
279
+ ```
280
+
271
281
  *Evolution is cleverer than you are.* - Leslie Orgel
@@ -1,10 +1,10 @@
1
1
  evolutionary_policy_optimization/__init__.py,sha256=NyiYDYU7DlpmOTM7xiBQET3r1WwX0ebrgMCBLSQrW3c,288
2
2
  evolutionary_policy_optimization/distributed.py,sha256=MxyxqxANAuOm8GYb0Yu09EHd_aVLhK2uwgrfuVWciPU,2342
3
3
  evolutionary_policy_optimization/env_wrappers.py,sha256=bDL06o9_b1iW6k3fw2xifnOnYlzs643tdW6Yv2gsIdw,803
4
- evolutionary_policy_optimization/epo.py,sha256=81rf249ykaPrAEMGk9KsF98qDkCUhW8xL3-2UXIvI2E,51838
5
- evolutionary_policy_optimization/experimental.py,sha256=NKtAeOS8AVoX5HHwwxAl1ngUi7ZE19uV1-NILFK6Tu8,4877
4
+ evolutionary_policy_optimization/epo.py,sha256=OKumVrOH7DSKZhbbx-5oCI_JJwJNYf4lrEpCNmbj6ZY,52991
5
+ evolutionary_policy_optimization/experimental.py,sha256=7LOrMIaU4fr2Vme1ZpHNIvlvFEIdWj0-uemhQoNJcPQ,5549
6
6
  evolutionary_policy_optimization/mock_env.py,sha256=TLyyRm6tOD0Kdn9QqJJQriaSnsR-YmNQHo4OohmZFG4,1410
7
- evolutionary_policy_optimization-0.2.5.dist-info/METADATA,sha256=WzU43HLqNw6kjQ-s7IMW9bFv9nojiQsLEd2pHTvKbfw,8858
8
- evolutionary_policy_optimization-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
- evolutionary_policy_optimization-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
- evolutionary_policy_optimization-0.2.5.dist-info/RECORD,,
7
+ evolutionary_policy_optimization-0.2.7.dist-info/METADATA,sha256=MiG_AYp6KoANhdGuaGM37-zaciW8dCrO0KvkXk7hO7w,9171
8
+ evolutionary_policy_optimization-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
9
+ evolutionary_policy_optimization-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
10
+ evolutionary_policy_optimization-0.2.7.dist-info/RECORD,,