nlpertools 1.0.4__py3-none-any.whl → 1.0.6.dev0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. nlpertools/__init__.py +24 -11
  2. nlpertools/algo/__init__.py +0 -0
  3. nlpertools/algo/ac.py +18 -0
  4. nlpertools/algo/bit_ops.py +28 -0
  5. nlpertools/algo/kmp.py +94 -0
  6. nlpertools/algo/num_ops.py +12 -0
  7. nlpertools/algo/template.py +116 -0
  8. nlpertools/algo/union.py +13 -0
  9. nlpertools/data_client.py +387 -0
  10. nlpertools/data_structure/__init__.py +0 -0
  11. nlpertools/data_structure/base_structure.py +109 -0
  12. nlpertools/dataprocess.py +611 -3
  13. nlpertools/default_db_config.yml +41 -0
  14. nlpertools/io/__init__.py +3 -3
  15. nlpertools/io/dir.py +54 -47
  16. nlpertools/io/file.py +277 -205
  17. nlpertools/ml.py +483 -317
  18. nlpertools/monitor/__init__.py +0 -0
  19. nlpertools/monitor/gpu.py +18 -0
  20. nlpertools/monitor/memory.py +24 -0
  21. nlpertools/movie.py +36 -0
  22. nlpertools/nlpertools_config.yml +1 -0
  23. nlpertools/{openApi.py → open_api.py} +65 -62
  24. nlpertools/other.py +364 -188
  25. nlpertools/pic.py +288 -0
  26. nlpertools/plugin.py +43 -34
  27. nlpertools/reminder.py +98 -15
  28. nlpertools/template/__init__.py +0 -0
  29. nlpertools/utils/__init__.py +3 -0
  30. nlpertools/utils/lazy.py +727 -0
  31. nlpertools/utils/log_util.py +20 -0
  32. nlpertools/utils/package.py +89 -0
  33. nlpertools/utils/package_v1.py +94 -0
  34. nlpertools/utils/package_v2.py +117 -0
  35. nlpertools/utils_for_nlpertools.py +93 -0
  36. nlpertools/vector_index_demo.py +108 -0
  37. nlpertools/wrapper.py +161 -0
  38. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/LICENSE +200 -200
  39. nlpertools-1.0.6.dev0.dist-info/METADATA +111 -0
  40. nlpertools-1.0.6.dev0.dist-info/RECORD +43 -0
  41. {nlpertools-1.0.4.dist-info → nlpertools-1.0.6.dev0.dist-info}/WHEEL +1 -1
  42. nlpertools-1.0.6.dev0.dist-info/top_level.txt +2 -0
  43. nlpertools_helper/__init__.py +10 -0
  44. nlpertools-1.0.4.dist-info/METADATA +0 -42
  45. nlpertools-1.0.4.dist-info/RECORD +0 -15
  46. nlpertools-1.0.4.dist-info/top_level.txt +0 -1
