PipeGraphPy 2.0.6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. PipeGraphPy/__init__.py +10 -0
  2. PipeGraphPy/common.py +4 -0
  3. PipeGraphPy/config/__init__.py +276 -0
  4. PipeGraphPy/config/custom.py +6 -0
  5. PipeGraphPy/config/default_settings.py +125 -0
  6. PipeGraphPy/constants.py +421 -0
  7. PipeGraphPy/core/__init__.py +2 -0
  8. PipeGraphPy/core/anchor.cp39-win_amd64.pyd +0 -0
  9. PipeGraphPy/core/edge.cp39-win_amd64.pyd +0 -0
  10. PipeGraphPy/core/graph.cp39-win_amd64.pyd +0 -0
  11. PipeGraphPy/core/graph_base.cp39-win_amd64.pyd +0 -0
  12. PipeGraphPy/core/modcls/__init__.py +3 -0
  13. PipeGraphPy/core/modcls/base.cp39-win_amd64.pyd +0 -0
  14. PipeGraphPy/core/modcls/branchselect.cp39-win_amd64.pyd +0 -0
  15. PipeGraphPy/core/modcls/classifier.cp39-win_amd64.pyd +0 -0
  16. PipeGraphPy/core/modcls/cluster.cp39-win_amd64.pyd +0 -0
  17. PipeGraphPy/core/modcls/datacharts.cp39-win_amd64.pyd +0 -0
  18. PipeGraphPy/core/modcls/deeplearning.cp39-win_amd64.pyd +0 -0
  19. PipeGraphPy/core/modcls/endscript.cp39-win_amd64.pyd +0 -0
  20. PipeGraphPy/core/modcls/ensemble.cp39-win_amd64.pyd +0 -0
  21. PipeGraphPy/core/modcls/evaluate.cp39-win_amd64.pyd +0 -0
  22. PipeGraphPy/core/modcls/exportdata.cp39-win_amd64.pyd +0 -0
  23. PipeGraphPy/core/modcls/handlescript.cp39-win_amd64.pyd +0 -0
  24. PipeGraphPy/core/modcls/importdata.cp39-win_amd64.pyd +0 -0
  25. PipeGraphPy/core/modcls/merge.cp39-win_amd64.pyd +0 -0
  26. PipeGraphPy/core/modcls/mergescript.cp39-win_amd64.pyd +0 -0
  27. PipeGraphPy/core/modcls/metrics.cp39-win_amd64.pyd +0 -0
  28. PipeGraphPy/core/modcls/postprocessor.cp39-win_amd64.pyd +0 -0
  29. PipeGraphPy/core/modcls/preprocessor.cp39-win_amd64.pyd +0 -0
  30. PipeGraphPy/core/modcls/pythonscript.cp39-win_amd64.pyd +0 -0
  31. PipeGraphPy/core/modcls/regressor.cp39-win_amd64.pyd +0 -0
  32. PipeGraphPy/core/modcls/selector.cp39-win_amd64.pyd +0 -0
  33. PipeGraphPy/core/modcls/selectscript.cp39-win_amd64.pyd +0 -0
  34. PipeGraphPy/core/modcls/special.cp39-win_amd64.pyd +0 -0
  35. PipeGraphPy/core/modcls/split.cp39-win_amd64.pyd +0 -0
  36. PipeGraphPy/core/modcls/splitscript.cp39-win_amd64.pyd +0 -0
  37. PipeGraphPy/core/modcls/startscript.cp39-win_amd64.pyd +0 -0
  38. PipeGraphPy/core/modcls/transformer.cp39-win_amd64.pyd +0 -0
  39. PipeGraphPy/core/module.cp39-win_amd64.pyd +0 -0
  40. PipeGraphPy/core/modules/__init__.py +65 -0
  41. PipeGraphPy/core/modules/classifier/__init__.py +2 -0
  42. PipeGraphPy/core/modules/cluster/__init__.py +0 -0
  43. PipeGraphPy/core/modules/custom/__init__.py +0 -0
  44. PipeGraphPy/core/modules/custom/classifier/__init__.py +0 -0
  45. PipeGraphPy/core/modules/datacharts/__init__.py +5 -0
  46. PipeGraphPy/core/modules/datacharts/dataview.py +28 -0
  47. PipeGraphPy/core/modules/deeplearning/__init__.py +0 -0
  48. PipeGraphPy/core/modules/ensemble/__init__.py +0 -0
  49. PipeGraphPy/core/modules/evaluate/__init__.py +0 -0
  50. PipeGraphPy/core/modules/exportdata/__init__.py +0 -0
  51. PipeGraphPy/core/modules/importdata/__init__.py +0 -0
  52. PipeGraphPy/core/modules/merge/__init__.py +0 -0
  53. PipeGraphPy/core/modules/model_selector/__init__.py +3 -0
  54. PipeGraphPy/core/modules/postprocessor/__init__.py +0 -0
  55. PipeGraphPy/core/modules/preprocessor/__init__.py +0 -0
  56. PipeGraphPy/core/modules/pythonscript/__init__.py +0 -0
  57. PipeGraphPy/core/modules/regressor/__init__.py +0 -0
  58. PipeGraphPy/core/modules/selector/__init__.py +0 -0
  59. PipeGraphPy/core/modules/special/__init__.py +0 -0
  60. PipeGraphPy/core/modules/split/__init__.py +0 -0
  61. PipeGraphPy/core/modules/transformer/__init__.py +0 -0
  62. PipeGraphPy/core/node.cp39-win_amd64.pyd +0 -0
  63. PipeGraphPy/core/pipegraph.cp39-win_amd64.pyd +0 -0
  64. PipeGraphPy/db/__init__.py +2 -0
  65. PipeGraphPy/db/models.cp39-win_amd64.pyd +0 -0
  66. PipeGraphPy/db/utils.py +106 -0
  67. PipeGraphPy/decorators.py +42 -0
  68. PipeGraphPy/logger.py +170 -0
  69. PipeGraphPy/plot/__init__.py +0 -0
  70. PipeGraphPy/plot/draw.py +424 -0
  71. PipeGraphPy/storage/__init__.py +10 -0
  72. PipeGraphPy/storage/base.py +2 -0
  73. PipeGraphPy/storage/dict_backend.py +102 -0
  74. PipeGraphPy/storage/file_backend.py +342 -0
  75. PipeGraphPy/storage/redis_backend.py +183 -0
  76. PipeGraphPy/tools.py +388 -0
  77. PipeGraphPy/utils/__init__.py +1 -0
  78. PipeGraphPy/utils/check.py +179 -0
  79. PipeGraphPy/utils/core.py +295 -0
  80. PipeGraphPy/utils/examine.py +259 -0
  81. PipeGraphPy/utils/file_operate.py +101 -0
  82. PipeGraphPy/utils/format.py +303 -0
  83. PipeGraphPy/utils/functional.py +422 -0
  84. PipeGraphPy/utils/handle_graph.py +31 -0
  85. PipeGraphPy/utils/lock.py +1 -0
  86. PipeGraphPy/utils/mq.py +54 -0
  87. PipeGraphPy/utils/osutil.py +29 -0
  88. PipeGraphPy/utils/redis_operate.py +195 -0
  89. PipeGraphPy/utils/str_handle.py +122 -0
  90. PipeGraphPy/utils/version.py +108 -0
  91. PipeGraphPy-2.0.6.dist-info/METADATA +17 -0
  92. PipeGraphPy-2.0.6.dist-info/RECORD +94 -0
  93. PipeGraphPy-2.0.6.dist-info/WHEEL +5 -0
  94. PipeGraphPy-2.0.6.dist-info/top_level.txt +1 -0
