nlpertools 1.0.5__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nlpertools/__init__.py +24 -20
- nlpertools/algo/ac.py +18 -0
- nlpertools/algo/bit_ops.py +28 -0
- nlpertools/algo/kmp.py +94 -55
- nlpertools/algo/num_ops.py +12 -0
- nlpertools/algo/template.py +116 -0
- nlpertools/algo/union.py +13 -0
- nlpertools/data_client.py +387 -257
- nlpertools/data_structure/base_structure.py +109 -13
- nlpertools/dataprocess.py +611 -3
- nlpertools/default_db_config.yml +41 -0
- nlpertools/io/__init__.py +3 -3
- nlpertools/io/dir.py +54 -36
- nlpertools/io/file.py +277 -222
- nlpertools/ml.py +483 -460
- nlpertools/monitor/__init__.py +0 -0
- nlpertools/monitor/gpu.py +18 -0
- nlpertools/monitor/memory.py +24 -0
- nlpertools/movie.py +36 -0
- nlpertools/nlpertools_config.yml +1 -0
- nlpertools/{openApi.py → open_api.py} +65 -65
- nlpertools/other.py +364 -249
- nlpertools/pic.py +288 -0
- nlpertools/plugin.py +43 -43
- nlpertools/reminder.py +98 -87
- nlpertools/utils/__init__.py +3 -3
- nlpertools/utils/lazy.py +727 -0
- nlpertools/utils/log_util.py +20 -0
- nlpertools/utils/package.py +89 -76
- nlpertools/utils/package_v1.py +94 -0
- nlpertools/utils/package_v2.py +117 -0
- nlpertools/utils_for_nlpertools.py +93 -93
- nlpertools/vector_index_demo.py +108 -0
- nlpertools/wrapper.py +161 -96
- {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
- nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
- nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
- {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
- nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
- nlpertools_helper/__init__.py +10 -0
- nlpertools-1.0.5.dist-info/METADATA +0 -85
- nlpertools-1.0.5.dist-info/RECORD +0 -25
- nlpertools-1.0.5.dist-info/top_level.txt +0 -1
nlpertools/ml.py
CHANGED
@@ -1,460 +1,483 @@
|
|
1
|
-
# encoding=utf-8
|
2
|
-
import codecs
|
3
|
-
import os
|
4
|
-
import random
|
5
|
-
|
6
|
-
from .io.dir import j_mkdir
|
7
|
-
from .io.file import readtxt_list_all_strip, writetxt_w_list
|
8
|
-
# import numpy as np
|
9
|
-
# import seaborn as sns
|
10
|
-
# import torch
|
11
|
-
# import torch.nn as nn
|
12
|
-
# import xgboost as xgb
|
13
|
-
# from matplotlib import pyplot as plt
|
14
|
-
# from nltk.stem import WordNetLemmatizer
|
15
|
-
# from sklearn import metrics
|
16
|
-
# from transformers import BertTokenizer, BertForMaskedLM
|
17
|
-
from .utils.package import *
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
#
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
text
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
1
|
+
# encoding=utf-8
|
2
|
+
import codecs
|
3
|
+
import os
|
4
|
+
import random
|
5
|
+
|
6
|
+
from .io.dir import j_mkdir
|
7
|
+
from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
|
8
|
+
# import numpy as np
|
9
|
+
# import seaborn as sns
|
10
|
+
# import torch
|
11
|
+
# import torch.nn as nn
|
12
|
+
# import xgboost as xgb
|
13
|
+
# from matplotlib import pyplot as plt
|
14
|
+
# from nltk.stem import WordNetLemmatizer
|
15
|
+
# from sklearn import metrics
|
16
|
+
# from transformers import BertTokenizer, BertForMaskedLM
|
17
|
+
from .utils.package import *
|
18
|
+
|
19
|
+
|
20
|
+
def calc_llm_train_activation_memory(
|
21
|
+
model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
|
22
|
+
):
|
23
|
+
|
24
|
+
"""
|
25
|
+
return bytes
|
26
|
+
|
27
|
+
reference:
|
28
|
+
1. https://zhuanlan.zhihu.com/p/665172400
|
29
|
+
2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
|
30
|
+
"""
|
31
|
+
# reference1
|
32
|
+
# attention
|
33
|
+
# FFN
|
34
|
+
# Layer Norm
|
35
|
+
r1 = (
|
36
|
+
sequence_length
|
37
|
+
* batch_size
|
38
|
+
* hidden_dim
|
39
|
+
* lay_number
|
40
|
+
* (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
|
41
|
+
)
|
42
|
+
# reference2
|
43
|
+
r2 = (
|
44
|
+
lay_number*(2 * sequence_length * attention_heads_num + 16 * hidden_dim)
|
45
|
+
* sequence_length
|
46
|
+
* batch_size
|
47
|
+
/ gpu_num
|
48
|
+
)
|
49
|
+
print(r1)
|
50
|
+
print(r2)
|
51
|
+
return r1
|
52
|
+
|
53
|
+
|
54
|
+
class DataAnalysis:
|
55
|
+
@staticmethod
|
56
|
+
def draw_pic(df, save_path):
|
57
|
+
"""
|
58
|
+
画直方图,对比两个不同类别差异
|
59
|
+
:param df: pd.DataFrame
|
60
|
+
:param save_path: str
|
61
|
+
:return:
|
62
|
+
"""
|
63
|
+
sns.distplot(df[df["label"] == 1]["feature"], label="label1")
|
64
|
+
sns.distplot(df[df["label"] == 0]["feature"], label="label2")
|
65
|
+
plt.legend()
|
66
|
+
plt.savefig(save_path)
|
67
|
+
|
68
|
+
|
69
|
+
class DataStructure:
|
70
|
+
spo = {
|
71
|
+
"sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
|
72
|
+
"triplets": [
|
73
|
+
{
|
74
|
+
"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
|
75
|
+
"p": {"text": "出版社", "l": 15, "r": 18},
|
76
|
+
"o": {"text": "故宫出版社", "l": 13, "r": 18},
|
77
|
+
}
|
78
|
+
],
|
79
|
+
"source": "baidu",
|
80
|
+
}
|
81
|
+
ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
|
82
|
+
ner_label_example = (
|
83
|
+
list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
|
88
|
+
# 两个句子的jacccard系数
|
89
|
+
# 判断输入来重新定义ipt_level和sim_level
|
90
|
+
|
91
|
+
# a = set(ipt1.split())
|
92
|
+
# b = set(ipt2.split())
|
93
|
+
a = set(ipt1)
|
94
|
+
b = set(ipt2)
|
95
|
+
c = a.intersection(b)
|
96
|
+
# spical situation:
|
97
|
+
if not ipt1 and not ipt2:
|
98
|
+
return 0
|
99
|
+
return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
|
100
|
+
|
101
|
+
|
102
|
+
class STEM(object):
|
103
|
+
def __init__(self, IPT_MODEL_PATH):
|
104
|
+
self.ltp = LTP(IPT_MODEL_PATH)
|
105
|
+
|
106
|
+
def start_by_dep(self, sentence):
|
107
|
+
seg, hidden = self.ltp.seg([sentence])
|
108
|
+
dep = self.ltp.dep(hidden) # , graph=False)
|
109
|
+
seg, dep = seg[0], dep[0]
|
110
|
+
for i in dep:
|
111
|
+
# 主谓宾
|
112
|
+
if "SBV" == i[2]:
|
113
|
+
subject = seg[i[0]]
|
114
|
+
verb = seg[i[1]]
|
115
|
+
if "VOB" in i[2]:
|
116
|
+
if seg[i[1]] == verb:
|
117
|
+
object = seg[i[0]]
|
118
|
+
|
119
|
+
return subject
|
120
|
+
|
121
|
+
return None
|
122
|
+
|
123
|
+
def start_by_srl(self, sentence):
|
124
|
+
"""
|
125
|
+
用语义角色标注工具
|
126
|
+
:param sentence: "他叫汤姆去拿外衣。"
|
127
|
+
:return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
|
128
|
+
"""
|
129
|
+
# 语义角色标注方法
|
130
|
+
seg, hidden = self.ltp.seg([sentence])
|
131
|
+
srl = self.ltp.srl(hidden)
|
132
|
+
seg, srl = seg[0], srl[0]
|
133
|
+
events = []
|
134
|
+
for wdx, each_srl in enumerate(srl):
|
135
|
+
if each_srl:
|
136
|
+
args = []
|
137
|
+
for arg in each_srl:
|
138
|
+
args.extend(seg[arg[1] : arg[2] + 1])
|
139
|
+
# 添加上谓词
|
140
|
+
args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
|
141
|
+
events.append(args)
|
142
|
+
# print(events)
|
143
|
+
return events
|
144
|
+
|
145
|
+
|
146
|
+
# 这个是另一种
|
147
|
+
# 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
|
148
|
+
def subject_object_labeling_new(spo_list, text):
|
149
|
+
pass
|
150
|
+
|
151
|
+
|
152
|
+
# 这个是传统格式的
|
153
|
+
# 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
|
154
|
+
def subject_object_labeling(spo_list, text):
|
155
|
+
# TODO
|
156
|
+
"""
|
157
|
+
百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
|
158
|
+
:param spo_list:
|
159
|
+
:param text:
|
160
|
+
:return: labeling_list
|
161
|
+
"""
|
162
|
+
|
163
|
+
def _spo_list_to_spo_predicate_dict(spo_list):
|
164
|
+
spo_predicate_dict = dict()
|
165
|
+
for spo_item in spo_list:
|
166
|
+
predicate = spo_item["predicate"]
|
167
|
+
subject = spo_item["subject"]
|
168
|
+
object = spo_item["object"]
|
169
|
+
spo_predicate_dict.setdefault(predicate, []).append((subject, object))
|
170
|
+
return spo_predicate_dict
|
171
|
+
|
172
|
+
def _index_q_list_in_k_list(q_list, k_list):
|
173
|
+
"""Known q_list in k_list, find index(first time) of q_list in k_list"""
|
174
|
+
q_list_length = len(q_list)
|
175
|
+
k_list_length = len(k_list)
|
176
|
+
for idx in range(k_list_length - q_list_length + 1):
|
177
|
+
t = [q == k for q, k in zip(q_list, k_list[idx : idx + q_list_length])]
|
178
|
+
# print(idx, t)
|
179
|
+
if all(t):
|
180
|
+
# print(idx)
|
181
|
+
idx_start = idx
|
182
|
+
return idx_start
|
183
|
+
|
184
|
+
def _labeling_type(spo, spo_type):
|
185
|
+
idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
|
186
|
+
labeling_list[idx_start] = "B-" + spo_type
|
187
|
+
if len(spo) == 2:
|
188
|
+
labeling_list[idx_start + 1] = "I-" + spo_type
|
189
|
+
elif len(spo) >= 3:
|
190
|
+
labeling_list[idx_start + 1 : idx_start + len(spo)] = ["I-" + spo_type] * (
|
191
|
+
len(spo) - 1
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
pass
|
195
|
+
|
196
|
+
spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
|
197
|
+
labeling_list = ["O"] * len(text)
|
198
|
+
# count = 0
|
199
|
+
for predicate, spo_list_form in spo_predicate_dict.items():
|
200
|
+
if predicate in text:
|
201
|
+
for (spo_subject, spo_object) in spo_list_form:
|
202
|
+
# if predicate not in spo_subject and predicate not in spo_object:
|
203
|
+
_labeling_type(spo_subject, "SUB")
|
204
|
+
_labeling_type(spo_object, "OBJ")
|
205
|
+
_labeling_type(predicate, "PRE")
|
206
|
+
# count += 1
|
207
|
+
# print(count)
|
208
|
+
# if count == 2:
|
209
|
+
# print()
|
210
|
+
if labeling_list != ["O"] * len(text):
|
211
|
+
return labeling_list
|
212
|
+
return None
|
213
|
+
|
214
|
+
|
215
|
+
def label(text, labels):
|
216
|
+
"""
|
217
|
+
返回两列的标记数据序列
|
218
|
+
:param text:
|
219
|
+
:param labels:
|
220
|
+
:return:
|
221
|
+
"""
|
222
|
+
train_sequence = "\n".join(
|
223
|
+
[
|
224
|
+
"\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
|
225
|
+
for i in zip(list(text), labels)
|
226
|
+
]
|
227
|
+
)
|
228
|
+
return train_sequence
|
229
|
+
|
230
|
+
|
231
|
+
def convert_crf_format_10_fold(corpus, objdir_path):
|
232
|
+
"""
|
233
|
+
把已经是crf格式的数据,分成十折。
|
234
|
+
para:
|
235
|
+
|
236
|
+
"""
|
237
|
+
# corpus = list(range(1,22))
|
238
|
+
j_mkdir(objdir_path)
|
239
|
+
split_position = int(len(corpus) / 10)
|
240
|
+
for k in range(0, 10):
|
241
|
+
if k == 9:
|
242
|
+
dev_set = corpus[k * split_position :]
|
243
|
+
train_set = corpus[: k * split_position]
|
244
|
+
else:
|
245
|
+
dev_set = corpus[k * split_position : (k + 1) * split_position]
|
246
|
+
train_set = (
|
247
|
+
corpus[: k * split_position] + corpus[(k + 1) * split_position :]
|
248
|
+
)
|
249
|
+
writetxt_w_list(
|
250
|
+
train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
|
251
|
+
)
|
252
|
+
writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
|
253
|
+
writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
|
254
|
+
|
255
|
+
|
256
|
+
def read_seq_res(path, labels):
|
257
|
+
"""
|
258
|
+
读序列标注三列数据的方法
|
259
|
+
:param path:
|
260
|
+
:param labels:
|
261
|
+
:return:
|
262
|
+
"""
|
263
|
+
with codecs.open(path, "r", "utf-8") as rd:
|
264
|
+
seqs_str = rd.read().strip()
|
265
|
+
seqs_list = seqs_str.split("\n\n")
|
266
|
+
text, raw_label, predict_label = [], [], []
|
267
|
+
for seq in seqs_list:
|
268
|
+
seq_split = seq.split("\n")
|
269
|
+
text_tmp = ""
|
270
|
+
raw_index_dict, pre_index_dict = {}, {}
|
271
|
+
for label in labels:
|
272
|
+
raw_index_dict.setdefault(label, [])
|
273
|
+
pre_index_dict.setdefault(label, [])
|
274
|
+
for idx, line in enumerate(seq_split):
|
275
|
+
tmp = line.split("\t")
|
276
|
+
text_tmp += tmp[0]
|
277
|
+
if tmp[1] in labels:
|
278
|
+
raw_index_dict[tmp[1]].append(idx)
|
279
|
+
if tmp[2] in labels:
|
280
|
+
pre_index_dict[tmp[2]].append(idx)
|
281
|
+
text.append(text_tmp)
|
282
|
+
raw_label.append(raw_index_dict)
|
283
|
+
predict_label.append(pre_index_dict)
|
284
|
+
return text, raw_label, predict_label
|
285
|
+
|
286
|
+
|
287
|
+
def kfold_txt(corpus, path, k=9, is_shuffle=True):
|
288
|
+
"""
|
289
|
+
k是10份中训练集占了几份
|
290
|
+
"""
|
291
|
+
j_mkdir(path)
|
292
|
+
if is_shuffle:
|
293
|
+
random.shuffle(corpus)
|
294
|
+
split_position = int(len(corpus) / 10)
|
295
|
+
train_set, dev_set = corpus[: k * split_position], corpus[k * split_position :]
|
296
|
+
writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
|
297
|
+
writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
|
298
|
+
writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
|
299
|
+
|
300
|
+
|
301
|
+
def kfold_df(df, save_dir=None):
|
302
|
+
"""
|
303
|
+
划分train test val集, 写为windows可读的csv。
|
304
|
+
:param df:pd.DataFrame
|
305
|
+
:param save_dir:
|
306
|
+
:return:
|
307
|
+
"""
|
308
|
+
from sklearn.model_selection import KFold
|
309
|
+
import pandas as pd
|
310
|
+
|
311
|
+
train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
|
312
|
+
df_test_and_val = df.iloc[test_and_val_idx]
|
313
|
+
test_idx, val_idx = (
|
314
|
+
KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
|
315
|
+
)
|
316
|
+
df_train = df.iloc[train_idx]
|
317
|
+
df_val = df.iloc[val_idx]
|
318
|
+
df_test = df.iloc[test_idx]
|
319
|
+
if save_dir:
|
320
|
+
j_mkdir(save_dir)
|
321
|
+
save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
|
322
|
+
save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
|
323
|
+
save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
|
324
|
+
return df_train, df_val, df_test
|
325
|
+
|
326
|
+
|
327
|
+
# 读取crf序列格式的数据
|
328
|
+
def read_seq_data(path):
|
329
|
+
content = readtxt_list_all_strip(path)
|
330
|
+
lines = [i.split("\t") if i else "" for i in content]
|
331
|
+
print(lines)
|
332
|
+
sequences, labels, sequence, label = [], [], [], []
|
333
|
+
for idx, line in enumerate(lines):
|
334
|
+
if line == "":
|
335
|
+
if sequence:
|
336
|
+
sequences.append(sequence)
|
337
|
+
labels.append(label)
|
338
|
+
sequence, label = [], []
|
339
|
+
else:
|
340
|
+
sequence.append(line[0])
|
341
|
+
label.append(line[1])
|
342
|
+
if idx == len(lines) - 1 and sequence:
|
343
|
+
sequences.append(sequence)
|
344
|
+
labels.append(label)
|
345
|
+
return sequences, labels
|
346
|
+
|
347
|
+
|
348
|
+
def split_5_percent(lines, sample_precent=5):
|
349
|
+
random.seed(8)
|
350
|
+
# lines = list(range(1, 109))
|
351
|
+
idx_lines = [(idx, i) for idx, i in enumerate(lines)]
|
352
|
+
div = int(len(lines) / 100)
|
353
|
+
sample_num = div * sample_precent
|
354
|
+
sample = random.sample(idx_lines, sample_num)
|
355
|
+
sorted_sample = sorted(sample, key=lambda x: x[0])
|
356
|
+
remove_idx = [i[0] for i in sorted_sample]
|
357
|
+
less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
|
358
|
+
most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
|
359
|
+
print(less_has_raw_line_info)
|
360
|
+
print(most)
|
361
|
+
return most, less_has_raw_line_info
|
362
|
+
|
363
|
+
|
364
|
+
def split_sentence(sentence, language="chinese", cross_line=True):
|
365
|
+
"""
|
366
|
+
分句,英文有nltk,中文怎么能没有好的分句工具呢
|
367
|
+
:param sentence:
|
368
|
+
:param language:
|
369
|
+
:param cross_line:
|
370
|
+
:return:
|
371
|
+
"""
|
372
|
+
# sentences->Str
|
373
|
+
# example '12“345。”“6789”'
|
374
|
+
assert language in ["chinese", "english"], "unsupportable for other language"
|
375
|
+
sentence = sentence.replace("\r", "")
|
376
|
+
if language == "chinese":
|
377
|
+
split_signs = list("。!?…")
|
378
|
+
if cross_line:
|
379
|
+
split_signs.append("\n")
|
380
|
+
other_sign = "”"
|
381
|
+
elif language == "english":
|
382
|
+
split_signs = list(".!?")
|
383
|
+
other_sign = '"'
|
384
|
+
else:
|
385
|
+
split_signs = list(".!?")
|
386
|
+
other_sign = '"'
|
387
|
+
sentences = []
|
388
|
+
start_idx = 0
|
389
|
+
for idx, char in enumerate(sentence):
|
390
|
+
if idx == len(sentence) - 1:
|
391
|
+
if char in split_signs:
|
392
|
+
sentences.append(sentence[start_idx : idx + 1].strip())
|
393
|
+
start_idx = idx + 1
|
394
|
+
else:
|
395
|
+
sentences.append(sentence[start_idx:].strip())
|
396
|
+
else:
|
397
|
+
if char in split_signs:
|
398
|
+
if sentence[idx + 1] == other_sign:
|
399
|
+
if idx < len(sentence) - 2:
|
400
|
+
# 处理。”。
|
401
|
+
if sentence[idx + 2] not in split_signs:
|
402
|
+
sentences.append(sentence[start_idx : idx + 2].strip())
|
403
|
+
start_idx = idx + 2
|
404
|
+
elif sentence[idx + 1] not in split_signs:
|
405
|
+
sentences.append(sentence[start_idx : idx + 1].strip())
|
406
|
+
start_idx = idx + 1
|
407
|
+
return sentences
|
408
|
+
|
409
|
+
|
410
|
+
def pos_reduction():
|
411
|
+
wnl = WordNetLemmatizer()
|
412
|
+
# lemmatize nouns
|
413
|
+
print(wnl.lemmatize("cars", "n"))
|
414
|
+
print(wnl.lemmatize("men", "n"))
|
415
|
+
|
416
|
+
# lemmatize verbs
|
417
|
+
print(wnl.lemmatize("running", "v"))
|
418
|
+
print(wnl.lemmatize("ate", "v"))
|
419
|
+
|
420
|
+
|
421
|
+
class DataVisualization:
|
422
|
+
# 和下面的类冲突了
|
423
|
+
pass
|
424
|
+
|
425
|
+
|
426
|
+
class Evaluate:
|
427
|
+
def __init__(self):
|
428
|
+
pass
|
429
|
+
|
430
|
+
def auc_metric(self, k):
|
431
|
+
pass
|
432
|
+
|
433
|
+
def map_metric(self):
|
434
|
+
pass
|
435
|
+
|
436
|
+
def ndcg(self, n, y_true, y_score):
|
437
|
+
report = metrics.ndcg_score(y_true, y_score)
|
438
|
+
return report
|
439
|
+
|
440
|
+
|
441
|
+
class DecideTreeUtils:
|
442
|
+
@staticmethod
|
443
|
+
def draw(bst):
|
444
|
+
# xgb 画图
|
445
|
+
fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
|
446
|
+
xgb.plot_tree(bst, ax=ax_tree)
|
447
|
+
fig_tree.savefig("tree.png")
|
448
|
+
plt.show()
|
449
|
+
|
450
|
+
|
451
|
+
def seed_everything(seed=7777777) -> None:
|
452
|
+
"""
|
453
|
+
设置整个开发环境的seed
|
454
|
+
:param seed:
|
455
|
+
:param device:
|
456
|
+
:return:
|
457
|
+
"""
|
458
|
+
random.seed(seed)
|
459
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
460
|
+
np.random.seed(seed)
|
461
|
+
torch.manual_seed(seed) # CPU随机种子确定
|
462
|
+
torch.cuda.manual_seed(seed)
|
463
|
+
torch.cuda.manual_seed_all(seed)
|
464
|
+
# some cudnn methods can be random even after fixing the seed
|
465
|
+
# unless you tell it to be deterministic
|
466
|
+
torch.backends.cudnn.deterministic = True
|
467
|
+
|
468
|
+
|
469
|
+
if __name__ == "__main__":
|
470
|
+
# stem = STEM(IPT_MODEL_PATH)
|
471
|
+
# test_sentence = "美国袭击伊拉克"
|
472
|
+
# a = stem.start_by_srl(test_sentence)
|
473
|
+
|
474
|
+
res = calc_llm_train_activation_memory(
|
475
|
+
model_name="",
|
476
|
+
sequence_length=2048,
|
477
|
+
batch_size=1,
|
478
|
+
hidden_dim=4096,
|
479
|
+
lay_number=28,
|
480
|
+
attention_heads_num=32,
|
481
|
+
gpu_num=1
|
482
|
+
)
|
483
|
+
print(res, "G")
|