nlpertools 1.0.5__py3-none-any.whl → 1.0.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (49) hide show
  1. nlpertools/__init__.py +23 -20
  2. nlpertools/algo/ac.py +18 -0
  3. nlpertools/algo/bit_ops.py +28 -0
  4. nlpertools/algo/kmp.py +94 -55
  5. nlpertools/algo/num_ops.py +12 -0
  6. nlpertools/algo/template.py +116 -0
  7. nlpertools/algo/union.py +13 -0
  8. nlpertools/cli.py +87 -0
  9. nlpertools/data_client.py +426 -257
  10. nlpertools/data_structure/base_structure.py +109 -13
  11. nlpertools/dataprocess.py +627 -3
  12. nlpertools/default_db_config.yml +41 -0
  13. nlpertools/draw/__init__.py +0 -0
  14. nlpertools/draw/draw.py +83 -0
  15. nlpertools/draw/math_func.py +33 -0
  16. nlpertools/get_2fa.py +0 -0
  17. nlpertools/io/__init__.py +3 -3
  18. nlpertools/io/dir.py +86 -36
  19. nlpertools/io/file.py +283 -222
  20. nlpertools/ml.py +511 -460
  21. nlpertools/monitor/__init__.py +0 -0
  22. nlpertools/monitor/gpu.py +18 -0
  23. nlpertools/monitor/memory.py +24 -0
  24. nlpertools/movie.py +36 -0
  25. nlpertools/nlpertools_config.yml +1 -0
  26. nlpertools/{openApi.py → open_api.py} +65 -65
  27. nlpertools/other.py +475 -249
  28. nlpertools/pic.py +288 -0
  29. nlpertools/plugin.py +43 -43
  30. nlpertools/reminder.py +98 -87
  31. nlpertools/utils/__init__.py +3 -3
  32. nlpertools/utils/lazy.py +727 -0
  33. nlpertools/utils/log_util.py +20 -0
  34. nlpertools/utils/package.py +89 -76
  35. nlpertools/utils/package_v1.py +94 -0
  36. nlpertools/utils/package_v2.py +117 -0
  37. nlpertools/utils_for_nlpertools.py +93 -93
  38. nlpertools/vector_index_demo.py +108 -0
  39. nlpertools/wrapper.py +161 -96
  40. {nlpertools-1.0.5.dist-info → nlpertools-1.0.8.dist-info}/LICENSE +200 -200
  41. nlpertools-1.0.8.dist-info/METADATA +132 -0
  42. nlpertools-1.0.8.dist-info/RECORD +49 -0
  43. {nlpertools-1.0.5.dist-info → nlpertools-1.0.8.dist-info}/WHEEL +1 -1
  44. nlpertools-1.0.8.dist-info/entry_points.txt +2 -0
  45. nlpertools-1.0.8.dist-info/top_level.txt +2 -0
  46. nlpertools_helper/__init__.py +10 -0
  47. nlpertools-1.0.5.dist-info/METADATA +0 -85
  48. nlpertools-1.0.5.dist-info/RECORD +0 -25
  49. nlpertools-1.0.5.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,460 +1,511 @@
