nlpertools 1.0.4__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. nlpertools/__init__.py +24 -11
  2. nlpertools/algo/__init__.py +0 -0
  3. nlpertools/algo/ac.py +18 -0
  4. nlpertools/algo/bit_ops.py +28 -0
  5. nlpertools/algo/kmp.py +94 -0
  6. nlpertools/algo/num_ops.py +12 -0
  7. nlpertools/algo/template.py +116 -0
  8. nlpertools/algo/union.py +13 -0
  9. nlpertools/data_client.py +387 -0
  10. nlpertools/data_structure/__init__.py +0 -0
  11. nlpertools/data_structure/base_structure.py +109 -0
  12. nlpertools/dataprocess.py +611 -3
  13. nlpertools/default_db_config.yml +41 -0
  14. nlpertools/io/__init__.py +3 -3
  15. nlpertools/io/dir.py +54 -47
  16. nlpertools/io/file.py +277 -205
  17. nlpertools/ml.py +483 -317
  18. nlpertools/monitor/__init__.py +0 -0
  19. nlpertools/monitor/gpu.py +18 -0
  20. nlpertools/monitor/memory.py +24 -0
  21. nlpertools/movie.py +36 -0
  22. nlpertools/nlpertools_config.yml +1 -0
  23. nlpertools/{openApi.py → open_api.py} +65 -62
  24. nlpertools/other.py +364 -188
  25. nlpertools/pic.py +288 -0
  26. nlpertools/plugin.py +43 -34
  27. nlpertools/reminder.py +98 -15
  28. nlpertools/template/__init__.py +0 -0
  29. nlpertools/utils/__init__.py +3 -0
  30. nlpertools/utils/lazy.py +727 -0
  31. nlpertools/utils/log_util.py +20 -0
  32. nlpertools/utils/package.py +89 -0
  33. nlpertools/utils/package_v1.py +94 -0
  34. nlpertools/utils/package_v2.py +117 -0
  35. nlpertools/utils_for_nlpertools.py +93 -0
  36. nlpertools/vector_index_demo.py +108 -0
  37. nlpertools/wrapper.py +161 -0
  38. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
  39. nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
  40. nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
  41. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
  42. nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
  43. nlpertools_helper/__init__.py +10 -0
  44. nlpertools-1.0.4.dist-info/METADATA +0 -42
  45. nlpertools-1.0.4.dist-info/RECORD +0 -15
  46. nlpertools-1.0.4.dist-info/top_level.txt +0 -1
