deeplotx 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,19 +13,20 @@ class BertEncoder(nn.Module):
13
13
  def __init__(self, model_name_or_path: str = DEFAULT_BERT):
14
14
  super().__init__()
15
15
  self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
16
- cache_dir=CACHE_PATH)
16
+ cache_dir=CACHE_PATH, _from_auto=True)
17
17
  self.bert = BertModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
18
- cache_dir=CACHE_PATH)
18
+ cache_dir=CACHE_PATH, _from_auto=True)
19
+ self.embed_dim = self.bert.config.max_position_embeddings
19
20
 
20
- def forward(self, input_ids, attention_mask: torch.Tensor) -> torch.Tensor:
21
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
21
22
  def _encoder(_input_tup: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
22
23
  return self.bert.forward(_input_tup[0], attention_mask=_input_tup[1]).last_hidden_state[:, 0, :]
23
24
 
24
- num_chunks = math.ceil(input_ids.shape[-1] / 512)
25
+ num_chunks = math.ceil(input_ids.shape[-1] / self.embed_dim)
25
26
  chunks = chunk_results = []
26
27
  for i in range(num_chunks):
27
- start_idx = i * 512
28
- end_idx = min(start_idx + 512, input_ids.shape[-1])
28
+ start_idx = i * self.embed_dim
29
+ end_idx = min(start_idx + self.embed_dim, input_ids.shape[-1])
29
30
  chunks.append((input_ids[:, start_idx: end_idx], attention_mask[:, start_idx: end_idx]))
30
31
  ori_mode = self.bert.training
31
32
  self.bert.eval()
@@ -24,7 +24,7 @@ class LongTextEncoder(BertEncoder):
24
24
  return input_tup[0], super().forward(input_tup[1], attention_mask=input_tup[2])
25
25
 
26
26
  @override
27
- def encode(self, text: str) -> torch.Tensor:
27
+ def encode(self, text: str, use_cache: bool = True) -> torch.Tensor:
28
28
  _text_to_show = text.replace("\n", str())
29
29
  logger.debug(f'Embedding \"{_text_to_show if len(_text_to_show) < 128 else _text_to_show[:128] + "..."}\".')
30
30
  # read cache
@@ -58,5 +58,6 @@ class LongTextEncoder(BertEncoder):
58
58
  fin_emb_tensor = torch.cat((fin_emb_tensor.detach().clone(), emb.detach().clone()), dim=-1)
59
59
  fin_emb_tensor = fin_emb_tensor.squeeze()
60
60
  # write cache
61
- self._cache[_text_hash] = fin_emb_tensor
61
+ if use_cache:
62
+ self._cache[_text_hash] = fin_emb_tensor
62
63
  return fin_emb_tensor
@@ -0,0 +1,30 @@
1
+ import torch
2
+ from torch import nn
3
+ from transformers import LongformerTokenizer, LongformerModel
4
+
5
+ from deeplotx import __ROOT__
6
+
7
+ CACHE_PATH = f'{__ROOT__}\\.cache'
8
+ DEFAULT_LONGFORMER = 'allenai/longformer-base-4096'
9
+
10
+
11
+ class LongformerEncoder(nn.Module):
12
+ def __init__(self, model_name_or_path: str = DEFAULT_LONGFORMER):
13
+ super().__init__()
14
+ self.tokenizer = LongformerTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
15
+ cache_dir=CACHE_PATH, _from_auto=True)
16
+ self.bert = LongformerModel.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
17
+ cache_dir=CACHE_PATH, _from_auto=True)
18
+
19
+ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
20
+ ori_mode = self.bert.training
21
+ self.bert.eval()
22
+ with torch.no_grad():
23
+ res = self.bert.forward(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
24
+ self.bert.train(mode=ori_mode)
25
+ return res
26
+
27
+ def encode(self, text: str) -> torch.Tensor:
28
+ _input_ids = torch.tensor([self.tokenizer.encode(text)], dtype=torch.long)
29
+ _att_mask = torch.tensor([[1] * _input_ids.shape[-1]], dtype=torch.int)
30
+ return self.forward(_input_ids, _att_mask).squeeze()
@@ -1,3 +1,5 @@
1
+ from abc import abstractmethod
2
+
1
3
  import torch
2
4
  from torch import nn
3
5
 
@@ -28,6 +30,7 @@ class BaseNeuralNetwork(nn.Module):
28
30
  def elastic_net(self, alpha: float = 1e-4, rho: float = 0.5) -> torch.Tensor:
29
31
  return alpha * (rho * self.l1(_lambda=1.) + (1 - rho) * self.l2(_lambda=1.))
30
32
 
33
+ @abstractmethod
31
34
  def forward(self, x) -> torch.Tensor: ...
32
35
 
33
36
  def predict(self, x) -> torch.Tensor:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.2.20
3
+ Version: 0.2.21
4
4
  Summary: Easy-2-use long text classifier trainers.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -21,13 +21,19 @@ Dynamic: license-file
21
21
  - Install with pip
22
22
 
23
23
  ```
24
- pip install git+https://github.com/vortezwohl/DeepLoTX.git
24
+ pip install -U deeplotx
25
25
  ```
26
26
 
27
27
  - Install with uv
28
28
 
29
29
  ```
30
- uv add git+https://github.com/vortezwohl/DeepLoTX.git
30
+ uv add -U deeplotx
31
+ ```
32
+
33
+ - Install from github
34
+
35
+ ```
36
+ pip install -U git+https://github.com/vortezwohl/DeepLoTX.git
31
37
  ```
32
38
 
33
39
  ## Quick Start
@@ -1,9 +1,10 @@
1
1
  deeplotx/__init__.py,sha256=Bhxc6HRnuhPZCMNlBc6oKcFTpJbWRGrZmt00vVOsNf0,916
2
2
  deeplotx/encoder/__init__.py,sha256=x7k8IE0FXvDl7kCJGWPsetOHFdvNCiCXHbYOdvo7_JQ,87
3
- deeplotx/encoder/bert_encoder.py,sha256=rdT8YgZzvRoqYqtzPW95ilagSQTAQgUl7mMVetGKxCY,1822
4
- deeplotx/encoder/long_text_encoder.py,sha256=yEEtTVZYHJ0W3OSbh7BHm6xI33nJmVYlSrgD5RVcJLY,2967
3
+ deeplotx/encoder/bert_encoder.py,sha256=A-B7Gj94xv6UhvsFTBH7tnkAdGHRhfUZA2QjSnTKB6c,1970
4
+ deeplotx/encoder/long_text_encoder.py,sha256=V6VxaHW6bMMaZHgU1UZ8n19UfSIV2f2sarWXquiFffQ,3018
5
+ deeplotx/encoder/longformer_encoder.py,sha256=mZpC5TrGHQo98-ydGtVQQ9KRHgCGl1sRoxcQs7r4SSo,1409
5
6
  deeplotx/nn/__init__.py,sha256=9gh8rhKqVWtJyvryU_wHPTLEQIorwOBhAQRc0DtNamM,153
6
- deeplotx/nn/base_neural_network.py,sha256=MXuID5bagdHyrFOkoybW1oiXAY2d4FGnzZoR37LZfUI,1566
7
+ deeplotx/nn/base_neural_network.py,sha256=Rkwu58mXXcuusf-59yLX89MywQx-EvTsSXOvlzUptRE,1621
7
8
  deeplotx/nn/linear_regression.py,sha256=D4mEWVOq6q1Fm2otm57rgZ_E06HJLZBV5k636PprAf4,1520
8
9
  deeplotx/nn/logistic_regression.py,sha256=QAtZp2oyqOW8-1pJWVcahsSM83bzfA68EHObg-wSHHY,463
9
10
  deeplotx/nn/softmax_regression.py,sha256=eUn3mVNlye9ewVdw3McPHZuKbUvvaamsUgFIJMVMgBU,487
@@ -13,8 +14,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=5O-5dwVMCj5EDX9gjJ
13
14
  deeplotx/util/__init__.py,sha256=JxqAK_WOOHcYVSTHBT1-WuBwWrPEVDTV3titeVWvNUM,74
14
15
  deeplotx/util/hash.py,sha256=wwsC6kOQvbpuvwKsNQOARd78_wePmW9i3oaUuXRUnpc,352
15
16
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
16
- deeplotx-0.2.20.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
17
- deeplotx-0.2.20.dist-info/METADATA,sha256=NQgRWucDSAI4awAJNf9984IujFRo9PurR1qrqpmWIzA,1573
18
- deeplotx-0.2.20.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
19
- deeplotx-0.2.20.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
20
- deeplotx-0.2.20.dist-info/RECORD,,
17
+ deeplotx-0.2.21.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
18
+ deeplotx-0.2.21.dist-info/METADATA,sha256=mNUcUO4dSccX1Sz8868nrbq3qWo3cINJXPVv8XtVpzY,1617
19
+ deeplotx-0.2.21.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
20
+ deeplotx-0.2.21.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
21
+ deeplotx-0.2.21.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5