1
- # encoding=utf-8
2
- import codecs
3
- import os
4
- import random
5
-
6
- from .io.dir import j_mkdir
7
- from .io.file import readtxt_list_all_strip, writetxt_w_list
8
- # import numpy as np
9
- # import seaborn as sns
10
- # import torch
11
- # import torch.nn as nn
12
- # import xgboost as xgb
13
- # from matplotlib import pyplot as plt
14
- # from nltk.stem import WordNetLemmatizer
15
- # from sklearn import metrics
16
- # from transformers import BertTokenizer, BertForMaskedLM
17
- from .utils.package import *
18
-
19
-
20
- class DataAnalysis:
21
- @staticmethod
22
- def draw_pic(df, save_path):
23
- """
24
- 画直方图,对比两个不同类别差异
25
- :param df: pd.DataFrame
26
- :param save_path: str
27
- :return:
28
- """
29
- sns.distplot(df[df["label"] == 1]["feature"], label="label1")
30
- sns.distplot(df[df["label"] == 0]["feature"], label="label2")
31
- plt.legend()
32
- plt.savefig(save_path)
33
-
34
-
35
- class DataStructure:
36
- spo = {
37
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
38
- "triplets": [
39
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
40
- "p": {"text": "出版社", "l": 15, "r": 18},
41
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
42
- "source": "baidu"
43
- }
44
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
45
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
46
-
47
-
48
- def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
49
- # 两个句子的jacccard系数
50
- # 判断输入来重新定义ipt_level和sim_level
51
-
52
- # a = set(ipt1.split())
53
- # b = set(ipt2.split())
54
- a = set(ipt1)
55
- b = set(ipt2)
56
- c = a.intersection(b)
57
- # spical situation:
58
- if not ipt1 and not ipt2:
59
- return 0
60
- return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
61
-
62
-
63
- class STEM(object):
64
-
65
- def __init__(self, IPT_MODEL_PATH):
66
- self.ltp = LTP(IPT_MODEL_PATH)
67
-
68
- def start_by_dep(self, sentence):
69
- seg, hidden = self.ltp.seg([sentence])
70
- dep = self.ltp.dep(hidden) # , graph=False)
71
- seg, dep = seg[0], dep[0]
72
- for i in dep:
73
- # 主谓宾
74
- if 'SBV' == i[2]:
75
- subject = seg[i[0]]
76
- verb = seg[i[1]]
77
- if 'VOB' in i[2]:
78
- if seg[i[1]] == verb:
79
- object = seg[i[0]]
80
-
81
- return subject
82
-
83
- return None
84
-
85
- def start_by_srl(self, sentence):
86
- """
87
- 用语义角色标注工具
88
- :param sentence: "他叫汤姆去拿外衣。"
89
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
90
- """
91
- # 语义角色标注方法
92
- seg, hidden = self.ltp.seg([sentence])
93
- srl = self.ltp.srl(hidden)
94
- seg, srl = seg[0], srl[0]
95
- events = []
96
- for wdx, each_srl in enumerate(srl):
97
- if each_srl:
98
- args = []
99
- for arg in each_srl:
100
- args.extend(seg[arg[1]:arg[2] + 1])
101
- # 添加上谓词
102
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
103
- events.append(args)
104
- # print(events)
105
- return events
106
-
107
-
108
- # 这个是另一种
109
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
110
- def subject_object_labeling_new(spo_list, text):
111
- pass
112
-
113
-
114
- # 这个是传统格式的
115
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
116
- def subject_object_labeling(spo_list, text):
117
- # TODO
118
- '''
119
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
120
- :param spo_list:
121
- :param text:
122
- :return: labeling_list
123
- '''
124
-
125
- def _spo_list_to_spo_predicate_dict(spo_list):
126
- spo_predicate_dict = dict()
127
- for spo_item in spo_list:
128
- predicate = spo_item["predicate"]
129
- subject = spo_item["subject"]
130
- object = spo_item["object"]
131
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
132
- return spo_predicate_dict
133
-
134
- def _index_q_list_in_k_list(q_list, k_list):
135
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
136
- q_list_length = len(q_list)
137
- k_list_length = len(k_list)
138
- for idx in range(k_list_length - q_list_length + 1):
139
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
140
- # print(idx, t)
141
- if all(t):
142
- # print(idx)
143
- idx_start = idx
144
- return idx_start
145
-
146
- def _labeling_type(spo, spo_type):
147
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
148
- labeling_list[idx_start] = 'B-' + spo_type
149
- if len(spo) == 2:
150
- labeling_list[idx_start + 1] = 'I-' + spo_type
151
- elif len(spo) >= 3:
152
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
153
- else:
154
- pass
155
-
156
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
157
- labeling_list = ['O'] * len(text)
158
- # count = 0
159
- for predicate, spo_list_form in spo_predicate_dict.items():
160
- if predicate in text:
161
- for (spo_subject, spo_object) in spo_list_form:
162
- # if predicate not in spo_subject and predicate not in spo_object:
163
- _labeling_type(spo_subject, 'SUB')
164
- _labeling_type(spo_object, 'OBJ')
165
- _labeling_type(predicate, 'PRE')
166
- # count += 1
167
- # print(count)
168
- # if count == 2:
169
- # print()
170
- if labeling_list != ['O'] * len(text):
171
- return labeling_list
172
- return None
173
-
174
-
175
- def label(text, labels):
176
- '''
177
- 返回两列的标记数据序列
178
- :param text:
179
- :param labels:
180
- :return:
181
- '''
182
- train_sequence = '\n'.join(
183
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
184
- return train_sequence
185
-
186
-
187
- def convert_crf_format_10_fold(corpus, objdir_path):
188
- '''
189
- 把已经是crf格式的数据,分成十折。
190
- para:
191
-
192
- '''
193
- # corpus = list(range(1,22))
194
- j_mkdir(objdir_path)
195
- split_position = int(len(corpus) / 10)
196
- for k in range(0, 10):
197
- if k == 9:
198
- dev_set = corpus[k * split_position:]
199
- train_set = corpus[:k * split_position]
200
- else:
201
- dev_set = corpus[k * split_position: (k + 1) * split_position]
202
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
203
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
204
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
205
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
206
-
207
-
208
- def read_seq_res(path, labels):
209
- '''
210
- 读序列标注三列数据的方法
211
- :param path:
212
- :param labels:
213
- :return:
214
- '''
215
- with codecs.open(path, 'r', 'utf-8') as rd:
216
- seqs_str = rd.read().strip()
217
- seqs_list = seqs_str.split('\n\n')
218
- text, raw_label, predict_label = [], [], []
219
- for seq in seqs_list:
220
- seq_split = seq.split('\n')
221
- text_tmp = ''
222
- raw_index_dict, pre_index_dict = {}, {}
223
- for label in labels:
224
- raw_index_dict.setdefault(label, [])
225
- pre_index_dict.setdefault(label, [])
226
- for idx, line in enumerate(seq_split):
227
- tmp = line.split('\t')
228
- text_tmp += tmp[0]
229
- if tmp[1] in labels:
230
- raw_index_dict[tmp[1]].append(idx)
231
- if tmp[2] in labels:
232
- pre_index_dict[tmp[2]].append(idx)
233
- text.append(text_tmp)
234
- raw_label.append(raw_index_dict)
235
- predict_label.append(pre_index_dict)
236
- return text, raw_label, predict_label
237
-
238
-
239
- def kfold(corpus, path, k=9, is_shuffle=True):
240
- '''
241
- k是10份中训练集占了几份
242
- '''
243
- j_mkdir(path)
244
- if is_shuffle:
245
- random.shuffle(corpus)
246
- split_position = int(len(corpus) / 10)
247
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
248
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
249
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
250
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
251
- """
252
- import pandas as pd
253
- from sklearn.model_selection import KFold
254
-
255
- df = pd.DataFrame({
256
- "text": ["text_{}".format(i) for i in range(100)],
257
- "labels": ["label_{}".format(i % 10) for i in range(100)]
258
- })
259
- train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
260
- test_idx, val_idx = KFold(n_splits=2, shuffle=True).split(df).__next__()
261
- df_train = df.iloc[train_idx]
262
- df_val = df.iloc[val_idx]
263
- df_test = df.iloc[test_idx]
264
- print(train_idx)
265
- print(val_idx)
266
- print(test_idx)
267
- """
268
-
269
-
270
- # 读取crf序列格式的数据
271
- def read_seq_data(path):
272
- content = readtxt_list_all_strip(path)
273
- lines = [i.split('\t') if i else '' for i in content]
274
- print(lines)
275
- sequences, labels, sequence, label = [], [], [], []
276
- for idx, line in enumerate(lines):
277
- if line == '':
278
- if sequence:
279
- sequences.append(sequence)
280
- labels.append(label)
281
- sequence, label = [], []
282
- else:
283
- sequence.append(line[0])
284
- label.append(line[1])
285
- if idx == len(lines) - 1 and sequence:
286
- sequences.append(sequence)
287
- labels.append(label)
288
- return sequences, labels
289
-
290
-
291
- def split_5_percent(lines, sample_precent=5):
292
- random.seed(8)
293
- # lines = list(range(1, 109))
294
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
295
- div = int(len(lines) / 100)
296
- sample_num = div * sample_precent
297
- sample = random.sample(idx_lines, sample_num)
298
- sorted_sample = sorted(sample, key=lambda x: x[0])
299
- remove_idx = [i[0] for i in sorted_sample]
300
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
301
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
302
- print(less_has_raw_line_info)
303
- print(most)
304
- return most, less_has_raw_line_info
305
-
306
-
307
- def split_sentences(sentences, mode='chinese'):
308
- # sentences->Str
309
- # example '12“345。”“6789”'
310
- if mode == 'chinese':
311
- split_signs = list('。!?…')
312
- other_sign = "”"
313
- elif mode == 'english':
314
- split_signs = list('.!?')
315
- other_sign = '"'
316
- else:
317
- print('暂时还没有')
318
- split_signs = list('.!?')
319
- other_sign = '"'
320
- splited_sentences = []
321
- start_idx = 0
322
- for idx, char in enumerate(sentences):
323
- if idx == len(sentences) - 1:
324
- if char in split_signs:
325
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
326
- start_idx = idx + 1
327
- else:
328
- splited_sentences.append(sentences[start_idx:].strip())
329
- else:
330
- if char in split_signs:
331
- if sentences[idx + 1] == other_sign:
332
- if idx < len(sentences) - 2:
333
- # 处理。”。
334
- if sentences[idx + 2] not in split_signs:
335
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
336
- start_idx = idx + 2
337
- elif sentences[idx + 1] not in split_signs:
338
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
339
- start_idx = idx + 1
340
- return splited_sentences
341
-
342
-
343
- def pos_reduction():
344
- wnl = WordNetLemmatizer()
345
- # lemmatize nouns
346
- print(wnl.lemmatize('cars', 'n'))
347
- print(wnl.lemmatize('men', 'n'))
348
-
349
- # lemmatize verbs
350
- print(wnl.lemmatize('running', 'v'))
351
- print(wnl.lemmatize('ate', 'v'))
352
-
353
-
354
- class DataVisualization:
355
- # 和下面的类冲突了
356
- pass
357
-
358
-
359
- class CalcPPL(object):
360
- # ppl计算
361
- # https://www.scribendi.ai/comparing-bert-and-gpt-2-as-language-models-to-score-the-grammatical-correctness-of-a-sentence/
362
- def __init__(self, path):
363
- self.model = BertForMaskedLM.from_pretrained(path)
364
- self.model.eval()
365
- # Load pre-trained model tokenizer (vocabulary)
366
- self.tokenizer = BertTokenizer.from_pretrained(path)
367
-
368
- def ppl_1(self, sentence):
369
- tokenizer = self.tokenizer
370
- model = self.tokenizer
371
- tokenize_input = tokenizer.tokenize(sentence)
372
- tokenize_input = tokenize_input
373
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
374
- with torch.no_grad():
375
- loss = model(tensor_input, labels=tensor_input)[0]
376
- return np.exp(loss.detach().numpy())
377
-
378
- # [1] Salazar J, Liang D, Nguyen T Q, et al. Masked Language Model Scoring[C]//Proceedings of ACL. 2020: 2699-2712.
379
- def ppl_2(self, sentence):
380
- tokenizer = self.tokenizer
381
- model = self.tokenizer
382
- with torch.no_grad():
383
- tokenize_input = tokenizer.tokenize(sentence)
384
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
385
- sen_len = len(tokenize_input)
386
- sentence_loss = 0.
387
-
388
- for i, word in enumerate(tokenize_input):
389
- # add mask to i-th character of the sentence
390
- tokenize_input[i] = '[MASK]'
391
- mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
392
-
393
- output = model(mask_input)
394
-
395
- prediction_scores = output[0]
396
- softmax = nn.Softmax(dim=0)
397
- ps = softmax(prediction_scores[0, i]).log()
398
- word_loss = ps[tensor_input[0, i]]
399
- sentence_loss += word_loss.item()
400
-
401
- tokenize_input[i] = word
402
- ppl = np.exp(-sentence_loss / sen_len)
403
- # print("困惑度:", ppl)
404
- return ppl
405
-
406
- def test(self):
407
- sentence = "输入句子:"
408
- ppl = self.ppl_1(sentence)
409
- ppl2 = self.ppl_2(sentence)
410
- print(ppl)
411
- print(ppl2)
412
-
413
-
414
- class Evaluate():
415
- def __init__(self):
416
- pass
417
-
418
- def auc_metric(self, k):
419
- pass
420
-
421
- def map_metric(self):
422
- pass
423
-
424
- def ndcg(self, n, y_true, y_score):
425
- report = metrics.ndcg_score(y_true, y_score)
426
- return report
427
-
428
-
429
- class DecideTreeUtils:
430
- @staticmethod
431
- def draw(bst):
432
- # xgb 画图
433
- fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
434
- xgb.plot_tree(bst, ax=ax_tree)
435
- fig_tree.savefig('tree.png')
436
- plt.show()
437
-
438
-
439
- def seed_everything(seed=7777777) -> None:
440
- """
441
- 设置整个开发环境的seed
442
- :param seed:
443
- :param device:
444
- :return:
445
- """
446
- random.seed(seed)
447
- os.environ['PYTHONHASHSEED'] = str(seed)
448
- np.random.seed(seed)
449
- torch.manual_seed(seed) # CPU随机种子确定
450
- torch.cuda.manual_seed(seed)
451
- torch.cuda.manual_seed_all(seed)
452
- # some cudnn methods can be random even after fixing the seed
453
- # unless you tell it to be deterministic
454
- torch.backends.cudnn.deterministic = True
455
-
456
-
457
- if __name__ == '__main__':
458
- stem = STEM(IPT_MODEL_PATH)
459
- test_sentence = '美国袭击伊拉克'
460
- a = stem.start_by_srl(test_sentence)
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+ """
24
+ return bytes
25
+
26
+ reference:
27
+ 1. https://zhuanlan.zhihu.com/p/665172400
28
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
29
+ """
30
+ # reference1
31
+ # attention
32
+ # FFN
33
+ # Layer Norm
34
+ r1 = (
35
+ sequence_length
36
+ * batch_size
37
+ * hidden_dim
38
+ * lay_number
39
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
40
+ )
41
+ # reference2
42
+ r2 = (
43
+ lay_number * (2 * sequence_length * attention_heads_num + 16 * hidden_dim)
44
+ * sequence_length
45
+ * batch_size
46
+ / gpu_num
47
+ )
48
+ print(r1)
49
+ print(r2)
50
+ return r1
51
+
52
+
53
+ class DataAnalysis:
54
+ @staticmethod
55
+ def draw_pic(df, save_path):
56
+ """
57
+ 画直方图,对比两个不同类别差异
58
+ :param df: pd.DataFrame
59
+ :param save_path: str
60
+ :return:
61
+ """
62
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
63
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
64
+ plt.legend()
65
+ plt.savefig(save_path)
66
+
67
+
68
+ class DataStructure:
69
+ spo = {
70
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
71
+ "triplets": [
72
+ {
73
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
74
+ "p": {"text": "出版社", "l": 15, "r": 18},
75
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
76
+ }
77
+ ],
78
+ "source": "baidu",
79
+ }
80
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
81
+ ner_label_example = (
82
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
83
+ )
84
+
85
+
86
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
87
+ # 两个句子的jacccard系数
88
+ # 判断输入来重新定义ipt_level和sim_level
89
+
90
+ # a = set(ipt1.split())
91
+ # b = set(ipt2.split())
92
+ a = set(ipt1)
93
+ b = set(ipt2)
94
+ c = a.intersection(b)
95
+ # spical situation:
96
+ if not ipt1 and not ipt2:
97
+ return 0
98
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
99
+
100
+
101
+ class STEM(object):
102
+ def __init__(self, IPT_MODEL_PATH):
103
+ self.ltp = LTP(IPT_MODEL_PATH)
104
+
105
+ def start_by_dep(self, sentence):
106
+ seg, hidden = self.ltp.seg([sentence])
107
+ dep = self.ltp.dep(hidden) # , graph=False)
108
+ seg, dep = seg[0], dep[0]
109
+ for i in dep:
110
+ # 主谓宾
111
+ if "SBV" == i[2]:
112
+ subject = seg[i[0]]
113
+ verb = seg[i[1]]
114
+ if "VOB" in i[2]:
115
+ if seg[i[1]] == verb:
116
+ object = seg[i[0]]
117
+
118
+ return subject
119
+
120
+ return None
121
+
122
+ def start_by_srl(self, sentence):
123
+ """
124
+ 用语义角色标注工具
125
+ :param sentence: "他叫汤姆去拿外衣。"
126
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
127
+ """
128
+ # 语义角色标注方法
129
+ seg, hidden = self.ltp.seg([sentence])
130
+ srl = self.ltp.srl(hidden)
131
+ seg, srl = seg[0], srl[0]
132
+ events = []
133
+ for wdx, each_srl in enumerate(srl):
134
+ if each_srl:
135
+ args = []
136
+ for arg in each_srl:
137
+ args.extend(seg[arg[1]: arg[2] + 1])
138
+ # 添加上谓词
139
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
140
+ events.append(args)
141
+ # print(events)
142
+ return events
143
+
144
+
145
+ # 这个是另一种
146
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
147
+ def subject_object_labeling_new(spo_list, text):
148
+ pass
149
+
150
+
151
+ # 这个是传统格式的
152
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
153
+ def subject_object_labeling(spo_list, text):
154
+ # TODO
155
+ """
156
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
157
+ :param spo_list:
158
+ :param text:
159
+ :return: labeling_list
160
+ """
161
+
162
+ def _spo_list_to_spo_predicate_dict(spo_list):
163
+ spo_predicate_dict = dict()
164
+ for spo_item in spo_list:
165
+ predicate = spo_item["predicate"]
166
+ subject = spo_item["subject"]
167
+ object = spo_item["object"]
168
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
169
+ return spo_predicate_dict
170
+
171
+ def _index_q_list_in_k_list(q_list, k_list):
172
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
173
+ q_list_length = len(q_list)
174
+ k_list_length = len(k_list)
175
+ for idx in range(k_list_length - q_list_length + 1):
176
+ t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
177
+ # print(idx, t)
178
+ if all(t):
179
+ # print(idx)
180
+ idx_start = idx
181
+ return idx_start
182
+
183
+ def _labeling_type(spo, spo_type):
184
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
185
+ labeling_list[idx_start] = "B-" + spo_type
186
+ if len(spo) == 2:
187
+ labeling_list[idx_start + 1] = "I-" + spo_type
188
+ elif len(spo) >= 3:
189
+ labeling_list[idx_start + 1: idx_start + len(spo)] = ["I-" + spo_type] * (
190
+ len(spo) - 1
191
+ )
192
+ else:
193
+ pass
194
+
195
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
196
+ labeling_list = ["O"] * len(text)
197
+ # count = 0
198
+ for predicate, spo_list_form in spo_predicate_dict.items():
199
+ if predicate in text:
200
+ for (spo_subject, spo_object) in spo_list_form:
201
+ # if predicate not in spo_subject and predicate not in spo_object:
202
+ _labeling_type(spo_subject, "SUB")
203
+ _labeling_type(spo_object, "OBJ")
204
+ _labeling_type(predicate, "PRE")
205
+ # count += 1
206
+ # print(count)
207
+ # if count == 2:
208
+ # print()
209
+ if labeling_list != ["O"] * len(text):
210
+ return labeling_list
211
+ return None
212
+
213
+
214
+ def label(text, labels):
215
+ """
216
+ 返回两列的标记数据序列
217
+ :param text:
218
+ :param labels:
219
+ :return:
220
+ """
221
+ train_sequence = "\n".join(
222
+ [
223
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
224
+ for i in zip(list(text), labels)
225
+ ]
226
+ )
227
+ return train_sequence
228
+
229
+
230
+ def convert_crf_format_10_fold(corpus, objdir_path):
231
+ """
232
+ 把已经是crf格式的数据,分成十折。
233
+ para:
234
+
235
+ """
236
+ # corpus = list(range(1,22))
237
+ j_mkdir(objdir_path)
238
+ split_position = int(len(corpus) / 10)
239
+ for k in range(0, 10):
240
+ if k == 9:
241
+ dev_set = corpus[k * split_position:]
242
+ train_set = corpus[: k * split_position]
243
+ else:
244
+ dev_set = corpus[k * split_position: (k + 1) * split_position]
245
+ train_set = (
246
+ corpus[: k * split_position] + corpus[(k + 1) * split_position:]
247
+ )
248
+ writetxt_w_list(
249
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
250
+ )
251
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
253
+
254
+
255
+ def read_seq_res(path, labels):
256
+ """
257
+ 读序列标注三列数据的方法
258
+ :param path:
259
+ :param labels:
260
+ :return:
261
+ """
262
+ with codecs.open(path, "r", "utf-8") as rd:
263
+ seqs_str = rd.read().strip()
264
+ seqs_list = seqs_str.split("\n\n")
265
+ text, raw_label, predict_label = [], [], []
266
+ for seq in seqs_list:
267
+ seq_split = seq.split("\n")
268
+ text_tmp = ""
269
+ raw_index_dict, pre_index_dict = {}, {}
270
+ for label in labels:
271
+ raw_index_dict.setdefault(label, [])
272
+ pre_index_dict.setdefault(label, [])
273
+ for idx, line in enumerate(seq_split):
274
+ tmp = line.split("\t")
275
+ text_tmp += tmp[0]
276
+ if tmp[1] in labels:
277
+ raw_index_dict[tmp[1]].append(idx)
278
+ if tmp[2] in labels:
279
+ pre_index_dict[tmp[2]].append(idx)
280
+ text.append(text_tmp)
281
+ raw_label.append(raw_index_dict)
282
+ predict_label.append(pre_index_dict)
283
+ return text, raw_label, predict_label
284
+
285
+
286
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
287
+ """
288
+ k是10份中训练集占了几份
289
+ """
290
+ j_mkdir(path)
291
+ if is_shuffle:
292
+ random.shuffle(corpus)
293
+ split_position = int(len(corpus) / 10)
294
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position:]
295
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
296
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
298
+
299
+
300
+ def sample():
301
+ import pandas as pd
302
+ from sklearn.model_selection import StratifiedShuffleSplit
303
+
304
+ # 假设 df 是你的 DataFrame
305
+
306
+ df = pd.DataFrame({
307
+ "count_line": [i for i in range(100)],
308
+ "x": [i for i in range(100)],
309
+ "y": [i // 10 for i in range(100)],
310
+ })
311
+ print(df)
312
+ # count_line 是用于分层抽样的字段
313
+
314
+ # 创建 StratifiedShuffleSplit 对象,设置测试集比例为 0.1
315
+ split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
316
+
317
+ # 获取训练集和测试集的索引
318
+ train_index, test_index = next(split.split(df, df['y']))
319
+
320
+ # 根据索引划分训练集和测试集
321
+ train_df = df.loc[train_index]
322
+ test_df = df.loc[test_index]
323
+
324
+ # 打印训练集和测试集的行数
325
+ print("训练集行数:", len(train_df))
326
+ print("测试集行数:", len(test_df))
327
+
328
+
329
+ def kfold_df(df, save_dir=None):
330
+ """
331
+ 划分train test val集, 写为windows可读的csv。
332
+ :param df:pd.DataFrame
333
+ :param save_dir:
334
+ :return:
335
+ """
336
+ from sklearn.model_selection import KFold
337
+ import pandas as pd
338
+
339
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
340
+ df_test_and_val = df.iloc[test_and_val_idx]
341
+ test_idx, val_idx = (
342
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
343
+ )
344
+ df_train = df.iloc[train_idx]
345
+ df_val = df.iloc[val_idx]
346
+ df_test = df.iloc[test_idx]
347
+ if save_dir:
348
+ j_mkdir(save_dir)
349
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
350
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
351
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
352
+ return df_train, df_val, df_test
353
+
354
+
355
+ # 读取crf序列格式的数据
356
+ def read_seq_data(path):
357
+ content = readtxt_list_all_strip(path)
358
+ lines = [i.split("\t") if i else "" for i in content]
359
+ print(lines)
360
+ sequences, labels, sequence, label = [], [], [], []
361
+ for idx, line in enumerate(lines):
362
+ if line == "":
363
+ if sequence:
364
+ sequences.append(sequence)
365
+ labels.append(label)
366
+ sequence, label = [], []
367
+ else:
368
+ sequence.append(line[0])
369
+ label.append(line[1])
370
+ if idx == len(lines) - 1 and sequence:
371
+ sequences.append(sequence)
372
+ labels.append(label)
373
+ return sequences, labels
374
+
375
+
376
+ def split_5_percent(lines, sample_precent=5):
377
+ random.seed(8)
378
+ # lines = list(range(1, 109))
379
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
380
+ div = int(len(lines) / 100)
381
+ sample_num = div * sample_precent
382
+ sample = random.sample(idx_lines, sample_num)
383
+ sorted_sample = sorted(sample, key=lambda x: x[0])
384
+ remove_idx = [i[0] for i in sorted_sample]
385
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
386
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
387
+ print(less_has_raw_line_info)
388
+ print(most)
389
+ return most, less_has_raw_line_info
390
+
391
+
392
+ def split_sentence(sentence, language="chinese", cross_line=True):
393
+ """
394
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
395
+ :param sentence:
396
+ :param language:
397
+ :param cross_line:
398
+ :return:
399
+ """
400
+ # sentences->Str
401
+ # example '12“345。”“6789”'
402
+ assert language in ["chinese", "english"], "unsupportable for other language"
403
+ sentence = sentence.replace("\r", "")
404
+ if language == "chinese":
405
+ split_signs = list("。!?…")
406
+ if cross_line:
407
+ split_signs.append("\n")
408
+ other_sign = "”"
409
+ elif language == "english":
410
+ split_signs = list(".!?")
411
+ other_sign = '"'
412
+ else:
413
+ split_signs = list(".!?")
414
+ other_sign = '"'
415
+ sentences = []
416
+ start_idx = 0
417
+ for idx, char in enumerate(sentence):
418
+ if idx == len(sentence) - 1:
419
+ if char in split_signs:
420
+ sentences.append(sentence[start_idx: idx + 1].strip())
421
+ start_idx = idx + 1
422
+ else:
423
+ sentences.append(sentence[start_idx:].strip())
424
+ else:
425
+ if char in split_signs:
426
+ if sentence[idx + 1] == other_sign:
427
+ if idx < len(sentence) - 2:
428
+ # 处理。”。
429
+ if sentence[idx + 2] not in split_signs:
430
+ sentences.append(sentence[start_idx: idx + 2].strip())
431
+ start_idx = idx + 2
432
+ elif sentence[idx + 1] not in split_signs:
433
+ sentences.append(sentence[start_idx: idx + 1].strip())
434
+ start_idx = idx + 1
435
+ return sentences
436
+
437
+
438
+ def pos_reduction():
439
+ wnl = WordNetLemmatizer()
440
+ # lemmatize nouns
441
+ print(wnl.lemmatize("cars", "n"))
442
+ print(wnl.lemmatize("men", "n"))
443
+
444
+ # lemmatize verbs
445
+ print(wnl.lemmatize("running", "v"))
446
+ print(wnl.lemmatize("ate", "v"))
447
+
448
+
449
+ class DataVisualization:
450
+ # 和下面的类冲突了
451
+ pass
452
+
453
+
454
+ class Evaluate:
455
+ def __init__(self):
456
+ pass
457
+
458
+ def auc_metric(self, k):
459
+ pass
460
+
461
+ def map_metric(self):
462
+ pass
463
+
464
+ def ndcg(self, n, y_true, y_score):
465
+ report = metrics.ndcg_score(y_true, y_score)
466
+ return report
467
+
468
+
469
+ class DecideTreeUtils:
470
+ @staticmethod
471
+ def draw(bst):
472
+ # xgb 画图
473
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
474
+ xgb.plot_tree(bst, ax=ax_tree)
475
+ fig_tree.savefig("tree.png")
476
+ plt.show()
477
+
478
+
479
+ def seed_everything(seed=7777777) -> None:
480
+ """
481
+ 设置整个开发环境的seed
482
+ :param seed:
483
+ :param device:
484
+ :return:
485
+ """
486
+ random.seed(seed)
487
+ os.environ["PYTHONHASHSEED"] = str(seed)
488
+ np.random.seed(seed)
489
+ torch.manual_seed(seed) # CPU随机种子确定
490
+ torch.cuda.manual_seed(seed)
491
+ torch.cuda.manual_seed_all(seed)
492
+ # some cudnn methods can be random even after fixing the seed
493
+ # unless you tell it to be deterministic
494
+ torch.backends.cudnn.deterministic = True
495
+
496
+
497
+ if __name__ == "__main__":
498
+ # stem = STEM(IPT_MODEL_PATH)
499
+ # test_sentence = "美国袭击伊拉克"
500
+ # a = stem.start_by_srl(test_sentence)
501
+
502
+ res = calc_llm_train_activation_memory(
503
+ model_name="",
504
+ sequence_length=2048,
505
+ batch_size=1,
506
+ hidden_dim=4096,
507
+ lay_number=28,
508
+ attention_heads_num=32,
509
+ gpu_num=1
510
+ )
511
+ print(res, "G")