deeplotx 0.8.8__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deeplotx/__init__.py +4 -0
- deeplotx/ner/__init__.py +3 -0
- deeplotx/ner/base_ner.py +10 -0
- deeplotx/ner/bert_ner.py +91 -0
- deeplotx/ner/n2g/__init__.py +91 -0
- deeplotx/ner/named_entity.py +16 -0
- {deeplotx-0.8.8.dist-info → deeplotx-0.9.2.dist-info}/METADATA +136 -73
- {deeplotx-0.8.8.dist-info → deeplotx-0.9.2.dist-info}/RECORD +11 -6
- {deeplotx-0.8.8.dist-info → deeplotx-0.9.2.dist-info}/WHEEL +0 -0
- {deeplotx-0.8.8.dist-info → deeplotx-0.9.2.dist-info}/licenses/LICENSE +0 -0
- {deeplotx-0.8.8.dist-info → deeplotx-0.9.2.dist-info}/top_level.txt +0 -0
deeplotx/__init__.py
CHANGED
@@ -4,6 +4,8 @@ import os
|
|
4
4
|
__ROOT__ = os.path.dirname(os.path.abspath(__file__))
|
5
5
|
|
6
6
|
from .encoder import Encoder, LongTextEncoder, LongformerEncoder
|
7
|
+
from .ner.n2g import Name2Gender, Gender
|
8
|
+
from .ner import BertNER, NamedEntity, NamedPerson
|
7
9
|
from .nn import (
|
8
10
|
FeedForward,
|
9
11
|
MultiHeadFeedForward,
|
@@ -40,3 +42,5 @@ logger = logging.getLogger('deeplotx.trainer')
|
|
40
42
|
logger.setLevel(logging.DEBUG)
|
41
43
|
logger = logging.getLogger('deeplotx.embedding')
|
42
44
|
logger.setLevel(logging.DEBUG)
|
45
|
+
logger = logging.getLogger('deeplotx.ner')
|
46
|
+
logger.setLevel(logging.DEBUG)
|
deeplotx/ner/__init__.py
ADDED
deeplotx/ner/base_ner.py
ADDED
@@ -0,0 +1,10 @@
|
|
1
|
+
from deeplotx.ner.named_entity import NamedEntity, NamedPerson
|
2
|
+
|
3
|
+
|
4
|
+
class BaseNER:
|
5
|
+
def __init__(self): ...
|
6
|
+
|
7
|
+
def __call__(self, s: str, *args, **kwargs) -> list[NamedEntity | NamedPerson]:
|
8
|
+
return self.extract_entities(s=s, *args, **kwargs)
|
9
|
+
|
10
|
+
def extract_entities(self, s: str, *args, **kwargs) -> list[NamedEntity | NamedPerson]: ...
|
deeplotx/ner/bert_ner.py
ADDED
@@ -0,0 +1,91 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from requests.exceptions import ConnectTimeout, SSLError
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
7
|
+
|
8
|
+
from deeplotx import __ROOT__
|
9
|
+
from deeplotx.ner.n2g import Name2Gender
|
10
|
+
from deeplotx.ner.base_ner import BaseNER
|
11
|
+
from deeplotx.ner.named_entity import NamedEntity, NamedPerson
|
12
|
+
|
13
|
+
CACHE_PATH = os.path.join(__ROOT__, '.cache')
|
14
|
+
DEFAULT_BERT_NER = 'Davlan/xlm-roberta-base-ner-hrl'
|
15
|
+
N2G_MODEL: list[Name2Gender] = []
|
16
|
+
logger = logging.getLogger('deeplotx.ner')
|
17
|
+
|
18
|
+
|
19
|
+
class BertNER(BaseNER):
|
20
|
+
def __init__(self, model_name_or_path: str = DEFAULT_BERT_NER, device: str | None = None):
|
21
|
+
super().__init__()
|
22
|
+
self.device = torch.device(device) if device is not None \
|
23
|
+
else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
24
|
+
try:
|
25
|
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
26
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
27
|
+
trust_remote_code=True)
|
28
|
+
self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
29
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
30
|
+
trust_remote_code=True).to(self.device)
|
31
|
+
except ConnectTimeout:
|
32
|
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
33
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
34
|
+
trust_remote_code=True, local_files_only=True)
|
35
|
+
self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
36
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
37
|
+
trust_remote_code=True, local_files_only=True).to(self.device)
|
38
|
+
except SSLError:
|
39
|
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
40
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
41
|
+
trust_remote_code=True, local_files_only=True)
|
42
|
+
self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
|
43
|
+
cache_dir=CACHE_PATH, _from_auto=True,
|
44
|
+
trust_remote_code=True, local_files_only=True).to(self.device)
|
45
|
+
self.embed_dim = self.encoder.config.max_position_embeddings
|
46
|
+
self._ner_pipeline = pipeline(task='ner', model=self.encoder, tokenizer=self.tokenizer, trust_remote_code=True)
|
47
|
+
logger.debug(f'{BaseNER.__name__} initialized on device: {self.device}.')
|
48
|
+
|
49
|
+
def extract_entities(self, s: str, with_gender: bool = True, prob_threshold: float = .0, *args, **kwargs) -> list[NamedEntity]:
|
50
|
+
assert prob_threshold <= 1., f'prob_threshold ({prob_threshold}) cannot be larger than 1.'
|
51
|
+
s = ' ' + s
|
52
|
+
raw_entities = self._ner_pipeline(s)
|
53
|
+
entities = []
|
54
|
+
for ent in raw_entities:
|
55
|
+
entities.append([s[ent['start']: ent['end']], ent['entity'], ent['score'].item()])
|
56
|
+
while True:
|
57
|
+
for i, ent in enumerate(entities):
|
58
|
+
if len(ent[0].strip()) < 1:
|
59
|
+
del entities[i]
|
60
|
+
if ent[1].upper().startswith('I') and entities[i - 1][1].upper().startswith('B'):
|
61
|
+
entities[i - 1][0] += ent[0]
|
62
|
+
entities[i - 1][2] *= ent[2]
|
63
|
+
del entities[i]
|
64
|
+
_continue = False
|
65
|
+
for ent in entities:
|
66
|
+
if ent[1].upper().startswith('I'):
|
67
|
+
_continue = True
|
68
|
+
if not _continue:
|
69
|
+
break
|
70
|
+
for ent in entities:
|
71
|
+
ent[0] = ent[0].strip()
|
72
|
+
if ent[1].upper().startswith('B'):
|
73
|
+
ent[1] = ent[1].upper()[1:].strip('-')
|
74
|
+
entities = [NamedEntity(*_) for _ in entities if _[2] >= prob_threshold]
|
75
|
+
if not with_gender:
|
76
|
+
return entities
|
77
|
+
if len(N2G_MODEL) < 1:
|
78
|
+
N2G_MODEL.append(Name2Gender())
|
79
|
+
n2g_model = N2G_MODEL[0]
|
80
|
+
for i, ent in enumerate(entities):
|
81
|
+
if ent.type.upper() == 'PER':
|
82
|
+
gender, gender_prob = n2g_model(ent.text, return_probability=True)
|
83
|
+
entities[i] = NamedPerson(text=ent.text,
|
84
|
+
type=ent.type,
|
85
|
+
base_probability=ent.base_probability,
|
86
|
+
gender=gender,
|
87
|
+
gender_probability=gender_prob)
|
88
|
+
return entities
|
89
|
+
|
90
|
+
def __call__(self, s: str, with_gender: bool = True, prob_threshold: float = .0, *args, **kwargs):
|
91
|
+
return self.extract_entities(s=s, with_gender=with_gender, prob_threshold=prob_threshold, *args, **kwargs)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import requests
|
4
|
+
import gdown
|
5
|
+
import torch
|
6
|
+
from name4py import Gender
|
7
|
+
|
8
|
+
from deeplotx import __ROOT__
|
9
|
+
from deeplotx.encoder.encoder import Encoder
|
10
|
+
from deeplotx.nn.logistic_regression import LogisticRegression
|
11
|
+
from deeplotx.nn.base_neural_network import BaseNeuralNetwork
|
12
|
+
|
13
|
+
|
14
|
+
__CACHE_DIR__ = os.path.join(__ROOT__, '.cache', '.n2g')
|
15
|
+
ENCODER = Encoder(model_name_or_path='FacebookAI/xlm-roberta-base')
|
16
|
+
BASE_MODEL = 'name2gender-base'
|
17
|
+
SMALL_MODEL = 'name2gender-small'
|
18
|
+
_MIN_FILE_SIZE = 1024 * 5
|
19
|
+
|
20
|
+
|
21
|
+
def download_model(model_name: str):
|
22
|
+
quiet = bool(os.getenv('QUIET_DOWNLOAD', False))
|
23
|
+
os.makedirs(__CACHE_DIR__, exist_ok=True)
|
24
|
+
_proxies = {
|
25
|
+
'http': os.getenv('HTTP_PROXY', os.getenv('http_proxy')),
|
26
|
+
'https': os.getenv('HTTPS_PROXY', os.getenv('https_proxy'))
|
27
|
+
}
|
28
|
+
model_name = f'{model_name}.dlx'
|
29
|
+
model_path = os.path.join(__CACHE_DIR__, model_name)
|
30
|
+
base_url = 'https://github.com/vortezwohl/Name2Gender'
|
31
|
+
if not os.path.exists(model_path):
|
32
|
+
url = f'{base_url}/releases/download/RESOURCE/{model_name}'
|
33
|
+
if requests.get(url=base_url, proxies=_proxies).status_code == 200:
|
34
|
+
try:
|
35
|
+
gdown.download(
|
36
|
+
url=url,
|
37
|
+
output=model_path,
|
38
|
+
quiet=quiet,
|
39
|
+
proxy=_proxies.get('https'),
|
40
|
+
speed=8192 * 1024,
|
41
|
+
resume=True
|
42
|
+
)
|
43
|
+
if os.path.getsize(model_path) < _MIN_FILE_SIZE:
|
44
|
+
raise FileNotFoundError(f"Model \"{model_name}\" doesn't exists.")
|
45
|
+
except Exception as e:
|
46
|
+
os.remove(model_path)
|
47
|
+
raise e
|
48
|
+
else:
|
49
|
+
raise ConnectionError(f'Failed to download model {model_name}.')
|
50
|
+
|
51
|
+
|
52
|
+
def load_model(model_name: str = 'name2gender-small', dtype: torch.dtype | None = torch.float16) -> BaseNeuralNetwork:
|
53
|
+
n2g_model = None
|
54
|
+
match model_name:
|
55
|
+
case 'name2gender-base' | 'n2g-base' | 'base':
|
56
|
+
download_model(BASE_MODEL)
|
57
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
58
|
+
num_heads=12, num_layers=4,
|
59
|
+
head_layers=1, expansion_factor=2,
|
60
|
+
model_name=BASE_MODEL, dtype=dtype)
|
61
|
+
case 'name2gender-small' | 'n2g-base' | 'small':
|
62
|
+
download_model(SMALL_MODEL)
|
63
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
64
|
+
num_heads=6, num_layers=2,
|
65
|
+
head_layers=1, expansion_factor=1.5,
|
66
|
+
model_name=SMALL_MODEL, dtype=dtype)
|
67
|
+
case _:
|
68
|
+
download_model(SMALL_MODEL)
|
69
|
+
n2g_model = LogisticRegression(input_dim=768, output_dim=1,
|
70
|
+
num_heads=6, num_layers=2,
|
71
|
+
head_layers=1, expansion_factor=1.5,
|
72
|
+
model_name=SMALL_MODEL, dtype=dtype)
|
73
|
+
return n2g_model.load(model_dir=__CACHE_DIR__)
|
74
|
+
|
75
|
+
|
76
|
+
class Name2Gender:
|
77
|
+
def __init__(self, model: BaseNeuralNetwork | None = None):
|
78
|
+
super().__init__()
|
79
|
+
if model is None:
|
80
|
+
model = load_model(SMALL_MODEL)
|
81
|
+
self._model = model
|
82
|
+
|
83
|
+
def __call__(self, name: str, return_probability: bool = False, threshold: float = .5) -> tuple[Gender, float] | Gender:
|
84
|
+
assert len(name) > 0, f'name ({name}) cannot be empty.'
|
85
|
+
name = f'{name[0].upper()}{name[1:]}'
|
86
|
+
emb = ENCODER.encode(name)
|
87
|
+
prob = self._model.predict(emb).item()
|
88
|
+
gender = Gender.Male if prob >= threshold else Gender.Female
|
89
|
+
if return_probability:
|
90
|
+
return gender, prob if gender == Gender.Male else (1. - prob)
|
91
|
+
return gender
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from deeplotx.ner.n2g import Gender
|
4
|
+
|
5
|
+
|
6
|
+
@dataclass
|
7
|
+
class NamedEntity:
|
8
|
+
text: str
|
9
|
+
type: str
|
10
|
+
base_probability: float
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class NamedPerson(NamedEntity):
|
15
|
+
gender: Gender
|
16
|
+
gender_probability: float
|
@@ -1,7 +1,7 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: deeplotx
|
3
|
-
Version: 0.
|
4
|
-
Summary:
|
3
|
+
Version: 0.9.2
|
4
|
+
Summary: An out-of-the-box long-text NLP framework.
|
5
5
|
Requires-Python: >=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
7
|
License-File: LICENSE
|
@@ -10,125 +10,204 @@ Requires-Dist: jupyter
|
|
10
10
|
Requires-Dist: numpy
|
11
11
|
Requires-Dist: protobuf
|
12
12
|
Requires-Dist: python-dotenv
|
13
|
+
Requires-Dist: sentencepiece
|
13
14
|
Requires-Dist: tiktoken
|
14
15
|
Requires-Dist: torch
|
15
16
|
Requires-Dist: transformers
|
16
17
|
Requires-Dist: typing-extensions
|
17
18
|
Requires-Dist: vortezwohl>=0.0.8
|
19
|
+
Requires-Dist: name4py>=0.1.4
|
18
20
|
Dynamic: license-file
|
19
21
|
|
20
22
|
[](https://deepwiki.com/vortezwohl/DeepLoTX)
|
21
23
|
|
22
|
-
# Deep Long Text Learning
|
24
|
+
# *Deep Long Text Learning*
|
23
25
|
|
24
|
-
|
26
|
+
*An out-of-the-box long-text NLP framework.*
|
25
27
|
|
26
|
-
|
28
|
+
> Author: [vortezwohl](https://github.com/vortezwohl)
|
27
29
|
|
28
|
-
##
|
30
|
+
## Installation
|
29
31
|
|
30
|
-
-
|
32
|
+
- **With pip**
|
31
33
|
|
32
34
|
```
|
33
35
|
pip install -U deeplotx
|
34
36
|
```
|
35
37
|
|
36
|
-
-
|
38
|
+
- **With uv (recommended)**
|
37
39
|
|
38
40
|
```
|
39
41
|
uv add -U deeplotx
|
40
42
|
```
|
41
43
|
|
42
|
-
-
|
44
|
+
- **Get the latest features from GitHub**
|
43
45
|
|
44
46
|
```
|
45
47
|
pip install -U git+https://github.com/vortezwohl/DeepLoTX.git
|
46
48
|
```
|
47
49
|
|
48
|
-
##
|
50
|
+
## Quick start
|
49
51
|
|
50
|
-
- ###
|
52
|
+
- ### Named entity recognition
|
51
53
|
|
52
|
-
|
54
|
+
> *Multilingual is supported.*
|
55
|
+
|
56
|
+
> *Gender recognition is supported.*
|
57
|
+
|
58
|
+
Import dependencies
|
59
|
+
|
60
|
+
```python
|
61
|
+
from deeplotx import BertNER
|
62
|
+
|
63
|
+
ner = BertNER()
|
64
|
+
```
|
65
|
+
|
66
|
+
```python
|
67
|
+
ner('你好, 我的名字是吴子豪, 来自福建福州.')
|
68
|
+
```
|
69
|
+
|
70
|
+
stdout:
|
71
|
+
|
72
|
+
```
|
73
|
+
[NamedPerson(text='吴子豪', type='PER', base_probability=0.9995428418719051, gender=<Gender.Male: 'male'>, gender_probability=0.9970703125),
|
74
|
+
NamedEntity(text='福建', type='LOC', base_probability=0.9986373782157898),
|
75
|
+
NamedEntity(text='福州', type='LOC', base_probability=0.9993632435798645)]
|
76
|
+
```
|
77
|
+
|
78
|
+
```python
|
79
|
+
ner("Hi, i'm Vortez Wohl, author of DeeploTX.")
|
80
|
+
```
|
81
|
+
|
82
|
+
stdout:
|
83
|
+
|
84
|
+
```
|
85
|
+
[NamedPerson(text='Vortez Wohl', type='PER', base_probability=0.9991965342072855, gender=<Gender.Male: 'male'>, gender_probability=0.87255859375)]
|
86
|
+
```
|
87
|
+
|
88
|
+
- ### Gender recognition
|
89
|
+
|
90
|
+
> *Multilingual is supported.*
|
91
|
+
|
92
|
+
> *Integrated from [Name2Gender](https://github.com/vortezwohl/Name2Gender)*
|
93
|
+
|
94
|
+
Import dependencies
|
95
|
+
|
96
|
+
```python
|
97
|
+
from deeplotx import Name2Gender
|
98
|
+
|
99
|
+
n2g = Name2Gender()
|
100
|
+
```
|
101
|
+
|
102
|
+
Recognize gender of "Elon Musk":
|
103
|
+
|
104
|
+
```python
|
105
|
+
n2g('Elon Musk')
|
106
|
+
```
|
107
|
+
|
108
|
+
stdout:
|
109
|
+
|
110
|
+
```
|
111
|
+
<Gender.Male: 'male'>
|
112
|
+
```
|
113
|
+
|
114
|
+
Recognize gender of "Anne Hathaway":
|
115
|
+
|
116
|
+
```python
|
117
|
+
n2g('Anne Hathaway')
|
118
|
+
```
|
119
|
+
|
120
|
+
stdout:
|
121
|
+
|
122
|
+
```
|
123
|
+
<Gender.Female: 'female'>
|
124
|
+
```
|
125
|
+
|
126
|
+
Recognize gender of "吴彦祖":
|
127
|
+
|
128
|
+
```python
|
129
|
+
n2g('吴彦祖', return_probability=True)
|
130
|
+
```
|
131
|
+
|
132
|
+
stdout:
|
133
|
+
|
134
|
+
```
|
135
|
+
(<Gender.Male: 'male'>, 1.0)
|
136
|
+
```
|
137
|
+
|
138
|
+
- ### Long text embedding
|
139
|
+
|
140
|
+
- **BERT based long text embedding**
|
53
141
|
|
54
142
|
```python
|
55
143
|
from deeplotx import LongTextEncoder
|
56
144
|
|
57
|
-
# 块大小为 448 个 tokens, 块间重叠部分为 32 个 tokens.
|
58
145
|
encoder = LongTextEncoder(
|
59
146
|
chunk_size=448,
|
60
147
|
overlapping=32
|
61
148
|
)
|
62
|
-
# 对 "我是吴子豪, 这是一个测试文本." 计算嵌入, 并堆叠.
|
63
149
|
encoder.encode('我是吴子豪, 这是一个测试文本.', flatten=False)
|
64
150
|
```
|
65
151
|
|
66
|
-
|
152
|
+
stdout:
|
67
153
|
```
|
68
154
|
tensor([ 2.2316e-01, 2.0300e-01, ..., 1.5578e-01, -6.6735e-02])
|
69
155
|
```
|
70
156
|
|
71
|
-
-
|
157
|
+
- **Longformer based long text embedding**
|
72
158
|
|
73
159
|
```python
|
74
160
|
from deeplotx import LongformerEncoder
|
75
161
|
|
76
162
|
encoder = LongformerEncoder()
|
77
|
-
encoder.encode('
|
163
|
+
encoder.encode('Thank you for using DeepLoTX.')
|
78
164
|
```
|
79
165
|
|
80
|
-
|
166
|
+
stdout:
|
81
167
|
```
|
82
168
|
tensor([-2.7490e-02, 6.6503e-02, ..., -6.5937e-02, 6.7802e-03])
|
83
169
|
```
|
84
170
|
|
85
|
-
- ###
|
171
|
+
- ### Similarities calculation
|
86
172
|
|
87
|
-
-
|
173
|
+
- **Vector based**
|
88
174
|
|
89
175
|
```python
|
90
176
|
import deeplotx.similarity as sim
|
91
177
|
|
92
178
|
vector_0, vector_1 = [1, 2, 3, 4], [4, 3, 2, 1]
|
93
|
-
# 欧几里得距离
|
94
179
|
distance_0 = sim.euclidean_similarity(vector_0, vector_1)
|
95
180
|
print(distance_0)
|
96
|
-
# 余弦距离
|
97
181
|
distance_1 = sim.cosine_similarity(vector_0, vector_1)
|
98
182
|
print(distance_1)
|
99
|
-
# 切比雪夫距离
|
100
183
|
distance_2 = sim.chebyshev_similarity(vector_0, vector_1)
|
101
184
|
print(distance_2)
|
102
185
|
```
|
103
186
|
|
104
|
-
|
187
|
+
stdout:
|
105
188
|
```
|
106
189
|
4.47213595499958
|
107
190
|
0.33333333333333337
|
108
191
|
3
|
109
192
|
```
|
110
193
|
|
111
|
-
-
|
194
|
+
- **Set based**
|
112
195
|
|
113
196
|
```python
|
114
197
|
import deeplotx.similarity as sim
|
115
198
|
|
116
199
|
set_0, set_1 = {1, 2, 3, 4}, {4, 5, 6, 7}
|
117
|
-
# 杰卡德距离
|
118
200
|
distance_0 = sim.jaccard_similarity(set_0, set_1)
|
119
201
|
print(distance_0)
|
120
|
-
# Ochiai 距离
|
121
202
|
distance_1 = sim.ochiai_similarity(set_0, set_1)
|
122
203
|
print(distance_1)
|
123
|
-
# Dice 系数
|
124
204
|
distance_2 = sim.dice_coefficient(set_0, set_1)
|
125
205
|
print(distance_2)
|
126
|
-
# Overlap 系数
|
127
206
|
distance_3 = sim.overlap_coefficient(set_0, set_1)
|
128
207
|
print(distance_3)
|
129
208
|
```
|
130
209
|
|
131
|
-
|
210
|
+
stdout:
|
132
211
|
```
|
133
212
|
0.1428571428572653
|
134
213
|
0.2500000000001875
|
@@ -136,27 +215,23 @@ Dynamic: license-file
|
|
136
215
|
0.2500000000001875
|
137
216
|
```
|
138
217
|
|
139
|
-
-
|
218
|
+
- **Distribution based**
|
140
219
|
|
141
220
|
```python
|
142
221
|
import deeplotx.similarity as sim
|
143
222
|
|
144
223
|
dist_0, dist_1 = [0.3, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.4]
|
145
|
-
# 交叉熵
|
146
224
|
distance_0 = sim.cross_entropy(dist_0, dist_1)
|
147
225
|
print(distance_0)
|
148
|
-
# KL 散度
|
149
226
|
distance_1 = sim.kl_divergence(dist_0, dist_1)
|
150
227
|
print(distance_1)
|
151
|
-
# JS 散度
|
152
228
|
distance_2 = sim.js_divergence(dist_0, dist_1)
|
153
229
|
print(distance_2)
|
154
|
-
# Hellinger 距离
|
155
230
|
distance_3 = sim.hellinger_distance(dist_0, dist_1)
|
156
231
|
print(distance_3)
|
157
232
|
```
|
158
233
|
|
159
|
-
|
234
|
+
stdout:
|
160
235
|
```
|
161
236
|
0.3575654913778237
|
162
237
|
0.15040773967762736
|
@@ -164,27 +239,27 @@ Dynamic: license-file
|
|
164
239
|
0.20105866986400994
|
165
240
|
```
|
166
241
|
|
167
|
-
- ###
|
242
|
+
- ### Pre-defined neural networks
|
168
243
|
|
169
244
|
```python
|
170
245
|
from deeplotx import (
|
171
|
-
FeedForward,
|
172
|
-
MultiHeadFeedForward,
|
173
|
-
LinearRegression,
|
174
|
-
LogisticRegression,
|
175
|
-
SoftmaxRegression,
|
176
|
-
RecursiveSequential,
|
177
|
-
LongContextRecursiveSequential,
|
178
|
-
RoPE,
|
179
|
-
Attention,
|
180
|
-
MultiHeadAttention,
|
181
|
-
RoFormerEncoder,
|
182
|
-
AutoRegression,
|
183
|
-
LongContextAutoRegression
|
246
|
+
FeedForward,
|
247
|
+
MultiHeadFeedForward,
|
248
|
+
LinearRegression,
|
249
|
+
LogisticRegression,
|
250
|
+
SoftmaxRegression,
|
251
|
+
RecursiveSequential,
|
252
|
+
LongContextRecursiveSequential,
|
253
|
+
RoPE,
|
254
|
+
Attention,
|
255
|
+
MultiHeadAttention,
|
256
|
+
RoFormerEncoder,
|
257
|
+
AutoRegression,
|
258
|
+
LongContextAutoRegression
|
184
259
|
)
|
185
260
|
```
|
186
261
|
|
187
|
-
|
262
|
+
The fundamental FFN (MLPs):
|
188
263
|
|
189
264
|
```python
|
190
265
|
from typing_extensions import override
|
@@ -242,7 +317,7 @@ Dynamic: license-file
|
|
242
317
|
return x
|
243
318
|
```
|
244
319
|
|
245
|
-
|
320
|
+
Attention:
|
246
321
|
|
247
322
|
```python
|
248
323
|
from typing_extensions import override
|
@@ -295,46 +370,34 @@ Dynamic: license-file
|
|
295
370
|
return torch.matmul(self._attention(x, y, mask), v)
|
296
371
|
```
|
297
372
|
|
298
|
-
- ###
|
373
|
+
- ### Text binary classification task with predefined trainer
|
299
374
|
|
300
375
|
```python
|
301
376
|
from deeplotx import TextBinaryClassifierTrainer, LongTextEncoder
|
302
377
|
from deeplotx.util import get_files, read_file
|
303
378
|
|
304
|
-
# 定义向量编码策略 (默认使用 FacebookAI/xlm-roberta-base 作为嵌入模型)
|
305
379
|
long_text_encoder = LongTextEncoder(
|
306
|
-
max_length=2048,
|
307
|
-
chunk_size=448,
|
308
|
-
overlapping=32,
|
309
|
-
cache_capacity=512
|
380
|
+
max_length=2048,
|
381
|
+
chunk_size=448,
|
382
|
+
overlapping=32,
|
383
|
+
cache_capacity=512
|
310
384
|
)
|
311
|
-
|
312
385
|
trainer = TextBinaryClassifierTrainer(
|
313
386
|
long_text_encoder=long_text_encoder,
|
314
387
|
batch_size=2,
|
315
|
-
train_ratio=0.9
|
388
|
+
train_ratio=0.9
|
316
389
|
)
|
317
|
-
|
318
|
-
# 读取数据
|
319
390
|
pos_data_path = 'path/to/pos_dir'
|
320
391
|
neg_data_path = 'path/to/neg_dir'
|
321
392
|
pos_data = [read_file(x) for x in get_files(pos_data_path)]
|
322
393
|
neg_data = [read_file(x) for x in get_files(neg_data_path)]
|
323
|
-
|
324
|
-
# 开始训练
|
325
394
|
model = trainer.train(pos_data, neg_data,
|
326
395
|
num_epochs=36, learning_rate=2e-5,
|
327
396
|
balancing_dataset=True, alpha=1e-4,
|
328
|
-
rho=.2, encoder_layers=2,
|
329
|
-
attn_heads=8,
|
330
|
-
recursive_layers=2)
|
331
|
-
|
332
|
-
# 保存模型权重
|
397
|
+
rho=.2, encoder_layers=2,
|
398
|
+
attn_heads=8,
|
399
|
+
recursive_layers=2)
|
333
400
|
model.save(model_name='test_model', model_dir='model')
|
334
|
-
|
335
|
-
# 加载已保存的模型
|
336
401
|
model = model.load(model_name='test_model', model_dir='model')
|
337
|
-
|
338
|
-
# 使用训练好的模型进行预测
|
339
402
|
model.predict(long_text_encoder.encode('这是一个测试文本.', flatten=False))
|
340
403
|
```
|
@@ -1,8 +1,13 @@
|
|
1
|
-
deeplotx/__init__.py,sha256=
|
1
|
+
deeplotx/__init__.py,sha256=x4CbJuW20al6S5KkKyrReeuwNGv04JGoqtGUyx-ACtg,1356
|
2
2
|
deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
|
3
3
|
deeplotx/encoder/encoder.py,sha256=wVRl3p_7eg7qT_tJEit5qnmZx7dXkMVLxAtao5vImkk,4201
|
4
4
|
deeplotx/encoder/long_text_encoder.py,sha256=4oRa9FqfGNZ8-gq14UKuhDkZC0A1Xi-wKmbQsn-uZ58,3966
|
5
5
|
deeplotx/encoder/longformer_encoder.py,sha256=7Lm65AUD3qwbrzrhJ3dPZkyHeNRSapga3f-5QJCxV5A,3538
|
6
|
+
deeplotx/ner/__init__.py,sha256=Rss1pup9HzHZCG8U9ub8niWa9zRjWCy3Z7zg378KZQg,114
|
7
|
+
deeplotx/ner/base_ner.py,sha256=bAp7R6mawsfO7owBONXtbPN0rzMSltMJVEGGNKhi41A,359
|
8
|
+
deeplotx/ner/bert_ner.py,sha256=RkqHVBY4SBJtHHR0YuR006v5gFmAaKqJCCKkOOs9ulY,5458
|
9
|
+
deeplotx/ner/named_entity.py,sha256=c6XufIwH6yloJ-ccUjagf4mBl1XbbYDT8xyEJJ_-ZNs,269
|
10
|
+
deeplotx/ner/n2g/__init__.py,sha256=b6fOWJVLaOCtoz8Qlp8NWQbL5lUSbn6H3-8fnVNIPi0,3940
|
6
11
|
deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
|
7
12
|
deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
|
8
13
|
deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
|
@@ -28,8 +33,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
|
|
28
33
|
deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
|
29
34
|
deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
|
30
35
|
deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
|
31
|
-
deeplotx-0.
|
32
|
-
deeplotx-0.
|
33
|
-
deeplotx-0.
|
34
|
-
deeplotx-0.
|
35
|
-
deeplotx-0.
|
36
|
+
deeplotx-0.9.2.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
|
37
|
+
deeplotx-0.9.2.dist-info/METADATA,sha256=lA_h92G6v6cT3ff94pmxVAi0LLj-qO2qrEjAVLFTYHw,13472
|
38
|
+
deeplotx-0.9.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
39
|
+
deeplotx-0.9.2.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
|
40
|
+
deeplotx-0.9.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|