PipeGraphPy 2.0.6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. PipeGraphPy/__init__.py +10 -0
  2. PipeGraphPy/common.py +4 -0
  3. PipeGraphPy/config/__init__.py +276 -0
  4. PipeGraphPy/config/custom.py +6 -0
  5. PipeGraphPy/config/default_settings.py +125 -0
  6. PipeGraphPy/constants.py +421 -0
  7. PipeGraphPy/core/__init__.py +2 -0
  8. PipeGraphPy/core/anchor.cp39-win_amd64.pyd +0 -0
  9. PipeGraphPy/core/edge.cp39-win_amd64.pyd +0 -0
  10. PipeGraphPy/core/graph.cp39-win_amd64.pyd +0 -0
  11. PipeGraphPy/core/graph_base.cp39-win_amd64.pyd +0 -0
  12. PipeGraphPy/core/modcls/__init__.py +3 -0
  13. PipeGraphPy/core/modcls/base.cp39-win_amd64.pyd +0 -0
  14. PipeGraphPy/core/modcls/branchselect.cp39-win_amd64.pyd +0 -0
  15. PipeGraphPy/core/modcls/classifier.cp39-win_amd64.pyd +0 -0
  16. PipeGraphPy/core/modcls/cluster.cp39-win_amd64.pyd +0 -0
  17. PipeGraphPy/core/modcls/datacharts.cp39-win_amd64.pyd +0 -0
  18. PipeGraphPy/core/modcls/deeplearning.cp39-win_amd64.pyd +0 -0
  19. PipeGraphPy/core/modcls/endscript.cp39-win_amd64.pyd +0 -0
  20. PipeGraphPy/core/modcls/ensemble.cp39-win_amd64.pyd +0 -0
  21. PipeGraphPy/core/modcls/evaluate.cp39-win_amd64.pyd +0 -0
  22. PipeGraphPy/core/modcls/exportdata.cp39-win_amd64.pyd +0 -0
  23. PipeGraphPy/core/modcls/handlescript.cp39-win_amd64.pyd +0 -0
  24. PipeGraphPy/core/modcls/importdata.cp39-win_amd64.pyd +0 -0
  25. PipeGraphPy/core/modcls/merge.cp39-win_amd64.pyd +0 -0
  26. PipeGraphPy/core/modcls/mergescript.cp39-win_amd64.pyd +0 -0
  27. PipeGraphPy/core/modcls/metrics.cp39-win_amd64.pyd +0 -0
  28. PipeGraphPy/core/modcls/postprocessor.cp39-win_amd64.pyd +0 -0
  29. PipeGraphPy/core/modcls/preprocessor.cp39-win_amd64.pyd +0 -0
  30. PipeGraphPy/core/modcls/pythonscript.cp39-win_amd64.pyd +0 -0
  31. PipeGraphPy/core/modcls/regressor.cp39-win_amd64.pyd +0 -0
  32. PipeGraphPy/core/modcls/selector.cp39-win_amd64.pyd +0 -0
  33. PipeGraphPy/core/modcls/selectscript.cp39-win_amd64.pyd +0 -0
  34. PipeGraphPy/core/modcls/special.cp39-win_amd64.pyd +0 -0
  35. PipeGraphPy/core/modcls/split.cp39-win_amd64.pyd +0 -0
  36. PipeGraphPy/core/modcls/splitscript.cp39-win_amd64.pyd +0 -0
  37. PipeGraphPy/core/modcls/startscript.cp39-win_amd64.pyd +0 -0
  38. PipeGraphPy/core/modcls/transformer.cp39-win_amd64.pyd +0 -0
  39. PipeGraphPy/core/module.cp39-win_amd64.pyd +0 -0
  40. PipeGraphPy/core/modules/__init__.py +65 -0
  41. PipeGraphPy/core/modules/classifier/__init__.py +2 -0
  42. PipeGraphPy/core/modules/cluster/__init__.py +0 -0
  43. PipeGraphPy/core/modules/custom/__init__.py +0 -0
  44. PipeGraphPy/core/modules/custom/classifier/__init__.py +0 -0
  45. PipeGraphPy/core/modules/datacharts/__init__.py +5 -0
  46. PipeGraphPy/core/modules/datacharts/dataview.py +28 -0
  47. PipeGraphPy/core/modules/deeplearning/__init__.py +0 -0
  48. PipeGraphPy/core/modules/ensemble/__init__.py +0 -0
  49. PipeGraphPy/core/modules/evaluate/__init__.py +0 -0
  50. PipeGraphPy/core/modules/exportdata/__init__.py +0 -0
  51. PipeGraphPy/core/modules/importdata/__init__.py +0 -0
  52. PipeGraphPy/core/modules/merge/__init__.py +0 -0
  53. PipeGraphPy/core/modules/model_selector/__init__.py +3 -0
  54. PipeGraphPy/core/modules/postprocessor/__init__.py +0 -0
  55. PipeGraphPy/core/modules/preprocessor/__init__.py +0 -0
  56. PipeGraphPy/core/modules/pythonscript/__init__.py +0 -0
  57. PipeGraphPy/core/modules/regressor/__init__.py +0 -0
  58. PipeGraphPy/core/modules/selector/__init__.py +0 -0
  59. PipeGraphPy/core/modules/special/__init__.py +0 -0
  60. PipeGraphPy/core/modules/split/__init__.py +0 -0
  61. PipeGraphPy/core/modules/transformer/__init__.py +0 -0
  62. PipeGraphPy/core/node.cp39-win_amd64.pyd +0 -0
  63. PipeGraphPy/core/pipegraph.cp39-win_amd64.pyd +0 -0
  64. PipeGraphPy/db/__init__.py +2 -0
  65. PipeGraphPy/db/models.cp39-win_amd64.pyd +0 -0
  66. PipeGraphPy/db/utils.py +106 -0
  67. PipeGraphPy/decorators.py +42 -0
  68. PipeGraphPy/logger.py +170 -0
  69. PipeGraphPy/plot/__init__.py +0 -0
  70. PipeGraphPy/plot/draw.py +424 -0
  71. PipeGraphPy/storage/__init__.py +10 -0
  72. PipeGraphPy/storage/base.py +2 -0
  73. PipeGraphPy/storage/dict_backend.py +102 -0
  74. PipeGraphPy/storage/file_backend.py +342 -0
  75. PipeGraphPy/storage/redis_backend.py +183 -0
  76. PipeGraphPy/tools.py +388 -0
  77. PipeGraphPy/utils/__init__.py +1 -0
  78. PipeGraphPy/utils/check.py +179 -0
  79. PipeGraphPy/utils/core.py +295 -0
  80. PipeGraphPy/utils/examine.py +259 -0
  81. PipeGraphPy/utils/file_operate.py +101 -0
  82. PipeGraphPy/utils/format.py +303 -0
  83. PipeGraphPy/utils/functional.py +422 -0
  84. PipeGraphPy/utils/handle_graph.py +31 -0
  85. PipeGraphPy/utils/lock.py +1 -0
  86. PipeGraphPy/utils/mq.py +54 -0
  87. PipeGraphPy/utils/osutil.py +29 -0
  88. PipeGraphPy/utils/redis_operate.py +195 -0
  89. PipeGraphPy/utils/str_handle.py +122 -0
  90. PipeGraphPy/utils/version.py +108 -0
  91. PipeGraphPy-2.0.6.dist-info/METADATA +17 -0
  92. PipeGraphPy-2.0.6.dist-info/RECORD +94 -0
  93. PipeGraphPy-2.0.6.dist-info/WHEEL +5 -0
  94. PipeGraphPy-2.0.6.dist-info/top_level.txt +1 -0
