deeplotx 0.5.5__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,7 @@ class Encoder(nn.Module):
25
25
  self.embed_dim = self.encoder.config.max_position_embeddings
26
26
  logger.debug(f'{Encoder.__name__} initialized on device: {self.device}.')
27
27
 
28
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
28
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, *args, **kwargs) -> torch.Tensor:
29
29
  def _encoder(_input_tup: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
30
30
  return self.encoder.forward(_input_tup[0], attention_mask=_input_tup[1]).last_hidden_state[:, 0, :]
31
31
 
@@ -25,6 +25,10 @@ class LongTextEncoder(Encoder):
25
25
  def __chunk_embedding(self, idx: int, x: torch.Tensor, mask: torch.Tensor) -> tuple[int, torch.Tensor]:
26
26
  return idx, super().forward(x, attention_mask=mask)
27
27
 
28
+ @override
29
+ def forward(self, text: str, flatten: bool = False, *args, **kwargs) -> torch.Tensor:
30
+ return self.encode(text=text, flatten=flatten)
31
+
28
32
  @override
29
33
  def encode(self, text: str, flatten: bool = False) -> torch.Tensor:
30
34
  def postprocess(tensors: list[torch.Tensor], _flatten: bool) -> torch.Tensor:
@@ -31,7 +31,7 @@ class TextBinaryClassifierTrainer(BaseTrainer):
31
31
  positive_texts = positive_texts[:min_length]
32
32
  negative_texts = negative_texts[:min_length]
33
33
  all_texts = positive_texts + negative_texts
34
- text_embeddings = [self._long_text_encoder.encode(x, flatten=False, use_cache=True) for x in all_texts]
34
+ text_embeddings = [self._long_text_encoder.encode(x, flatten=False) for x in all_texts]
35
35
  feature_dim = text_embeddings[0].shape[-1]
36
36
  dtype = text_embeddings[0].dtype
37
37
  labels = ([torch.tensor([1.], dtype=dtype, device=self.device) for _ in range(len(positive_texts))]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.5.5
3
+ Version: 0.5.6
4
4
  Summary: Easy-2-use long text NLP toolkit.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -265,7 +265,8 @@ Dynamic: license-file
265
265
  long_text_encoder = LongTextEncoder(
266
266
  max_length=2048, # 最大文本大小, 超出截断
267
267
  chunk_size=448, # 块大小 (按 Token 计)
268
- overlapping=32 # 块间重叠大小 (按 Token 计)
268
+ overlapping=32, # 块间重叠大小 (按 Token 计)
269
+ cache_capacity=512 # 缓存大小
269
270
  )
270
271
 
271
272
  trainer = TextBinaryClassifierTrainer(
@@ -1,7 +1,7 @@
1
1
  deeplotx/__init__.py,sha256=6El66QXHDrgNMsNIG9bG97WO8BhPK5btXbTikzx2ce4,1087
2
2
  deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
3
- deeplotx/encoder/encoder.py,sha256=p1e4Dx3-Ghdl0MGNalr0D_OnafwaJnbhscEDVq-y73A,2400
4
- deeplotx/encoder/long_text_encoder.py,sha256=GatkOF1QQHLtvyuikfCP4xpzfDvszJJyonaS9f7wSxg,3401
3
+ deeplotx/encoder/encoder.py,sha256=2e1ZnZ37PkFQ5BePndmq42xmHp8YZh65Q1bd0dxejPI,2417
4
+ deeplotx/encoder/long_text_encoder.py,sha256=4445FdVwubvDiebCWoT9wAUpYlMj6Mmd0OBxbFZ3ZIo,3565
5
5
  deeplotx/encoder/longformer_encoder.py,sha256=A8FXqd4mdHxSn_o_R689XtpT73ISDT788EgMQRGLC2g,1822
6
6
  deeplotx/nn/__init__.py,sha256=CS0UwyYKa8wI6vu6FBIYxvm-HAmw39MTMFlZDtqi6UA,444
7
7
  deeplotx/nn/auto_regression.py,sha256=7P63opWCWMqE2DigwbsL6kfXtFtJPz00Yo1RqflBz4A,572
@@ -19,12 +19,12 @@ deeplotx/similarity/set.py,sha256=zhGFxtSIXlWqvipBYzoiPahp4g0boAIoUiMfG0wl07A,68
19
19
  deeplotx/similarity/vector.py,sha256=WVbDHqykt-fvuILVrhUCtIFAOEjY_zvttrXGM9eylG0,1125
20
20
  deeplotx/trainer/__init__.py,sha256=Fl5DR9UecQc5VtBcczU9sx_HtPNoFohpuELOh-Jrsks,77
21
21
  deeplotx/trainer/base_trainer.py,sha256=z0MeAT-rRYmjeBXt0ckt7J1itYArR0Cx02wHesXUoZE,385
22
- deeplotx/trainer/text_binary_classification_trainer.py,sha256=BNBQdpaD8nB1dQv8naHNIravNcQC8JjOMqD-WRSrUH0,4931
22
+ deeplotx/trainer/text_binary_classification_trainer.py,sha256=umuvikc09Op4SB43EqmYo8W3ung8DBjEOrMG3hCVFz8,4915
23
23
  deeplotx/util/__init__.py,sha256=JxqAK_WOOHcYVSTHBT1-WuBwWrPEVDTV3titeVWvNUM,74
24
24
  deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
25
25
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
26
- deeplotx-0.5.5.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
27
- deeplotx-0.5.5.dist-info/METADATA,sha256=QE1R1jodTrnPFY7cbu4mQNPt8_BgKNJuHoSDswopueo,10880
28
- deeplotx-0.5.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
- deeplotx-0.5.5.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
30
- deeplotx-0.5.5.dist-info/RECORD,,
26
+ deeplotx-0.5.6.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
27
+ deeplotx-0.5.6.dist-info/METADATA,sha256=vBUVgshgGG_vZmJT07C7CPEhMfBUmwbCtsIY06D_14g,10925
28
+ deeplotx-0.5.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
29
+ deeplotx-0.5.6.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
30
+ deeplotx-0.5.6.dist-info/RECORD,,