deeplotx 0.9.12__py3-none-any.whl → 0.9.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deeplotx/__init__.py CHANGED
@@ -15,6 +15,7 @@ from .nn import (
15
15
  RecursiveSequential,
16
16
  LongContextRecursiveSequential,
17
17
  RoPE,
18
+ LoRA,
18
19
  Attention,
19
20
  MultiHeadAttention,
20
21
  RoFormerEncoder,
deeplotx/nn/__init__.py CHANGED
@@ -7,6 +7,7 @@ from .softmax_regression import SoftmaxRegression
7
7
  from .recursive_sequential import RecursiveSequential
8
8
  from .long_context_recursive_sequential import LongContextRecursiveSequential
9
9
  from .rope import RoPE
10
+ from .lora import LoRA
10
11
  from .attention import Attention
11
12
  from .multi_head_attention import MultiHeadAttention
12
13
  from .roformer_encoder import RoFormerEncoder
@@ -10,13 +10,15 @@ DEFAULT_SUFFIX = 'dlx'
10
10
 
11
11
  class BaseNeuralNetwork(nn.Module):
12
12
  def __init__(self, in_features: int, out_features: int, model_name: str | None = None,
13
- device: str | None = None, dtype: torch.dtype | None = None):
13
+ device: str | torch.device | None = None, dtype: torch.dtype | None = None):
14
14
  super().__init__()
15
15
  self._model_name = model_name \
16
16
  if model_name is not None \
17
17
  else self.__class__.__name__
18
- self.device = torch.device(device) if device is not None \
19
- else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ self.device = device if isinstance(device, torch.device) else None
19
+ if self.device is None:
20
+ self.device = torch.device(device) if device is not None \
21
+ else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
22
  self.dtype = dtype if dtype is not None else torch.float32
21
23
  self._in_features = in_features
22
24
  self._out_features = out_features
deeplotx/nn/lora.py ADDED
@@ -0,0 +1,56 @@
1
+ from typing_extensions import override
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from deeplotx.nn.base_neural_network import BaseNeuralNetwork
7
+
8
+
9
+ class LoRA(BaseNeuralNetwork):
10
+ def __init__(self, input_dim: int, output_dim: int, rank: int = 8, alpha: int = 16,
11
+ dropout_rate: float = .0, model_name: str | None = None, device: str | torch.device | None = None,
12
+ dtype: torch.dtype | None = None):
13
+ super().__init__(in_features=input_dim, out_features=output_dim, model_name=model_name,
14
+ device=device, dtype=dtype)
15
+ self._rank = rank
16
+ self._alpha = alpha
17
+ self._scaling = self._alpha / self._rank
18
+ self._dropout = nn.Dropout(p=dropout_rate) if dropout_rate > .0 else nn.Identity()
19
+ self.lora_A = nn.Linear(in_features=input_dim, out_features=rank, bias=False,
20
+ device=self.device, dtype=self.dtype)
21
+ self.lora_B = nn.Linear(in_features=rank, out_features=output_dim, bias=False,
22
+ device=self.device, dtype=self.dtype)
23
+ nn.init.normal_(self.lora_A.weight, mean=.0, std=.01)
24
+ nn.init.zeros_(self.lora_B.weight)
25
+ self.w0 = None
26
+
27
+ @override
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ if not isinstance(self.w0, nn.Module):
30
+ raise ValueError('LoRA adapter was not mounted successfully.')
31
+ original_out = self.w0(x)
32
+ lora_out = self.lora_B(self._dropout(self.lora_A(x))) * self._scaling
33
+ return original_out + lora_out
34
+
35
+ @staticmethod
36
+ def apply_to(model: nn.Module, target_modules: list[str] | str, rank: int = 8, alpha: int = 16,
37
+ dropout_rate: float = .0) -> nn.Module:
38
+ if isinstance(target_modules, str):
39
+ target_modules = [target_modules]
40
+ for layer_name, module in model.named_modules():
41
+ if any(_name in layer_name.split('.')[-1] for _name in target_modules):
42
+ lora = LoRA(input_dim=module.in_features, output_dim=module.out_features,
43
+ rank=rank, alpha=alpha, dropout_rate=dropout_rate,
44
+ device=next(module.parameters()).device,
45
+ dtype=next(module.parameters()).dtype)
46
+ lora.w0 = module
47
+ parent_name = layer_name.rsplit('.', 1)[0] if '.' in layer_name else ''
48
+ child_name = layer_name.split('.')[-1]
49
+ parent_module = dict(model.named_modules())[parent_name] if parent_name else model
50
+ setattr(parent_module, child_name, lora)
51
+ for param in model.parameters():
52
+ param.requires_grad = False
53
+ for name, param in model.named_parameters():
54
+ if 'lora_A.weight' in name or 'lora_B.weight' in name:
55
+ param.requires_grad = True
56
+ return model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.9.12
3
+ Version: 0.9.15
4
4
  Summary: An out-of-the-box long-text NLP framework.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
