evolutionary-policy-optimization 0.1.9__tar.gz → 0.1.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (16) hide show
  1. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/PKG-INFO +1 -1
  2. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/epo.py +7 -5
  3. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/pyproject.toml +1 -1
  4. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/.github/workflows/python-publish.yml +0 -0
  5. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/.github/workflows/test.yml +0 -0
  6. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/.gitignore +0 -0
  7. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/LICENSE +0 -0
  8. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/README.md +0 -0
  9. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/__init__.py +0 -0
  10. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/distributed.py +0 -0
  11. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/env_wrappers.py +0 -0
  12. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/experimental.py +0 -0
  13. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/evolutionary_policy_optimization/mock_env.py +0 -0
  14. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/requirements.txt +0 -0
  15. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/tests/test_epo.py +0 -0
  16. {evolutionary_policy_optimization-0.1.9 → evolutionary_policy_optimization-0.1.10}/train_gym.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: evolutionary-policy-optimization
3
- Version: 0.1.9
3
+ Version: 0.1.10
4
4
  Summary: EPO - Pytorch
5
5
  Project-URL: Homepage, https://pypi.org/project/evolutionary-policy-optimization/
6
6
  Project-URL: Repository, https://github.com/lucidrains/evolutionary-policy-optimization
@@ -431,6 +431,8 @@ class Critic(Module):
431
431
 
432
432
  value = self.maybe_bins_to_value(logits)
433
433
 
434
+ loss_fn = partial(self.loss_fn, reduction = 'none')
435
+
434
436
  if use_improved:
435
437
  clipped_target = target.clamp(-eps_clip, eps_clip)
436
438
 
@@ -439,8 +441,8 @@ class Critic(Module):
439
441
 
440
442
  is_between = lambda lo, hi: (lo < value) & (value < hi)
441
443
 
442
- clipped_loss = self.loss_fn(logits, clipped_target, reduction = 'none')
443
- loss = self.loss_fn(logits, target, reduction = 'none')
444
+ clipped_loss = loss_fn(logits, clipped_target)
445
+ loss = loss_fn(logits, target)
444
446
 
445
447
  value_loss = torch.where(
446
448
  is_between(target, old_values_lo) | is_between(old_values_hi, target),
@@ -448,10 +450,10 @@ class Critic(Module):
448
450
  torch.min(loss, clipped_loss)
449
451
  )
450
452
  else:
451
- clipped_value = old_values + (value - old_values).clamp(1. - eps_clip, 1. + eps_clip)
453
+ clipped_value = old_values + (value - old_values).clamp(-eps_clip, eps_clip)
452
454
 
453
- loss = self.loss_fn(logits, target, reduction = 'none')
454
- clipped_loss = self.loss_fn(clipped_value, target, reduction = 'none')
455
+ loss = loss_fn(logits, target)
456
+ clipped_loss = loss_fn(clipped_value, target)
455
457
 
456
458
  value_loss = torch.max(loss, clipped_loss)
457
459
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "evolutionary-policy-optimization"
3
- version = "0.1.9"
3
+ version = "0.1.10"
4
4
  description = "EPO - Pytorch"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }