PipeGraphPy 2.0.6__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PipeGraphPy/__init__.py +10 -0
- PipeGraphPy/common.py +4 -0
- PipeGraphPy/config/__init__.py +276 -0
- PipeGraphPy/config/custom.py +6 -0
- PipeGraphPy/config/default_settings.py +125 -0
- PipeGraphPy/constants.py +421 -0
- PipeGraphPy/core/__init__.py +2 -0
- PipeGraphPy/core/anchor.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/edge.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/graph.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/graph_base.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/__init__.py +3 -0
- PipeGraphPy/core/modcls/base.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/branchselect.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/classifier.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/cluster.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/datacharts.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/deeplearning.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/endscript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/ensemble.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/evaluate.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/exportdata.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/handlescript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/importdata.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/merge.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/mergescript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/metrics.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/postprocessor.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/preprocessor.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/pythonscript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/regressor.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/selector.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/selectscript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/special.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/split.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/splitscript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/startscript.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modcls/transformer.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/module.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/modules/__init__.py +65 -0
- PipeGraphPy/core/modules/classifier/__init__.py +2 -0
- PipeGraphPy/core/modules/cluster/__init__.py +0 -0
- PipeGraphPy/core/modules/custom/__init__.py +0 -0
- PipeGraphPy/core/modules/custom/classifier/__init__.py +0 -0
- PipeGraphPy/core/modules/datacharts/__init__.py +5 -0
- PipeGraphPy/core/modules/datacharts/dataview.py +28 -0
- PipeGraphPy/core/modules/deeplearning/__init__.py +0 -0
- PipeGraphPy/core/modules/ensemble/__init__.py +0 -0
- PipeGraphPy/core/modules/evaluate/__init__.py +0 -0
- PipeGraphPy/core/modules/exportdata/__init__.py +0 -0
- PipeGraphPy/core/modules/importdata/__init__.py +0 -0
- PipeGraphPy/core/modules/merge/__init__.py +0 -0
- PipeGraphPy/core/modules/model_selector/__init__.py +3 -0
- PipeGraphPy/core/modules/postprocessor/__init__.py +0 -0
- PipeGraphPy/core/modules/preprocessor/__init__.py +0 -0
- PipeGraphPy/core/modules/pythonscript/__init__.py +0 -0
- PipeGraphPy/core/modules/regressor/__init__.py +0 -0
- PipeGraphPy/core/modules/selector/__init__.py +0 -0
- PipeGraphPy/core/modules/special/__init__.py +0 -0
- PipeGraphPy/core/modules/split/__init__.py +0 -0
- PipeGraphPy/core/modules/transformer/__init__.py +0 -0
- PipeGraphPy/core/node.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/core/pipegraph.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/db/__init__.py +2 -0
- PipeGraphPy/db/models.cp39-win_amd64.pyd +0 -0
- PipeGraphPy/db/utils.py +106 -0
- PipeGraphPy/decorators.py +42 -0
- PipeGraphPy/logger.py +170 -0
- PipeGraphPy/plot/__init__.py +0 -0
- PipeGraphPy/plot/draw.py +424 -0
- PipeGraphPy/storage/__init__.py +10 -0
- PipeGraphPy/storage/base.py +2 -0
- PipeGraphPy/storage/dict_backend.py +102 -0
- PipeGraphPy/storage/file_backend.py +342 -0
- PipeGraphPy/storage/redis_backend.py +183 -0
- PipeGraphPy/tools.py +388 -0
- PipeGraphPy/utils/__init__.py +1 -0
- PipeGraphPy/utils/check.py +179 -0
- PipeGraphPy/utils/core.py +295 -0
- PipeGraphPy/utils/examine.py +259 -0
- PipeGraphPy/utils/file_operate.py +101 -0
- PipeGraphPy/utils/format.py +303 -0
- PipeGraphPy/utils/functional.py +422 -0
- PipeGraphPy/utils/handle_graph.py +31 -0
- PipeGraphPy/utils/lock.py +1 -0
- PipeGraphPy/utils/mq.py +54 -0
- PipeGraphPy/utils/osutil.py +29 -0
- PipeGraphPy/utils/redis_operate.py +195 -0
- PipeGraphPy/utils/str_handle.py +122 -0
- PipeGraphPy/utils/version.py +108 -0
- PipeGraphPy-2.0.6.dist-info/METADATA +17 -0
- PipeGraphPy-2.0.6.dist-info/RECORD +94 -0
- PipeGraphPy-2.0.6.dist-info/WHEEL +5 -0
- PipeGraphPy-2.0.6.dist-info/top_level.txt +1 -0
PipeGraphPy/tools.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
import os
|
|
4
|
+
import signal
|
|
5
|
+
import copy
|
|
6
|
+
import time
|
|
7
|
+
import json
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from contextlib import contextmanager
|
|
10
|
+
from PipeGraphPy.db.models import GraphsTB, PredictRecordTB, BacktestRecordTB, ObjectGraphsTB
|
|
11
|
+
from datetime import datetime, timedelta
|
|
12
|
+
from PipeGraphPy.config import settings
|
|
13
|
+
from PipeGraphPy.db.utils import update_node_params
|
|
14
|
+
from PipeGraphPy.constants import MODULES, STATUS
|
|
15
|
+
from PipeGraphPy.common import multi_graph
|
|
16
|
+
from PipeGraphPy.logger import log
|
|
17
|
+
from PipeGraphPy.utils.file_operate import pickle_loads
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def print_to_run_log(*args, graph_id=None, rlog_record_id=None):
|
|
21
|
+
'''将内容输出到run_log里,前端来展示
|
|
22
|
+
params:
|
|
23
|
+
args: 要打印的内容
|
|
24
|
+
graph_id: 运行图id
|
|
25
|
+
rlog_record_id: 运行记录id
|
|
26
|
+
'''
|
|
27
|
+
assert graph_id, 'graph_id必传'
|
|
28
|
+
msg = '\n'.join([str(i) for i in args])
|
|
29
|
+
if settings.SDK_SHOW_LOG:
|
|
30
|
+
print(msg)
|
|
31
|
+
if isinstance(graph_id, int):
|
|
32
|
+
GraphsTB.add_log(graph_id, msg)
|
|
33
|
+
if rlog_record_id:
|
|
34
|
+
log.info(msg)
|
|
35
|
+
|
|
36
|
+
def print_to_predict_log(*args, plog_record_id=None):
|
|
37
|
+
'''将内容输出到预测日志里,前端来展示
|
|
38
|
+
params:
|
|
39
|
+
args: 要打印的内容
|
|
40
|
+
plog_record_id: 预测记录id
|
|
41
|
+
'''
|
|
42
|
+
assert plog_record_id, 'plog_record_id必传'
|
|
43
|
+
PredictRecordTB.add_log(plog_record_id, '\n'.join([str(i) for i in args]))
|
|
44
|
+
|
|
45
|
+
def update_params_value(key, value, node):
|
|
46
|
+
"""更新当前节点参数"""
|
|
47
|
+
node.params[key] = value
|
|
48
|
+
graph = multi_graph.get(node.graph_info["id"])
|
|
49
|
+
if graph.use_db:
|
|
50
|
+
update_node_params(node.id, key, value)
|
|
51
|
+
return 1
|
|
52
|
+
|
|
53
|
+
def update_params_source(key, value, node):
|
|
54
|
+
"""更新当前节点参数source"""
|
|
55
|
+
graph = multi_graph.get(node.graph_info["id"])
|
|
56
|
+
if graph.use_db:
|
|
57
|
+
update_node_params(node.id, key, value, value_key="source")
|
|
58
|
+
return 1
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def predict_to_csv(datas, filename="",
|
|
62
|
+
graph_id='', object_id='', node_id='',
|
|
63
|
+
plog_record_id='', online_plog_record_id="",
|
|
64
|
+
filepath=""):
|
|
65
|
+
# 取出节点的数据导入节点
|
|
66
|
+
assert isinstance(datas, pd.DataFrame), "预测保存的数据必须是DataFrame"
|
|
67
|
+
assert graph_id, "未传graph_id"
|
|
68
|
+
assert object_id, "未传object_id"
|
|
69
|
+
assert node_id, "未传node_id"
|
|
70
|
+
if filepath:
|
|
71
|
+
if not filepath.endswith(".csv"):
|
|
72
|
+
raise Exception("预测数据只能保存csv文件")
|
|
73
|
+
datas.to_csv(filepath, encoding="utf_8", index=False)
|
|
74
|
+
else:
|
|
75
|
+
if not plog_record_id and not online_plog_record_id:
|
|
76
|
+
return 0
|
|
77
|
+
predict_save_path = os.path.join(
|
|
78
|
+
settings.PREDICT_RESULT_SAVE_PATH,
|
|
79
|
+
str(graph_id),
|
|
80
|
+
str(object_id))
|
|
81
|
+
if not os.path.exists(predict_save_path):
|
|
82
|
+
os.makedirs(predict_save_path)
|
|
83
|
+
if filename:
|
|
84
|
+
predict_filename = filename
|
|
85
|
+
else:
|
|
86
|
+
if online_plog_record_id:
|
|
87
|
+
auto_file_name_prefix = "online_predict_%s" % online_plog_record_id
|
|
88
|
+
else:
|
|
89
|
+
auto_file_name_prefix = "predict_%s" % plog_record_id
|
|
90
|
+
predict_filename = "%s_%s_%s.csv" % (
|
|
91
|
+
auto_file_name_prefix,
|
|
92
|
+
str(node_id),
|
|
93
|
+
(datetime.utcnow()+timedelta(hours=8)).strftime("%Y%m%d%H%M%S"))
|
|
94
|
+
datas.to_csv(os.path.join(predict_save_path, predict_filename), encoding="utf_8", index=False)
|
|
95
|
+
return 1
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def read_predict_csv(graph_id, object_id, start_date=None, end_date=None):
|
|
99
|
+
"""读取模型历史的预测数据
|
|
100
|
+
args:
|
|
101
|
+
graph_id:预测模型的id
|
|
102
|
+
object_id, 建模对象id,
|
|
103
|
+
start_date: 预测数据批次的开始日期 (包含), 不传的话只返回最近的一次数据, 格式:YYYYmmdd or YYYY-mm-dd
|
|
104
|
+
end_date: 预测数据批次的结束日期(包含), 不传的话只返回最近的一次数据, 格式:YYYYmmdd or YYYY-mm-dd
|
|
105
|
+
"""
|
|
106
|
+
datas = pd.DataFrame()
|
|
107
|
+
predict_save_path = os.path.join(settings.PREDICT_RESULT_SAVE_PATH, str(graph_id), str(object_id))
|
|
108
|
+
if not os.path.exists(predict_save_path):
|
|
109
|
+
return datas
|
|
110
|
+
file_list = os.listdir(predict_save_path)
|
|
111
|
+
file_time_dict = {(i.split("_")[-1]).replace(".csv", ""):i for i in file_list}
|
|
112
|
+
time_list = list(file_time_dict.keys())
|
|
113
|
+
if settings.RUN_ENV != "image":
|
|
114
|
+
if start_date and end_date:
|
|
115
|
+
start_date = str(start_date).replace('-', "")
|
|
116
|
+
end_date = str(end_date).replace('-', "")
|
|
117
|
+
daterange = pd.date_range(start_date, end_date, freq="D").to_list()
|
|
118
|
+
for d in daterange:
|
|
119
|
+
match_times = [i for i in time_list if i.startswith(d.strftime("%Y%m%d"))]
|
|
120
|
+
match_time = None
|
|
121
|
+
if len(match_times) > 1:
|
|
122
|
+
match_time = str(max(list(map(int, match_times))))
|
|
123
|
+
elif len(match_times) == 1:
|
|
124
|
+
match_time = match_times[0]
|
|
125
|
+
else:
|
|
126
|
+
continue
|
|
127
|
+
df = pd.read_csv(os.path.join(predict_save_path, file_time_dict[match_time]))
|
|
128
|
+
df["file_date"] = match_time[:8]
|
|
129
|
+
datas = datas.append(df)
|
|
130
|
+
return datas
|
|
131
|
+
else:
|
|
132
|
+
match_time = str(max(list(map(int, time_list))))
|
|
133
|
+
datas = pd.read_csv(os.path.join(predict_save_path, file_time_dict[match_time]))
|
|
134
|
+
datas["file_date"] = match_time[:8]
|
|
135
|
+
return datas
|
|
136
|
+
else:
|
|
137
|
+
# 请求NPMOS的接口
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def train_to_csv(datas, graph_id='', node_id=''):
|
|
142
|
+
# 取出节点的数据导入节点
|
|
143
|
+
assert isinstance(datas, pd.DataFrame), "预测保存的数据必须是DataFrame"
|
|
144
|
+
assert graph_id, "未传graph_id"
|
|
145
|
+
assert node_id, "未传node_id"
|
|
146
|
+
run_save_path = os.path.join(
|
|
147
|
+
settings.RUN_RESULT_SAVE_PATH,
|
|
148
|
+
str(graph_id))
|
|
149
|
+
if not os.path.exists(run_save_path):
|
|
150
|
+
os.makedirs(run_save_path)
|
|
151
|
+
run_filename = "run_%s.csv" % str(node_id)
|
|
152
|
+
datas.to_csv(os.path.join(run_save_path, run_filename), encoding="utf_8", index=False)
|
|
153
|
+
return 1
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def read_train_csv(graph_id, node_id):
|
|
157
|
+
"""读取模型历史的预测数据
|
|
158
|
+
args:
|
|
159
|
+
graph_id:预测模型的id
|
|
160
|
+
node_id: 节点
|
|
161
|
+
"""
|
|
162
|
+
run_save_path = os.path.join(settings.RUN_RESULT_SAVE_PATH, str(graph_id))
|
|
163
|
+
file_list = os.listdir(run_save_path)
|
|
164
|
+
datas = pd.DataFrame()
|
|
165
|
+
if node_id:
|
|
166
|
+
file_list = [f for f in file_list if int(str(str(f).split("-")[-1]).replace(".csv", "")) == int(node_id)]
|
|
167
|
+
for f in file_list:
|
|
168
|
+
df = pd.read_csv(os.path.join(run_save_path, f))
|
|
169
|
+
df["node_id"] = str(str(f).split("-")[-1]).replace(".csv", "")
|
|
170
|
+
datas = datas.append(df)
|
|
171
|
+
return datas
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def read_backtest_data(graph_id, object_id, start_date=None, end_date=None, record_id=None):
|
|
175
|
+
"""读取回测数据结果
|
|
176
|
+
args:
|
|
177
|
+
graph_id:回测模型的id, 必传
|
|
178
|
+
object_id: 回测建模对象id,必传
|
|
179
|
+
start_date: 回测的开始日期 (包含), 不传的话读取最近的一次批次或指定record_id的回测, 格式:YYYY-mm-dd
|
|
180
|
+
end_date: 回测的结束日期 (包含), 不传的话读取最近的一次批次或指定record_id的回测, 格式:YYYY-mm-dd
|
|
181
|
+
record_id: 回测记录的记录id, start_date, end_date, record_id都不传的情况会读取最近一次回测结果
|
|
182
|
+
example:
|
|
183
|
+
>>> first_backtest = read_backtest_data(1371, '2023-07-01', '2023-07-26')
|
|
184
|
+
>>> first_backtest
|
|
185
|
+
power power_predict
|
|
186
|
+
time
|
|
187
|
+
2023-07-01 00:45:00 0.0 0.0
|
|
188
|
+
2023-07-01 01:00:00 0.0 0.0
|
|
189
|
+
2023-07-01 01:15:00 0.0 0.0
|
|
190
|
+
2023-07-01 01:30:00 0.0 0.0
|
|
191
|
+
2023-07-01 01:45:00 0.0 0.0
|
|
192
|
+
... ... ...
|
|
193
|
+
2023-07-26 22:45:00 0.0 0.0
|
|
194
|
+
2023-07-26 23:00:00 0.0 0.0
|
|
195
|
+
2023-07-26 23:15:00 0.0 0.0
|
|
196
|
+
2023-07-26 23:30:00 0.0 0.0
|
|
197
|
+
2023-07-26 23:45:00 0.0 0.0
|
|
198
|
+
|
|
199
|
+
[2492 rows x 2 columns]
|
|
200
|
+
>>> second_backtest = read_backtest_data(1371, '2023-06-01', '2023-07-26')
|
|
201
|
+
>>> second_backtest
|
|
202
|
+
power power_predict
|
|
203
|
+
time
|
|
204
|
+
2023-06-01 00:45:00 0.0 0.0
|
|
205
|
+
2023-06-01 01:00:00 0.0 0.0
|
|
206
|
+
2023-06-01 01:15:00 0.0 0.0
|
|
207
|
+
2023-06-01 01:30:00 0.0 0.0
|
|
208
|
+
2023-06-01 01:45:00 0.0 0.0
|
|
209
|
+
... ... ...
|
|
210
|
+
2023-07-26 22:45:00 0.0 0.0
|
|
211
|
+
2023-07-26 23:00:00 0.0 0.0
|
|
212
|
+
2023-07-26 23:15:00 0.0 0.0
|
|
213
|
+
2023-07-26 23:30:00 0.0 0.0
|
|
214
|
+
2023-07-26 23:45:00 0.0 0.0
|
|
215
|
+
|
|
216
|
+
[5372 rows x 2 columns]
|
|
217
|
+
>>> backtest_by_record_id = read_backtest_data(1371, record_id=302)
|
|
218
|
+
>>> backtest_by_record_id
|
|
219
|
+
power power_predict
|
|
220
|
+
time
|
|
221
|
+
2023-07-01 00:45:00 0.0 0.0
|
|
222
|
+
2023-07-01 01:00:00 0.0 0.0
|
|
223
|
+
2023-07-01 01:15:00 0.0 0.0
|
|
224
|
+
2023-07-01 01:30:00 0.0 0.0
|
|
225
|
+
2023-07-01 01:45:00 0.0 0.0
|
|
226
|
+
... ... ...
|
|
227
|
+
2023-07-26 22:45:00 0.0 0.0
|
|
228
|
+
2023-07-26 23:00:00 0.0 0.0
|
|
229
|
+
2023-07-26 23:15:00 0.0 0.0
|
|
230
|
+
2023-07-26 23:30:00 0.0 0.0
|
|
231
|
+
2023-07-26 23:45:00 0.0 0.0
|
|
232
|
+
|
|
233
|
+
[2492 rows x 2 columns]
|
|
234
|
+
>>> last_backtest_data = read_backtest_data(1371)
|
|
235
|
+
>>> last_backtest_data
|
|
236
|
+
power power_predict
|
|
237
|
+
time
|
|
238
|
+
2023-06-01 00:45:00 0.0 0.0
|
|
239
|
+
2023-06-01 01:00:00 0.0 0.0
|
|
240
|
+
2023-06-01 01:15:00 0.0 0.0
|
|
241
|
+
2023-06-01 01:30:00 0.0 0.0
|
|
242
|
+
2023-06-01 01:45:00 0.0 0.0
|
|
243
|
+
... ... ...
|
|
244
|
+
2023-07-26 22:45:00 0.0 0.0
|
|
245
|
+
2023-07-26 23:00:00 0.0 0.0
|
|
246
|
+
2023-07-26 23:15:00 0.0 0.0
|
|
247
|
+
2023-07-26 23:30:00 0.0 0.0
|
|
248
|
+
2023-07-26 23:45:00 0.0 0.0
|
|
249
|
+
|
|
250
|
+
[5372 rows x 2 columns]
|
|
251
|
+
"""
|
|
252
|
+
datas = pd.DataFrame()
|
|
253
|
+
if start_date and end_date:
|
|
254
|
+
graph_info = GraphsTB.get(id=graph_id)
|
|
255
|
+
if graph_info["status"] != STATUS.SUCCESS:
|
|
256
|
+
raise Exception("该模型%s还未训练成功" % graph_id)
|
|
257
|
+
if graph_info["b_status"] in [STATUS.WAITRUN, STATUS.WAITEXE]:
|
|
258
|
+
raise Exception("模型%s正在等待回测,无法再次回测" % graph_id)
|
|
259
|
+
if graph_info["b_status"] in [STATUS.RUNNING]:
|
|
260
|
+
raise Exception("模型%s正在回测,无法再次回测" % graph_id)
|
|
261
|
+
ObjectGraphsTB.set(
|
|
262
|
+
b_status=STATUS.WAITRUN,
|
|
263
|
+
backtest_params=json.dumps({"start_dt": start_date, "end_dt":end_date})
|
|
264
|
+
).where(graph_id=graph_id, object_id=object_id)
|
|
265
|
+
while True:
|
|
266
|
+
time.sleep(5)
|
|
267
|
+
b_status = ObjectGraphsTB.map_one("b_status").where(graph_id=graph_id, object_id=object_id)
|
|
268
|
+
if b_status == STATUS.ERROR:
|
|
269
|
+
backtest_record_id = ObjectGraphsTB.map_one("backtest_record_id").where(graph_id=graph_id, object_id=object_id)
|
|
270
|
+
if backtest_record_id:
|
|
271
|
+
record_log = BacktestRecordTB.map_one("log").where(id=backtest_record_id)
|
|
272
|
+
raise Exception("模型%s回测报错,回测日志:\n%s" % (graph_id, record_log))
|
|
273
|
+
raise Exception("模型%s回测报错" % graph_id)
|
|
274
|
+
elif b_status == STATUS.SUCCESS:
|
|
275
|
+
break
|
|
276
|
+
else:
|
|
277
|
+
continue
|
|
278
|
+
if record_id:
|
|
279
|
+
backtest_record_id = record_id
|
|
280
|
+
else:
|
|
281
|
+
backtest_record_id = ObjectGraphsTB.map_one("backtest_record_id").where(graph_id=graph_id, object_id=object_id)
|
|
282
|
+
if not backtest_record_id:
|
|
283
|
+
raise Exception("模型%s没有获取到回测记录" % graph_id)
|
|
284
|
+
backtest_save_path = os.path.join(settings.BACKTEST_RESULT_SAVE_PATH, str(graph_id))
|
|
285
|
+
backtest_save_file = os.path.join(backtest_save_path, "backtest_%s.pkl" % backtest_record_id)
|
|
286
|
+
if not os.path.isfile(backtest_save_file):
|
|
287
|
+
raise Exception("模型%s没有回测记录%s的回测数据" % (graph_id, record_id))
|
|
288
|
+
datas = pickle_loads(backtest_save_file)
|
|
289
|
+
label_columns = datas.get("label_columns")
|
|
290
|
+
if label_columns:
|
|
291
|
+
y_true_column = label_columns[0] if isinstance(label_columns, list) else label_columns
|
|
292
|
+
y_pred_column = y_true_column + "_predict"
|
|
293
|
+
if y_true_column not in datas["data"].columns:
|
|
294
|
+
raise Exception("模型%s回测数据里未找到真实值%s" % (graph_id, y_true_column))
|
|
295
|
+
if y_pred_column not in datas["data"].columns:
|
|
296
|
+
raise Exception("模型%s回测数据里未找到预测值%s" % (graph_id, y_pred_column))
|
|
297
|
+
else:
|
|
298
|
+
y_pred_columns = [i for i in datas["data"].columns if str(i).endswith("_predict")]
|
|
299
|
+
if len(y_pred_columns) > 1:
|
|
300
|
+
raise Exception("回测结果有多个_predict列")
|
|
301
|
+
if len(y_pred_columns) == 0:
|
|
302
|
+
raise Exception("回测结果中未找到_predict列")
|
|
303
|
+
y_pred_column = y_pred_columns[0]
|
|
304
|
+
y_true_column = str(y_pred_column).replace("_predict", "")
|
|
305
|
+
return datas["data"][[y_true_column,y_pred_column]]
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def get_model_save_path(graph_id, object_id, node_id):
|
|
309
|
+
"""获取模型的自定义模型的保存路径
|
|
310
|
+
args:
|
|
311
|
+
graph_id:预测模型的id
|
|
312
|
+
object_id: 建模对象id
|
|
313
|
+
node_id: 节点id
|
|
314
|
+
"""
|
|
315
|
+
if settings.RUN_ENV == 'sdk':
|
|
316
|
+
model_save_path = settings.SDK_MODEL_SAVE_PATH
|
|
317
|
+
nodes_model_save_path = os.path.join(model_save_path, str(graph_id), str(object_id), str(node_id))
|
|
318
|
+
else:
|
|
319
|
+
model_save_path = settings.RUN_MODEL_SAVE_PATH
|
|
320
|
+
nodes_model_save_path = os.path.join(model_save_path, str(graph_id), str(object_id), 'nodes', node_id)
|
|
321
|
+
return nodes_model_save_path
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def update_nwp_config(nwp_config, node):
|
|
325
|
+
# 取出节点的数据导入节点
|
|
326
|
+
import_nodes = []
|
|
327
|
+
graph = multi_graph.get(node.graph_info["id"])
|
|
328
|
+
if graph is None:
|
|
329
|
+
raise Exception("全局未找到graph")
|
|
330
|
+
|
|
331
|
+
for n in graph.a._iter_fathers(node):
|
|
332
|
+
if n.module.parent.info["cls_name"] == MODULES.IMPORT and n not in import_nodes:
|
|
333
|
+
import_nodes.append(n)
|
|
334
|
+
if len(import_nodes) > 1:
|
|
335
|
+
raise Exception("导入数据节点有多个,只能更新一个导入数据节点的nwp_config")
|
|
336
|
+
if len(import_nodes) == 0:
|
|
337
|
+
raise Exception("未找到Algodata或StrategyAlgodata数据导入节点")
|
|
338
|
+
import_node = import_nodes[0]
|
|
339
|
+
if import_node.params.get("nwp_config") is None:
|
|
340
|
+
raise Exception("数据导入节点%s没有nwp_config参数" % import_node.info["cls_name"])
|
|
341
|
+
|
|
342
|
+
# 验证nwp_config的格式是否正确
|
|
343
|
+
if not isinstance(nwp_config, dict):
|
|
344
|
+
raise Exception("nwp_config格式不正确")
|
|
345
|
+
if not nwp_config:
|
|
346
|
+
raise Exception("nwp_config传值为空")
|
|
347
|
+
# 更新节点里面的参数
|
|
348
|
+
import_node.params["nwp_config"] = nwp_config
|
|
349
|
+
if graph.use_db:
|
|
350
|
+
# 更新数据库里面的参数
|
|
351
|
+
update_node_params(import_node.id, "nwp_config", str(nwp_config))
|
|
352
|
+
print_to_run_log("更新nwp_config为:%s" % str(nwp_config), graph_id=node.graph_info["id"])
|
|
353
|
+
return 1
|
|
354
|
+
|
|
355
|
+
def get_nwp_config(node):
|
|
356
|
+
# 取出节点的数据导入节点
|
|
357
|
+
import_nodes = []
|
|
358
|
+
graph = multi_graph.get(node.graph_info["id"])
|
|
359
|
+
if graph is None:
|
|
360
|
+
raise Exception("全局未找到graph")
|
|
361
|
+
|
|
362
|
+
for n in graph.a._iter_fathers(node):
|
|
363
|
+
if n.module.parent.info["cls_name"] == MODULES.IMPORT and n not in import_nodes:
|
|
364
|
+
import_nodes.append(n)
|
|
365
|
+
if len(import_nodes) == 0:
|
|
366
|
+
raise Exception("未找到Algodata或StrategyAlgodata数据导入节点")
|
|
367
|
+
nwp_configs = []
|
|
368
|
+
for n in import_nodes:
|
|
369
|
+
if n.params.get("nwp_config") is not None:
|
|
370
|
+
nwp_configs.append(n.params["nwp_config"])
|
|
371
|
+
|
|
372
|
+
if len(nwp_configs) > 1:
|
|
373
|
+
raise Exception("数据导入节点存在多个nwp_config")
|
|
374
|
+
|
|
375
|
+
if len(nwp_configs) == 0:
|
|
376
|
+
raise Exception("数据导入节点不存在nwp_config参数")
|
|
377
|
+
|
|
378
|
+
return nwp_configs[0]
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@contextmanager
|
|
382
|
+
def timeout(duration):
|
|
383
|
+
def timeout_handler(signum, frame):
|
|
384
|
+
raise TimeoutError(f'block timedout after {duration} seconds')
|
|
385
|
+
signal.signal(signal.SIGALRM, timeout_handler)
|
|
386
|
+
signal.alarm(duration)
|
|
387
|
+
yield
|
|
388
|
+
signal.alarm(0)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# coding: utf8
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
import traceback
|
|
5
|
+
from PipeGraphPy.db.models import GraphsTB, NodesTB, EdgesTB, ModulesTB
|
|
6
|
+
from PipeGraphPy.constants import GRAPHTYPE
|
|
7
|
+
from PipeGraphPy.storage import store
|
|
8
|
+
from PipeGraphPy.logger import log
|
|
9
|
+
from PipeGraphPy.core.module import get_template_type, Module
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def node_is_passed(node):
|
|
13
|
+
"""判断节点是否已被过滤"""
|
|
14
|
+
able_node = NodesTB.find(id=node.id, is_pass=0)
|
|
15
|
+
return True if not able_node else False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def check_before_evaluate(graph_id):
|
|
19
|
+
"""在生成评估图之前检查
|
|
20
|
+
parameters:
|
|
21
|
+
graph_id: int, 图id
|
|
22
|
+
return:
|
|
23
|
+
如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
|
|
24
|
+
如果成功,则返回True
|
|
25
|
+
"""
|
|
26
|
+
try:
|
|
27
|
+
graph_info = GraphsTB.find_one(id=graph_id)
|
|
28
|
+
if not graph_info:
|
|
29
|
+
raise Exception("数据库不存在该模型信息")
|
|
30
|
+
if graph_info["category"] != GRAPHTYPE.TRAIN:
|
|
31
|
+
raise Exception("只有一般图模型才能生成评估模型")
|
|
32
|
+
# if not check_graph_run_success(graph_id):
|
|
33
|
+
# raise Exception('该模型还未训练成功,不能生成评估图')
|
|
34
|
+
graph_model = store.load_graph(graph_id)
|
|
35
|
+
pg_list = [
|
|
36
|
+
graph_model.multi_pg.get(i)
|
|
37
|
+
for i in graph_model.estimator_set
|
|
38
|
+
if not node_is_passed(i)
|
|
39
|
+
]
|
|
40
|
+
if len(pg_list) == 0:
|
|
41
|
+
raise Exception("该模型不存在可用的优化器, 不能生成评估模型")
|
|
42
|
+
if len(pg_list) > 1:
|
|
43
|
+
raise Exception("该模型存在多个可用优化器,必须有且仅有一个可用优化器才能生成评估模型")
|
|
44
|
+
return True
|
|
45
|
+
except FileNotFoundError:
|
|
46
|
+
log.error("模型文件不存在", graph_id=graph_id)
|
|
47
|
+
raise Exception("模型文件不存在")
|
|
48
|
+
except Exception as e:
|
|
49
|
+
log.error(traceback.format_exc(), graph_id=graph_id)
|
|
50
|
+
raise e
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def check_before_predict(graph_id):
|
|
54
|
+
"""在执行预测之前检查模型
|
|
55
|
+
parameters:
|
|
56
|
+
graph_id: int, 图id
|
|
57
|
+
return:
|
|
58
|
+
如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
|
|
59
|
+
如果成功,则返回True
|
|
60
|
+
"""
|
|
61
|
+
try:
|
|
62
|
+
graph_info = GraphsTB.find_one(id=graph_id)
|
|
63
|
+
if not graph_info:
|
|
64
|
+
raise Exception("数据库不存在该模型(%s)信息" % graph_id)
|
|
65
|
+
if graph_info.category != GRAPHTYPE.TRAIN:
|
|
66
|
+
raise Exception("%s 只有训练模型才能执行预测" % graph_id)
|
|
67
|
+
# 训练不成功的也能预测,只要存在训练模型
|
|
68
|
+
if not store.has_graph(graph_id):
|
|
69
|
+
raise Exception("%s 不存在训练好的模型" % graph_id)
|
|
70
|
+
# if not check_graph_run_success(graph_id):
|
|
71
|
+
# raise Exception('该模型还未训练成功,不能执行预测')
|
|
72
|
+
# graph_model = store.load_graph(graph_id)
|
|
73
|
+
# pg_list = [graph_model.multi_pg.get(i) for
|
|
74
|
+
# i in graph_model.estimator_set if not node_is_passed(i)]
|
|
75
|
+
# if len(pg_list) == 0:
|
|
76
|
+
# raise Exception('该模型不存在可用的优化器, 不能执行预测')
|
|
77
|
+
# if len(pg_list) > 1:
|
|
78
|
+
# raise Exception('该模型存在多个可用优化器,必须有且仅有一个可用优化器才能执行预测')
|
|
79
|
+
return True
|
|
80
|
+
except FileNotFoundError:
|
|
81
|
+
raise Exception("模型文件不存在")
|
|
82
|
+
except Exception as e:
|
|
83
|
+
raise e
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def check_edge(
|
|
87
|
+
graph_id,
|
|
88
|
+
source_id,
|
|
89
|
+
target_id,
|
|
90
|
+
source_anchor=0,
|
|
91
|
+
target_anchor=0,
|
|
92
|
+
nodes_info_dict=None,
|
|
93
|
+
is_run=False,
|
|
94
|
+
):
|
|
95
|
+
"""检查edge的合理性
|
|
96
|
+
parameters:
|
|
97
|
+
graph_id: int, 图id
|
|
98
|
+
source_id: int, 起始节点id
|
|
99
|
+
target_id:int, 结尾结点id
|
|
100
|
+
source_anchor: int, 起始节点连接锚点
|
|
101
|
+
target_anchor: int, 结尾节点连接锚点
|
|
102
|
+
nodes_info_dict: dict, 所有节点信息, 此参数可以不传
|
|
103
|
+
return:
|
|
104
|
+
如果不成功,则返回Exception报错信息,报错信息中包含不成功原因;
|
|
105
|
+
如果成功,则返回True
|
|
106
|
+
"""
|
|
107
|
+
# 是否是同一个节点
|
|
108
|
+
if source_id == target_id:
|
|
109
|
+
raise Exception("起始节点和结尾节点不能是同一个节点")
|
|
110
|
+
if not nodes_info_dict:
|
|
111
|
+
# 这两个节点是否已经存在连线(运行时不需要验证这个)
|
|
112
|
+
edges_info = EdgesTB.find(
|
|
113
|
+
graph_id=graph_id, source_id=source_id, target_id=target_id
|
|
114
|
+
)
|
|
115
|
+
if len(edges_info) != 0:
|
|
116
|
+
raise Exception("这两个节点间已经存在一条连线")
|
|
117
|
+
nodes_info = NodesTB.find(graph_id=graph_id)
|
|
118
|
+
nodes_info_dict = {i["id"]: i for i in nodes_info}
|
|
119
|
+
source_data_type, target_data_type = None, None
|
|
120
|
+
for idx, (node_id, anchor) in enumerate(
|
|
121
|
+
[(source_id, source_anchor), (target_id, target_anchor)]
|
|
122
|
+
):
|
|
123
|
+
name = "结尾" if idx else "起始"
|
|
124
|
+
if anchor < 0:
|
|
125
|
+
raise Exception("%s锚点值不能小于零(%s)" % (name, anchor))
|
|
126
|
+
# 节点是否存在
|
|
127
|
+
node_info = nodes_info_dict.get(node_id)
|
|
128
|
+
if not node_info:
|
|
129
|
+
raise Exception("未找到%s节点信息: %s " % (name, node_id))
|
|
130
|
+
# 节点对应组件是否存在
|
|
131
|
+
module_info = ModulesTB.find_one(id=node_info["mod_id"])
|
|
132
|
+
if not module_info:
|
|
133
|
+
raise Exception("未找到%s节点对应的组件信息: %s " % (name, node_id))
|
|
134
|
+
template_type = get_template_type(node_info["mod_id"])
|
|
135
|
+
# 是否存在输入配置信息
|
|
136
|
+
if template_type["INPUT"] is None:
|
|
137
|
+
raise Exception("没有%s节点(%s)的输入配置信息" % (name, node_id))
|
|
138
|
+
if template_type["OUTPUT"] is None:
|
|
139
|
+
raise Exception("没有%s节点(%s)的输出配置信息" % (name, node_id))
|
|
140
|
+
# 起始节点是否是输出节点
|
|
141
|
+
total_anchor = len(template_type["INPUT"]) + len(template_type["OUTPUT"])
|
|
142
|
+
if anchor > total_anchor:
|
|
143
|
+
raise Exception("%s节点的锚点(%s)超出了锚点范围(%s)" % (name, anchor, total_anchor))
|
|
144
|
+
if idx:
|
|
145
|
+
if anchor > len(template_type["INPUT"]):
|
|
146
|
+
raise Exception("结尾节点的锚点(%s)必须是图输入锚点" % anchor)
|
|
147
|
+
target_data_type = template_type["INPUT"][anchor]
|
|
148
|
+
else:
|
|
149
|
+
if anchor - len(template_type["INPUT"]) < 0:
|
|
150
|
+
raise Exception("起始节点的锚点(%s)不能是输入锚点" % anchor)
|
|
151
|
+
source_data_type = template_type["OUTPUT"][
|
|
152
|
+
anchor - len(template_type["INPUT"])
|
|
153
|
+
]
|
|
154
|
+
# 节点的起始节点的数据类型和结尾结点的数据类型是否匹配
|
|
155
|
+
if target_data_type.find(source_data_type) == -1:
|
|
156
|
+
raise Exception("起始节点%s和结尾节点的数据类型不匹配%s" % (source_data_type, target_data_type))
|
|
157
|
+
# 检验一个节点是否能接收两个节点的数据
|
|
158
|
+
if not is_run:
|
|
159
|
+
if target_data_type.find("list") == -1:
|
|
160
|
+
edge_info = EdgesTB.find(
|
|
161
|
+
graph_id=graph_id, target_id=target_id, target_anchor=target_anchor
|
|
162
|
+
)
|
|
163
|
+
if edge_info:
|
|
164
|
+
raise Exception("目标节点的锚点只接收一个数据")
|
|
165
|
+
# 是否有环在运行的时候校验,在此处不做检验
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def check_node(graph_id):
|
|
170
|
+
"""检查节点的合理性"""
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def check_module_code(module_id):
|
|
175
|
+
"""检查自定义组件是否合理
|
|
176
|
+
module_id: 组件id
|
|
177
|
+
"""
|
|
178
|
+
module = Module(module_id)
|
|
179
|
+
return module.check_code()
|