deeplotx 0.4.12b3__py3-none-any.whl → 0.4.12b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeplotx/nn/base_neural_network.py +3 -2
- deeplotx/nn/recursive_sequential.py +10 -0
- {deeplotx-0.4.12b3.dist-info → deeplotx-0.4.12b5.dist-info}/METADATA +1 -1
- {deeplotx-0.4.12b3.dist-info → deeplotx-0.4.12b5.dist-info}/RECORD +7 -7
- {deeplotx-0.4.12b3.dist-info → deeplotx-0.4.12b5.dist-info}/WHEEL +0 -0
- {deeplotx-0.4.12b3.dist-info → deeplotx-0.4.12b5.dist-info}/licenses/LICENSE +0 -0
- {deeplotx-0.4.12b3.dist-info → deeplotx-0.4.12b5.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,7 @@ class BaseNeuralNetwork(nn.Module):
|
|
24
24
|
|
25
25
|
def l1(self, _lambda: float = 1e-4) -> torch.Tensor:
|
26
26
|
def _l1() -> torch.Tensor:
|
27
|
-
l2_reg = torch.tensor(0.)
|
27
|
+
l2_reg = torch.tensor(0., device=self.device, dtype=self.dtype)
|
28
28
|
for param in self.parameters():
|
29
29
|
l2_reg += (torch.abs(param)).sum()
|
30
30
|
return l2_reg
|
@@ -32,7 +32,7 @@ class BaseNeuralNetwork(nn.Module):
|
|
32
32
|
|
33
33
|
def l2(self, _lambda: float = 1e-4) -> torch.Tensor:
|
34
34
|
def _l2() -> torch.Tensor:
|
35
|
-
l2_reg = torch.tensor(0.)
|
35
|
+
l2_reg = torch.tensor(0., device=self.device, dtype=self.dtype)
|
36
36
|
for param in self.parameters():
|
37
37
|
l2_reg += (torch.pow(param, exponent=2.)).sum()
|
38
38
|
return l2_reg
|
@@ -45,6 +45,7 @@ class BaseNeuralNetwork(nn.Module):
|
|
45
45
|
def forward(self, *args, **kwargs) -> torch.Tensor: ...
|
46
46
|
|
47
47
|
def predict(self, x) -> torch.Tensor:
|
48
|
+
x = self.ensure_device_and_dtype(x, device=self.device, dtype=self.dtype)
|
48
49
|
__train = self.training
|
49
50
|
self.training = False
|
50
51
|
with torch.no_grad():
|
@@ -34,3 +34,13 @@ class RecursiveSequential(BaseNeuralNetwork):
|
|
34
34
|
x, (hidden_state, cell_state) = self.lstm(x, state)
|
35
35
|
x = self.regressive_head(x[:, -1, :])
|
36
36
|
return x, (hidden_state, cell_state)
|
37
|
+
|
38
|
+
@override
|
39
|
+
def predict(self, x, batch_size: int | None = None) -> torch.Tensor:
|
40
|
+
_batch_size = batch_size if batch_size is not None else x.shape[0]
|
41
|
+
__train = self.training
|
42
|
+
self.training = False
|
43
|
+
with torch.no_grad():
|
44
|
+
res = self.forward(x, _batch_size)[0]
|
45
|
+
self.training = __train
|
46
|
+
return res
|
@@ -5,10 +5,10 @@ deeplotx/encoder/long_text_encoder.py,sha256=hl_O8kR9o1kcII9YfSx2rf_Pk0l_Rv7LNbs
|
|
5
5
|
deeplotx/encoder/longformer_encoder.py,sha256=A8FXqd4mdHxSn_o_R689XtpT73ISDT788EgMQRGLC2g,1822
|
6
6
|
deeplotx/nn/__init__.py,sha256=oQ-vYXyuaGelfCOs2im_gZXAiiBlCCVXh1uw9yjvRMs,253
|
7
7
|
deeplotx/nn/auto_regression.py,sha256=7P63opWCWMqE2DigwbsL6kfXtFtJPz00Yo1RqflBz4A,572
|
8
|
-
deeplotx/nn/base_neural_network.py,sha256=
|
8
|
+
deeplotx/nn/base_neural_network.py,sha256=ufA0QOFFXaz4RLqjqx9N6VY-mDwWOe9Y35u2vsh_NFc,2339
|
9
9
|
deeplotx/nn/linear_regression.py,sha256=_LQFrOKBbQxvuNzb_B8Mr6PAQJUg-pFeu3h7_jQz04o,2166
|
10
10
|
deeplotx/nn/logistic_regression.py,sha256=j8QGe0e7In97RMOXApJRID85qf1rOUCOk3V368CBfqs,653
|
11
|
-
deeplotx/nn/recursive_sequential.py,sha256=
|
11
|
+
deeplotx/nn/recursive_sequential.py,sha256=4NeSE11ZsZ4YBpjee2mRn3a3hqU4xw_k6t09zTl-tTg,2292
|
12
12
|
deeplotx/nn/softmax_regression.py,sha256=SlhvHho-Oufp7adAjm1t1ygidu-FrnHQ9aleMXyS_s8,674
|
13
13
|
deeplotx/similarity/__init__.py,sha256=s3u-KSgxjnMcWpIItKgXNltFMPQ7YY3CqsqHI-5F1c8,724
|
14
14
|
deeplotx/similarity/distribution.py,sha256=wQGouuuW531pZeBRKBujXsdsoz4fDnPw7_GW81jwepc,1066
|
@@ -20,8 +20,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=Wq_pGO78zgdXxFeBja
|
|
20
20
|
deeplotx/util/__init__.py,sha256=JxqAK_WOOHcYVSTHBT1-WuBwWrPEVDTV3titeVWvNUM,74
|
21
21
|
deeplotx/util/hash.py,sha256=wwsC6kOQvbpuvwKsNQOARd78_wePmW9i3oaUuXRUnpc,352
|
22
22
|
deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
|
23
|
-
deeplotx-0.4.
|
24
|
-
deeplotx-0.4.
|
25
|
-
deeplotx-0.4.
|
26
|
-
deeplotx-0.4.
|
27
|
-
deeplotx-0.4.
|
23
|
+
deeplotx-0.4.12b5.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
24
|
+
deeplotx-0.4.12b5.dist-info/METADATA,sha256=xA4HtA5PMPO5goVZITcOJs6hWrFD1DNXDZASs5dcD_o,6287
|
25
|
+
deeplotx-0.4.12b5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
26
|
+
deeplotx-0.4.12b5.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
|
27
|
+
deeplotx-0.4.12b5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|