homa 0.2.93__tar.gz → 0.3.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. {homa-0.2.93 → homa-0.3.11}/PKG-INFO +1 -1
  2. {homa-0.2.93 → homa-0.3.11}/pyproject.toml +1 -1
  3. homa-0.3.11/src/homa/core/__init__.py +0 -0
  4. homa-0.3.11/src/homa/core/concerns/MovesNetworkToDevice.py +13 -0
  5. homa-0.3.11/src/homa/core/concerns/__init__.py +1 -0
  6. {homa-0.2.93 → homa-0.3.11}/src/homa/device.py +4 -0
  7. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/Ensemble.py +4 -2
  8. homa-0.2.93/src/homa/ensemble/concerns/ReportsSize.py → homa-0.3.11/src/homa/ensemble/concerns/ReportsEnsembleSize.py +3 -3
  9. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsLogits.py +3 -1
  10. homa-0.3.11/src/homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
  11. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/StoresModels.py +9 -6
  12. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/__init__.py +2 -1
  13. homa-0.3.11/src/homa/ensemble/utils.py +9 -0
  14. homa-0.3.11/src/homa/graph/GraphAttention.py +13 -0
  15. homa-0.3.11/src/homa/graph/__init__.py +1 -0
  16. homa-0.3.11/src/homa/graph/modules/GraphAttentionHeadModule.py +37 -0
  17. homa-0.3.11/src/homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
  18. homa-0.3.11/src/homa/graph/modules/__init__.py +2 -0
  19. homa-0.3.11/src/homa/loss/Loss.py +5 -0
  20. homa-0.3.11/src/homa/rl/DQN.py +2 -0
  21. homa-0.3.11/src/homa/rl/DRQN.py +5 -0
  22. homa-0.3.11/src/homa/rl/DiversityIsAllYouNeed.py +96 -0
  23. homa-0.3.11/src/homa/rl/SoftActorCritic.py +64 -0
  24. homa-0.3.11/src/homa/rl/__init__.py +4 -0
  25. homa-0.3.11/src/homa/rl/buffers/Buffer.py +11 -0
  26. homa-0.3.11/src/homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
  27. homa-0.3.11/src/homa/rl/buffers/ImageBuffer.py +5 -0
  28. homa-0.3.11/src/homa/rl/buffers/SoftActorCriticBuffer.py +56 -0
  29. homa-0.3.11/src/homa/rl/buffers/__init__.py +4 -0
  30. homa-0.3.11/src/homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
  31. homa-0.3.11/src/homa/rl/buffers/concerns/ResetsCollection.py +9 -0
  32. homa-0.3.11/src/homa/rl/buffers/concerns/__init__.py +2 -0
  33. homa-0.3.11/src/homa/rl/diayn/Actor.py +54 -0
  34. homa-0.3.11/src/homa/rl/diayn/Critic.py +41 -0
  35. homa-0.3.11/src/homa/rl/diayn/Discriminator.py +45 -0
  36. homa-0.3.11/src/homa/rl/diayn/__init__.py +3 -0
  37. homa-0.3.11/src/homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
  38. homa-0.3.11/src/homa/rl/diayn/modules/CriticModule.py +28 -0
  39. homa-0.3.11/src/homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
  40. homa-0.3.11/src/homa/rl/diayn/modules/__init__.py +3 -0
  41. homa-0.3.11/src/homa/rl/sac/SoftActor.py +69 -0
  42. homa-0.3.11/src/homa/rl/sac/SoftCritic.py +100 -0
  43. homa-0.3.11/src/homa/rl/sac/__init__.py +2 -0
  44. homa-0.3.11/src/homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
  45. homa-0.3.11/src/homa/rl/sac/modules/SoftActorModule.py +35 -0
  46. homa-0.3.11/src/homa/rl/sac/modules/SoftCriticModule.py +30 -0
  47. homa-0.3.11/src/homa/rl/sac/modules/__init__.py +3 -0
  48. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Resnet.py +3 -3
  49. homa-0.3.11/src/homa/vision/Swin.py +25 -0
  50. homa-0.3.11/src/homa/vision/modules/SwinModule.py +31 -0
  51. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/PKG-INFO +1 -1
  52. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/SOURCES.txt +39 -3
  53. homa-0.2.93/src/homa/loss/Loss.py +0 -2
  54. homa-0.2.93/src/homa/torch/__init__.py +0 -1
  55. homa-0.2.93/src/homa/torch/helpers.py +0 -6
  56. homa-0.2.93/src/homa/vision/Swin.py +0 -13
  57. homa-0.2.93/src/homa/vision/modules/SwinModule.py +0 -23
  58. {homa-0.2.93 → homa-0.3.11}/README.md +0 -0
  59. {homa-0.2.93 → homa-0.3.11}/setup.cfg +0 -0
  60. {homa-0.2.93 → homa-0.3.11}/src/homa/__init__.py +0 -0
  61. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/APLU.py +0 -0
  62. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ActivationFunction.py +0 -0
  63. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/AdaptiveActivationFunction.py +0 -0
  64. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/BaseDLReLU.py +0 -0
  65. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/CaLU.py +0 -0
  66. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/DLReLU.py +0 -0
  67. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ERF.py +0 -0
  68. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Elliot.py +0 -0
  69. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExpExpish.py +0 -0
  70. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExponentialDLReLU.py +0 -0
  71. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExponentialSwish.py +0 -0
  72. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GCU.py +0 -0
  73. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GaLU.py +0 -0
  74. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GaussianReLU.py +0 -0
  75. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GeneralizedSwish.py +0 -0
  76. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Gish.py +0 -0
  77. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LaLU.py +0 -0
  78. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LogLogish.py +0 -0
  79. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LogSigmoid.py +0 -0
  80. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Logish.py +0 -0
  81. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MeLU.py +0 -0
  82. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MexicanReLU.py +0 -0
  83. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MinSin.py +0 -0
  84. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/NReLU.py +0 -0
  85. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/NoisyReLU.py +0 -0
  86. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/PLogish.py +0 -0
  87. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ParametricLogish.py +0 -0
  88. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Phish.py +0 -0
  89. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/RReLU.py +0 -0
  90. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/RandomizedSlopedReLU.py +0 -0
  91. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SGELU.py +0 -0
  92. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SReLU.py +0 -0
  93. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SelfArctan.py +0 -0
  94. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ShiftedReLU.py +0 -0
  95. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SigmoidDerivative.py +0 -0
  96. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SineReLU.py +0 -0
  97. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SlopedReLU.py +0 -0
  98. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SmallGaLU.py +0 -0
  99. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Smish.py +0 -0
  100. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SoftsignRReLU.py +0 -0
  101. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Suish.py +0 -0
  102. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TBSReLU.py +0 -0
  103. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TSReLU.py +0 -0
  104. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TangentBipolarSigmoidReLU.py +0 -0
  105. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TangentSigmoidReLU.py +0 -0
  106. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TeLU.py +0 -0
  107. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TripleStateSwish.py +0 -0
  108. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/WideMeLU.py +0 -0
  109. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/__init__.py +0 -0
  110. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/AOAF.py +0 -0
  111. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/AReLU.py +0 -0
  112. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/DPReLU.py +0 -0
  113. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/DualLine.py +0 -0
  114. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/FReLU.py +0 -0
  115. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/LeLeLU.py +0 -0
  116. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/PERU.py +0 -0
  117. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/PiLU.py +0 -0
  118. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/ShiLU.py +0 -0
  119. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/StarReLU.py +0 -0
  120. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/__init__.py +0 -0
  121. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/concerns/ChannelBased.py +0 -0
  122. {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/concerns/__init__.py +0 -0
  123. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/Command.py +0 -0
  124. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/InitCommand.py +0 -0
  125. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/__init__.py +0 -0
  126. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/HomaCommand.py +0 -0
  127. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
  128. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
  129. {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/__init__.py +0 -0
  130. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/__init__.py +0 -0
  131. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -0
  132. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
  133. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
  134. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
  135. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
  136. {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
  137. {homa-0.2.93 → homa-0.3.11}/src/homa/loss/LogitNormLoss.py +0 -0
  138. {homa-0.2.93 → homa-0.3.11}/src/homa/loss/__init__.py +0 -0
  139. {homa-0.2.93 → homa-0.3.11}/src/homa/settings.py +0 -0
  140. {homa-0.2.93 → homa-0.3.11}/src/homa/utils.py +0 -0
  141. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Classifier.py +0 -0
  142. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Model.py +0 -0
  143. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/StochasticClassifier.py +0 -0
  144. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/StochasticSwin.py +0 -0
  145. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/__init__.py +0 -0
  146. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasLabels.py +0 -0
  147. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasLogits.py +0 -0
  148. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasProbabilities.py +0 -0
  149. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/ReportsAccuracy.py +0 -0
  150. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
  151. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/Trainable.py +0 -0
  152. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/__init__.py +0 -0
  153. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/modules/ResnetModule.py +0 -0
  154. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/modules/__init__.py +0 -0
  155. {homa-0.2.93 → homa-0.3.11}/src/homa/vision/utils.py +0 -0
  156. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/dependency_links.txt +0 -0
  157. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/entry_points.txt +0 -0
  158. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/requires.txt +0 -0
  159. {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.2.93
3
+ Version: 0.3.11
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "homa"
7
- version = "0.2.93"
7
+ version = "0.3.11"
8
8
  description = "A curated list of machine learning and deep learning helpers."
9
9
  authors = [
10
10
  { name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
File without changes
@@ -0,0 +1,13 @@
1
+ from ...device import move
2
+
3
+
4
+ class MovesNetworkToDevice:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ if not hasattr(self, "network"):
9
+ raise RuntimeError(
10
+ "MovesNetworkToDevice assumes the underlying class has a network property."
11
+ )
12
+
13
+ move(self.network)
@@ -0,0 +1 @@
1
+ from .MovesNetworkToDevice import MovesNetworkToDevice
@@ -23,3 +23,7 @@ def mps():
23
23
 
24
24
  def device():
25
25
  return get_device()
26
+
27
+
28
+ def move(module: torch.nn.Module):
29
+ module.to(get_device())
@@ -1,16 +1,18 @@
1
1
  from .concerns import (
2
- ReportsSize,
2
+ ReportsEnsembleSize,
3
3
  StoresModels,
4
4
  ReportsClassificationMetrics,
5
5
  PredictsProbabilities,
6
+ SavesEnsembleModels,
6
7
  )
7
8
 
8
9
 
9
10
  class Ensemble(
10
- ReportsSize,
11
+ ReportsEnsembleSize,
11
12
  ReportsClassificationMetrics,
12
13
  PredictsProbabilities,
13
14
  StoresModels,
15
+ SavesEnsembleModels,
14
16
  ):
15
17
  def __init__(self):
16
18
  super().__init__()
@@ -1,11 +1,11 @@
1
- class ReportsSize:
1
+ class ReportsEnsembleSize:
2
2
  def __init__(self, *args, **kwargs):
3
3
  super().__init__(*args, **kwargs)
4
4
 
5
5
  @property
6
6
  def size(self):
7
- return len(self.models)
7
+ return len(self.weights)
8
8
 
9
9
  @property
10
10
  def length(self):
11
- return len(self.models)
11
+ return self.size
@@ -8,7 +8,9 @@ class ReportsLogits:
8
8
  def logits(self, x: torch.Tensor) -> torch.Tensor:
9
9
  batch_size = x.shape[0]
10
10
  logits = torch.zeros((batch_size, self.num_classes))
11
- for model in self.models:
11
+ for factory, weight in zip(self.factories, self.weights):
12
+ model = factory(num_classes=self.num_classes)
13
+ model.load_state_dict(weight)
12
14
  logits += model(x)
13
15
  return logits
14
16
 
@@ -0,0 +1,13 @@
1
+ class SavesEnsembleModels:
2
+ def __init__(self, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
+
5
+ def save(self):
6
+ self.save_factories()
7
+ self.save_weights()
8
+
9
+ def save_factories(self):
10
+ pass
11
+
12
+ def save_weights(self):
13
+ pass
@@ -1,23 +1,26 @@
1
1
  import torch
2
- from copy import deepcopy
3
- from typing import List
2
+ from typing import List, Type
3
+ from collections import OrderedDict
4
4
  from ...vision import Model
5
5
 
6
6
 
7
7
  class StoresModels:
8
8
  def __init__(self, *args, **kwargs):
9
9
  super().__init__(*args, **kwargs)
10
- self.models: List[torch.nn.Module] = []
10
+ self.factories: List[Type[torch.nn.Module]] = []
11
+ self.weights: List[OrderedDict] = []
11
12
 
12
13
  def record(self, model: Model | torch.nn.Module):
13
14
  model_: torch.nn.Module | None = None
14
15
  if isinstance(model, Model):
15
- model_ = deepcopy(model.network)
16
+ model_ = model.network
16
17
  elif isinstance(model, torch.nn.Module):
17
- model_ = deepcopy(model)
18
+ model_ = model
18
19
  else:
19
20
  raise TypeError("Wrong input to ensemble record")
20
- self.models.append(model_)
21
+
22
+ self.factories.append(model_.__class__)
23
+ self.weights.append(model_.state_dict())
21
24
 
22
25
  def push(self, *args, **kwargs):
23
26
  self.record(*args, **kwargs)
@@ -5,5 +5,6 @@ from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
5
5
  from .ReportsEnsembleF1 import ReportsEnsembleF1
6
6
  from .ReportsEnsembleKappa import ReportsEnsembleKappa
7
7
  from .ReportsLogits import ReportsLogits
8
- from .ReportsSize import ReportsSize
8
+ from .ReportsEnsembleSize import ReportsEnsembleSize
9
9
  from .StoresModels import StoresModels
10
+ from .SavesEnsembleModels import SavesEnsembleModels
@@ -0,0 +1,9 @@
1
+ import torch
2
+
3
+
4
+ def get_model_device(model: torch.nn.Module):
5
+ try:
6
+ device = next(model.parameters()).device
7
+ except StopIteration:
8
+ device = torch.device("cpu")
9
+ return device
@@ -0,0 +1,13 @@
1
+ import torch
2
+ from .modules import GraphAttentionModule
3
+ from ..core.concerns import MovesNetworkToDevice
4
+
5
+
6
+ class GraphAttention(MovesNetworkToDevice):
7
+ def __init__(self, lr: float = 0.005, decay: float = 5e-4, dropout: float = 0.6):
8
+ super().__init__()
9
+ self.network = GraphAttentionModule()
10
+ self.optimizer = torch.nn.AdamW(
11
+ self.network.parameters(), lr=lr, weight_decay=decay
12
+ )
13
+ self.criterion = torch.nn.CrossEntropyLoss()
@@ -0,0 +1 @@
1
+ from .GraphAttention import GraphAttention
@@ -0,0 +1,37 @@
1
+ import torch
2
+
3
+
4
+ class GraphAttentionHeadModule(torch.nn.Module):
5
+ def __init__(self, in_features: int, out_features: int, alpha=0.2):
6
+ super().__init__()
7
+ self.in_features = in_features
8
+ self.out_features = out_features
9
+ self.alpha = alpha
10
+
11
+ self.W = torch.nn.Linear(in_features, out_features, bias=False)
12
+ self.a_1 = torch.nn.Parameter(torch.randn(out_features, 1))
13
+ self.a_2 = torch.nn.Parameter(torch.randn(out_features, 1))
14
+
15
+ self.leaky_relu = torch.nn.LeakyReLU(self.alpha)
16
+ self.elu = torch.nn.ELU()
17
+ self.reset_parameters()
18
+
19
+ def reset_parameters(self):
20
+ torch.nn.init.xavier_uniform_(self.W.weight, gain=1.414)
21
+ torch.nn.init.xavier_uniform_(self.a_1, gain=1.414)
22
+ torch.nn.init.xavier_uniform_(self.a_2, gain=1.414)
23
+
24
+ def forward(self, node_features, adj_matrix):
25
+ N = node_features.size(0)
26
+ h_prime = self.W(node_features)
27
+ s1 = torch.matmul(h_prime, self.a_1)
28
+ s2 = torch.matmul(h_prime, self.a_2)
29
+ e = s1 + s2.T
30
+ e = self.leaky_relu(e)
31
+ zero_vec = -9e15 * torch.ones_like(e)
32
+ attention_mask = torch.where(
33
+ adj_matrix > 0, e, zero_vec.to(node_features.device)
34
+ )
35
+ attention_weights = F.softmax(attention_mask, dim=1)
36
+ h_new = torch.matmul(attention_weights, h_prime)
37
+ return self.elu(h_new)
@@ -0,0 +1,22 @@
1
+ import torch
2
+ from .GraphAttentionHeadModule import GraphAttentionHeadModule
3
+
4
+
5
+ class MultiHeadGraphAttentionModule(torch.nn.Module):
6
+ def __init__(self, num_heads: int, in_features: int, out_features: int, alpha=0.2):
7
+ super().__init__()
8
+ self.num_heads = num_heads
9
+ self.head_out_features = out_features
10
+ self.heads = torch.nn.ModuleList(
11
+ [
12
+ GraphAttentionHeadModule(in_features, out_features, alpha=alpha)
13
+ for _ in range(num_heads)
14
+ ]
15
+ )
16
+
17
+ def forward(
18
+ self, node_features: torch.Tensor, adj_matrix: torch.Tensor
19
+ ) -> torch.Tensor:
20
+ outputs = [head(node_features, adj_matrix) for head in self.heads]
21
+ h_new_concat = torch.cat(outputs, dim=1)
22
+ return h_new_concat
@@ -0,0 +1,2 @@
1
+ from .GraphAttentionHeadModule import GraphAttentionHeadModule
2
+ from .MultiHeadGraphAttentionModule import MultiHeadGraphAttentionModule
@@ -0,0 +1,5 @@
1
+ import torch
2
+
3
+
4
+ class Loss(torch.nn.Module):
5
+ pass
@@ -0,0 +1,2 @@
1
+ class DQN:
2
+ pass
@@ -0,0 +1,5 @@
1
+ from .DQN import DQN
2
+
3
+
4
+ class DRQN(DQN):
5
+ pass
@@ -0,0 +1,96 @@
1
+ import torch
2
+ from .diayn.Actor import Actor
3
+ from .diayn.Critic import Critic
4
+ from .diayn.Discriminator import Discriminator
5
+ from .buffers import DiversityIsAllYouNeedBuffer, Buffer
6
+
7
+
8
+ class DiversityIsAllYouNeed:
9
+ def __init__(
10
+ self,
11
+ state_dimension: int,
12
+ action_dimension: int,
13
+ hidden_dimension: int = 256,
14
+ num_skills: int = 10,
15
+ critic_decay: float = 0.0,
16
+ actor_decay: float = 0.0,
17
+ discriminator_decay: float = 0.0,
18
+ actor_lr: float = 0.0001,
19
+ critic_lr: float = 0.001,
20
+ discriminator_lr=0.001,
21
+ buffer_capacity: int = 1_000_000,
22
+ actor_epsilon: float = 1e-6,
23
+ gamma: float = 0.99,
24
+ min_std: float = -20.0,
25
+ max_std: float = 2.0,
26
+ ):
27
+ self.buffer: Buffer = DiversityIsAllYouNeedBuffer(capacity=buffer_capacity)
28
+ self.num_skills: int = num_skills
29
+ self.actor = Actor(
30
+ state_dimension=state_dimension,
31
+ action_dimension=action_dimension,
32
+ hidden_dimension=hidden_dimension,
33
+ num_skills=num_skills,
34
+ lr=actor_lr,
35
+ decay=actor_decay,
36
+ epsilon=actor_epsilon,
37
+ min_std=min_std,
38
+ max_std=max_std,
39
+ )
40
+ self.critic = Critic(
41
+ state_dimension=state_dimension,
42
+ hidden_dimension=hidden_dimension,
43
+ num_skills=num_skills,
44
+ lr=critic_lr,
45
+ decay=critic_decay,
46
+ gamma=gamma,
47
+ )
48
+ self.discriminator = Discriminator(
49
+ state_dimension=state_dimension,
50
+ hidden_dimension=hidden_dimension,
51
+ num_skills=num_skills,
52
+ lr=discriminator_lr,
53
+ decay=discriminator_decay,
54
+ )
55
+
56
+ def one_hot(self, indices, max_index) -> torch.Tensor:
57
+ one_hot = torch.zeros(indices.size(0), max_index)
58
+ one_hot.scatter_(1, indices.unsqueeze(1), 1)
59
+ return one_hot
60
+
61
+ def skill_index(self) -> torch.Tensor:
62
+ return torch.randint(0, self.num_skills, (1,))
63
+
64
+ def skill(self) -> torch.Tensor:
65
+ return self.one_hot(self.skill_index(), self.num_skills)
66
+
67
+ def advantages(
68
+ self,
69
+ states: torch.Tensor,
70
+ skills: torch.Tensor,
71
+ rewards: torch.Tensor,
72
+ terminations: torch.Tensor,
73
+ next_states: torch.Tensor,
74
+ ) -> torch.Tensor:
75
+ values = self.critic.values(states=states, skills=skills)
76
+ termination_mask = 1 - terminations
77
+ next_values = self.critic.values_(states=next_states, skills=skills)
78
+ update = self.gamma * next_values * termination_mask
79
+ return rewards + update - values
80
+
81
+ def train(self, skill: torch.Tensor):
82
+ data = self.buffer.all_tensor()
83
+ skill_indices = skill.repeat(data.states.size(0), 1).long()
84
+ skills_indices_one_hot = self.one_hot(skill_indices, self.num_skills)
85
+ self.discriminator.train(
86
+ states=data.states, skills_indices=skills_indices_one_hot
87
+ )
88
+ advantages = self.advantages(
89
+ states=data.states,
90
+ rewards=data.rewards,
91
+ terminations=data.terminations,
92
+ next_states=data.next_states,
93
+ skills=skills,
94
+ )
95
+ self.critic.train(advantages=advantages)
96
+ self.actor.train(advantages=advantages)
@@ -0,0 +1,64 @@
1
+ from .sac import SoftActor, SoftCritic
2
+ from .buffers import SoftActorCriticBuffer
3
+
4
+
5
+ class SoftActorCritic:
6
+ def __init__(
7
+ self,
8
+ state_dimension: int,
9
+ action_dimension: int,
10
+ hidden_dimension: int = 256,
11
+ buffer_capacity: int = 1_000_000,
12
+ batch_size: int = 256,
13
+ actor_lr: float = 0.0002,
14
+ critic_lr: float = 0.0003,
15
+ actor_decay: float = 0.0,
16
+ critic_decay: float = 0.0,
17
+ tau: float = 0.005,
18
+ alpha: float = 0.2,
19
+ gamma: float = 0.99,
20
+ min_std: float = -20,
21
+ max_std: float = 2,
22
+ warmup: int = 10_000,
23
+ ):
24
+ self.batch_size: int = batch_size
25
+ self.warmup: int = warmup
26
+
27
+ self.actor = SoftActor(
28
+ state_dimension=state_dimension,
29
+ action_dimension=action_dimension,
30
+ hidden_dimension=hidden_dimension,
31
+ lr=actor_lr,
32
+ weight_decay=actor_decay,
33
+ alpha=alpha,
34
+ min_std=min_std,
35
+ max_std=max_std,
36
+ )
37
+ self.critic = SoftCritic(
38
+ state_dimension=state_dimension,
39
+ action_dimension=action_dimension,
40
+ hidden_dimension=hidden_dimension,
41
+ lr=critic_lr,
42
+ weight_decay=critic_decay,
43
+ tau=tau,
44
+ gamma=gamma,
45
+ alpha=alpha,
46
+ )
47
+ self.buffer = SoftActorCriticBuffer(capacity=buffer_capacity)
48
+
49
+ def train(self):
50
+ # don't train before warmup
51
+ if self.buffer.size < self.warmup:
52
+ return
53
+
54
+ data = self.buffer.sample_torch(self.batch_size)
55
+ self.critic.train(
56
+ states=data.states,
57
+ actions=data.actions,
58
+ rewards=data.rewards,
59
+ terminations=data.terminations,
60
+ next_states=data.next_states,
61
+ actor=self.actor,
62
+ )
63
+ self.actor.train(states=data.states, critic_network=self.critic.network)
64
+ self.critic.update()
@@ -0,0 +1,4 @@
1
+ from .DiversityIsAllYouNeed import DiversityIsAllYouNeed
2
+ from .SoftActorCritic import SoftActorCritic
3
+ from .DQN import DQN
4
+ from .DRQN import DRQN
@@ -0,0 +1,11 @@
1
+ from .concerns import ResetsCollection, HasRecordAlternatives
2
+
3
+
4
+ class Buffer(ResetsCollection, HasRecordAlternatives):
5
+ def __init__(self, capacity: int):
6
+ self.capacity: int = capacity
7
+ self.reset()
8
+
9
+ @property
10
+ def size(self):
11
+ return len(self.collection)
@@ -0,0 +1,50 @@
1
+ import torch
2
+ import numpy
3
+ from types import SimpleNamespace
4
+ from .Buffer import Buffer
5
+ from .concerns import HasRecordAlternatives
6
+
7
+
8
+ class DiversityIsAllYouNeedBuffer(Buffer, HasRecordAlternatives):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def all_tensor(self) -> SimpleNamespace:
13
+ return self.all(tensor=True)
14
+
15
+ def all(self, tensor: bool = False) -> SimpleNamespace:
16
+ states, actions, rewards, next_states, terminations, probabilities = zip(
17
+ *self.collection
18
+ )
19
+
20
+ if tensor:
21
+ states = torch.from_numpy(numpy.array(states))
22
+ actions = torch.from_numpy(numpy.array(actions))
23
+ rewards = torch.from_numpy(numpy.array(rewards))
24
+ next_states = torch.from_numpy(numpy.array(next_states))
25
+ terminations = torch.from_numpy(numpy.array(terminations))
26
+ probabilities = torch.from_numpy(numpy.array(probabilities))
27
+
28
+ return SimpleNamespace(
29
+ **{
30
+ "states": states,
31
+ "actions": actions,
32
+ "rewards": rewards,
33
+ "next_states": next_states,
34
+ "terminations": terminations,
35
+ "probabilities": probabilities,
36
+ }
37
+ )
38
+
39
+ def record(
40
+ self,
41
+ state: numpy.ndarray,
42
+ action: int,
43
+ reward: float,
44
+ next_state: numpy.ndarray,
45
+ termination: bool,
46
+ probability: numpy.ndarray,
47
+ ) -> None:
48
+ self.collection.append(
49
+ (state, action, reward, next_state, termination, probability)
50
+ )
@@ -0,0 +1,5 @@
1
+ from .Buffer import Buffer
2
+
3
+
4
+ class ImageBuffer(Buffer):
5
+ pass
@@ -0,0 +1,56 @@
1
+ import numpy
2
+ import random
3
+ import torch
4
+ from types import SimpleNamespace
5
+ from .Buffer import Buffer
6
+
7
+
8
+ class SoftActorCriticBuffer(Buffer):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def record(
13
+ self,
14
+ state: numpy.ndarray,
15
+ action: int,
16
+ reward: float,
17
+ next_state: numpy.ndarray,
18
+ termination: float,
19
+ probability: numpy.ndarray,
20
+ ):
21
+ self.collection.append(
22
+ (state, action, reward, next_state, termination, probability)
23
+ )
24
+
25
+ def sample(self, k: int, as_tensor: bool = False):
26
+ batch = random.sample(self.collection, k)
27
+ states, actions, rewards, next_states, terminations, probabilities = zip(*batch)
28
+
29
+ states = numpy.array(states)
30
+ actions = numpy.array(actions)
31
+ rewards = numpy.array(rewards)
32
+ next_states = numpy.array(next_states)
33
+ terminations = numpy.array(terminations)
34
+ probabilities = numpy.array(probabilities)
35
+
36
+ if as_tensor:
37
+ states = torch.from_numpy(states).float()
38
+ actions = torch.from_numpy(actions).long()
39
+ rewards = torch.from_numpy(rewards).float()
40
+ next_states = torch.from_numpy(next_states).float()
41
+ terminations = torch.from_numpy(terminations).float()
42
+ probabilities = torch.from_numpy(probabilities).float()
43
+
44
+ return SimpleNamespace(
45
+ **{
46
+ "states": states,
47
+ "actions": actions,
48
+ "rewards": rewards,
49
+ "next_states": next_states,
50
+ "terminations": terminations,
51
+ "probabilities": probabilities,
52
+ }
53
+ )
54
+
55
+ def sample_torch(self, k: int):
56
+ return self.sample(k=k, as_tensor=True)
@@ -0,0 +1,4 @@
1
+ from .SoftActorCriticBuffer import SoftActorCriticBuffer
2
+ from .ImageBuffer import ImageBuffer
3
+ from .DiversityIsAllYouNeedBuffer import DiversityIsAllYouNeedBuffer
4
+ from .Buffer import Buffer
@@ -0,0 +1,12 @@
1
+ class HasRecordAlternatives:
2
+ def __init__(self, *args, **kwargs):
3
+ super().__init__(*args, **kwargs)
4
+
5
+ def add(self, *args, **kwargs) -> None:
6
+ self.record(*args, **kwargs)
7
+
8
+ def push(self, *args, **kwargs) -> None:
9
+ self.record(*args, **kwargs)
10
+
11
+ def append(self, *args, **kwargs) -> None:
12
+ self.record(*args, **kwargs)
@@ -0,0 +1,9 @@
1
+ from collections import deque
2
+
3
+
4
+ class ResetsCollection:
5
+ def __init__(self, *args, **kwargs):
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def reset(self):
9
+ self.collection = deque(maxlen=self.capacity)
@@ -0,0 +1,2 @@
1
+ from .HasRecordAlternatives import HasRecordAlternatives
2
+ from .ResetsCollection import ResetsCollection
@@ -0,0 +1,54 @@
1
+ import torch
2
+ from torch.distributions import Normal
3
+ from .modules import ContinuousActorModule
4
+ from ...core.concerns import MovesNetworkToDevice
5
+
6
+
7
+ class Actor(MovesNetworkToDevice):
8
+ def __init__(
9
+ self,
10
+ state_dimension: int,
11
+ action_dimension: int,
12
+ num_skills: int,
13
+ hidden_dimension: int,
14
+ lr: float,
15
+ decay: float,
16
+ epsilon: float,
17
+ min_std: float,
18
+ max_std: float,
19
+ ):
20
+ self.epsilon: float = epsilon
21
+ self.network = ContinuousActorModule(
22
+ state_dimension=state_dimension,
23
+ action_dimension=action_dimension,
24
+ hidden_dimension=hidden_dimension,
25
+ num_skills=num_skills,
26
+ min_std=min_std,
27
+ max_std=max_std,
28
+ )
29
+ self.optimizer = torch.optim.AdamW(
30
+ self.network.parameters(), lr=lr, weight_decay=decay
31
+ )
32
+
33
+ def action(self, state: torch.Tensor, skill: torch.Tensor):
34
+ mean, std = self.network(state, skill)
35
+ std = std.exp()
36
+ distribution = Normal(mean, std)
37
+ raw_action = distribution.rsample()
38
+ action = torch.tanh(raw_action)
39
+ corrected_probabilities = torch.log(1.0 - action.pow(2) + self.epsilon)
40
+ probabilities = distribution.log_prob(raw_action) - corrected_probabilities
41
+ probabilities = probabilities.sum(dim=-1, keepdim=True)
42
+ return action, probabilities
43
+
44
+ def train(self, advantages: torch.Tensor, probabilities: torch.Tensor) -> float:
45
+ self.optimizer.zero_grad()
46
+ loss = self.loss(advantages=advantages, probabilities=probabilities)
47
+ loss.backward()
48
+ self.optimizer.step()
49
+ return loss.item()
50
+
51
+ def loss(
52
+ self, advantages: torch.Tensor, probabilities: torch.Tensor
53
+ ) -> torch.Tensor:
54
+ return -(probabilities * advantages.detach()).mean()