nlpertools 1.0.5__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. nlpertools/__init__.py +23 -20
  2. nlpertools/algo/ac.py +18 -0
  3. nlpertools/algo/bit_ops.py +28 -0
  4. nlpertools/algo/kmp.py +94 -55
  5. nlpertools/algo/num_ops.py +12 -0
  6. nlpertools/algo/template.py +116 -0
  7. nlpertools/algo/union.py +13 -0
  8. nlpertools/cli.py +87 -0
  9. nlpertools/data_client.py +426 -257
  10. nlpertools/data_structure/base_structure.py +109 -13
  11. nlpertools/dataprocess.py +627 -3
  12. nlpertools/default_db_config.yml +41 -0
  13. nlpertools/draw/__init__.py +0 -0
  14. nlpertools/draw/draw.py +83 -0
  15. nlpertools/draw/math_func.py +33 -0
  16. nlpertools/get_2fa.py +0 -0
  17. nlpertools/io/__init__.py +3 -3
  18. nlpertools/io/dir.py +86 -36
  19. nlpertools/io/file.py +283 -222
  20. nlpertools/ml.py +511 -460
  21. nlpertools/monitor/__init__.py +0 -0
  22. nlpertools/monitor/gpu.py +18 -0
  23. nlpertools/monitor/memory.py +24 -0
  24. nlpertools/movie.py +36 -0
  25. nlpertools/nlpertools_config.yml +1 -0
  26. nlpertools/{openApi.py → open_api.py} +65 -65
  27. nlpertools/other.py +475 -249
  28. nlpertools/pic.py +288 -0
  29. nlpertools/plugin.py +43 -43
  30. nlpertools/reminder.py +98 -87
  31. nlpertools/utils/__init__.py +3 -3
  32. nlpertools/utils/lazy.py +727 -0
  33. nlpertools/utils/log_util.py +20 -0
  34. nlpertools/utils/package.py +89 -76
  35. nlpertools/utils/package_v1.py +94 -0
  36. nlpertools/utils/package_v2.py +117 -0
  37. nlpertools/utils_for_nlpertools.py +93 -93
  38. nlpertools/vector_index_demo.py +108 -0
  39. nlpertools/wrapper.py +161 -96
  40. {nlpertools-1.0.5.dist-info → nlpertools-1.0.8.dist-info}/LICENSE +200 -200
  41. nlpertools-1.0.8.dist-info/METADATA +132 -0
  42. nlpertools-1.0.8.dist-info/RECORD +49 -0
  43. {nlpertools-1.0.5.dist-info → nlpertools-1.0.8.dist-info}/WHEEL +1 -1
  44. nlpertools-1.0.8.dist-info/entry_points.txt +2 -0
  45. nlpertools-1.0.8.dist-info/top_level.txt +2 -0
  46. nlpertools_helper/__init__.py +10 -0
  47. nlpertools-1.0.5.dist-info/METADATA +0 -85
  48. nlpertools-1.0.5.dist-info/RECORD +0 -25
  49. nlpertools-1.0.5.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,460 +1,511 @@