PipeGraphPy/tools.py ADDED
@@ -0,0 +1,388 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ import signal
5
+ import copy
6
+ import time
7
+ import json
8
+ import pandas as pd
9
+ from contextlib import contextmanager
10
+ from PipeGraphPy.db.models import GraphsTB, PredictRecordTB, BacktestRecordTB, ObjectGraphsTB
11
+ from datetime import datetime, timedelta
12
+ from PipeGraphPy.config import settings
13
+ from PipeGraphPy.db.utils import update_node_params
14
+ from PipeGraphPy.constants import MODULES, STATUS
15
+ from PipeGraphPy.common import multi_graph
16
+ from PipeGraphPy.logger import log
17
+ from PipeGraphPy.utils.file_operate import pickle_loads
18
+
19
+
20
+ def print_to_run_log(*args, graph_id=None, rlog_record_id=None):
21
+ '''将内容输出到run_log里,前端来展示
22
+ params:
23
+ args: 要打印的内容
24
+ graph_id: 运行图id
25
+ rlog_record_id: 运行记录id
26
+ '''
27
+ assert graph_id, 'graph_id必传'
28
+ msg = '\n'.join([str(i) for i in args])
29
+ if settings.SDK_SHOW_LOG:
30
+ print(msg)
31
+ if isinstance(graph_id, int):
32
+ GraphsTB.add_log(graph_id, msg)
33
+ if rlog_record_id:
34
+ log.info(msg)
35
+
36
+ def print_to_predict_log(*args, plog_record_id=None):
37
+ '''将内容输出到预测日志里,前端来展示
38
+ params:
39
+ args: 要打印的内容
40
+ plog_record_id: 预测记录id
41
+ '''
42
+ assert plog_record_id, 'plog_record_id必传'
43
+ PredictRecordTB.add_log(plog_record_id, '\n'.join([str(i) for i in args]))
44
+
45
+ def update_params_value(key, value, node):
46
+ """更新当前节点参数"""
47
+ node.params[key] = value
48
+ graph = multi_graph.get(node.graph_info["id"])
49
+ if graph.use_db:
50
+ update_node_params(node.id, key, value)
51
+ return 1
52
+
53
+ def update_params_source(key, value, node):
54
+ """更新当前节点参数source"""
55
+ graph = multi_graph.get(node.graph_info["id"])
56
+ if graph.use_db:
57
+ update_node_params(node.id, key, value, value_key="source")
58
+ return 1
59
+
60
+
61
+ def predict_to_csv(datas, filename="",
62
+ graph_id='', object_id='', node_id='',
63
+ plog_record_id='', online_plog_record_id="",
64
+ filepath=""):
65
+ # 取出节点的数据导入节点
66
+ assert isinstance(datas, pd.DataFrame), "预测保存的数据必须是DataFrame"
67
+ assert graph_id, "未传graph_id"
68
+ assert object_id, "未传object_id"
69
+ assert node_id, "未传node_id"
70
+ if filepath:
71
+ if not filepath.endswith(".csv"):
72
+ raise Exception("预测数据只能保存csv文件")
73
+ datas.to_csv(filepath, encoding="utf_8", index=False)
74
+ else:
75
+ if not plog_record_id and not online_plog_record_id:
76
+ return 0
77
+ predict_save_path = os.path.join(
78
+ settings.PREDICT_RESULT_SAVE_PATH,
79
+ str(graph_id),
80
+ str(object_id))
81
+ if not os.path.exists(predict_save_path):
82
+ os.makedirs(predict_save_path)
83
+ if filename:
84
+ predict_filename = filename
85
+ else:
86
+ if online_plog_record_id:
87
+ auto_file_name_prefix = "online_predict_%s" % online_plog_record_id
88
+ else:
89
+ auto_file_name_prefix = "predict_%s" % plog_record_id
90
+ predict_filename = "%s_%s_%s.csv" % (
91
+ auto_file_name_prefix,
92
+ str(node_id),
93
+ (datetime.utcnow()+timedelta(hours=8)).strftime("%Y%m%d%H%M%S"))
94
+ datas.to_csv(os.path.join(predict_save_path, predict_filename), encoding="utf_8", index=False)
95
+ return 1
96
+
97
+
98
+ def read_predict_csv(graph_id, object_id, start_date=None, end_date=None):
99
+ """读取模型历史的预测数据
100
+ args:
101
+ graph_id:预测模型的id
102
+ object_id, 建模对象id,
103
+ start_date: 预测数据批次的开始日期 (包含), 不传的话只返回最近的一次数据, 格式:YYYYmmdd or YYYY-mm-dd
104
+ end_date: 预测数据批次的结束日期(包含), 不传的话只返回最近的一次数据, 格式:YYYYmmdd or YYYY-mm-dd
105
+ """
106
+ datas = pd.DataFrame()
107
+ predict_save_path = os.path.join(settings.PREDICT_RESULT_SAVE_PATH, str(graph_id), str(object_id))
108
+ if not os.path.exists(predict_save_path):
109
+ return datas
110
+ file_list = os.listdir(predict_save_path)
111
+ file_time_dict = {(i.split("_")[-1]).replace(".csv", ""):i for i in file_list}
112
+ time_list = list(file_time_dict.keys())
113
+ if settings.RUN_ENV != "image":
114
+ if start_date and end_date:
115
+ start_date = str(start_date).replace('-', "")
116
+ end_date = str(end_date).replace('-', "")
117
+ daterange = pd.date_range(start_date, end_date, freq="D").to_list()
118
+ for d in daterange:
119
+ match_times = [i for i in time_list if i.startswith(d.strftime("%Y%m%d"))]
120
+ match_time = None
121
+ if len(match_times) > 1:
122
+ match_time = str(max(list(map(int, match_times))))
123
+ elif len(match_times) == 1:
124
+ match_time = match_times[0]
125
+ else:
126
+ continue
127
+ df = pd.read_csv(os.path.join(predict_save_path, file_time_dict[match_time]))
128
+ df["file_date"] = match_time[:8]
129
+ datas = datas.append(df)
130
+ return datas
131
+ else:
132
+ match_time = str(max(list(map(int, time_list))))
133
+ datas = pd.read_csv(os.path.join(predict_save_path, file_time_dict[match_time]))
134
+ datas["file_date"] = match_time[:8]
135
+ return datas
136
+ else:
137
+ # 请求NPMOS的接口
138
+ pass
139
+
140
+
141
+ def train_to_csv(datas, graph_id='', node_id=''):
142
+ # 取出节点的数据导入节点
143
+ assert isinstance(datas, pd.DataFrame), "预测保存的数据必须是DataFrame"
144
+ assert graph_id, "未传graph_id"
145
+ assert node_id, "未传node_id"
146
+ run_save_path = os.path.join(
147
+ settings.RUN_RESULT_SAVE_PATH,
148
+ str(graph_id))
149
+ if not os.path.exists(run_save_path):
150
+ os.makedirs(run_save_path)
151
+ run_filename = "run_%s.csv" % str(node_id)
152
+ datas.to_csv(os.path.join(run_save_path, run_filename), encoding="utf_8", index=False)
153
+ return 1
154
+
155
+
156
+ def read_train_csv(graph_id, node_id):
157
+ """读取模型历史的预测数据
158
+ args:
159
+ graph_id:预测模型的id
160
+ node_id: 节点
161
+ """
162
+ run_save_path = os.path.join(settings.RUN_RESULT_SAVE_PATH, str(graph_id))
163
+ file_list = os.listdir(run_save_path)
164
+ datas = pd.DataFrame()
165
+ if node_id:
166
+ file_list = [f for f in file_list if int(str(str(f).split("-")[-1]).replace(".csv", "")) == int(node_id)]
167
+ for f in file_list:
168
+ df = pd.read_csv(os.path.join(run_save_path, f))
169
+ df["node_id"] = str(str(f).split("-")[-1]).replace(".csv", "")
170
+ datas = datas.append(df)
171
+ return datas
172
+
173
+
174
+ def read_backtest_data(graph_id, object_id, start_date=None, end_date=None, record_id=None):
175
+ """读取回测数据结果
176
+ args:
177
+ graph_id:回测模型的id, 必传
178
+ object_id: 回测建模对象id,必传
179
+ start_date: 回测的开始日期 (包含), 不传的话读取最近的一次批次或指定record_id的回测, 格式:YYYY-mm-dd
180
+ end_date: 回测的结束日期 (包含), 不传的话读取最近的一次批次或指定record_id的回测, 格式:YYYY-mm-dd
181
+ record_id: 回测记录的记录id, start_date, end_date, record_id都不传的情况会读取最近一次回测结果
182
+ example:
183
+ >>> first_backtest = read_backtest_data(1371, '2023-07-01', '2023-07-26')
184
+ >>> first_backtest
185
+ power power_predict
186
+ time
187
+ 2023-07-01 00:45:00 0.0 0.0
188
+ 2023-07-01 01:00:00 0.0 0.0
189
+ 2023-07-01 01:15:00 0.0 0.0
190
+ 2023-07-01 01:30:00 0.0 0.0
191
+ 2023-07-01 01:45:00 0.0 0.0
192
+ ... ... ...
193
+ 2023-07-26 22:45:00 0.0 0.0
194
+ 2023-07-26 23:00:00 0.0 0.0
195
+ 2023-07-26 23:15:00 0.0 0.0
196
+ 2023-07-26 23:30:00 0.0 0.0
197
+ 2023-07-26 23:45:00 0.0 0.0
198
+
199
+ [2492 rows x 2 columns]
200
+ >>> second_backtest = read_backtest_data(1371, '2023-06-01', '2023-07-26')
201
+ >>> second_backtest
202
+ power power_predict
203
+ time
204
+ 2023-06-01 00:45:00 0.0 0.0
205
+ 2023-06-01 01:00:00 0.0 0.0
206
+ 2023-06-01 01:15:00 0.0 0.0
207
+ 2023-06-01 01:30:00 0.0 0.0
208
+ 2023-06-01 01:45:00 0.0 0.0
209
+ ... ... ...
210
+ 2023-07-26 22:45:00 0.0 0.0
211
+ 2023-07-26 23:00:00 0.0 0.0
212
+ 2023-07-26 23:15:00 0.0 0.0
213
+ 2023-07-26 23:30:00 0.0 0.0
214
+ 2023-07-26 23:45:00 0.0 0.0
215
+
216
+ [5372 rows x 2 columns]
217
+ >>> backtest_by_record_id = read_backtest_data(1371, record_id=302)
218
+ >>> backtest_by_record_id
219
+ power power_predict
220
+ time
221
+ 2023-07-01 00:45:00 0.0 0.0
222
+ 2023-07-01 01:00:00 0.0 0.0
223
+ 2023-07-01 01:15:00 0.0 0.0
224
+ 2023-07-01 01:30:00 0.0 0.0
225
+ 2023-07-01 01:45:00 0.0 0.0
226
+ ... ... ...
227
+ 2023-07-26 22:45:00 0.0 0.0
228
+ 2023-07-26 23:00:00 0.0 0.0
229
+ 2023-07-26 23:15:00 0.0 0.0
230
+ 2023-07-26 23:30:00 0.0 0.0
231
+ 2023-07-26 23:45:00 0.0 0.0
232
+
233
+ [2492 rows x 2 columns]
234
+ >>> last_backtest_data = read_backtest_data(1371)
235
+ >>> last_backtest_data
236
+ power power_predict
237
+ time
238
+ 2023-06-01 00:45:00 0.0 0.0
239
+ 2023-06-01 01:00:00 0.0 0.0
240
+ 2023-06-01 01:15:00 0.0 0.0
241
+ 2023-06-01 01:30:00 0.0 0.0
242
+ 2023-06-01 01:45:00 0.0 0.0
243
+ ... ... ...
244
+ 2023-07-26 22:45:00 0.0 0.0
245
+ 2023-07-26 23:00:00 0.0 0.0
246
+ 2023-07-26 23:15:00 0.0 0.0
247
+ 2023-07-26 23:30:00 0.0 0.0
248
+ 2023-07-26 23:45:00 0.0 0.0
249
+
250
+ [5372 rows x 2 columns]
251
+ """
252
+ datas = pd.DataFrame()
253
+ if start_date and end_date:
254
+ graph_info = GraphsTB.get(id=graph_id)
255
+ if graph_info["status"] != STATUS.SUCCESS:
256
+ raise Exception("该模型%s还未训练成功" % graph_id)
257
+ if graph_info["b_status"] in [STATUS.WAITRUN, STATUS.WAITEXE]:
258
+ raise Exception("模型%s正在等待回测,无法再次回测" % graph_id)
259
+ if graph_info["b_status"] in [STATUS.RUNNING]:
260
+ raise Exception("模型%s正在回测,无法再次回测" % graph_id)
261
+ ObjectGraphsTB.set(
262
+ b_status=STATUS.WAITRUN,
263
+ backtest_params=json.dumps({"start_dt": start_date, "end_dt":end_date})
264
+ ).where(graph_id=graph_id, object_id=object_id)
265
+ while True:
266
+ time.sleep(5)
267
+ b_status = ObjectGraphsTB.map_one("b_status").where(graph_id=graph_id, object_id=object_id)
268
+ if b_status == STATUS.ERROR:
269
+ backtest_record_id = ObjectGraphsTB.map_one("backtest_record_id").where(graph_id=graph_id, object_id=object_id)
270
+ if backtest_record_id:
271
+ record_log = BacktestRecordTB.map_one("log").where(id=backtest_record_id)
272
+ raise Exception("模型%s回测报错,回测日志:\n%s" % (graph_id, record_log))
273
+ raise Exception("模型%s回测报错" % graph_id)
274
+ elif b_status == STATUS.SUCCESS:
275
+ break
276
+ else:
277
+ continue
278
+ if record_id:
279
+ backtest_record_id = record_id
280
+ else:
281
+ backtest_record_id = ObjectGraphsTB.map_one("backtest_record_id").where(graph_id=graph_id, object_id=object_id)
282
+ if not backtest_record_id:
283
+ raise Exception("模型%s没有获取到回测记录" % graph_id)
284
+ backtest_save_path = os.path.join(settings.BACKTEST_RESULT_SAVE_PATH, str(graph_id))
285
+ backtest_save_file = os.path.join(backtest_save_path, "backtest_%s.pkl" % backtest_record_id)
286
+ if not os.path.isfile(backtest_save_file):
287
+ raise Exception("模型%s没有回测记录%s的回测数据" % (graph_id, record_id))
288
+ datas = pickle_loads(backtest_save_file)
289
+ label_columns = datas.get("label_columns")
290
+ if label_columns:
291
+ y_true_column = label_columns[0] if isinstance(label_columns, list) else label_columns
292
+ y_pred_column = y_true_column + "_predict"
293
+ if y_true_column not in datas["data"].columns:
294
+ raise Exception("模型%s回测数据里未找到真实值%s" % (graph_id, y_true_column))
295
+ if y_pred_column not in datas["data"].columns:
296
+ raise Exception("模型%s回测数据里未找到预测值%s" % (graph_id, y_pred_column))
297
+ else:
298
+ y_pred_columns = [i for i in datas["data"].columns if str(i).endswith("_predict")]
299
+ if len(y_pred_columns) > 1:
300
+ raise Exception("回测结果有多个_predict列")
301
+ if len(y_pred_columns) == 0:
302
+ raise Exception("回测结果中未找到_predict列")
303
+ y_pred_column = y_pred_columns[0]
304
+ y_true_column = str(y_pred_column).replace("_predict", "")
305
+ return datas["data"][[y_true_column,y_pred_column]]
306
+
307
+
308
+ def get_model_save_path(graph_id, object_id, node_id):
309
+ """获取模型的自定义模型的保存路径
310
+ args:
311
+ graph_id:预测模型的id
312
+ object_id: 建模对象id
313
+ node_id: 节点id
314
+ """
315
+ if settings.RUN_ENV == 'sdk':
316
+ model_save_path = settings.SDK_MODEL_SAVE_PATH
317
+ nodes_model_save_path = os.path.join(model_save_path, str(graph_id), str(object_id), str(node_id))
318
+ else:
319
+ model_save_path = settings.RUN_MODEL_SAVE_PATH
320
+ nodes_model_save_path = os.path.join(model_save_path, str(graph_id), str(object_id), 'nodes', node_id)
321
+ return nodes_model_save_path
322
+
323
+
324
+ def update_nwp_config(nwp_config, node):
325
+ # 取出节点的数据导入节点
326
+ import_nodes = []
327
+ graph = multi_graph.get(node.graph_info["id"])
328
+ if graph is None:
329
+ raise Exception("全局未找到graph")
330
+
331
+ for n in graph.a._iter_fathers(node):
332
+ if n.module.parent.info["cls_name"] == MODULES.IMPORT and n not in import_nodes:
333
+ import_nodes.append(n)
334
+ if len(import_nodes) > 1:
335
+ raise Exception("导入数据节点有多个,只能更新一个导入数据节点的nwp_config")
336
+ if len(import_nodes) == 0:
337
+ raise Exception("未找到Algodata或StrategyAlgodata数据导入节点")
338
+ import_node = import_nodes[0]
339
+ if import_node.params.get("nwp_config") is None:
340
+ raise Exception("数据导入节点%s没有nwp_config参数" % import_node.info["cls_name"])
341
+
342
+ # 验证nwp_config的格式是否正确
343
+ if not isinstance(nwp_config, dict):
344
+ raise Exception("nwp_config格式不正确")
345
+ if not nwp_config:
346
+ raise Exception("nwp_config传值为空")
347
+ # 更新节点里面的参数
348
+ import_node.params["nwp_config"] = nwp_config
349
+ if graph.use_db:
350
+ # 更新数据库里面的参数
351
+ update_node_params(import_node.id, "nwp_config", str(nwp_config))
352
+ print_to_run_log("更新nwp_config为:%s" % str(nwp_config), graph_id=node.graph_info["id"])
353
+ return 1
354
+
355
+ def get_nwp_config(node):
356
+ # 取出节点的数据导入节点
357
+ import_nodes = []
358
+ graph = multi_graph.get(node.graph_info["id"])
359
+ if graph is None:
360
+ raise Exception("全局未找到graph")
361
+
362
+ for n in graph.a._iter_fathers(node):
363
+ if n.module.parent.info["cls_name"] == MODULES.IMPORT and n not in import_nodes:
364
+ import_nodes.append(n)
365
+ if len(import_nodes) == 0:
366
+ raise Exception("未找到Algodata或StrategyAlgodata数据导入节点")
367
+ nwp_configs = []
368
+ for n in import_nodes:
369
+ if n.params.get("nwp_config") is not None:
370
+ nwp_configs.append(n.params["nwp_config"])
371
+
372
+ if len(nwp_configs) > 1:
373
+ raise Exception("数据导入节点存在多个nwp_config")
374
+
375
+ if len(nwp_configs) == 0:
376
+ raise Exception("数据导入节点不存在nwp_config参数")
377
+
378
+ return nwp_configs[0]
379
+
380
+
381
+ @contextmanager
382
+ def timeout(duration):
383
+ def timeout_handler(signum, frame):
384
+ raise TimeoutError(f'block timedout after {duration} seconds')
385
+ signal.signal(signal.SIGALRM, timeout_handler)
386
+ signal.alarm(duration)
387
+ yield
388
+ signal.alarm(0)
@@ -0,0 +1 @@
1
+ # coding: utf8
@@ -0,0 +1,179 @@
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import traceback
5
+ from PipeGraphPy.db.models import GraphsTB, NodesTB, EdgesTB, ModulesTB
6
+ from PipeGraphPy.constants import GRAPHTYPE
7
+ from PipeGraphPy.storage import store
8
+ from PipeGraphPy.logger import log
9
+ from PipeGraphPy.core.module import get_template_type, Module
10
+
11
+
12
+ def node_is_passed(node):
13
+ """判断节点是否已被过滤"""
14
+ able_node = NodesTB.find(id=node.id, is_pass=0)
15
+ return True if not able_node else False
16
+
17
+
18
+ def check_before_evaluate(graph_id):
19
+ """在生成评估图之前检查
20
+ parameters:
21
+ graph_id: int, 图id
22
+ return:
23
+ 如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
24
+ 如果成功,则返回True
25
+ """
26
+ try:
27
+ graph_info = GraphsTB.find_one(id=graph_id)
28
+ if not graph_info:
29
+ raise Exception("数据库不存在该模型信息")
30
+ if graph_info["category"] != GRAPHTYPE.TRAIN:
31
+ raise Exception("只有一般图模型才能生成评估模型")
32
+ # if not check_graph_run_success(graph_id):
33
+ # raise Exception('该模型还未训练成功,不能生成评估图')
34
+ graph_model = store.load_graph(graph_id)
35
+ pg_list = [
36
+ graph_model.multi_pg.get(i)
37
+ for i in graph_model.estimator_set
38
+ if not node_is_passed(i)
39
+ ]
40
+ if len(pg_list) == 0:
41
+ raise Exception("该模型不存在可用的优化器, 不能生成评估模型")
42
+ if len(pg_list) > 1:
43
+ raise Exception("该模型存在多个可用优化器,必须有且仅有一个可用优化器才能生成评估模型")
44
+ return True
45
+ except FileNotFoundError:
46
+ log.error("模型文件不存在", graph_id=graph_id)
47
+ raise Exception("模型文件不存在")
48
+ except Exception as e:
49
+ log.error(traceback.format_exc(), graph_id=graph_id)
50
+ raise e
51
+
52
+
53
+ def check_before_predict(graph_id):
54
+ """在执行预测之前检查模型
55
+ parameters:
56
+ graph_id: int, 图id
57
+ return:
58
+ 如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
59
+ 如果成功,则返回True
60
+ """
61
+ try:
62
+ graph_info = GraphsTB.find_one(id=graph_id)
63
+ if not graph_info:
64
+ raise Exception("数据库不存在该模型(%s)信息" % graph_id)
65
+ if graph_info.category != GRAPHTYPE.TRAIN:
66
+ raise Exception("%s 只有训练模型才能执行预测" % graph_id)
67
+ # 训练不成功的也能预测,只要存在训练模型
68
+ if not store.has_graph(graph_id):
69
+ raise Exception("%s 不存在训练好的模型" % graph_id)
70
+ # if not check_graph_run_success(graph_id):
71
+ # raise Exception('该模型还未训练成功,不能执行预测')
72
+ # graph_model = store.load_graph(graph_id)
73
+ # pg_list = [graph_model.multi_pg.get(i) for
74
+ # i in graph_model.estimator_set if not node_is_passed(i)]
75
+ # if len(pg_list) == 0:
76
+ # raise Exception('该模型不存在可用的优化器, 不能执行预测')
77
+ # if len(pg_list) > 1:
78
+ # raise Exception('该模型存在多个可用优化器,必须有且仅有一个可用优化器才能执行预测')
79
+ return True
80
+ except FileNotFoundError:
81
+ raise Exception("模型文件不存在")
82
+ except Exception as e:
83
+ raise e
84
+
85
+
86
+ def check_edge(
87
+ graph_id,
88
+ source_id,
89
+ target_id,
90
+ source_anchor=0,
91
+ target_anchor=0,
92
+ nodes_info_dict=None,
93
+ is_run=False,
94
+ ):
95
+ """检查edge的合理性
96
+ parameters:
97
+ graph_id: int, 图id
98
+ source_id: int, 起始节点id
99
+ target_id:int, 结尾结点id
100
+ source_anchor: int, 起始节点连接锚点
101
+ target_anchor: int, 结尾节点连接锚点
102
+ nodes_info_dict: dict, 所有节点信息, 此参数可以不传
103
+ return:
104
+ 如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
105
+ 如果成功,则返回True
106
+ """
107
+ # 是否是同一个节点
108
+ if source_id == target_id:
109
+ raise Exception("起始节点和结尾节点不能是同一个节点")
110
+ if not nodes_info_dict:
111
+ # 这两个节点是否已经存在连线(运行时不需要验证这个)
112
+ edges_info = EdgesTB.find(
113
+ graph_id=graph_id, source_id=source_id, target_id=target_id
114
+ )
115
+ if len(edges_info) != 0:
116
+ raise Exception("这两个节点间已经存在一条连线")
117
+ nodes_info = NodesTB.find(graph_id=graph_id)
118
+ nodes_info_dict = {i["id"]: i for i in nodes_info}
119
+ source_data_type, target_data_type = None, None
120
+ for idx, (node_id, anchor) in enumerate(
121
+ [(source_id, source_anchor), (target_id, target_anchor)]
122
+ ):
123
+ name = "结尾" if idx else "起始"
124
+ if anchor < 0:
125
+ raise Exception("%s锚点值不能小于零(%s)" % (name, anchor))
126
+ # 节点是否存在
127
+ node_info = nodes_info_dict.get(node_id)
128
+ if not node_info:
129
+ raise Exception("未找到%s节点信息: %s " % (name, node_id))
130
+ # 节点对应组件是否存在
131
+ module_info = ModulesTB.find_one(id=node_info["mod_id"])
132
+ if not module_info:
133
+ raise Exception("未找到%s节点对应的组件信息: %s " % (name, node_id))
134
+ template_type = get_template_type(node_info["mod_id"])
135
+ # 是否存在输入配置信息
136
+ if template_type["INPUT"] is None:
137
+ raise Exception("没有%s节点(%s)的输入配置信息" % (name, node_id))
138
+ if template_type["OUTPUT"] is None:
139
+ raise Exception("没有%s节点(%s)的输出配置信息" % (name, node_id))
140
+ # 起始节点是否是输出节点
141
+ total_anchor = len(template_type["INPUT"]) + len(template_type["OUTPUT"])
142
+ if anchor > total_anchor:
143
+ raise Exception("%s节点的锚点(%s)超出了锚点范围(%s)" % (name, anchor, total_anchor))
144
+ if idx:
145
+ if anchor > len(template_type["INPUT"]):
146
+ raise Exception("结尾节点的锚点(%s)必须是图输入锚点" % anchor)
147
+ target_data_type = template_type["INPUT"][anchor]
148
+ else:
149
+ if anchor - len(template_type["INPUT"]) < 0:
150
+ raise Exception("起始节点的锚点(%s)不能是输入锚点" % anchor)
151
+ source_data_type = template_type["OUTPUT"][
152
+ anchor - len(template_type["INPUT"])
153
+ ]
154
+ # 节点的起始节点的数据类型和结尾结点的数据类型是否匹配
155
+ if target_data_type.find(source_data_type) == -1:
156
+ raise Exception("起始节点%s和结尾节点的数据类型不匹配%s" % (source_data_type, target_data_type))
157
+ # 检验一个节点是否能接收两个节点的数据
158
+ if not is_run:
159
+ if target_data_type.find("list") == -1:
160
+ edge_info = EdgesTB.find(
161
+ graph_id=graph_id, target_id=target_id, target_anchor=target_anchor
162
+ )
163
+ if edge_info:
164
+ raise Exception("目标节点的锚点只接收一个数据")
165
+ # 是否有环在运行的时候校验,在此处不做检验
166
+ return True
167
+
168
+
169
+ def check_node(graph_id):
170
+ """检查节点的合理性"""
171
+ pass
172
+
173
+
174
+ def check_module_code(module_id):
175
+ """检查自定义组件是否合理
176
+ module_id: 组件id
177
+ """
178
+ module = Module(module_id)
179
+ return module.check_code()