@@ -15,7 +15,7 @@ Requires-Dist: tiktoken
15
15
  Requires-Dist: torch
16
16
  Requires-Dist: transformers
17
17
  Requires-Dist: typing-extensions
18
- Requires-Dist: vortezwohl>=0.0.10
18
+ Requires-Dist: vortezwohl>=0.0.17
19
19
  Requires-Dist: name4py>=0.1.4
20
20
  Dynamic: license-file
21
21
 
@@ -157,6 +157,21 @@ year = {2025}
157
157
  (<Gender.Male: 'male'>, 1.0)
158
158
  ```
159
159
 
160
+ - ### Apply LoRA to a model
161
+
162
+ Import dependencies
163
+
164
+ ```python
165
+ from deeplotx import LoRA
166
+ ```
167
+
168
+ Assumed that the `model` has been loaded
169
+
170
+ ```python
171
+ model = ... # Maybe an LLM or some other deep neural network models
172
+ lora_model = LoRA.apply_to(model, target_modules=['q_proj'], rank=16, alpha=32, dropout_rate=.05)
173
+ ```
174
+
160
175
  - ### Long text embedding
161
176
 
162
177
  - **BERT based long text embedding**
@@ -273,6 +288,7 @@ year = {2025}
273
288
  RecursiveSequential,
274
289
  LongContextRecursiveSequential,
275
290
  RoPE,
291
+ LoRA,
276
292
  Attention,
277
293
  MultiHeadAttention,
278
294
  RoFormerEncoder,
@@ -391,35 +407,3 @@ year = {2025}
391
407
  v = self.v_proj(y)
392
408
  return torch.matmul(self._attention(x, y, mask), v)
393
409
  ```
394
-
395
- - ### Text binary classification task with predefined trainer
396
-
397
- ```python
398
- from deeplotx import TextBinaryClassifierTrainer, LongTextEncoder
399
- from deeplotx.util import get_files, read_file
400
-
401
- long_text_encoder = LongTextEncoder(
402
- max_length=2048,
403
- chunk_size=448,
404
- overlapping=32,
405
- cache_capacity=512
406
- )
407
- trainer = TextBinaryClassifierTrainer(
408
- long_text_encoder=long_text_encoder,
409
- batch_size=2,
410
- train_ratio=0.9
411
- )
412
- pos_data_path = 'path/to/pos_dir'
413
- neg_data_path = 'path/to/neg_dir'
414
- pos_data = [read_file(x) for x in get_files(pos_data_path)]
415
- neg_data = [read_file(x) for x in get_files(neg_data_path)]
416
- model = trainer.train(pos_data, neg_data,
417
- num_epochs=36, learning_rate=2e-5,
418
- balancing_dataset=True, alpha=1e-4,
419
- rho=.2, encoder_layers=2,
420
- attn_heads=8,
421
- recursive_layers=2)
422
- model.save(model_name='test_model', model_dir='model')
423
- model = model.load(model_name='test_model', model_dir='model')
424
- model.predict(long_text_encoder.encode('这是一个测试文本.', flatten=False))
425
- ```
@@ -1,4 +1,4 @@
1
- deeplotx/__init__.py,sha256=0OWLsgXlStzwm0m9ScaoZvBnsx3a0xTmlzYBUgarl-g,1306
1
+ deeplotx/__init__.py,sha256=9vGW9dz8aY9GMR0ctU0mBXvzhZR4a7P0CUG8fB1YQRc,1317
2
2
  deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
3
3
  deeplotx/encoder/encoder.py,sha256=uJswUSrYVDWP84HeCD40R9KgXGPUEa080konv7jEp8I,3539
4
4
  deeplotx/encoder/long_text_encoder.py,sha256=eabhgTMhJrAvRC7YyXAAR_LRv9-LULzo8S0uybWTcwM,3961
