yms-kan 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/spline.py ADDED
@@ -0,0 +1,144 @@
1
+ import torch
2
+
3
+
4
+ def B_batch(x, grid, k=0, extend=True, device='cpu'):
5
+ '''
6
+ evaludate x on B-spline bases
7
+
8
+ Args:
9
+ -----
10
+ x : 2D torch.tensor
11
+ inputs, shape (number of splines, number of samples)
12
+ grid : 2D torch.tensor
13
+ grids, shape (number of splines, number of grid points)
14
+ k : int
15
+ the piecewise polynomial order of splines.
16
+ extend : bool
17
+ If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True
18
+ device : str
19
+ devicde
20
+
21
+ Returns:
22
+ --------
23
+ spline values : 3D torch.tensor
24
+ shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order.
25
+
26
+ Example
27
+ -------
28
+ >>> from yms_kan.spline import B_batch
29
+ >>> x = torch.rand(100,2)
30
+ >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11)
31
+ >>> B_batch(x, grid, k=3).shape
32
+ '''
33
+
34
+ x = x.unsqueeze(dim=2)
35
+ grid = grid.unsqueeze(dim=0)
36
+
37
+ if k == 0:
38
+ value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:])
39
+ else:
40
+ B_km1 = B_batch(x[:,:,0], grid=grid[0], k=k - 1)
41
+
42
+ value = (x - grid[:, :, :-(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, :-(k + 1)]) * B_km1[:, :, :-1] + (
43
+ grid[:, :, k + 1:] - x) / (grid[:, :, k + 1:] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:]
44
+
45
+ # in case grid is degenerate
46
+ value = torch.nan_to_num(value)
47
+ return value
48
+
49
+
50
+
51
+ def coef2curve(x_eval, grid, coef, k, device="cpu"):
52
+ '''
53
+ converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis).
54
+
55
+ Args:
56
+ -----
57
+ x_eval : 2D torch.tensor
58
+ shape (batch, in_dim)
59
+ grid : 2D torch.tensor
60
+ shape (in_dim, G+2k). G: the number of grid intervals; k: spline order.
61
+ coef : 3D torch.tensor
62
+ shape (in_dim, out_dim, G+k)
63
+ k : int
64
+ the piecewise polynomial order of splines.
65
+ device : str
66
+ devicde
67
+
68
+ Returns:
69
+ --------
70
+ y_eval : 3D torch.tensor
71
+ shape (batch, in_dim, out_dim)
72
+
73
+ '''
74
+
75
+ b_splines = B_batch(x_eval, grid, k=k)
76
+ y_eval = torch.einsum('ijk,jlk->ijl', b_splines, coef.to(b_splines.device))
77
+
78
+ return y_eval
79
+
80
+
81
+ def curve2coef(x_eval, y_eval, grid, k):
82
+ '''
83
+ converting B-spline curves to B-spline coefficients using least squares.
84
+
85
+ Args:
86
+ -----
87
+ x_eval : 2D torch.tensor
88
+ shape (batch, in_dim)
89
+ y_eval : 3D torch.tensor
90
+ shape (batch, in_dim, out_dim)
91
+ grid : 2D torch.tensor
92
+ shape (in_dim, grid+2*k)
93
+ k : int
94
+ spline order
95
+ lamb : float
96
+ regularized least square lambda
97
+
98
+ Returns:
99
+ --------
100
+ coef : 3D torch.tensor
101
+ shape (in_dim, out_dim, G+k)
102
+ '''
103
+ #print('haha', x_eval.shape, y_eval.shape, grid.shape)
104
+ batch = x_eval.shape[0]
105
+ in_dim = x_eval.shape[1]
106
+ out_dim = y_eval.shape[2]
107
+ n_coef = grid.shape[1] - k - 1
108
+ mat = B_batch(x_eval, grid, k)
109
+ mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef)
110
+ #print('mat', mat.shape)
111
+ y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3)
112
+ #print('y_eval', y_eval.shape)
113
+ device = mat.device
114
+
115
+ #coef = torch.linalg.lstsq(mat, y_eval, driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]
116
+ try:
117
+ coef = torch.linalg.lstsq(mat, y_eval).solution[:,:,:,0]
118
+ except:
119
+ print('lstsq failed')
120
+
121
+ # manual psuedo-inverse
122
+ '''lamb=1e-8
123
+ XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
124
+ Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
125
+ n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
126
+ identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
127
+ A = XtX + lamb * identity
128
+ B = Xty
129
+ coef = (A.pinverse() @ B)[:,:,:,0]'''
130
+
131
+ return coef
132
+
133
+
134
+ def extend_grid(grid, k_extend=0):
135
+ '''
136
+ extend grid
137
+ '''
138
+ h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1)
139
+
140
+ for i in range(k_extend):
141
+ grid = torch.cat([grid[:, [0]] - h, grid], dim=1)
142
+ grid = torch.cat([grid, grid[:, [-1]] + h], dim=1)
143
+
144
+ return grid
yms_kan/tool.py ADDED
@@ -0,0 +1,304 @@
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import wandb
6
+
7
+
8
+ def initialize_results_file(results_file, result_info):
9
+ """
10
+ 初始化结果文件,确保文件存在且第一行包含指定的内容。
11
+
12
+ 参数:
13
+ results_file (str): 结果文件的路径。
14
+ result_info (list): 需要写入的第一行内容列表。
15
+ """
16
+ # 处理 result_info,在每个单词后添加两个空格
17
+ result_info_str = " ".join(result_info) + '\n'
18
+ # 检查文件是否存在
19
+ if os.path.exists(results_file):
20
+ # 如果文件存在,读取第一行
21
+ with open(results_file, "r") as f:
22
+ first_line = f.readline().strip()
23
+ # 检查第一行是否与 result_info 一致
24
+ if first_line == result_info_str.strip():
25
+ print(f"文件 {results_file} 已存在且第一行已包含 result_info,不进行写入。")
26
+ else:
27
+ # 如果不一致,写入 result_info
28
+ with open(results_file, "w") as f:
29
+ f.write(result_info_str)
30
+ print(f"文件 {results_file} 已被重新初始化。")
31
+ else:
32
+ # 如果文件不存在,创建并写入 result_info
33
+ with open(results_file, "w") as f:
34
+ f.write(result_info_str)
35
+ print(f"文件 {results_file} 已创建并写入 result_info。")
36
+
37
+
38
+ def write_results_file(file_path: str,
39
+ data_dict: dict,
40
+ column_order: list,
41
+ float_precision: int = 5) -> None:
42
+ """
43
+ 通用格式化文本行写入函数(支持列表形式数据)
44
+
45
+ 参数:
46
+ file_path: 目标文件路径
47
+ data_dict: 包含数据的字典,键为列名,值为列表
48
+ column_order: 列顺序列表,元素为字典键
49
+ float_precision: 浮点数精度位数 (默认5位)
50
+ """
51
+ # 验证数据格式
52
+ rows = None
53
+ for key in data_dict:
54
+ if not isinstance(data_dict[key], list):
55
+ raise ValueError(f"Value for key '{key}' is not a list")
56
+ if rows is None:
57
+ rows = len(data_dict[key])
58
+ else:
59
+ if len(data_dict[key]) != rows:
60
+ raise ValueError("All lists in data_dict must have the same length")
61
+
62
+ # 辅助函数:格式化单个值
63
+ def format_value(value, column_name):
64
+ if isinstance(value, (int, np.integer)):
65
+ return f"{value:d}"
66
+ elif isinstance(value, (float, np.floating)):
67
+ if column_name in ['train_losses', 'val_losses']:
68
+ return f"{value:.{float_precision + 1}f}"
69
+ elif column_name == 'lrs':
70
+ return f"{value:.8f}"
71
+ else:
72
+ return f"{value:.{float_precision}f}"
73
+ elif isinstance(value, str):
74
+ return value
75
+ else:
76
+ return str(value)
77
+
78
+ # 计算列宽
79
+ column_widths = []
80
+ for col in column_order:
81
+ dict_key = 'val_accuracies' if col == 'accuracies' else col
82
+ if dict_key not in data_dict:
83
+ raise ValueError(f"Missing required column: {dict_key}")
84
+ values = data_dict[dict_key]
85
+
86
+ max_width = len(col)
87
+ for val in values:
88
+ fmt_val = format_value(val, col)
89
+ max_width = max(max_width, len(fmt_val))
90
+ column_widths.append(max_width)
91
+
92
+ # 生成格式化行
93
+ lines = []
94
+ for i in range(rows):
95
+ row = []
96
+ for j, col in enumerate(column_order):
97
+ dict_key = 'val_accuracies' if col == 'accuracies' else col
98
+ val = data_dict[dict_key][i]
99
+ fmt_val = format_value(val, col)
100
+
101
+ # 对齐处理
102
+ if j == len(column_order) - 1:
103
+ fmt_val = fmt_val.ljust(column_widths[j])
104
+ else:
105
+ fmt_val = fmt_val.rjust(column_widths[j])
106
+ row.append(fmt_val)
107
+ lines.append(" ".join(row) + '\n')
108
+
109
+ # 写入文件
110
+ with open(file_path, 'a', encoding='utf-8') as f:
111
+ f.writelines(lines)
112
+
113
+
114
+ def append_to_results_file(file_path: str,
115
+ data_dict: dict,
116
+ column_order: list,
117
+ float_precision: int = 5) -> None:
118
+ """
119
+ 通用格式化文本行写入函数
120
+
121
+ 参数:
122
+ file_path: 目标文件路径
123
+ data_dict: 包含数据的字典,键为列名
124
+ column_order: 列顺序列表,元素为字典键
125
+ float_precision: 浮点数精度位数 (默认5位)
126
+ """
127
+ # 检查 data_dict 中的值是否为列表
128
+ all_values_are_lists = all(isinstance(value, list) for value in data_dict.values())
129
+ if all_values_are_lists:
130
+ num_rows = len(next(iter(data_dict.values())))
131
+ # 逐行处理
132
+ for row_index in range(num_rows):
133
+ formatted_data = []
134
+ column_widths = []
135
+ for col in column_order:
136
+ # 处理字典键的别名
137
+ dict_key = 'val_accuracies' if col == 'accuracies' else col
138
+ # 如果键不存在,跳过该列
139
+ if dict_key not in data_dict:
140
+ continue
141
+ value_list = data_dict[dict_key]
142
+ if row_index >= len(value_list):
143
+ continue
144
+ value = value_list[row_index]
145
+
146
+ # 根据数据类型进行格式化
147
+ if isinstance(value, (int, np.integer)):
148
+ fmt_value = f"{value:d}"
149
+ elif isinstance(value, (float, np.floating)):
150
+ if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
151
+ fmt_value = f"{value:.{float_precision + 1}f}"
152
+ elif col == 'lrs': # 如果列名是'lrs',保留8位小数
153
+ fmt_value = f"{value:.8f}"
154
+ else:
155
+ fmt_value = f"{value:.{float_precision}f}"
156
+ elif isinstance(value, str):
157
+ fmt_value = value
158
+ else: # 处理其他类型转换为字符串
159
+ fmt_value = str(value)
160
+
161
+ # 取列名长度和数值长度的最大值作为列宽
162
+ column_width = max(len(col), len(fmt_value))
163
+ column_widths.append(column_width)
164
+
165
+ # 应用列宽对齐
166
+ if col == column_order[-1]: # 最后一列左边对齐
167
+ fmt_value = fmt_value.ljust(column_width)
168
+ else:
169
+ fmt_value = fmt_value.rjust(column_width)
170
+
171
+ formatted_data.append(fmt_value)
172
+
173
+ # 构建文本行并写入,列之间用两个空格分隔
174
+ if formatted_data:
175
+ line = " ".join(formatted_data) + '\n'
176
+ with open(file_path, 'a', encoding='utf-8') as f:
177
+ f.write(line)
178
+ else:
179
+ # 非列表情况,原逻辑处理
180
+ # 计算每列的最大宽度
181
+ column_widths = []
182
+ formatted_data = []
183
+ for col in column_order:
184
+ # 处理字典键的别名
185
+ dict_key = 'val_accuracies' if col == 'accuracies' else col
186
+ # 如果键不存在,跳过该列
187
+ if dict_key not in data_dict:
188
+ continue
189
+
190
+ value = data_dict[dict_key]
191
+
192
+ # 根据数据类型进行格式化
193
+ if isinstance(value, (int, np.integer)):
194
+ fmt_value = f"{value:d}"
195
+ elif isinstance(value, (float, np.floating)):
196
+ if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
197
+ fmt_value = f"{value:.{float_precision + 1}f}"
198
+ elif col == 'lrs': # 如果列名是'lrs',保留8位小数
199
+ fmt_value = f"{value:.8f}"
200
+ else:
201
+ fmt_value = f"{value:.{float_precision}f}"
202
+ elif isinstance(value, str):
203
+ fmt_value = value
204
+ else: # 处理其他类型转换为字符串
205
+ fmt_value = str(value)
206
+
207
+ # 取列名长度和数值长度的最大值作为列宽
208
+ column_width = max(len(col), len(fmt_value))
209
+ column_widths.append(column_width)
210
+
211
+ # 应用列宽对齐
212
+ if col == column_order[-1]: # 最后一列左边对齐
213
+ fmt_value = fmt_value.ljust(column_width)
214
+ else:
215
+ fmt_value = fmt_value.rjust(column_width)
216
+
217
+ formatted_data.append(fmt_value)
218
+
219
+ # 构建文本行并写入,列之间用两个空格分隔
220
+ if formatted_data:
221
+ line = " ".join(formatted_data) + '\n'
222
+ with open(file_path, 'a', encoding='utf-8') as f:
223
+ f.write(line)
224
+
225
+
226
+ # def append_to_results_file(file_path: str,
227
+ # data_dict: dict,
228
+ # column_order: list,
229
+ # column_widths: list = None,
230
+ # float_precision: int = 5) -> None:
231
+ # """
232
+ # 通用格式化文本行写入函数
233
+ #
234
+ # 参数:
235
+ # file_path: 目标文件路径
236
+ # data_dict: 包含数据的字典,键为列名
237
+ # column_order: 列顺序列表,元素为字典键
238
+ # column_widths: 每列字符宽度列表 (可选)
239
+ # float_precision: 浮点数精度位数 (默认4位)
240
+ # """
241
+ # formatted_data = []
242
+ #
243
+ # # 遍历指定列顺序处理数据
244
+ # for i, col in enumerate(column_order):
245
+ # # 处理字典键的别名
246
+ # if col == 'accuracies':
247
+ # dict_key = 'val_accuracies'
248
+ # else:
249
+ # dict_key = col
250
+ #
251
+ # if dict_key not in data_dict:
252
+ # raise ValueError(f"Missing required column: {dict_key}")
253
+ #
254
+ # value = data_dict[dict_key]
255
+ #
256
+ # # 根据数据类型进行格式化
257
+ # if isinstance(value, (int, np.integer)):
258
+ # fmt_value = f"{value:d}"
259
+ # elif isinstance(value, (float, np.floating)):
260
+ # if col in ['train_losses', 'val_losses']: # 如果列名是'train_losses'或'val_losses',保留浮点数精度位数+1位
261
+ # fmt_value = f"{value:.{float_precision + 1}f}"
262
+ # elif col == 'lr': # 如果列名是'lr',保留8位小数
263
+ # fmt_value = f"{value:.8f}"
264
+ # else:
265
+ # fmt_value = f"{value:.{float_precision}f}"
266
+ # elif isinstance(value, str):
267
+ # fmt_value = value
268
+ # else: # 处理其他类型转换为字符串
269
+ # fmt_value = str(value)
270
+ #
271
+ # # 应用列宽对齐
272
+ # if column_widths and i < len(column_widths):
273
+ # try:
274
+ # if i == len(column_order) - 1: # 最后一列左边对齐
275
+ # fmt_value = fmt_value.ljust(column_widths[i])
276
+ # else:
277
+ # fmt_value = fmt_value.rjust(column_widths[i])
278
+ # except TypeError: # 处理非字符串类型
279
+ # if i == len(column_order) - 1: # 最后一列左边对齐
280
+ # fmt_value = str(fmt_value).ljust(column_widths[i])
281
+ # else:
282
+ # fmt_value = str(fmt_value).rjust(column_widths[i])
283
+ #
284
+ # formatted_data.append(fmt_value)
285
+ #
286
+ # # 构建文本行并写入
287
+ # line = '\t'.join(formatted_data) + '\n'
288
+ # with open(file_path, 'a', encoding='utf-8') as f:
289
+ # f.write(line)
290
+
291
+
292
+ def get_wandb_key(key_path='tools/wandb_key.txt'):
293
+ with open(key_path, 'r', encoding='utf-8') as f:
294
+ key = f.read()
295
+ return key
296
+
297
+
298
+ def wandb_use(project=None, name=None, key_path='tools/wandb_key.txt'):
299
+ run = None
300
+ if project is not None:
301
+ wandb_key = get_wandb_key(key_path)
302
+ wandb.login(key=wandb_key)
303
+ run = wandb.init(project=project, name=name)
304
+ return run
@@ -0,0 +1,175 @@
1
+ import math
2
+ import os
3
+ import sys
4
+ from enum import Enum, auto
5
+
6
+ import numpy as np
7
+ import torch
8
+ from matplotlib import pyplot as plt
9
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
10
+ from tqdm import tqdm
11
+
12
+ from yms_kan import LBFGS
13
+
14
+
15
+ class TaskType(Enum):
16
+ classification = auto()
17
+ zlm = auto()
18
+
19
+
20
+ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
21
+ lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
22
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
23
+ stop_grid_update_step=100,
24
+ save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
25
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
26
+ # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
27
+ # 'precisions', 'recalls', 'f1-scores']
28
+ # initialize_results_file(results_file, result_info)
29
+ all_predictions = []
30
+ all_labels = []
31
+ if lamb > 0. and not model.save_act:
32
+ print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
33
+
34
+ old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
35
+ if label is not None:
36
+ label = label.to(model.device)
37
+
38
+ if loss_fn is None:
39
+ loss_fn = lambda x, y: torch.mean((x - y) ** 2)
40
+ else:
41
+ loss_fn = loss_fn
42
+
43
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
44
+
45
+ if opt == "Adam":
46
+ optimizer = torch.optim.Adam(model.get_params(), lr=lr)
47
+ elif opt == "LBFGS":
48
+ optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
49
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
50
+ else:
51
+ optimizer = torch.optim.SGD(model.get_params(), lr=lr)
52
+
53
+ lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
54
+
55
+ results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
56
+ 'precisions': [], 'recalls': [], 'f1-scores': []}
57
+
58
+ steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
59
+
60
+ train_loss = torch.zeros(1).to(model.device)
61
+ reg_ = torch.zeros(1).to(model.device)
62
+
63
+ def closure():
64
+ nonlocal train_loss, reg_
65
+ optimizer.zero_grad()
66
+ pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
67
+ loss = loss_fn(pred, batch_train_label)
68
+ if model.save_act:
69
+ if reg_metric == 'edge_backward':
70
+ model.attribute()
71
+ if reg_metric == 'node_backward':
72
+ model.node_attribute()
73
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
74
+ else:
75
+ reg_ = torch.tensor(0.)
76
+ objective = loss + lamb * reg_
77
+ train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
78
+ objective.backward()
79
+ return objective
80
+
81
+ if save_fig:
82
+ if not os.path.exists(img_folder):
83
+ os.makedirs(img_folder)
84
+
85
+ for epoch in range(epochs):
86
+
87
+ if epoch == epochs - 1 and old_save_act:
88
+ model.save_act = True
89
+
90
+ if save_fig and epoch % save_fig_freq == 0:
91
+ save_act = model.save_act
92
+ model.save_act = True
93
+
94
+ train_indices = np.arange(dataset['train_input'].shape[0])
95
+ np.random.shuffle(train_indices)
96
+ train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
97
+ for batch_num in train_pbar:
98
+ step = epoch * steps + batch_num + 1
99
+ i = batch_num * batch_size
100
+ batch_train_id = train_indices[i:i + batch_size]
101
+ batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
102
+ batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
103
+
104
+ if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
105
+ model.update_grid(batch_train_input)
106
+
107
+ if opt == "LBFGS":
108
+ optimizer.step(closure)
109
+
110
+ else:
111
+ optimizer.zero_grad()
112
+ pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
113
+ y_th=y_th)
114
+ loss = loss_fn(pred, batch_train_label)
115
+ if model.save_act:
116
+ if reg_metric == 'edge_backward':
117
+ model.attribute()
118
+ if reg_metric == 'node_backward':
119
+ model.node_attribute()
120
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
121
+ else:
122
+ reg_ = torch.tensor(0.)
123
+ loss = loss + lamb * reg_
124
+ train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
125
+ loss.backward()
126
+ optimizer.step()
127
+ train_pbar.set_postfix(loss=train_loss.item())
128
+
129
+ val_loss = torch.zeros(1).to(model.device)
130
+ with torch.no_grad():
131
+ test_indices = np.arange(dataset['test_input'].shape[0])
132
+ np.random.shuffle(test_indices)
133
+ test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
134
+ test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
135
+ for batch_num in test_pbar:
136
+ i = batch_num * batch_size_test
137
+ batch_test_id = test_indices[i:i + batch_size_test]
138
+ batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
139
+ batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
140
+
141
+ outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
142
+ y_th=y_th)
143
+
144
+ loss = loss_fn(outputs, batch_test_label)
145
+
146
+ val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
147
+ test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
148
+ if label is not None:
149
+ diffs = torch.abs(outputs - label)
150
+ closest_indices = torch.argmin(diffs, dim=1)
151
+ closest_values = label[closest_indices]
152
+ all_predictions.extend(closest_values.detach().cpu().numpy())
153
+ all_labels.extend(batch_test_label.detach().cpu().numpy())
154
+
155
+ lr_scheduler.step(val_loss)
156
+
157
+ results['train_losses'].append(train_loss.cpu().item())
158
+ results['val_losses'].append(val_loss.cpu().item())
159
+ results['regularize'].append(reg_.cpu().item())
160
+
161
+ if save_fig and epoch % save_fig_freq == 0:
162
+ model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
163
+ beta=beta)
164
+ plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
165
+ plt.close()
166
+ model.save_act = save_act
167
+
168
+ # append_to_results_file(results_file, results, result_info)
169
+ model.log_history('fit')
170
+ model.symbolic_enabled = old_symbolic_enabled
171
+ return results
172
+
173
+
174
+ if __name__ == '__main__':
175
+ print(TaskType.zlm.value)