homa 0.2.9__py3-none-any.whl → 0.2.95__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,5 +12,5 @@ class AOAF(AdaptiveActivationFunction, ChannelBased):
12
12
 
13
13
  def forward(self, x: torch.Tensor) -> torch.Tensor:
14
14
  self.initialize(x, "a")
15
- a = self.a.view(self.parameter_view(x))
15
+ a = self.a.view(self.parameter_shape(x))
16
16
  return torch.relu(x - self.b * a) + self.c * a
@@ -1,5 +1,6 @@
1
1
  import torch
2
2
  from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
+ from ...device import get_device
3
4
 
4
5
 
5
6
  class AReLU(AdaptiveActivationFunction):
@@ -7,10 +8,12 @@ class AReLU(AdaptiveActivationFunction):
7
8
  super(AReLU, self).__init__()
8
9
  self.a = torch.nn.Parameter(torch.tensor(0.9, requires_grad=True))
9
10
  self.b = torch.nn.Parameter(torch.tensor(2.0, requires_grad=True))
11
+ self.a.to(get_device())
12
+ self.b.to(get_device())
10
13
 
11
- def forward(self, z):
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
15
  negative_slope = torch.clamp(self.a, 0.01, 0.99)
13
16
  positive_slope = 1 + torch.sigmoid(self.b)
14
- positive = positive_slope * torch.relu(z)
15
- negative = negative_slope * (-torch.relu(-z))
17
+ positive = positive_slope * torch.relu(x)
18
+ negative = negative_slope * (-torch.relu(-x))
16
19
  return positive + negative
@@ -3,7 +3,7 @@ from ..AdaptiveActivationFunction import AdaptiveActivationFunction
3
3
  from .concerns import ChannelBased
4
4
 
5
5
 
6
- class DualLine(AdaptiveActivationFunction, ChannelBased):
6
+ class PiLU(AdaptiveActivationFunction, ChannelBased):
7
7
  def __init__(self):
8
8
  super().__init__()
9
9
  self.a = None
@@ -1,10 +1,10 @@
1
- from .StarReLU import StarReLU
2
1
  from .DualLine import DualLine
3
2
  from .LeLeLU import LeLeLU
4
3
  from .AReLU import AReLU
5
4
  from .PERU import PERU
6
5
  from .ShiLU import ShiLU
6
+ from .StarReLU import StarReLU
7
7
  from .DPReLU import DPReLU
8
- from .PiLU import DualLine
8
+ from .PiLU import PiLU
9
9
  from .FReLU import FReLU
10
10
  from .AOAF import AOAF
@@ -21,12 +21,14 @@ class ChannelBased:
21
21
  attrs = [attrs]
22
22
 
23
23
  self.num_channels = x.shape[1]
24
+ device = x.device
24
25
  for index, attr in enumerate(attrs):
25
26
  if index < len(values) and values[index] is not None:
26
27
  default_value = float(values[index])
27
28
  else:
28
29
  default_value = 1.0
29
30
  param = torch.nn.Parameter(torch.full((self.num_channels,), default_value))
31
+ param = param.to(device)
30
32
  setattr(self, attr, param)
31
33
  self._initialized = True
32
34
 
@@ -1,5 +1,5 @@
1
1
  import torch
2
- from copy import deepcopy
2
+ import io
3
3
  from typing import List
4
4
  from ...vision import Model
5
5
 
@@ -12,11 +12,17 @@ class StoresModels:
12
12
  def record(self, model: Model | torch.nn.Module):
13
13
  model_: torch.nn.Module | None = None
14
14
  if isinstance(model, Model):
15
- model_ = deepcopy(model.network)
15
+ model_ = model.network
16
16
  elif isinstance(model, torch.nn.Module):
17
- model_ = deepcopy(model)
17
+ model_ = model
18
18
  else:
19
19
  raise TypeError("Wrong input to ensemble record")
20
+
21
+ device = model_.device
22
+ buffer = io.BytesIO()
23
+ torch.save(model_.to("cpu"), buffer)
24
+ buffer.seek(0)
25
+ model_ = torch.load(buffer, map_location=device)
20
26
  self.models.append(model_)
