homa 0.3.11__tar.gz → 0.3.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. {homa-0.3.11 → homa-0.3.12}/PKG-INFO +1 -1
  2. {homa-0.3.11 → homa-0.3.12}/pyproject.toml +1 -1
  3. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/Buffer.py +3 -1
  4. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/SoftActorCriticBuffer.py +1 -1
  5. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/SoftCritic.py +8 -6
  6. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/modules/SoftActorModule.py +1 -1
  7. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/PKG-INFO +1 -1
  8. {homa-0.3.11 → homa-0.3.12}/README.md +0 -0
  9. {homa-0.3.11 → homa-0.3.12}/setup.cfg +0 -0
  10. {homa-0.3.11 → homa-0.3.12}/src/homa/__init__.py +0 -0
  11. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/APLU.py +0 -0
  12. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ActivationFunction.py +0 -0
  13. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/AdaptiveActivationFunction.py +0 -0
  14. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/BaseDLReLU.py +0 -0
  15. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/CaLU.py +0 -0
  16. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/DLReLU.py +0 -0
  17. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ERF.py +0 -0
  18. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Elliot.py +0 -0
  19. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ExpExpish.py +0 -0
  20. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ExponentialDLReLU.py +0 -0
  21. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ExponentialSwish.py +0 -0
  22. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/GCU.py +0 -0
  23. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/GaLU.py +0 -0
  24. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/GaussianReLU.py +0 -0
  25. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/GeneralizedSwish.py +0 -0
  26. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Gish.py +0 -0
  27. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/LaLU.py +0 -0
  28. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/LogLogish.py +0 -0
  29. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/LogSigmoid.py +0 -0
  30. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Logish.py +0 -0
  31. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/MeLU.py +0 -0
  32. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/MexicanReLU.py +0 -0
  33. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/MinSin.py +0 -0
  34. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/NReLU.py +0 -0
  35. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/NoisyReLU.py +0 -0
  36. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/PLogish.py +0 -0
  37. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ParametricLogish.py +0 -0
  38. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Phish.py +0 -0
  39. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/RReLU.py +0 -0
  40. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/RandomizedSlopedReLU.py +0 -0
  41. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SGELU.py +0 -0
  42. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SReLU.py +0 -0
  43. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SelfArctan.py +0 -0
  44. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/ShiftedReLU.py +0 -0
  45. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SigmoidDerivative.py +0 -0
  46. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SineReLU.py +0 -0
  47. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SlopedReLU.py +0 -0
  48. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SmallGaLU.py +0 -0
  49. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Smish.py +0 -0
  50. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/SoftsignRReLU.py +0 -0
  51. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/Suish.py +0 -0
  52. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TBSReLU.py +0 -0
  53. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TSReLU.py +0 -0
  54. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TangentBipolarSigmoidReLU.py +0 -0
  55. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TangentSigmoidReLU.py +0 -0
  56. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TeLU.py +0 -0
  57. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/TripleStateSwish.py +0 -0
  58. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/WideMeLU.py +0 -0
  59. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/__init__.py +0 -0
  60. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/AOAF.py +0 -0
  61. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/AReLU.py +0 -0
  62. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/DPReLU.py +0 -0
  63. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/DualLine.py +0 -0
  64. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/FReLU.py +0 -0
  65. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/LeLeLU.py +0 -0
  66. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/PERU.py +0 -0
  67. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/PiLU.py +0 -0
  68. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/ShiLU.py +0 -0
  69. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/StarReLU.py +0 -0
  70. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/__init__.py +0 -0
  71. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/concerns/ChannelBased.py +0 -0
  72. {homa-0.3.11 → homa-0.3.12}/src/homa/activations/learnable/concerns/__init__.py +0 -0
  73. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/Commands/Command.py +0 -0
  74. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/Commands/InitCommand.py +0 -0
  75. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/Commands/__init__.py +0 -0
  76. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/HomaCommand.py +0 -0
  77. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
  78. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
  79. {homa-0.3.11 → homa-0.3.12}/src/homa/cli/namespaces/__init__.py +0 -0
  80. {homa-0.3.11 → homa-0.3.12}/src/homa/core/__init__.py +0 -0
  81. {homa-0.3.11 → homa-0.3.12}/src/homa/core/concerns/MovesNetworkToDevice.py +0 -0
  82. {homa-0.3.11 → homa-0.3.12}/src/homa/core/concerns/__init__.py +0 -0
  83. {homa-0.3.11 → homa-0.3.12}/src/homa/device.py +0 -0
  84. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/Ensemble.py +0 -0
  85. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/__init__.py +0 -0
  86. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -0
  87. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
  88. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
  89. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
  90. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
  91. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
  92. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsEnsembleSize.py +0 -0
  93. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/ReportsLogits.py +0 -0
  94. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/SavesEnsembleModels.py +0 -0
  95. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/StoresModels.py +0 -0
  96. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/concerns/__init__.py +0 -0
  97. {homa-0.3.11 → homa-0.3.12}/src/homa/ensemble/utils.py +0 -0
  98. {homa-0.3.11 → homa-0.3.12}/src/homa/graph/GraphAttention.py +0 -0
  99. {homa-0.3.11 → homa-0.3.12}/src/homa/graph/__init__.py +0 -0
  100. {homa-0.3.11 → homa-0.3.12}/src/homa/graph/modules/GraphAttentionHeadModule.py +0 -0
  101. {homa-0.3.11 → homa-0.3.12}/src/homa/graph/modules/MultiHeadGraphAttentionModule.py +0 -0
  102. {homa-0.3.11 → homa-0.3.12}/src/homa/graph/modules/__init__.py +0 -0
  103. {homa-0.3.11 → homa-0.3.12}/src/homa/loss/LogitNormLoss.py +0 -0
  104. {homa-0.3.11 → homa-0.3.12}/src/homa/loss/Loss.py +0 -0
  105. {homa-0.3.11 → homa-0.3.12}/src/homa/loss/__init__.py +0 -0
  106. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/DQN.py +0 -0
  107. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/DRQN.py +0 -0
  108. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/DiversityIsAllYouNeed.py +0 -0
  109. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/SoftActorCritic.py +0 -0
  110. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/__init__.py +0 -0
  111. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +0 -0
  112. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/ImageBuffer.py +0 -0
  113. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/__init__.py +0 -0
  114. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/concerns/HasRecordAlternatives.py +0 -0
  115. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/concerns/ResetsCollection.py +0 -0
  116. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/buffers/concerns/__init__.py +0 -0
  117. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/Actor.py +0 -0
  118. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/Critic.py +0 -0
  119. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/Discriminator.py +0 -0
  120. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/__init__.py +0 -0
  121. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/modules/ContinuousActorModule.py +0 -0
  122. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/modules/CriticModule.py +0 -0
  123. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/modules/DiscriminatorModule.py +0 -0
  124. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/diayn/modules/__init__.py +0 -0
  125. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/SoftActor.py +0 -0
  126. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/__init__.py +0 -0
  127. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/modules/DualSoftCriticModule.py +0 -0
  128. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/modules/SoftCriticModule.py +0 -0
  129. {homa-0.3.11 → homa-0.3.12}/src/homa/rl/sac/modules/__init__.py +0 -0
  130. {homa-0.3.11 → homa-0.3.12}/src/homa/settings.py +0 -0
  131. {homa-0.3.11 → homa-0.3.12}/src/homa/utils.py +0 -0
  132. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/Classifier.py +0 -0
  133. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/Model.py +0 -0
  134. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/Resnet.py +0 -0
  135. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/StochasticClassifier.py +0 -0
  136. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/StochasticSwin.py +0 -0
  137. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/Swin.py +0 -0
  138. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/__init__.py +0 -0
  139. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/HasLabels.py +0 -0
  140. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/HasLogits.py +0 -0
  141. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/HasProbabilities.py +0 -0
  142. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/ReportsAccuracy.py +0 -0
  143. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
  144. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/Trainable.py +0 -0
  145. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/concerns/__init__.py +0 -0
  146. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/modules/ResnetModule.py +0 -0
  147. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/modules/SwinModule.py +0 -0
  148. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/modules/__init__.py +0 -0
  149. {homa-0.3.11 → homa-0.3.12}/src/homa/vision/utils.py +0 -0
  150. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/SOURCES.txt +0 -0
  151. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/dependency_links.txt +0 -0
  152. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/entry_points.txt +0 -0
  153. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/requires.txt +0 -0
  154. {homa-0.3.11 → homa-0.3.12}/src/homa.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.3.11
