deeplotx 0.8.8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deeplotx/__init__.py CHANGED
@@ -4,6 +4,8 @@ import os
4
4
  __ROOT__ = os.path.dirname(os.path.abspath(__file__))
5
5
 
6
6
  from .encoder import Encoder, LongTextEncoder, LongformerEncoder
7
+ from .ner.n2g import Name2Gender, Gender
8
+ from .ner import BertNER, NamedEntity, NamedPerson
7
9
  from .nn import (
8
10
  FeedForward,
9
11
  MultiHeadFeedForward,
@@ -40,3 +42,5 @@ logger = logging.getLogger('deeplotx.trainer')
40
42
  logger.setLevel(logging.DEBUG)
41
43
  logger = logging.getLogger('deeplotx.embedding')
42
44
  logger.setLevel(logging.DEBUG)
45
+ logger = logging.getLogger('deeplotx.ner')
46
+ logger.setLevel(logging.DEBUG)
@@ -0,0 +1,3 @@
1
+ from .named_entity import NamedEntity, NamedPerson
2
+ from .base_ner import BaseNER
3
+ from .bert_ner import BertNER
@@ -0,0 +1,10 @@
1
+ from deeplotx.ner.named_entity import NamedEntity, NamedPerson
2
+
3
+
4
+ class BaseNER:
5
+ def __init__(self): ...
6
+
7
+ def __call__(self, s: str, *args, **kwargs) -> list[NamedEntity | NamedPerson]:
8
+ return self.extract_entities(s=s, *args, **kwargs)
9
+
10
+ def extract_entities(self, s: str, *args, **kwargs) -> list[NamedEntity | NamedPerson]: ...
@@ -0,0 +1,91 @@
1
+ import logging
2
+ import os
3
+ from requests.exceptions import ConnectTimeout, SSLError
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
7
+
8
+ from deeplotx import __ROOT__
9
+ from deeplotx.ner.n2g import Name2Gender
10
+ from deeplotx.ner.base_ner import BaseNER
11
+ from deeplotx.ner.named_entity import NamedEntity, NamedPerson
12
+
13
+ CACHE_PATH = os.path.join(__ROOT__, '.cache')
14
+ DEFAULT_BERT_NER = 'Davlan/xlm-roberta-base-ner-hrl'
15
+ N2G_MODEL: list[Name2Gender] = []
16
+ logger = logging.getLogger('deeplotx.ner')
17
+
18
+
19
+ class BertNER(BaseNER):
20
+ def __init__(self, model_name_or_path: str = DEFAULT_BERT_NER, device: str | None = None):
21
+ super().__init__()
22
+ self.device = torch.device(device) if device is not None \
23
+ else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ try:
25
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
26
+ cache_dir=CACHE_PATH, _from_auto=True,
27
+ trust_remote_code=True)
28
+ self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
29
+ cache_dir=CACHE_PATH, _from_auto=True,
30
+ trust_remote_code=True).to(self.device)
31
+ except ConnectTimeout:
32
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
33
+ cache_dir=CACHE_PATH, _from_auto=True,
34
+ trust_remote_code=True, local_files_only=True)
35
+ self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
36
+ cache_dir=CACHE_PATH, _from_auto=True,
37
+ trust_remote_code=True, local_files_only=True).to(self.device)
38
+ except SSLError:
39
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
40
+ cache_dir=CACHE_PATH, _from_auto=True,
41
+ trust_remote_code=True, local_files_only=True)
42
+ self.encoder = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=model_name_or_path,
43
+ cache_dir=CACHE_PATH, _from_auto=True,
44
+ trust_remote_code=True, local_files_only=True).to(self.device)
45
+ self.embed_dim = self.encoder.config.max_position_embeddings
46
+ self._ner_pipeline = pipeline(task='ner', model=self.encoder, tokenizer=self.tokenizer, trust_remote_code=True)
47
+ logger.debug(f'{BaseNER.__name__} initialized on device: {self.device}.')
48
+
49
+ def extract_entities(self, s: str, with_gender: bool = True, prob_threshold: float = .0, *args, **kwargs) -> list[NamedEntity]:
50
+ assert prob_threshold <= 1., f'prob_threshold ({prob_threshold}) cannot be larger than 1.'
51
+ s = ' ' + s
52
+ raw_entities = self._ner_pipeline(s)
53
+ entities = []
54
+ for ent in raw_entities:
55
+ entities.append([s[ent['start']: ent['end']], ent['entity'], ent['score'].item()])
56
+ while True:
57
+ for i, ent in enumerate(entities):
58
+ if len(ent[0].strip()) < 1:
59
+ del entities[i]
60
+ if ent[1].upper().startswith('I') and entities[i - 1][1].upper().startswith('B'):
61
+ entities[i - 1][0] += ent[0]
62
+ entities[i - 1][2] *= ent[2]
63
+ del entities[i]
64
+ _continue = False
65
+ for ent in entities:
66
+ if ent[1].upper().startswith('I'):
67
+ _continue = True
68
+ if not _continue:
69
+ break
70
+ for ent in entities:
71
+ ent[0] = ent[0].strip()
72
+ if ent[1].upper().startswith('B'):
73
+ ent[1] = ent[1].upper()[1:].strip('-')
74
+ entities = [NamedEntity(*_) for _ in entities if _[2] >= prob_threshold]
75
+ if not with_gender:
76
+ return entities
77
+ if len(N2G_MODEL) < 1:
78
+ N2G_MODEL.append(Name2Gender())
79
+ n2g_model = N2G_MODEL[0]
80
+ for i, ent in enumerate(entities):
81
+ if ent.type.upper() == 'PER':
82
+ gender, gender_prob = n2g_model(ent.text, return_probability=True)
83
+ entities[i] = NamedPerson(text=ent.text,
84
+ type=ent.type,
85
+ base_probability=ent.base_probability,
86
+ gender=gender,
87
+ gender_probability=gender_prob)
88
+ return entities
89
+
90
+ def __call__(self, s: str, with_gender: bool = True, prob_threshold: float = .0, *args, **kwargs):
91
+ return self.extract_entities(s=s, with_gender=with_gender, prob_threshold=prob_threshold, *args, **kwargs)
@@ -0,0 +1,91 @@
1
+ import os
2
+
3
+ import requests
4
+ import gdown
5
+ import torch
6
+ from name4py import Gender
7
+
8
+ from deeplotx import __ROOT__
9
+ from deeplotx.encoder.encoder import Encoder
10
+ from deeplotx.nn.logistic_regression import LogisticRegression
11
+ from deeplotx.nn.base_neural_network import BaseNeuralNetwork
12
+
13
+
14
+ __CACHE_DIR__ = os.path.join(__ROOT__, '.cache', '.n2g')
15
+ ENCODER = Encoder(model_name_or_path='FacebookAI/xlm-roberta-base')
16
+ BASE_MODEL = 'name2gender-base'
17
+ SMALL_MODEL = 'name2gender-small'
18
+ _MIN_FILE_SIZE = 1024 * 5
19
+
20
+
21
+ def download_model(model_name: str):
22
+ quiet = bool(os.getenv('QUIET_DOWNLOAD', False))
23
+ os.makedirs(__CACHE_DIR__, exist_ok=True)
24
+ _proxies = {
25
+ 'http': os.getenv('HTTP_PROXY', os.getenv('http_proxy')),
26
+ 'https': os.getenv('HTTPS_PROXY', os.getenv('https_proxy'))
27
+ }
28
+ model_name = f'{model_name}.dlx'
29
+ model_path = os.path.join(__CACHE_DIR__, model_name)
30
+ base_url = 'https://github.com/vortezwohl/Name2Gender'
31
+ if not os.path.exists(model_path):
32
+ url = f'{base_url}/releases/download/RESOURCE/{model_name}'
33
+ if requests.get(url=base_url, proxies=_proxies).status_code == 200:
34
+ try:
35
+ gdown.download(
36
+ url=url,
37
+ output=model_path,
38
+ quiet=quiet,
39
+ proxy=_proxies.get('https'),
40
+ speed=8192 * 1024,
41
+ resume=True
42
+ )
43
+ if os.path.getsize(model_path) < _MIN_FILE_SIZE:
44
+ raise FileNotFoundError(f"Model \"{model_name}\" doesn't exists.")
45
+ except Exception as e:
46
+ os.remove(model_path)
47
+ raise e
48
+ else:
49
+ raise ConnectionError(f'Failed to download model {model_name}.')
50
+
51
+
52
+ def load_model(model_name: str = 'name2gender-small', dtype: torch.dtype | None = torch.float16) -> BaseNeuralNetwork:
53
+ n2g_model = None
54
+ match model_name:
55
+ case 'name2gender-base' | 'n2g-base' | 'base':
56
+ download_model(BASE_MODEL)
57
+ n2g_model = LogisticRegression(input_dim=768, output_dim=1,
58
+ num_heads=12, num_layers=4,
59
+ head_layers=1, expansion_factor=2,
60
+ model_name=BASE_MODEL, dtype=dtype)
61
+ case 'name2gender-small' | 'n2g-base' | 'small':
62
+ download_model(SMALL_MODEL)
63
+ n2g_model = LogisticRegression(input_dim=768, output_dim=1,
64
+ num_heads=6, num_layers=2,
65
+ head_layers=1, expansion_factor=1.5,
66
+ model_name=SMALL_MODEL, dtype=dtype)
67
+ case _:
68
+ download_model(SMALL_MODEL)
69
+ n2g_model = LogisticRegression(input_dim=768, output_dim=1,
70
+ num_heads=6, num_layers=2,
71
+ head_layers=1, expansion_factor=1.5,
72
+ model_name=SMALL_MODEL, dtype=dtype)
73
+ return n2g_model.load(model_dir=__CACHE_DIR__)
74
+
75
+
76
+ class Name2Gender:
77
+ def __init__(self, model: BaseNeuralNetwork | None = None):
78
+ super().__init__()
79
+ if model is None:
80
+ model = load_model(SMALL_MODEL)
81
+ self._model = model
82
+
83
+ def __call__(self, name: str, return_probability: bool = False, threshold: float = .5) -> tuple[Gender, float] | Gender:
84
+ assert len(name) > 0, f'name ({name}) cannot be empty.'
85
+ name = f'{name[0].upper()}{name[1:]}'
86
+ emb = ENCODER.encode(name)
87
+ prob = self._model.predict(emb).item()
88
+ gender = Gender.Male if prob >= threshold else Gender.Female
89
+ if return_probability:
90
+ return gender, prob if gender == Gender.Male else (1. - prob)
91
+ return gender
@@ -0,0 +1,16 @@
1
+ from dataclasses import dataclass
2
+
3
+ from deeplotx.ner.n2g import Gender
4
+
5
+
6
+ @dataclass
7
+ class NamedEntity:
8
+ text: str
9
+ type: str
10
+ base_probability: float
11
+
12
+
13
+ @dataclass
14
+ class NamedPerson(NamedEntity):
15
+ gender: Gender
16
+ gender_probability: float
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: deeplotx
3
- Version: 0.8.8
4
- Summary: Easy-2-use long text NLP toolkit.
3
+ Version: 0.9.2
4
+ Summary: An out-of-the-box long-text NLP framework.
5
5
  Requires-Python: >=3.10
