deeplotx 0.9.0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeplotx/__init__.py +2 -1
- deeplotx/ner/__init__.py +1 -1
- deeplotx/ner/base_ner.py +2 -2
- deeplotx/ner/bert_ner.py +67 -4
- deeplotx/ner/n2g/__init__.py +91 -0
- deeplotx/ner/named_entity.py +9 -1
- {deeplotx-0.9.0.dist-info → deeplotx-0.9.3.dist-info}/METADATA +135 -74
- {deeplotx-0.9.0.dist-info → deeplotx-0.9.3.dist-info}/RECORD +11 -10
- {deeplotx-0.9.0.dist-info → deeplotx-0.9.3.dist-info}/WHEEL +0 -0
- {deeplotx-0.9.0.dist-info → deeplotx-0.9.3.dist-info}/licenses/LICENSE +0 -0
- {deeplotx-0.9.0.dist-info → deeplotx-0.9.3.dist-info}/top_level.txt +0 -0
deeplotx/__init__.py
CHANGED
@@ -4,7 +4,8 @@ import os
|
|
4
4
|
__ROOT__ = os.path.dirname(os.path.abspath(__file__))
|
5
5
|
|
6
6
|
from .encoder import Encoder, LongTextEncoder, LongformerEncoder
|
7
|
-
from .ner import
|
7
|
+
from .ner.n2g import Name2Gender, Gender
|
8
|
+
from .ner import BertNER, NamedEntity, NamedPerson
|
8
9
|
from .nn import (
|
9
10
|
FeedForward,
|
10
11
|
MultiHeadFeedForward,
|
deeplotx/ner/__init__.py
CHANGED
deeplotx/ner/base_ner.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
-
from deeplotx.ner.named_entity import NamedEntity
|
1
|
+
from deeplotx.ner.named_entity import NamedEntity, NamedPerson
|
2
2
|
|
3
3
|
|
4
4
|
class BaseNER:
|
5
5
|
def __init__(self): ...
|
6
6
|
|
7
|
-
def
|
7
|
+
def __call__(self, s: str, *args, **kwargs) -> list[NamedEntity | NamedPerson]: ...
|
deeplotx/ner/bert_ner.py
CHANGED
@@ -6,11 +6,13 @@ import torch
|
|
6
6
|
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
7
7
|
|
8
8
|
from deeplotx import __ROOT__
|
9
|
+
from deeplotx.ner.n2g import Name2Gender
|
9
10
|
from deeplotx.ner.base_ner import BaseNER
|
10
|
-
from deeplotx.ner.named_entity import NamedEntity
|
11
|
+
from deeplotx.ner.named_entity import NamedEntity, NamedPerson
|
11
12
|
|
12
13
|
CACHE_PATH = os.path.join(__ROOT__, '.cache')
|
13
14
|
DEFAULT_BERT_NER = 'Davlan/xlm-roberta-base-ner-hrl'
|
15
|
+
N2G_MODEL: list[Name2Gender] = []
|
14
16
|
logger = logging.getLogger('deeplotx.ner')
|
15
17
|
|
16
18
|
|
@@ -44,9 +46,9 @@ class BertNER(BaseNER):
|
|
44
46
|
self._ner_pipeline = pipeline(task='ner', model=self.encoder, tokenizer=self.tokenizer, trust_remote_code=True)
|
45
47
|
logger.debug(f'{BaseNER.__name__} initialized on device: {self.device}.')
|
46
48
|
|
47
|
-
def
|
49
|
+
def _fast_extract(self, s: str, with_gender: bool = True, prob_threshold: float = .0) -> list[NamedEntity]:
|
48
50
|
assert prob_threshold <= 1., f'prob_threshold ({prob_threshold}) cannot be larger than 1.'
|
49
|
-
s = ' '
|
51
|
+
s = f' {s} '
|
50
52
|
raw_entities = self._ner_pipeline(s)
|
51
53
|
entities = []
|
52
54
|
for ent in raw_entities:
|
@@ -69,4 +71,65 @@ class BertNER(BaseNER):
|
|
69
71
|
ent[0] = ent[0].strip()
|
70
72
|
if ent[1].upper().startswith('B'):
|
71
73
|
ent[1] = ent[1].upper()[1:].strip('-')
|
72
|
-
|
74
|
+
entities = [NamedEntity(*_) for _ in entities if _[2] >= prob_threshold]
|
75
|
+
if not with_gender:
|
76
|
+
return entities
|
77
|
+
if len(N2G_MODEL) < 1:
|
78
|
+
N2G_MODEL.append(Name2Gender())
|
79
|
+
n2g_model = N2G_MODEL[0]
|
80
|
+
for i, ent in enumerate(entities):
|
81
|
+
if ent.type.upper() == 'PER':
|
82
|
+
gender, gender_prob = n2g_model(ent.text, return_probability=True)
|
83
|
+
entities[i] = NamedPerson(text=ent.text,
|
84
|
+
type=ent.type,
|
85
|
+
base_probability=ent.base_probability,
|
86
|
+
gender=gender,
|
87
|
+
gender_probability=gender_prob)
|
88
|
+
return entities
|
89
|
+
|
90
|
+
def _slow_extract(self, s: str, with_gender: bool = True, prob_threshold: float = .0, deduplicate: bool = True) -> list[NamedEntity]:
|
91
|
+
_entities = self._fast_extract(s, with_gender=with_gender, prob_threshold=prob_threshold) if len(s) < 512 else []
|
92
|
+
if len(s) >= 512:
|
93
|
+
window_size: int = 512
|
94
|
+
offset = window_size // 6
|
95
|
+
for _offset in [- offset, offset]:
|
96
|
+
_window_size = window_size + _offset
|
97
|
+
for i in range(0, len(s) + _window_size, _window_size):
|
98
|
+
_entities.extend(self._fast_extract(s[i: i + _window_size], with_gender=with_gender, prob_threshold=prob_threshold))
|
99
|
+
_tmp_entities = sorted(_entities, key=lambda x: len(x.text), reverse=True)
|
100
|
+
for _ent_i in _tmp_entities:
|
101
|
+
for _ent_j in _entities:
|
102
|
+
if (_ent_j.text in _ent_i.text
|
103
|
+
and len(_ent_j.text) != len(_ent_i.text)
|
104
|
+
and _ent_j in _tmp_entities):
|
105
|
+
_tmp_entities.remove(_ent_j)
|
106
|
+
while True:
|
107
|
+
for _ent in _tmp_entities:
|
108
|
+
if _ent.text not in s or len(_ent.text) < 2:
|
109
|
+
_tmp_entities.remove(_ent)
|
110
|
+
_continue = False
|
111
|
+
for _ent in _tmp_entities:
|
112
|
+
if _ent.text not in s or len(_ent.text) < 2:
|
113
|
+
_continue = True
|
114
|
+
break
|
115
|
+
if not _continue:
|
116
|
+
break
|
117
|
+
if not deduplicate:
|
118
|
+
return _tmp_entities
|
119
|
+
_fin_entities = dict()
|
120
|
+
texts = set([text.text for text in _tmp_entities])
|
121
|
+
for text in texts:
|
122
|
+
for _ent in _tmp_entities:
|
123
|
+
if _ent.text == text:
|
124
|
+
if _ent.text not in _fin_entities.keys():
|
125
|
+
_fin_entities[_ent.text] = _ent
|
126
|
+
else:
|
127
|
+
if _ent.base_probability > _fin_entities[_ent.text].base_probability:
|
128
|
+
_fin_entities[_ent.text] = _ent
|
129
|
+
return [v for k, v in _fin_entities.items()]
|
130
|
+
|
131
|
+
def __call__(self, s: str, with_gender: bool = True, prob_threshold: float = .0, fast_mode: bool = False, *args, **kwargs):
|
132
|
+
if fast_mode:
|
133
|
+
return self._fast_extract(s=s, with_gender=with_gender, prob_threshold=prob_threshold)
|
134
|
+
else:
|
135
|
+
return self._slow_extract(s=s, with_gender=with_gender, prob_threshold=prob_threshold, deduplicate=True)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import requests
|
4
|
+
import gdown
|
5
|
+
import torch
|
6
|
+
from name4py import Gender
|
7
|
+
|
8
|
+
from deeplotx import __ROOT__
|
9
|
+
from deeplotx.encoder.encoder import Encoder
|
10
|
+
from deeplotx.nn.logistic_regression import LogisticRegression
|
11
|
+
from deeplotx.nn.base_neural_network import BaseNeuralNetwork
|
12
|
+
|
13
|
+
|
14
|
+
__CACHE_DIR__ = os.path.join(__ROOT__, '.cache', '.n2g')
|
15
|
+
ENCODER = Encoder(model_name_or_path='FacebookAI/xlm-roberta-base')
|
16
|
+
BASE_MODEL = 'name2gender-base'
|
17
|
+
SMALL_MODEL = 'name2gender-small'
|
18
|
+
_MIN_FILE_SIZE = 1024 * 5
|
19
|
+
|
20
|
+
|
21
|
+
def download_model(model_name: str):
|
22
|
+
quiet = bool(os.getenv('QUIET_DOWNLOAD', False))
|
23
|
+
os.makedirs(__CACHE_DIR__, exist_ok=True)
|
24
|
+
_proxies = {
|
25
|
+
'http': os.getenv('HTTP_PROXY', os.getenv('http_proxy')),
|
26
|
+
'https': os.getenv('HTTPS_PROXY', os.getenv('https_proxy'))
|
27
|
+
}
|
28
|
+
model_name = f'{model_name}.dlx'
|
29
|
+
model_path = os.path.join(__CACHE_DIR__, model_name)
|
30
|
+
base_url = 'https://github.com/vortezwohl/Name2Gender'
|
31
|
+
if not os.path.exists(model_path):
|
32
|
+
url = f'{base_url}/releases/download/RESOURCE/{model_name}'
|
33
|
+
if requests.get(url=base_url, proxies=_proxies).status_code == 200:
|
34
|
+
try:
|
35
|
+
gdown.download(
|
36
|
+
url=url,
|
37
|
+
output=model_path,
|
38
|
+
quiet=quiet,
|
39
|
+
proxy=_proxies.get('https'),
|
40
|
+
speed=8192 * 1024,
|
41
|
+
resume=True
|
42
|
+
)
|
43
|
+
if os.path.getsize(model_path) < _MIN_FILE_SIZE:
|
44
|
+
raise FileNotFoundError(f"Model \"{model_name}\" doesn't exists.")
|
45
|
+
except Exception as e:
|
46
|
+
os.remove(model_path)
|
47
|
+
raise e
|
48
|
+
else:
|
49
|
+
raise ConnectionError(f'Failed to download model {model_name}.')
|
50
|
+
|
51
|
+
|
52
|
+
def load_model(model_name: str = 'name2gender-small', dtype: torch.dtype | None = torch.float16) -> BaseNeuralNetwork:
|
53
|
+
n2g_model = None
|
54
|
+
match model_name:
|
55
|
+
case 'name2gender-base' | 'n2g-base' | 'base':
|
56
|
+
download_model(BASE_MODEL)
|
57
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
58
|
+
num_heads=12, num_layers=4,
|
59
|
+
head_layers=1, expansion_factor=2,
|
60
|
+
model_name=BASE_MODEL, dtype=dtype)
|
61
|
+
case 'name2gender-small' | 'n2g-base' | 'small':
|
62
|
+
download_model(SMALL_MODEL)
|
63
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
64
|
+
num_heads=6, num_layers=2,
|
65
|
+
head_layers=1, expansion_factor=1.5,
|
66
|
+
model_name=SMALL_MODEL, dtype=dtype)
|
67
|
+
case _:
|
68
|
+
download_model(SMALL_MODEL)
|
69
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
70
|
+
num_heads=6, num_layers=2,
|
71
|
+
head_layers=1, expansion_factor=1.5,
|
72
|
+
model_name=SMALL_MODEL, dtype=dtype)
|
73
|
+
return n2g_model.load(model_dir=__CACHE_DIR__)
|
74
|
+
|
75
|
+
|
76
|
+
class Name2Gender:
|
77
|
+
def __init__(self, model: BaseNeuralNetwork | None = None):
|
78
|
+
super().__init__()
|
79
|
+
if model is None:
|
80
|
+
model = load_model(SMALL_MODEL)
|
81
|
+
self._model = model
|
82
|
+
|
83
|
+
def __call__(self, name: str, return_probability: bool = False, threshold: float = .5) -> tuple[Gender, float] | Gender:
|
84
|
+
assert len(name) > 0, f'name ({name}) cannot be empty.'
|
85
|
+
name = f'{name[0].upper()}{name[1:]}'
|
86
|
+
emb = ENCODER.encode(name)
|
87
|
+
prob = self._model.predict(emb).item()
|
88
|
+
gender = Gender.Male if prob >= threshold else Gender.Female
|
89
|
+
if return_probability:
|
90
|
+
return gender, prob if gender == Gender.Male else (1. - prob)
|
91
|
+
return gender
|
deeplotx/ner/named_entity.py
CHANGED
@@ -1,8 +1,16 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
|
3
|
+
from deeplotx.ner.n2g import Gender
|
4
|
+
|
3
5
|
|
4
6
|
@dataclass
|
5
7
|
class NamedEntity:
|
6
8
|
text: str
|
7
9
|
type: str
|
8
|
-
|
10
|
+
base_probability: float
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class NamedPerson(NamedEntity):
|
15
|
+
gender: Gender
|
16
|
+
gender_probability: float
|
@@ -1,7 +1,7 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: deeplotx
|
3
|
-
Version: 0.9.
|
4
|
-
Summary:
|
3
|
+
Version: 0.9.3
|
4
|
+
Summary: An out-of-the-box long-text NLP framework.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
7
|
License-File: LICENSE
|
@@ -16,121 +16,198 @@ Requires-Dist: torch
|
|
16
16
|
Requires-Dist: transformers
|
17
17
|
Requires-Dist: typing-extensions
|
18
18
|
Requires-Dist: vortezwohl>=0.0.8
|
19
|
-
Requires-Dist:
|
19
|
+
Requires-Dist: name4py>=0.1.4
|
20
20
|
Dynamic: license-file
|
21
21
|
|
22
22
|
[](https://deepwiki.com/vortezwohl/DeepLoTX)
|
23
23
|
|
24
|
-
# Deep Long Text Learning
|
24
|
+
# *Deep Long Text Learning*
|
25
25
|
|
26
|
-
|
26
|
+
*An out-of-the-box long-text NLP framework.*
|
27
27
|
|
28
|
-
|
28
|
+
> Author: [vortezwohl](https://github.com/vortezwohl)
|
29
29
|
|
30
|
-
##
|
30
|
+
## Installation
|
31
31
|
|
32
|
-
-
|
32
|
+
- **With pip**
|
33
33
|
|
34
34
|
```
|
35
35
|
pip install -U deeplotx
|
36
36
|
```
|
37
37
|
|
38
|
-
-
|
38
|
+
- **With uv (recommended)**
|
39
39
|
|
40
40
|
```
|
41
41
|
uv add -U deeplotx
|
42
42
|
```
|
43
43
|
|
44
|
-
-
|
44
|
+
- **Get the latest features from GitHub**
|
45
45
|
|
46
46
|
```
|
47
47
|
pip install -U git+https://github.com/vortezwohl/DeepLoTX.git
|
48
48
|
```
|
49
49
|
|
50
|
-
##
|
50
|
+
## Quick start
|
51
51
|
|
52
|
-
- ###
|
52
|
+
- ### Named entity recognition
|
53
53
|
|
54
|
-
|
54
|
+
> *Multilingual is supported.*
|
55
|
+
|
56
|
+
> *Gender recognition is supported.*
|
57
|
+
|
58
|
+
Import dependencies
|
59
|
+
|
60
|
+
```python
|
61
|
+
from deeplotx import BertNER
|
62
|
+
|
63
|
+
ner = BertNER()
|
64
|
+
```
|
65
|
+
|
66
|
+
```python
|
67
|
+
ner('你好, 我的名字是吴子豪, 来自福建福州.')
|
68
|
+
```
|
69
|
+
|
70
|
+
stdout:
|
71
|
+
|
72
|
+
```
|
73
|
+
[NamedPerson(text='吴子豪', type='PER', base_probability=0.9995428418719051, gender=<Gender.Male: 'male'>, gender_probability=0.9970703125),
|
74
|
+
NamedEntity(text='福建', type='LOC', base_probability=0.9986373782157898),
|
75
|
+
NamedEntity(text='福州', type='LOC', base_probability=0.9993632435798645)]
|
76
|
+
```
|
77
|
+
|
78
|
+
```python
|
79
|
+
ner("Hi, i'm Vortez Wohl, author of DeeploTX.")
|
80
|
+
```
|
81
|
+
|
82
|
+
stdout:
|
83
|
+
|
84
|
+
```
|
85
|
+
[NamedPerson(text='Vortez Wohl', type='PER', base_probability=0.9991965342072855, gender=<Gender.Male: 'male'>, gender_probability=0.87255859375)]
|
86
|
+
```
|
87
|
+
|
88
|
+
- ### Gender recognition
|
89
|
+
|
90
|
+
> *Multilingual is supported.*
|
91
|
+
|
92
|
+
> *Integrated from [Name2Gender](https://github.com/vortezwohl/Name2Gender)*
|
93
|
+
|
94
|
+
Import dependencies
|
95
|
+
|
96
|
+
```python
|
97
|
+
from deeplotx import Name2Gender
|
98
|
+
|
99
|
+
n2g = Name2Gender()
|
100
|
+
```
|
101
|
+
|
102
|
+
Recognize gender of "Elon Musk":
|
103
|
+
|
104
|
+
```python
|
105
|
+
n2g('Elon Musk')
|
106
|
+
```
|
107
|
+
|
108
|
+
stdout:
|
109
|
+
|
110
|
+
```
|
111
|
+
<Gender.Male: 'male'>
|
112
|
+
```
|
113
|
+
|
114
|
+
Recognize gender of "Anne Hathaway":
|
115
|
+
|
116
|
+
```python
|
117
|
+
n2g('Anne Hathaway')
|
118
|
+
```
|
119
|
+
|
120
|
+
stdout:
|
121
|
+
|
122
|
+
```
|
123
|
+
<Gender.Female: 'female'>
|
124
|
+
```
|
125
|
+
|
126
|
+
Recognize gender of "吴彦祖":
|
127
|
+
|
128
|
+
```python
|
129
|
+
n2g('吴彦祖', return_probability=True)
|
130
|
+
```
|
131
|
+
|
132
|
+
stdout:
|
133
|
+
|
134
|
+
```
|
135
|
+
(<Gender.Male: 'male'>, 1.0)
|
136
|
+
```
|
137
|
+
|
138
|
+
- ### Long text embedding
|
139
|
+
|
140
|
+
- **BERT based long text embedding**
|
55
141
|
|
56
142
|
```python
|
57
143
|
from deeplotx import LongTextEncoder
|
58
144
|
|
59
|
-
# 块大小为 448 个 tokens, 块间重叠部分为 32 个 tokens.
|
60
145
|
encoder = LongTextEncoder(
|
61
146
|
chunk_size=448,
|
62
147
|
overlapping=32
|
63
148
|
)
|
64
|
-
# 对 "我是吴子豪, 这是一个测试文本." 计算嵌入, 并堆叠.
|
65
149
|
encoder.encode('我是吴子豪, 这是一个测试文本.', flatten=False)
|
66
150
|
```
|
67
151
|
|
68
|
-
|
152
|
+
stdout:
|
69
153
|
```
|
70
154
|
tensor([ 2.2316e-01, 2.0300e-01, ..., 1.5578e-01, -6.6735e-02])
|
71
155
|
```
|
72
156
|
|
73
|
-
-
|
157
|
+
- **Longformer based long text embedding**
|
74
158
|
|
75
159
|
```python
|
76
160
|
from deeplotx import LongformerEncoder
|
77
161
|
|
78
162
|
encoder = LongformerEncoder()
|
79
|
-
encoder.encode('
|
163
|
+
encoder.encode('Thank you for using DeepLoTX.')
|
80
164
|
```
|
81
165
|
|
82
|
-
|
166
|
+
stdout:
|
83
167
|
```
|
84
168
|
tensor([-2.7490e-02, 6.6503e-02, ..., -6.5937e-02, 6.7802e-03])
|
85
169
|
```
|
86
170
|
|
87
|
-
- ###
|
171
|
+
- ### Similarities calculation
|
88
172
|
|
89
|
-
-
|
173
|
+
- **Vector based**
|
90
174
|
|
91
175
|
```python
|
92
176
|
import deeplotx.similarity as sim
|
93
177
|
|
94
178
|
vector_0, vector_1 = [1, 2, 3, 4], [4, 3, 2, 1]
|
95
|
-
# 欧几里得距离
|
96
179
|
distance_0 = sim.euclidean_similarity(vector_0, vector_1)
|
97
180
|
print(distance_0)
|
98
|
-
# 余弦距离
|
99
181
|
distance_1 = sim.cosine_similarity(vector_0, vector_1)
|
100
182
|
print(distance_1)
|
101
|
-
# 切比雪夫距离
|
102
183
|
distance_2 = sim.chebyshev_similarity(vector_0, vector_1)
|
103
184
|
print(distance_2)
|
104
185
|
```
|
105
186
|
|
106
|
-
|
187
|
+
stdout:
|
107
188
|
```
|
108
189
|
4.47213595499958
|
109
190
|
0.33333333333333337
|
110
191
|
3
|
111
192
|
```
|
112
193
|
|
113
|
-
-
|
194
|
+
- **Set based**
|
114
195
|
|
115
196
|
```python
|
116
197
|
import deeplotx.similarity as sim
|
117
198
|
|
118
199
|
set_0, set_1 = {1, 2, 3, 4}, {4, 5, 6, 7}
|
119
|
-
# 杰卡德距离
|
120
200
|
distance_0 = sim.jaccard_similarity(set_0, set_1)
|
121
201
|
print(distance_0)
|
122
|
-
# Ochiai 距离
|
123
202
|
distance_1 = sim.ochiai_similarity(set_0, set_1)
|
124
203
|
print(distance_1)
|
125
|
-
# Dice 系数
|
126
204
|
distance_2 = sim.dice_coefficient(set_0, set_1)
|
127
205
|
print(distance_2)
|
128
|
-
# Overlap 系数
|
129
206
|
distance_3 = sim.overlap_coefficient(set_0, set_1)
|
130
207
|
print(distance_3)
|
131
208
|
```
|
132
209
|
|
133
|
-
|
210
|
+
stdout:
|
134
211
|
```
|
135
212
|
0.1428571428572653
|
136
213
|
0.2500000000001875
|
@@ -138,27 +215,23 @@ Dynamic: license-file
|
|
138
215
|
0.2500000000001875
|
139
216
|
```
|
140
217
|
|
141
|
-
-
|
218
|
+
- **Distribution based**
|
142
219
|
|
143
220
|
```python
|
144
221
|
import deeplotx.similarity as sim
|
145
222
|
|
146
223
|
dist_0, dist_1 = [0.3, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.4]
|
147
|
-
# 交叉熵
|
148
224
|
distance_0 = sim.cross_entropy(dist_0, dist_1)
|
149
225
|
print(distance_0)
|
150
|
-
# KL 散度
|
151
226
|
distance_1 = sim.kl_divergence(dist_0, dist_1)
|
152
227
|
print(distance_1)
|
153
|
-
# JS 散度
|
154
228
|
distance_2 = sim.js_divergence(dist_0, dist_1)
|
155
229
|
print(distance_2)
|
156
|
-
# Hellinger 距离
|
157
230
|
distance_3 = sim.hellinger_distance(dist_0, dist_1)
|
158
231
|
print(distance_3)
|
159
232
|
```
|
160
233
|
|
161
|
-
|
234
|
+
stdout:
|
162
235
|
```
|
163
236
|
0.3575654913778237
|
164
237
|
0.15040773967762736
|
@@ -166,27 +239,27 @@ Dynamic: license-file
|
|
166
239
|
0.20105866986400994
|
167
240
|
```
|
168
241
|
|
169
|
-
- ###
|
242
|
+
- ### Pre-defined neural networks
|
170
243
|
|
171
244
|
```python
|
172
245
|
from deeplotx import (
|
173
|
-
FeedForward,
|
174
|
-
MultiHeadFeedForward,
|
175
|
-
LinearRegression,
|
176
|
-
LogisticRegression,
|
177
|
-
SoftmaxRegression,
|
178
|
-
RecursiveSequential,
|
179
|
-
LongContextRecursiveSequential,
|
180
|
-
RoPE,
|
181
|
-
Attention,
|
182
|
-
MultiHeadAttention,
|
183
|
-
RoFormerEncoder,
|
184
|
-
AutoRegression,
|
185
|
-
LongContextAutoRegression
|
246
|
+
FeedForward,
|
247
|
+
MultiHeadFeedForward,
|
248
|
+
LinearRegression,
|
249
|
+
LogisticRegression,
|
250
|
+
SoftmaxRegression,
|
251
|
+
RecursiveSequential,
|
252
|
+
LongContextRecursiveSequential,
|
253
|
+
RoPE,
|
254
|
+
Attention,
|
255
|
+
MultiHeadAttention,
|
256
|
+
RoFormerEncoder,
|
257
|
+
AutoRegression,
|
258
|
+
LongContextAutoRegression
|
186
259
|
)
|
187
260
|
```
|
188
261
|
|
189
|
-
|
262
|
+
The fundamental FFN (MLPs):
|
190
263
|
|
191
264
|
```python
|
192
265
|
from typing_extensions import override
|
@@ -244,7 +317,7 @@ Dynamic: license-file
|
|
244
317
|
return x
|
245
318
|
```
|
246
319
|
|
247
|
-
|
320
|
+
Attention:
|
248
321
|
|
249
322
|
```python
|
250
323
|
from typing_extensions import override
|
@@ -297,46 +370,34 @@ Dynamic: license-file
|
|
297
370
|
return torch.matmul(self._attention(x, y, mask), v)
|
298
371
|
```
|
299
372
|
|
300
|
-
- ###
|
373
|
+
- ### Text binary classification task with predefined trainer
|
301
374
|
|
302
375
|
```python
|
303
376
|
from deeplotx import TextBinaryClassifierTrainer, LongTextEncoder
|
304
377
|
from deeplotx.util import get_files, read_file
|
305
378
|
|
306
|
-
# 定义向量编码策略 (默认使用 FacebookAI/xlm-roberta-base 作为嵌入模型)
|
307
379
|
long_text_encoder = LongTextEncoder(
|
308
|
-
max_length=2048,
|
309
|
-
chunk_size=448,
|
310
|
-
overlapping=32,
|
311
|
-
cache_capacity=512
|
380
|
+
max_length=2048,
|
381
|
+
chunk_size=448,
|
382
|
+
overlapping=32,
|
383
|
+
cache_capacity=512
|
312
384
|
)
|
313
|
-
|
314
385
|
trainer = TextBinaryClassifierTrainer(
|
315
386
|
long_text_encoder=long_text_encoder,
|
316
387
|
batch_size=2,
|
317
|
-
train_ratio=0.9
|
388
|
+
train_ratio=0.9
|
318
389
|
)
|
319
|
-
|
320
|
-
# 读取数据
|
321
390
|
pos_data_path = 'path/to/pos_dir'
|
322
391
|
neg_data_path = 'path/to/neg_dir'
|
323
392
|
pos_data = [read_file(x) for x in get_files(pos_data_path)]
|
324
393
|
neg_data = [read_file(x) for x in get_files(neg_data_path)]
|
325
|
-
|
326
|
-
# 开始训练
|
327
394
|
model = trainer.train(pos_data, neg_data,
|
328
395
|
num_epochs=36, learning_rate=2e-5,
|
329
396
|
balancing_dataset=True, alpha=1e-4,
|
330
|
-
rho=.2, encoder_layers=2,
|
331
|
-
attn_heads=8,
|
332
|
-
recursive_layers=2)
|
333
|
-
|
334
|
-
# 保存模型权重
|
397
|
+
rho=.2, encoder_layers=2,
|
398
|
+
attn_heads=8,
|
399
|
+
recursive_layers=2)
|
335
400
|
model.save(model_name='test_model', model_dir='model')
|
336
|
-
|
337
|
-
# 加载已保存的模型
|
338
401
|
model = model.load(model_name='test_model', model_dir='model')
|
339
|
-
|
340
|
-
# 使用训练好的模型进行预测
|
341
402
|
model.predict(long_text_encoder.encode('这是一个测试文本.', flatten=False))
|
342
403
|
```
|
@@ -1,12 +1,13 @@
|
|
1
|
-
deeplotx/__init__.py,sha256=
|
1
|
+
deeplotx/__init__.py,sha256=x4CbJuW20al6S5KkKyrReeuwNGv04JGoqtGUyx-ACtg,1356
|
2
2
|
deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
|
3
3
|
deeplotx/encoder/encoder.py,sha256=wVRl3p_7eg7qT_tJEit5qnmZx7dXkMVLxAtao5vImkk,4201
|
4
4
|
deeplotx/encoder/long_text_encoder.py,sha256=4oRa9FqfGNZ8-gq14UKuhDkZC0A1Xi-wKmbQsn-uZ58,3966
|
5
5
|
deeplotx/encoder/longformer_encoder.py,sha256=7Lm65AUD3qwbrzrhJ3dPZkyHeNRSapga3f-5QJCxV5A,3538
|
6
|
-
deeplotx/ner/__init__.py,sha256=
|
7
|
-
deeplotx/ner/base_ner.py,sha256=
|
8
|
-
deeplotx/ner/bert_ner.py,sha256=
|
9
|
-
deeplotx/ner/named_entity.py,sha256=
|
6
|
+
deeplotx/ner/__init__.py,sha256=Rss1pup9HzHZCG8U9ub8niWa9zRjWCy3Z7zg378KZQg,114
|
7
|
+
deeplotx/ner/base_ner.py,sha256=pZTl50OrHH_FJm4rKp9iuixeOE6FX_AzgDXD32aXsN0,204
|
8
|
+
deeplotx/ner/bert_ner.py,sha256=I8yFsarsLEQv0vcnNU2JIc0-LuPJcxaO-mLhDFCh1PI,7704
|
9
|
+
deeplotx/ner/named_entity.py,sha256=c6XufIwH6yloJ-ccUjagf4mBl1XbbYDT8xyEJJ_-ZNs,269
|
10
|
+
deeplotx/ner/n2g/__init__.py,sha256=b6fOWJVLaOCtoz8Qlp8NWQbL5lUSbn6H3-8fnVNIPi0,3940
|
10
11
|
deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
|
11
12
|
deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
|
12
13
|
deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
|
@@ -32,8 +33,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
|
|
32
33
|
deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
|
33
34
|
deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
|
34
35
|
deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
|
35
|
-
deeplotx-0.9.
|
36
|
-
deeplotx-0.9.
|
37
|
-
deeplotx-0.9.
|
38
|
-
deeplotx-0.9.
|
39
|
-
deeplotx-0.9.
|
36
|
+
deeplotx-0.9.3.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
37
|
+
deeplotx-0.9.3.dist-info/METADATA,sha256=Fg0KzWIxFcMtuTfmuQ9BBJDFXjNTWtl9l3Cuuc1sX3I,13472
|
38
|
+
deeplotx-0.9.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
39
|
+
deeplotx-0.9.3.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
|
40
|
+
deeplotx-0.9.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|