21
27
 
22
28
  def push(self, *args, **kwargs):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: homa
3
- Version: 0.2.9
3
+ Version: 0.2.95
4
4
  Summary: A curated list of machine learning and deep learning helpers.
5
5
  Author-email: Taha Shieenavaz <tahashieenavaz@gmail.com>
6
6
  Requires-Python: >=3.7
@@ -51,18 +51,18 @@ homa/activations/TeLU.py,sha256=qU5x0EskjQs6d5rCtbL91C6cMAm8vjDnjQNMX0LcEt8,180
51
51
  homa/activations/TripleStateSwish.py,sha256=UG5BGY29wUEJaryClB2rDM90s0jt5vMJF9Kv-5M4Rgo,507
52
52
  homa/activations/WideMeLU.py,sha256=ieJjTjnK9JJtApPFGpmTynu3G8YlyH5jw6qnhkJkStI,421
53
53
  homa/activations/__init__.py,sha256=2GHNqrOp6WoLAtFFJcSj6j4GP-W8-YAYRZGX9vZbcmU,1659
54
- homa/activations/learnable/AOAF.py,sha256=KYtQtpLiupdyoumqNmz0kMTgRK66sSYiuLnpbr2H7Mw,509
55
- homa/activations/learnable/AReLU.py,sha256=-6kQ0mDGq3p9Xlg74waMa8xsTDALCtkE6pwx7DrTDeI,610
54
+ homa/activations/learnable/AOAF.py,sha256=1ArhgpI6PfCRePgvFq8VqKDQ9rDMHZb0bm6g4Tiz13s,510
55
+ homa/activations/learnable/AReLU.py,sha256=Pfyv_7EEwGgW4_UyKc8CiSg7lhTcO7LZ7uIUeVQWLpA,737
56
56
  homa/activations/learnable/DPReLU.py,sha256=xQhYTJ0-mfRGdld950xoTh8c9O08WIY50K0FjPtVVFs,507
57
57
  homa/activations/learnable/DualLine.py,sha256=cgqyE7dVqXflT8ulCuOyKQQa09FYSj8vJkeVUEOaeIU,600
58
58
  homa/activations/learnable/FReLU.py,sha256=qQ8GjjWWGeoE6qW9tw49mZPs29app0QK1AFOuMc5ASU,413
59
59
  homa/activations/learnable/LeLeLU.py,sha256=ya2m60QRcpVlTwMejJTgMTxM3RRHF0RgNe72_EdD1-U,425
60
60
  homa/activations/learnable/PERU.py,sha256=y2OxRLIA1HTUnFyRHs0zgLhLMJhQz9Q4F6QrqBSkQ00,513
61
- homa/activations/learnable/PiLU.py,sha256=p5FmWGJWlZEdLGVXmiXKg0rTxCVO-qn9bQIVcyAaa8U,616
61
+ homa/activations/learnable/PiLU.py,sha256=w7LkBBs_hr07pvizUie5Z49UkHg3O8LHA-wFK4hbnjE,612
62
62
  homa/activations/learnable/ShiLU.py,sha256=35VC1pCAWMaxHKWYBeXd2DrXn1tepvQaT7a-KwoNdHY,475
63
63
  homa/activations/learnable/StarReLU.py,sha256=hrscp-A0HnIvebFPLGr86K5Uf_U--EWtpNDqdNgonA0,485
64
- homa/activations/learnable/__init__.py,sha256=fcfm-GHEe4AQzEz9mXrWfSLkcgWaTg91ccByx7LxfX4,264
65
- homa/activations/learnable/concerns/ChannelBased.py,sha256=uK6FdC9mJRWSoXinjM8r5GJCZNWWxst7NMt8P6rnhKg,1143
64
+ homa/activations/learnable/__init__.py,sha256=yDzcgM_n5sNEU0kz9I0aVgGihpw_2RvtkCCylaTCPEQ,260
65
+ homa/activations/learnable/concerns/ChannelBased.py,sha256=pSKnWOKVOdb0GoiBobSSUANaZPGNwT9rxBnJUpZ9Eac,1206
66
66
  homa/activations/learnable/concerns/__init__.py,sha256=CubRRYQEQMAK2-igsYKD8tcyesPOYoZYF_IlHzRZXi4,39