6
6
  Description-Content-Type: text/markdown
7
7
  License-File: LICENSE
@@ -10,125 +10,204 @@ Requires-Dist: jupyter
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: protobuf
12
12
  Requires-Dist: python-dotenv
13
+ Requires-Dist: sentencepiece
13
14
  Requires-Dist: tiktoken
14
15
  Requires-Dist: torch
15
16
  Requires-Dist: transformers
16
17
  Requires-Dist: typing-extensions
17
18
  Requires-Dist: vortezwohl>=0.0.8
19
+ Requires-Dist: name4py>=0.1.4
18
20
  Dynamic: license-file
19
21
 
20
22
  [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/vortezwohl/DeepLoTX)
21
23
 
22
- # Deep Long Text Learning Kit
24
+ # *Deep Long Text Learning*
23
25
 
24
- > Author: 吴子豪
26
+ *An out-of-the-box long-text NLP framework.*
25
27
 
26
- **开箱即用的长文本语义建模框架**
28
+ > Author: [vortezwohl](https://github.com/vortezwohl)
27
29
 
28
- ## 安装
30
+ ## Installation
29
31
 
30
- - 使用 pip
32
+ - **With pip**
31
33
 
32
34
  ```
33
35
  pip install -U deeplotx
34
36
  ```
35
37
 
36
- - 使用 uv (推荐)
38
+ - **With uv (recommended)**
37
39
 
38
40
  ```
39
41
  uv add -U deeplotx
40
42
  ```
41
43
 
42
- - github 安装最新特性
44
+ - **Get the latest features from GitHub**
43
45
 
44
46
  ```
45
47
  pip install -U git+https://github.com/vortezwohl/DeepLoTX.git
46
48
  ```
47
49
 
48
- ## 核心功能
50
+ ## Quick start
49
51
 
50
- - ### 长文本嵌入
52
+ - ### Named entity recognition
51
53
 
52
- - **基于通用 BERT 的长文本嵌入** (最大支持长度, 无限长, 可通过 max_length 限制长度)
54
+ > *Multilingual is supported.*
55
+
56
+ > *Gender recognition is supported.*
57
+
58
+ Import dependencies
59
+
60
+ ```python
61
+ from deeplotx import BertNER
62
+
63
+ ner = BertNER()
64
+ ```
65
+
66
+ ```python
67
+ ner('你好, 我的名字是吴子豪, 来自福建福州.')
68
+ ```
69
+
70
+ stdout:
71
+
72
+ ```
73
+ [NamedPerson(text='吴子豪', type='PER', base_probability=0.9995428418719051, gender=<Gender.Male: 'male'>, gender_probability=0.9970703125),
74
+ NamedEntity(text='福建', type='LOC', base_probability=0.9986373782157898),
75
+ NamedEntity(text='福州', type='LOC', base_probability=0.9993632435798645)]
76
+ ```
77
+
78
+ ```python
79
+ ner("Hi, i'm Vortez Wohl, author of DeeploTX.")
80
+ ```
81
+
82
+ stdout:
83
+
84
+ ```
85
+ [NamedPerson(text='Vortez Wohl', type='PER', base_probability=0.9991965342072855, gender=<Gender.Male: 'male'>, gender_probability=0.87255859375)]
86
+ ```
87
+
88
+ - ### Gender recognition
89
+
90
+ > *Multilingual is supported.*
91
+
92
+ > *Integrated from [Name2Gender](https://github.com/vortezwohl/Name2Gender)*
93
+
94
+ Import dependencies
95
+
96
+ ```python
97
+ from deeplotx import Name2Gender
98
+
99
+ n2g = Name2Gender()
100
+ ```
101
+
102
+ Recognize gender of "Elon Musk":
103
+
104
+ ```python
105
+ n2g('Elon Musk')
106
+ ```
107
+
108
+ stdout:
109
+
110
+ ```
111
+ <Gender.Male: 'male'>
112
+ ```
113
+
114
+ Recognize gender of "Anne Hathaway":
115
+
116
+ ```python
117
+ n2g('Anne Hathaway')
118
+ ```
119
+
120
+ stdout:
121
+
122
+ ```
123
+ <Gender.Female: 'female'>
124
+ ```
125
+
126
+ Recognize gender of "吴彦祖":
127
+
128
+ ```python
129
+ n2g('吴彦祖', return_probability=True)
130
+ ```
131
+
132
+ stdout:
133
+
134
+ ```
135
+ (<Gender.Male: 'male'>, 1.0)
136
+ ```
137
+
138
+ - ### Long text embedding
139
+
140
+ - **BERT based long text embedding**
53
141
 
54
142
  ```python
55
143
  from deeplotx import LongTextEncoder
56
144
 
57
- # 块大小为 448 个 tokens, 块间重叠部分为 32 个 tokens.
58
145
  encoder = LongTextEncoder(
59
146
  chunk_size=448,
60
147
  overlapping=32
61
148
  )
62
- # 对 "我是吴子豪, 这是一个测试文本." 计算嵌入, 并堆叠.
63
149
  encoder.encode('我是吴子豪, 这是一个测试文本.', flatten=False)
64
150
  ```
65
151
 
66
- 输出:
152
+ stdout:
67
153
  ```
68
154
  tensor([ 2.2316e-01, 2.0300e-01, ..., 1.5578e-01, -6.6735e-02])
69
155
  ```
70
156
 
71
- - **基于 Longformer 的长文本嵌入** (最大支持长度 4096 个 tokens)
157
+ - **Longformer based long text embedding**
72
158
 
73
159
  ```python
74
160
  from deeplotx import LongformerEncoder
75
161
 
76
162
  encoder = LongformerEncoder()
77
- encoder.encode('我是吴子豪, 这是一个测试文本.')
163
+ encoder.encode('Thank you for using DeepLoTX.')
78
164
  ```
79
165
 
80
- 输出:
166
+ stdout:
81
167
  ```
82
168
  tensor([-2.7490e-02, 6.6503e-02, ..., -6.5937e-02, 6.7802e-03])
83
169
  ```
84
170
 
85
- - ### 相似性计算
171
+ - ### Similarities calculation
86
172
 
87
- - **基于向量的相似性**
173
+ - **Vector based**
88
174
 
89
175
  ```python
90
176
  import deeplotx.similarity as sim
91
177
 
92
178
  vector_0, vector_1 = [1, 2, 3, 4], [4, 3, 2, 1]
93
- # 欧几里得距离
94
179
  distance_0 = sim.euclidean_similarity(vector_0, vector_1)
95
180
  print(distance_0)
96
- # 余弦距离
97
181
  distance_1 = sim.cosine_similarity(vector_0, vector_1)
98
182
  print(distance_1)
99
- # 切比雪夫距离
100
183
  distance_2 = sim.chebyshev_similarity(vector_0, vector_1)
101
184
  print(distance_2)
102
185
  ```
103
186
 
104
- 输出:
187
+ stdout:
105
188
  ```
106
189
  4.47213595499958
107
190
  0.33333333333333337
108
191
  3
109
192
  ```
110
193
 
111
- - **基于集合的相似性**
194
+ - **Set based**
112
195
 
113
196
  ```python
114
197
  import deeplotx.similarity as sim
115
198
 
116
199
  set_0, set_1 = {1, 2, 3, 4}, {4, 5, 6, 7}
117
- # 杰卡德距离
118
200
  distance_0 = sim.jaccard_similarity(set_0, set_1)
119
201
  print(distance_0)
120
- # Ochiai 距离
121
202
  distance_1 = sim.ochiai_similarity(set_0, set_1)
122
203
  print(distance_1)
123
- # Dice 系数
124
204
  distance_2 = sim.dice_coefficient(set_0, set_1)
125
205
  print(distance_2)
126
- # Overlap 系数
127
206
  distance_3 = sim.overlap_coefficient(set_0, set_1)
128
207
  print(distance_3)
129
208
  ```
130
209
 
131
- 输出:
210
+ stdout:
132
211
  ```
133
212
  0.1428571428572653
134
213
  0.2500000000001875
@@ -136,27 +215,23 @@ Dynamic: license-file
136
215
  0.2500000000001875
137
216
  ```
138
217
 
139
- - **基于概率分布的相似性**
218
+ - **Distribution based**
140
219
 
141
220
  ```python
142
221
  import deeplotx.similarity as sim
143
222
 
144
223
  dist_0, dist_1 = [0.3, 0.2, 0.1, 0.4], [0.2, 0.1, 0.3, 0.4]
145
- # 交叉熵
146
224
  distance_0 = sim.cross_entropy(dist_0, dist_1)
147
225
  print(distance_0)
148
- # KL 散度
149
226
  distance_1 = sim.kl_divergence(dist_0, dist_1)
150
227
  print(distance_1)
151
- # JS 散度
152
228
  distance_2 = sim.js_divergence(dist_0, dist_1)
153
229
  print(distance_2)
154
- # Hellinger 距离
155
230
  distance_3 = sim.hellinger_distance(dist_0, dist_1)
156
231
  print(distance_3)
157
232
  ```
158
233
 
159
- 输出:
234
+ stdout:
160
235
  ```
161
236
  0.3575654913778237
162
237
  0.15040773967762736
@@ -164,27 +239,27 @@ Dynamic: license-file
164
239
  0.20105866986400994
165
240
  ```
166
241
 
167
- - ### 预定义深度神经网络
242
+ - ### Pre-defined neural networks
168
243
 
169
244
  ```python
170
245
  from deeplotx import (
171
- FeedForward, # 前馈神经网络
172
- MultiHeadFeedForward, # 多头前馈神经网络
173
- LinearRegression, # 线性回归
174
- LogisticRegression, # 逻辑回归 / 二分类 / 多标签分类
175
- SoftmaxRegression, # Softmax 回归 / 多分类
176
- RecursiveSequential, # 序列模型 / 循环神经网络
177
- LongContextRecursiveSequential, # 长上下文序列模型 / 自注意力融合循环神经网络
178
- RoPE, # RoPE 位置编码
179
- Attention, # 自注意力 / 交叉注意力
180
- MultiHeadAttention, # 并行多头注意力
181
- RoFormerEncoder, # Roformer (Transformer + RoPE) 编码器模型
182
- AutoRegression, # 自回归模型 / 循环神经网络
183
- LongContextAutoRegression # 长上下文自回归模型 / 自注意力融合循环神经网络
246
+ FeedForward,
247
+ MultiHeadFeedForward,
248
+ LinearRegression,
249
+ LogisticRegression,
250
+ SoftmaxRegression,
251
+ RecursiveSequential,
252
+ LongContextRecursiveSequential,
253
+ RoPE,
254
+ Attention,
255
+ MultiHeadAttention,
256
+ RoFormerEncoder,
257
+ AutoRegression,
258
+ LongContextAutoRegression
184
259
  )
185
260
  ```
186
261
 
187
- 基础网络结构:
262
+ The fundamental FFN (MLPs):
188
263
 
189
264
  ```python
190
265
  from typing_extensions import override
@@ -242,7 +317,7 @@ Dynamic: license-file
242
317
  return x
243
318
  ```
244
319
 
245
- 注意力模块:
320
+ Attention:
246
321
 
247
322
  ```python
248
323
  from typing_extensions import override
@@ -295,46 +370,34 @@ Dynamic: license-file
295
370
  return torch.matmul(self._attention(x, y, mask), v)
296
371
  ```
297
372
 
298
- - ### 使用预定义训练器实现文本二分类任务
373
+ - ### Text binary classification task with predefined trainer
299
374
 
300
375
  ```python
301
376
  from deeplotx import TextBinaryClassifierTrainer, LongTextEncoder
302
377
  from deeplotx.util import get_files, read_file
303
378
 
304
- # 定义向量编码策略 (默认使用 FacebookAI/xlm-roberta-base 作为嵌入模型)
305
379
  long_text_encoder = LongTextEncoder(
306
- max_length=2048, # 最大文本大小, 超出截断
307
- chunk_size=448, # 块大小 (按 Token 计)
308
- overlapping=32, # 块间重叠大小 (按 Token 计)
309
- cache_capacity=512 # 缓存大小
380
+ max_length=2048,
381
+ chunk_size=448,
382
+ overlapping=32,
383
+ cache_capacity=512
310
384
  )
311
-
312
385
  trainer = TextBinaryClassifierTrainer(
313
386
  long_text_encoder=long_text_encoder,
314
387
  batch_size=2,
315
- train_ratio=0.9 # 训练集和验证集比例
388
+ train_ratio=0.9
316
389
  )
317
-
318
- # 读取数据
319
390
  pos_data_path = 'path/to/pos_dir'
320
391
  neg_data_path = 'path/to/neg_dir'
321
392
  pos_data = [read_file(x) for x in get_files(pos_data_path)]
322
393
  neg_data = [read_file(x) for x in get_files(neg_data_path)]
323
-
324
- # 开始训练
325
394
  model = trainer.train(pos_data, neg_data,
326
395
  num_epochs=36, learning_rate=2e-5,
327
396
  balancing_dataset=True, alpha=1e-4,
328
- rho=.2, encoder_layers=2, # 2 层 Roformer 编码器
329
- attn_heads=8, # 8 个注意力头
330
- recursive_layers=2) # 2 层 Bi-LSTM
331
-
332
- # 保存模型权重
397
+ rho=.2, encoder_layers=2,
398
+ attn_heads=8,
399
+ recursive_layers=2)
333
400
  model.save(model_name='test_model', model_dir='model')
334
-
335
- # 加载已保存的模型
336
401
  model = model.load(model_name='test_model', model_dir='model')
337
-
338
- # 使用训练好的模型进行预测
339
402
  model.predict(long_text_encoder.encode('这是一个测试文本.', flatten=False))
340
403
  ```
@@ -1,8 +1,13 @@
1
- deeplotx/__init__.py,sha256=xEq8WQ2LpEZoLX_Z464d0dy4aemFGrEV6ZMJr6ioFnQ,1186
1
+ deeplotx/__init__.py,sha256=x4CbJuW20al6S5KkKyrReeuwNGv04JGoqtGUyx-ACtg,1356
2
2
  deeplotx/encoder/__init__.py,sha256=BrsF5_4O-4pfihYF2wjExDOoAY-03kGJTH-Mhez4tsE,129
3
3
  deeplotx/encoder/encoder.py,sha256=wVRl3p_7eg7qT_tJEit5qnmZx7dXkMVLxAtao5vImkk,4201
4
4
  deeplotx/encoder/long_text_encoder.py,sha256=4oRa9FqfGNZ8-gq14UKuhDkZC0A1Xi-wKmbQsn-uZ58,3966
5
5
  deeplotx/encoder/longformer_encoder.py,sha256=7Lm65AUD3qwbrzrhJ3dPZkyHeNRSapga3f-5QJCxV5A,3538
6
+ deeplotx/ner/__init__.py,sha256=Rss1pup9HzHZCG8U9ub8niWa9zRjWCy3Z7zg378KZQg,114
7
+ deeplotx/ner/base_ner.py,sha256=bAp7R6mawsfO7owBONXtbPN0rzMSltMJVEGGNKhi41A,359
8
+ deeplotx/ner/bert_ner.py,sha256=RkqHVBY4SBJtHHR0YuR006v5gFmAaKqJCCKkOOs9ulY,5458
9
+ deeplotx/ner/named_entity.py,sha256=c6XufIwH6yloJ-ccUjagf4mBl1XbbYDT8xyEJJ_-ZNs,269
10
+ deeplotx/ner/n2g/__init__.py,sha256=b6fOWJVLaOCtoz8Qlp8NWQbL5lUSbn6H3-8fnVNIPi0,3940
6
11
  deeplotx/nn/__init__.py,sha256=YILwbxb-NHdiJjfOwBKH8F7PuZSDZSrGpTznPDucTro,710
7
12
  deeplotx/nn/attention.py,sha256=R-i-Rd7gnsh6hwXDeYfqLQOJvfSZIGfQbFzRlC91XLo,2879
8
13
  deeplotx/nn/auto_regression.py,sha256=j_R7WGPq9REngjpLuX5c0AaNqOpgGm2Vfrolw-XjWXw,877
@@ -28,8 +33,8 @@ deeplotx/trainer/text_binary_classification_trainer.py,sha256=TFxOX8rWU_zKliI9zm
28
33
  deeplotx/util/__init__.py,sha256=5CH4MTeSgsmCe3LPMfvKoSBpwh6jDSBuHVElJvzQzgs,90
29
34
  deeplotx/util/hash.py,sha256=qbNU3RLBWGQYFVte9WZBAkZ1BkdjCXiKLDaKPN54KFk,662
30
35
  deeplotx/util/read_file.py,sha256=ptzouvEQeeW8KU5BrWNJlXw-vFXVrpS9SkAUxsu6A8A,612
31
- deeplotx-0.8.8.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
32
- deeplotx-0.8.8.dist-info/METADATA,sha256=EhDWaoH6HmlnNga9c7VitZiBerZ2GAXXrz3BWG9latc,13163
33
- deeplotx-0.8.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- deeplotx-0.8.8.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
35
- deeplotx-0.8.8.dist-info/RECORD,,
36
+ deeplotx-0.9.2.dist-info/licenses/LICENSE,sha256=IwGE9guuL-ryRPEKi6wFPI_zOhg7zDZbTYuHbSt_SAk,35823
37
+ deeplotx-0.9.2.dist-info/METADATA,sha256=lA_h92G6v6cT3ff94pmxVAi0LLj-qO2qrEjAVLFTYHw,13472
38
+ deeplotx-0.9.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
+ deeplotx-0.9.2.dist-info/top_level.txt,sha256=hKg4pVDXZ-WWxkRfJFczRIll1Sv7VyfKCmzHLXbuh1U,9
40
+ deeplotx-0.9.2.dist-info/RECORD,,