@@ -0,0 +1,342 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import shutil
5
+ import traceback
6
+ import os
7
+ from datetime import datetime, timedelta
8
+ from PipeGraphPy.storage.base import ParamsPoolBase
9
+ from collections import defaultdict
10
+ from PipeGraphPy.config import settings
11
+ from PipeGraphPy.utils.file_operate import pickle_dumps, pickle_loads
12
+ from PipeGraphPy.utils.str_handle import generate_random_str
13
+ from PipeGraphPy.constants import (
14
+ NODE_OUTPUT_KEY, MOD_RETURN_RESULT,
15
+ ALGO_MOD_TYPE, MODULES, SCENETYPE)
16
+ from PipeGraphPy.logger import log
17
+
18
+ ttl = 7
19
+
20
+
21
+ class ParamsPool(ParamsPoolBase):
22
+ """传参池"""
23
+
24
+ def __init__(self, graph_id, object_id, is_predict=False):
25
+ self.graph_id = graph_id
26
+ self.object_id = object_id
27
+ self.is_predict = is_predict
28
+
29
+ def add_params(self, node, anchor, data):
30
+ """添加传参值"""
31
+ save_params(self.graph_id, self.object_id, node.id, anchor, data, self.is_predict)
32
+
33
+ def get_params(self, node, anchor):
34
+ """获取传参值"""
35
+ res = get_params(
36
+ self.graph_id, self.object_id, node.id, anchor, is_pd=True, is_predict=self.is_predict
37
+ )
38
+ return res
39
+
40
+ def add_params_by_list(self, node, data_list):
41
+ """通过数据列表添加传参"""
42
+ for idx, data in enumerate(data_list):
43
+ anchor = node.outidx_to_outanchor(idx)
44
+ self.add_params(node, anchor, data)
45
+
46
+ def is_end(self, node):
47
+ """判断节点是否已运行完"""
48
+ return (
49
+ True if self.get_params(node, node.anchors[1][0].anc) is not None else False
50
+ )
51
+
52
+ def check_pass_params(self, fathers):
53
+ """循环父节点判断父节点是否已结束"""
54
+ return all([self.is_end(i) for i in fathers])
55
+
56
+ def gen_pass_params(self, nodes_group, node, fathers):
57
+ """生成节点要传入的参数"""
58
+ input_output_dict = defaultdict(list)
59
+ input_output_list = list()
60
+ # 循环所有父节点, 组合node和anchor值
61
+ for n in fathers:
62
+ output_anchor, input_anchor = nodes_group[(n.id, node.id)]
63
+ input_output_list.extend(
64
+ zip(list(input_anchor), [(n, int(i)) for i in output_anchor])
65
+ )
66
+
67
+ for i, o in input_output_list:
68
+ input_output_dict[int(i)].append(o)
69
+
70
+ # 根据node和anchor获取节点,识别列表,把接口是列表类型的数据放入一个列表中
71
+ input_params_list = list()
72
+ for k, v in input_output_dict.items():
73
+ if len(v) > 1:
74
+ input_params_list.append((k, [self.get_params(*i) for i in v]))
75
+ elif len(v) == 1:
76
+ input_params_list.append((k, self.get_params(*v[0])))
77
+ else:
78
+ raise Exception("传参错误")
79
+ # 参数排序
80
+ input_params_list.sort(key=lambda x: x[0])
81
+ arg = [i[1] for i in input_params_list]
82
+ # 当list输入为1个元素时要转成list
83
+ for i, item in enumerate(node.INPUT):
84
+ if item.startswith("list") and not isinstance(arg[i], list):
85
+ arg[i] = [arg[i]]
86
+ return arg
87
+
88
+
89
+ def save_params(graph_id, object_id, node_id, anchor, data, is_predict=False):
90
+ """保存图模块输出结果到文件中"""
91
+ key = NODE_OUTPUT_KEY.format(graph_id=graph_id, node_id=node_id, anchor=anchor)
92
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
93
+ pickle_dumps(os.path.join(folder_path, key), data)
94
+
95
+
96
+ def get_params(graph_id, object_id, node_id, anchor, is_pd=False, is_predict=False):
97
+ """获取图模块输出结果从文件中"""
98
+ key = NODE_OUTPUT_KEY.format(graph_id=graph_id, node_id=node_id, anchor=anchor)
99
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
100
+ return pickle_loads(os.path.join(folder_path, key))
101
+
102
+
103
+ def delete_params(graph_id, object_id, is_predict=False):
104
+ """删除所有模块输出结果从文件中"""
105
+ try:
106
+ key = NODE_OUTPUT_KEY.format(graph_id=graph_id, node_id="", anchor="")
107
+ find_key = key[:-1]
108
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
109
+ # 用os.walk方法取得path路径下的文件夹路径,子文件夹名,所有文件名
110
+ for foldName, subfolders, filenames in os.walk(folder_path):
111
+ for filename in filenames: # 遍历列表下的所有文件名
112
+ if filename.find(find_key) != -1: # 当文件名不为
113
+ os.remove(os.path.join(foldName, filename)) # 删除符合条件的文件
114
+ except Exception:
115
+ log.error(traceback.format_exc())
116
+
117
+
118
+ def save_graph(graph, file_path=False):
119
+ """保存图到pickle文件里"""
120
+ if file_path:
121
+ folder_path = os.path.dirname(file_path)
122
+ dl_folder_path = os.path.join(
123
+ settings.SDK_MODEL_SAVE_PATH,
124
+ str(graph.id))
125
+ if not os.path.exists(dl_folder_path):
126
+ os.makedirs(dl_folder_path)
127
+ trans_deeplearning_algo_model(graph, dl_folder_path, is_load=False)
128
+ pickle_dumps(file_path, graph)
129
+ else:
130
+ model_name = settings.MODEL_SAVE_NAME.format(graph_id=graph.id, object_id=graph.info["object_info"]["id"])
131
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph.id), graph.info["object_id"])
132
+ file_path = os.path.join(folder_path, model_name)
133
+ trans_deeplearning_algo_model(graph, folder_path, is_load=False)
134
+ pickle_dumps(file_path, graph)
135
+
136
+
137
+ def load_graph(graph_id=None, object_id=None, file_path=None, use_predict_model=False):
138
+ """获取图从pickle文件里"""
139
+ if file_path:
140
+ folder_path = os.path.dirname(file_path)
141
+ graph = pickle_loads(file_path)
142
+ dl_folder_path = os.path.join(settings.SDK_MODEL_SAVE_PATH, str(graph.id))
143
+ trans_deeplearning_algo_model(graph, dl_folder_path, is_load=True)
144
+ return graph
145
+ else:
146
+ assert graph_id, ValueError("转入图模型必须传graph_id")
147
+ assert object_id, ValueError("转入图模型必须传object_id")
148
+ model_name = settings.MODEL_SAVE_NAME.format(graph_id=graph_id, object_id=object_id)
149
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
150
+ predict_folder_path = os.path.join(settings.PREDICT_MODEL_SAVE_PATH, str(graph_id))
151
+ predict_file_path = os.path.join(predict_folder_path, model_name)
152
+ if use_predict_model and os.path.exists(predict_file_path):
153
+ try:
154
+ graph = pickle_loads(file_path)
155
+ trans_deeplearning_algo_model(graph, predict_folder_path, is_load=True)
156
+ return graph
157
+ except Exception:
158
+ pass
159
+ file_path = os.path.join(folder_path, model_name)
160
+ graph = pickle_loads(file_path)
161
+ trans_deeplearning_algo_model(graph, folder_path, is_load=True)
162
+ return graph
163
+
164
+
165
+ def has_graph(graph_id, object_id, is_predict=False):
166
+ """判断图模型是否存在"""
167
+ model_name = settings.MODEL_SAVE_NAME.format(graph_id=graph_id, object_id=object_id)
168
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
169
+ model_path = os.path.join(folder_path, model_name)
170
+ return os.path.exists(model_path)
171
+
172
+
173
+ def push_result(graph_id, object_id, data, is_predict=False):
174
+ """增加模块进行结果到文件中"""
175
+ key = MOD_RETURN_RESULT.format(graph_id=graph_id)
176
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
177
+ result_folder_path = os.path.join(folder_path, key)
178
+ if not os.path.exists(result_folder_path):
179
+ os.makedirs(result_folder_path)
180
+ now = (datetime.utcnow()+timedelta(hours=8)).strftime("%Y%m%d%H%M%S%f")
181
+ filename = generate_random_str(16) + now + "0"
182
+
183
+ def gen_no_repeat_file_path(file_path):
184
+ if not os.path.isfile(file_path):
185
+ return file_path
186
+ file_path = file_path + str(int(file_path[-1]) + 1)
187
+ return gen_no_repeat_file_path(file_path)
188
+
189
+ file_path = os.path.join(result_folder_path, filename)
190
+ # 找到无重复的文件名
191
+ file_path = gen_no_repeat_file_path(file_path)
192
+ # 执行保存
193
+ pickle_dumps(file_path, data)
194
+
195
+
196
+ def pop_result(graph_id, object_id, is_predict=False):
197
+ """从文件中获取模块运行结果"""
198
+ key = MOD_RETURN_RESULT.format(graph_id=graph_id)
199
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
200
+ result_folder_path = os.path.join(folder_path, key)
201
+ if not os.path.exists(result_folder_path):
202
+ os.makedirs(result_folder_path)
203
+ file_path = None
204
+ for fold_name, subfolders, filenames in os.walk(result_folder_path):
205
+ if filenames:
206
+ file_path = os.path.join(fold_name, filenames[0])
207
+ else:
208
+ return None
209
+ # 获取数据
210
+ res = pickle_loads(file_path)
211
+ os.remove(file_path)
212
+ return res
213
+
214
+
215
+ def clear_result(graph_id, object_id, is_predict=False):
216
+ """从文件中清空所有模块运行结果"""
217
+ key = MOD_RETURN_RESULT.format(graph_id=graph_id)
218
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
219
+ result_folder_path = os.path.join(folder_path, key)
220
+ if os.path.exists(result_folder_path):
221
+ shutil.rmtree(result_folder_path)
222
+
223
+
224
+ def save_algo_model(graph_id, object_id, node_id, algo_mod_type, model, folder_path=None):
225
+ """保存深度学习模型图到文件里"""
226
+ try:
227
+ res = list()
228
+ if not folder_path:
229
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
230
+ if algo_mod_type == ALGO_MOD_TYPE.TENSORFLOW:
231
+ try:
232
+ import tensorflow as tf
233
+ except:
234
+ raise ImportError("载入tensorflow error")
235
+ saver = tf.compat.v1.train.Saver()
236
+ elif algo_mod_type == ALGO_MOD_TYPE.KERAS:
237
+ pass
238
+ elif algo_mod_type == ALGO_MOD_TYPE.PYTORCH:
239
+ try:
240
+ import torch
241
+ except:
242
+ raise ImportError("载入torch")
243
+ else:
244
+ return model
245
+ models = model if isinstance(model, list) else [model]
246
+ for idx, m in enumerate(models):
247
+ if m is None or isinstance(m, (str, list, tuple, int, dict, float)):
248
+ continue
249
+ model_name = settings.ALGO_MODEL_SAVE_NAME.format(
250
+ graph_id=graph_id,
251
+ node_id=node_id,
252
+ algo_mod_type=algo_mod_type,
253
+ idx=idx
254
+ )
255
+ if algo_mod_type == ALGO_MOD_TYPE.TENSORFLOW:
256
+ model_name = f"{model_name}.ckpt"
257
+ saver.save(m, os.path.join(folder_path, model_name))
258
+ elif algo_mod_type == ALGO_MOD_TYPE.KERAS:
259
+ model_name = f"{model_name}.h5"
260
+ m.save(os.path.join(folder_path, model_name))
261
+ elif algo_mod_type == ALGO_MOD_TYPE.PYTORCH:
262
+ model_name = f"{model_name}.pt"
263
+ torch.save(m, os.path.join(folder_path, model_name))
264
+ else:
265
+ raise ValueError("algo_mod_type(%s) error" % algo_mod_type)
266
+ res.append(model_name)
267
+ return res[0] if len(res) == 1 else res
268
+ except Exception:
269
+ raise Exception("无法保存深度学习模型:\n" + str(traceback.format_exc()))
270
+
271
+
272
+ def load_algo_model(graph_id, object_id, node_id, algo_mod_type, model_name, folder_path=None):
273
+ """获取深度学习模型从文件里"""
274
+ try:
275
+ res = list()
276
+ if not folder_path:
277
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph_id), str(object_id))
278
+ if algo_mod_type == ALGO_MOD_TYPE.TENSORFLOW:
279
+ try:
280
+ import tensorflow as tf
281
+ except:
282
+ raise ImportError("载入tensorflow error")
283
+ saver = tf.compat.v1.train.Saver()
284
+ elif algo_mod_type == ALGO_MOD_TYPE.KERAS:
285
+ try:
286
+ from tensorflow.keras.models import load_model
287
+ except:
288
+ raise ImportError("载入keras.load_model error")
289
+ elif algo_mod_type == ALGO_MOD_TYPE.PYTORCH:
290
+ try:
291
+ import torch
292
+ except:
293
+ raise ImportError("载入torch")
294
+ else:
295
+ return model_name
296
+ model_names = model_name if isinstance(model_name, list) else [model_name]
297
+ for idx, mn in enumerate(model_names):
298
+ if algo_mod_type == ALGO_MOD_TYPE.TENSORFLOW:
299
+ with tf.Seesion() as sess:
300
+ model = saver.resore(sess, os.path.join(folder_path, mn))
301
+ elif algo_mod_type == ALGO_MOD_TYPE.KERAS:
302
+ model = load_model(os.path.join(folder_path, mn))
303
+ elif algo_mod_type == ALGO_MOD_TYPE.PYTORCH:
304
+ model = torch.load(os.path.join(folder_path, mn))
305
+ else:
306
+ raise ValueError("algo_mod_type(%s) error" % algo_mod_type)
307
+ res.append(model)
308
+ return res[0] if len(res) == 1 else res
309
+ except Exception:
310
+ raise Exception("无法载入深度学习模型:\n" + traceback.format_exc())
311
+
312
+
313
+ def trans_deeplearning_algo_model(graph, folder_path=None, is_load=False):
314
+ # 判断是否有深度学习算法, 深度学习算法在模型和字符串之间转换
315
+ if not folder_path:
316
+ if graph.scene == SCENETYPE.SDKTEST:
317
+ folder_path = os.path.join(
318
+ settings.SDK_MODEL_SAVE_PATH, str(graph.id), str(graph.info["object_id"]))
319
+ elif graph.scene == SCENETYPE.ONLINE:
320
+ folder_path = os.path.join(settings.RUN_MODEL_SAVE_PATH, str(graph.id), graph.info["object_id"])
321
+ for k, v in graph.a.nodes_dict.items():
322
+ if v.module.parent.info["cls_name"] == MODULES.DEEPLEARNING:
323
+ if v.module.info.get("algo_type_id"):
324
+ algo_mod_type_name = v.module.info.get('algo_mod_type_name')
325
+ if algo_mod_type_name and hasattr(v.algo_instance, "model"):
326
+ func = None
327
+ if is_load:
328
+ if isinstance(v.algo_instance.model, str):
329
+ func = load_algo_model
330
+ else:
331
+ if not isinstance(v.algo_instance.model, str):
332
+ func = save_algo_model
333
+ if func is not None:
334
+ v.algo_instance.model = func(
335
+ graph.id,
336
+ v.info['id'],
337
+ algo_mod_type_name,
338
+ v.algo_instance.model,
339
+ folder_path,
340
+ )
341
+
342
+
@@ -0,0 +1,183 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # import redis
5
+ # import pickle
6
+ # import traceback
7
+ # from PipeGraphPy.storage.base import ParamsPoolBase
8
+ # from PipeGraphPy.utils.format import binary_to_utf8
9
+ # from collections import defaultdict
10
+ # from PipeGraphPy.config import settings
11
+ # from PipeGraphPy.utils.redis_operate import pickle_dumps, pickle_loads
12
+ # from PipeGraphPy.constants import (NODE_STATUS_KEY, NODE_OUTPUT_KEY,
13
+ # GRAPH_STATUS_KEY, GRAPH_KEY, MOD_RETURN_RESULT)
14
+ #
15
+ # from PipeGraphPy.logger import log
16
+ # redis_conf = dict(
17
+ # host=settings.REDIS_HOST,
18
+ # port=settings.REDIS_PORT,
19
+ # db=settings.REDIS_DB,
20
+ # )
21
+ # redis_key_ttl = settings.REDIS_KEY_TTL
22
+ #
23
+ #
24
+ # redis_conn = redis.Redis(
25
+ # host=redis_conf['host'], port=redis_conf['port'], db=redis_conf['db'])
26
+ #
27
+ #
28
+ # class ParamsPool(ParamsPoolBase):
29
+ # '''传参池'''
30
+ #
31
+ # def __init__(self, graph_id, is_predict=False):
32
+ # self.graph_id = graph_id
33
+ # self.is_predict = is_predict
34
+ #
35
+ # def add_params(self, node, idx, data):
36
+ # '''添加传参值'''
37
+ # save_params(self.graph_id, node.id, idx, data, self.is_predict)
38
+ #
39
+ # def get_params(self, node, idx):
40
+ # '''获取传参值'''
41
+ # res = get_params(self.graph_id, node.id, idx,
42
+ # is_pd=True, is_predict=self.is_predict)
43
+ # return res
44
+ #
45
+ # def add_params_by_list(self, node, data_list):
46
+ # '''通过数据列表添加传参'''
47
+ # for idx, data in enumerate(data_list):
48
+ # self.add_params(node, idx, data)
49
+ #
50
+ # def is_end(self, node):
51
+ # '''判断节点是否已运行完'''
52
+ # return True if self.get_params(node, 0) is not None else False
53
+ #
54
+ # def check_pass_params(self, node):
55
+ # '''循环父节点判断父节点是否已结束'''
56
+ # return all([self.is_end(i) for i in node.fathers])
57
+ #
58
+ # def gen_pass_params(self, edges_dict, node):
59
+ # '''生成节点要传入的参数'''
60
+ # input_output_dict = defaultdict(list)
61
+ # input_output_list = list()
62
+ # # 循环所有父节点, 组合node和idx值
63
+ # for n in node.fathers:
64
+ # pass_idx = edges_dict[(n.id, node.id)]
65
+ # out_idx_str, input_idx_str = pass_idx.split('-')
66
+ # input_output_list.extend(zip(input_idx_str.split(
67
+ # ','), [(n, int(i)) for i in out_idx_str.split(',')]))
68
+ #
69
+ # for i, o in input_output_list:
70
+ # input_output_dict[int(i)].append(o)
71
+ #
72
+ # # 根据node和idx获取节点,识别列表,把接口是列表类型的数据放入一个列表中
73
+ # input_params_list = list()
74
+ # for k, v in input_output_dict.items():
75
+ # if len(v) > 1:
76
+ # input_params_list.append((k, [self.get_params(*i) for i in v]))
77
+ # elif len(v) == 1:
78
+ # input_params_list.append((k, self.get_params(*v[0])))
79
+ # else:
80
+ # raise Exception('传参错误')
81
+ # # 参数排序
82
+ # input_params_list.sort(key=lambda x: x[0])
83
+ # arg = [i[1] for i in input_params_list]
84
+ # # 当list输入为1个元素时要转成list
85
+ # for i, item in enumerate(node.input_data_type):
86
+ # if item.startswith('list') and not isinstance(arg[i], list):
87
+ # arg[i] = [arg[i]]
88
+ # return arg
89
+ #
90
+ #
91
+ # def save_node_status(node, is_predict=False):
92
+ # '''保存节点到redis'''
93
+ # key = NODE_STATUS_KEY.format(graph_id=node.info['graph_id'],
94
+ # node_id=node.id)
95
+ # redis_conn.hmset(key, node.to_dict())
96
+ #
97
+ #
98
+ # def get_node_status(graph_id, node_id, is_predict=False):
99
+ # '''获取节点状态到从redis'''
100
+ # fields = ['val', 'is_pass', 'status', 'run_log']
101
+ # key = NODE_STATUS_KEY.format(graph_id=graph_id, node_id=node_id)
102
+ # res = redis_conn.hmget(key, fields)
103
+ # return dict(zip(fields, binary_to_utf8(res)))
104
+ #
105
+ #
106
+ # def save_graph_status(graph_id, is_predict=False):
107
+ # '''保存图的状态到redis'''
108
+ # key = GRAPH_STATUS_KEY.format(graph_id=graph_id)
109
+ # redis_conn.hmset(key, graph_id)
110
+ #
111
+ #
112
+ # def get_graph_status(graph_id, is_predict=False):
113
+ # '''获取图状态到从redis'''
114
+ # fields = ['pid', 'status', 'run_log']
115
+ # key = GRAPH_STATUS_KEY.format(graph_id=graph_id)
116
+ # res = redis_conn.hmget(key, fields)
117
+ # return dict(zip(fields, binary_to_utf8(res)))
118
+ #
119
+ #
120
+ # def save_params(graph_id, node_id, idx, data, is_predict=False):
121
+ # '''保存图模块输出结果到redis'''
122
+ # key = NODE_OUTPUT_KEY.format(graph_id=graph_id,
123
+ # node_id=node_id,
124
+ # idx=idx)
125
+ # pickle_dumps(redis_conn, key, data)
126
+ #
127
+ #
128
+ # def get_params(graph_id, node_id, idx, is_pd=False, is_predict=False):
129
+ # '''获取图模块输出结果从redis'''
130
+ # key = NODE_OUTPUT_KEY.format(graph_id=graph_id,
131
+ # node_id=node_id,
132
+ # idx=idx)
133
+ # return pickle_loads(redis_conn, key)
134
+ #
135
+ #
136
+ # def delete_params(graph_id, is_predict=False):
137
+ # '''删除所有模块输出结果从redis'''
138
+ # try:
139
+ # pattern_key = 'graph_output_%s_*' % graph_id
140
+ # key_list = redis_conn.keys(pattern=pattern_key)
141
+ # if key_list:
142
+ # res = redis_conn.delete(*key_list)
143
+ # return res
144
+ # return 1
145
+ # except Exception:
146
+ # log.error(
147
+ # traceback.format_exc(),
148
+ # graph_id=graph_id)
149
+ #
150
+ #
151
+ # def save_graph(graph_id, data, is_predict=False):
152
+ # '''保存图到redis'''
153
+ # key = GRAPH_KEY.format(graph_id=graph_id)
154
+ # pickle_dumps(redis_conn, key, data)
155
+ #
156
+ #
157
+ # def get_graph(graph_id, is_predict=False):
158
+ # '''获取图从redis'''
159
+ # key = GRAPH_KEY.format(graph_id=graph_id)
160
+ # return pickle_loads(redis_conn, key)
161
+ #
162
+ #
163
+ # def push_result(graph_id, data, is_predict=False):
164
+ # '''增加模块进行结果到redis'''
165
+ # key = MOD_RETURN_RESULT.format(graph_id=graph_id)
166
+ # pd_bytes = pickle.dumps(data)
167
+ # res = redis_conn.lpush(key, pd_bytes)
168
+ # if not res:
169
+ # raise Exception('保存数据到redis发生错误')
170
+ # return res
171
+ #
172
+ #
173
+ # def pop_result(graph_id, is_predict=False):
174
+ # '''从redis获取模块运行结果'''
175
+ # try:
176
+ # key = MOD_RETURN_RESULT.format(graph_id=graph_id)
177
+ # res = redis_conn.rpop(name=key)
178
+ # if not res:
179
+ # return None
180
+ # res = pickle.loads(res)
181
+ # return res
182
+ # except Exception:
183
+ # log.error(traceback.format_exc(), graph_id=graph_id)