67
67
  homa/cli/HomaCommand.py,sha256=w-Dg6dFpoXbQx2tvWSLdND2pdhqB2cPSORyi4MfY8XY,307
68
68
  homa/cli/Commands/Command.py,sha256=DnmsEwpaxdQaLjzyYBO7qtIQTLwYzyhJS31YazA1IHg,24
@@ -81,7 +81,7 @@ homa/ensemble/concerns/ReportsEnsembleF1.py,sha256=hdtdCQrWaFJNUn1KP9cAmi_q_EA4F
81
81
  homa/ensemble/concerns/ReportsEnsembleKappa.py,sha256=ZRbtrFCTD84EDql6ZL1xeWtTLFxpO5Y5tQaUlR6_0jw,300
82
82
  homa/ensemble/concerns/ReportsLogits.py,sha256=vTGuC9NR4rno3Mkbm0MhL8f7YopuCErGyjIorxamKTM,461
83
83
  homa/ensemble/concerns/ReportsSize.py,sha256=S7lo_Wu6rDnuqyAcv6AI6jspaBhcpfsirpp9RVD8c20,238
84
- homa/ensemble/concerns/StoresModels.py,sha256=PNoaoAOx4v8rercxXHmf7zqVIPGYM4APzIHHEb3RwT0,850
84
+ homa/ensemble/concerns/StoresModels.py,sha256=tfql0sr_Y27cHEJxZkc9AUQYlQRe0HtbN4JD940lKqY,1001
85
85
  homa/ensemble/concerns/__init__.py,sha256=X0F_b2Jsv0XpiNhYwJsl-dfPsBOdEeW53LQPE4xQD0w,479
86
86
  homa/loss/LogitNormLoss.py,sha256=LJMzRA1WoJ7aDYTV-FYGhgo8DMkcpv7e8_74qiJ4zT8,386
87
87
  homa/loss/Loss.py,sha256=COUr_idShYgAP8xKCxcaXbyUyAoJg7IOON0ARTQykmQ,21
@@ -106,8 +106,8 @@ homa/vision/concerns/__init__.py,sha256=mrw1YvN-GpQPvMwDF00KxnFkksPKo23RWM4KRioU
106
106
  homa/vision/modules/ResnetModule.py,sha256=eFudBnILD6OmgQtcW_CQQ8aZ62NEa4HyZ15-lobTtt0,712
107
107
  homa/vision/modules/SwinModule.py,sha256=h7wq1YdKoN6-7C3FVFA0bpkAET_30002iTRbjZxziFQ,714
108
108
  homa/vision/modules/__init__.py,sha256=zVMYB9IAO_xZylC1-N3p8ymHgEkAE2sBbuVz8K5Y1kk,74
109
- homa-0.2.9.dist-info/METADATA,sha256=uqaBYePnoJwrTwJRFB47fx_vh073hlynKWA7JAU0hDs,1759
110
- homa-0.2.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
111
- homa-0.2.9.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
112
- homa-0.2.9.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
113
- homa-0.2.9.dist-info/RECORD,,
109
+ homa-0.2.95.dist-info/METADATA,sha256=Tt_dtrzp2O9_bhBkhZAjId_k_kRQI6z9ze6aQJhId_s,1760
110
+ homa-0.2.95.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
111
+ homa-0.2.95.dist-info/entry_points.txt,sha256=tJZzjs-f2QvFe3ES8Qta8IE5sAbeE8-cyZ_UtbgqG4s,51
112
+ homa-0.2.95.dist-info/top_level.txt,sha256=tmOfy2tuaAwc3W5-i6j61_vYJsXgR4ivBWkhJ3ZtJDc,5
113
+ homa-0.2.95.dist-info/RECORD,,
File without changes