nlpertools 1.0.5__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. nlpertools/__init__.py +24 -20
  2. nlpertools/algo/ac.py +18 -0
  3. nlpertools/algo/bit_ops.py +28 -0
  4. nlpertools/algo/kmp.py +94 -55
  5. nlpertools/algo/num_ops.py +12 -0
  6. nlpertools/algo/template.py +116 -0
  7. nlpertools/algo/union.py +13 -0
  8. nlpertools/data_client.py +387 -257
  9. nlpertools/data_structure/base_structure.py +109 -13
  10. nlpertools/dataprocess.py +611 -3
  11. nlpertools/default_db_config.yml +41 -0
  12. nlpertools/io/__init__.py +3 -3
  13. nlpertools/io/dir.py +54 -36
  14. nlpertools/io/file.py +277 -222
  15. nlpertools/ml.py +483 -460
  16. nlpertools/monitor/__init__.py +0 -0
  17. nlpertools/monitor/gpu.py +18 -0
  18. nlpertools/monitor/memory.py +24 -0
  19. nlpertools/movie.py +36 -0
  20. nlpertools/nlpertools_config.yml +1 -0
  21. nlpertools/{openApi.py → open_api.py} +65 -65
  22. nlpertools/other.py +364 -249
  23. nlpertools/pic.py +288 -0
  24. nlpertools/plugin.py +43 -43
  25. nlpertools/reminder.py +98 -87
  26. nlpertools/utils/__init__.py +3 -3
  27. nlpertools/utils/lazy.py +727 -0
  28. nlpertools/utils/log_util.py +20 -0
  29. nlpertools/utils/package.py +89 -76
  30. nlpertools/utils/package_v1.py +94 -0
  31. nlpertools/utils/package_v2.py +117 -0
  32. nlpertools/utils_for_nlpertools.py +93 -93
  33. nlpertools/vector_index_demo.py +108 -0
  34. nlpertools/wrapper.py +161 -96
  35. {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
  36. nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
  37. nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
  38. {nlpertools-1.0.5.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
  39. nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
  40. nlpertools_helper/__init__.py +10 -0
  41. nlpertools-1.0.5.dist-info/METADATA +0 -85
  42. nlpertools-1.0.5.dist-info/RECORD +0 -25
  43. nlpertools-1.0.5.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,460 +1,483 @@
1
- # encoding=utf-8
2
- import codecs
3
- import os
4
- import random
5
-
6
- from .io.dir import j_mkdir
7
- from .io.file import readtxt_list_all_strip, writetxt_w_list
8
- # import numpy as np
9
- # import seaborn as sns
10
- # import torch
11
- # import torch.nn as nn
12
- # import xgboost as xgb
13
- # from matplotlib import pyplot as plt
14
- # from nltk.stem import WordNetLemmatizer
15
- # from sklearn import metrics
16
- # from transformers import BertTokenizer, BertForMaskedLM
17
- from .utils.package import *
18
-
19
-
20
- class DataAnalysis:
21
- @staticmethod
22
- def draw_pic(df, save_path):
23
- """
24
- 画直方图,对比两个不同类别差异
25
- :param df: pd.DataFrame
26
- :param save_path: str
27
- :return:
28
- """
29
- sns.distplot(df[df["label"] == 1]["feature"], label="label1")
30
- sns.distplot(df[df["label"] == 0]["feature"], label="label2")
31
- plt.legend()
32
- plt.savefig(save_path)
33
-
34
-
35
- class DataStructure:
36
- spo = {
37
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
38
- "triplets": [
39
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
40
- "p": {"text": "出版社", "l": 15, "r": 18},
41
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
42
- "source": "baidu"
43
- }
44
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
45
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
46
-
47
-
48
- def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
49
- # 两个句子的jacccard系数
50
- # 判断输入来重新定义ipt_level和sim_level
51
-
52
- # a = set(ipt1.split())
53
- # b = set(ipt2.split())
54
- a = set(ipt1)
55
- b = set(ipt2)
56
- c = a.intersection(b)
57
- # spical situation:
58
- if not ipt1 and not ipt2:
59
- return 0
60
- return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
61
-
62
-
63
- class STEM(object):
64
-
65
- def __init__(self, IPT_MODEL_PATH):
66
- self.ltp = LTP(IPT_MODEL_PATH)
67
-
68
- def start_by_dep(self, sentence):
69
- seg, hidden = self.ltp.seg([sentence])
70
- dep = self.ltp.dep(hidden) # , graph=False)
71
- seg, dep = seg[0], dep[0]
72
- for i in dep:
73
- # 主谓宾
74
- if 'SBV' == i[2]:
75
- subject = seg[i[0]]
76
- verb = seg[i[1]]
77
- if 'VOB' in i[2]:
78
- if seg[i[1]] == verb:
79
- object = seg[i[0]]
80
-
81
- return subject
82
-
83
- return None
84
-
85
- def start_by_srl(self, sentence):
86
- """
87
- 用语义角色标注工具
88
- :param sentence: "他叫汤姆去拿外衣。"
89
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
90
- """
91
- # 语义角色标注方法
92
- seg, hidden = self.ltp.seg([sentence])
93
- srl = self.ltp.srl(hidden)
94
- seg, srl = seg[0], srl[0]
95
- events = []
96
- for wdx, each_srl in enumerate(srl):
97
- if each_srl:
98
- args = []
99
- for arg in each_srl:
100
- args.extend(seg[arg[1]:arg[2] + 1])
101
- # 添加上谓词
102
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
103
- events.append(args)
104
- # print(events)
105
- return events
106
-
107
-
108
- # 这个是另一种
109
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
110
- def subject_object_labeling_new(spo_list, text):
111
- pass
112
-
113
-
114
- # 这个是传统格式的
115
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
116
- def subject_object_labeling(spo_list, text):
117
- # TODO
118
- '''
119
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
120
- :param spo_list:
121
- :param text:
122
- :return: labeling_list
123
- '''
124
-
125
- def _spo_list_to_spo_predicate_dict(spo_list):
126
- spo_predicate_dict = dict()
127
- for spo_item in spo_list:
128
- predicate = spo_item["predicate"]
129
- subject = spo_item["subject"]
130
- object = spo_item["object"]
131
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
132
- return spo_predicate_dict
133
-
134
- def _index_q_list_in_k_list(q_list, k_list):
135
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
136
- q_list_length = len(q_list)
137
- k_list_length = len(k_list)
138
- for idx in range(k_list_length - q_list_length + 1):
139
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
140
- # print(idx, t)
141
- if all(t):
142
- # print(idx)
143
- idx_start = idx
144
- return idx_start
145
-
146
- def _labeling_type(spo, spo_type):
147
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
148
- labeling_list[idx_start] = 'B-' + spo_type
149
- if len(spo) == 2:
150
- labeling_list[idx_start + 1] = 'I-' + spo_type
151
- elif len(spo) >= 3:
152
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
153
- else:
154
- pass
155
-
156
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
157
- labeling_list = ['O'] * len(text)
158
- # count = 0
159
- for predicate, spo_list_form in spo_predicate_dict.items():
160
- if predicate in text:
161
- for (spo_subject, spo_object) in spo_list_form:
162
- # if predicate not in spo_subject and predicate not in spo_object:
163
- _labeling_type(spo_subject, 'SUB')
164
- _labeling_type(spo_object, 'OBJ')
165
- _labeling_type(predicate, 'PRE')
166
- # count += 1
167
- # print(count)
168
- # if count == 2:
169
- # print()
170
- if labeling_list != ['O'] * len(text):
171
- return labeling_list
172
- return None
173
-
174
-
175
- def label(text, labels):
176
- '''
177
- 返回两列的标记数据序列
178
- :param text:
179
- :param labels:
180
- :return:
181
- '''
182
- train_sequence = '\n'.join(
183
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
184
- return train_sequence
185
-
186
-
187
- def convert_crf_format_10_fold(corpus, objdir_path):
188
- '''
189
- 把已经是crf格式的数据,分成十折。
190
- para:
191
-
192
- '''
193
- # corpus = list(range(1,22))
194
- j_mkdir(objdir_path)
195
- split_position = int(len(corpus) / 10)
196
- for k in range(0, 10):
197
- if k == 9:
198
- dev_set = corpus[k * split_position:]
199
- train_set = corpus[:k * split_position]
200
- else:
201
- dev_set = corpus[k * split_position: (k + 1) * split_position]
202
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
203
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
204
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
205
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
206
-
207
-
208
- def read_seq_res(path, labels):
209
- '''
210
- 读序列标注三列数据的方法
211
- :param path:
212
- :param labels:
213
- :return:
214
- '''
215
- with codecs.open(path, 'r', 'utf-8') as rd:
216
- seqs_str = rd.read().strip()
217
- seqs_list = seqs_str.split('\n\n')
218
- text, raw_label, predict_label = [], [], []
219
- for seq in seqs_list:
220
- seq_split = seq.split('\n')
221
- text_tmp = ''
222
- raw_index_dict, pre_index_dict = {}, {}
223
- for label in labels:
224
- raw_index_dict.setdefault(label, [])
225
- pre_index_dict.setdefault(label, [])
226
- for idx, line in enumerate(seq_split):
227
- tmp = line.split('\t')
228
- text_tmp += tmp[0]
229
- if tmp[1] in labels:
230
- raw_index_dict[tmp[1]].append(idx)
231
- if tmp[2] in labels:
232
- pre_index_dict[tmp[2]].append(idx)
233
- text.append(text_tmp)
234
- raw_label.append(raw_index_dict)
235
- predict_label.append(pre_index_dict)
236
- return text, raw_label, predict_label
237
-
238
-
239
- def kfold(corpus, path, k=9, is_shuffle=True):
240
- '''
241
- k是10份中训练集占了几份
242
- '''
243
- j_mkdir(path)
244
- if is_shuffle:
245
- random.shuffle(corpus)
246
- split_position = int(len(corpus) / 10)
247
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
248
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
249
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
250
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
251
- """
252
- import pandas as pd
253
- from sklearn.model_selection import KFold
254
-
255
- df = pd.DataFrame({
256
- "text": ["text_{}".format(i) for i in range(100)],
257
- "labels": ["label_{}".format(i % 10) for i in range(100)]
258
- })
259
- train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
260
- test_idx, val_idx = KFold(n_splits=2, shuffle=True).split(df).__next__()
261
- df_train = df.iloc[train_idx]
262
- df_val = df.iloc[val_idx]
263
- df_test = df.iloc[test_idx]
264
- print(train_idx)
265
- print(val_idx)
266
- print(test_idx)
267
- """
268
-
269
-
270
- # 读取crf序列格式的数据
271
- def read_seq_data(path):
272
- content = readtxt_list_all_strip(path)
273
- lines = [i.split('\t') if i else '' for i in content]
274
- print(lines)
275
- sequences, labels, sequence, label = [], [], [], []
276
- for idx, line in enumerate(lines):
277
- if line == '':
278
- if sequence:
279
- sequences.append(sequence)
280
- labels.append(label)
281
- sequence, label = [], []
282
- else:
283
- sequence.append(line[0])
284
- label.append(line[1])
285
- if idx == len(lines) - 1 and sequence:
286
- sequences.append(sequence)
287
- labels.append(label)
288
- return sequences, labels
289
-
290
-
291
- def split_5_percent(lines, sample_precent=5):
292
- random.seed(8)
293
- # lines = list(range(1, 109))
294
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
295
- div = int(len(lines) / 100)
296
- sample_num = div * sample_precent
297
- sample = random.sample(idx_lines, sample_num)
298
- sorted_sample = sorted(sample, key=lambda x: x[0])
299
- remove_idx = [i[0] for i in sorted_sample]
300
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
301
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
302
- print(less_has_raw_line_info)
303
- print(most)
304
- return most, less_has_raw_line_info
305
-
306
-
307
- def split_sentences(sentences, mode='chinese'):
308
- # sentences->Str
309
- # example '12“345。”“6789”'
310
- if mode == 'chinese':
311
- split_signs = list('。!?…')
312
- other_sign = "”"
313
- elif mode == 'english':
314
- split_signs = list('.!?')
315
- other_sign = '"'
316
- else:
317
- print('暂时还没有')
318
- split_signs = list('.!?')
319
- other_sign = '"'
320
- splited_sentences = []
321
- start_idx = 0
322
- for idx, char in enumerate(sentences):
323
- if idx == len(sentences) - 1:
324
- if char in split_signs:
325
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
326
- start_idx = idx + 1
327
- else:
328
- splited_sentences.append(sentences[start_idx:].strip())
329
- else:
330
- if char in split_signs:
331
- if sentences[idx + 1] == other_sign:
332
- if idx < len(sentences) - 2:
333
- # 处理。”。
334
- if sentences[idx + 2] not in split_signs:
335
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
336
- start_idx = idx + 2
337
- elif sentences[idx + 1] not in split_signs:
338
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
339
- start_idx = idx + 1
340
- return splited_sentences
341
-
342
-
343
- def pos_reduction():
344
- wnl = WordNetLemmatizer()
345
- # lemmatize nouns
346
- print(wnl.lemmatize('cars', 'n'))
347
- print(wnl.lemmatize('men', 'n'))
348
-
349
- # lemmatize verbs
350
- print(wnl.lemmatize('running', 'v'))
351
- print(wnl.lemmatize('ate', 'v'))
352
-
353
-
354
- class DataVisualization:
355
- # 和下面的类冲突了
356
- pass
357
-
358
-
359
- class CalcPPL(object):
360
- # ppl计算
361
- # https://www.scribendi.ai/comparing-bert-and-gpt-2-as-language-models-to-score-the-grammatical-correctness-of-a-sentence/
362
- def __init__(self, path):
363
- self.model = BertForMaskedLM.from_pretrained(path)
364
- self.model.eval()
365
- # Load pre-trained model tokenizer (vocabulary)
366
- self.tokenizer = BertTokenizer.from_pretrained(path)
367
-
368
- def ppl_1(self, sentence):
369
- tokenizer = self.tokenizer
370
- model = self.tokenizer
371
- tokenize_input = tokenizer.tokenize(sentence)
372
- tokenize_input = tokenize_input
373
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
374
- with torch.no_grad():
375
- loss = model(tensor_input, labels=tensor_input)[0]
376
- return np.exp(loss.detach().numpy())
377
-
378
- # [1] Salazar J, Liang D, Nguyen T Q, et al. Masked Language Model Scoring[C]//Proceedings of ACL. 2020: 2699-2712.
379
- def ppl_2(self, sentence):
380
- tokenizer = self.tokenizer
381
- model = self.tokenizer
382
- with torch.no_grad():
383
- tokenize_input = tokenizer.tokenize(sentence)
384
- tensor_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
385
- sen_len = len(tokenize_input)
386
- sentence_loss = 0.
387
-
388
- for i, word in enumerate(tokenize_input):
389
- # add mask to i-th character of the sentence
390
- tokenize_input[i] = '[MASK]'
391
- mask_input = torch.tensor([tokenizer.convert_tokens_to_ids(tokenize_input)])
392
-
393
- output = model(mask_input)
394
-
395
- prediction_scores = output[0]
396
- softmax = nn.Softmax(dim=0)
397
- ps = softmax(prediction_scores[0, i]).log()
398
- word_loss = ps[tensor_input[0, i]]
399
- sentence_loss += word_loss.item()
400
-
401
- tokenize_input[i] = word
402
- ppl = np.exp(-sentence_loss / sen_len)
403
- # print("困惑度:", ppl)
404
- return ppl
405
-
406
- def test(self):
407
- sentence = "输入句子:"
408
- ppl = self.ppl_1(sentence)
409
- ppl2 = self.ppl_2(sentence)
410
- print(ppl)
411
- print(ppl2)
412
-
413
-
414
- class Evaluate():
415
- def __init__(self):
416
- pass
417
-
418
- def auc_metric(self, k):
419
- pass
420
-
421
- def map_metric(self):
422
- pass
423
-
424
- def ndcg(self, n, y_true, y_score):
425
- report = metrics.ndcg_score(y_true, y_score)
426
- return report
427
-
428
-
429
- class DecideTreeUtils:
430
- @staticmethod
431
- def draw(bst):
432
- # xgb 画图
433
- fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
434
- xgb.plot_tree(bst, ax=ax_tree)
435
- fig_tree.savefig('tree.png')
436
- plt.show()
437
-
438
-
439
- def seed_everything(seed=7777777) -> None:
440
- """
441
- 设置整个开发环境的seed
442
- :param seed:
443
- :param device:
444
- :return:
445
- """
446
- random.seed(seed)
447
- os.environ['PYTHONHASHSEED'] = str(seed)
448
- np.random.seed(seed)
449
- torch.manual_seed(seed) # CPU随机种子确定
450
- torch.cuda.manual_seed(seed)
451
- torch.cuda.manual_seed_all(seed)
452
- # some cudnn methods can be random even after fixing the seed
453
- # unless you tell it to be deterministic
454
- torch.backends.cudnn.deterministic = True
455
-
456
-
457
- if __name__ == '__main__':
458
- stem = STEM(IPT_MODEL_PATH)
459
- test_sentence = '美国袭击伊拉克'
460
- a = stem.start_by_srl(test_sentence)
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+
24
+ """
25
+ return bytes
26
+
27
+ reference:
28
+ 1. https://zhuanlan.zhihu.com/p/665172400
29
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
30
+ """
31
+ # reference1
32
+ # attention
33
+ # FFN
34
+ # Layer Norm
35
+ r1 = (
36
+ sequence_length
37
+ * batch_size
38
+ * hidden_dim
39
+ * lay_number
40
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
41
+ )
42
+ # reference2
43
+ r2 = (
44
+ lay_number*(2 * sequence_length * attention_heads_num + 16 * hidden_dim)
45
+ * sequence_length
46
+ * batch_size
47
+ / gpu_num
48
+ )
49
+ print(r1)
50
+ print(r2)
51
+ return r1
52
+
53
+
54
+ class DataAnalysis:
55
+ @staticmethod
56
+ def draw_pic(df, save_path):
57
+ """
58
+ 画直方图,对比两个不同类别差异
59
+ :param df: pd.DataFrame
60
+ :param save_path: str
61
+ :return:
62
+ """
63
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
64
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
65
+ plt.legend()
66
+ plt.savefig(save_path)
67
+
68
+
69
+ class DataStructure:
70
+ spo = {
71
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
72
+ "triplets": [
73
+ {
74
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
75
+ "p": {"text": "出版社", "l": 15, "r": 18},
76
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
77
+ }
78
+ ],
79
+ "source": "baidu",
80
+ }
81
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
82
+ ner_label_example = (
83
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
84
+ )
85
+
86
+
87
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
88
+ # 两个句子的jacccard系数
89
+ # 判断输入来重新定义ipt_level和sim_level
90
+
91
+ # a = set(ipt1.split())
92
+ # b = set(ipt2.split())
93
+ a = set(ipt1)
94
+ b = set(ipt2)
95
+ c = a.intersection(b)
96
+ # spical situation:
97
+ if not ipt1 and not ipt2:
98
+ return 0
99
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
100
+
101
+
102
+ class STEM(object):
103
+ def __init__(self, IPT_MODEL_PATH):
104
+ self.ltp = LTP(IPT_MODEL_PATH)
105
+
106
+ def start_by_dep(self, sentence):
107
+ seg, hidden = self.ltp.seg([sentence])
108
+ dep = self.ltp.dep(hidden) # , graph=False)
109
+ seg, dep = seg[0], dep[0]
110
+ for i in dep:
111
+ # 主谓宾
112
+ if "SBV" == i[2]:
113
+ subject = seg[i[0]]
114
+ verb = seg[i[1]]
115
+ if "VOB" in i[2]:
116
+ if seg[i[1]] == verb:
117
+ object = seg[i[0]]
118
+
119
+ return subject
120
+
121
+ return None
122
+
123
+ def start_by_srl(self, sentence):
124
+ """
125
+ 用语义角色标注工具
126
+ :param sentence: "他叫汤姆去拿外衣。"
127
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
128
+ """
129
+ # 语义角色标注方法
130
+ seg, hidden = self.ltp.seg([sentence])
131
+ srl = self.ltp.srl(hidden)
132
+ seg, srl = seg[0], srl[0]
133
+ events = []
134
+ for wdx, each_srl in enumerate(srl):
135
+ if each_srl:
136
+ args = []
137
+ for arg in each_srl:
138
+ args.extend(seg[arg[1] : arg[2] + 1])
139
+ # 添加上谓词
140
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
141
+ events.append(args)
142
+ # print(events)
143
+ return events
144
+
145
+
146
+ # 这个是另一种
147
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
148
+ def subject_object_labeling_new(spo_list, text):
149
+ pass
150
+
151
+
152
+ # 这个是传统格式的
153
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
154
+ def subject_object_labeling(spo_list, text):
155
+ # TODO
156
+ """
157
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
158
+ :param spo_list:
159
+ :param text:
160
+ :return: labeling_list
161
+ """
162
+
163
+ def _spo_list_to_spo_predicate_dict(spo_list):
164
+ spo_predicate_dict = dict()
165
+ for spo_item in spo_list:
166
+ predicate = spo_item["predicate"]
167
+ subject = spo_item["subject"]
168
+ object = spo_item["object"]
169
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
170
+ return spo_predicate_dict
171
+
172
+ def _index_q_list_in_k_list(q_list, k_list):
173
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
174
+ q_list_length = len(q_list)
175
+ k_list_length = len(k_list)
176
+ for idx in range(k_list_length - q_list_length + 1):
177
+ t = [q == k for q, k in zip(q_list, k_list[idx : idx + q_list_length])]
178
+ # print(idx, t)
179
+ if all(t):
180
+ # print(idx)
181
+ idx_start = idx
182
+ return idx_start
183
+
184
+ def _labeling_type(spo, spo_type):
185
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
186
+ labeling_list[idx_start] = "B-" + spo_type
187
+ if len(spo) == 2:
188
+ labeling_list[idx_start + 1] = "I-" + spo_type
189
+ elif len(spo) >= 3:
190
+ labeling_list[idx_start + 1 : idx_start + len(spo)] = ["I-" + spo_type] * (
191
+ len(spo) - 1
192
+ )
193
+ else:
194
+ pass
195
+
196
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
197
+ labeling_list = ["O"] * len(text)
198
+ # count = 0
199
+ for predicate, spo_list_form in spo_predicate_dict.items():
200
+ if predicate in text:
201
+ for (spo_subject, spo_object) in spo_list_form:
202
+ # if predicate not in spo_subject and predicate not in spo_object:
203
+ _labeling_type(spo_subject, "SUB")
204
+ _labeling_type(spo_object, "OBJ")
205
+ _labeling_type(predicate, "PRE")
206
+ # count += 1
207
+ # print(count)
208
+ # if count == 2:
209
+ # print()
210
+ if labeling_list != ["O"] * len(text):
211
+ return labeling_list
212
+ return None
213
+
214
+
215
+ def label(text, labels):
216
+ """
217
+ 返回两列的标记数据序列
218
+ :param text:
219
+ :param labels:
220
+ :return:
221
+ """
222
+ train_sequence = "\n".join(
223
+ [
224
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
225
+ for i in zip(list(text), labels)
226
+ ]
227
+ )
228
+ return train_sequence
229
+
230
+
231
+ def convert_crf_format_10_fold(corpus, objdir_path):
232
+ """
233
+ 把已经是crf格式的数据,分成十折。
234
+ para:
235
+
236
+ """
237
+ # corpus = list(range(1,22))
238
+ j_mkdir(objdir_path)
239
+ split_position = int(len(corpus) / 10)
240
+ for k in range(0, 10):
241
+ if k == 9:
242
+ dev_set = corpus[k * split_position :]
243
+ train_set = corpus[: k * split_position]
244
+ else:
245
+ dev_set = corpus[k * split_position : (k + 1) * split_position]
246
+ train_set = (
247
+ corpus[: k * split_position] + corpus[(k + 1) * split_position :]
248
+ )
249
+ writetxt_w_list(
250
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
251
+ )
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
253
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
254
+
255
+
256
+ def read_seq_res(path, labels):
257
+ """
258
+ 读序列标注三列数据的方法
259
+ :param path:
260
+ :param labels:
261
+ :return:
262
+ """
263
+ with codecs.open(path, "r", "utf-8") as rd:
264
+ seqs_str = rd.read().strip()
265
+ seqs_list = seqs_str.split("\n\n")
266
+ text, raw_label, predict_label = [], [], []
267
+ for seq in seqs_list:
268
+ seq_split = seq.split("\n")
269
+ text_tmp = ""
270
+ raw_index_dict, pre_index_dict = {}, {}
271
+ for label in labels:
272
+ raw_index_dict.setdefault(label, [])
273
+ pre_index_dict.setdefault(label, [])
274
+ for idx, line in enumerate(seq_split):
275
+ tmp = line.split("\t")
276
+ text_tmp += tmp[0]
277
+ if tmp[1] in labels:
278
+ raw_index_dict[tmp[1]].append(idx)
279
+ if tmp[2] in labels:
280
+ pre_index_dict[tmp[2]].append(idx)
281
+ text.append(text_tmp)
282
+ raw_label.append(raw_index_dict)
283
+ predict_label.append(pre_index_dict)
284
+ return text, raw_label, predict_label
285
+
286
+
287
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
288
+ """
289
+ k是10份中训练集占了几份
290
+ """
291
+ j_mkdir(path)
292
+ if is_shuffle:
293
+ random.shuffle(corpus)
294
+ split_position = int(len(corpus) / 10)
295
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position :]
296
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
298
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
299
+
300
+
301
+ def kfold_df(df, save_dir=None):
302
+ """
303
+ 划分train test val集, 写为windows可读的csv。
304
+ :param df:pd.DataFrame
305
+ :param save_dir:
306
+ :return:
307
+ """
308
+ from sklearn.model_selection import KFold
309
+ import pandas as pd
310
+
311
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
312
+ df_test_and_val = df.iloc[test_and_val_idx]
313
+ test_idx, val_idx = (
314
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
315
+ )
316
+ df_train = df.iloc[train_idx]
317
+ df_val = df.iloc[val_idx]
318
+ df_test = df.iloc[test_idx]
319
+ if save_dir:
320
+ j_mkdir(save_dir)
321
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
322
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
323
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
324
+ return df_train, df_val, df_test
325
+
326
+
327
+ # 读取crf序列格式的数据
328
+ def read_seq_data(path):
329
+ content = readtxt_list_all_strip(path)
330
+ lines = [i.split("\t") if i else "" for i in content]
331
+ print(lines)
332
+ sequences, labels, sequence, label = [], [], [], []
333
+ for idx, line in enumerate(lines):
334
+ if line == "":
335
+ if sequence:
336
+ sequences.append(sequence)
337
+ labels.append(label)
338
+ sequence, label = [], []
339
+ else:
340
+ sequence.append(line[0])
341
+ label.append(line[1])
342
+ if idx == len(lines) - 1 and sequence:
343
+ sequences.append(sequence)
344
+ labels.append(label)
345
+ return sequences, labels
346
+
347
+
348
+ def split_5_percent(lines, sample_precent=5):
349
+ random.seed(8)
350
+ # lines = list(range(1, 109))
351
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
352
+ div = int(len(lines) / 100)
353
+ sample_num = div * sample_precent
354
+ sample = random.sample(idx_lines, sample_num)
355
+ sorted_sample = sorted(sample, key=lambda x: x[0])
356
+ remove_idx = [i[0] for i in sorted_sample]
357
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
358
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
359
+ print(less_has_raw_line_info)
360
+ print(most)
361
+ return most, less_has_raw_line_info
362
+
363
+
364
+ def split_sentence(sentence, language="chinese", cross_line=True):
365
+ """
366
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
367
+ :param sentence:
368
+ :param language:
369
+ :param cross_line:
370
+ :return:
371
+ """
372
+ # sentences->Str
373
+ # example '12“345。”“6789”'
374
+ assert language in ["chinese", "english"], "unsupportable for other language"
375
+ sentence = sentence.replace("\r", "")
376
+ if language == "chinese":
377
+ split_signs = list("。!?…")
378
+ if cross_line:
379
+ split_signs.append("\n")
380
+ other_sign = "”"
381
+ elif language == "english":
382
+ split_signs = list(".!?")
383
+ other_sign = '"'
384
+ else:
385
+ split_signs = list(".!?")
386
+ other_sign = '"'
387
+ sentences = []
388
+ start_idx = 0
389
+ for idx, char in enumerate(sentence):
390
+ if idx == len(sentence) - 1:
391
+ if char in split_signs:
392
+ sentences.append(sentence[start_idx : idx + 1].strip())
393
+ start_idx = idx + 1
394
+ else:
395
+ sentences.append(sentence[start_idx:].strip())
396
+ else:
397
+ if char in split_signs:
398
+ if sentence[idx + 1] == other_sign:
399
+ if idx < len(sentence) - 2:
400
+ # 处理。”。
401
+ if sentence[idx + 2] not in split_signs:
402
+ sentences.append(sentence[start_idx : idx + 2].strip())
403
+ start_idx = idx + 2
404
+ elif sentence[idx + 1] not in split_signs:
405
+ sentences.append(sentence[start_idx : idx + 1].strip())
406
+ start_idx = idx + 1
407
+ return sentences
408
+
409
+
410
+ def pos_reduction():
411
+ wnl = WordNetLemmatizer()
412
+ # lemmatize nouns
413
+ print(wnl.lemmatize("cars", "n"))
414
+ print(wnl.lemmatize("men", "n"))
415
+
416
+ # lemmatize verbs
417
+ print(wnl.lemmatize("running", "v"))
418
+ print(wnl.lemmatize("ate", "v"))
419
+
420
+
421
+ class DataVisualization:
422
+ # 和下面的类冲突了
423
+ pass
424
+
425
+
426
+ class Evaluate:
427
+ def __init__(self):
428
+ pass
429
+
430
+ def auc_metric(self, k):
431
+ pass
432
+
433
+ def map_metric(self):
434
+ pass
435
+
436
+ def ndcg(self, n, y_true, y_score):
437
+ report = metrics.ndcg_score(y_true, y_score)
438
+ return report
439
+
440
+
441
+ class DecideTreeUtils:
442
+ @staticmethod
443
+ def draw(bst):
444
+ # xgb 画图
445
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
446
+ xgb.plot_tree(bst, ax=ax_tree)
447
+ fig_tree.savefig("tree.png")
448
+ plt.show()
449
+
450
+
451
+ def seed_everything(seed=7777777) -> None:
452
+ """
453
+ 设置整个开发环境的seed
454
+ :param seed:
455
+ :param device:
456
+ :return:
457
+ """
458
+ random.seed(seed)
459
+ os.environ["PYTHONHASHSEED"] = str(seed)
460
+ np.random.seed(seed)
461
+ torch.manual_seed(seed) # CPU随机种子确定
462
+ torch.cuda.manual_seed(seed)
463
+ torch.cuda.manual_seed_all(seed)
464
+ # some cudnn methods can be random even after fixing the seed
465
+ # unless you tell it to be deterministic
466
+ torch.backends.cudnn.deterministic = True
467
+
468
+
469
+ if __name__ == "__main__":
470
+ # stem = STEM(IPT_MODEL_PATH)
471
+ # test_sentence = "美国袭击伊拉克"
472
+ # a = stem.start_by_srl(test_sentence)
473
+
474
+ res = calc_llm_train_activation_memory(
475
+ model_name="",
476
+ sequence_length=2048,
477
+ batch_size=1,
478
+ hidden_dim=4096,
479
+ lay_number=28,
480
+ attention_heads_num=32,
481
+ gpu_num=1
482
+ )
483
+ print(res, "G")