yms-kan 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yms_kan/KANLayer.py +364 -0
- yms_kan/LBFGS.py +492 -0
- yms_kan/MLP.py +361 -0
- yms_kan/MultKAN.py +3085 -0
- yms_kan/Symbolic_KANLayer.py +270 -0
- yms_kan/__init__.py +4 -0
- yms_kan/compiler.py +498 -0
- yms_kan/experiment.py +50 -0
- yms_kan/feynman.py +739 -0
- yms_kan/hypothesis.py +695 -0
- yms_kan/spline.py +144 -0
- yms_kan/tool.py +304 -0
- yms_kan/train_eval_utils.py +175 -0
- yms_kan/utils.py +661 -0
- yms_kan/version.py +1 -0
- yms_kan-0.0.1.dist-info/METADATA +11 -0
- yms_kan-0.0.1.dist-info/RECORD +20 -0
- yms_kan-0.0.1.dist-info/WHEEL +5 -0
- yms_kan-0.0.1.dist-info/licenses/LICENSE +21 -0
- yms_kan-0.0.1.dist-info/top_level.txt +1 -0
yms_kan/spline.py
ADDED
@@ -0,0 +1,144 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def B_batch(x, grid, k=0, extend=True, device='cpu'):
|
5
|
+
'''
|
6
|
+
evaludate x on B-spline bases
|
7
|
+
|
8
|
+
Args:
|
9
|
+
-----
|
10
|
+
x : 2D torch.tensor
|
11
|
+
inputs, shape (number of splines, number of samples)
|
12
|
+
grid : 2D torch.tensor
|
13
|
+
grids, shape (number of splines, number of grid points)
|
14
|
+
k : int
|
15
|
+
the piecewise polynomial order of splines.
|
16
|
+
extend : bool
|
17
|
+
If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
|
18
|
+
device : str
|
19
|
+
devicde
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
--------
|
23
|
+
spline values : 3D torch.tensor
|
24
|
+
shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
|
25
|
+
|
26
|
+
Example
|
27
|
+
-------
|
28
|
+
>>> from yms_kan.spline import B_batch
|
29
|
+
>>> x = torch.rand(100,2)
|
30
|
+
>>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
|
31
|
+
>>> B_batch(x, grid, k=3).shape
|
32
|
+
'''
|
33
|
+
|
34
|
+
x = x.unsqueeze(dim=2)
|
35
|
+
grid = grid.unsqueeze(dim=0)
|
36
|
+
|
37
|
+
if k == 0:
|
38
|
+
value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
|
39
|
+
else:
|
40
|
+
B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
|
41
|
+
|
42
|
+
value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
|
43
|
+
grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
|
44
|
+
|
45
|
+
# in case grid is degenerate
|
46
|
+
value = torch.nan_to_num(value)
|
47
|
+
return value
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
def coef2curve(x_eval, grid, coef, k, device="cpu"):
|
52
|
+
'''
|
53
|
+
converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
|
54
|
+
|
55
|
+
Args:
|
56
|
+
-----
|
57
|
+
x_eval : 2D torch.tensor
|
58
|
+
shape (batch, in_dim)
|
59
|
+
grid : 2D torch.tensor
|
60
|
+
shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
|
61
|
+
coef : 3D torch.tensor
|
62
|
+
shape (in_dim, out_dim, G+k)
|
63
|
+
k : int
|
64
|
+
the piecewise polynomial order of splines.
|
65
|
+
device : str
|
66
|
+
devicde
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
--------
|
70
|
+
y_eval : 3D torch.tensor
|
71
|
+
shape (batch, in_dim, out_dim)
|
72
|
+
|
73
|
+
'''
|
74
|
+
|
75
|
+
b_splines = B_batch(x_eval, grid, k=k)
|
76
|
+
y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
|
77
|
+
|
78
|
+
return y_eval
|
79
|
+
|
80
|
+
|
81
|
+
def curve2coef(x_eval, y_eval, grid, k):
|
82
|
+
'''
|
83
|
+
converting B-spline curves to B-spline coefficients using least squares.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
-----
|
87
|
+
x_eval : 2D torch.tensor
|
88
|
+
shape (batch, in_dim)
|
89
|
+
y_eval : 3D torch.tensor
|
90
|
+
shape (batch, in_dim, out_dim)
|
91
|
+
grid : 2D torch.tensor
|
92
|
+
shape (in_dim, grid+2*k)
|
93
|
+
k : int
|
94
|
+
spline order
|
95
|
+
lamb : float
|
96
|
+
regularized least square lambda
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
--------
|
100
|
+
coef : 3D torch.tensor
|
101
|
+
shape (in_dim, out_dim, G+k)
|
102
|
+
'''
|
103
|
+
#print('haha', x_eval.shape, y_eval.shape, grid.shape)
|
104
|
+
batch = x_eval.shape[0]
|
105
|
+
in_dim = x_eval.shape[1]
|
106
|
+
out_dim = y_eval.shape[2]
|
107
|
+
n_coef = grid.shape[1] - k - 1
|
108
|
+
mat = B_batch(x_eval, grid, k)
|
109
|
+
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
|
110
|
+
#print('mat', mat.shape)
|
111
|
+
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
|
112
|
+
#print('y_eval', y_eval.shape)
|
113
|
+
device = mat.device
|
114
|
+
|
115
|
+
#coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
|
116
|
+
try:
|
117
|
+
coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0]
|
118
|
+
except:
|
119
|
+
print('lstsq failed')
|
120
|
+
|
121
|
+
# manual psuedo-inverse
|
122
|
+
'''lamb=1e-8
|
123
|
+
XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
|
124
|
+
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
|
125
|
+
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
|
126
|
+
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
|
127
|
+
A = XtX + lamb * identity
|
128
|
+
B = Xty
|
129
|
+
coef = (A.pinverse() @ B)[:,:,:,0]'''
|
130
|
+
|
131
|
+
return coef
|
132
|
+
|
133
|
+
|
134
|
+
def extend_grid(grid, k_extend=0):
|
135
|
+
'''
|
136
|
+
extend grid
|
137
|
+
'''
|
138
|
+
h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
|
139
|
+
|
140
|
+
for i in range(k_extend):
|
141
|
+
grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
|
142
|
+
grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
|
143
|
+
|
144
|
+
return grid
|
yms_kan/tool.py
ADDED
@@ -0,0 +1,304 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
import wandb
|
6
|
+
|
7
|
+
|
8
|
+
def initialize_results_file(results_file, result_info):
|
9
|
+
"""
|
10
|
+
初始化结果文件,确保文件存在且第一行包含指定的内容。
|
11
|
+
|
12
|
+
参数:
|
13
|
+
results_file (str): 结果文件的路径。
|
14
|
+
result_info (list): 需要写入的第一行内容列表。
|
15
|
+
"""
|
16
|
+
# 处理 result_info,在每个单词后添加两个空格
|
17
|
+
result_info_str = " ".join(result_info) + '\n'
|
18
|
+
# 检查文件是否存在
|
19
|
+
if os.path.exists(results_file):
|
20
|
+
# 如果文件存在,读取第一行
|
21
|
+
with open(results_file, "r") as f:
|
22
|
+
first_line = f.readline().strip()
|
23
|
+
# 检查第一行是否与 result_info 一致
|
24
|
+
if first_line == result_info_str.strip():
|
25
|
+
print(f"文件 {results_file} 已存在且第一行已包含 result_info,不进行写入。")
|
26
|
+
else:
|
27
|
+
# 如果不一致,写入 result_info
|
28
|
+
with open(results_file, "w") as f:
|
29
|
+
f.write(result_info_str)
|
30
|
+
print(f"文件 {results_file} 已被重新初始化。")
|
31
|
+
else:
|
32
|
+
# 如果文件不存在,创建并写入 result_info
|
33
|
+
with open(results_file, "w") as f:
|
34
|
+
f.write(result_info_str)
|
35
|
+
print(f"文件 {results_file} 已创建并写入 result_info。")
|
36
|
+
|
37
|
+
|
38
|
+
def write_results_file(file_path: str,
|
39
|
+
data_dict: dict,
|
40
|
+
column_order: list,
|
41
|
+
float_precision: int = 5) -> None:
|
42
|
+
"""
|
43
|
+
通用格式化文本行写入函数(支持列表形式数据)
|
44
|
+
|
45
|
+
参数:
|
46
|
+
file_path: 目标文件路径
|
47
|
+
data_dict: 包含数据的字典,键为列名,值为列表
|
48
|
+
column_order: 列顺序列表,元素为字典键
|
49
|
+
float_precision: 浮点数精度位数 (默认5位)
|
50
|
+
"""
|
51
|
+
# 验证数据格式
|
52
|
+
rows = None
|
53
|
+
for key in data_dict:
|
54
|
+
if not isinstance(data_dict[key], list):
|
55
|
+
raise ValueError(f"Value for key '{key}' is not a list")
|
56
|
+
if rows is None:
|
57
|
+
rows = len(data_dict[key])
|
58
|
+
else:
|
59
|
+
if len(data_dict[key]) != rows:
|
60
|
+
raise ValueError("All lists in data_dict must have the same length")
|
61
|
+
|
62
|
+
# 辅助函数:格式化单个值
|
63
|
+
def format_value(value, column_name):
|
64
|
+
if isinstance(value, (int, np.integer)):
|
65
|
+
return f"{value:d}"
|
66
|
+
elif isinstance(value, (float, np.floating)):
|
67
|
+
if column_name in ['train_losses', 'val_losses']:
|
68
|
+
return f"{value:.{float_precision + 1}f}"
|
69
|
+
elif column_name == 'lrs':
|
70
|
+
return f"{value:.8f}"
|
71
|
+
else:
|
72
|
+
return f"{value:.{float_precision}f}"
|
73
|
+
elif isinstance(value, str):
|
74
|
+
return value
|
75
|
+
else:
|
76
|
+
return str(value)
|
77
|
+
|
78
|
+
# 计算列宽
|
79
|
+
column_widths = []
|
80
|
+
for col in column_order:
|
81
|
+
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
82
|
+
if dict_key not in data_dict:
|
83
|
+
raise ValueError(f"Missing required column: {dict_key}")
|
84
|
+
values = data_dict[dict_key]
|
85
|
+
|
86
|
+
max_width = len(col)
|
87
|
+
for val in values:
|
88
|
+
fmt_val = format_value(val, col)
|
89
|
+
max_width = max(max_width, len(fmt_val))
|
90
|
+
column_widths.append(max_width)
|
91
|
+
|
92
|
+
# 生成格式化行
|
93
|
+
lines = []
|
94
|
+
for i in range(rows):
|
95
|
+
row = []
|
96
|
+
for j, col in enumerate(column_order):
|
97
|
+
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
98
|
+
val = data_dict[dict_key][i]
|
99
|
+
fmt_val = format_value(val, col)
|
100
|
+
|
101
|
+
# 对齐处理
|
102
|
+
if j == len(column_order) - 1:
|
103
|
+
fmt_val = fmt_val.ljust(column_widths[j])
|
104
|
+
else:
|
105
|
+
fmt_val = fmt_val.rjust(column_widths[j])
|
106
|
+
row.append(fmt_val)
|
107
|
+
lines.append(" ".join(row) + '\n')
|
108
|
+
|
109
|
+
# 写入文件
|
110
|
+
with open(file_path, 'a', encoding='utf-8') as f:
|
111
|
+
f.writelines(lines)
|
112
|
+
|
113
|
+
|
114
|
+
def append_to_results_file(file_path: str,
|
115
|
+
data_dict: dict,
|
116
|
+
column_order: list,
|
117
|
+
float_precision: int = 5) -> None:
|
118
|
+
"""
|
119
|
+
通用格式化文本行写入函数
|
120
|
+
|
121
|
+
参数:
|
122
|
+
file_path: 目标文件路径
|
123
|
+
data_dict: 包含数据的字典,键为列名
|
124
|
+
column_order: 列顺序列表,元素为字典键
|
125
|
+
float_precision: 浮点数精度位数 (默认5位)
|
126
|
+
"""
|
127
|
+
# 检查 data_dict 中的值是否为列表
|
128
|
+
all_values_are_lists = all(isinstance(value, list) for value in data_dict.values())
|
129
|
+
if all_values_are_lists:
|
130
|
+
num_rows = len(next(iter(data_dict.values())))
|
131
|
+
# 逐行处理
|
132
|
+
for row_index in range(num_rows):
|
133
|
+
formatted_data = []
|
134
|
+
column_widths = []
|
135
|
+
for col in column_order:
|
136
|
+
# 处理字典键的别名
|
137
|
+
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
138
|
+
# 如果键不存在,跳过该列
|
139
|
+
if dict_key not in data_dict:
|
140
|
+
continue
|
141
|
+
value_list = data_dict[dict_key]
|
142
|
+
if row_index >= len(value_list):
|
143
|
+
continue
|
144
|
+
value = value_list[row_index]
|
145
|
+
|
146
|
+
# 根据数据类型进行格式化
|
147
|
+
if isinstance(value, (int, np.integer)):
|
148
|
+
fmt_value = f"{value:d}"
|
149
|
+
elif isinstance(value, (float, np.floating)):
|
150
|
+
if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
151
|
+
fmt_value = f"{value:.{float_precision + 1}f}"
|
152
|
+
elif col == 'lrs': # 如果列名是'lrs',保留8位小数
|
153
|
+
fmt_value = f"{value:.8f}"
|
154
|
+
else:
|
155
|
+
fmt_value = f"{value:.{float_precision}f}"
|
156
|
+
elif isinstance(value, str):
|
157
|
+
fmt_value = value
|
158
|
+
else: # 处理其他类型转换为字符串
|
159
|
+
fmt_value = str(value)
|
160
|
+
|
161
|
+
# 取列名长度和数值长度的最大值作为列宽
|
162
|
+
column_width = max(len(col), len(fmt_value))
|
163
|
+
column_widths.append(column_width)
|
164
|
+
|
165
|
+
# 应用列宽对齐
|
166
|
+
if col == column_order[-1]: # 最后一列左边对齐
|
167
|
+
fmt_value = fmt_value.ljust(column_width)
|
168
|
+
else:
|
169
|
+
fmt_value = fmt_value.rjust(column_width)
|
170
|
+
|
171
|
+
formatted_data.append(fmt_value)
|
172
|
+
|
173
|
+
# 构建文本行并写入,列之间用两个空格分隔
|
174
|
+
if formatted_data:
|
175
|
+
line = " ".join(formatted_data) + '\n'
|
176
|
+
with open(file_path, 'a', encoding='utf-8') as f:
|
177
|
+
f.write(line)
|
178
|
+
else:
|
179
|
+
# 非列表情况,原逻辑处理
|
180
|
+
# 计算每列的最大宽度
|
181
|
+
column_widths = []
|
182
|
+
formatted_data = []
|
183
|
+
for col in column_order:
|
184
|
+
# 处理字典键的别名
|
185
|
+
dict_key = 'val_accuracies' if col == 'accuracies' else col
|
186
|
+
# 如果键不存在,跳过该列
|
187
|
+
if dict_key not in data_dict:
|
188
|
+
continue
|
189
|
+
|
190
|
+
value = data_dict[dict_key]
|
191
|
+
|
192
|
+
# 根据数据类型进行格式化
|
193
|
+
if isinstance(value, (int, np.integer)):
|
194
|
+
fmt_value = f"{value:d}"
|
195
|
+
elif isinstance(value, (float, np.floating)):
|
196
|
+
if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
197
|
+
fmt_value = f"{value:.{float_precision + 1}f}"
|
198
|
+
elif col == 'lrs': # 如果列名是'lrs',保留8位小数
|
199
|
+
fmt_value = f"{value:.8f}"
|
200
|
+
else:
|
201
|
+
fmt_value = f"{value:.{float_precision}f}"
|
202
|
+
elif isinstance(value, str):
|
203
|
+
fmt_value = value
|
204
|
+
else: # 处理其他类型转换为字符串
|
205
|
+
fmt_value = str(value)
|
206
|
+
|
207
|
+
# 取列名长度和数值长度的最大值作为列宽
|
208
|
+
column_width = max(len(col), len(fmt_value))
|
209
|
+
column_widths.append(column_width)
|
210
|
+
|
211
|
+
# 应用列宽对齐
|
212
|
+
if col == column_order[-1]: # 最后一列左边对齐
|
213
|
+
fmt_value = fmt_value.ljust(column_width)
|
214
|
+
else:
|
215
|
+
fmt_value = fmt_value.rjust(column_width)
|
216
|
+
|
217
|
+
formatted_data.append(fmt_value)
|
218
|
+
|
219
|
+
# 构建文本行并写入,列之间用两个空格分隔
|
220
|
+
if formatted_data:
|
221
|
+
line = " ".join(formatted_data) + '\n'
|
222
|
+
with open(file_path, 'a', encoding='utf-8') as f:
|
223
|
+
f.write(line)
|
224
|
+
|
225
|
+
|
226
|
+
# def append_to_results_file(file_path: str,
|
227
|
+
# data_dict: dict,
|
228
|
+
# column_order: list,
|
229
|
+
# column_widths: list = None,
|
230
|
+
# float_precision: int = 5) -> None:
|
231
|
+
# """
|
232
|
+
# 通用格式化文本行写入函数
|
233
|
+
#
|
234
|
+
# 参数:
|
235
|
+
# file_path: 目标文件路径
|
236
|
+
# data_dict: 包含数据的字典,键为列名
|
237
|
+
# column_order: 列顺序列表,元素为字典键
|
238
|
+
# column_widths: 每列字符宽度列表 (可选)
|
239
|
+
# float_precision: 浮点数精度位数 (默认4位)
|
240
|
+
# """
|
241
|
+
# formatted_data = []
|
242
|
+
#
|
243
|
+
# # 遍历指定列顺序处理数据
|
244
|
+
# for i, col in enumerate(column_order):
|
245
|
+
# # 处理字典键的别名
|
246
|
+
# if col == 'accuracies':
|
247
|
+
# dict_key = 'val_accuracies'
|
248
|
+
# else:
|
249
|
+
# dict_key = col
|
250
|
+
#
|
251
|
+
# if dict_key not in data_dict:
|
252
|
+
# raise ValueError(f"Missing required column: {dict_key}")
|
253
|
+
#
|
254
|
+
# value = data_dict[dict_key]
|
255
|
+
#
|
256
|
+
# # 根据数据类型进行格式化
|
257
|
+
# if isinstance(value, (int, np.integer)):
|
258
|
+
# fmt_value = f"{value:d}"
|
259
|
+
# elif isinstance(value, (float, np.floating)):
|
260
|
+
# if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
|
261
|
+
# fmt_value = f"{value:.{float_precision + 1}f}"
|
262
|
+
# elif col == 'lr': # 如果列名是'lr',保留8位小数
|
263
|
+
# fmt_value = f"{value:.8f}"
|
264
|
+
# else:
|
265
|
+
# fmt_value = f"{value:.{float_precision}f}"
|
266
|
+
# elif isinstance(value, str):
|
267
|
+
# fmt_value = value
|
268
|
+
# else: # 处理其他类型转换为字符串
|
269
|
+
# fmt_value = str(value)
|
270
|
+
#
|
271
|
+
# # 应用列宽对齐
|
272
|
+
# if column_widths and i < len(column_widths):
|
273
|
+
# try:
|
274
|
+
# if i == len(column_order) - 1: # 最后一列左边对齐
|
275
|
+
# fmt_value = fmt_value.ljust(column_widths[i])
|
276
|
+
# else:
|
277
|
+
# fmt_value = fmt_value.rjust(column_widths[i])
|
278
|
+
# except TypeError: # 处理非字符串类型
|
279
|
+
# if i == len(column_order) - 1: # 最后一列左边对齐
|
280
|
+
# fmt_value = str(fmt_value).ljust(column_widths[i])
|
281
|
+
# else:
|
282
|
+
# fmt_value = str(fmt_value).rjust(column_widths[i])
|
283
|
+
#
|
284
|
+
# formatted_data.append(fmt_value)
|
285
|
+
#
|
286
|
+
# # 构建文本行并写入
|
287
|
+
# line = '\t'.join(formatted_data) + '\n'
|
288
|
+
# with open(file_path, 'a', encoding='utf-8') as f:
|
289
|
+
# f.write(line)
|
290
|
+
|
291
|
+
|
292
|
+
def get_wandb_key(key_path='tools/wandb_key.txt'):
|
293
|
+
with open(key_path, 'r', encoding='utf-8') as f:
|
294
|
+
key = f.read()
|
295
|
+
return key
|
296
|
+
|
297
|
+
|
298
|
+
def wandb_use(project=None, name=None, key_path='tools/wandb_key.txt'):
|
299
|
+
run = None
|
300
|
+
if project is not None:
|
301
|
+
wandb_key = get_wandb_key(key_path)
|
302
|
+
wandb.login(key=wandb_key)
|
303
|
+
run = wandb.init(project=project, name=name)
|
304
|
+
return run
|
@@ -0,0 +1,175 @@
|
|
1
|
+
import math
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
from enum import Enum, auto
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from matplotlib import pyplot as plt
|
9
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from yms_kan import LBFGS
|
13
|
+
|
14
|
+
|
15
|
+
class TaskType(Enum):
|
16
|
+
classification = auto()
|
17
|
+
zlm = auto()
|
18
|
+
|
19
|
+
|
20
|
+
def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
|
21
|
+
lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
|
22
|
+
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
|
23
|
+
stop_grid_update_step=100,
|
24
|
+
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
|
25
|
+
singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
|
26
|
+
# result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
|
27
|
+
# 'precisions', 'recalls', 'f1-scores']
|
28
|
+
# initialize_results_file(results_file, result_info)
|
29
|
+
all_predictions = []
|
30
|
+
all_labels = []
|
31
|
+
if lamb > 0. and not model.save_act:
|
32
|
+
print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
|
33
|
+
|
34
|
+
old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
|
35
|
+
if label is not None:
|
36
|
+
label = label.to(model.device)
|
37
|
+
|
38
|
+
if loss_fn is None:
|
39
|
+
loss_fn = lambda x, y: torch.mean((x - y) ** 2)
|
40
|
+
else:
|
41
|
+
loss_fn = loss_fn
|
42
|
+
|
43
|
+
grid_update_freq = int(stop_grid_update_step / grid_update_num)
|
44
|
+
|
45
|
+
if opt == "Adam":
|
46
|
+
optimizer = torch.optim.Adam(model.get_params(), lr=lr)
|
47
|
+
elif opt == "LBFGS":
|
48
|
+
optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
|
49
|
+
tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
|
50
|
+
else:
|
51
|
+
optimizer = torch.optim.SGD(model.get_params(), lr=lr)
|
52
|
+
|
53
|
+
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
|
54
|
+
|
55
|
+
results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
|
56
|
+
'precisions': [], 'recalls': [], 'f1-scores': []}
|
57
|
+
|
58
|
+
steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
|
59
|
+
|
60
|
+
train_loss = torch.zeros(1).to(model.device)
|
61
|
+
reg_ = torch.zeros(1).to(model.device)
|
62
|
+
|
63
|
+
def closure():
|
64
|
+
nonlocal train_loss, reg_
|
65
|
+
optimizer.zero_grad()
|
66
|
+
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
|
67
|
+
loss = loss_fn(pred, batch_train_label)
|
68
|
+
if model.save_act:
|
69
|
+
if reg_metric == 'edge_backward':
|
70
|
+
model.attribute()
|
71
|
+
if reg_metric == 'node_backward':
|
72
|
+
model.node_attribute()
|
73
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
74
|
+
else:
|
75
|
+
reg_ = torch.tensor(0.)
|
76
|
+
objective = loss + lamb * reg_
|
77
|
+
train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
|
78
|
+
objective.backward()
|
79
|
+
return objective
|
80
|
+
|
81
|
+
if save_fig:
|
82
|
+
if not os.path.exists(img_folder):
|
83
|
+
os.makedirs(img_folder)
|
84
|
+
|
85
|
+
for epoch in range(epochs):
|
86
|
+
|
87
|
+
if epoch == epochs - 1 and old_save_act:
|
88
|
+
model.save_act = True
|
89
|
+
|
90
|
+
if save_fig and epoch % save_fig_freq == 0:
|
91
|
+
save_act = model.save_act
|
92
|
+
model.save_act = True
|
93
|
+
|
94
|
+
train_indices = np.arange(dataset['train_input'].shape[0])
|
95
|
+
np.random.shuffle(train_indices)
|
96
|
+
train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
|
97
|
+
for batch_num in train_pbar:
|
98
|
+
step = epoch * steps + batch_num + 1
|
99
|
+
i = batch_num * batch_size
|
100
|
+
batch_train_id = train_indices[i:i + batch_size]
|
101
|
+
batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
|
102
|
+
batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
|
103
|
+
|
104
|
+
if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
|
105
|
+
model.update_grid(batch_train_input)
|
106
|
+
|
107
|
+
if opt == "LBFGS":
|
108
|
+
optimizer.step(closure)
|
109
|
+
|
110
|
+
else:
|
111
|
+
optimizer.zero_grad()
|
112
|
+
pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
|
113
|
+
y_th=y_th)
|
114
|
+
loss = loss_fn(pred, batch_train_label)
|
115
|
+
if model.save_act:
|
116
|
+
if reg_metric == 'edge_backward':
|
117
|
+
model.attribute()
|
118
|
+
if reg_metric == 'node_backward':
|
119
|
+
model.node_attribute()
|
120
|
+
reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
|
121
|
+
else:
|
122
|
+
reg_ = torch.tensor(0.)
|
123
|
+
loss = loss + lamb * reg_
|
124
|
+
train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
|
125
|
+
loss.backward()
|
126
|
+
optimizer.step()
|
127
|
+
train_pbar.set_postfix(loss=train_loss.item())
|
128
|
+
|
129
|
+
val_loss = torch.zeros(1).to(model.device)
|
130
|
+
with torch.no_grad():
|
131
|
+
test_indices = np.arange(dataset['test_input'].shape[0])
|
132
|
+
np.random.shuffle(test_indices)
|
133
|
+
test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
|
134
|
+
test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
|
135
|
+
for batch_num in test_pbar:
|
136
|
+
i = batch_num * batch_size_test
|
137
|
+
batch_test_id = test_indices[i:i + batch_size_test]
|
138
|
+
batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
|
139
|
+
batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
|
140
|
+
|
141
|
+
outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
|
142
|
+
y_th=y_th)
|
143
|
+
|
144
|
+
loss = loss_fn(outputs, batch_test_label)
|
145
|
+
|
146
|
+
val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
|
147
|
+
test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
|
148
|
+
if label is not None:
|
149
|
+
diffs = torch.abs(outputs - label)
|
150
|
+
closest_indices = torch.argmin(diffs, dim=1)
|
151
|
+
closest_values = label[closest_indices]
|
152
|
+
all_predictions.extend(closest_values.detach().cpu().numpy())
|
153
|
+
all_labels.extend(batch_test_label.detach().cpu().numpy())
|
154
|
+
|
155
|
+
lr_scheduler.step(val_loss)
|
156
|
+
|
157
|
+
results['train_losses'].append(train_loss.cpu().item())
|
158
|
+
results['val_losses'].append(val_loss.cpu().item())
|
159
|
+
results['regularize'].append(reg_.cpu().item())
|
160
|
+
|
161
|
+
if save_fig and epoch % save_fig_freq == 0:
|
162
|
+
model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
|
163
|
+
beta=beta)
|
164
|
+
plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
|
165
|
+
plt.close()
|
166
|
+
model.save_act = save_act
|
167
|
+
|
168
|
+
# append_to_results_file(results_file, results, result_info)
|
169
|
+
model.log_history('fit')
|
170
|
+
model.symbolic_enabled = old_symbolic_enabled
|
171
|
+
return results
|
172
|
+
|
173
|
+
|
174
|
+
if __name__ == '__main__':
|
175
|
+
print(TaskType.zlm.value)
|