homa 0.2.9__py3-none-any.whl → 0.2.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- homa/activations/learnable/AOAF.py +1 -1
- homa/activations/learnable/AReLU.py +6 -3
- homa/activations/learnable/PiLU.py +1 -1
- homa/activations/learnable/__init__.py +2 -2
- homa/activations/learnable/concerns/ChannelBased.py +2 -0
- homa/ensemble/concerns/StoresModels.py +9 -3
- {homa-0.2.9.dist-info → homa-0.2.95.dist-info}/METADATA +1 -1
- {homa-0.2.9.dist-info → homa-0.2.95.dist-info}/RECORD +11 -11
- {homa-0.2.9.dist-info → homa-0.2.95.dist-info}/WHEEL +0 -0
- {homa-0.2.9.dist-info → homa-0.2.95.dist-info}/entry_points.txt +0 -0
- {homa-0.2.9.dist-info → homa-0.2.95.dist-info}/top_level.txt +0 -0
|
@@ -12,5 +12,5 @@ class AOAF(AdaptiveActivationFunction, ChannelBased):
|
|
|
12
12
|
|
|
13
13
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
14
14
|
self.initialize(x, "a")
|
|
15
|
-
a = self.a.view(self.
|
|
15
|
+
a = self.a.view(self.parameter_shape(x))
|
|
16
16
|
return torch.relu(x - self.b * a) + self.c * a
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
3
|
+
from ...device import get_device
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class AReLU(AdaptiveActivationFunction):
|
|
@@ -7,10 +8,12 @@ class AReLU(AdaptiveActivationFunction):
|
|
|
7
8
|
super(AReLU, self).__init__()
|
|
8
9
|
self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
|
|
9
10
|
self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
|
|
11
|
+
self.a.to(get_device())
|
|
12
|
+
self.b.to(get_device())
|
|
10
13
|
|
|
11
|
-
def forward(self,
|
|
14
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
12
15
|
negative_slope = torch.clamp(self.a, 0.01, 0.99)
|
|
13
16
|
positive_slope = 1 + torch.sigmoid(self.b)
|
|
14
|
-
positive = positive_slope * torch.relu(
|
|
15
|
-
negative = negative_slope * (-torch.relu(-
|
|
17
|
+
positive = positive_slope * torch.relu(x)
|
|
18
|
+
negative = negative_slope * (-torch.relu(-x))
|
|
16
19
|
return positive + negative
|
|
@@ -3,7 +3,7 @@ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
|
|
|
3
3
|
from .concerns import ChannelBased
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class
|
|
6
|
+
class PiLU(AdaptiveActivationFunction, ChannelBased):
|
|
7
7
|
def __init__(self):
|
|
8
8
|
super().__init__()
|
|
9
9
|
self.a = None
|
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from .StarReLU import StarReLU
|
|
2
1
|
from .DualLine import DualLine
|
|
3
2
|
from .LeLeLU import LeLeLU
|
|
4
3
|
from .AReLU import AReLU
|
|
5
4
|
from .PERU import PERU
|
|
6
5
|
from .ShiLU import ShiLU
|
|
6
|
+
from .StarReLU import StarReLU
|
|
7
7
|
from .DPReLU import DPReLU
|
|
8
|
-
from .PiLU import
|
|
8
|
+
from .PiLU import PiLU
|
|
9
9
|
from .FReLU import FReLU
|
|
10
10
|
from .AOAF import AOAF
|
|
@@ -21,12 +21,14 @@ class ChannelBased:
|
|
|
21
21
|
attrs = [attrs]
|
|
22
22
|
|
|
23
23
|
self.num_channels = x.shape[1]
|
|
24
|
+
device = x.device
|
|
24
25
|
for index, attr in enumerate(attrs):
|
|
25
26
|
if index < len(values) and values[index] is not None:
|
|
26
27
|
default_value = float(values[index])
|
|
27
28
|
else:
|
|
28
29
|
default_value = 1.0
|
|
29
30
|
param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
|
|
31
|
+
param = param.to(device)
|
|
30
32
|
setattr(self, attr, param)
|
|
31
33
|
self._initialized = True
|
|
32
34
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
|
|
2
|
+
import io
|
|
3
3
|
from typing import List
|
|
4
4
|
from ...vision import Model
|
|
5
5
|
|
|
@@ -12,11 +12,17 @@ class StoresModels:
|
|
|
12
12
|
def record(self, model: Model | torch.nn.Module):
|
|
13
13
|
model_: torch.nn.Module | None = None
|
|
14
14
|
if isinstance(model, Model):
|
|
15
|
-
model_ =
|
|
15
|
+
model_ = model.network
|
|
16
16
|
elif isinstance(model, torch.nn.Module):
|
|
17
|
-
model_ =
|
|
17
|
+
model_ = model
|
|
18
18
|
else:
|
|
19
19
|
raise TypeError("Wrong input to ensemble record")
|
|
20
|
+
|
|
21
|
+
device = model_.device
|
|
22
|
+
buffer = io.BytesIO()
|
|
23
|
+
torch.save(model_.to("cpu"), buffer)
|
|
24
|
+
buffer.seek(0)
|
|
25
|
+
model_ = torch.load(buffer, map_location=device)
|
|
20
26
|
self.models.append(model_)
|
|
21
27
|
|
|
22
28
|
def push(self, *args, **kwargs):
|
|
@@ -51,18 +51,18 @@ homa/activations/TeLU.py,sha256=qU5x0EskjQs6d5rCtbL91C6cMAm8vjDnjQNMX0LcEt8,180
|
|
|
51
51
|
homa/activations/TripleStateSwish.py,sha256=UG5BGY29wUEJaryClB2rDM90s0jt5vMJF9Kv-5M4Rgo,507
|
|
52
52
|
homa/activations/WideMeLU.py,sha256=ieJjTjnK9JJtApPFGpmTynu3G8YlyH5jw6qnhkJkStI,421
|
|
53
53
|
homa/activations/__init__.py,sha256=2GHNqrOp6WoLAtFFJcSj6j4GP-W8-YAYRZGX9vZbcmU,1659
|
|
54
|
-
homa/activations/learnable/AOAF.py,sha256=
|
|
55
|
-
homa/activations/learnable/AReLU.py,sha256
|
|
54
|
+
homa/activations/learnable/AOAF.py,sha256=1ArhgpI6PfCRePgvFq8VqKDQ9rDMHZb0bm6g4Tiz13s,510
|
|
55
|
+
homa/activations/learnable/AReLU.py,sha256=Pfyv_7EEwGgW4_UyKc8CiSg7lhTcO7LZ7uIUeVQWLpA,737
|
|
56
56
|
homa/activations/learnable/DPReLU.py,sha256=xQhYTJ0-mfRGdld950xoTh8c9O08WIY50K0FjPtVVFs,507
|
|
57
57
|
homa/activations/learnable/DualLine.py,sha256=cgqyE7dVqXflT8ulCuOyKQQa09FYSj8vJkeVUEOaeIU,600
|
|
58
58
|
homa/activations/learnable/FReLU.py,sha256=qQ8GjjWWGeoE6qW9tw49mZPs29app0QK1AFOuMc5ASU,413
|
|
59
59
|
homa/activations/learnable/LeLeLU.py,sha256=ya2m60QRcpVlTwMejJTgMTxM3RRHF0RgNe72_EdD1-U,425
|
|
60
60
|
homa/activations/learnable/PERU.py,sha256=y2OxRLIA1HTUnFyRHs0zgLhLMJhQz9Q4F6QrqBSkQ00,513
|
|
61
|
-
homa/activations/learnable/PiLU.py,sha256=
|
|
61
|
+
homa/activations/learnable/PiLU.py,sha256=w7LkBBs_hr07pvizUie5Z49UkHg3O8LHA-wFK4hbnjE,612
|
|
62
62
|
homa/activations/learnable/ShiLU.py,sha256=35VC1pCAWMaxHKWYBeXd2DrXn1tepvQaT7a-KwoNdHY,475
|
|
63
63
|
homa/activations/learnable/StarReLU.py,sha256=hrscp-A0HnIvebFPLGr86K5Uf_U--EWtpNDqdNgonA0,485
|
|
64
|
-
homa/activations/learnable/__init__.py,sha256=
|
|
65
|
-
homa/activations/learnable/concerns/ChannelBased.py,sha256=
|
|
64
|
+
homa/activations/learnable/__init__.py,sha256=yDzcgM_n5sNEU0kz9I0aVgGihpw_2RvtkCCylaTCPEQ,260
|
|
65
|
+
homa/activations/learnable/concerns/ChannelBased.py,sha256=pSKnWOKVOdb0GoiBobSSUANaZPGNwT9rxBnJUpZ9Eac,1206
|
|
66
66
|
homa/activations/learnable/concerns/__init__.py,sha256=CubRRYQEQMAK2-igsYKD8tcyesPOYoZYF_IlHzRZXi4,39
|
|
67
67
|
homa/cli/HomaCommand.py,sha256=w-Dg6dFpoXbQx2tvWSLdND2pdhqB2cPSORyi4MfY8XY,307
|
|
68
68
|
homa/cli/Commands/Command.py,sha256=DnmsEwpaxdQaLjzyYBO7qtIQTLwYzyhJS31YazA1IHg,24
|
|
@@ -81,7 +81,7 @@ homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4F
|
|
|
81
81
|
homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
|
|
82
82
|
homa/ensemble/concerns/ReportsLogits.py,sha256=vTGuC9NR4rno3Mkbm0MhL8f7YopuCErGyjIorxamKTM,461
|
|
83
83
|
homa/ensemble/concerns/ReportsSize.py,sha256=S7lo_Wu6rDnuqyAcv6AI6jspaBhcpfsirpp9RVD8c20,238
|
|
84
|
-
homa/ensemble/concerns/StoresModels.py,sha256=
|
|
84
|
+
homa/ensemble/concerns/StoresModels.py,sha256=tfql0sr_Y27cHEJxZkc9AUQYlQRe0HtbN4JD940lKqY,1001
|
|
85
85
|
homa/ensemble/concerns/__init__.py,sha256=X0F_b2Jsv0XpiNhYwJsl-dfPsBOdEeW53LQPE4xQD0w,479
|
|
86
86
|
homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
|
|
87
87
|
homa/loss/Loss.py,sha256=COUr_idShYgAP8xKCxcaXbyUyAoJg7IOON0ARTQykmQ,21
|
|
@@ -106,8 +106,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
|
|
|
106
106
|
homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
|
|
107
107
|
homa/vision/modules/SwinModule.py,sha256=h7wq1YdKoN6-7C3FVFA0bpkAET_30002iTRbjZxziFQ,714
|
|
108
108
|
homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
|
|
109
|
-
homa-0.2.
|
|
110
|
-
homa-0.2.
|
|
111
|
-
homa-0.2.
|
|
112
|
-
homa-0.2.
|
|
113
|
-
homa-0.2.
|
|
109
|
+
homa-0.2.95.dist-info/METADATA,sha256=Tt_dtrzp2O9_bhBkhZAjId_k_kRQI6z9ze6aQJhId_s,1760
|
|
110
|
+
homa-0.2.95.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
111
|
+
homa-0.2.95.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
|
|
112
|
+
homa-0.2.95.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
|
|
113
|
+
homa-0.2.95.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|