nlpertools 1.0.4__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. nlpertools/__init__.py +24 -11
  2. nlpertools/algo/__init__.py +0 -0
  3. nlpertools/algo/ac.py +18 -0
  4. nlpertools/algo/bit_ops.py +28 -0
  5. nlpertools/algo/kmp.py +94 -0
  6. nlpertools/algo/num_ops.py +12 -0
  7. nlpertools/algo/template.py +116 -0
  8. nlpertools/algo/union.py +13 -0
  9. nlpertools/data_client.py +387 -0
  10. nlpertools/data_structure/__init__.py +0 -0
  11. nlpertools/data_structure/base_structure.py +109 -0
  12. nlpertools/dataprocess.py +611 -3
  13. nlpertools/default_db_config.yml +41 -0
  14. nlpertools/io/__init__.py +3 -3
  15. nlpertools/io/dir.py +54 -47
  16. nlpertools/io/file.py +277 -205
  17. nlpertools/ml.py +483 -317
  18. nlpertools/monitor/__init__.py +0 -0
  19. nlpertools/monitor/gpu.py +18 -0
  20. nlpertools/monitor/memory.py +24 -0
  21. nlpertools/movie.py +36 -0
  22. nlpertools/nlpertools_config.yml +1 -0
  23. nlpertools/{openApi.py → open_api.py} +65 -62
  24. nlpertools/other.py +364 -188
  25. nlpertools/pic.py +288 -0
  26. nlpertools/plugin.py +43 -34
  27. nlpertools/reminder.py +98 -15
  28. nlpertools/template/__init__.py +0 -0
  29. nlpertools/utils/__init__.py +3 -0
  30. nlpertools/utils/lazy.py +727 -0
  31. nlpertools/utils/log_util.py +20 -0
  32. nlpertools/utils/package.py +89 -0
  33. nlpertools/utils/package_v1.py +94 -0
  34. nlpertools/utils/package_v2.py +117 -0
  35. nlpertools/utils_for_nlpertools.py +93 -0
  36. nlpertools/vector_index_demo.py +108 -0
  37. nlpertools/wrapper.py +161 -0
  38. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
  39. nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
  40. nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
  41. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
  42. nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
  43. nlpertools_helper/__init__.py +10 -0
  44. nlpertools-1.0.4.dist-info/METADATA +0 -42
  45. nlpertools-1.0.4.dist-info/RECORD +0 -15
  46. nlpertools-1.0.4.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,317 +1,483 @@
