homa 0.2.93__tar.gz → 0.3.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {homa-0.2.93 → homa-0.3.11}/PKG-INFO +1 -1
- {homa-0.2.93 → homa-0.3.11}/pyproject.toml +1 -1
- homa-0.3.11/src/homa/core/__init__.py +0 -0
- homa-0.3.11/src/homa/core/concerns/MovesNetworkToDevice.py +13 -0
- homa-0.3.11/src/homa/core/concerns/__init__.py +1 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/device.py +4 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/Ensemble.py +4 -2
- homa-0.2.93/src/homa/ensemble/concerns/ReportsSize.py → homa-0.3.11/src/homa/ensemble/concerns/ReportsEnsembleSize.py +3 -3
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsLogits.py +3 -1
- homa-0.3.11/src/homa/ensemble/concerns/SavesEnsembleModels.py +13 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/StoresModels.py +9 -6
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/__init__.py +2 -1
- homa-0.3.11/src/homa/ensemble/utils.py +9 -0
- homa-0.3.11/src/homa/graph/GraphAttention.py +13 -0
- homa-0.3.11/src/homa/graph/__init__.py +1 -0
- homa-0.3.11/src/homa/graph/modules/GraphAttentionHeadModule.py +37 -0
- homa-0.3.11/src/homa/graph/modules/MultiHeadGraphAttentionModule.py +22 -0
- homa-0.3.11/src/homa/graph/modules/__init__.py +2 -0
- homa-0.3.11/src/homa/loss/Loss.py +5 -0
- homa-0.3.11/src/homa/rl/DQN.py +2 -0
- homa-0.3.11/src/homa/rl/DRQN.py +5 -0
- homa-0.3.11/src/homa/rl/DiversityIsAllYouNeed.py +96 -0
- homa-0.3.11/src/homa/rl/SoftActorCritic.py +64 -0
- homa-0.3.11/src/homa/rl/__init__.py +4 -0
- homa-0.3.11/src/homa/rl/buffers/Buffer.py +11 -0
- homa-0.3.11/src/homa/rl/buffers/DiversityIsAllYouNeedBuffer.py +50 -0
- homa-0.3.11/src/homa/rl/buffers/ImageBuffer.py +5 -0
- homa-0.3.11/src/homa/rl/buffers/SoftActorCriticBuffer.py +56 -0
- homa-0.3.11/src/homa/rl/buffers/__init__.py +4 -0
- homa-0.3.11/src/homa/rl/buffers/concerns/HasRecordAlternatives.py +12 -0
- homa-0.3.11/src/homa/rl/buffers/concerns/ResetsCollection.py +9 -0
- homa-0.3.11/src/homa/rl/buffers/concerns/__init__.py +2 -0
- homa-0.3.11/src/homa/rl/diayn/Actor.py +54 -0
- homa-0.3.11/src/homa/rl/diayn/Critic.py +41 -0
- homa-0.3.11/src/homa/rl/diayn/Discriminator.py +45 -0
- homa-0.3.11/src/homa/rl/diayn/__init__.py +3 -0
- homa-0.3.11/src/homa/rl/diayn/modules/ContinuousActorModule.py +42 -0
- homa-0.3.11/src/homa/rl/diayn/modules/CriticModule.py +28 -0
- homa-0.3.11/src/homa/rl/diayn/modules/DiscriminatorModule.py +24 -0
- homa-0.3.11/src/homa/rl/diayn/modules/__init__.py +3 -0
- homa-0.3.11/src/homa/rl/sac/SoftActor.py +69 -0
- homa-0.3.11/src/homa/rl/sac/SoftCritic.py +100 -0
- homa-0.3.11/src/homa/rl/sac/__init__.py +2 -0
- homa-0.3.11/src/homa/rl/sac/modules/DualSoftCriticModule.py +22 -0
- homa-0.3.11/src/homa/rl/sac/modules/SoftActorModule.py +35 -0
- homa-0.3.11/src/homa/rl/sac/modules/SoftCriticModule.py +30 -0
- homa-0.3.11/src/homa/rl/sac/modules/__init__.py +3 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Resnet.py +3 -3
- homa-0.3.11/src/homa/vision/Swin.py +25 -0
- homa-0.3.11/src/homa/vision/modules/SwinModule.py +31 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/PKG-INFO +1 -1
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/SOURCES.txt +39 -3
- homa-0.2.93/src/homa/loss/Loss.py +0 -2
- homa-0.2.93/src/homa/torch/__init__.py +0 -1
- homa-0.2.93/src/homa/torch/helpers.py +0 -6
- homa-0.2.93/src/homa/vision/Swin.py +0 -13
- homa-0.2.93/src/homa/vision/modules/SwinModule.py +0 -23
- {homa-0.2.93 → homa-0.3.11}/README.md +0 -0
- {homa-0.2.93 → homa-0.3.11}/setup.cfg +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/APLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ActivationFunction.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/AdaptiveActivationFunction.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/BaseDLReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/CaLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/DLReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ERF.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Elliot.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExpExpish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExponentialDLReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ExponentialSwish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GCU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GaLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GaussianReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/GeneralizedSwish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Gish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LaLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LogLogish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/LogSigmoid.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Logish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MeLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MexicanReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/MinSin.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/NReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/NoisyReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/PLogish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ParametricLogish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Phish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/RReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/RandomizedSlopedReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SGELU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SelfArctan.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/ShiftedReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SigmoidDerivative.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SineReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SlopedReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SmallGaLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Smish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/SoftsignRReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/Suish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TBSReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TSReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TangentBipolarSigmoidReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TangentSigmoidReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TeLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/TripleStateSwish.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/WideMeLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/AOAF.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/AReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/DPReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/DualLine.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/FReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/LeLeLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/PERU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/PiLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/ShiLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/StarReLU.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/concerns/ChannelBased.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/activations/learnable/concerns/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/Command.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/InitCommand.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/Commands/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/HomaCommand.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/cli/namespaces/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/loss/LogitNormLoss.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/loss/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/settings.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/utils.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Classifier.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/Model.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/StochasticClassifier.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/StochasticSwin.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasLabels.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasLogits.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/HasProbabilities.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/ReportsAccuracy.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/ReportsMetrics.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/Trainable.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/concerns/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/modules/ResnetModule.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/modules/__init__.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa/vision/utils.py +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/dependency_links.txt +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/entry_points.txt +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/requires.txt +0 -0
- {homa-0.2.93 → homa-0.3.11}/src/homa.egg-info/top_level.txt +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "homa"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.3.11"
|
|
8
8
|
description = "A curated list of machine learning and deep learning helpers."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
|
|
File without changes
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from ...device import move
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MovesNetworkToDevice:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
if not hasattr(self, "network"):
|
|
9
|
+
raise RuntimeError(
|
|
10
|
+
"MovesNetworkToDevice assumes the underlying class has a network property."
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
move(self.network)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .MovesNetworkToDevice import MovesNetworkToDevice
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
from .concerns import (
|
|
2
|
-
|
|
2
|
+
ReportsEnsembleSize,
|
|
3
3
|
StoresModels,
|
|
4
4
|
ReportsClassificationMetrics,
|
|
5
5
|
PredictsProbabilities,
|
|
6
|
+
SavesEnsembleModels,
|
|
6
7
|
)
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
class Ensemble(
|
|
10
|
-
|
|
11
|
+
ReportsEnsembleSize,
|
|
11
12
|
ReportsClassificationMetrics,
|
|
12
13
|
PredictsProbabilities,
|
|
13
14
|
StoresModels,
|
|
15
|
+
SavesEnsembleModels,
|
|
14
16
|
):
|
|
15
17
|
def __init__(self):
|
|
16
18
|
super().__init__()
|
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
class
|
|
1
|
+
class ReportsEnsembleSize:
|
|
2
2
|
def __init__(self, *args, **kwargs):
|
|
3
3
|
super().__init__(*args, **kwargs)
|
|
4
4
|
|
|
5
5
|
@property
|
|
6
6
|
def size(self):
|
|
7
|
-
return len(self.
|
|
7
|
+
return len(self.weights)
|
|
8
8
|
|
|
9
9
|
@property
|
|
10
10
|
def length(self):
|
|
11
|
-
return
|
|
11
|
+
return self.size
|
|
@@ -8,7 +8,9 @@ class ReportsLogits:
|
|
|
8
8
|
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
9
|
batch_size = x.shape[0]
|
|
10
10
|
logits = torch.zeros((batch_size, self.num_classes))
|
|
11
|
-
for
|
|
11
|
+
for factory, weight in zip(self.factories, self.weights):
|
|
12
|
+
model = factory(num_classes=self.num_classes)
|
|
13
|
+
model.load_state_dict(weight)
|
|
12
14
|
logits += model(x)
|
|
13
15
|
return logits
|
|
14
16
|
|
|
@@ -1,23 +1,26 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from
|
|
3
|
-
from
|
|
2
|
+
from typing import List, Type
|
|
3
|
+
from collections import OrderedDict
|
|
4
4
|
from ...vision import Model
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class StoresModels:
|
|
8
8
|
def __init__(self, *args, **kwargs):
|
|
9
9
|
super().__init__(*args, **kwargs)
|
|
10
|
-
self.
|
|
10
|
+
self.factories: List[Type[torch.nn.Module]] = []
|
|
11
|
+
self.weights: List[OrderedDict] = []
|
|
11
12
|
|
|
12
13
|
def record(self, model: Model | torch.nn.Module):
|
|
13
14
|
model_: torch.nn.Module | None = None
|
|
14
15
|
if isinstance(model, Model):
|
|
15
|
-
model_ =
|
|
16
|
+
model_ = model.network
|
|
16
17
|
elif isinstance(model, torch.nn.Module):
|
|
17
|
-
model_ =
|
|
18
|
+
model_ = model
|
|
18
19
|
else:
|
|
19
20
|
raise TypeError("Wrong input to ensemble record")
|
|
20
|
-
|
|
21
|
+
|
|
22
|
+
self.factories.append(model_.__class__)
|
|
23
|
+
self.weights.append(model_.state_dict())
|
|
21
24
|
|
|
22
25
|
def push(self, *args, **kwargs):
|
|
23
26
|
self.record(*args, **kwargs)
|
|
@@ -5,5 +5,6 @@ from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
|
|
|
5
5
|
from .ReportsEnsembleF1 import ReportsEnsembleF1
|
|
6
6
|
from .ReportsEnsembleKappa import ReportsEnsembleKappa
|
|
7
7
|
from .ReportsLogits import ReportsLogits
|
|
8
|
-
from .
|
|
8
|
+
from .ReportsEnsembleSize import ReportsEnsembleSize
|
|
9
9
|
from .StoresModels import StoresModels
|
|
10
|
+
from .SavesEnsembleModels import SavesEnsembleModels
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import GraphAttentionModule
|
|
3
|
+
from ..core.concerns import MovesNetworkToDevice
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GraphAttention(MovesNetworkToDevice):
|
|
7
|
+
def __init__(self, lr: float = 0.005, decay: float = 5e-4, dropout: float = 0.6):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.network = GraphAttentionModule()
|
|
10
|
+
self.optimizer = torch.nn.AdamW(
|
|
11
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
12
|
+
)
|
|
13
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .GraphAttention import GraphAttention
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GraphAttentionHeadModule(torch.nn.Module):
|
|
5
|
+
def __init__(self, in_features: int, out_features: int, alpha=0.2):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.in_features = in_features
|
|
8
|
+
self.out_features = out_features
|
|
9
|
+
self.alpha = alpha
|
|
10
|
+
|
|
11
|
+
self.W = torch.nn.Linear(in_features, out_features, bias=False)
|
|
12
|
+
self.a_1 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
13
|
+
self.a_2 = torch.nn.Parameter(torch.randn(out_features, 1))
|
|
14
|
+
|
|
15
|
+
self.leaky_relu = torch.nn.LeakyReLU(self.alpha)
|
|
16
|
+
self.elu = torch.nn.ELU()
|
|
17
|
+
self.reset_parameters()
|
|
18
|
+
|
|
19
|
+
def reset_parameters(self):
|
|
20
|
+
torch.nn.init.xavier_uniform_(self.W.weight, gain=1.414)
|
|
21
|
+
torch.nn.init.xavier_uniform_(self.a_1, gain=1.414)
|
|
22
|
+
torch.nn.init.xavier_uniform_(self.a_2, gain=1.414)
|
|
23
|
+
|
|
24
|
+
def forward(self, node_features, adj_matrix):
|
|
25
|
+
N = node_features.size(0)
|
|
26
|
+
h_prime = self.W(node_features)
|
|
27
|
+
s1 = torch.matmul(h_prime, self.a_1)
|
|
28
|
+
s2 = torch.matmul(h_prime, self.a_2)
|
|
29
|
+
e = s1 + s2.T
|
|
30
|
+
e = self.leaky_relu(e)
|
|
31
|
+
zero_vec = -9e15 * torch.ones_like(e)
|
|
32
|
+
attention_mask = torch.where(
|
|
33
|
+
adj_matrix > 0, e, zero_vec.to(node_features.device)
|
|
34
|
+
)
|
|
35
|
+
attention_weights = F.softmax(attention_mask, dim=1)
|
|
36
|
+
h_new = torch.matmul(attention_weights, h_prime)
|
|
37
|
+
return self.elu(h_new)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .GraphAttentionHeadModule import GraphAttentionHeadModule
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MultiHeadGraphAttentionModule(torch.nn.Module):
|
|
6
|
+
def __init__(self, num_heads: int, in_features: int, out_features: int, alpha=0.2):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.num_heads = num_heads
|
|
9
|
+
self.head_out_features = out_features
|
|
10
|
+
self.heads = torch.nn.ModuleList(
|
|
11
|
+
[
|
|
12
|
+
GraphAttentionHeadModule(in_features, out_features, alpha=alpha)
|
|
13
|
+
for _ in range(num_heads)
|
|
14
|
+
]
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
def forward(
|
|
18
|
+
self, node_features: torch.Tensor, adj_matrix: torch.Tensor
|
|
19
|
+
) -> torch.Tensor:
|
|
20
|
+
outputs = [head(node_features, adj_matrix) for head in self.heads]
|
|
21
|
+
h_new_concat = torch.cat(outputs, dim=1)
|
|
22
|
+
return h_new_concat
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .diayn.Actor import Actor
|
|
3
|
+
from .diayn.Critic import Critic
|
|
4
|
+
from .diayn.Discriminator import Discriminator
|
|
5
|
+
from .buffers import DiversityIsAllYouNeedBuffer, Buffer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeed:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
state_dimension: int,
|
|
12
|
+
action_dimension: int,
|
|
13
|
+
hidden_dimension: int = 256,
|
|
14
|
+
num_skills: int = 10,
|
|
15
|
+
critic_decay: float = 0.0,
|
|
16
|
+
actor_decay: float = 0.0,
|
|
17
|
+
discriminator_decay: float = 0.0,
|
|
18
|
+
actor_lr: float = 0.0001,
|
|
19
|
+
critic_lr: float = 0.001,
|
|
20
|
+
discriminator_lr=0.001,
|
|
21
|
+
buffer_capacity: int = 1_000_000,
|
|
22
|
+
actor_epsilon: float = 1e-6,
|
|
23
|
+
gamma: float = 0.99,
|
|
24
|
+
min_std: float = -20.0,
|
|
25
|
+
max_std: float = 2.0,
|
|
26
|
+
):
|
|
27
|
+
self.buffer: Buffer = DiversityIsAllYouNeedBuffer(capacity=buffer_capacity)
|
|
28
|
+
self.num_skills: int = num_skills
|
|
29
|
+
self.actor = Actor(
|
|
30
|
+
state_dimension=state_dimension,
|
|
31
|
+
action_dimension=action_dimension,
|
|
32
|
+
hidden_dimension=hidden_dimension,
|
|
33
|
+
num_skills=num_skills,
|
|
34
|
+
lr=actor_lr,
|
|
35
|
+
decay=actor_decay,
|
|
36
|
+
epsilon=actor_epsilon,
|
|
37
|
+
min_std=min_std,
|
|
38
|
+
max_std=max_std,
|
|
39
|
+
)
|
|
40
|
+
self.critic = Critic(
|
|
41
|
+
state_dimension=state_dimension,
|
|
42
|
+
hidden_dimension=hidden_dimension,
|
|
43
|
+
num_skills=num_skills,
|
|
44
|
+
lr=critic_lr,
|
|
45
|
+
decay=critic_decay,
|
|
46
|
+
gamma=gamma,
|
|
47
|
+
)
|
|
48
|
+
self.discriminator = Discriminator(
|
|
49
|
+
state_dimension=state_dimension,
|
|
50
|
+
hidden_dimension=hidden_dimension,
|
|
51
|
+
num_skills=num_skills,
|
|
52
|
+
lr=discriminator_lr,
|
|
53
|
+
decay=discriminator_decay,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def one_hot(self, indices, max_index) -> torch.Tensor:
|
|
57
|
+
one_hot = torch.zeros(indices.size(0), max_index)
|
|
58
|
+
one_hot.scatter_(1, indices.unsqueeze(1), 1)
|
|
59
|
+
return one_hot
|
|
60
|
+
|
|
61
|
+
def skill_index(self) -> torch.Tensor:
|
|
62
|
+
return torch.randint(0, self.num_skills, (1,))
|
|
63
|
+
|
|
64
|
+
def skill(self) -> torch.Tensor:
|
|
65
|
+
return self.one_hot(self.skill_index(), self.num_skills)
|
|
66
|
+
|
|
67
|
+
def advantages(
|
|
68
|
+
self,
|
|
69
|
+
states: torch.Tensor,
|
|
70
|
+
skills: torch.Tensor,
|
|
71
|
+
rewards: torch.Tensor,
|
|
72
|
+
terminations: torch.Tensor,
|
|
73
|
+
next_states: torch.Tensor,
|
|
74
|
+
) -> torch.Tensor:
|
|
75
|
+
values = self.critic.values(states=states, skills=skills)
|
|
76
|
+
termination_mask = 1 - terminations
|
|
77
|
+
next_values = self.critic.values_(states=next_states, skills=skills)
|
|
78
|
+
update = self.gamma * next_values * termination_mask
|
|
79
|
+
return rewards + update - values
|
|
80
|
+
|
|
81
|
+
def train(self, skill: torch.Tensor):
|
|
82
|
+
data = self.buffer.all_tensor()
|
|
83
|
+
skill_indices = skill.repeat(data.states.size(0), 1).long()
|
|
84
|
+
skills_indices_one_hot = self.one_hot(skill_indices, self.num_skills)
|
|
85
|
+
self.discriminator.train(
|
|
86
|
+
states=data.states, skills_indices=skills_indices_one_hot
|
|
87
|
+
)
|
|
88
|
+
advantages = self.advantages(
|
|
89
|
+
states=data.states,
|
|
90
|
+
rewards=data.rewards,
|
|
91
|
+
terminations=data.terminations,
|
|
92
|
+
next_states=data.next_states,
|
|
93
|
+
skills=skills,
|
|
94
|
+
)
|
|
95
|
+
self.critic.train(advantages=advantages)
|
|
96
|
+
self.actor.train(advantages=advantages)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from .sac import SoftActor, SoftCritic
|
|
2
|
+
from .buffers import SoftActorCriticBuffer
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SoftActorCritic:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
state_dimension: int,
|
|
9
|
+
action_dimension: int,
|
|
10
|
+
hidden_dimension: int = 256,
|
|
11
|
+
buffer_capacity: int = 1_000_000,
|
|
12
|
+
batch_size: int = 256,
|
|
13
|
+
actor_lr: float = 0.0002,
|
|
14
|
+
critic_lr: float = 0.0003,
|
|
15
|
+
actor_decay: float = 0.0,
|
|
16
|
+
critic_decay: float = 0.0,
|
|
17
|
+
tau: float = 0.005,
|
|
18
|
+
alpha: float = 0.2,
|
|
19
|
+
gamma: float = 0.99,
|
|
20
|
+
min_std: float = -20,
|
|
21
|
+
max_std: float = 2,
|
|
22
|
+
warmup: int = 10_000,
|
|
23
|
+
):
|
|
24
|
+
self.batch_size: int = batch_size
|
|
25
|
+
self.warmup: int = warmup
|
|
26
|
+
|
|
27
|
+
self.actor = SoftActor(
|
|
28
|
+
state_dimension=state_dimension,
|
|
29
|
+
action_dimension=action_dimension,
|
|
30
|
+
hidden_dimension=hidden_dimension,
|
|
31
|
+
lr=actor_lr,
|
|
32
|
+
weight_decay=actor_decay,
|
|
33
|
+
alpha=alpha,
|
|
34
|
+
min_std=min_std,
|
|
35
|
+
max_std=max_std,
|
|
36
|
+
)
|
|
37
|
+
self.critic = SoftCritic(
|
|
38
|
+
state_dimension=state_dimension,
|
|
39
|
+
action_dimension=action_dimension,
|
|
40
|
+
hidden_dimension=hidden_dimension,
|
|
41
|
+
lr=critic_lr,
|
|
42
|
+
weight_decay=critic_decay,
|
|
43
|
+
tau=tau,
|
|
44
|
+
gamma=gamma,
|
|
45
|
+
alpha=alpha,
|
|
46
|
+
)
|
|
47
|
+
self.buffer = SoftActorCriticBuffer(capacity=buffer_capacity)
|
|
48
|
+
|
|
49
|
+
def train(self):
|
|
50
|
+
# don't train before warmup
|
|
51
|
+
if self.buffer.size < self.warmup:
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
data = self.buffer.sample_torch(self.batch_size)
|
|
55
|
+
self.critic.train(
|
|
56
|
+
states=data.states,
|
|
57
|
+
actions=data.actions,
|
|
58
|
+
rewards=data.rewards,
|
|
59
|
+
terminations=data.terminations,
|
|
60
|
+
next_states=data.next_states,
|
|
61
|
+
actor=self.actor,
|
|
62
|
+
)
|
|
63
|
+
self.actor.train(states=data.states, critic_network=self.critic.network)
|
|
64
|
+
self.critic.update()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .concerns import ResetsCollection, HasRecordAlternatives
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Buffer(ResetsCollection, HasRecordAlternatives):
|
|
5
|
+
def __init__(self, capacity: int):
|
|
6
|
+
self.capacity: int = capacity
|
|
7
|
+
self.reset()
|
|
8
|
+
|
|
9
|
+
@property
|
|
10
|
+
def size(self):
|
|
11
|
+
return len(self.collection)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy
|
|
3
|
+
from types import SimpleNamespace
|
|
4
|
+
from .Buffer import Buffer
|
|
5
|
+
from .concerns import HasRecordAlternatives
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiversityIsAllYouNeedBuffer(Buffer, HasRecordAlternatives):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def all_tensor(self) -> SimpleNamespace:
|
|
13
|
+
return self.all(tensor=True)
|
|
14
|
+
|
|
15
|
+
def all(self, tensor: bool = False) -> SimpleNamespace:
|
|
16
|
+
states, actions, rewards, next_states, terminations, probabilities = zip(
|
|
17
|
+
*self.collection
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if tensor:
|
|
21
|
+
states = torch.from_numpy(numpy.array(states))
|
|
22
|
+
actions = torch.from_numpy(numpy.array(actions))
|
|
23
|
+
rewards = torch.from_numpy(numpy.array(rewards))
|
|
24
|
+
next_states = torch.from_numpy(numpy.array(next_states))
|
|
25
|
+
terminations = torch.from_numpy(numpy.array(terminations))
|
|
26
|
+
probabilities = torch.from_numpy(numpy.array(probabilities))
|
|
27
|
+
|
|
28
|
+
return SimpleNamespace(
|
|
29
|
+
**{
|
|
30
|
+
"states": states,
|
|
31
|
+
"actions": actions,
|
|
32
|
+
"rewards": rewards,
|
|
33
|
+
"next_states": next_states,
|
|
34
|
+
"terminations": terminations,
|
|
35
|
+
"probabilities": probabilities,
|
|
36
|
+
}
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def record(
|
|
40
|
+
self,
|
|
41
|
+
state: numpy.ndarray,
|
|
42
|
+
action: int,
|
|
43
|
+
reward: float,
|
|
44
|
+
next_state: numpy.ndarray,
|
|
45
|
+
termination: bool,
|
|
46
|
+
probability: numpy.ndarray,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.collection.append(
|
|
49
|
+
(state, action, reward, next_state, termination, probability)
|
|
50
|
+
)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import numpy
|
|
2
|
+
import random
|
|
3
|
+
import torch
|
|
4
|
+
from types import SimpleNamespace
|
|
5
|
+
from .Buffer import Buffer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SoftActorCriticBuffer(Buffer):
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
|
|
12
|
+
def record(
|
|
13
|
+
self,
|
|
14
|
+
state: numpy.ndarray,
|
|
15
|
+
action: int,
|
|
16
|
+
reward: float,
|
|
17
|
+
next_state: numpy.ndarray,
|
|
18
|
+
termination: float,
|
|
19
|
+
probability: numpy.ndarray,
|
|
20
|
+
):
|
|
21
|
+
self.collection.append(
|
|
22
|
+
(state, action, reward, next_state, termination, probability)
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
def sample(self, k: int, as_tensor: bool = False):
|
|
26
|
+
batch = random.sample(self.collection, k)
|
|
27
|
+
states, actions, rewards, next_states, terminations, probabilities = zip(*batch)
|
|
28
|
+
|
|
29
|
+
states = numpy.array(states)
|
|
30
|
+
actions = numpy.array(actions)
|
|
31
|
+
rewards = numpy.array(rewards)
|
|
32
|
+
next_states = numpy.array(next_states)
|
|
33
|
+
terminations = numpy.array(terminations)
|
|
34
|
+
probabilities = numpy.array(probabilities)
|
|
35
|
+
|
|
36
|
+
if as_tensor:
|
|
37
|
+
states = torch.from_numpy(states).float()
|
|
38
|
+
actions = torch.from_numpy(actions).long()
|
|
39
|
+
rewards = torch.from_numpy(rewards).float()
|
|
40
|
+
next_states = torch.from_numpy(next_states).float()
|
|
41
|
+
terminations = torch.from_numpy(terminations).float()
|
|
42
|
+
probabilities = torch.from_numpy(probabilities).float()
|
|
43
|
+
|
|
44
|
+
return SimpleNamespace(
|
|
45
|
+
**{
|
|
46
|
+
"states": states,
|
|
47
|
+
"actions": actions,
|
|
48
|
+
"rewards": rewards,
|
|
49
|
+
"next_states": next_states,
|
|
50
|
+
"terminations": terminations,
|
|
51
|
+
"probabilities": probabilities,
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def sample_torch(self, k: int):
|
|
56
|
+
return self.sample(k=k, as_tensor=True)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
class HasRecordAlternatives:
|
|
2
|
+
def __init__(self, *args, **kwargs):
|
|
3
|
+
super().__init__(*args, **kwargs)
|
|
4
|
+
|
|
5
|
+
def add(self, *args, **kwargs) -> None:
|
|
6
|
+
self.record(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def push(self, *args, **kwargs) -> None:
|
|
9
|
+
self.record(*args, **kwargs)
|
|
10
|
+
|
|
11
|
+
def append(self, *args, **kwargs) -> None:
|
|
12
|
+
self.record(*args, **kwargs)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.distributions import Normal
|
|
3
|
+
from .modules import ContinuousActorModule
|
|
4
|
+
from ...core.concerns import MovesNetworkToDevice
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Actor(MovesNetworkToDevice):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
state_dimension: int,
|
|
11
|
+
action_dimension: int,
|
|
12
|
+
num_skills: int,
|
|
13
|
+
hidden_dimension: int,
|
|
14
|
+
lr: float,
|
|
15
|
+
decay: float,
|
|
16
|
+
epsilon: float,
|
|
17
|
+
min_std: float,
|
|
18
|
+
max_std: float,
|
|
19
|
+
):
|
|
20
|
+
self.epsilon: float = epsilon
|
|
21
|
+
self.network = ContinuousActorModule(
|
|
22
|
+
state_dimension=state_dimension,
|
|
23
|
+
action_dimension=action_dimension,
|
|
24
|
+
hidden_dimension=hidden_dimension,
|
|
25
|
+
num_skills=num_skills,
|
|
26
|
+
min_std=min_std,
|
|
27
|
+
max_std=max_std,
|
|
28
|
+
)
|
|
29
|
+
self.optimizer = torch.optim.AdamW(
|
|
30
|
+
self.network.parameters(), lr=lr, weight_decay=decay
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
def action(self, state: torch.Tensor, skill: torch.Tensor):
|
|
34
|
+
mean, std = self.network(state, skill)
|
|
35
|
+
std = std.exp()
|
|
36
|
+
distribution = Normal(mean, std)
|
|
37
|
+
raw_action = distribution.rsample()
|
|
38
|
+
action = torch.tanh(raw_action)
|
|
39
|
+
corrected_probabilities = torch.log(1.0 - action.pow(2) + self.epsilon)
|
|
40
|
+
probabilities = distribution.log_prob(raw_action) - corrected_probabilities
|
|
41
|
+
probabilities = probabilities.sum(dim=-1, keepdim=True)
|
|
42
|
+
return action, probabilities
|
|
43
|
+
|
|
44
|
+
def train(self, advantages: torch.Tensor, probabilities: torch.Tensor) -> float:
|
|
45
|
+
self.optimizer.zero_grad()
|
|
46
|
+
loss = self.loss(advantages=advantages, probabilities=probabilities)
|
|
47
|
+
loss.backward()
|
|
48
|
+
self.optimizer.step()
|
|
49
|
+
return loss.item()
|
|
50
|
+
|
|
51
|
+
def loss(
|
|
52
|
+
self, advantages: torch.Tensor, probabilities: torch.Tensor
|
|
53
|
+
) -> torch.Tensor:
|
|
54
|
+
return -(probabilities * advantages.detach()).mean()
|