3
+ Version: 0.3.12
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "homa"
7
- version = "0.3.11"
7
+ version = "0.3.12"
8
8
  description = "A curated list of machine learning and deep learning helpers."
9
9
  authors = [
10
10
  { name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
@@ -1,10 +1,12 @@
1
+ from collections import deque
2
+ from typing import Type
1
3
  from .concerns import ResetsCollection, HasRecordAlternatives
2
4
 
3
5
 
4
6
  class Buffer(ResetsCollection, HasRecordAlternatives):
5
7
  def __init__(self, capacity: int):
6
8
  self.capacity: int = capacity
7
- self.reset()
9
+ self.collection: Type[deque] = deque(maxlen=self.capacity)
8
10
 
9
11
  @property
10
12
  def size(self):
@@ -35,7 +35,7 @@ class SoftActorCriticBuffer(Buffer):
35
35
 
36
36
  if as_tensor:
37
37
  states = torch.from_numpy(states).float()
38
- actions = torch.from_numpy(actions).long()
38
+ actions = torch.from_numpy(actions).float()
39
39
  rewards = torch.from_numpy(rewards).float()
40
40
  next_states = torch.from_numpy(next_states).float()
41
41
  terminations = torch.from_numpy(terminations).float()
@@ -1,6 +1,5 @@
1
1
  import torch
2
2
  from torch.nn.functional import mse_loss as mse
3
- from typing import Type
4
3
  from .modules import DualSoftCriticModule
5
4
  from .SoftActor import SoftActor
6
5
 
@@ -31,6 +30,10 @@ class SoftCritic:
31
30
  hidden_dimension=hidden_dimension,
32
31
  action_dimension=action_dimension,
33
32
  )
33
+
34
+ # copy source to target when initiated
35
+ self.target.load_state_dict(self.network.state_dict())
36
+
34
37
  self.optimizer = torch.optim.AdamW(
35
38
  self.network.parameters(), lr=lr, weight_decay=weight_decay
36
39
  )
@@ -65,7 +68,7 @@ class SoftCritic:
65
68
  next_states: torch.Tensor,
66
69
  actor: torch.nn.Module,
67
70
  ):