1
- # encoding=utf-8
2
- import codecs
3
- import os
4
- import random
5
-
6
- from .io.dir import j_mkdir
7
- from .io.file import readtxt_list_all_strip, writetxt_w_list
8
- # import numpy as np
9
- # import seaborn as sns
10
- # import torch
11
- # import torch.nn as nn
12
- # import xgboost as xgb
13
- # from matplotlib import pyplot as plt
14
- # from nltk.stem import WordNetLemmatizer
15
- # from sklearn import metrics
16
- # from transformers import BertTokenizer, BertForMaskedLM
17
- from .utils.package import *
18
-
19
-
20
- class DataAnalysis:
21
- @staticmethod
22
- def draw_pic(df, save_path):
23
- """
24
- 画直方图,对比两个不同类别差异
25
- :param df: pd.DataFrame
26
- :param save_path: str
27
- :return:
28
- """
29
- sns.distplot(df[df["label"] == 1]["feature"], label="label1")
30
- sns.distplot(df[df["label"] == 0]["feature"], label="label2")
31
- plt.legend()
32
- plt.savefig(save_path)
33
-
34
-
35
- class DataStructure:
36
- spo = {
37
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
38
- "triplets": [
39
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
40
- "p": {"text": "出版社", "l": 15, "r": 18},
41
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
42
- "source": "baidu"
43
- }
44
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
45
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
46
-
47
-
48
- def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
49
- # 两个句子的jacccard系数
50
- # 判断输入来重新定义ipt_level和sim_level
51
-
52
- # a = set(ipt1.split())
53
- # b = set(ipt2.split())
54
- a = set(ipt1)
55
- b = set(ipt2)
56
- c = a.intersection(b)
57
- # spical situation:
58
- if not ipt1 and not ipt2:
59
- return 0
60
- return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
61
-
62
-
63
- class STEM(object):
64
-
65
- def __init__(self, IPT_MODEL_PATH):
66
- self.ltp = LTP(IPT_MODEL_PATH)
67
-
68
- def start_by_dep(self, sentence):
69
- seg, hidden = self.ltp.seg([sentence])
70
- dep = self.ltp.dep(hidden) # , graph=False)
71
- seg, dep = seg[0], dep[0]
72
- for i in dep:
73
- # 主谓宾
74
- if 'SBV' == i[2]:
75
- subject = seg[i[0]]
76
- verb = seg[i[1]]
77
- if 'VOB' in i[2]:
78
- if seg[i[1]] == verb:
79
- object = seg[i[0]]
80
-
81
- return subject
82
-
83
- return None
84
-
85
- def start_by_srl(self, sentence):
86
- """
87
- 用语义角色标注工具
88
- :param sentence: "他叫汤姆去拿外衣。"
89
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
90
- """
91
- # 语义角色标注方法
92
- seg, hidden = self.ltp.seg([sentence])
93
- srl = self.ltp.srl(hidden)
94
- seg, srl = seg[0], srl[0]
95
- events = []
96
- for wdx, each_srl in enumerate(srl):
97
- if each_srl:
98
- args = []
99
- for arg in each_srl:
100
- args.extend(seg[arg[1]:arg[2] + 1])
101
- # 添加上谓词
102
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
103
- events.append(args)
104
- # print(events)
105
- return events
106
-
107
-
108
- # 这个是另一种
109
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
110
- def subject_object_labeling_new(spo_list, text):
111
- pass
112
-
113
-
114
- # 这个是传统格式的
115
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
116
- def subject_object_labeling(spo_list, text):
117
- # TODO
118
- '''
119
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
120
- :param spo_list:
121
- :param text:
122
- :return: labeling_list
123
- '''
124
-
125
- def _spo_list_to_spo_predicate_dict(spo_list):
126
- spo_predicate_dict = dict()
127
- for spo_item in spo_list:
128
- predicate = spo_item["predicate"]
129
- subject = spo_item["subject"]
130
- object = spo_item["object"]
131
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
132
- return spo_predicate_dict
133
-
134
- def _index_q_list_in_k_list(q_list, k_list):
135
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
136
- q_list_length = len(q_list)
137
- k_list_length = len(k_list)
138
- for idx in range(k_list_length - q_list_length + 1):
139
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
140
- # print(idx, t)
141
- if all(t):
142
- # print(idx)
143
- idx_start = idx
144
- return idx_start
145
-
146
- def _labeling_type(spo, spo_type):
147
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
148
- labeling_list[idx_start] = 'B-' + spo_type
149
- if len(spo) == 2:
150
- labeling_list[idx_start + 1] = 'I-' + spo_type
151
- elif len(spo) >= 3:
152
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
153
- else:
154
- pass
155
-
156
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
157
- labeling_list = ['O'] * len(text)
158
- # count = 0
159
- for predicate, spo_list_form in spo_predicate_dict.items():
160
- if predicate in text:
161
- for (spo_subject, spo_object) in spo_list_form:
162
- # if predicate not in spo_subject and predicate not in spo_object:
163
- _labeling_type(spo_subject, 'SUB')
164
- _labeling_type(spo_object, 'OBJ')
165
- _labeling_type(predicate, 'PRE')
166
- # count += 1
167
- # print(count)
168
- # if count == 2:
169
- # print()
170
- if labeling_list != ['O'] * len(text):
171
- return labeling_list
172
- return None
173
-
174
-
175
- def label(text, labels):
176
- '''
177
- 返回两列的标记数据序列
178
- :param text:
179
- :param labels:
180
- :return:
181
- '''
182
- train_sequence = '\n'.join(
183
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
184
- return train_sequence
185
-
186
-
187
- def convert_crf_format_10_fold(corpus, objdir_path):
188
- '''
189
- 把已经是crf格式的数据,分成十折。
190
- para:
191
-
192
- '''
193
- # corpus = list(range(1,22))
194
- j_mkdir(objdir_path)
195
- split_position = int(len(corpus) / 10)
196
- for k in range(0, 10):
197
- if k == 9:
198
- dev_set = corpus[k * split_position:]
199
- train_set = corpus[:k * split_position]
200
- else:
201
- dev_set = corpus[k * split_position: (k + 1) * split_position]
202
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
203
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
204
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
205
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
206
-
207
-
208
- def read_seq_res(path, labels):
209
- '''
210
- 读序列标注三列数据的方法
211
- :param path:
212
- :param labels:
213
- :return:
214
- '''
215
- with codecs.open(path, 'r', 'utf-8') as rd:
216
- seqs_str = rd.read().strip()
217
- seqs_list = seqs_str.split('\n\n')
218
- text, raw_label, predict_label = [], [], []
219
- for seq in seqs_list:
220
- seq_split = seq.split('\n')
221
- text_tmp = ''
222
- raw_index_dict, pre_index_dict = {}, {}
223
- for label in labels:
224
- raw_index_dict.setdefault(label, [])
225
- pre_index_dict.setdefault(label, [])
226
- for idx, line in enumerate(seq_split):
227
- tmp = line.split('\t')
228
- text_tmp += tmp[0]
229
- if tmp[1] in labels:
230
- raw_index_dict[tmp[1]].append(idx)
231
- if tmp[2] in labels:
232
- pre_index_dict[tmp[2]].append(idx)
233
- text.append(text_tmp)
234
- raw_label.append(raw_index_dict)
235
- predict_label.append(pre_index_dict)
236
- return text, raw_label, predict_label
237
-
238
-
239
- def kfold(corpus, path, k=9, is_shuffle=True):
240
- '''
241
- k是10份中训练集占了几份
242
- '''
243
- j_mkdir(path)
244
- if is_shuffle:
245
- random.shuffle(corpus)
246
- split_position = int(len(corpus) / 10)
247
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
248
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
249
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
250
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
251
- """
252
- import pandas as pd
253
- from sklearn.model_selection import KFold
254
-
255
- df = pd.DataFrame({
256
- "text": ["text_{}".format(i) for i in range(100)],
257
- "labels": ["label_{}".format(i % 10) for i in range(100)]
258
- })
259
- train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
260
- test_idx, val_idx = KFold(n_splits=2, shuffle=True).split(df).__next__()
261
- df_train = df.iloc[train_idx]
262
- df_val = df.iloc[val_idx]
263
- df_test = df.iloc[test_idx]
264
- print(train_idx)
265
- print(val_idx)
266
- print(test_idx)
267
- """
268
-
269
-
270
- # 读取crf序列格式的数据
271
- def read_seq_data(path):
272
- content = readtxt_list_all_strip(path)
273
- lines = [i.split('\t') if i else '' for i in content]
274
- print(lines)
275
- sequences, labels, sequence, label = [], [], [], []
276
- for idx, line in enumerate(lines):
277
- if line == '':
278
- if sequence:
279
- sequences.append(sequence)
280
- labels.append(label)
281
- sequence, label = [], []
282
- else:
283
- sequence.append(line[0])
284
- label.append(line[1])
285
- if idx == len(lines) - 1 and sequence:
286
- sequences.append(sequence)
287
- labels.append(label)
288
- return sequences, labels
289
-
290
-
291
- def split_5_percent(lines, sample_precent=5):
292
- random.seed(8)
293
- # lines = list(range(1, 109))
294
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
295
- div = int(len(lines) / 100)
296
- sample_num = div * sample_precent
297
- sample = random.sample(idx_lines, sample_num)
298
- sorted_sample = sorted(sample, key=lambda x: x[0])
299
- remove_idx = [i[0] for i in sorted_sample]
300
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
301
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
302
- print(less_has_raw_line_info)
303
- print(most)
304
- return most, less_has_raw_line_info
305
-
306
-
307
- def split_sentences(sentences, mode='chinese'):
308
- # sentences->Str
309
- # example '12“345。”“6789”'
310
- if mode == 'chinese':
311
- split_signs = list('。!?…')
312
- other_sign = "”"
313
- elif mode == 'english':
314
- split_signs = list('.!?')
315
- other_sign = '"'
316
- else:
317
- print('暂时还没有')
318
- split_signs = list('.!?')
319
- other_sign = '"'
320
- splited_sentences = []
321
- start_idx = 0
322
- for idx, char in enumerate(sentences):
323
- if idx == len(sentences) - 1:
324
- if char in split_signs:
325
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
326
- start_idx = idx + 1
327
- else:
328
- splited_sentences.append(sentences[start_idx:].strip())
329
- else:
330
- if char in split_signs:
331
- if sentences[idx + 1] == other_sign:
332
- if idx < len(sentences) - 2:
333
- # 处理。”。
334
- if sentences[idx + 2] not in split_signs:
335
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
336
- start_idx = idx + 2
337
- elif sentences[idx + 1] not in split_signs:
338
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
339
- start_idx = idx + 1
340
- return splited_sentences
341
-
342
-
343
- def pos_reduction():
344
- wnl = WordNetLemmatizer()
345
- # lemmatize nouns
346
- print(wnl.lemmatize('cars', 'n'))
347
- print(wnl.lemmatize('men', 'n'))
348
-
349
- # lemmatize verbs
350
- print(wnl.lemmatize('running', 'v'))
351
- print(wnl.lemmatize('ate', 'v'))
352
-
353
-
354
- class DataVisualization:
355
- # 和下面的类冲突了
356
- pass
357
-
358
-
359
- class CalcPPL(object):
360
- # ppl计算
361
- # https://www.scribendi.ai/comparing-bert-and-gpt-2-as-language-models-to-score-the-grammatical-correctness-of-a-sentence/
362
- def __init__(self, path):
363
- self.model = BertForMaskedLM.from_pretrained(path)
364
- self.model.eval()
365
- # Load pre-trained model tokenizer (vocabulary)
366
- self.tokenizer = BertTokenizer.from_pretrained(path)
367
-
368
- def ppl_1(self, sentence):
369
- tokenizer = self.tokenizer
370
- model = self.tokenizer
371
- tokenize_input = tokenizer.tokenize(sentence)
372
- tokenize_input = tokenize_input
373
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
374
- with torch.no_grad():
375
- loss = model(tensor_input, labels=tensor_input)[0]
376
- return np.exp(loss.detach().numpy())
377
-
378
- # [1] Salazar J, Liang D, Nguyen T Q, et al. Masked Language Model Scoring[C]//Proceedings of ACL. 2020: 2699-2712.
379
- def ppl_2(self, sentence):
380
- tokenizer = self.tokenizer
381
- model = self.tokenizer
382
- with torch.no_grad():
383
- tokenize_input = tokenizer.tokenize(sentence)
384
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
385
- sen_len = len(tokenize_input)
386
- sentence_loss = 0.
387
-
388
- for i, word in enumerate(tokenize_input):
389
- # add mask to i-th character of the sentence
390
- tokenize_input[i] = '[MASK]'
391
- mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
392
-
393
- output = model(mask_input)
394
-
395
- prediction_scores = output[0]
396
- softmax = nn.Softmax(dim=0)
397
- ps = softmax(prediction_scores[0, i]).log()
398
- word_loss = ps[tensor_input[0, i]]
399
- sentence_loss += word_loss.item()
400
-
401
- tokenize_input[i] = word
402
- ppl = np.exp(-sentence_loss / sen_len)
403
- # print("困惑度:", ppl)
404
- return ppl
405
-
406
- def test(self):
407
- sentence = "输入句子:"
408
- ppl = self.ppl_1(sentence)
409
- ppl2 = self.ppl_2(sentence)
410
- print(ppl)
411
- print(ppl2)
412
-
413
-
414
- class Evaluate():
415
- def __init__(self):
416
- pass
417
-
418
- def auc_metric(self, k):
419
- pass
420
-
421
- def map_metric(self):
422
- pass
423
-
424
- def ndcg(self, n, y_true, y_score):
425
- report = metrics.ndcg_score(y_true, y_score)
426
- return report
427
-
428
-
429
- class DecideTreeUtils:
430
- @staticmethod
431
- def draw(bst):
432
- # xgb 画图
433
- fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
434
- xgb.plot_tree(bst, ax=ax_tree)
435
- fig_tree.savefig('tree.png')
436
- plt.show()
437
-
438
-
439
- def seed_everything(seed=7777777) -> None:
440
- """
441
- 设置整个开发环境的seed
442
- :param seed:
443
- :param device:
444
- :return:
445
- """
446
- random.seed(seed)
447
- os.environ['PYTHONHASHSEED'] = str(seed)
448
- np.random.seed(seed)
449
- torch.manual_seed(seed) # CPU随机种子确定
450
- torch.cuda.manual_seed(seed)
451
- torch.cuda.manual_seed_all(seed)
452
- # some cudnn methods can be random even after fixing the seed
453
- # unless you tell it to be deterministic
454
- torch.backends.cudnn.deterministic = True
455
-
456
-
457
- if __name__ == '__main__':
458
- stem = STEM(IPT_MODEL_PATH)
459
- test_sentence = '美国袭击伊拉克'
460
- a = stem.start_by_srl(test_sentence)
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+ """
24
+ return bytes
25
+
26
+ reference:
27
+ 1. https://zhuanlan.zhihu.com/p/665172400
28
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
29
+ """
30
+ # reference1
31
+ # attention
32
+ # FFN
33
+ # Layer Norm
34
+ r1 = (
35
+ sequence_length
36
+ * batch_size
37
+ * hidden_dim
38
+ * lay_number
39
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
40
+ )
41
+ # reference2
42
+ r2 = (
43
+ lay_number * (2 * sequence_length * attention_heads_num + 16 * hidden_dim)
44
+ * sequence_length
45
+ * batch_size
46
+ / gpu_num
47
+ )
48
+ print(r1)
49
+ print(r2)
50
+ return r1
51
+
52
+
53
+ class DataAnalysis:
54
+ @staticmethod
55
+ def draw_pic(df, save_path):
56
+ """
57
+ 画直方图,对比两个不同类别差异
58
+ :param df: pd.DataFrame
59
+ :param save_path: str
60
+ :return:
61
+ """
62
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
63
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
64
+ plt.legend()
65
+ plt.savefig(save_path)
66
+
67
+
68
+ class DataStructure:
69
+ spo = {
70
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
71
+ "triplets": [
72
+ {
73
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
74
+ "p": {"text": "出版社", "l": 15, "r": 18},
75
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
76
+ }
77
+ ],
78
+ "source": "baidu",
79
+ }
80
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
81
+ ner_label_example = (
82
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
83
+ )
84
+
85
+
86
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
87
+ # 两个句子的jacccard系数
88
+ # 判断输入来重新定义ipt_level和sim_level
89
+
90
+ # a = set(ipt1.split())
91
+ # b = set(ipt2.split())
92
+ a = set(ipt1)
93
+ b = set(ipt2)
94
+ c = a.intersection(b)
95
+ # spical situation:
96
+ if not ipt1 and not ipt2:
97
+ return 0
98
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
99
+
100
+
101
+ class STEM(object):
102
+ def __init__(self, IPT_MODEL_PATH):
103
+ self.ltp = LTP(IPT_MODEL_PATH)
104
+
105
+ def start_by_dep(self, sentence):
106
+ seg, hidden = self.ltp.seg([sentence])
107
+ dep = self.ltp.dep(hidden) # , graph=False)
108
+ seg, dep = seg[0], dep[0]
109
+ for i in dep:
110
+ # 主谓宾
111
+ if "SBV" == i[2]:
112
+ subject = seg[i[0]]
113
+ verb = seg[i[1]]
114
+ if "VOB" in i[2]:
115
+ if seg[i[1]] == verb:
116
+ object = seg[i[0]]
117
+
118
+ return subject
119
+
120
+ return None
121
+
122
+ def start_by_srl(self, sentence):
123
+ """
124
+ 用语义角色标注工具
125
+ :param sentence: "他叫汤姆去拿外衣。"
126
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
127
+ """
128
+ # 语义角色标注方法
129
+ seg, hidden = self.ltp.seg([sentence])
130
+ srl = self.ltp.srl(hidden)
131
+ seg, srl = seg[0], srl[0]
132
+ events = []
133
+ for wdx, each_srl in enumerate(srl):
134
+ if each_srl:
135
+ args = []
136
+ for arg in each_srl:
137
+ args.extend(seg[arg[1]: arg[2] + 1])
138
+ # 添加上谓词
139
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
140
+ events.append(args)
141
+ # print(events)
142
+ return events
143
+
144
+
145
+ # 这个是另一种
146
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
147
+ def subject_object_labeling_new(spo_list, text):
148
+ pass
149
+
150
+
151
+ # 这个是传统格式的
152
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
153
+ def subject_object_labeling(spo_list, text):
154
+ # TODO
155
+ """
156
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
157
+ :param spo_list:
158
+ :param text:
159
+ :return: labeling_list
160
+ """
161
+
162
+ def _spo_list_to_spo_predicate_dict(spo_list):
163
+ spo_predicate_dict = dict()
164
+ for spo_item in spo_list:
165
+ predicate = spo_item["predicate"]
166
+ subject = spo_item["subject"]
167
+ object = spo_item["object"]
168
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
169
+ return spo_predicate_dict
170
+
171
+ def _index_q_list_in_k_list(q_list, k_list):
172
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
173
+ q_list_length = len(q_list)
174
+ k_list_length = len(k_list)
175
+ for idx in range(k_list_length - q_list_length + 1):
176
+ t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
177
+ # print(idx, t)
178
+ if all(t):
179
+ # print(idx)
180
+ idx_start = idx
181
+ return idx_start
182
+
183
+ def _labeling_type(spo, spo_type):
184
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
185
+ labeling_list[idx_start] = "B-" + spo_type
186
+ if len(spo) == 2:
187
+ labeling_list[idx_start + 1] = "I-" + spo_type
188
+ elif len(spo) >= 3:
189
+ labeling_list[idx_start + 1: idx_start + len(spo)] = ["I-" + spo_type] * (
190
+ len(spo) - 1
191
+ )
192
+ else:
193
+ pass
194
+
195
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
196
+ labeling_list = ["O"] * len(text)
197
+ # count = 0
198
+ for predicate, spo_list_form in spo_predicate_dict.items():
199
+ if predicate in text:
200
+ for (spo_subject, spo_object) in spo_list_form:
201
+ # if predicate not in spo_subject and predicate not in spo_object:
202
+ _labeling_type(spo_subject, "SUB")
203
+ _labeling_type(spo_object, "OBJ")
204
+ _labeling_type(predicate, "PRE")
205
+ # count += 1
206
+ # print(count)
207
+ # if count == 2:
208
+ # print()
209
+ if labeling_list != ["O"] * len(text):
210
+ return labeling_list
211
+ return None
212
+
213
+
214
+ def label(text, labels):
215
+ """
216
+ 返回两列的标记数据序列
217
+ :param text:
218
+ :param labels:
219
+ :return:
220
+ """
221
+ train_sequence = "\n".join(
222
+ [
223
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
224
+ for i in zip(list(text), labels)
225
+ ]
226
+ )
227
+ return train_sequence
228
+
229
+
230
+ def convert_crf_format_10_fold(corpus, objdir_path):
231
+ """
232
+ 把已经是crf格式的数据,分成十折。
233
+ para:
234
+
235
+ """
236
+ # corpus = list(range(1,22))
237
+ j_mkdir(objdir_path)
238
+ split_position = int(len(corpus) / 10)
239
+ for k in range(0, 10):
240
+ if k == 9:
241
+ dev_set = corpus[k * split_position:]
242
+ train_set = corpus[: k * split_position]
243
+ else:
244
+ dev_set = corpus[k * split_position: (k + 1) * split_position]
245
+ train_set = (
246
+ corpus[: k * split_position] + corpus[(k + 1) * split_position:]
247
+ )
248
+ writetxt_w_list(
249
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
250
+ )
251
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
253
+
254
+
255
+ def read_seq_res(path, labels):
256
+ """
257
+ 读序列标注三列数据的方法
258
+ :param path:
259
+ :param labels:
260
+ :return:
261
+ """
262
+ with codecs.open(path, "r", "utf-8") as rd:
263
+ seqs_str = rd.read().strip()
264
+ seqs_list = seqs_str.split("\n\n")
265
+ text, raw_label, predict_label = [], [], []
266
+ for seq in seqs_list:
267
+ seq_split = seq.split("\n")
268
+ text_tmp = ""
269
+ raw_index_dict, pre_index_dict = {}, {}
270
+ for label in labels:
271
+ raw_index_dict.setdefault(label, [])
272
+ pre_index_dict.setdefault(label, [])
273
+ for idx, line in enumerate(seq_split):
274
+ tmp = line.split("\t")
275
+ text_tmp += tmp[0]
276
+ if tmp[1] in labels:
277
+ raw_index_dict[tmp[1]].append(idx)
278
+ if tmp[2] in labels:
279
+ pre_index_dict[tmp[2]].append(idx)
280
+ text.append(text_tmp)
281
+ raw_label.append(raw_index_dict)
282
+ predict_label.append(pre_index_dict)
283
+ return text, raw_label, predict_label
284
+
285
+
286
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
287
+ """
288
+ k是10份中训练集占了几份
289
+ """
290
+ j_mkdir(path)
291
+ if is_shuffle:
292
+ random.shuffle(corpus)
293
+ split_position = int(len(corpus) / 10)
294
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position:]
295
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
296
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
298
+
299
+
300
+ def sample():
301
+ import pandas as pd
302
+ from sklearn.model_selection import StratifiedShuffleSplit
303
+
304
+ # 假设 df 是你的 DataFrame
305
+
306
+ df = pd.DataFrame({
307
+ "count_line": [i for i in range(100)],
308
+ "x": [i for i in range(100)],
309
+ "y": [i // 10 for i in range(100)],
310
+ })
311
+ print(df)
312
+ # count_line 是用于分层抽样的字段
313
+
314
+ # 创建 StratifiedShuffleSplit 对象,设置测试集比例为 0.1
315
+ split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
316
+
317
+ # 获取训练集和测试集的索引
318
+ train_index, test_index = next(split.split(df, df['y']))
319
+
320
+ # 根据索引划分训练集和测试集
321
+ train_df = df.loc[train_index]
322
+ test_df = df.loc[test_index]
323
+
324
+ # 打印训练集和测试集的行数
325
+ print("训练集行数:", len(train_df))
326
+ print("测试集行数:", len(test_df))
327
+
328
+
329
+ def kfold_df(df, save_dir=None):
330
+ """
331
+ 划分train test val集, 写为windows可读的csv。
332
+ :param df:pd.DataFrame
333
+ :param save_dir:
334
+ :return:
335
+ """
336
+ from sklearn.model_selection import KFold
337
+ import pandas as pd
338
+
339
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
340
+ df_test_and_val = df.iloc[test_and_val_idx]
341
+ test_idx, val_idx = (
342
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
343
+ )
344
+ df_train = df.iloc[train_idx]
345
+ df_val = df.iloc[val_idx]
346
+ df_test = df.iloc[test_idx]
347
+ if save_dir:
348
+ j_mkdir(save_dir)
349
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
350
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
351
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
352
+ return df_train, df_val, df_test
353
+
354
+
355
+ # 读取crf序列格式的数据
356
+ def read_seq_data(path):
357
+ content = readtxt_list_all_strip(path)
358
+ lines = [i.split("\t") if i else "" for i in content]
359
+ print(lines)
360
+ sequences, labels, sequence, label = [], [], [], []
361
+ for idx, line in enumerate(lines):
362
+ if line == "":
363
+ if sequence:
364
+ sequences.append(sequence)
365
+ labels.append(label)
366
+ sequence, label = [], []
367
+ else:
368
+ sequence.append(line[0])
369
+ label.append(line[1])
370
+ if idx == len(lines) - 1 and sequence:
371
+ sequences.append(sequence)
372
+ labels.append(label)
373
+ return sequences, labels
374
+
375
+
376
+ def split_5_percent(lines, sample_precent=5):
377
+ random.seed(8)
378
+ # lines = list(range(1, 109))
379
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
380
+ div = int(len(lines) / 100)
381
+ sample_num = div * sample_precent
382
+ sample = random.sample(idx_lines, sample_num)
383
+ sorted_sample = sorted(sample, key=lambda x: x[0])
384
+ remove_idx = [i[0] for i in sorted_sample]
385
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
386
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
387
+ print(less_has_raw_line_info)
388
+ print(most)
389
+ return most, less_has_raw_line_info
390
+
391
+
392
+ def split_sentence(sentence, language="chinese", cross_line=True):
393
+ """
394
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
395
+ :param sentence:
396
+ :param language:
397
+ :param cross_line:
398
+ :return:
399
+ """
400
+ # sentences->Str
401
+ # example '12“345。”“6789”'
402
+ assert language in ["chinese", "english"], "unsupportable for other language"
403
+ sentence = sentence.replace("\r", "")
404
+ if language == "chinese":
405
+ split_signs = list("。!?…")
406
+ if cross_line:
407
+ split_signs.append("\n")
408
+ other_sign = "”"
409
+ elif language == "english":
410
+ split_signs = list(".!?")
411
+ other_sign = '"'
412
+ else:
413
+ split_signs = list(".!?")
414
+ other_sign = '"'
415
+ sentences = []
416
+ start_idx = 0
417
+ for idx, char in enumerate(sentence):
418
+ if idx == len(sentence) - 1:
419
+ if char in split_signs:
420
+ sentences.append(sentence[start_idx: idx + 1].strip())
421
+ start_idx = idx + 1
422
+ else:
423
+ sentences.append(sentence[start_idx:].strip())
424
+ else:
425
+ if char in split_signs:
426
+ if sentence[idx + 1] == other_sign:
427
+ if idx < len(sentence) - 2:
428
+ # 处理。”。
429
+ if sentence[idx + 2] not in split_signs:
430
+ sentences.append(sentence[start_idx: idx + 2].strip())
431
+ start_idx = idx + 2
432
+ elif sentence[idx + 1] not in split_signs:
433
+ sentences.append(sentence[start_idx: idx + 1].strip())
434
+ start_idx = idx + 1
435
+ return sentences
436
+
437
+
438
+ def pos_reduction():
439
+ wnl = WordNetLemmatizer()
440
+ # lemmatize nouns
441
+ print(wnl.lemmatize("cars", "n"))
442
+ print(wnl.lemmatize("men", "n"))
443
+
444
+ # lemmatize verbs
445
+ print(wnl.lemmatize("running", "v"))
446
+ print(wnl.lemmatize("ate", "v"))
447
+
448
+
449
+ class DataVisualization:
450
+ # 和下面的类冲突了
451
+ pass
452
+
453
+
454
+ class Evaluate:
455
+ def __init__(self):
456
+ pass
457
+
458
+ def auc_metric(self, k):
459
+ pass
460
+
461
+ def map_metric(self):
462
+ pass
463
+
464
+ def ndcg(self, n, y_true, y_score):
465
+ report = metrics.ndcg_score(y_true, y_score)
466
+ return report
467
+
468
+
469
+ class DecideTreeUtils:
470
+ @staticmethod
471
+ def draw(bst):
472
+ # xgb 画图
473
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
474
+ xgb.plot_tree(bst, ax=ax_tree)
475
+ fig_tree.savefig("tree.png")
476
+ plt.show()
477
+
478
+
479
+ def seed_everything(seed=7777777) -> None:
480
+ """
481
+ 设置整个开发环境的seed
482
+ :param seed:
483
+ :param device:
484
+ :return:
485
+ """
486
+ random.seed(seed)
487
+ os.environ["PYTHONHASHSEED"] = str(seed)
488
+ np.random.seed(seed)
489
+ torch.manual_seed(seed) # CPU随机种子确定
490
+ torch.cuda.manual_seed(seed)
491
+ torch.cuda.manual_seed_all(seed)
492
+ # some cudnn methods can be random even after fixing the seed
493
+ # unless you tell it to be deterministic
494
+ torch.backends.cudnn.deterministic = True
495
+
496
+
497
+ if __name__ == "__main__":
498
+ # stem = STEM(IPT_MODEL_PATH)
499
+ # test_sentence = "美国袭击伊拉克"
500
+ # a = stem.start_by_srl(test_sentence)
501
+
502
+ res = calc_llm_train_activation_memory(
503
+ model_name="",
504
+ sequence_length=2048,
505
+ batch_size=1,
506
+ hidden_dim=4096,
507
+ lay_number=28,
508
+ attention_heads_num=32,
509
+ gpu_num=1
510
+ )
511
+ print(res, "G")