homa 0.2.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of homa might be problematic. Click here for more details.
- homa/__init__.py +2 -0
- homa/activations/APLU.py +49 -0
- homa/activations/ActivationFunction.py +6 -0
- homa/activations/AdaptiveActivationFunction.py +15 -0
- homa/activations/BaseDLReLU.py +34 -0
- homa/activations/CaLU.py +13 -0
- homa/activations/DLReLU.py +6 -0
- homa/activations/ERF.py +10 -0
- homa/activations/Elliot.py +10 -0
- homa/activations/ExpExpish.py +9 -0
- homa/activations/ExponentialDLReLU.py +6 -0
- homa/activations/ExponentialSwish.py +10 -0
- homa/activations/GCU.py +9 -0
- homa/activations/GaLU.py +11 -0
- homa/activations/GaussianReLU.py +50 -0
- homa/activations/GeneralizedSwish.py +10 -0
- homa/activations/Gish.py +11 -0
- homa/activations/LaLU.py +11 -0
- homa/activations/LogLogish.py +10 -0
- homa/activations/LogSigmoid.py +10 -0
- homa/activations/Logish.py +10 -0
- homa/activations/MeLU.py +11 -0
- homa/activations/MexicanReLU.py +49 -0
- homa/activations/MinSin.py +10 -0
- homa/activations/NReLU.py +12 -0
- homa/activations/NoisyReLU.py +6 -0
- homa/activations/PLogish.py +6 -0
- homa/activations/ParametricLogish.py +13 -0
- homa/activations/Phish.py +11 -0
- homa/activations/RReLU.py +16 -0
- homa/activations/RandomizedSlopedReLU.py +7 -0
- homa/activations/SGELU.py +12 -0
- homa/activations/SReLU.py +37 -0
- homa/activations/SelfArctan.py +9 -0
- homa/activations/ShiftedReLU.py +10 -0
- homa/activations/SigmoidDerivative.py +10 -0
- homa/activations/SineReLU.py +11 -0
- homa/activations/SlopedReLU.py +13 -0
- homa/activations/SmallGaLU.py +11 -0
- homa/activations/Smish.py +9 -0
- homa/activations/SoftsignRReLU.py +17 -0
- homa/activations/Suish.py +11 -0
- homa/activations/TBSReLU.py +13 -0
- homa/activations/TSReLU.py +10 -0
- homa/activations/TangentBipolarSigmoidReLU.py +6 -0
- homa/activations/TangentSigmoidReLU.py +6 -0
- homa/activations/TeLU.py +9 -0
- homa/activations/TripleStateSwish.py +15 -0
- homa/activations/WideMeLU.py +15 -0
- homa/activations/__init__.py +49 -0
- homa/activations/learnable/AOAF.py +16 -0
- homa/activations/learnable/AReLU.py +19 -0
- homa/activations/learnable/DPReLU.py +16 -0
- homa/activations/learnable/DualLine.py +18 -0
- homa/activations/learnable/FReLU.py +14 -0
- homa/activations/learnable/LeLeLU.py +14 -0
- homa/activations/learnable/PERU.py +16 -0
- homa/activations/learnable/PiLU.py +18 -0
- homa/activations/learnable/ShiLU.py +16 -0
- homa/activations/learnable/StarReLU.py +16 -0
- homa/activations/learnable/__init__.py +10 -0
- homa/activations/learnable/concerns/ChannelBased.py +38 -0
- homa/activations/learnable/concerns/__init__.py +1 -0
- homa/cli/Commands/Command.py +2 -0
- homa/cli/Commands/InitCommand.py +34 -0
- homa/cli/Commands/__init__.py +2 -0
- homa/cli/HomaCommand.py +16 -0
- homa/cli/namespaces/CacheNamespace.py +29 -0
- homa/cli/namespaces/MakeNamespace.py +18 -0
- homa/cli/namespaces/__init__.py +2 -0
- homa/device.py +25 -0
- homa/ensemble/Ensemble.py +16 -0
- homa/ensemble/__init__.py +1 -0
- homa/ensemble/concerns/CalculatesMetricNecessities.py +24 -0
- homa/ensemble/concerns/PredictsProbabilities.py +15 -0
- homa/ensemble/concerns/ReportsClassificationMetrics.py +13 -0
- homa/ensemble/concerns/ReportsEnsembleAccuracy.py +11 -0
- homa/ensemble/concerns/ReportsEnsembleF1.py +10 -0
- homa/ensemble/concerns/ReportsEnsembleKappa.py +10 -0
- homa/ensemble/concerns/ReportsLogits.py +17 -0
- homa/ensemble/concerns/ReportsSize.py +11 -0
- homa/ensemble/concerns/StoresModels.py +36 -0
- homa/ensemble/concerns/__init__.py +9 -0
- homa/loss/LogitNormLoss.py +12 -0
- homa/loss/Loss.py +2 -0
- homa/loss/__init__.py +2 -0
- homa/settings.py +12 -0
- homa/torch/__init__.py +1 -0
- homa/torch/helpers.py +6 -0
- homa/utils.py +2 -0
- homa/vision/Classifier.py +5 -0
- homa/vision/Model.py +2 -0
- homa/vision/Resnet.py +13 -0
- homa/vision/StochasticClassifier.py +29 -0
- homa/vision/StochasticSwin.py +11 -0
- homa/vision/Swin.py +13 -0
- homa/vision/__init__.py +5 -0
- homa/vision/concerns/HasLabels.py +13 -0
- homa/vision/concerns/HasLogits.py +12 -0
- homa/vision/concerns/HasProbabilities.py +9 -0
- homa/vision/concerns/ReportsAccuracy.py +27 -0
- homa/vision/concerns/ReportsMetrics.py +6 -0
- homa/vision/concerns/Trainable.py +29 -0
- homa/vision/concerns/__init__.py +6 -0
- homa/vision/modules/ResnetModule.py +23 -0
- homa/vision/modules/SwinModule.py +23 -0
- homa/vision/modules/__init__.py +2 -0
- homa/vision/utils.py +12 -0
- homa-0.2.94.dist-info/METADATA +75 -0
- homa-0.2.94.dist-info/RECORD +113 -0
- homa-0.2.94.dist-info/WHEEL +5 -0
- homa-0.2.94.dist-info/entry_points.txt +2 -0
- homa-0.2.94.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import io
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import List
|
|
5
|
+
from ...vision import Model
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StoresModels:
|
|
9
|
+
def __init__(self, *args, **kwargs):
|
|
10
|
+
super().__init__(*args, **kwargs)
|
|
11
|
+
self.models: List[torch.nn.Module] = []
|
|
12
|
+
|
|
13
|
+
def record(self, model: Model | torch.nn.Module):
|
|
14
|
+
model_: torch.nn.Module | None = None
|
|
15
|
+
if isinstance(model, Model):
|
|
16
|
+
model_ = deepcopy(model.network)
|
|
17
|
+
elif isinstance(model, torch.nn.Module):
|
|
18
|
+
model_ = deepcopy(model)
|
|
19
|
+
else:
|
|
20
|
+
raise TypeError("Wrong input to ensemble record")
|
|
21
|
+
|
|
22
|
+
device = model_.device
|
|
23
|
+
buffer = io.BytesIO()
|
|
24
|
+
torch.save(model_.to("cpu"), buffer)
|
|
25
|
+
buffer.seek(0)
|
|
26
|
+
model_ = torch.load(buffer, map_location=device)
|
|
27
|
+
self.models.append(model_)
|
|
28
|
+
|
|
29
|
+
def push(self, *args, **kwargs):
|
|
30
|
+
self.record(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
def append(self, *args, **kwargs):
|
|
33
|
+
self.record(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
def add(self, *args, **kwargs):
|
|
36
|
+
self.record(*args, **kwargs)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .CalculatesMetricNecessities import CalculatesMetricNecessities
|
|
2
|
+
from .PredictsProbabilities import PredictsProbabilities
|
|
3
|
+
from .ReportsClassificationMetrics import ReportsClassificationMetrics
|
|
4
|
+
from .ReportsEnsembleAccuracy import ReportsEnsembleAccuracy
|
|
5
|
+
from .ReportsEnsembleF1 import ReportsEnsembleF1
|
|
6
|
+
from .ReportsEnsembleKappa import ReportsEnsembleKappa
|
|
7
|
+
from .ReportsLogits import ReportsLogits
|
|
8
|
+
from .ReportsSize import ReportsSize
|
|
9
|
+
from .StoresModels import StoresModels
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Loss import Loss
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class LogitNormLoss(Loss):
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__(*args, **kwargs)
|
|
8
|
+
|
|
9
|
+
def forward(self, logits, target):
|
|
10
|
+
norms = torch.norm(logits, p=2, dim=-1, keepdim=True) + 1e-7
|
|
11
|
+
normalized_logits = torch.div(logits, norms)
|
|
12
|
+
return torch.nn.functional.cross_entropy(normalized_logits, target)
|
homa/loss/Loss.py
ADDED
homa/loss/__init__.py
ADDED
homa/settings.py
ADDED
homa/torch/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .helpers import *
|
homa/torch/helpers.py
ADDED
homa/utils.py
ADDED
homa/vision/Model.py
ADDED
homa/vision/Resnet.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .modules import ResnetModule
|
|
3
|
+
from .Classifier import Classifier
|
|
4
|
+
from .concerns import Trainable, ReportsMetrics
|
|
5
|
+
from ..device import get_device
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Resnet(Classifier, Trainable, ReportsMetrics):
|
|
9
|
+
def __init__(self, num_classes: int, lr: float = 0.001):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.network = ResnetModule(num_classes).to(get_device())
|
|
12
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
|
13
|
+
self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from ..activations import (
|
|
2
|
+
AOAF,
|
|
3
|
+
AReLU,
|
|
4
|
+
DPReLU,
|
|
5
|
+
DualLine,
|
|
6
|
+
FReLU,
|
|
7
|
+
LeLeLU,
|
|
8
|
+
PERU,
|
|
9
|
+
PiLU,
|
|
10
|
+
ShiLU,
|
|
11
|
+
StarReLU,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class StochasticClassifier:
|
|
16
|
+
def __init__(self, *args, **kwargs):
|
|
17
|
+
super().__init__(*args, **kwargs)
|
|
18
|
+
self._activation_pool = [
|
|
19
|
+
AOAF,
|
|
20
|
+
AReLU,
|
|
21
|
+
DPReLU,
|
|
22
|
+
DualLine,
|
|
23
|
+
FReLU,
|
|
24
|
+
LeLeLU,
|
|
25
|
+
PERU,
|
|
26
|
+
PiLU,
|
|
27
|
+
ShiLU,
|
|
28
|
+
StarReLU,
|
|
29
|
+
]
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Swin import Swin
|
|
3
|
+
from .StochasticClassifier import StochasticClassifier
|
|
4
|
+
from .utils import replace_activations
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StochasticSwin(Swin, StochasticClassifier):
|
|
8
|
+
def __init__(self, *args, **kwargs):
|
|
9
|
+
super().__init__(*args, **kwargs)
|
|
10
|
+
replace_activations(self.network, torch.nn.GELU, self._activation_pool)
|
|
11
|
+
replace_activations(self.network, torch.nn.ReLU, self._activation_pool)
|
homa/vision/Swin.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .Classifier import Classifier
|
|
3
|
+
from .concerns import Trainable, ReportsMetrics
|
|
4
|
+
from .modules import SwinModule
|
|
5
|
+
from ..device import get_device
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Swin(Classifier, Trainable, ReportsMetrics):
|
|
9
|
+
def __init__(self, num_classes: int, lr: float = 0.0001):
|
|
10
|
+
super().__init__()
|
|
11
|
+
self.network = SwinModule(num_classes=num_classes).to(get_device())
|
|
12
|
+
self.optimizer = torch.optim.AdamW(self.network.parameters(), lr=lr)
|
|
13
|
+
self.criterion = torch.nn.CrossEntropyLoss()
|
homa/vision/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class HasLabels:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def predict(self, x: torch.Tensor):
|
|
9
|
+
return torch.argmax(self.logits(x), dim=1)
|
|
10
|
+
|
|
11
|
+
@torch.no_grad()
|
|
12
|
+
def predict_(self, x: torch.Tensor):
|
|
13
|
+
return torch.argmax(self.logits(x), dim=1)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class HasLogits:
|
|
5
|
+
def __init__(self, *args, **kwargs):
|
|
6
|
+
super().__init__(*args, **kwargs)
|
|
7
|
+
|
|
8
|
+
def logits(self, x: torch.Tensor) -> torch.Tensor:
|
|
9
|
+
return self.network(x)
|
|
10
|
+
|
|
11
|
+
def logits_(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
return self.network(x)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from torch import Tensor, no_grad
|
|
2
|
+
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from ...device import get_device
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReportsAccuracy:
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
|
|
10
|
+
def accuracy_tensors(self, x: Tensor, y: Tensor) -> float:
|
|
11
|
+
predictions = self.predict_(x)
|
|
12
|
+
return (predictions == y).float().mean().item()
|
|
13
|
+
|
|
14
|
+
def accuracy_dataloader(self, dataloader: DataLoader):
|
|
15
|
+
correct, total = 0, 0
|
|
16
|
+
for x, y in dataloader:
|
|
17
|
+
x, y = x.to(get_device()), y.to(get_device())
|
|
18
|
+
predictions = self.predict_(x)
|
|
19
|
+
correct += (predictions == y).sum().item()
|
|
20
|
+
total += y.numel()
|
|
21
|
+
return correct / total if total > 0 else 0.0
|
|
22
|
+
|
|
23
|
+
def accuracy(self, x: Tensor | DataLoader, y: Tensor | None = None) -> float:
|
|
24
|
+
self.network.eval()
|
|
25
|
+
if isinstance(x, DataLoader):
|
|
26
|
+
return self.accuracy_dataloader(x)
|
|
27
|
+
return self.accuracy_tensors(x, y)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from torch import Tensor
|
|
2
|
+
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from .HasLogits import HasLogits
|
|
4
|
+
from .HasProbabilities import HasProbabilities
|
|
5
|
+
from .HasLabels import HasLabels
|
|
6
|
+
from ...device import get_device
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Trainable(HasLogits, HasProbabilities, HasLabels):
|
|
10
|
+
def __init__(self, *args, **kwargs):
|
|
11
|
+
super().__init__(*args, **kwargs)
|
|
12
|
+
|
|
13
|
+
def train(self, x: Tensor | DataLoader, y: Tensor | None = None):
|
|
14
|
+
if y is None and isinstance(x, DataLoader):
|
|
15
|
+
self.train_dataloader(x)
|
|
16
|
+
return
|
|
17
|
+
self.train_tensors(x, y)
|
|
18
|
+
|
|
19
|
+
def train_tensors(self, x: Tensor, y: Tensor):
|
|
20
|
+
self.network.train()
|
|
21
|
+
self.optimizer.zero_grad()
|
|
22
|
+
loss = self.criterion(self.network(x).float(), y)
|
|
23
|
+
loss.backward()
|
|
24
|
+
self.optimizer.step()
|
|
25
|
+
|
|
26
|
+
def train_dataloader(self, dataloader: DataLoader):
|
|
27
|
+
for x, y in dataloader:
|
|
28
|
+
x, y = x.to(get_device()), y.to(get_device())
|
|
29
|
+
self.train_tensors(x, y)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.models import resnet50
|
|
3
|
+
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ResnetModule(torch.nn.Module):
|
|
7
|
+
def __init__(self, num_classes: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.num_classes = num_classes
|
|
10
|
+
self._create_encoder()
|
|
11
|
+
self._create_fc()
|
|
12
|
+
|
|
13
|
+
def _create_encoder(self):
|
|
14
|
+
self.encoder = resnet50(weights="DEFAULT")
|
|
15
|
+
self.encoder.fc = torch.nn.Identity()
|
|
16
|
+
|
|
17
|
+
def _create_fc(self):
|
|
18
|
+
self.fc = torch.nn.Linear(2048, self.num_classes)
|
|
19
|
+
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
|
+
|
|
21
|
+
def forward(self, images: torch.Tensor):
|
|
22
|
+
features = self.encoder(images)
|
|
23
|
+
return self.fc(features)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torchvision.models import swin_v2_b
|
|
3
|
+
from torch.nn.init import kaiming_uniform_ as kaiming
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SwinModule(torch.nn.Module):
|
|
7
|
+
def __init__(self, num_classes: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.num_classes = num_classes
|
|
10
|
+
self._create_encoder()
|
|
11
|
+
self._create_fc()
|
|
12
|
+
|
|
13
|
+
def _create_encoder(self):
|
|
14
|
+
self.encoder = swin_v2_b(weights="DEFAULT")
|
|
15
|
+
self.encoder.head = torch.nn.Identity()
|
|
16
|
+
|
|
17
|
+
def _create_fc(self):
|
|
18
|
+
self.fc = torch.nn.Linear(1024, self.num_classes)
|
|
19
|
+
kaiming(self.fc.weight, mode="fan_in", nonlinearity="relu")
|
|
20
|
+
|
|
21
|
+
def forward(self, images: torch.Tensor):
|
|
22
|
+
features = self.encoder(images)
|
|
23
|
+
return self.fc(features)
|
homa/vision/utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import random
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def replace_activations(module, needle: torch.nn.Module, candidates: list):
|
|
6
|
+
for name, child in module.named_children():
|
|
7
|
+
if isinstance(child, needle):
|
|
8
|
+
new_activation = random.choice(candidates)
|
|
9
|
+
setattr(module, name, new_activation())
|
|
10
|
+
else:
|
|
11
|
+
replace_activations(child, needle, candidates)
|
|
12
|
+
return module
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: homa
|
|
3
|
+
Version: 0.2.94
|
|
4
|
+
Summary: A curated list of machine learning and deep learning helpers.
|
|
5
|
+
Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
|
|
6
|
+
Requires-Python: >=3.7
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: numpy
|
|
9
|
+
Requires-Dist: torch
|
|
10
|
+
Requires-Dist: fire
|
|
11
|
+
|
|
12
|
+
# Core
|
|
13
|
+
|
|
14
|
+
### Device Management
|
|
15
|
+
|
|
16
|
+
```py
|
|
17
|
+
from homa import cpu, mps, cuda, device
|
|
18
|
+
|
|
19
|
+
torch.tensor([1, 2, 3, 4, 5]).to(cpu())
|
|
20
|
+
torch.tensor([1, 2, 3, 4, 5]).to(cuda())
|
|
21
|
+
torch.tensor([1, 2, 3, 4, 5]).to(mps())
|
|
22
|
+
torch.tensor([1, 2, 3, 4, 5]).to(device())
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
# Vision
|
|
26
|
+
|
|
27
|
+
## Resnet
|
|
28
|
+
|
|
29
|
+
This is the standard ResNet50 module.
|
|
30
|
+
|
|
31
|
+
You can train the model with a `DataLoader` object.
|
|
32
|
+
|
|
33
|
+
```py
|
|
34
|
+
from homa.vision import Resnet
|
|
35
|
+
|
|
36
|
+
model = Resnet(num_classes=10, lr=0.001)
|
|
37
|
+
for epoch in range(10):
|
|
38
|
+
model.train(train_dataloader)
|
|
39
|
+
```
|
|
40
|
+
|
|
41
|
+
Similarly you can manually take care of decomposition of data from the `DataLoader`.
|
|
42
|
+
|
|
43
|
+
```py
|
|
44
|
+
from homa.vision import Resnet
|
|
45
|
+
|
|
46
|
+
model = Resnet(num_classes=10, lr=0.001)
|
|
47
|
+
for epoch in range(10):
|
|
48
|
+
for x, y in train_dataloader:
|
|
49
|
+
model.train(x, y)
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## StochasticResnet
|
|
53
|
+
|
|
54
|
+
This is a ResNet module whose activation functions are replaced from a pool of different activation functions randomly. Read more on the [(paper)](https://www.mdpi.com/1424-8220/22/16/6129).
|
|
55
|
+
|
|
56
|
+
You can train the model with a `DataLoader` object.
|
|
57
|
+
|
|
58
|
+
```py
|
|
59
|
+
from homa.vision import StochasticResnet
|
|
60
|
+
|
|
61
|
+
model = StochasticResnet(num_classes=10, lr=0.001)
|
|
62
|
+
for epoch in range(10):
|
|
63
|
+
model.train(train_dataloader)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Similarly you can manually take care of decomposition of data from the `DataLoader`.
|
|
67
|
+
|
|
68
|
+
```py
|
|
69
|
+
from homa.vision import StochasticResnet
|
|
70
|
+
|
|
71
|
+
model = StochasticResnet(num_classes=10, lr=0.001)
|
|
72
|
+
for epoch in range(10):
|
|
73
|
+
for x, y in train_dataloader:
|
|
74
|
+
model.train(x, y)
|
|
75
|
+
```
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
homa/__init__.py,sha256=NBYFKizG8UASiz5HLsEBqzXNGlWr78xm4sLr5hxKvjU,46
|
|
2
|
+
homa/device.py,sha256=9kKXfpYfnEk2cFQWPfcJrVloHgC_SSbP4I8IRY9TYk4,343
|
|
3
|
+
homa/settings.py,sha256=CPZDPvs1380O7SY7FcSKol8kBVFVVYFgSJl3YEyJuZ0,263
|
|
4
|
+
homa/utils.py,sha256=dPp6TItJwWxBqxmkMzUuCtX_BzdPT-kMOZyXRGVMCbQ,70
|
|
5
|
+
homa/activations/APLU.py,sha256=cUf6LUjY8TewXe_V1avO_7IcOtY66Hd6Dyk_1K4R3Ms,1555
|
|
6
|
+
homa/activations/ActivationFunction.py,sha256=XUw7Pa5E-CPG6rPL8Us_pDH7xCZqY0c2P9xtnJMyX44,141
|
|
7
|
+
homa/activations/AdaptiveActivationFunction.py,sha256=p_bqAq7527UOhVm47kdUtgdC1DlApxgiLOA4ZPBFdCE,386
|
|
8
|
+
homa/activations/BaseDLReLU.py,sha256=iRmDhhbFaO8N9G8u5M01s8-y-09t7poP96oA6uQkVq8,1186
|
|
9
|
+
homa/activations/CaLU.py,sha256=n0drKwp4GstHql69p4S58KeVctdaQ1B5oK_AIoI_okk,331
|
|
10
|
+
homa/activations/DLReLU.py,sha256=Q8l2zpR5q_tSgfgbz90uDXbXMbBT3b_7BWKw6JbtpQE,191
|
|
11
|
+
homa/activations/ERF.py,sha256=tDgHbo7UNFU93XPlcQCBRRxPMksr-FOE19mlsqfzmU8,252
|
|
12
|
+
homa/activations/Elliot.py,sha256=RDxERH9vFh6FYwtZXKHMDmLVG2ia1UfOoW18Gm2_8hM,298
|
|
13
|
+
homa/activations/ExpExpish.py,sha256=iq_uOmmV9EIz2eKowEy7SCeW-OMgGcEeMcivTnPc-Y0,202
|
|
14
|
+
homa/activations/ExponentialDLReLU.py,sha256=aVtah3c4sokB-aSdbVa5F_uq06IyXHwovnHtXlKGYlw,199
|
|
15
|
+
homa/activations/ExponentialSwish.py,sha256=nJtGu1TRHa2GSQ35w2MN0HEWzFogVvA9R2pGEkFvJX4,266
|
|
16
|
+
homa/activations/GCU.py,sha256=hXwty6WPovnhPGAxQDd4bIixujdoMOORN-77imVri7s,199
|
|
17
|
+
homa/activations/GaLU.py,sha256=5QHnHsUsLAy28s-LTxtwRN-t1hO1tg9xtWmkzE1T7Ck,308
|
|
18
|
+
homa/activations/GaussianReLU.py,sha256=ufNeVnod6dxkPLmdd9ye-xt0SIWap2dehX14_YxSZVM,2051
|
|
19
|
+
homa/activations/GeneralizedSwish.py,sha256=zv6CX83cOTVnN0yoCIXIvIgkjXLnmm_T_LsvyoN7lOY,236
|
|
20
|
+
homa/activations/Gish.py,sha256=Xohk5tTmeGTmQ4PXtHF5sPBDikmoNTjdEJzy2KPDmOI,249
|
|
21
|
+
homa/activations/LaLU.py,sha256=UiulXzSTmnoU_Gp8qKigFoL6efonqbldUlsBBlm9mB8,356
|
|
22
|
+
homa/activations/LogLogish.py,sha256=lfVRNhnDGbYYakTsUmePmr5azkzz_NQwEy6NvSSD-Do,205
|
|
23
|
+
homa/activations/LogSigmoid.py,sha256=PUvr84dRRd6L-VZ_9UeWAN9lhUFr2Otj8VrAIQ3eOEM,239
|
|
24
|
+
homa/activations/Logish.py,sha256=CnL-10b76C2EaDm56N4n2GYCaYJUKl_k7H82UBJI5to,257
|
|
25
|
+
homa/activations/MeLU.py,sha256=f13h2AAQCwp9soR3RWbMAA4Bl38oqRdBAsdzh6Bf4k8,321
|
|
26
|
+
homa/activations/MexicanReLU.py,sha256=vfDa1lWI-PgY4ztDY34aeBMaJ2rOyAYt5ifZBG0DS0c,1946
|
|
27
|
+
homa/activations/MinSin.py,sha256=JzQsmuffRAGGcD40nlz2ZnOGhQMEU0JYBIeFHIC1qcE,250
|
|
28
|
+
homa/activations/NReLU.py,sha256=mX4B2OXw28M8zyd6RpkaSoOCGZuB3FaksO4oFyr3YD8,314
|
|
29
|
+
homa/activations/NoisyReLU.py,sha256=2YFkOS_h8EijvCugYQTfq_gQg5uNEkcypcm0iDgEHIg,134
|
|
30
|
+
homa/activations/PLogish.py,sha256=ia_V0xewAS6mmX3G8JNQTxWl66ea8P3MuRLQkEPv-I8,165
|
|
31
|
+
homa/activations/ParametricLogish.py,sha256=grCGG61xTDytA4iOK3kS70V9m2bYoiSmuilJV3U8vIw,360
|
|
32
|
+
homa/activations/Phish.py,sha256=CLAV1fLHRAq-GxBut-_FsSYJRMlk5sOFVlcXs3G3w9c,280
|
|
33
|
+
homa/activations/RReLU.py,sha256=ILpkmoWk8WatXusrPqSLu15xMWQALwRQVZjhzwmw1PM,476
|
|
34
|
+
homa/activations/RandomizedSlopedReLU.py,sha256=O20XX3vRRmkERxwhLSNgue-fn0qSRoF7rlIN1LSWlyI,169
|
|
35
|
+
homa/activations/SGELU.py,sha256=AaNmXRoFQ68Xsgt4sSWMZxnSCTR5OD5ZEuqxxg1mvfg,358
|
|
36
|
+
homa/activations/SReLU.py,sha256=xyChK3G2HPpM7C8icQNfMzrOm142boDLY31n9yXqPtg,1472
|
|
37
|
+
homa/activations/SelfArctan.py,sha256=Sq3yWGXjxdP32J-rSZ38BQ5S_XErr5H1ZyPsMF1VKfI,193
|
|
38
|
+
homa/activations/ShiftedReLU.py,sha256=JVsf2F6C13PRICjnVOSVEsx9IoQ9rcM2TFn55DguZQs,229
|
|
39
|
+
homa/activations/SigmoidDerivative.py,sha256=4PPT-QX4MW9ySKU4Qv9K-y--lxlqFxvKVviC2S3e6Z0,274
|
|
40
|
+
homa/activations/SineReLU.py,sha256=gzYF1ZEZAFYmUuABWJf18LIer1oPAS38i_5NLcIhP-I,357
|
|
41
|
+
homa/activations/SlopedReLU.py,sha256=j6YfM4msg6It-ANbpMzEaXkiHvgEdhFNFbB6NkY6KpE,421
|
|
42
|
+
homa/activations/SmallGaLU.py,sha256=ERrK-g3QMZTNFDzUyiSLAovymEpV5h1x1696CN5K4Zg,289
|
|
43
|
+
homa/activations/Smish.py,sha256=hsr5FS4KywsCmsuFUKP-4pKoXkJK0hhRVDleq_CFGX0,198
|
|
44
|
+
homa/activations/SoftsignRReLU.py,sha256=bBSjYDLUVKxXPyaJExYXndEO3oORnP3M6NKoU-hiCCQ,564
|
|
45
|
+
homa/activations/Suish.py,sha256=I459CV24NV1JlLbko4oHUOh98fxoLaM-2SH71pVMcwA,279
|
|
46
|
+
homa/activations/TBSReLU.py,sha256=ZfYY_M6msDimJAOHr1HyrG1HHnWiJ7hnZ5hWjCFPecU,320
|
|
47
|
+
homa/activations/TSReLU.py,sha256=gbU0Q7zhf3X6oWvKUSry6sVdRhuaxIQt8keFH3WsxV8,256
|
|
48
|
+
homa/activations/TangentBipolarSigmoidReLU.py,sha256=YtrFHkFbEbx7aeIpIRc9TLCxhpveUFSgAnvTKaKLZ4E,156
|
|
49
|
+
homa/activations/TangentSigmoidReLU.py,sha256=C47UK6ADWsG2ueaZe9FUt-sPBzeuBLkiNjpkDZOCYGc,146
|
|
50
|
+
homa/activations/TeLU.py,sha256=qU5x0EskjQs6d5rCtbL91C6cMAm8vjDnjQNMX0LcEt8,180
|
|
51
|
+
homa/activations/TripleStateSwish.py,sha256=UG5BGY29wUEJaryClB2rDM90s0jt5vMJF9Kv-5M4Rgo,507
|
|
52
|
+
homa/activations/WideMeLU.py,sha256=ieJjTjnK9JJtApPFGpmTynu3G8YlyH5jw6qnhkJkStI,421
|
|
53
|
+
homa/activations/__init__.py,sha256=2GHNqrOp6WoLAtFFJcSj6j4GP-W8-YAYRZGX9vZbcmU,1659
|
|
54
|
+
homa/activations/learnable/AOAF.py,sha256=1ArhgpI6PfCRePgvFq8VqKDQ9rDMHZb0bm6g4Tiz13s,510
|
|
55
|
+
homa/activations/learnable/AReLU.py,sha256=Pfyv_7EEwGgW4_UyKc8CiSg7lhTcO7LZ7uIUeVQWLpA,737
|
|
56
|
+
homa/activations/learnable/DPReLU.py,sha256=xQhYTJ0-mfRGdld950xoTh8c9O08WIY50K0FjPtVVFs,507
|
|
57
|
+
homa/activations/learnable/DualLine.py,sha256=cgqyE7dVqXflT8ulCuOyKQQa09FYSj8vJkeVUEOaeIU,600
|
|
58
|
+
homa/activations/learnable/FReLU.py,sha256=qQ8GjjWWGeoE6qW9tw49mZPs29app0QK1AFOuMc5ASU,413
|
|
59
|
+
homa/activations/learnable/LeLeLU.py,sha256=ya2m60QRcpVlTwMejJTgMTxM3RRHF0RgNe72_EdD1-U,425
|
|
60
|
+
homa/activations/learnable/PERU.py,sha256=y2OxRLIA1HTUnFyRHs0zgLhLMJhQz9Q4F6QrqBSkQ00,513
|
|
61
|
+
homa/activations/learnable/PiLU.py,sha256=w7LkBBs_hr07pvizUie5Z49UkHg3O8LHA-wFK4hbnjE,612
|
|
62
|
+
homa/activations/learnable/ShiLU.py,sha256=35VC1pCAWMaxHKWYBeXd2DrXn1tepvQaT7a-KwoNdHY,475
|
|
63
|
+
homa/activations/learnable/StarReLU.py,sha256=hrscp-A0HnIvebFPLGr86K5Uf_U--EWtpNDqdNgonA0,485
|
|
64
|
+
homa/activations/learnable/__init__.py,sha256=yDzcgM_n5sNEU0kz9I0aVgGihpw_2RvtkCCylaTCPEQ,260
|
|
65
|
+
homa/activations/learnable/concerns/ChannelBased.py,sha256=pSKnWOKVOdb0GoiBobSSUANaZPGNwT9rxBnJUpZ9Eac,1206
|
|
66
|
+
homa/activations/learnable/concerns/__init__.py,sha256=CubRRYQEQMAK2-igsYKD8tcyesPOYoZYF_IlHzRZXi4,39
|
|
67
|
+
homa/cli/HomaCommand.py,sha256=w-Dg6dFpoXbQx2tvWSLdND2pdhqB2cPSORyi4MfY8XY,307
|
|
68
|
+
homa/cli/Commands/Command.py,sha256=DnmsEwpaxdQaLjzyYBO7qtIQTLwYzyhJS31YazA1IHg,24
|
|
69
|
+
homa/cli/Commands/InitCommand.py,sha256=3whh2mWLuevXpUyRpDEMbo_KNeAIdO2aLMFnC2nz_0c,1159
|
|
70
|
+
homa/cli/Commands/__init__.py,sha256=PYKkcG06R5LqLnp2x8otuimzRpL4oMbziL3xEMkCffc,66
|
|
71
|
+
homa/cli/namespaces/CacheNamespace.py,sha256=QXGljzj287stzTx0y_MXnqvCgPLqd7WjSPop2WDe14E,784
|
|
72
|
+
homa/cli/namespaces/MakeNamespace.py,sha256=5G6LHk3lDkXROz7uq4jYE0DyO_V7JvnhJ33IFCiqYro,590
|
|
73
|
+
homa/cli/namespaces/__init__.py,sha256=zAKUGPH4wcacxfH5Qvidp-uOuHdfzhan6kvVI6eMKA8,84
|
|
74
|
+
homa/ensemble/Ensemble.py,sha256=GNkXEV7Nli8lHSTQ3qTTCTeSBwST1PLZS5wxpKpeC5U,290
|
|
75
|
+
homa/ensemble/__init__.py,sha256=1pk2W-NbgfDFh9WLKZVLUk2E3PTjVZ5Bap9dQEnrs9o,31
|
|
76
|
+
homa/ensemble/concerns/CalculatesMetricNecessities.py,sha256=QccROg_FOp_X2T_lZDg8p1DMZhPYdO-7aEdnebRXMsY,825
|
|
77
|
+
homa/ensemble/concerns/PredictsProbabilities.py,sha256=7rmI66DzE7-QGoJgZEk-9fu5YQvJW-4ZnMn_dWEEhqU,440
|
|
78
|
+
homa/ensemble/concerns/ReportsClassificationMetrics.py,sha256=bg__cdCKp2U1H9qN1aOJH4BoX98oIvt8XaPDGApJhSM,395
|
|
79
|
+
homa/ensemble/concerns/ReportsEnsembleAccuracy.py,sha256=AX5X3VGOm7DfdonW0N7FFgUwEr7wnsojRSVEULEii7c,380
|
|
80
|
+
homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4FYnpkBMlYLjzRZg,296
|
|
81
|
+
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
|
|
82
|
+
homa/ensemble/concerns/ReportsLogits.py,sha256=vTGuC9NR4rno3Mkbm0MhL8f7YopuCErGyjIorxamKTM,461
|
|
83
|
+
homa/ensemble/concerns/ReportsSize.py,sha256=S7lo_Wu6rDnuqyAcv6AI6jspaBhcpfsirpp9RVD8c20,238
|
|
84
|
+
homa/ensemble/concerns/StoresModels.py,sha256=VWTBZbRepa_AW82BX5FMhTlUZ0OFthiRPY71Rvx0mYs,1047
|
|
85
|
+
homa/ensemble/concerns/__init__.py,sha256=X0F_b2Jsv0XpiNhYwJsl-dfPsBOdEeW53LQPE4xQD0w,479
|
|
86
|
+
homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
|
|
87
|
+
homa/loss/Loss.py,sha256=COUr_idShYgAP8xKCxcaXbyUyAoJg7IOON0ARTQykmQ,21
|
|
88
|
+
homa/loss/__init__.py,sha256=4mPVzme2_-M64bgBu1cANIfBFAL0voa5I71-ceMr_qk,64
|
|
89
|
+
homa/torch/__init__.py,sha256=HTxCVaw1TLgpHMH8guB3hHYQ80cX6_fSEoPT_hz2Y8w,23
|
|
90
|
+
homa/torch/helpers.py,sha256=CLbTCXRrroM0n4PfM-K_xFavs4dCZJEu_L7hdgb1DCI,134
|
|
91
|
+
homa/vision/Classifier.py,sha256=bAypqREQVuPamnc8hpbLCwmW9Uly3T1rvrlbMxXp1eA,61
|
|
92
|
+
homa/vision/Model.py,sha256=JIeVpHJwirHfsDfYYbLsu0kt7bGf4nhMQGIOagUDKw4,22
|
|
93
|
+
homa/vision/Resnet.py,sha256=Uitf58bEzIKkZd-F4FTvJ8nmhoFHlzZjJTvBPXEt2Iw,513
|
|
94
|
+
homa/vision/StochasticClassifier.py,sha256=6-o0TaH4iWXiPFefL7DOdLr3ZrTnjnJ9PIgQLlygN8w,497
|
|
95
|
+
homa/vision/StochasticSwin.py,sha256=FggzfaVYrP4fnjAFcdMpDozwQHc7CQhl3iRw78oQh0o,425
|
|
96
|
+
homa/vision/Swin.py,sha256=W3XbfUTrjaIhMH8fI_whPP6XO9fVA2R34LlGfQ1hoyo,508
|
|
97
|
+
homa/vision/__init__.py,sha256=w5OkcmdU6Ik5wHIJzeV1Z2UElQtvCsUZks1Q-xViSVg,153
|
|
98
|
+
homa/vision/utils.py,sha256=WB2b7eMDaf6UO3SuS7cB6IJk-9NRQesLavuzWUZRZyg,389
|
|
99
|
+
homa/vision/concerns/HasLabels.py,sha256=fM6nHLeQaEaWDlV6R8NQ5hgOSiwspPxOIwj-nvYXbP0,321
|
|
100
|
+
homa/vision/concerns/HasLogits.py,sha256=oStX4NCV7zwxI7Vj23M8wQSlY1xoSmAYJ_6cBNJpVCk,290
|
|
101
|
+
homa/vision/concerns/HasProbabilities.py,sha256=m1_ObS2BNYO-WVCNVMiHXzC3XAsyb88_0N4BWVDwCw0,221
|
|
102
|
+
homa/vision/concerns/ReportsAccuracy.py,sha256=DD0YTr5i8JMllIJTQn88Dn711yjZ2uiecaTi7WqpOEw,986
|
|
103
|
+
homa/vision/concerns/ReportsMetrics.py,sha256=93Hw_JBUbwfkrJNJA1xFSQ4cqRwzbSv4nPU524PGF6I,169
|
|
104
|
+
homa/vision/concerns/Trainable.py,sha256=SRCW3XpG9_DQgubyqhALlYDHwAWNzVVFjshUv1ecuEQ,988
|
|
105
|
+
homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioURsg,234
|
|
106
|
+
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
107
|
+
homa/vision/modules/SwinModule.py,sha256=h7wq1YdKoN6-7C3FVFA0bpkAET_30002iTRbjZxziFQ,714
|
|
108
|
+
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
109
|
+
homa-0.2.94.dist-info/METADATA,sha256=BHw_3Yfudz7Q8NC1ukYYINXLS7hR_dehKMhy4mtS24g,1760
|
|
110
|
+
homa-0.2.94.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
111
|
+
homa-0.2.94.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
112
|
+
homa-0.2.94.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
113
|
+
homa-0.2.94.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
homa
|