nlpertools/ml.py CHANGED
@@ -1,317 +1,483 @@
1
- # encoding=utf-8
2
- from .io import *
3
- import random
4
-
5
-
6
- class DataStructure:
7
- spo = {
8
- "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
9
- "triplets": [
10
- {"s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
11
- "p": {"text": "出版社", "l": 15, "r": 18},
12
- "o": {"text": "故宫出版社", "l": 13, "r": 18}}],
13
- "source": "baidu"
14
- }
15
- ner_input_example = '这句话一共有两个实体分别为大象和老鼠。'
16
- ner_label_example = list('OOOOOOOOOOOOO') + ['B-s', 'I-s'] + ['O'] + ['B-o', 'I-o'] + ['O']
17
-
18
-
19
- '''
20
- try:
21
- from ltp import LTP
22
- except:
23
- pass
24
-
25
- class STEM(object):
26
- def __init__(self, IPT_MODEL_PATH):
27
- self.ltp = LTP(IPT_MODEL_PATH)
28
-
29
- def start(self, sentence):
30
- seg, hidden = self.ltp.seg([sentence])
31
- dep = self.ltp.dep(hidden) # , graph=False)
32
- seg, dep = seg[0], dep[0]
33
- for i in dep:
34
- # 主谓宾
35
- if 'SBV' == i[2]:
36
- subject = seg[i[0]]
37
- verb = seg[i[1]]
38
- if 'VOB' in i[2]:
39
- if seg[i[1]] == verb:
40
- object = seg[i[]]
41
-
42
- return subject
43
-
44
- return None
45
- class STEM(object):
46
- def __init__(self, IPT_MODEL_PATH):
47
- self.ltp = LTP(IPT_MODEL_PATH)
48
-
49
- def start(self, sentence):
50
- """
51
- 用语义角色标注工具
52
- :param sentence: "他叫汤姆去拿外衣。"
53
- :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
54
- """
55
- # 语义角色标注方法
56
- seg, hidden = self.ltp.seg([sentence])
57
- srl = self.ltp.srl(hidden)
58
- seg, srl = seg[0], srl[0]
59
- events = []
60
- for wdx, each_srl in enumerate(srl):
61
- if each_srl:
62
- args = []
63
- for arg in each_srl:
64
- args.extend(seg[arg[1]:arg[2] + 1])
65
- # 添加上谓词
66
- args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
67
- events.append(args)
68
- # print(events)
69
- return events
70
-
71
- def start_dep_method(self, sentence):
72
- # seg, hidden = self.ltp.seg([sentence])
73
- # dep = self.ltp.dep(hidden)#, graph=False)
74
- # seg, dep = seg[0], dep[0]
75
- # for i in dep:
76
- # # 主谓宾
77
- # if 'SBV' == i[2]:
78
- # subject = seg[i[0]]
79
- # verb = seg[i[1]]
80
- # if 'VOB' in i[2]:
81
- # if seg[i[1]] == verb:
82
- # object = seg[i]
83
- # return subject
84
- return None
85
-
86
- IPT_MODEL_PATH = './tiny'
87
- stem = STEM(IPT_MODEL_PATH)
88
- sentence = '美国袭击伊拉克'
89
- a = stem.start(sentence)
90
- '''
91
-
92
-
93
- # 这个是另一种
94
- # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
95
- def subject_object_labeling_new(spo_list, text):
96
- pass
97
-
98
-
99
- # 这个是传统格式的
100
- # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
101
- def subject_object_labeling(spo_list, text):
102
- # TODO
103
- '''
104
- 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
105
- :param spo_list:
106
- :param text:
107
- :return: labeling_list
108
- '''
109
-
110
- def _spo_list_to_spo_predicate_dict(spo_list):
111
- spo_predicate_dict = dict()
112
- for spo_item in spo_list:
113
- predicate = spo_item["predicate"]
114
- subject = spo_item["subject"]
115
- object = spo_item["object"]
116
- spo_predicate_dict.setdefault(predicate, []).append((subject, object))
117
- return spo_predicate_dict
118
-
119
- def _index_q_list_in_k_list(q_list, k_list):
120
- """Known q_list in k_list, find index(first time) of q_list in k_list"""
121
- q_list_length = len(q_list)
122
- k_list_length = len(k_list)
123
- for idx in range(k_list_length - q_list_length + 1):
124
- t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])]
125
- # print(idx, t)
126
- if all(t):
127
- # print(idx)
128
- idx_start = idx
129
- return idx_start
130
-
131
- def _labeling_type(spo, spo_type):
132
- idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
133
- labeling_list[idx_start] = 'B-' + spo_type
134
- if len(spo) == 2:
135
- labeling_list[idx_start + 1] = 'I-' + spo_type
136
- elif len(spo) >= 3:
137
- labeling_list[idx_start + 1: idx_start + len(spo)] = ['I-' + spo_type] * (len(spo) - 1)
138
- else:
139
- pass
140
-
141
- spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
142
- labeling_list = ['O'] * len(text)
143
- # count = 0
144
- for predicate, spo_list_form in spo_predicate_dict.items():
145
- if predicate in text:
146
- for (spo_subject, spo_object) in spo_list_form:
147
- # if predicate not in spo_subject and predicate not in spo_object:
148
- _labeling_type(spo_subject, 'SUB')
149
- _labeling_type(spo_object, 'OBJ')
150
- _labeling_type(predicate, 'PRE')
151
- # count += 1
152
- # print(count)
153
- # if count == 2:
154
- # print()
155
- if labeling_list != ['O'] * len(text):
156
- return labeling_list
157
- return None
158
-
159
-
160
- def label(text, labels):
161
- '''
162
- 返回两列的标记数据序列
163
- :param text:
164
- :param labels:
165
- :return:
166
- '''
167
- train_sequence = '\n'.join(
168
- ['\t'.join(i) if i[0] != ' ' else '[null]\t{}'.format(i[1]) for i in zip(list(text), labels)])
169
- return train_sequence
170
-
171
-
172
- def convert_crf_format_10_fold(corpus, objdir_path):
173
- '''
174
- 把已经是crf格式的数据,分成十折。
175
- para:
176
-
177
- '''
178
- # corpus = list(range(1,22))
179
- j_mkdir(objdir_path)
180
- split_position = int(len(corpus) / 10)
181
- for k in range(0, 10):
182
- if k == 9:
183
- dev_set = corpus[k * split_position:]
184
- train_set = corpus[:k * split_position]
185
- else:
186
- dev_set = corpus[k * split_position: (k + 1) * split_position]
187
- train_set = corpus[:k * split_position] + corpus[(k + 1) * split_position:]
188
- writetxt_w_list(train_set, os.path.join(objdir_path, 'train{}.txt'.format(k + 1)))
189
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'test{}.txt'.format(k + 1)))
190
- writetxt_w_list(dev_set, os.path.join(objdir_path, 'dev{}.txt'.format(k + 1)))
191
-
192
-
193
- def read_seq_res(path, labels):
194
- '''
195
- 读序列标注三列数据的方法
196
- :param path:
197
- :param labels:
198
- :return:
199
- '''
200
- with codecs.open(path, 'r', 'utf-8') as rd:
201
- seqs_str = rd.read().strip()
202
- seqs_list = seqs_str.split('\n\n')
203
- text, raw_label, predict_label = [], [], []
204
- for seq in seqs_list:
205
- seq_split = seq.split('\n')
206
- text_tmp = ''
207
- raw_index_dict, pre_index_dict = {}, {}
208
- for label in labels:
209
- raw_index_dict.setdefault(label, [])
210
- pre_index_dict.setdefault(label, [])
211
- for idx, line in enumerate(seq_split):
212
- tmp = line.split('\t')
213
- text_tmp += tmp[0]
214
- if tmp[1] in labels:
215
- raw_index_dict[tmp[1]].append(idx)
216
- if tmp[2] in labels:
217
- pre_index_dict[tmp[2]].append(idx)
218
- text.append(text_tmp)
219
- raw_label.append(raw_index_dict)
220
- predict_label.append(pre_index_dict)
221
- return text, raw_label, predict_label
222
-
223
-
224
- def kfold(corpus, path, k=9, is_shuffle=True):
225
- '''
226
- k是10份中训练集占了几份
227
- '''
228
- j_mkdir(path)
229
- if is_shuffle:
230
- random.shuffle(corpus)
231
- split_position = int(len(corpus) / 10)
232
- train_set, dev_set = corpus[:k * split_position], corpus[k * split_position:]
233
- writetxt_w_list(train_set, os.path.join(path, 'train.tsv'), num_lf=1)
234
- writetxt_w_list(dev_set, os.path.join(path, 'test.tsv'), num_lf=1)
235
- writetxt_w_list(dev_set, os.path.join(path, 'dev.tsv'), num_lf=1)
236
-
237
-
238
- # 读取crf序列格式的数据
239
- def read_seq_data(path):
240
- content = readtxt_list_all_strip(path)
241
- lines = [i.split('\t') if i else '' for i in content]
242
- print(lines)
243
- sequences, labels, sequence, label = [], [], [], []
244
- for idx, line in enumerate(lines):
245
- if line == '':
246
- if sequence:
247
- sequences.append(sequence)
248
- labels.append(label)
249
- sequence, label = [], []
250
- else:
251
- sequence.append(line[0])
252
- label.append(line[1])
253
- if idx == len(lines) - 1 and sequence:
254
- sequences.append(sequence)
255
- labels.append(label)
256
- return sequences, labels
257
-
258
-
259
- def split_5_percent(lines, sample_precent=5):
260
- random.seed(8)
261
- # lines = list(range(1, 109))
262
- idx_lines = [(idx, i) for idx, i in enumerate(lines)]
263
- div = int(len(lines) / 100)
264
- sample_num = div * sample_precent
265
- sample = random.sample(idx_lines, sample_num)
266
- sorted_sample = sorted(sample, key=lambda x: x[0])
267
- remove_idx = [i[0] for i in sorted_sample]
268
- less_has_raw_line_info = [str(i[0] + 1) + '\t' + str(i[1]) for i in sorted_sample]
269
- most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
270
- print(less_has_raw_line_info)
271
- print(most)
272
- return most, less_has_raw_line_info
273
- def split_sentences(sentences, mode='chinese'):
274
- # sentences->Str
275
- # example '12“345。”“6789”'
276
- if mode == 'chinese':
277
- split_signs = list('。!?…')
278
- other_sign = "”"
279
- elif mode == 'english':
280
- split_signs = list('.!?')
281
- other_sign = '"'
282
- else:
283
- print('暂时还没有')
284
- split_signs = list('.!?')
285
- other_sign = '"'
286
- splited_sentences = []
287
- start_idx = 0
288
- for idx, char in enumerate(sentences):
289
- if idx == len(sentences) - 1:
290
- if char in split_signs:
291
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
292
- start_idx = idx + 1
293
- else:
294
- splited_sentences.append(sentences[start_idx:].strip())
295
- else:
296
- if char in split_signs:
297
- if sentences[idx + 1] == other_sign:
298
- if idx < len(sentences) - 2:
299
- # 处理。”。
300
- if sentences[idx + 2] not in split_signs:
301
- splited_sentences.append(sentences[start_idx:idx + 2].strip())
302
- start_idx = idx + 2
303
- elif sentences[idx + 1] not in split_signs:
304
- splited_sentences.append(sentences[start_idx:idx + 1].strip())
305
- start_idx = idx + 1
306
- return splited_sentences
307
- def pos_huanyuan():
308
- from nltk.stem import WordNetLemmatizer
309
- data = nlpertools.readtxt_list_all_strip('ie-selfmedia/')
310
- wnl = WordNetLemmatizer()
311
- # lemmatize nouns
312
- print(wnl.lemmatize('cars', 'n'))
313
- print(wnl.lemmatize('men', 'n'))
314
-
315
- # lemmatize verbs
316
- print(wnl.lemmatize('running', 'v'))
317
- print(wnl.lemmatize('ate', 'v'))
1
+ # encoding=utf-8
2
+ import codecs
3
+ import os
4
+ import random
5
+
6
+ from .io.dir import j_mkdir
7
+ from .io.file import readtxt_list_all_strip, writetxt_w_list, save_to_csv
8
+ # import numpy as np
9
+ # import seaborn as sns
10
+ # import torch
11
+ # import torch.nn as nn
12
+ # import xgboost as xgb
13
+ # from matplotlib import pyplot as plt
14
+ # from nltk.stem import WordNetLemmatizer
15
+ # from sklearn import metrics
16
+ # from transformers import BertTokenizer, BertForMaskedLM
17
+ from .utils.package import *
18
+
19
+
20
+ def calc_llm_train_activation_memory(
21
+ model_name, sequence_length, batch_size, hidden_dim, lay_number, attention_heads_num, gpu_num=1
22
+ ):
23
+
24
+ """
25
+ return bytes
26
+
27
+ reference:
28
+ 1. https://zhuanlan.zhihu.com/p/665172400
29
+ 2. https://deepspeed.readthedocs.io/en/latest/memory.html#discussion 里面没有乘以层数就很怪
30
+ """
31
+ # reference1
32
+ # attention
33
+ # FFN
34
+ # Layer Norm
35
+ r1 = (
36
+ sequence_length
37
+ * batch_size
38
+ * hidden_dim
39
+ * lay_number
40
+ * (34 + 5 * attention_heads_num * sequence_length / hidden_dim)
41
+ )
42
+ # reference2
43
+ r2 = (
44
+ lay_number*(2 * sequence_length * attention_heads_num + 16 * hidden_dim)
45
+ * sequence_length
46
+ * batch_size
47
+ / gpu_num
48
+ )
49
+ print(r1)
50
+ print(r2)
51
+ return r1
52
+
53
+
54
+ class DataAnalysis:
55
+ @staticmethod
56
+ def draw_pic(df, save_path):
57
+ """
58
+ 画直方图,对比两个不同类别差异
59
+ :param df: pd.DataFrame
60
+ :param save_path: str
61
+ :return:
62
+ """
63
+ sns.distplot(df[df["label"] == 1]["feature"], label="label1")
64
+ sns.distplot(df[df["label"] == 0]["feature"], label="label2")
65
+ plt.legend()
66
+ plt.savefig(save_path)
67
+
68
+
69
+ class DataStructure:
70
+ spo = {
71
+ "sentence": "内容简介《宜兴紫砂图典》由故宫出版社出版",
72
+ "triplets": [
73
+ {
74
+ "s": {"text": "宜兴紫砂图典", "l": 5, "r": 11},
75
+ "p": {"text": "出版社", "l": 15, "r": 18},
76
+ "o": {"text": "故宫出版社", "l": 13, "r": 18},
77
+ }
78
+ ],
79
+ "source": "baidu",
80
+ }
81
+ ner_input_example = "这句话一共有两个实体分别为大象和老鼠。"
82
+ ner_label_example = (
83
+ list("OOOOOOOOOOOOO") + ["B-s", "I-s"] + ["O"] + ["B-o", "I-o"] + ["O"]
84
+ )
85
+
86
+
87
+ def text_jaccard(ipt1, ipt2, ipt_level="char", sim_level="char"):
88
+ # 两个句子的jacccard系数
89
+ # 判断输入来重新定义ipt_level和sim_level
90
+
91
+ # a = set(ipt1.split())
92
+ # b = set(ipt2.split())
93
+ a = set(ipt1)
94
+ b = set(ipt2)
95
+ c = a.intersection(b)
96
+ # spical situation:
97
+ if not ipt1 and not ipt2:
98
+ return 0
99
+ return int(100 * float(len(c)) / (len(a) + len(b) - len(c)))
100
+
101
+
102
+ class STEM(object):
103
+ def __init__(self, IPT_MODEL_PATH):
104
+ self.ltp = LTP(IPT_MODEL_PATH)
105
+
106
+ def start_by_dep(self, sentence):
107
+ seg, hidden = self.ltp.seg([sentence])
108
+ dep = self.ltp.dep(hidden) # , graph=False)
109
+ seg, dep = seg[0], dep[0]
110
+ for i in dep:
111
+ # 主谓宾
112
+ if "SBV" == i[2]:
113
+ subject = seg[i[0]]
114
+ verb = seg[i[1]]
115
+ if "VOB" in i[2]:
116
+ if seg[i[1]] == verb:
117
+ object = seg[i[0]]
118
+
119
+ return subject
120
+
121
+ return None
122
+
123
+ def start_by_srl(self, sentence):
124
+ """
125
+ 用语义角色标注工具
126
+ :param sentence: "他叫汤姆去拿外衣。"
127
+ :return: events: [['他', '叫', '汤姆', '去', '拿', '外衣'], ['汤姆', '拿', '外衣']]
128
+ """
129
+ # 语义角色标注方法
130
+ seg, hidden = self.ltp.seg([sentence])
131
+ srl = self.ltp.srl(hidden)
132
+ seg, srl = seg[0], srl[0]
133
+ events = []
134
+ for wdx, each_srl in enumerate(srl):
135
+ if each_srl:
136
+ args = []
137
+ for arg in each_srl:
138
+ args.extend(seg[arg[1] : arg[2] + 1])
139
+ # 添加上谓词
140
+ args.insert(each_srl[0][2] - each_srl[0][1] + 1, seg[wdx])
141
+ events.append(args)
142
+ # print(events)
143
+ return events
144
+
145
+
146
+ # 这个是另一种
147
+ # 数据示例为:{"sentence": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "triplets": [{"s": {"text": "兴族闪蝶", "l": 0, "r": 4}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蛱蝶科", "l": 62, "r": 65}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "蝴蝶", "l": 72, "r": 74}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}, {"s": {"text": "闪蝶属", "l": 66, "r": 69}, "p": {"text": "目", "l": 60, "r": 61}, "o": {"text": "鳞翅目", "l": 58, "r": 61}}], "source": "baidu"}
148
+ def subject_object_labeling_new(spo_list, text):
149
+ pass
150
+
151
+
152
+ # 这个是传统格式的
153
+ # 数据格式示例:{"postag": [{"word": "兴族闪蝶", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "Morpho achilles patroclus", "pos": "nz"}, {"word": ",", "pos": "w"}, {"word": "节肢动物门", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "昆虫纲", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "鳞翅目", "pos": "n"}, {"word": "、", "pos": "w"}, {"word": "蛱蝶科", "pos": "nz"}, {"word": "、", "pos": "w"}, {"word": "闪蝶属", "pos": "nz"}, {"word": "的", "pos": "u"}, {"word": "一种", "pos": "m"}, {"word": "蝴蝶", "pos": "n"}], "text": "兴族闪蝶,Morpho patroclus,Morpho achilles patroclus,节肢动物门、昆虫纲、鳞翅目、蛱蝶科、闪蝶属的一种蝴蝶", "spo_list": [{"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "兴族闪蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蛱蝶科"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "蝴蝶"}, {"predicate": "目", "object_type": "目", "subject_type": "生物", "object": "鳞翅目", "subject": "闪蝶属"}]}
154
+ def subject_object_labeling(spo_list, text):
155
+ # TODO
156
+ """
157
+ 百度那种有spo字典的数据,给标成。草,看不懂,得找找哪里用的
158
+ :param spo_list:
159
+ :param text:
160
+ :return: labeling_list
161
+ """
162
+
163
+ def _spo_list_to_spo_predicate_dict(spo_list):
164
+ spo_predicate_dict = dict()
165
+ for spo_item in spo_list:
166
+ predicate = spo_item["predicate"]
167
+ subject = spo_item["subject"]
168
+ object = spo_item["object"]
169
+ spo_predicate_dict.setdefault(predicate, []).append((subject, object))
170
+ return spo_predicate_dict
171
+
172
+ def _index_q_list_in_k_list(q_list, k_list):
173
+ """Known q_list in k_list, find index(first time) of q_list in k_list"""
174
+ q_list_length = len(q_list)
175
+ k_list_length = len(k_list)
176
+ for idx in range(k_list_length - q_list_length + 1):
177
+ t = [q == k for q, k in zip(q_list, k_list[idx : idx + q_list_length])]
178
+ # print(idx, t)
179
+ if all(t):
180
+ # print(idx)
181
+ idx_start = idx
182
+ return idx_start
183
+
184
+ def _labeling_type(spo, spo_type):
185
+ idx_start = _index_q_list_in_k_list(q_list=spo, k_list=text)
186
+ labeling_list[idx_start] = "B-" + spo_type
187
+ if len(spo) == 2:
188
+ labeling_list[idx_start + 1] = "I-" + spo_type
189
+ elif len(spo) >= 3:
190
+ labeling_list[idx_start + 1 : idx_start + len(spo)] = ["I-" + spo_type] * (
191
+ len(spo) - 1
192
+ )
193
+ else:
194
+ pass
195
+
196
+ spo_predicate_dict = _spo_list_to_spo_predicate_dict(spo_list)
197
+ labeling_list = ["O"] * len(text)
198
+ # count = 0
199
+ for predicate, spo_list_form in spo_predicate_dict.items():
200
+ if predicate in text:
201
+ for (spo_subject, spo_object) in spo_list_form:
202
+ # if predicate not in spo_subject and predicate not in spo_object:
203
+ _labeling_type(spo_subject, "SUB")
204
+ _labeling_type(spo_object, "OBJ")
205
+ _labeling_type(predicate, "PRE")
206
+ # count += 1
207
+ # print(count)
208
+ # if count == 2:
209
+ # print()
210
+ if labeling_list != ["O"] * len(text):
211
+ return labeling_list
212
+ return None
213
+
214
+
215
+ def label(text, labels):
216
+ """
217
+ 返回两列的标记数据序列
218
+ :param text:
219
+ :param labels:
220
+ :return:
221
+ """
222
+ train_sequence = "\n".join(
223
+ [
224
+ "\t".join(i) if i[0] != " " else "[null]\t{}".format(i[1])
225
+ for i in zip(list(text), labels)
226
+ ]
227
+ )
228
+ return train_sequence
229
+
230
+
231
+ def convert_crf_format_10_fold(corpus, objdir_path):
232
+ """
233
+ 把已经是crf格式的数据,分成十折。
234
+ para:
235
+
236
+ """
237
+ # corpus = list(range(1,22))
238
+ j_mkdir(objdir_path)
239
+ split_position = int(len(corpus) / 10)
240
+ for k in range(0, 10):
241
+ if k == 9:
242
+ dev_set = corpus[k * split_position :]
243
+ train_set = corpus[: k * split_position]
244
+ else:
245
+ dev_set = corpus[k * split_position : (k + 1) * split_position]
246
+ train_set = (
247
+ corpus[: k * split_position] + corpus[(k + 1) * split_position :]
248
+ )
249
+ writetxt_w_list(
250
+ train_set, os.path.join(objdir_path, "train{}.txt".format(k + 1))
251
+ )
252
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "test{}.txt".format(k + 1)))
253
+ writetxt_w_list(dev_set, os.path.join(objdir_path, "dev{}.txt".format(k + 1)))
254
+
255
+
256
+ def read_seq_res(path, labels):
257
+ """
258
+ 读序列标注三列数据的方法
259
+ :param path:
260
+ :param labels:
261
+ :return:
262
+ """
263
+ with codecs.open(path, "r", "utf-8") as rd:
264
+ seqs_str = rd.read().strip()
265
+ seqs_list = seqs_str.split("\n\n")
266
+ text, raw_label, predict_label = [], [], []
267
+ for seq in seqs_list:
268
+ seq_split = seq.split("\n")
269
+ text_tmp = ""
270
+ raw_index_dict, pre_index_dict = {}, {}
271
+ for label in labels:
272
+ raw_index_dict.setdefault(label, [])
273
+ pre_index_dict.setdefault(label, [])
274
+ for idx, line in enumerate(seq_split):
275
+ tmp = line.split("\t")
276
+ text_tmp += tmp[0]
277
+ if tmp[1] in labels:
278
+ raw_index_dict[tmp[1]].append(idx)
279
+ if tmp[2] in labels:
280
+ pre_index_dict[tmp[2]].append(idx)
281
+ text.append(text_tmp)
282
+ raw_label.append(raw_index_dict)
283
+ predict_label.append(pre_index_dict)
284
+ return text, raw_label, predict_label
285
+
286
+
287
+ def kfold_txt(corpus, path, k=9, is_shuffle=True):
288
+ """
289
+ k是10份中训练集占了几份
290
+ """
291
+ j_mkdir(path)
292
+ if is_shuffle:
293
+ random.shuffle(corpus)
294
+ split_position = int(len(corpus) / 10)
295
+ train_set, dev_set = corpus[: k * split_position], corpus[k * split_position :]
296
+ writetxt_w_list(train_set, os.path.join(path, "train.tsv"), num_lf=1)
297
+ writetxt_w_list(dev_set, os.path.join(path, "test.tsv"), num_lf=1)
298
+ writetxt_w_list(dev_set, os.path.join(path, "dev.tsv"), num_lf=1)
299
+
300
+
301
+ def kfold_df(df, save_dir=None):
302
+ """
303
+ 划分train test val集, 写为windows可读的csv。
304
+ :param df:pd.DataFrame
305
+ :param save_dir:
306
+ :return:
307
+ """
308
+ from sklearn.model_selection import KFold
309
+ import pandas as pd
310
+
311
+ train_idx, test_and_val_idx = KFold(n_splits=8, shuffle=True).split(df).__next__()
312
+ df_test_and_val = df.iloc[test_and_val_idx]
313
+ test_idx, val_idx = (
314
+ KFold(n_splits=2, shuffle=True).split(df_test_and_val).__next__()
315
+ )
316
+ df_train = df.iloc[train_idx]
317
+ df_val = df.iloc[val_idx]
318
+ df_test = df.iloc[test_idx]
319
+ if save_dir:
320
+ j_mkdir(save_dir)
321
+ save_to_csv(df_train, os.path.join(save_dir, "train.csv"))
322
+ save_to_csv(df_test, os.path.join(save_dir, "test.csv"))
323
+ save_to_csv(df_val, os.path.join(save_dir, "val.csv"))
324
+ return df_train, df_val, df_test
325
+
326
+
327
+ # 读取crf序列格式的数据
328
+ def read_seq_data(path):
329
+ content = readtxt_list_all_strip(path)
330
+ lines = [i.split("\t") if i else "" for i in content]
331
+ print(lines)
332
+ sequences, labels, sequence, label = [], [], [], []
333
+ for idx, line in enumerate(lines):
334
+ if line == "":
335
+ if sequence:
336
+ sequences.append(sequence)
337
+ labels.append(label)
338
+ sequence, label = [], []
339
+ else:
340
+ sequence.append(line[0])
341
+ label.append(line[1])
342
+ if idx == len(lines) - 1 and sequence:
343
+ sequences.append(sequence)
344
+ labels.append(label)
345
+ return sequences, labels
346
+
347
+
348
+ def split_5_percent(lines, sample_precent=5):
349
+ random.seed(8)
350
+ # lines = list(range(1, 109))
351
+ idx_lines = [(idx, i) for idx, i in enumerate(lines)]
352
+ div = int(len(lines) / 100)
353
+ sample_num = div * sample_precent
354
+ sample = random.sample(idx_lines, sample_num)
355
+ sorted_sample = sorted(sample, key=lambda x: x[0])
356
+ remove_idx = [i[0] for i in sorted_sample]
357
+ less_has_raw_line_info = [str(i[0] + 1) + "\t" + str(i[1]) for i in sorted_sample]
358
+ most = [i for idx, i in enumerate(lines) if not idx in remove_idx]
359
+ print(less_has_raw_line_info)
360
+ print(most)
361
+ return most, less_has_raw_line_info
362
+
363
+
364
+ def split_sentence(sentence, language="chinese", cross_line=True):
365
+ """
366
+ 分句,英文有nltk,中文怎么能没有好的分句工具呢
367
+ :param sentence:
368
+ :param language:
369
+ :param cross_line:
370
+ :return:
371
+ """
372
+ # sentences->Str
373
+ # example '12“345。”“6789”'
374
+ assert language in ["chinese", "english"], "unsupportable for other language"
375
+ sentence = sentence.replace("\r", "")
376
+ if language == "chinese":
377
+ split_signs = list("。!?…")
378
+ if cross_line:
379
+ split_signs.append("\n")
380
+ other_sign = "”"
381
+ elif language == "english":
382
+ split_signs = list(".!?")
383
+ other_sign = '"'
384
+ else:
385
+ split_signs = list(".!?")
386
+ other_sign = '"'
387
+ sentences = []
388
+ start_idx = 0
389
+ for idx, char in enumerate(sentence):
390
+ if idx == len(sentence) - 1:
391
+ if char in split_signs:
392
+ sentences.append(sentence[start_idx : idx + 1].strip())
393
+ start_idx = idx + 1
394
+ else:
395
+ sentences.append(sentence[start_idx:].strip())
396
+ else:
397
+ if char in split_signs:
398
+ if sentence[idx + 1] == other_sign:
399
+ if idx < len(sentence) - 2:
400
+ # 处理。”。
401
+ if sentence[idx + 2] not in split_signs:
402
+ sentences.append(sentence[start_idx : idx + 2].strip())
403
+ start_idx = idx + 2
404
+ elif sentence[idx + 1] not in split_signs:
405
+ sentences.append(sentence[start_idx : idx + 1].strip())
406
+ start_idx = idx + 1
407
+ return sentences
408
+
409
+
410
+ def pos_reduction():
411
+ wnl = WordNetLemmatizer()
412
+ # lemmatize nouns
413
+ print(wnl.lemmatize("cars", "n"))
414
+ print(wnl.lemmatize("men", "n"))
415
+
416
+ # lemmatize verbs
417
+ print(wnl.lemmatize("running", "v"))
418
+ print(wnl.lemmatize("ate", "v"))
419
+
420
+
421
+ class DataVisualization:
422
+ # 和下面的类冲突了
423
+ pass
424
+
425
+
426
+ class Evaluate:
427
+ def __init__(self):
428
+ pass
429
+
430
+ def auc_metric(self, k):
431
+ pass
432
+
433
+ def map_metric(self):
434
+ pass
435
+
436
+ def ndcg(self, n, y_true, y_score):
437
+ report = metrics.ndcg_score(y_true, y_score)
438
+ return report
439
+
440
+
441
+ class DecideTreeUtils:
442
+ @staticmethod
443
+ def draw(bst):
444
+ # xgb 画图
445
+ fig_tree, ax_tree = plt.subplots(figsize=(200, 200))
446
+ xgb.plot_tree(bst, ax=ax_tree)
447
+ fig_tree.savefig("tree.png")
448
+ plt.show()
449
+
450
+
451
+ def seed_everything(seed=7777777) -> None:
452
+ """
453
+ 设置整个开发环境的seed
454
+ :param seed:
455
+ :param device:
456
+ :return:
457
+ """
458
+ random.seed(seed)
459
+ os.environ["PYTHONHASHSEED"] = str(seed)
460
+ np.random.seed(seed)
461
+ torch.manual_seed(seed) # CPU随机种子确定
462
+ torch.cuda.manual_seed(seed)
463
+ torch.cuda.manual_seed_all(seed)
464
+ # some cudnn methods can be random even after fixing the seed
465
+ # unless you tell it to be deterministic
466
+ torch.backends.cudnn.deterministic = True
467
+
468
+
469
+ if __name__ == "__main__":
470
+ # stem = STEM(IPT_MODEL_PATH)
471
+ # test_sentence = "美国袭击伊拉克"
472
+ # a = stem.start_by_srl(test_sentence)
473
+
474
+ res = calc_llm_train_activation_memory(
475
+ model_name="",
476
+ sequence_length=2048,
477
+ batch_size=1,
478
+ hidden_dim=4096,
479
+ lay_number=28,
480
+ attention_heads_num=32,
481
+ gpu_num=1
482
+ )
483
+ print(res, "G")