deeplotx 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeplotx/encoder/encoder.py +1 -1
- deeplotx/encoder/long_text_encoder.py +4 -0
- deeplotx/trainer/text_binary_classification_trainer.py +1 -1
- {deeplotx-0.5.5.dist-info → deeplotx-0.5.6.dist-info}/METADATA +3 -2
- {deeplotx-0.5.5.dist-info → deeplotx-0.5.6.dist-info}/RECORD +8 -8
- {deeplotx-0.5.5.dist-info → deeplotx-0.5.6.dist-info}/WHEEL +0 -0
- {deeplotx-0.5.5.dist-info → deeplotx-0.5.6.dist-info}/licenses/LICENSE +0 -0
- {deeplotx-0.5.5.dist-info → deeplotx-0.5.6.dist-info}/top_level.txt +0 -0
deeplotx/encoder/encoder.py
CHANGED
@@ -25,7 +25,7 @@ class Encoder(nn.Module):
|
|
25
25
|
self.embed_dim = self.encoder.config.max_position_embeddings
|
26
26
|
logger.debug(f'{Encoder.__name__} initialized on device: {self.device}.')
|
27
27
|
|
28
|
-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
28
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
29
29
|
def _encoder(_input_tup: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
30
30
|
return self.encoder.forward(_input_tup[0], attention_mask=_input_tup[1]).last_hidden_state[:, 0, :]
|
31
31
|
|
@@ -25,6 +25,10 @@ class LongTextEncoder(Encoder):
|
|
25
25
|
def __chunk_embedding(self, idx: int, x: torch.Tensor, mask: torch.Tensor) -> tuple[int, torch.Tensor]:
|
26
26
|
return idx, super().forward(x, attention_mask=mask)
|
27
27
|
|
28
|
+
@override
|
29
|
+
def forward(self, text: str, flatten: bool = False, *args, **kwargs) -> torch.Tensor:
|
30
|
+
return self.encode(text=text, flatten=flatten)
|
31
|
+
|
28
32
|
@override
|
29
33
|
def encode(self, text: str, flatten: bool = False) -> torch.Tensor:
|
30
34
|
def postprocess(tensors: list[torch.Tensor], _flatten: bool) -> torch.Tensor:
|
@@ -31,7 +31,7 @@ class TextBinaryClassifierTrainer(BaseTrainer):
|
|
31
31
|
positive_texts = positive_texts[:min_length]
|
32
32
|
negative_texts = negative_texts[:min_length]
|
33
33
|
all_texts = positive_texts + negative_texts
|
34
|
-
text_embeddings = [self._long_text_encoder.encode(x, flatten=False
|
34
|
+
text_embeddings = [self._long_text_encoder.encode(x, flatten=False) for x in all_texts]
|
35
35
|
feature_dim = text_embeddings[0].shape[-1]
|
36
36
|
dtype = text_embeddings[0].dtype
|
37
37
|
labels = ([torch.tensor([1.], dtype=dtype, device=self.device) for _ in range(len(positive_texts))]
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: deeplotx
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.6
|
4
4
|
Summary: Easy-2-use long text NLP toolkit.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
@@ -265,7 +265,8 @@ Dynamic: license-file
|
|
265
265
|
long_text_encoder = LongTextEncoder(
|
266
266
|
max_length=2048, # 最大文本大小, 超出截断
|
267
267
|
chunk_size=448, # 块大小 (按 Token 计)
|
268
|
-
overlapping=32 # 块间重叠大小 (按 Token 计)
|
268
|
+
overlapping=32, # 块间重叠大小 (按 Token 计)
|
269
|
+
cache_capacity=512 # 缓存大小
|
269
270
|
)
|
270
271
|
|
271
272
|
trainer = TextBinaryClassifierTrainer(
|
@@ -1,7 +1,7 @@
|
|
1
1
|
deeplotx/__init__.py,sha256=6El66QXHDrgNMsNIG9bG97WO8BhPK5btXbTikzx2ce4,1087
|
2
2
|
deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
|
3
|
-
deeplotx/encoder/encoder.py,sha256=
|
4
|
-
deeplotx/encoder/long_text_encoder.py,sha256=
|
3
|
+
deeplotx/encoder/encoder.py,sha256=2e1ZnZ37PkFQ5BePndmq42xmHp8YZh65Q1bd0dxejPI,2417
|
4
|
+
deeplotx/encoder/long_text_encoder.py,sha256=4445FdVwubvDiebCWoT9wAUpYlMj6Mmd0OBxbFZ3ZIo,3565
|
5
5
|
deeplotx/encoder/longformer_encoder.py,sha256=A8FXqd4mdHxSn_o_R689XtpT73ISDT788EgMQRGLC2g,1822
|
6
6
|
deeplotx/nn/__init__.py,sha256=CS0UwyYKa8wI6vu6FBIYxvm-HAmw39MTMFlZDtqi6UA,444
|
7
7
|
deeplotx/nn/auto_regression.py,sha256=7P63opWCWMqE2DigwbsL6kfXtFtJPz00Yo1RqflBz4A,572
|
@@ -19,12 +19,12 @@ deeplotx/similarity/set.py,sha256=zhGFxtSIXlWqvipBYzoiPahp4g0boAIoUiMfG0wl07A,68
|
|
19
19
|
deeplotx/similarity/vector.py,sha256=WVbDHqykt-fvuILVrhUCtIFAOEjY_zvttrXGM9eylG0,1125
|
20
20
|
deeplotx/trainer/__init__.py,sha256=Fl5DR9UecQc5VtBcczU9sx_HtPNoFohpuELOh-Jrsks,77
|
21
21
|
deeplotx/trainer/base_trainer.py,sha256=z0MeAT-rRYmjeBXt0ckt7J1itYArR0Cx02wHesXUoZE,385
|
22
|
-
deeplotx/trainer/text_binary_classification_trainer.py,sha256=
|
22
|
+
deeplotx/trainer/text_binary_classification_trainer.py,sha256=umuvikc09Op4SB43EqmYo8W3ung8DBjEOrMG3hCVFz8,4915
|
23
23
|
deeplotx/util/__init__.py,sha256=JxqAK_WOOHcYVSTHBT1-WuBwWrPEVDTV3titeVWvNUM,74
|
24
24
|
deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
|
25
25
|
deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
|
26
|
-
deeplotx-0.5.
|
27
|
-
deeplotx-0.5.
|
28
|
-
deeplotx-0.5.
|
29
|
-
deeplotx-0.5.
|
30
|
-
deeplotx-0.5.
|
26
|
+
deeplotx-0.5.6.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
27
|
+
deeplotx-0.5.6.dist-info/METADATA,sha256=vBUVgshgGG_vZmJT07C7CPEhMfBUmwbCtsIY06D_14g,10925
|
28
|
+
deeplotx-0.5.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
29
|
+
deeplotx-0.5.6.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
|
30
|
+
deeplotx-0.5.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|