deeplotx 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -99,11 +99,14 @@ class BaseNeuralNetwork(nn.Module):
99
99
 
100
100
  def predict(self, x: torch.Tensor) -> torch.Tensor:
101
101
  x = self.ensure_device_and_dtype(x, device=self.device, dtype=self.dtype)
102
- __train = self.training
103
- self.training = False
102
+ training_state_dict = dict()
103
+ for m in self.modules():
104
+ training_state_dict[m] = m.training
105
+ m.training = False
104
106
  with torch.no_grad():
105
107
  res = self.forward(x)
106
- self.training = __train
108
+ for m, training_state in training_state_dict.items():
109
+ m.training = training_state
107
110
  return res
108
111
 
109
112
  def save(self, model_name: str | None = None, model_dir: str = '.', _suffix: str = DEFAULT_SUFFIX):
@@ -28,7 +28,7 @@ class FeedForwardUnit(BaseNeuralNetwork):
28
28
  x = self.layer_norm(x)
29
29
  x = self.up_proj(x)
30
30
  x = self.parametric_relu(x)
31
- if self._dropout_rate > .0:
31
+ if self._dropout_rate > .0 and self.training:
32
32
  x = torch.dropout(x, p=self._dropout_rate, train=self.training)
33
33
  return self.down_proj(x) + residual
34
34
 
@@ -41,9 +41,12 @@ class RecursiveSequential(BaseNeuralNetwork):
41
41
 
42
42
  @override
43
43
  def predict(self, x: torch.Tensor) -> torch.Tensor:
44
- __train = self.training
45
- self.training = False
44
+ training_state_dict = dict()
45
+ for m in self.modules():
46
+ training_state_dict[m] = m.training
47
+ m.training = False
46
48
  with torch.no_grad():
47
49
  res = self.forward(x.unsqueeze(0), self.initial_state(batch_size=1))[0]
48
- self.training = __train
50
+ for m, training_state in training_state_dict.items():
51
+ m.training = training_state
49
52
  return res
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.8.6
3
+ Version: 0.8.7
4
4
  Summary: Easy-2-use long text NLP toolkit.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -6,15 +6,15 @@ deeplotx/encoder/longformer_encoder.py,sha256=NNYLr5I9tdeh0C8Ir7QcbEMU9gDk6U7CiF
6
6
  deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
7
7
  deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
8
8
  deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
9
- deeplotx/nn/base_neural_network.py,sha256=FjQEDFH810fJS7JV3aLgJZnaMqC6DH--wlBvuj-ghTc,5900
10
- deeplotx/nn/feed_forward.py,sha256=4ozj7EDalO9pb6JUhZtsJqE0r8bIHFApHRt2zTrl4ho,2931
9
+ deeplotx/nn/base_neural_network.py,sha256=QCyB1dxOs4I8vpu6PCshrZs0infoHXS9IErw6tN-dhs,6060
10
+ deeplotx/nn/feed_forward.py,sha256=kGWEUo8J7jrhSSWlitNnj-AcitNiLz6eOCvUcEuWlVs,2949
11
11
  deeplotx/nn/linear_regression.py,sha256=LWrrdAIw32KIT1bdr7q6HczdpEiCgb-R8BCNXGywMxE,1763
12
12
  deeplotx/nn/logistic_regression.py,sha256=nipWD3ZPRub2Cx0rU2zxYQyG0COn3NJvew8b2gbJy24,998
13
13
  deeplotx/nn/long_context_auto_regression.py,sha256=uy0k_g8wEfMH5nd5HCfrHA8dgEsuWBA2x8U-g3h4vQc,1054
14
14
  deeplotx/nn/long_context_recursive_sequential.py,sha256=pcZfnrIHBqbp2BssfUTS1klpuykZwowikfAIaOnvRUI,2674
15
15
  deeplotx/nn/multi_head_attention.py,sha256=3z73uGbvy3jszRy1B9nxGOJjlttHpcpRF8Qd09OEams,2267
16
16
  deeplotx/nn/multi_head_feed_forward.py,sha256=hD9ScrVJZ9kNksoFASf0xaPgEnNgCeRivW-XjYOPjj8,1908
17
- deeplotx/nn/recursive_sequential.py,sha256=Nrnsx-AU68tz1vn8_uf5ZdC-r8vA_X4-p-DY2t8y8us,2768
17
+ deeplotx/nn/recursive_sequential.py,sha256=sNvAs9iVCuWIgx0_6TizDq41hJpFbfKT3kyDHE86wRM,2928
18
18
  deeplotx/nn/roformer_encoder.py,sha256=BAPAMS5-qiM3i2FUyIW-ZTc7og4gZzwlu5LniqzaymY,2432
19
19
  deeplotx/nn/rope.py,sha256=RTOjnllubktdy2rzFWxBfkuLuGjhEMyDd06uojdqPhM,1848
20
20
  deeplotx/nn/softmax_regression.py,sha256=xe2etxSfN0e9XZ4E6Uyz5ThWWzAdQVjYIvN24j8kfNY,1019
@@ -28,8 +28,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
28
28
  deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
29
29
  deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
30
30
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
31
- deeplotx-0.8.6.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
32
- deeplotx-0.8.6.dist-info/METADATA,sha256=9cUvV_kD2TMFotnw51j1hXvGqjm8MBAfm7nJG62174I,13138
33
- deeplotx-0.8.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- deeplotx-0.8.6.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
35
- deeplotx-0.8.6.dist-info/RECORD,,
31
+ deeplotx-0.8.7.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
32
+ deeplotx-0.8.7.dist-info/METADATA,sha256=fGyVnmSy3YKst_ZpwtMQhCq_-yxp5pvf-4zcQlhxNBA,13138
33
+ deeplotx-0.8.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ deeplotx-0.8.7.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
35
+ deeplotx-0.8.7.dist-info/RECORD,,