nlpertools/other.py CHANGED
@@ -1,188 +1,364 @@
1
- #!/usr/bin/python3.8
2
- # -*- coding: utf-8 -*-
3
- # @Author : youshu.Ji
4
- import string
5
- import time
6
- from concurrent.futures import ThreadPoolExecutor
7
- from functools import wraps
8
- import requests
9
- from tqdm import tqdm
10
-
11
- try:
12
- from elasticsearch import Elasticsearch
13
- import pyquery as pq
14
- from ltp import LTP
15
- except:
16
- pass
17
-
18
- CHINESE_PUNCTUATION = list(',。;:‘’“”!?《》「」【】<>()、')
19
- ENGLISH_PUNCTUATION = list(',.;:\'"!?<>()')
20
- # other ----------------------------------------------------------------------
21
- # 统计词频
22
- def calc_word_count(list_word, mode, path='tempcount.txt', sort_id=1, is_reverse=True):
23
- word_count = {}
24
- for key in list_word:
25
- if key not in word_count:
26
- word_count[key] = 1
27
- else:
28
- word_count[key] += 1
29
- word_dict_sort = sorted(word_count.items(), key=lambda x: x[sort_id], reverse=is_reverse)
30
- if mode == 'w':
31
- for key in word_dict_sort:
32
- writetxt_a(str(key[0]) + '\t' + str(key[1]) + '\n', path)
33
- elif mode == 'p':
34
- for key in word_dict_sort:
35
- print(str(key[0]) + '\t' + str(key[1]))
36
- elif mode == 'u':
37
- return word_dict_sort
38
-
39
- # 字典去重
40
- def dupl_dict(dict_list, key):
41
- new_dict_list, value_set = [], []
42
- print('去重中...')
43
- for i in tqdm(dict_list):
44
- if i[key] not in value_set:
45
- new_dict_list.append(i)
46
- value_set.append(i[key])
47
- return new_dict_list
48
-
49
-
50
- def del_special_char(sentence):
51
- special_chars = ['\ufeff', '\xa0', '\u3000', '\xa0', '\ue627']
52
- for i in special_chars:
53
- sentence = sentence.replace(i, '')
54
- return sentence
55
-
56
-
57
- def en_pun_2_zh_pun(sentence):
58
- # TODO 因为引号的问题,所以我没有写
59
- for i in ENGLISH_PUNCTUATION:
60
- pass
61
-
62
-
63
- def spider(url):
64
- """
65
-
66
- :param url:
67
- :return:
68
- """
69
- if 'baijiahao' in url:
70
- content = requests.get(url)
71
- # print(content.text)
72
- html = pq.PyQuery(content.text)
73
- title = html('.index-module_articleTitle_28fPT').text()
74
- res = html('.index-module_articleWrap_2Zphx').text().rstrip('举报/反馈')
75
- return '{}\n{}'.format(title, res)
76
-
77
-
78
- def eda(sentence):
79
- url = 'http://x.x.x.x:x/eda'
80
- json_data = dict({"sentence": sentence})
81
- res = requests.post(url, json=json_data)
82
- return res.json()['eda']
83
-
84
-
85
- def find_language(text):
86
- # TODO 替换为开源包
87
- letters = list(string.ascii_letters)
88
- if len(text) > 50:
89
- passage = text[:50]
90
- len_passage = 50
91
- else:
92
- len_passage = len(text)
93
- count = 0
94
- for c in passage:
95
- if c in letters:
96
- count += 1
97
- if count / len_passage > 0.5:
98
- return "en"
99
- else:
100
- return "not en"
101
-
102
-
103
- def print_prf():
104
- from sklearn.metrics import precision_recall_fscore_support
105
- y_true = [0, 1, 2, 1, 1, 2, 3, 1, 1, 1]
106
- y_pred = [0, 1, 2, 1, 1, 2, 3, 1, 1, 1]
107
- p, r, f, s = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred)
108
- print("p\t{}".format(p))
109
- print("r\t{}".format(r))
110
- print("f\t{}".format(f))
111
- print("s\t{}".format(s))
112
- """
113
- result = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, labels=label)
114
-
115
- for i in range(len(label)):
116
- res = []
117
- for k in result:
118
- res.append('%.5f' % k[i])
119
- print('{}: {} {} {}'.format(label[i], *res[:3]))
120
- """
121
-
122
-
123
- def print_cpu():
124
- import psutil
125
- p = psutil.Process()
126
- # pro_info = p.as_dict(attrs=['pid', 'name', 'username'])
127
- print(psutil.cpu_count())
128
-
129
-
130
- @nlpertools.fn_timer
131
- def stress_test(func, ipts):
132
- with ThreadPoolExecutor() as executor:
133
- results = list(tqdm(executor.map(func, ipts), total=len(ipts)))
134
- return results
135
-
136
- # 定义装饰器
137
- def fn_timer(function):
138
- @wraps(function)
139
- def function_timer(*args, **kwargs):
140
- t0 = time.time()
141
- result = function(*args, **kwargs)
142
- t1 = time.time()
143
- print('[finished {func_name} in {time:.2f}s]'.format(func_name=function.__name__, time=t1 - t0))
144
- return result
145
-
146
- return function_timer
147
-
148
-
149
- def get_substring_loc(text, subtext):
150
- import re
151
- res = re.finditer(
152
- subtext.replace('\\', '\\\\').replace('?', '\?').replace('(', '\(').replace(')', '\)').replace(']',
153
- '\]').replace(
154
- '[', '\[').replace('+', '\+'), text)
155
- l, r = [i for i in res][0].regs[0]
156
- return l, r
157
-
158
-
159
- def tf_idf(corpus, save_path):
160
- from sklearn.feature_extraction.text import TfidfTransformer
161
- from sklearn.feature_extraction.text import CountVectorizer
162
- tfidfdict = {}
163
- vectorizer = CountVectorizer() # 该类会将文本中的词语转换为词频矩阵,矩阵元素a[i][j] 表示j词在i类文本下的词频
164
- transformer = TfidfTransformer() # 该类会统计每个词语的tf-idf权值
165
- tfidf = transformer.fit_transform(
166
- vectorizer.fit_transform(corpus)) # 第一个fit_transform是计算tf-idf,第二个fit_transform是将文本转为词频矩阵
167
- word = vectorizer.get_feature_names() # 获取词袋模型中的所有词语
168
- weight = tfidf.toarray() # 将tf-idf矩阵抽取出来,元素a[i][j]表示j词在i类文本中的tf-idf权重
169
- for i in range(len(weight)): # 打印每类文本的tf-idf词语权重,第一个for遍历所有文本,第二个for便利某一类文本下的词语权重
170
- for j in range(len(word)):
171
- getword = word[j]
172
- getvalue = weight[i][j]
173
- if getvalue != 0: # 去掉值为0的项
174
- if getword in tfidfdict: # 更新全局TFIDF值
175
- tfidfdict[getword] += float(getvalue)
176
- else:
177
- tfidfdict.update({getword: getvalue})
178
- sorted_tfidf = sorted(tfidfdict.items(), key=lambda d: d[1], reverse=True)
179
- to_write = ['{} {}'.format(i[0], i[1]) for i in sorted_tfidf]
180
- writetxt_w_list(to_write, save_path, num_lf=1)
181
-
182
- # 常用函数参考
183
- # import tensorflow as tf
184
- #
185
- # gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
186
- # sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
187
- # for gpu in tf.config.experimental.list_physical_devices('GPU'):
188
- # tf.config.experimental.set_memory_growth()
1
+ #!/usr/bin/python3.8
2
+ # -*- coding: utf-8 -*-
3
+ # @Author : youshu.Ji
4
+ import itertools
5
+ import os
6
+ import re
7
+ import string
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from functools import reduce
10
+ import math
11
+ import datetime
12
+ import psutil
13
+ from .io.file import writetxt_w_list, writetxt_a
14
+ # import numpy as np
15
+ # import psutil
16
+ # import pyquery as pq
17
+ # import requests
18
+ # import torch
19
+ # from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
20
+ # from sklearn.metrics import precision_recall_fscore_support
21
+ # from tqdm import tqdm
22
+ # from win32evtlogutil import langid
23
+ from .utils.package import *
24
+
25
+ CHINESE_PUNCTUATION = list(',。;:‘’“”!?《》「」【】<>()、')
26
+ ENGLISH_PUNCTUATION = list(',.;:\'"!?<>()')
27
+ OTHER_PUNCTUATION = list('!@#$%^&*')
28
+
29
+
30
+ def seed_everything():
31
+ import torch
32
+ # seed everything
33
+ seed = 7777777
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed) # CPU随机种子确定
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+
39
+
40
+ def sent_email(mail_user, mail_pass, receiver, title, content, attach_path=None):
41
+ import smtplib
42
+ from email.mime.multipart import MIMEMultipart
43
+ from email.mime.text import MIMEText
44
+ from email.mime.application import MIMEApplication
45
+
46
+ mail_host = 'smtp.qq.com'
47
+ mail_user = mail_user
48
+ mail_pass = mail_pass
49
+ sender = mail_user
50
+
51
+ message = MIMEMultipart()
52
+ message.attach(MIMEText(content, 'plain', 'utf-8'))
53
+ if attach_path:
54
+ attachment = MIMEApplication(open(attach_path, 'rb').read())
55
+ attachment["Content-Type"] = 'application/octet-stream'
56
+ attachment.add_header('Content-Dispositon', 'attachment',
57
+ filename=('utf-8', '', attach_path)) # 注意:此处basename要转换为gbk编码,否则中文会有乱码。
58
+ message.attach(attachment)
59
+ message['Subject'] = title
60
+ message['From'] = sender
61
+ message['To'] = receiver
62
+
63
+ try:
64
+ smtp_obj = smtplib.SMTP()
65
+ smtp_obj.connect(mail_host, 25)
66
+ smtp_obj.login(mail_user, mail_pass)
67
+ smtp_obj.sendmail(sender, receiver, message.as_string())
68
+ smtp_obj.quit()
69
+ print('send email success')
70
+ except smtplib.SMTPException as e:
71
+ print('send failed', e)
72
+
73
+
74
+ def convert_np_to_py(obj):
75
+ if isinstance(obj, dict):
76
+ return {k: convert_np_to_py(v) for k, v in obj.items()}
77
+ elif isinstance(obj, list):
78
+ return [convert_np_to_py(v) for v in obj]
79
+ elif isinstance(obj, np.float64) or isinstance(obj, np.float32):
80
+ return float(obj)
81
+ else:
82
+ return obj
83
+
84
+
85
+ def git_push():
86
+ """
87
+ 针对国内提交github经常失败,自动提交
88
+ """
89
+ num = -1
90
+ while 1:
91
+ num += 1
92
+ print("retry num: {}".format(num))
93
+ info = os.system("git push --set-upstream origin main")
94
+ print(str(info))
95
+ if not str(info).startswith("fatal"):
96
+ print("scucess")
97
+ break
98
+
99
+
100
+ def snake_to_camel(s: str) -> str:
101
+ """
102
+ author: u
103
+ snake case 转换到 camel case.
104
+ :param s: snake case variable
105
+ :return:
106
+ """
107
+ return s.title().replace("_", "")
108
+
109
+
110
+ def camel_to_snake(s: str) -> str:
111
+ """
112
+ 将 camel case 转换到 snake case.
113
+ :param s: camel case variable
114
+ :return:
115
+ """
116
+ return reduce(lambda x, y: x + ('_' if y.isupper() else '') + y, s).lower()
117
+
118
+
119
+ # other ----------------------------------------------------------------------
120
+ # 统计词频
121
+ def calc_word_count(list_word, mode, path='tempcount.txt', sort_id=1, is_reverse=True):
122
+ word_count = {}
123
+ for key in list_word:
124
+ if key not in word_count:
125
+ word_count[key] = 1
126
+ else:
127
+ word_count[key] += 1
128
+ word_dict_sort = sorted(word_count.items(), key=lambda x: x[sort_id], reverse=is_reverse)
129
+ if mode == 'w':
130
+ for key in word_dict_sort:
131
+ writetxt_a(str(key[0]) + '\t' + str(key[1]) + '\n', path)
132
+ elif mode == 'p':
133
+ for key in word_dict_sort:
134
+ print(str(key[0]) + '\t' + str(key[1]))
135
+ elif mode == 'u':
136
+ return word_dict_sort
137
+
138
+
139
+ # 字典去重
140
+ def dupl_dict(dict_list, key):
141
+ new_dict_list, value_set = [], []
142
+ print('去重中...')
143
+ for i in tqdm(dict_list):
144
+ if i[key] not in value_set:
145
+ new_dict_list.append(i)
146
+ value_set.append(i[key])
147
+ return new_dict_list
148
+
149
+
150
+ def multi_thread_run(_task, data):
151
+ with ThreadPoolExecutor() as executor:
152
+ result = list(tqdm(executor.map(_task, data), total=len(data)))
153
+ return result
154
+
155
+
156
+ def del_special_char(sentence):
157
+ special_chars = ['\ufeff', '\xa0', '\u3000', '\xa0', '\ue627']
158
+ for i in special_chars:
159
+ sentence = sentence.replace(i, '')
160
+ return sentence
161
+
162
+
163
+ def en_pun_2_zh_pun(sentence):
164
+ # TODO 因为引号的问题,所以我没有写
165
+ for i in ENGLISH_PUNCTUATION:
166
+ pass
167
+
168
+
169
+ def spider(url):
170
+ """
171
+
172
+ :param url:
173
+ :return:
174
+ """
175
+ if 'baijiahao' in url:
176
+ content = requests.get(url)
177
+ # print(content.text)
178
+ html = pq.PyQuery(content.text)
179
+ title = html('.index-module_articleTitle_28fPT').text()
180
+ res = html('.index-module_articleWrap_2Zphx').text().rstrip('举报/反馈')
181
+ return '{}\n{}'.format(title, res)
182
+
183
+
184
+ def eda(sentence):
185
+ url = 'https://x.x.x.x:x/eda'
186
+ json_data = dict({"sentence": sentence})
187
+ res = requests.post(url, json=json_data)
188
+ return res.json()['eda']
189
+
190
+
191
+ def find_language(text):
192
+ # TODO 替换为开源包
193
+ letters = list(string.ascii_letters)
194
+ if len(text) > 50:
195
+ passage = text[:50]
196
+ len_passage = 50
197
+ else:
198
+ len_passage = len(text)
199
+ count = 0
200
+ for c in passage:
201
+ if c in letters:
202
+ count += 1
203
+ if count / len_passage > 0.5:
204
+ return "en"
205
+ else:
206
+ return "not en"
207
+
208
+
209
+ def print_prf(y_true, y_pred, label=None):
210
+ # y_true = [0, 1, 2, 1, 1, 2, 3, 1, 1, 1]
211
+ # y_pred = [0, 1, 2, 1, 1, 2, 3, 1, 1, 1]
212
+ # p, r, f, s = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred)
213
+ # print("p\t{}".format(p))
214
+ # print("r\t{}".format(r))
215
+ # print("f\t{}".format(f))
216
+ # print("s\t{}".format(s))
217
+ result = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, labels=label)
218
+
219
+ for i in range(len(label)):
220
+ res = []
221
+ for k in result:
222
+ res.append('%.5f' % k[i])
223
+ print('{}: {} {} {}'.format(label[i], *res[:3]))
224
+
225
+
226
+ def print_cpu():
227
+ p = psutil.Process()
228
+ # pro_info = p.as_dict(attrs=['pid', 'name', 'username'])
229
+ print(psutil.cpu_count())
230
+
231
+
232
+ def stress_test(func, ipts):
233
+ with ThreadPoolExecutor() as executor:
234
+ results = list(tqdm(executor.map(func, ipts), total=len(ipts)))
235
+ return results
236
+
237
+
238
+ def get_substring_loc(text, subtext):
239
+ res = re.finditer(
240
+ subtext.replace('\\', '\\\\').replace('?', '\?').replace('(', '\(').replace(')', '\)').replace(']',
241
+ '\]').replace(
242
+ '[', '\[').replace('+', '\+'), text)
243
+ l, r = [i for i in res][0].regs[0]
244
+ return l, r
245
+
246
+
247
+ def squeeze_list(high_dim_list):
248
+ return list(itertools.chain.from_iterable(high_dim_list))
249
+
250
+
251
+ def unsqueeze_list(flatten_list, each_element_len):
252
+ two_dim_list = [flatten_list[i * each_element_len:(i + 1) * each_element_len] for i in
253
+ range(len(flatten_list) // each_element_len)]
254
+ return two_dim_list
255
+
256
+
257
+ def auto_close():
258
+ """
259
+ 针对企业微信15分钟会显示离开的机制,假装自己还在上班
260
+ """
261
+ import pyautogui as pg
262
+ import time
263
+ import os
264
+ cmd = 'schtasks /create /tn shut /tr "shutdown -s -f" /sc once /st 23:30'
265
+ os.system(cmd)
266
+ while 1:
267
+ pg.moveTo(970, 17, 2)
268
+ pg.click()
269
+ time.sleep(840)
270
+
271
+
272
+ def tf_idf(corpus, save_path):
273
+ tfidfdict = {}
274
+ vectorizer = CountVectorizer() # 该类会将文本中的词语转换为词频矩阵,矩阵元素a[i][j] 表示j词在i类文本下的词频
275
+ transformer = TfidfTransformer() # 该类会统计每个词语的tf-idf权值
276
+ tfidf = transformer.fit_transform(
277
+ vectorizer.fit_transform(corpus)) # 第一个fit_transform是计算tf-idf,第二个fit_transform是将文本转为词频矩阵
278
+ word = vectorizer.get_feature_names() # 获取词袋模型中的所有词语
279
+ weight = tfidf.toarray() # 将tf-idf矩阵抽取出来,元素a[i][j]表示j词在i类文本中的tf-idf权重
280
+ for i in range(len(weight)): # 打印每类文本的tf-idf词语权重,第一个for遍历所有文本,第二个for便利某一类文本下的词语权重
281
+ for j in range(len(word)):
282
+ getword = word[j]
283
+ getvalue = weight[i][j]
284
+ if getvalue != 0: # 去掉值为0的项
285
+ if getword in tfidfdict: # 更新全局TFIDF值
286
+ tfidfdict[getword] += float(getvalue)
287
+ else:
288
+ tfidfdict.update({getword: getvalue})
289
+ sorted_tfidf = sorted(tfidfdict.items(), key=lambda d: d[1], reverse=True)
290
+ to_write = ['{} {}'.format(i[0], i[1]) for i in sorted_tfidf]
291
+ writetxt_w_list(to_write, save_path, num_lf=1)
292
+
293
+
294
+ class GaussDecay(object):
295
+ """
296
+ 当前只实现了时间的,全部使用默认值
297
+ """
298
+
299
+ def __init__(self, origin='2022-08-02', scale='90d', offset='5d', decay=0.5, task="time"):
300
+ self.origin = origin
301
+ self.task = task
302
+ self.scale, self.offset = self.translate(scale, offset)
303
+ self.decay = decay
304
+ self.time_coefficient = 0.6
305
+ self.related_coefficient = 0.4
306
+
307
+ def translate(self, scale, offset):
308
+ """
309
+ 将领域的输入转化为标准
310
+ :return:
311
+ """
312
+ if self.task == "time":
313
+ scale = 180
314
+ offset = 5
315
+ else:
316
+ scale = 180
317
+ offset = 5
318
+ return scale, offset
319
+
320
+ @staticmethod
321
+ def translated_minus(field_value):
322
+ origin = datetime.datetime.now()
323
+ field_value = datetime.datetime.strptime(field_value, '%Y-%m-%d %H:%M:%S')
324
+ return (origin - field_value).days
325
+
326
+ def calc_exp(self):
327
+ pass
328
+
329
+ def calc_liner(self):
330
+ pass
331
+
332
+ def calc_gauss(self, raw_score, field_value):
333
+ """
334
+ $$S(doc)=exp(-\frac{max(0,|fieldvalues_{doc}-origin|-offset)^2}{2σ^2})$$ -
335
+ $$σ^2=-scale^2/(2·ln(decay))$$
336
+ :param raw_score:
337
+ :param field_value:
338
+ :return:
339
+ """
340
+ numerator = max(0, (abs(self.translated_minus(field_value)) - self.offset)) ** 2
341
+ sigma_square = -1 * self.scale ** 2 / (2 * math.log(self.decay, math.e))
342
+ denominator = 2 * sigma_square
343
+ s = math.exp(-1 * numerator / denominator)
344
+ return round(self.time_coefficient * s + self.related_coefficient * raw_score, 7)
345
+
346
+
347
+ if __name__ == '__main__':
348
+ gauss_decay = GaussDecay()
349
+ res = gauss_decay.calc_gauss(raw_score=1, field_value="2021-05-29 14:31:13")
350
+ print(res)
351
+ # res = gauss_decay.calc_gauss(raw_score=1, field_value="2022-05-29 14:31:13")
352
+ # print(res)
353
+ # res = gauss_decay.calc_gauss(raw_score=1, field_value="2022-05-29 14:31:13")
354
+ # print(res)
355
+ # res = gauss_decay.calc_gauss(raw_score=1, field_value="2022-05-29 14:31:13")
356
+ # print(res)
357
+
358
+ # 常用函数参考
359
+ # import tensorflow as tf
360
+ #
361
+ # gpu_options = tf.compat.v1.GPUOptions(allow_growth=True)
362
+ # sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
363
+ # for gpu in tf.config.experimental.list_physical_devices('GPU'):
364
+ # tf.config.experimental.set_memory_growth()