deeplotx 0.4.12b3__py3-none-any.whl → 0.4.12b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,7 @@ class BaseNeuralNetwork(nn.Module):
24
24
 
25
25
  def l1(self, _lambda: float = 1e-4) -> torch.Tensor:
26
26
  def _l1() -> torch.Tensor:
27
- l2_reg = torch.tensor(0.)
27
+ l2_reg = torch.tensor(0., device=self.device, dtype=self.dtype)
28
28
  for param in self.parameters():
29
29
  l2_reg += (torch.abs(param)).sum()
30
30
  return l2_reg
@@ -32,7 +32,7 @@ class BaseNeuralNetwork(nn.Module):
32
32
 
33
33
  def l2(self, _lambda: float = 1e-4) -> torch.Tensor:
34
34
  def _l2() -> torch.Tensor:
35
- l2_reg = torch.tensor(0.)
35
+ l2_reg = torch.tensor(0., device=self.device, dtype=self.dtype)
36
36
  for param in self.parameters():
37
37
  l2_reg += (torch.pow(param, exponent=2.)).sum()
38
38
  return l2_reg
@@ -45,6 +45,7 @@ class BaseNeuralNetwork(nn.Module):
45
45
  def forward(self, *args, **kwargs) -> torch.Tensor: ...
46
46
 
47
47
  def predict(self, x) -> torch.Tensor:
48
+ x = self.ensure_device_and_dtype(x, device=self.device, dtype=self.dtype)
48
49
  __train = self.training
49
50
  self.training = False
50
51
  with torch.no_grad():
@@ -34,3 +34,13 @@ class RecursiveSequential(BaseNeuralNetwork):
34
34
  x, (hidden_state, cell_state) = self.lstm(x, state)
35
35
  x = self.regressive_head(x[:, -1, :])
36
36
  return x, (hidden_state, cell_state)
37
+
38
+ @override
39
+ def predict(self, x, batch_size: int | None = None) -> torch.Tensor:
40
+ _batch_size = batch_size if batch_size is not None else x.shape[0]
41
+ __train = self.training
42
+ self.training = False
43
+ with torch.no_grad():
44
+ res = self.forward(x, _batch_size)[0]
45
+ self.training = __train
46
+ return res
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.4.12b3
3
+ Version: 0.4.12b5
4
4
  Summary: Easy-2-use long text NLP toolkit.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -5,10 +5,10 @@ deeplotx/encoder/long_text_encoder.py,sha256=hl_O8kR9o1kcII9YfSx2rf_Pk0l_Rv7LNbs
5
5
  deeplotx/encoder/longformer_encoder.py,sha256=A8FXqd4mdHxSn_o_R689XtpT73ISDT788EgMQRGLC2g,1822
6
6
  deeplotx/nn/__init__.py,sha256=oQ-vYXyuaGelfCOs2im_gZXAiiBlCCVXh1uw9yjvRMs,253
7
7
  deeplotx/nn/auto_regression.py,sha256=7P63opWCWMqE2DigwbsL6kfXtFtJPz00Yo1RqflBz4A,572
8
- deeplotx/nn/base_neural_network.py,sha256=yEyF5C-Z3bp4Ddx6GbvqpBxXyFdbSChmP6SgyTzjQmM,2180
8
+ deeplotx/nn/base_neural_network.py,sha256=ufA0QOFFXaz4RLqjqx9N6VY-mDwWOe9Y35u2vsh_NFc,2339
9
9
  deeplotx/nn/linear_regression.py,sha256=_LQFrOKBbQxvuNzb_B8Mr6PAQJUg-pFeu3h7_jQz04o,2166
10
10
  deeplotx/nn/logistic_regression.py,sha256=j8QGe0e7In97RMOXApJRID85qf1rOUCOk3V368CBfqs,653
11
- deeplotx/nn/recursive_sequential.py,sha256=YCQUUcTBsZUeyO7CLjUO1EISYX1SXPnW6asR6ZBQAb4,1926
11
+ deeplotx/nn/recursive_sequential.py,sha256=4NeSE11ZsZ4YBpjee2mRn3a3hqU4xw_k6t09zTl-tTg,2292
12
12
  deeplotx/nn/softmax_regression.py,sha256=SlhvHho-Oufp7adAjm1t1ygidu-FrnHQ9aleMXyS_s8,674
13
13
  deeplotx/similarity/__init__.py,sha256=s3u-KSgxjnMcWpIItKgXNltFMPQ7YY3CqsqHI-5F1c8,724
14
14
  deeplotx/similarity/distribution.py,sha256=wQGouuuW531pZeBRKBujXsdsoz4fDnPw7_GW81jwepc,1066
@@ -20,8 +20,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=Wq_pGO78zgdXxFeBja
20
20
  deeplotx/util/__init__.py,sha256=JxqAK_WOOHcYVSTHBT1-WuBwWrPEVDTV3titeVWvNUM,74
21
21
  deeplotx/util/hash.py,sha256=wwsC6kOQvbpuvwKsNQOARd78_wePmW9i3oaUuXRUnpc,352
22
22
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
23
- deeplotx-0.4.12b3.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
24
- deeplotx-0.4.12b3.dist-info/METADATA,sha256=HZY8s697peX3onvVr5bjdHqwKJk2YyMOadMLGk-QvSc,6287
25
- deeplotx-0.4.12b3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
- deeplotx-0.4.12b3.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
27
- deeplotx-0.4.12b3.dist-info/RECORD,,
23
+ deeplotx-0.4.12b5.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
24
+ deeplotx-0.4.12b5.dist-info/METADATA,sha256=xA4HtA5PMPO5goVZITcOJs6hWrFD1DNXDZASs5dcD_o,6287
25
+ deeplotx-0.4.12b5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
+ deeplotx-0.4.12b5.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
27
+ deeplotx-0.4.12b5.dist-info/RECORD,,