deeplotx 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeplotx/nn/base_neural_network.py +6 -3
- deeplotx/nn/feed_forward.py +1 -1
- deeplotx/nn/recursive_sequential.py +6 -3
- {deeplotx-0.8.6.dist-info → deeplotx-0.8.7.dist-info}/METADATA +1 -1
- {deeplotx-0.8.6.dist-info → deeplotx-0.8.7.dist-info}/RECORD +8 -8
- {deeplotx-0.8.6.dist-info → deeplotx-0.8.7.dist-info}/WHEEL +0 -0
- {deeplotx-0.8.6.dist-info → deeplotx-0.8.7.dist-info}/licenses/LICENSE +0 -0
- {deeplotx-0.8.6.dist-info → deeplotx-0.8.7.dist-info}/top_level.txt +0 -0
@@ -99,11 +99,14 @@ class BaseNeuralNetwork(nn.Module):
|
|
99
99
|
|
100
100
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
101
101
|
x = self.ensure_device_and_dtype(x, device=self.device, dtype=self.dtype)
|
102
|
-
|
103
|
-
self.
|
102
|
+
training_state_dict = dict()
|
103
|
+
for m in self.modules():
|
104
|
+
training_state_dict[m] = m.training
|
105
|
+
m.training = False
|
104
106
|
with torch.no_grad():
|
105
107
|
res = self.forward(x)
|
106
|
-
|
108
|
+
for m, training_state in training_state_dict.items():
|
109
|
+
m.training = training_state
|
107
110
|
return res
|
108
111
|
|
109
112
|
def save(self, model_name: str | None = None, model_dir: str = '.', _suffix: str = DEFAULT_SUFFIX):
|
deeplotx/nn/feed_forward.py
CHANGED
@@ -28,7 +28,7 @@ class FeedForwardUnit(BaseNeuralNetwork):
|
|
28
28
|
x = self.layer_norm(x)
|
29
29
|
x = self.up_proj(x)
|
30
30
|
x = self.parametric_relu(x)
|
31
|
-
if self._dropout_rate > .0:
|
31
|
+
if self._dropout_rate > .0 and self.training:
|
32
32
|
x = torch.dropout(x, p=self._dropout_rate, train=self.training)
|
33
33
|
return self.down_proj(x) + residual
|
34
34
|
|
@@ -41,9 +41,12 @@ class RecursiveSequential(BaseNeuralNetwork):
|
|
41
41
|
|
42
42
|
@override
|
43
43
|
def predict(self, x: torch.Tensor) -> torch.Tensor:
|
44
|
-
|
45
|
-
self.
|
44
|
+
training_state_dict = dict()
|
45
|
+
for m in self.modules():
|
46
|
+
training_state_dict[m] = m.training
|
47
|
+
m.training = False
|
46
48
|
with torch.no_grad():
|
47
49
|
res = self.forward(x.unsqueeze(0), self.initial_state(batch_size=1))[0]
|
48
|
-
|
50
|
+
for m, training_state in training_state_dict.items():
|
51
|
+
m.training = training_state
|
49
52
|
return res
|
@@ -6,15 +6,15 @@ deeplotx/encoder/longformer_encoder.py,sha256=NNYLr5I9tdeh0C8Ir7QcbEMU9gDk6U7CiF
|
|
6
6
|
deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
|
7
7
|
deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
|
8
8
|
deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
|
9
|
-
deeplotx/nn/base_neural_network.py,sha256=
|
10
|
-
deeplotx/nn/feed_forward.py,sha256=
|
9
|
+
deeplotx/nn/base_neural_network.py,sha256=QCyB1dxOs4I8vpu6PCshrZs0infoHXS9IErw6tN-dhs,6060
|
10
|
+
deeplotx/nn/feed_forward.py,sha256=kGWEUo8J7jrhSSWlitNnj-AcitNiLz6eOCvUcEuWlVs,2949
|
11
11
|
deeplotx/nn/linear_regression.py,sha256=LWrrdAIw32KIT1bdr7q6HczdpEiCgb-R8BCNXGywMxE,1763
|
12
12
|
deeplotx/nn/logistic_regression.py,sha256=nipWD3ZPRub2Cx0rU2zxYQyG0COn3NJvew8b2gbJy24,998
|
13
13
|
deeplotx/nn/long_context_auto_regression.py,sha256=uy0k_g8wEfMH5nd5HCfrHA8dgEsuWBA2x8U-g3h4vQc,1054
|
14
14
|
deeplotx/nn/long_context_recursive_sequential.py,sha256=pcZfnrIHBqbp2BssfUTS1klpuykZwowikfAIaOnvRUI,2674
|
15
15
|
deeplotx/nn/multi_head_attention.py,sha256=3z73uGbvy3jszRy1B9nxGOJjlttHpcpRF8Qd09OEams,2267
|
16
16
|
deeplotx/nn/multi_head_feed_forward.py,sha256=hD9ScrVJZ9kNksoFASf0xaPgEnNgCeRivW-XjYOPjj8,1908
|
17
|
-
deeplotx/nn/recursive_sequential.py,sha256=
|
17
|
+
deeplotx/nn/recursive_sequential.py,sha256=sNvAs9iVCuWIgx0_6TizDq41hJpFbfKT3kyDHE86wRM,2928
|
18
18
|
deeplotx/nn/roformer_encoder.py,sha256=BAPAMS5-qiM3i2FUyIW-ZTc7og4gZzwlu5LniqzaymY,2432
|
19
19
|
deeplotx/nn/rope.py,sha256=RTOjnllubktdy2rzFWxBfkuLuGjhEMyDd06uojdqPhM,1848
|
20
20
|
deeplotx/nn/softmax_regression.py,sha256=xe2etxSfN0e9XZ4E6Uyz5ThWWzAdQVjYIvN24j8kfNY,1019
|
@@ -28,8 +28,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
|
|
28
28
|
deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
|
29
29
|
deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
|
30
30
|
deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
|
31
|
-
deeplotx-0.8.
|
32
|
-
deeplotx-0.8.
|
33
|
-
deeplotx-0.8.
|
34
|
-
deeplotx-0.8.
|
35
|
-
deeplotx-0.8.
|
31
|
+
deeplotx-0.8.7.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
32
|
+
deeplotx-0.8.7.dist-info/METADATA,sha256=fGyVnmSy3YKst_ZpwtMQhCq_-yxp5pvf-4zcQlhxNBA,13138
|
33
|
+
deeplotx-0.8.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
34
|
+
deeplotx-0.8.7.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
|
35
|
+
deeplotx-0.8.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|