68
- q_alpha, q_beta = self.target(states, actions)
71
+ q_alpha, q_beta = self.network(states, actions)
69
72
  target = self.calculate_target(
70
73
  rewards=rewards,
71
74
  terminations=terminations,
@@ -82,16 +85,15 @@ class SoftCritic:
82
85
  next_states: torch.Tensor,
83
86
  actor: SoftActor,
84
87
  ):
88
+ termination_mask = 1 - terminations
85
89
  next_actions, next_probabilities = actor.sample(next_states)
86
90
  q_alpha, q_beta = self.target(next_states, next_actions)
87
91
  q = torch.min(q_alpha, q_beta)
88
- termination_mask = 1 - terminations
89
92
  entropy_q = q - self.alpha * next_probabilities * termination_mask
90
93
  return rewards + self.gamma * entropy_q
91
94
 
92
- def soft_update(
93
- self, network: Type[torch.nn.Module], target: Type[torch.nn.Module]
94
- ):
95
+ @torch.no_grad()
96
+ def soft_update(self, network: torch.nn.Module, target: torch.nn.Module):
95
97
  for s, t in zip(network.parameters(), target.parameters()):
96
98
  t.data.copy_(self.tau * s.data + (1 - self.tau) * t.data)
97
99
 
@@ -30,6 +30,6 @@ class SoftActorModule(torch.nn.Module):
30
30
  def forward(self, state: torch.Tensor):
31
31
  features = self.phi(state)
32
32
  mean = self.mu(features)
33
- std = self.mu(features)
33
+ std = self.xi(features)
34
34
  std = std.clamp(self.min_std, self.max_std)
35
35
  return mean, std
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.3.11
3
+ Version: 0.3.12
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes