nlpertools 1.0.5__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (43) hide show
  1. nlpertools/__init__.py +24 -20
  2. nlpertools/algo/ac.py +18 -0
  3. nlpertools/algo/bit_ops.py +28 -0
  4. nlpertools/algo/kmp.py +94 -55
  5. nlpertools/algo/num_ops.py +12 -0
  6. nlpertools/algo/template.py +116 -0
  7. nlpertools/algo/union.py +13 -0
  8. nlpertools/data_client.py +387 -257
  9. nlpertools/data_structure/base_structure.py +109 -13
  10. nlpertools/dataprocess.py +611 -3
  11. nlpertools/default_db_config.yml +41 -0
  12. nlpertools/io/__init__.py +3 -3
  13. nlpertools/io/dir.py +54 -36
  14. nlpertools/io/file.py +277 -222
  15. nlpertools/ml.py +483 -460
  16. nlpertools/monitor/__init__.py +0 -0
  17. nlpertools/monitor/gpu.py +18 -0
  18. nlpertools/monitor/memory.py +24 -0
  19. nlpertools/movie.py +36 -0
  20. nlpertools/nlpertools_config.yml +1 -0
  21. nlpertools/{openApi.py → open_api.py} +65 -65
  22. nlpertools/other.py +364 -249
  23. nlpertools/pic.py +288 -0
  24. nlpertools/plugin.py +43 -43
  25. nlpertools/reminder.py +98 -87
  26. nlpertools/utils/__init__.py +3 -3
  27. nlpertools/utils/lazy.py +727 -0
  28. nlpertools/utils/log_util.py +20 -0
  29. nlpertools/utils/package.py +89 -76
  30. nlpertools/utils/package_v1.py +94 -0
  31. nlpertools/utils/package_v2.py +117 -0
  32. nlpertools/utils_for_nlpertools.py +93 -93
  33. nlpertools/vector_index_demo.py +108 -0
  34. nlpertools/wrapper.py +161 -96
  35. {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
  36. nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
  37. nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
  38. {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
  39. nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
  40. nlpertools_helper/__init__.py +10 -0
  41. nlpertools-1.0.5.dist-info/METADATA +0 -85
  42. nlpertools-1.0.5.dist-info/RECORD +0 -25
  43. nlpertools-1.0.5.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,460 +1,483 @@
1
- # encoding=utf-8
2
- import codecs
3
- import os
4
- import random
5
-
6
- from .io.dir import j_mkdir
7
- from .io.file import readtxt_list_all_strip, writetxt_w_list
8
- # import numpy as np
9
- # import seaborn as sns
10
- # import torch
11
- # import torch.nn as nn
12
- # import xgboost as xgb
13
- # from matplotlib import pyplot as plt
14
- # from nltk.stem import WordNetLemmatizer
15
- # from sklearn import metrics
16
- # from transformers import BertTokenizer, BertForMaskedLM
17
- from .utils.package import *
18
-
19
-
20
- class DataAnalysis:
21
- @staticmethod
22
- def draw_pic(df, save_path):
23
- """
24
- 画直方图,对比两个不同类别差异
25
- :param df: pd.DataFrame
26
- :param save_path: str
27
- :return:
28
- """
29
- sns.distplot(df[df["label"] == 1]["feature"], label="label1")
30
- sns.distplot(df[df["label"] == 0]["feature"], label="label2")
31
- plt.legend()
32
- plt.savefig(save_path)
33
-
34
-
35
- class DataStructure:
36
- spo = {
37
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
38
- "triplets": [
39
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
40
- "p": {"text": "出版社", "l": 15, "r": 18},
41
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
42
- "source": "baidu"
43
- }
44
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
45
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
46
-
47
-
48
- def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
49
- # 两个句子的jacccard系数
50
- # 判断输入来重新定义ipt_level和sim_level
51
-
52
- # a = set(ipt1.split())
53
- # b = set(ipt2.split())
54
- a = set(ipt1)
55
- b = set(ipt2)
56
- c = a.intersection(b)
57
- # spical situation:
58
- if not ipt1 and not ipt2:
59
- return 0
60
- return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
61
-
62
-
63
- class STEM(object):
64
-
65
- def __init__(self, IPT_MODEL_PATH):
66
- self.ltp = LTP(IPT_MODEL_PATH)
67
-
68
- def start_by_dep(self, sentence):
69
- seg, hidden = self.ltp.seg([sentence])
70
- dep = self.ltp.dep(hidden) # , graph=False)
71
- seg, dep = seg[0], dep[0]
72
- for i in dep:
73
- # 主谓宾
74
- if 'SBV' == i[2]:
75
- subject = seg[i[0]]
76
- verb = seg[i[1]]
77
- if 'VOB' in i[2]:
78
- if seg[i[1]] == verb:
79
- object = seg[i[0]]
80
-
81
- return subject
82
-
83
- return None
84
-
85
- def start_by_srl(self, sentence):
86
- """
87
- 用语义角色标注工具
88
- :param sentence: "他叫汤姆去拿外衣。"
89
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
90
- """
91
- # 语义角色标注方法
92
- seg, hidden = self.ltp.seg([sentence])
93
- srl = self.ltp.srl(hidden)
94
- seg, srl = seg[0], srl[0]
95
- events = []
96
- for wdx, each_srl in enumerate(srl):
97
- if each_srl:
98
- args = []
99
- for arg in each_srl:
100
- args.extend(seg[arg[1]:arg[2] + 1])
101
- # 添加上谓词
102
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
103
- events.append(args)
104
- # print(events)
105
- return events
106
-
107
-
108
- # 这个是另一种
109
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
110
- def subject_object_labeling_new(spo_list, text):
111
- pass
112
-
113
-
114
- # 这个是传统格式的
115
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
116
- def subject_object_labeling(spo_list, text):
117
- # TODO
118
- '''
119
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
120
- :param spo_list:
121
- :param text:
122
- :return: labeling_list
123
- '''
124
-
125
- def _spo_list_to_spo_predicate_dict(spo_list):
126
- spo_predicate_dict = dict()
127
- for spo_item in spo_list:
128
- predicate = spo_item["predicate"]
129
- subject = spo_item["subject"]
130
- object = spo_item["object"]
131
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
132
- return spo_predicate_dict
133
-
134
- def _index_q_list_in_k_list(q_list, k_list):
135
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
136
- q_list_length = len(q_list)
137
- k_list_length = len(k_list)
138
- for idx in range(k_list_length - q_list_length + 1):
139
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
140
- # print(idx, t)
141
- if all(t):
142
- # print(idx)
143
- idx_start = idx
144
- return idx_start
145
-
146
- def _labeling_type(spo, spo_type):
147
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
148
- labeling_list[idx_start] = 'B-' + spo_type
149
- if len(spo) == 2:
150
- labeling_list[idx_start + 1] = 'I-' + spo_type
151
- elif len(spo) >= 3:
152
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
153
- else:
154
- pass
155
-
156
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
157
- labeling_list = ['O'] * len(text)
158
- # count = 0
159
- for predicate, spo_list_form in spo_predicate_dict.items():
160
- if predicate in text:
161
- for (spo_subject, spo_object) in spo_list_form:
162
- # if predicate not in spo_subject and predicate not in spo_object:
163
- _labeling_type(spo_subject, 'SUB')
164
- _labeling_type(spo_object, 'OBJ')
165
- _labeling_type(predicate, 'PRE')
166
- # count += 1
167
- # print(count)
168
- # if count == 2:
169
- # print()
170
- if labeling_list != ['O'] * len(text):
171
- return labeling_list
172
- return None
173
-
174
-
175
- def label(text, labels):
176
- '''
177
- 返回两列的标记数据序列
178
- :param text:
179
- :param labels:
180
- :return:
181
- '''
182
- train_sequence = '\n'.join(
183
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
184
- return train_sequence
185
-
186
-
187
- def convert_crf_format_10_fold(corpus, objdir_path):
188
- '''
189
- 把已经是crf格式的数据,分成十折。
190
- para:
191
-
192
- '''
193
- # corpus = list(range(1,22))
194
- j_mkdir(objdir_path)
195
- split_position = int(len(corpus) / 10)
196
- for k in range(0, 10):
197
- if k == 9:
198
- dev_set = corpus[k * split_position:]
199
- train_set = corpus[:k * split_position]
200
- else:
201
- dev_set = corpus[k * split_position: (k + 1) * split_position]
202
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
203
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
204
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
205
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
206
-
207
-
208
- def read_seq_res(path, labels):
209
- '''
210
- 读序列标注三列数据的方法
211
- :param path:
212
- :param labels:
213
- :return:
214
- '''
215
- with codecs.open(path, 'r', 'utf-8') as rd:
216
- seqs_str = rd.read().strip()
217
- seqs_list = seqs_str.split('\n\n')
218
- text, raw_label, predict_label = [], [], []
219
- for seq in seqs_list:
220
- seq_split = seq.split('\n')
221
- text_tmp = ''
222
- raw_index_dict, pre_index_dict = {}, {}
223
- for label in labels:
224
- raw_index_dict.setdefault(label, [])
225
- pre_index_dict.setdefault(label, [])
226
- for idx, line in enumerate(seq_split):
227
- tmp = line.split('\t')
228
- text_tmp += tmp[0]
229
- if tmp[1] in labels:
230
- raw_index_dict[tmp[1]].append(idx)
231
- if tmp[2] in labels:
232
- pre_index_dict[tmp[2]].append(idx)
233
- text.append(text_tmp)
234
- raw_label.append(raw_index_dict)
235
- predict_label.append(pre_index_dict)
236
- return text, raw_label, predict_label
237
-
238
-
239
- def kfold(corpus, path, k=9, is_shuffle=True):
240
- '''
241
- k是10份中训练集占了几份
242
- '''
243
- j_mkdir(path)
244
- if is_shuffle:
245
- random.shuffle(corpus)
246
- split_position = int(len(corpus) / 10)
247
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
248
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
249
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
250
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
251
- """
252
- import pandas as pd
253
- from sklearn.model_selection import KFold
254
-
255
- df = pd.DataFrame({
256
- "text": ["text_{}".format(i) for i in range(100)],
257
- "labels": ["label_{}".format(i % 10) for i in range(100)]
258
- })
259
- train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
260
- test_idx, val_idx = KFold(n_splits=2, shuffle=True).split(df).__next__()
261
- df_train = df.iloc[train_idx]
262
- df_val = df.iloc[val_idx]
263
- df_test = df.iloc[test_idx]
264
- print(train_idx)
265
- print(val_idx)
266
- print(test_idx)
267
- """
268
-
269
-
270
- # 读取crf序列格式的数据
271
- def read_seq_data(path):
272
- content = readtxt_list_all_strip(path)
273
- lines = [i.split('\t') if i else '' for i in content]
274
- print(lines)
275
- sequences, labels, sequence, label = [], [], [], []
276
- for idx, line in enumerate(lines):
277
- if line == '':
278
- if sequence:
279
- sequences.append(sequence)
280
- labels.append(label)
281
- sequence, label = [], []
282
- else:
283
- sequence.append(line[0])
284
- label.append(line[1])
285
- if idx == len(lines) - 1 and sequence:
286
- sequences.append(sequence)
287
- labels.append(label)
288
- return sequences, labels
289
-
290
-
291
- def split_5_percent(lines, sample_precent=5):
292
- random.seed(8)
293
- # lines = list(range(1, 109))
294
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
295
- div = int(len(lines) / 100)
296
- sample_num = div * sample_precent
297
- sample = random.sample(idx_lines, sample_num)
298
- sorted_sample = sorted(sample, key=lambda x: x[0])
299
- remove_idx = [i[0] for i in sorted_sample]
300
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
301
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
302
- print(less_has_raw_line_info)
303
- print(most)
304
- return most, less_has_raw_line_info
305
-
306
-
307
- def split_sentences(sentences, mode='chinese'):
308
- # sentences->Str
309
- # example '12“345。”“6789”'
310
- if mode == 'chinese':
311
- split_signs = list('。!?…')
312
- other_sign = "”"
313
- elif mode == 'english':
314
- split_signs = list('.!?')
315
- other_sign = '"'
316
- else:
317
- print('暂时还没有')
318
- split_signs = list('.!?')
319
- other_sign = '"'
320
- splited_sentences = []
321
- start_idx = 0
322
- for idx, char in enumerate(sentences):
323
- if idx == len(sentences) - 1:
324
- if char in split_signs:
325
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
326
- start_idx = idx + 1
327
- else:
328
- splited_sentences.append(sentences[start_idx:].strip())
329
- else:
330
- if char in split_signs:
331
- if sentences[idx + 1] == other_sign:
332
- if idx < len(sentences) - 2:
333
- # 处理。”。
334
- if sentences[idx + 2] not in split_signs:
335
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
336
- start_idx = idx + 2
337
- elif sentences[idx + 1] not in split_signs:
338
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
339
- start_idx = idx + 1
340
- return splited_sentences
341
-
342
-
343
- def pos_reduction():
344
- wnl = WordNetLemmatizer()
345
- # lemmatize nouns
346
- print(wnl.lemmatize('cars', 'n'))
347
- print(wnl.lemmatize('men', 'n'))
348
-
349
- # lemmatize verbs
350
- print(wnl.lemmatize('running', 'v'))
351
- print(wnl.lemmatize('ate', 'v'))
352
-
353
-
354
- class DataVisualization:
355
- # 和下面的类冲突了
356
- pass
357
-
358
-
359
- class CalcPPL(object):
360
- # ppl计算
361
- # https://www.scribendi.ai/comparing-bert-and-gpt-2-as-language-models-to-score-the-grammatical-correctness-of-a-sentence/
362
- def __init__(self, path):
363
- self.model = BertForMaskedLM.from_pretrained(path)
364
- self.model.eval()
365
- # Load pre-trained model tokenizer (vocabulary)
366
- self.tokenizer = BertTokenizer.from_pretrained(path)
367
-
368
- def ppl_1(self, sentence):
369
- tokenizer = self.tokenizer
370
- model = self.tokenizer
371
- tokenize_input = tokenizer.tokenize(sentence)
372
- tokenize_input = tokenize_input
373
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
374
- with torch.no_grad():
375
- loss = model(tensor_input, labels=tensor_input)[0]
376
- return np.exp(loss.detach().numpy())
377
-
378
- # [1] Salazar J, Liang D, Nguyen T Q, et al. Masked Language Model Scoring[C]//Proceedings of ACL. 2020: 2699-2712.
379
- def ppl_2(self, sentence):
380
- tokenizer = self.tokenizer
381
- model = self.tokenizer
382
- with torch.no_grad():
383
- tokenize_input = tokenizer.tokenize(sentence)
384
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
385
- sen_len = len(tokenize_input)
386
- sentence_loss = 0.
387
-
388
- for i, word in enumerate(tokenize_input):
389
- # add mask to i-th character of the sentence
390
- tokenize_input[i] = '[MASK]'
391
- mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
392
-
393
- output = model(mask_input)
394
-
395
- prediction_scores = output[0]
396
- softmax = nn.Softmax(dim=0)
397
- ps = softmax(prediction_scores[0, i]).log()
398
- word_loss = ps[tensor_input[0, i]]
399
- sentence_loss += word_loss.item()
400
-
401
- tokenize_input[i] = word
402
- ppl = np.exp(-sentence_loss / sen_len)
403
- # print("困惑度:", ppl)
404
- return ppl
405
-
406
- def test(self):
407
- sentence = "输入句子:"
408
- ppl = self.ppl_1(sentence)
409
- ppl2 = self.ppl_2(sentence)
410
- print(ppl)
411
- print(ppl2)
412
-
413
-
414
- class Evaluate():
415
- def __init__(self):
416
- pass
417
-
418
- def auc_metric(self, k):
419
- pass
420
-
421
- def map_metric(self):
422
- pass
423
-
424
- def ndcg(self, n, y_true, y_score):
425
- report = metrics.ndcg_score(y_true, y_score)
426
- return report
427
-
428
-
429
- class DecideTreeUtils:
430
- @staticmethod
431
- def draw(bst):
432
- # xgb 画图
433
- fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
434
- xgb.plot_tree(bst, ax=ax_tree)
435
- fig_tree.savefig('tree.png')
436
- plt.show()
437
-
438
-
439
- def seed_everything(seed=7777777) -> None:
440
- """
441
- 设置整个开发环境的seed
442
- :param seed:
443
- :param device:
444
- :return:
445
- """
446
- random.seed(seed)
447
- os.environ['PYTHONHASHSEED'] = str(seed)
448
- np.random.seed(seed)
449
- torch.manual_seed(seed) # CPU随机种子确定
450
- torch.cuda.manual_seed(seed)
451
- torch.cuda.manual_seed_all(seed)
452
- # some cudnn methods can be random even after fixing the seed
453
- # unless you tell it to be deterministic
454
- torch.backends.cudnn.deterministic = True
455
-
456
-
457
- if __name__ == '__main__':
458
- stem = STEM(IPT_MODEL_PATH)
459
- test_sentence = '美国袭击伊拉克'
460
- a = stem.start_by_srl(test_sentence)
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+
24
+ """
25
+ return bytes
26
+
27
+ reference:
28
+ 1. https://zhuanlan.zhihu.com/p/665172400
29
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
30
+ """
31
+ # reference1
32
+ # attention
33
+ # FFN
34
+ # Layer Norm
35
+ r1 = (
36
+ sequence_length
37
+ * batch_size
38
+ * hidden_dim
39
+ * lay_number
40
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
41
+ )
42
+ # reference2
43
+ r2 = (
44
+ lay_number*(2 * sequence_length * attention_heads_num + 16 * hidden_dim)
45
+ * sequence_length
46
+ * batch_size
47
+ / gpu_num
48
+ )
49
+ print(r1)
50
+ print(r2)
51
+ return r1
52
+
53
+
54
+ class DataAnalysis:
55
+ @staticmethod
56
+ def draw_pic(df, save_path):
57
+ """
58
+ 画直方图,对比两个不同类别差异
59
+ :param df: pd.DataFrame
60
+ :param save_path: str
61
+ :return:
62
+ """
63
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
64
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
65
+ plt.legend()
66
+ plt.savefig(save_path)
67
+
68
+
69
+ class DataStructure:
70
+ spo = {
71
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
72
+ "triplets": [
73
+ {
74
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
75
+ "p": {"text": "出版社", "l": 15, "r": 18},
76
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
77
+ }
78
+ ],
79
+ "source": "baidu",
80
+ }
81
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
82
+ ner_label_example = (
83
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
84
+ )
85
+
86
+
87
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
88
+ # 两个句子的jacccard系数
89
+ # 判断输入来重新定义ipt_level和sim_level
90
+
91
+ # a = set(ipt1.split())
92
+ # b = set(ipt2.split())
93
+ a = set(ipt1)
94
+ b = set(ipt2)
95
+ c = a.intersection(b)
96
+ # spical situation:
97
+ if not ipt1 and not ipt2:
98
+ return 0
99
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
100
+
101
+
102
+ class STEM(object):
103
+ def __init__(self, IPT_MODEL_PATH):
104
+ self.ltp = LTP(IPT_MODEL_PATH)
105
+
106
+ def start_by_dep(self, sentence):
107
+ seg, hidden = self.ltp.seg([sentence])
108
+ dep = self.ltp.dep(hidden) # , graph=False)
109
+ seg, dep = seg[0], dep[0]
110
+ for i in dep:
111
+ # 主谓宾
112
+ if "SBV" == i[2]:
113
+ subject = seg[i[0]]
114
+ verb = seg[i[1]]
115
+ if "VOB" in i[2]:
116
+ if seg[i[1]] == verb:
117
+ object = seg[i[0]]
118
+
119
+ return subject
120
+
121
+ return None
122
+
123
+ def start_by_srl(self, sentence):
124
+ """
125
+ 用语义角色标注工具
126
+ :param sentence: "他叫汤姆去拿外衣。"
127
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
128
+ """
129
+ # 语义角色标注方法
130
+ seg, hidden = self.ltp.seg([sentence])
131
+ srl = self.ltp.srl(hidden)
132
+ seg, srl = seg[0], srl[0]
133
+ events = []
134
+ for wdx, each_srl in enumerate(srl):
135
+ if each_srl:
136
+ args = []
137
+ for arg in each_srl:
138
+ args.extend(seg[arg[1] : arg[2] + 1])
139
+ # 添加上谓词
140
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
141
+ events.append(args)
142
+ # print(events)
143
+ return events
144
+
145
+
146
+ # 这个是另一种
147
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
148
+ def subject_object_labeling_new(spo_list, text):
149
+ pass
150
+
151
+
152
+ # 这个是传统格式的
153
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
154
+ def subject_object_labeling(spo_list, text):
155
+ # TODO
156
+ """
157
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
158
+ :param spo_list:
159
+ :param text:
160
+ :return: labeling_list
161
+ """
162
+
163
+ def _spo_list_to_spo_predicate_dict(spo_list):
164
+ spo_predicate_dict = dict()
165
+ for spo_item in spo_list:
166
+ predicate = spo_item["predicate"]
167
+ subject = spo_item["subject"]
168
+ object = spo_item["object"]
169
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
170
+ return spo_predicate_dict
171
+
172
+ def _index_q_list_in_k_list(q_list, k_list):
173
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
174
+ q_list_length = len(q_list)
175
+ k_list_length = len(k_list)
176
+ for idx in range(k_list_length - q_list_length + 1):
177
+ t = [q == k for q, k in zip(q_list, k_list[idx : idx + q_list_length])]
178
+ # print(idx, t)
179
+ if all(t):
180
+ # print(idx)
181
+ idx_start = idx
182
+ return idx_start
183
+
184
+ def _labeling_type(spo, spo_type):
185
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
186
+ labeling_list[idx_start] = "B-" + spo_type
187
+ if len(spo) == 2:
188
+ labeling_list[idx_start + 1] = "I-" + spo_type
189
+ elif len(spo) >= 3:
190
+ labeling_list[idx_start + 1 : idx_start + len(spo)] = ["I-" + spo_type] * (
191
+ len(spo) - 1
192
+ )
193
+ else:
194
+ pass
195
+
196
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
197
+ labeling_list = ["O"] * len(text)
198
+ # count = 0
199
+ for predicate, spo_list_form in spo_predicate_dict.items():
200
+ if predicate in text:
201
+ for (spo_subject, spo_object) in spo_list_form:
202
+ # if predicate not in spo_subject and predicate not in spo_object:
203
+ _labeling_type(spo_subject, "SUB")
204
+ _labeling_type(spo_object, "OBJ")
205
+ _labeling_type(predicate, "PRE")
206
+ # count += 1
207
+ # print(count)
208
+ # if count == 2:
209
+ # print()
210
+ if labeling_list != ["O"] * len(text):
211
+ return labeling_list
212
+ return None
213
+
214
+
215
+ def label(text, labels):
216
+ """
217
+ 返回两列的标记数据序列
218
+ :param text:
219
+ :param labels:
220
+ :return:
221
+ """
222
+ train_sequence = "\n".join(
223
+ [
224
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
225
+ for i in zip(list(text), labels)
226
+ ]
227
+ )
228
+ return train_sequence
229
+
230
+
231
+ def convert_crf_format_10_fold(corpus, objdir_path):
232
+ """
233
+ 把已经是crf格式的数据,分成十折。
234
+ para:
235
+
236
+ """
237
+ # corpus = list(range(1,22))
238
+ j_mkdir(objdir_path)
239
+ split_position = int(len(corpus) / 10)
240
+ for k in range(0, 10):
241
+ if k == 9:
242
+ dev_set = corpus[k * split_position :]
243
+ train_set = corpus[: k * split_position]
244
+ else:
245
+ dev_set = corpus[k * split_position : (k + 1) * split_position]
246
+ train_set = (
247
+ corpus[: k * split_position] + corpus[(k + 1) * split_position :]
248
+ )
249
+ writetxt_w_list(
250
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
251
+ )
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
253
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
254
+
255
+
256
+ def read_seq_res(path, labels):
257
+ """
258
+ 读序列标注三列数据的方法
259
+ :param path:
260
+ :param labels:
261
+ :return:
262
+ """
263
+ with codecs.open(path, "r", "utf-8") as rd:
264
+ seqs_str = rd.read().strip()
265
+ seqs_list = seqs_str.split("\n\n")
266
+ text, raw_label, predict_label = [], [], []
267
+ for seq in seqs_list:
268
+ seq_split = seq.split("\n")
269
+ text_tmp = ""
270
+ raw_index_dict, pre_index_dict = {}, {}
271
+ for label in labels:
272
+ raw_index_dict.setdefault(label, [])
273
+ pre_index_dict.setdefault(label, [])
274
+ for idx, line in enumerate(seq_split):
275
+ tmp = line.split("\t")
276
+ text_tmp += tmp[0]
277
+ if tmp[1] in labels:
278
+ raw_index_dict[tmp[1]].append(idx)
279
+ if tmp[2] in labels:
280
+ pre_index_dict[tmp[2]].append(idx)
281
+ text.append(text_tmp)
282
+ raw_label.append(raw_index_dict)
283
+ predict_label.append(pre_index_dict)
284
+ return text, raw_label, predict_label
285
+
286
+
287
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
288
+ """
289
+ k是10份中训练集占了几份
290
+ """
291
+ j_mkdir(path)
292
+ if is_shuffle:
293
+ random.shuffle(corpus)
294
+ split_position = int(len(corpus) / 10)
295
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position :]
296
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
298
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
299
+
300
+
301
+ def kfold_df(df, save_dir=None):
302
+ """
303
+ 划分train test val集, 写为windows可读的csv。
304
+ :param df:pd.DataFrame
305
+ :param save_dir:
306
+ :return:
307
+ """
308
+ from sklearn.model_selection import KFold
309
+ import pandas as pd
310
+
311
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
312
+ df_test_and_val = df.iloc[test_and_val_idx]
313
+ test_idx, val_idx = (
314
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
315
+ )
316
+ df_train = df.iloc[train_idx]
317
+ df_val = df.iloc[val_idx]
318
+ df_test = df.iloc[test_idx]
319
+ if save_dir:
320
+ j_mkdir(save_dir)
321
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
322
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
323
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
324
+ return df_train, df_val, df_test
325
+
326
+
327
+ # 读取crf序列格式的数据
328
+ def read_seq_data(path):
329
+ content = readtxt_list_all_strip(path)
330
+ lines = [i.split("\t") if i else "" for i in content]
331
+ print(lines)
332
+ sequences, labels, sequence, label = [], [], [], []
333
+ for idx, line in enumerate(lines):
334
+ if line == "":
335
+ if sequence:
336
+ sequences.append(sequence)
337
+ labels.append(label)
338
+ sequence, label = [], []
339
+ else:
340
+ sequence.append(line[0])
341
+ label.append(line[1])
342
+ if idx == len(lines) - 1 and sequence:
343
+ sequences.append(sequence)
344
+ labels.append(label)
345
+ return sequences, labels
346
+
347
+
348
+ def split_5_percent(lines, sample_precent=5):
349
+ random.seed(8)
350
+ # lines = list(range(1, 109))
351
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
352
+ div = int(len(lines) / 100)
353
+ sample_num = div * sample_precent
354
+ sample = random.sample(idx_lines, sample_num)
355
+ sorted_sample = sorted(sample, key=lambda x: x[0])
356
+ remove_idx = [i[0] for i in sorted_sample]
357
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
358
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
359
+ print(less_has_raw_line_info)
360
+ print(most)
361
+ return most, less_has_raw_line_info
362
+
363
+
364
+ def split_sentence(sentence, language="chinese", cross_line=True):
365
+ """
366
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
367
+ :param sentence:
368
+ :param language:
369
+ :param cross_line:
370
+ :return:
371
+ """
372
+ # sentences->Str
373
+ # example '12“345。”“6789”'
374
+ assert language in ["chinese", "english"], "unsupportable for other language"
375
+ sentence = sentence.replace("\r", "")
376
+ if language == "chinese":
377
+ split_signs = list("。!?…")
378
+ if cross_line:
379
+ split_signs.append("\n")
380
+ other_sign = "”"
381
+ elif language == "english":
382
+ split_signs = list(".!?")
383
+ other_sign = '"'
384
+ else:
385
+ split_signs = list(".!?")
386
+ other_sign = '"'
387
+ sentences = []
388
+ start_idx = 0
389
+ for idx, char in enumerate(sentence):
390
+ if idx == len(sentence) - 1:
391
+ if char in split_signs:
392
+ sentences.append(sentence[start_idx : idx + 1].strip())
393
+ start_idx = idx + 1
394
+ else:
395
+ sentences.append(sentence[start_idx:].strip())
396
+ else:
397
+ if char in split_signs:
398
+ if sentence[idx + 1] == other_sign:
399
+ if idx < len(sentence) - 2:
400
+ # 处理。”。
401
+ if sentence[idx + 2] not in split_signs:
402
+ sentences.append(sentence[start_idx : idx + 2].strip())
403
+ start_idx = idx + 2
404
+ elif sentence[idx + 1] not in split_signs:
405
+ sentences.append(sentence[start_idx : idx + 1].strip())
406
+ start_idx = idx + 1
407
+ return sentences
408
+
409
+
410
+ def pos_reduction():
411
+ wnl = WordNetLemmatizer()
412
+ # lemmatize nouns
413
+ print(wnl.lemmatize("cars", "n"))
414
+ print(wnl.lemmatize("men", "n"))
415
+
416
+ # lemmatize verbs
417
+ print(wnl.lemmatize("running", "v"))
418
+ print(wnl.lemmatize("ate", "v"))
419
+
420
+
421
+ class DataVisualization:
422
+ # 和下面的类冲突了
423
+ pass
424
+
425
+
426
+ class Evaluate:
427
+ def __init__(self):
428
+ pass
429
+
430
+ def auc_metric(self, k):
431
+ pass
432
+
433
+ def map_metric(self):
434
+ pass
435
+
436
+ def ndcg(self, n, y_true, y_score):
437
+ report = metrics.ndcg_score(y_true, y_score)
438
+ return report
439
+
440
+
441
+ class DecideTreeUtils:
442
+ @staticmethod
443
+ def draw(bst):
444
+ # xgb 画图
445
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
446
+ xgb.plot_tree(bst, ax=ax_tree)
447
+ fig_tree.savefig("tree.png")
448
+ plt.show()
449
+
450
+
451
+ def seed_everything(seed=7777777) -> None:
452
+ """
453
+ 设置整个开发环境的seed
454
+ :param seed:
455
+ :param device:
456
+ :return:
457
+ """
458
+ random.seed(seed)
459
+ os.environ["PYTHONHASHSEED"] = str(seed)
460
+ np.random.seed(seed)
461
+ torch.manual_seed(seed) # CPU随机种子确定
462
+ torch.cuda.manual_seed(seed)
463
+ torch.cuda.manual_seed_all(seed)
464
+ # some cudnn methods can be random even after fixing the seed
465
+ # unless you tell it to be deterministic
466
+ torch.backends.cudnn.deterministic = True
467
+
468
+
469
+ if __name__ == "__main__":
470
+ # stem = STEM(IPT_MODEL_PATH)
471
+ # test_sentence = "美国袭击伊拉克"
472
+ # a = stem.start_by_srl(test_sentence)
473
+
474
+ res = calc_llm_train_activation_memory(
475
+ model_name="",
476
+ sequence_length=2048,
477
+ batch_size=1,
478
+ hidden_dim=4096,
479
+ lay_number=28,
480
+ attention_heads_num=32,
481
+ gpu_num=1
482
+ )
483
+ print(res, "G")