homa 0.1.2__tar.gz → 0.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of homa might be problematic. Click here for more details.
- {homa-0.1.2 → homa-0.1.5}/PKG-INFO +1 -1
- {homa-0.1.2 → homa-0.1.5}/pyproject.toml +1 -1
- homa-0.1.5/src/homa/torch/__init__.py +1 -0
- homa-0.1.5/src/homa/vision/ClassificationModel.py +5 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/Resnet.py +5 -4
- homa-0.1.5/src/homa/vision/concerns/Predicts.py +9 -0
- homa-0.1.5/src/homa/vision/concerns/ReportsAccuracy.py +29 -0
- homa-0.1.5/src/homa/vision/concerns/ReportsLogits.py +9 -0
- homa-0.1.5/src/homa/vision/concerns/ReportsMetrics.py +6 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/concerns/Trainable.py +3 -1
- homa-0.1.5/src/homa/vision/concerns/__init__.py +5 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/PKG-INFO +1 -1
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/SOURCES.txt +5 -1
- {homa-0.1.2 → homa-0.1.5}/tests/test_resnet.py +13 -0
- homa-0.1.2/src/homa/torch/Module.py +0 -8
- homa-0.1.2/src/homa/torch/__init__.py +0 -2
- homa-0.1.2/src/homa/vision/concerns/__init__.py +0 -1
- {homa-0.1.2 → homa-0.1.5}/README.md +0 -0
- {homa-0.1.2 → homa-0.1.5}/setup.cfg +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/APLU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/GALU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/MELU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/PDELU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/SReLU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/SmallGALU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/StochasticActivation.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/WideMELU.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/classes/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/activations/utils.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/cli/HomaCommand.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/cli/namespaces/CacheNamespace.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/cli/namespaces/MakeNamespace.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/cli/namespaces/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/device.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/Ensemble.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/CalculatesMetricNecessities.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/HasNetwork.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/HasStateDicts.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/PredictsProbabilities.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/RecordsStateDictionaries.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsClassificationMetrics.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsEnsembleAccuracy.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsEnsembleF1.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsEnsembleKappa.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsLogits.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/ReportsSize.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/ensemble/concerns/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/settings.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/torch/helpers.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/utils.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/Model.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/StochasticResnet.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/modules/ResnetModule.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/modules/StochasticResnetModule.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/modules/__init__.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa/vision/utils.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/dependency_links.txt +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/entry_points.txt +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/requires.txt +0 -0
- {homa-0.1.2 → homa-0.1.5}/src/homa.egg-info/top_level.txt +0 -0
- {homa-0.1.2 → homa-0.1.5}/tests/test_ensemble.py +0 -0
- {homa-0.1.2 → homa-0.1.5}/tests/test_stochastic_resnet.py +0 -0
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "homa"
|
|
7
|
-
version = "0.1.
|
|
7
|
+
version = "0.1.5"
|
|
8
8
|
description = "A curated list of machine learning and deep learning helpers."
|
|
9
9
|
authors = [
|
|
10
10
|
{ name="Taha Shieenavaz", email="tahashieenavaz@gmail.com" },
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .helpers import *
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from .modules import ResnetModule
|
|
3
|
-
from .
|
|
4
|
-
from .concerns import Trainable
|
|
3
|
+
from .ClassificationModel import ClassificationModel
|
|
4
|
+
from .concerns import Trainable, ReportsMetrics
|
|
5
|
+
from ..device import get_device
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
class Resnet(
|
|
8
|
+
class Resnet(ClassificationModel, Trainable, ReportsMetrics):
|
|
8
9
|
def __init__(self, num_classes: int, lr: float):
|
|
9
10
|
super().__init__()
|
|
10
|
-
self.network = ResnetModule(num_classes)
|
|
11
|
+
self.network = ResnetModule(num_classes).to(get_device())
|
|
11
12
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
12
13
|
self.optimizer = torch.optim.SGD(self.network.parameters(), lr=lr, momentum=0.9)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from torch import Tensor, no_grad
|
|
2
|
+
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from ...device import get_device
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReportsAccuracy:
|
|
7
|
+
def __init__(self, *args, **kwargs):
|
|
8
|
+
super().__init__(*args, **kwargs)
|
|
9
|
+
|
|
10
|
+
@no_grad()
|
|
11
|
+
def accuracy_tensors(self, x: Tensor, y: Tensor) -> float:
|
|
12
|
+
predictions = self.predict(x)
|
|
13
|
+
return (predictions == y).float().mean().item()
|
|
14
|
+
|
|
15
|
+
@no_grad()
|
|
16
|
+
def accuracy_dataloader(self, dataloader: DataLoader):
|
|
17
|
+
correct, total = 0, 0
|
|
18
|
+
for x, y in dataloader:
|
|
19
|
+
x, y = x.to(get_device()), y.to(get_device())
|
|
20
|
+
predictions = self.predict(x)
|
|
21
|
+
correct += (predictions == y).sum().item()
|
|
22
|
+
total += y.numel()
|
|
23
|
+
return correct / total if total > 0 else 0.0
|
|
24
|
+
|
|
25
|
+
def accuracy(self, x: Tensor | DataLoader, y: Tensor | None = None) -> float:
|
|
26
|
+
self.network.eval()
|
|
27
|
+
if isinstance(x, DataLoader):
|
|
28
|
+
return self.accuracy_dataloader(x)
|
|
29
|
+
return self.accuracy_tensors(x, y)
|
|
@@ -1,9 +1,11 @@
|
|
|
1
1
|
from torch import Tensor
|
|
2
2
|
from torch.utils.data.dataloader import DataLoader
|
|
3
|
+
from .ReportsLogits import ReportsLogits
|
|
4
|
+
from .Predicts import Predicts
|
|
3
5
|
from ...device import get_device
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class Trainable:
|
|
8
|
+
class Trainable(ReportsLogits, Predicts):
|
|
7
9
|
def __init__(self, *args, **kwargs):
|
|
8
10
|
super().__init__(*args, **kwargs)
|
|
9
11
|
|
|
@@ -39,14 +39,18 @@ src/homa/ensemble/concerns/ReportsEnsembleKappa.py
|
|
|
39
39
|
src/homa/ensemble/concerns/ReportsLogits.py
|
|
40
40
|
src/homa/ensemble/concerns/ReportsSize.py
|
|
41
41
|
src/homa/ensemble/concerns/__init__.py
|
|
42
|
-
src/homa/torch/Module.py
|
|
43
42
|
src/homa/torch/__init__.py
|
|
44
43
|
src/homa/torch/helpers.py
|
|
44
|
+
src/homa/vision/ClassificationModel.py
|
|
45
45
|
src/homa/vision/Model.py
|
|
46
46
|
src/homa/vision/Resnet.py
|
|
47
47
|
src/homa/vision/StochasticResnet.py
|
|
48
48
|
src/homa/vision/__init__.py
|
|
49
49
|
src/homa/vision/utils.py
|
|
50
|
+
src/homa/vision/concerns/Predicts.py
|
|
51
|
+
src/homa/vision/concerns/ReportsAccuracy.py
|
|
52
|
+
src/homa/vision/concerns/ReportsLogits.py
|
|
53
|
+
src/homa/vision/concerns/ReportsMetrics.py
|
|
50
54
|
src/homa/vision/concerns/Trainable.py
|
|
51
55
|
src/homa/vision/concerns/__init__.py
|
|
52
56
|
src/homa/vision/modules/ResnetModule.py
|
|
@@ -2,6 +2,7 @@ import pytest
|
|
|
2
2
|
import torch
|
|
3
3
|
from homa.vision import Resnet, Model
|
|
4
4
|
from homa.vision.modules import ResnetModule
|
|
5
|
+
from homa import get_device
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
@pytest.fixture
|
|
@@ -19,3 +20,15 @@ def test_resnet_initialization(resnet_model):
|
|
|
19
20
|
assert isinstance(resnet_model.network, ResnetModule)
|
|
20
21
|
assert isinstance(resnet_model.optimizer, torch.optim.SGD)
|
|
21
22
|
assert isinstance(resnet_model.criterion, torch.nn.CrossEntropyLoss)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def test_reports_accuracy(resnet_model):
|
|
26
|
+
x = torch.randn(10, 3, 84, 84).to(get_device())
|
|
27
|
+
y = torch.randint(0, 3, (10, 1)).to(get_device())
|
|
28
|
+
accuracy = resnet_model.accuracy(x, y)
|
|
29
|
+
assert isinstance(accuracy, float)
|
|
30
|
+
|
|
31
|
+
dataset = torch.utils.data.TensorDataset(x, y)
|
|
32
|
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)
|
|
33
|
+
accuracy = resnet_model.accuracy(dataloader)
|
|
34
|
+
assert isinstance(accuracy, float)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
from .Trainable import Trainable
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|