1
- # encoding=utf-8
2
- from .io import *
3
- import random
4
-
5
-
6
- class DataStructure:
7
- spo = {
8
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
9
- "triplets": [
10
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
11
- "p": {"text": "出版社", "l": 15, "r": 18},
12
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
13
- "source": "baidu"
14
- }
15
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
16
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
17
-
18
-
19
- '''
20
- try:
21
- from ltp import LTP
22
- except:
23
- pass
24
-
25
- class STEM(object):
26
- def __init__(self, IPT_MODEL_PATH):
27
- self.ltp = LTP(IPT_MODEL_PATH)
28
-
29
- def start(self, sentence):
30
- seg, hidden = self.ltp.seg([sentence])
31
- dep = self.ltp.dep(hidden) # , graph=False)
32
- seg, dep = seg[0], dep[0]
33
- for i in dep:
34
- # 主谓宾
35
- if 'SBV' == i[2]:
36
- subject = seg[i[0]]
37
- verb = seg[i[1]]
38
- if 'VOB' in i[2]:
39
- if seg[i[1]] == verb:
40
- object = seg[i[]]
41
-
42
- return subject
43
-
44
- return None
45
- class STEM(object):
46
- def __init__(self, IPT_MODEL_PATH):
47
- self.ltp = LTP(IPT_MODEL_PATH)
48
-
49
- def start(self, sentence):
50
- """
51
- 用语义角色标注工具
52
- :param sentence: "他叫汤姆去拿外衣。"
53
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
54
- """
55
- # 语义角色标注方法
56
- seg, hidden = self.ltp.seg([sentence])
57
- srl = self.ltp.srl(hidden)
58
- seg, srl = seg[0], srl[0]
59
- events = []
60
- for wdx, each_srl in enumerate(srl):
61
- if each_srl:
62
- args = []
63
- for arg in each_srl:
64
- args.extend(seg[arg[1]:arg[2] + 1])
65
- # 添加上谓词
66
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
67
- events.append(args)
68
- # print(events)
69
- return events
70
-
71
- def start_dep_method(self, sentence):
72
- # seg, hidden = self.ltp.seg([sentence])
73
- # dep = self.ltp.dep(hidden)#, graph=False)
74
- # seg, dep = seg[0], dep[0]
75
- # for i in dep:
76
- # # 主谓宾
77
- # if 'SBV' == i[2]:
78
- # subject = seg[i[0]]
79
- # verb = seg[i[1]]
80
- # if 'VOB' in i[2]:
81
- # if seg[i[1]] == verb:
82
- # object = seg[i]
83
- # return subject
84
- return None
85
-
86
- IPT_MODEL_PATH = './tiny'
87
- stem = STEM(IPT_MODEL_PATH)
88
- sentence = '美国袭击伊拉克'
89
- a = stem.start(sentence)
90
- '''
91
-
92
-
93
- # 这个是另一种
94
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
95
- def subject_object_labeling_new(spo_list, text):
96
- pass
97
-
98
-
99
- # 这个是传统格式的
100
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
101
- def subject_object_labeling(spo_list, text):
102
- # TODO
103
- '''
104
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
105
- :param spo_list:
106
- :param text:
107
- :return: labeling_list
108
- '''
109
-
110
- def _spo_list_to_spo_predicate_dict(spo_list):
111
- spo_predicate_dict = dict()
112
- for spo_item in spo_list:
113
- predicate = spo_item["predicate"]
114
- subject = spo_item["subject"]
115
- object = spo_item["object"]
116
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
117
- return spo_predicate_dict
118
-
119
- def _index_q_list_in_k_list(q_list, k_list):
120
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
121
- q_list_length = len(q_list)
122
- k_list_length = len(k_list)
123
- for idx in range(k_list_length - q_list_length + 1):
124
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
125
- # print(idx, t)
126
- if all(t):
127
- # print(idx)
128
- idx_start = idx
129
- return idx_start
130
-
131
- def _labeling_type(spo, spo_type):
132
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
133
- labeling_list[idx_start] = 'B-' + spo_type
134
- if len(spo) == 2:
135
- labeling_list[idx_start + 1] = 'I-' + spo_type
136
- elif len(spo) >= 3:
137
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
138
- else:
139
- pass
140
-
141
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
142
- labeling_list = ['O'] * len(text)
143
- # count = 0
144
- for predicate, spo_list_form in spo_predicate_dict.items():
145
- if predicate in text:
146
- for (spo_subject, spo_object) in spo_list_form:
147
- # if predicate not in spo_subject and predicate not in spo_object:
148
- _labeling_type(spo_subject, 'SUB')
149
- _labeling_type(spo_object, 'OBJ')
150
- _labeling_type(predicate, 'PRE')
151
- # count += 1
152
- # print(count)
153
- # if count == 2:
154
- # print()
155
- if labeling_list != ['O'] * len(text):
156
- return labeling_list
157
- return None
158
-
159
-
160
- def label(text, labels):
161
- '''
162
- 返回两列的标记数据序列
163
- :param text:
164
- :param labels:
165
- :return:
166
- '''
167
- train_sequence = '\n'.join(
168
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
169
- return train_sequence
170
-
171
-
172
- def convert_crf_format_10_fold(corpus, objdir_path):
173
- '''
174
- 把已经是crf格式的数据,分成十折。
175
- para:
176
-
177
- '''
178
- # corpus = list(range(1,22))
179
- j_mkdir(objdir_path)
180
- split_position = int(len(corpus) / 10)
181
- for k in range(0, 10):
182
- if k == 9:
183
- dev_set = corpus[k * split_position:]
184
- train_set = corpus[:k * split_position]
185
- else:
186
- dev_set = corpus[k * split_position: (k + 1) * split_position]
187
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
188
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
189
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
190
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
191
-
192
-
193
- def read_seq_res(path, labels):
194
- '''
195
- 读序列标注三列数据的方法
196
- :param path:
197
- :param labels:
198
- :return:
199
- '''
200
- with codecs.open(path, 'r', 'utf-8') as rd:
201
- seqs_str = rd.read().strip()
202
- seqs_list = seqs_str.split('\n\n')
203
- text, raw_label, predict_label = [], [], []
204
- for seq in seqs_list:
205
- seq_split = seq.split('\n')
206
- text_tmp = ''
207
- raw_index_dict, pre_index_dict = {}, {}
208
- for label in labels:
209
- raw_index_dict.setdefault(label, [])
210
- pre_index_dict.setdefault(label, [])
211
- for idx, line in enumerate(seq_split):
212
- tmp = line.split('\t')
213
- text_tmp += tmp[0]
214
- if tmp[1] in labels:
215
- raw_index_dict[tmp[1]].append(idx)
216
- if tmp[2] in labels:
217
- pre_index_dict[tmp[2]].append(idx)
218
- text.append(text_tmp)
219
- raw_label.append(raw_index_dict)
220
- predict_label.append(pre_index_dict)
221
- return text, raw_label, predict_label
222
-
223
-
224
- def kfold(corpus, path, k=9, is_shuffle=True):
225
- '''
226
- k是10份中训练集占了几份
227
- '''
228
- j_mkdir(path)
229
- if is_shuffle:
230
- random.shuffle(corpus)
231
- split_position = int(len(corpus) / 10)
232
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
233
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
234
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
235
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
236
-
237
-
238
- # 读取crf序列格式的数据
239
- def read_seq_data(path):
240
- content = readtxt_list_all_strip(path)
241
- lines = [i.split('\t') if i else '' for i in content]
242
- print(lines)
243
- sequences, labels, sequence, label = [], [], [], []
244
- for idx, line in enumerate(lines):
245
- if line == '':
246
- if sequence:
247
- sequences.append(sequence)
248
- labels.append(label)
249
- sequence, label = [], []
250
- else:
251
- sequence.append(line[0])
252
- label.append(line[1])
253
- if idx == len(lines) - 1 and sequence:
254
- sequences.append(sequence)
255
- labels.append(label)
256
- return sequences, labels
257
-
258
-
259
- def split_5_percent(lines, sample_precent=5):
260
- random.seed(8)
261
- # lines = list(range(1, 109))
262
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
263
- div = int(len(lines) / 100)
264
- sample_num = div * sample_precent
265
- sample = random.sample(idx_lines, sample_num)
266
- sorted_sample = sorted(sample, key=lambda x: x[0])
267
- remove_idx = [i[0] for i in sorted_sample]
268
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
269
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
270
- print(less_has_raw_line_info)
271
- print(most)
272
- return most, less_has_raw_line_info
273
- def split_sentences(sentences, mode='chinese'):
274
- # sentences->Str
275
- # example '12“345。”“6789”'
276
- if mode == 'chinese':
277
- split_signs = list('。!?…')
278
- other_sign = "”"
279
- elif mode == 'english':
280
- split_signs = list('.!?')
281
- other_sign = '"'
282
- else:
283
- print('暂时还没有')
284
- split_signs = list('.!?')
285
- other_sign = '"'
286
- splited_sentences = []
287
- start_idx = 0
288
- for idx, char in enumerate(sentences):
289
- if idx == len(sentences) - 1:
290
- if char in split_signs:
291
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
292
- start_idx = idx + 1
293
- else:
294
- splited_sentences.append(sentences[start_idx:].strip())
295
- else:
296
- if char in split_signs:
297
- if sentences[idx + 1] == other_sign:
298
- if idx < len(sentences) - 2:
299
- # 处理。”。
300
- if sentences[idx + 2] not in split_signs:
301
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
302
- start_idx = idx + 2
303
- elif sentences[idx + 1] not in split_signs:
304
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
305
- start_idx = idx + 1
306
- return splited_sentences
307
- def pos_huanyuan():
308
- from nltk.stem import WordNetLemmatizer
309
- data = nlpertools.readtxt_list_all_strip('ie-selfmedia/')
310
- wnl = WordNetLemmatizer()
311
- # lemmatize nouns
312
- print(wnl.lemmatize('cars', 'n'))
313
- print(wnl.lemmatize('men', 'n'))
314
-
315
- # lemmatize verbs
316
- print(wnl.lemmatize('running', 'v'))
317
- print(wnl.lemmatize('ate', 'v'))
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+
24
+ """
25
+ return bytes
26
+
27
+ reference:
28
+ 1. https://zhuanlan.zhihu.com/p/665172400
29
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
30
+ """
31
+ # reference1
32
+ # attention
33
+ # FFN
34
+ # Layer Norm
35
+ r1 = (
36
+ sequence_length
37
+ * batch_size
38
+ * hidden_dim
39
+ * lay_number
40
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
41
+ )
42
+ # reference2
43
+ r2 = (
44
+ lay_number*(2 * sequence_length * attention_heads_num + 16 * hidden_dim)
45
+ * sequence_length
46
+ * batch_size
47
+ / gpu_num
48
+ )
49
+ print(r1)
50
+ print(r2)
51
+ return r1
52
+
53
+
54
+ class DataAnalysis:
55
+ @staticmethod
56
+ def draw_pic(df, save_path):
57
+ """
58
+ 画直方图,对比两个不同类别差异
59
+ :param df: pd.DataFrame
60
+ :param save_path: str
61
+ :return:
62
+ """
63
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
64
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
65
+ plt.legend()
66
+ plt.savefig(save_path)
67
+
68
+
69
+ class DataStructure:
70
+ spo = {
71
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
72
+ "triplets": [
73
+ {
74
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
75
+ "p": {"text": "出版社", "l": 15, "r": 18},
76
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
77
+ }
78
+ ],
79
+ "source": "baidu",
80
+ }
81
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
82
+ ner_label_example = (
83
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
84
+ )
85
+
86
+
87
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
88
+ # 两个句子的jacccard系数
89
+ # 判断输入来重新定义ipt_level和sim_level
90
+
91
+ # a = set(ipt1.split())
92
+ # b = set(ipt2.split())
93
+ a = set(ipt1)
94
+ b = set(ipt2)
95
+ c = a.intersection(b)
96
+ # spical situation:
97
+ if not ipt1 and not ipt2:
98
+ return 0
99
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
100
+
101
+
102
+ class STEM(object):
103
+ def __init__(self, IPT_MODEL_PATH):
104
+ self.ltp = LTP(IPT_MODEL_PATH)
105
+
106
+ def start_by_dep(self, sentence):
107
+ seg, hidden = self.ltp.seg([sentence])
108
+ dep = self.ltp.dep(hidden) # , graph=False)
109
+ seg, dep = seg[0], dep[0]
110
+ for i in dep:
111
+ # 主谓宾
112
+ if "SBV" == i[2]:
113
+ subject = seg[i[0]]
114
+ verb = seg[i[1]]
115
+ if "VOB" in i[2]:
116
+ if seg[i[1]] == verb:
117
+ object = seg[i[0]]
118
+
119
+ return subject
120
+
121
+ return None
122
+
123
+ def start_by_srl(self, sentence):
124
+ """
125
+ 用语义角色标注工具
126
+ :param sentence: "他叫汤姆去拿外衣。"
127
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
128
+ """
129
+ # 语义角色标注方法
130
+ seg, hidden = self.ltp.seg([sentence])
131
+ srl = self.ltp.srl(hidden)
132
+ seg, srl = seg[0], srl[0]
133
+ events = []
134
+ for wdx, each_srl in enumerate(srl):
135
+ if each_srl:
136
+ args = []
137
+ for arg in each_srl:
138
+ args.extend(seg[arg[1] : arg[2] + 1])
139
+ # 添加上谓词
140
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
141
+ events.append(args)
142
+ # print(events)
143
+ return events
144
+
145
+
146
+ # 这个是另一种
147
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
148
+ def subject_object_labeling_new(spo_list, text):
149
+ pass
150
+
151
+
152
+ # 这个是传统格式的
153
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
154
+ def subject_object_labeling(spo_list, text):
155
+ # TODO
156
+ """
157
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
158
+ :param spo_list:
159
+ :param text:
160
+ :return: labeling_list
161
+ """
162
+
163
+ def _spo_list_to_spo_predicate_dict(spo_list):
164
+ spo_predicate_dict = dict()
165
+ for spo_item in spo_list:
166
+ predicate = spo_item["predicate"]
167
+ subject = spo_item["subject"]
168
+ object = spo_item["object"]
169
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
170
+ return spo_predicate_dict
171
+
172
+ def _index_q_list_in_k_list(q_list, k_list):
173
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
174
+ q_list_length = len(q_list)
175
+ k_list_length = len(k_list)
176
+ for idx in range(k_list_length - q_list_length + 1):
177
+ t = [q == k for q, k in zip(q_list, k_list[idx : idx + q_list_length])]
178
+ # print(idx, t)
179
+ if all(t):
180
+ # print(idx)
181
+ idx_start = idx
182
+ return idx_start
183
+
184
+ def _labeling_type(spo, spo_type):
185
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
186
+ labeling_list[idx_start] = "B-" + spo_type
187
+ if len(spo) == 2:
188
+ labeling_list[idx_start + 1] = "I-" + spo_type
189
+ elif len(spo) >= 3:
190
+ labeling_list[idx_start + 1 : idx_start + len(spo)] = ["I-" + spo_type] * (
191
+ len(spo) - 1
192
+ )
193
+ else:
194
+ pass
195
+
196
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
197
+ labeling_list = ["O"] * len(text)
198
+ # count = 0
199
+ for predicate, spo_list_form in spo_predicate_dict.items():
200
+ if predicate in text:
201
+ for (spo_subject, spo_object) in spo_list_form:
202
+ # if predicate not in spo_subject and predicate not in spo_object:
203
+ _labeling_type(spo_subject, "SUB")
204
+ _labeling_type(spo_object, "OBJ")
205
+ _labeling_type(predicate, "PRE")
206
+ # count += 1
207
+ # print(count)
208
+ # if count == 2:
209
+ # print()
210
+ if labeling_list != ["O"] * len(text):
211
+ return labeling_list
212
+ return None
213
+
214
+
215
+ def label(text, labels):
216
+ """
217
+ 返回两列的标记数据序列
218
+ :param text:
219
+ :param labels:
220
+ :return:
221
+ """
222
+ train_sequence = "\n".join(
223
+ [
224
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
225
+ for i in zip(list(text), labels)
226
+ ]
227
+ )
228
+ return train_sequence
229
+
230
+
231
+ def convert_crf_format_10_fold(corpus, objdir_path):
232
+ """
233
+ 把已经是crf格式的数据,分成十折。
234
+ para:
235
+
236
+ """
237
+ # corpus = list(range(1,22))
238
+ j_mkdir(objdir_path)
239
+ split_position = int(len(corpus) / 10)
240
+ for k in range(0, 10):
241
+ if k == 9:
242
+ dev_set = corpus[k * split_position :]
243
+ train_set = corpus[: k * split_position]
244
+ else:
245
+ dev_set = corpus[k * split_position : (k + 1) * split_position]
246
+ train_set = (
247
+ corpus[: k * split_position] + corpus[(k + 1) * split_position :]
248
+ )
249
+ writetxt_w_list(
250
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
251
+ )
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
253
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
254
+
255
+
256
+ def read_seq_res(path, labels):
257
+ """
258
+ 读序列标注三列数据的方法
259
+ :param path:
260
+ :param labels:
261
+ :return:
262
+ """
263
+ with codecs.open(path, "r", "utf-8") as rd:
264
+ seqs_str = rd.read().strip()
265
+ seqs_list = seqs_str.split("\n\n")
266
+ text, raw_label, predict_label = [], [], []
267
+ for seq in seqs_list:
268
+ seq_split = seq.split("\n")
269
+ text_tmp = ""
270
+ raw_index_dict, pre_index_dict = {}, {}
271
+ for label in labels:
272
+ raw_index_dict.setdefault(label, [])
273
+ pre_index_dict.setdefault(label, [])
274
+ for idx, line in enumerate(seq_split):
275
+ tmp = line.split("\t")
276
+ text_tmp += tmp[0]
277
+ if tmp[1] in labels:
278
+ raw_index_dict[tmp[1]].append(idx)
279
+ if tmp[2] in labels:
280
+ pre_index_dict[tmp[2]].append(idx)
281
+ text.append(text_tmp)
282
+ raw_label.append(raw_index_dict)
283
+ predict_label.append(pre_index_dict)
284
+ return text, raw_label, predict_label
285
+
286
+
287
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
288
+ """
289
+ k是10份中训练集占了几份
290
+ """
291
+ j_mkdir(path)
292
+ if is_shuffle:
293
+ random.shuffle(corpus)
294
+ split_position = int(len(corpus) / 10)
295
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position :]
296
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
298
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
299
+
300
+
301
+ def kfold_df(df, save_dir=None):
302
+ """
303
+ 划分train test val集, 写为windows可读的csv。
304
+ :param df:pd.DataFrame
305
+ :param save_dir:
306
+ :return:
307
+ """
308
+ from sklearn.model_selection import KFold
309
+ import pandas as pd
310
+
311
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
312
+ df_test_and_val = df.iloc[test_and_val_idx]
313
+ test_idx, val_idx = (
314
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
315
+ )
316
+ df_train = df.iloc[train_idx]
317
+ df_val = df.iloc[val_idx]
318
+ df_test = df.iloc[test_idx]
319
+ if save_dir:
320
+ j_mkdir(save_dir)
321
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
322
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
323
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
324
+ return df_train, df_val, df_test
325
+
326
+
327
+ # 读取crf序列格式的数据
328
+ def read_seq_data(path):
329
+ content = readtxt_list_all_strip(path)
330
+ lines = [i.split("\t") if i else "" for i in content]
331
+ print(lines)
332
+ sequences, labels, sequence, label = [], [], [], []
333
+ for idx, line in enumerate(lines):
334
+ if line == "":
335
+ if sequence:
336
+ sequences.append(sequence)
337
+ labels.append(label)
338
+ sequence, label = [], []
339
+ else:
340
+ sequence.append(line[0])
341
+ label.append(line[1])
342
+ if idx == len(lines) - 1 and sequence:
343
+ sequences.append(sequence)
344
+ labels.append(label)
345
+ return sequences, labels
346
+
347
+
348
+ def split_5_percent(lines, sample_precent=5):
349
+ random.seed(8)
350
+ # lines = list(range(1, 109))
351
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
352
+ div = int(len(lines) / 100)
353
+ sample_num = div * sample_precent
354
+ sample = random.sample(idx_lines, sample_num)
355
+ sorted_sample = sorted(sample, key=lambda x: x[0])
356
+ remove_idx = [i[0] for i in sorted_sample]
357
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
358
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
359
+ print(less_has_raw_line_info)
360
+ print(most)
361
+ return most, less_has_raw_line_info
362
+
363
+
364
+ def split_sentence(sentence, language="chinese", cross_line=True):
365
+ """
366
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
367
+ :param sentence:
368
+ :param language:
369
+ :param cross_line:
370
+ :return:
371
+ """
372
+ # sentences->Str
373
+ # example '12“345。”“6789”'
374
+ assert language in ["chinese", "english"], "unsupportable for other language"
375
+ sentence = sentence.replace("\r", "")
376
+ if language == "chinese":
377
+ split_signs = list("。!?…")
378
+ if cross_line:
379
+ split_signs.append("\n")
380
+ other_sign = "”"
381
+ elif language == "english":
382
+ split_signs = list(".!?")
383
+ other_sign = '"'
384
+ else:
385
+ split_signs = list(".!?")
386
+ other_sign = '"'
387
+ sentences = []
388
+ start_idx = 0
389
+ for idx, char in enumerate(sentence):
390
+ if idx == len(sentence) - 1:
391
+ if char in split_signs:
392
+ sentences.append(sentence[start_idx : idx + 1].strip())
393
+ start_idx = idx + 1
394
+ else:
395
+ sentences.append(sentence[start_idx:].strip())
396
+ else:
397
+ if char in split_signs:
398
+ if sentence[idx + 1] == other_sign:
399
+ if idx < len(sentence) - 2:
400
+ # 处理。”。
401
+ if sentence[idx + 2] not in split_signs:
402
+ sentences.append(sentence[start_idx : idx + 2].strip())
403
+ start_idx = idx + 2
404
+ elif sentence[idx + 1] not in split_signs:
405
+ sentences.append(sentence[start_idx : idx + 1].strip())
406
+ start_idx = idx + 1
407
+ return sentences
408
+
409
+
410
+ def pos_reduction():
411
+ wnl = WordNetLemmatizer()
412
+ # lemmatize nouns
413
+ print(wnl.lemmatize("cars", "n"))
414
+ print(wnl.lemmatize("men", "n"))
415
+
416
+ # lemmatize verbs
417
+ print(wnl.lemmatize("running", "v"))
418
+ print(wnl.lemmatize("ate", "v"))
419
+
420
+
421
+ class DataVisualization:
422
+ # 和下面的类冲突了
423
+ pass
424
+
425
+
426
+ class Evaluate:
427
+ def __init__(self):
428
+ pass
429
+
430
+ def auc_metric(self, k):
431
+ pass
432
+
433
+ def map_metric(self):
434
+ pass
435
+
436
+ def ndcg(self, n, y_true, y_score):
437
+ report = metrics.ndcg_score(y_true, y_score)
438
+ return report
439
+
440
+
441
+ class DecideTreeUtils:
442
+ @staticmethod
443
+ def draw(bst):
444
+ # xgb 画图
445
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
446
+ xgb.plot_tree(bst, ax=ax_tree)
447
+ fig_tree.savefig("tree.png")
448
+ plt.show()
449
+
450
+
451
+ def seed_everything(seed=7777777) -> None:
452
+ """
453
+ 设置整个开发环境的seed
454
+ :param seed:
455
+ :param device:
456
+ :return:
457
+ """
458
+ random.seed(seed)
459
+ os.environ["PYTHONHASHSEED"] = str(seed)
460
+ np.random.seed(seed)
461
+ torch.manual_seed(seed) # CPU随机种子确定
462
+ torch.cuda.manual_seed(seed)
463
+ torch.cuda.manual_seed_all(seed)
464
+ # some cudnn methods can be random even after fixing the seed
465
+ # unless you tell it to be deterministic
466
+ torch.backends.cudnn.deterministic = True
467
+
468
+
469
+ if __name__ == "__main__":
470
+ # stem = STEM(IPT_MODEL_PATH)
471
+ # test_sentence = "美国袭击伊拉克"
472
+ # a = stem.start_by_srl(test_sentence)
473
+
474
+ res = calc_llm_train_activation_memory(
475
+ model_name="",
476
+ sequence_length=2048,
477
+ batch_size=1,
478
+ hidden_dim=4096,
479
+ lay_number=28,
480
+ attention_heads_num=32,
481
+ gpu_num=1
482
+ )
483
+ print(res, "G")