@@ -8,15 +8,16 @@ deeplotx/ner/base_ner.py,sha256=pZTl50OrHH_FJm4rKp9iuixeOE6FX_AzgDXD32aXsN0,204
8
8
  deeplotx/ner/bert_ner.py,sha256=tfbM3CQBEpZsD0KYVA7GVNJax-7kzNOQgwlpo2S8h-c,8986
9
9
  deeplotx/ner/named_entity.py,sha256=c6XufIwH6yloJ-ccUjagf4mBl1XbbYDT8xyEJJ_-ZNs,269
10
10
  deeplotx/ner/n2g/__init__.py,sha256=L1IJ8W1nApzqHx2u7JMtPCLfABm5qKJvh_bHMWdvdLY,3538
11
- deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
11
+ deeplotx/nn/__init__.py,sha256=9v-cftkJe7MSRkqfvBwu1kZx4nEosT0d36xi3KQtzXw,734
12
12
  deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
13
13
  deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
14
- deeplotx/nn/base_neural_network.py,sha256=QCyB1dxOs4I8vpu6PCshrZs0infoHXS9IErw6tN-dhs,6060
14
+ deeplotx/nn/base_neural_network.py,sha256=w98m1qOSw4C9zFu0uI4k3z8zeXrn5NiAZruyCe2ZQjs,6192
15
15
  deeplotx/nn/feed_forward.py,sha256=kGWEUo8J7jrhSSWlitNnj-AcitNiLz6eOCvUcEuWlVs,2949
16
16
  deeplotx/nn/linear_regression.py,sha256=LWrrdAIw32KIT1bdr7q6HczdpEiCgb-R8BCNXGywMxE,1763
17
17
  deeplotx/nn/logistic_regression.py,sha256=nipWD3ZPRub2Cx0rU2zxYQyG0COn3NJvew8b2gbJy24,998
18
18
  deeplotx/nn/long_context_auto_regression.py,sha256=uy0k_g8wEfMH5nd5HCfrHA8dgEsuWBA2x8U-g3h4vQc,1054
19
19
  deeplotx/nn/long_context_recursive_sequential.py,sha256=pcZfnrIHBqbp2BssfUTS1klpuykZwowikfAIaOnvRUI,2674
20
+ deeplotx/nn/lora.py,sha256=oy3HltuDily4NLSbUTNvS0i4Kq2kiwl92rBsIt8UOIA,2919
20
21
  deeplotx/nn/multi_head_attention.py,sha256=3z73uGbvy3jszRy1B9nxGOJjlttHpcpRF8Qd09OEams,2267
21
22
  deeplotx/nn/multi_head_feed_forward.py,sha256=hD9ScrVJZ9kNksoFASf0xaPgEnNgCeRivW-XjYOPjj8,1908
22
23
  deeplotx/nn/recursive_sequential.py,sha256=sNvAs9iVCuWIgx0_6TizDq41hJpFbfKT3kyDHE86wRM,2928
@@ -28,8 +29,8 @@ deeplotx/similarity/distribution.py,sha256=wQGouuuW531pZeBRKBujXsdsoz4fDnPw7_GW8
28
29
  deeplotx/similarity/set.py,sha256=zhGFxtSIXlWqvipBYzoiPahp4g0boAIoUiMfG0wl07A,686
29
30
  deeplotx/similarity/vector.py,sha256=WVbDHqykt-fvuILVrhUCtIFAOEjY_zvttrXGM9eylG0,1125
30
31
  deeplotx/util/__init__.py,sha256=d1qelOGVTLSsHp1R_gsP_FSMAtAxUxWMwiPrTS58RSg,66
31
- deeplotx-0.9.12.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
32
- deeplotx-0.9.12.dist-info/METADATA,sha256=dDoMn1gtn1dTfc7KhfGmyuZ1i_YycYVmh0QxnRh5g5I,14444
33
- deeplotx-0.9.12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- deeplotx-0.9.12.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
35
- deeplotx-0.9.12.dist-info/RECORD,,
32
+ deeplotx-0.9.15.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
33
+ deeplotx-0.9.15.dist-info/METADATA,sha256=8y2RddiUJGYpRITudVOnD4KnynqJNKfIEKYR6qietWI,13573
34
+ deeplotx-0.9.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
35
+ deeplotx-0.9.15.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
36
+ deeplotx-0.9.15.dist-info/RECORD,,