yms-kan 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. yms_kan-0.0.2/MANIFEST.in +1 -0
  2. {yms_kan-0.0.1/yms_kan.egg-info → yms_kan-0.0.2}/PKG-INFO +1 -1
  3. {yms_kan-0.0.1 → yms_kan-0.0.2}/pyproject.toml +12 -0
  4. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/MultKAN.py +5 -2
  5. yms_kan-0.0.2/yms_kan/assets/img/mult_symbol.png +0 -0
  6. yms_kan-0.0.2/yms_kan/assets/img/sum_symbol.png +0 -0
  7. yms_kan-0.0.2/yms_kan/train_eval_utils.py +318 -0
  8. yms_kan-0.0.2/yms_kan/version.py +1 -0
  9. {yms_kan-0.0.1 → yms_kan-0.0.2/yms_kan.egg-info}/PKG-INFO +1 -1
  10. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/SOURCES.txt +4 -1
  11. yms_kan-0.0.1/yms_kan/train_eval_utils.py +0 -175
  12. yms_kan-0.0.1/yms_kan/version.py +0 -1
  13. {yms_kan-0.0.1 → yms_kan-0.0.2}/LICENSE +0 -0
  14. {yms_kan-0.0.1 → yms_kan-0.0.2}/README.md +0 -0
  15. {yms_kan-0.0.1 → yms_kan-0.0.2}/setup.cfg +0 -0
  16. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/KANLayer.py +0 -0
  17. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/LBFGS.py +0 -0
  18. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/MLP.py +0 -0
  19. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/Symbolic_KANLayer.py +0 -0
  20. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/__init__.py +0 -0
  21. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/compiler.py +0 -0
  22. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/experiment.py +0 -0
  23. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/feynman.py +0 -0
  24. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/hypothesis.py +0 -0
  25. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/spline.py +0 -0
  26. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/tool.py +0 -0
  27. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan/utils.py +0 -0
  28. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/dependency_links.txt +0 -0
  29. {yms_kan-0.0.1 → yms_kan-0.0.2}/yms_kan.egg-info/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ recursive-include yms_kan/assets/img *.png
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -17,3 +17,15 @@ license-files = ["LICENSE"]
17
17
  [tool.setuptools.dynamic]
18
18
  # 明确指定版本号来源
19
19
  version = {attr = "yms_kan.version.__version__"}
20
+
21
+ [tool.setuptools]
22
+ # 包含非代码文件
23
+ include-package-data = true
24
+
25
+ [tool.setuptools.package-data]
26
+ # 指定包内资源文件的匹配规则
27
+ yms_kan = [
28
+ "assets/img/*.png", # 包含所有png文件
29
+ "assets/img/*.svg", # 可扩展其他格式
30
+ "assets/**/*" # 递归包含子目录
31
+ ]
@@ -3,6 +3,7 @@ import math
3
3
  import os
4
4
  import random
5
5
  import sys
6
+ from importlib.resources import files
6
7
 
7
8
  import matplotlib.pyplot as plt
8
9
  import numpy as np
@@ -1299,7 +1300,8 @@ class MultKAN(nn.Module):
1299
1300
  N = n = width_out[l + 1]
1300
1301
  for j in range(n):
1301
1302
  id_ = j
1302
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1303
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1304
+ path = files('yms_kan') / "assets/img/sum_symbol.png"
1303
1305
  im = plt.imread(path)
1304
1306
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1305
1307
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
@@ -1315,7 +1317,8 @@ class MultKAN(nn.Module):
1315
1317
  n_mult = width[l + 1][1]
1316
1318
  for j in range(n_mult):
1317
1319
  id_ = j + n_sum
1318
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1320
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1321
+ path = files('yms_kan') / "assets/img/mult_symbol.png"
1319
1322
  im = plt.imread(path)
1320
1323
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1321
1324
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
@@ -0,0 +1,318 @@
1
+ import math
2
+ import os
3
+ import sys
4
+ from enum import Enum, auto
5
+
6
+ import numpy as np
7
+ import torch
8
+ from matplotlib import pyplot as plt
9
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
10
+ from tqdm import tqdm
11
+
12
+ from yms_kan import LBFGS
13
+
14
+
15
+ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
16
+ lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
17
+ lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
18
+ stop_grid_update_step=100,
19
+ save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
20
+ singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
21
+ # all_predictions = []
22
+ # all_labels = []
23
+ if lamb > 0. and not model.save_act:
24
+ print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
25
+
26
+ old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
27
+ if label is not None:
28
+ label = label.to(model.device)
29
+
30
+ if loss_fn is None:
31
+ loss_fn = lambda x, y: torch.mean((x - y) ** 2)
32
+ else:
33
+ loss_fn = loss_fn
34
+
35
+ grid_update_freq = int(stop_grid_update_step / grid_update_num)
36
+
37
+ if opt == "Adam":
38
+ optimizer = torch.optim.Adam(model.get_params(), lr=lr)
39
+ elif opt == "LBFGS":
40
+ optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
41
+ tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
42
+ else:
43
+ optimizer = torch.optim.SGD(model.get_params(), lr=lr)
44
+
45
+ lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
46
+
47
+ results = {'train_loss': .0, 'val_loss': .0, 'regularize': .0, 'all_predictions': [],
48
+ 'all_labels': []}
49
+
50
+ steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
51
+
52
+ train_loss = torch.zeros(1).to(model.device)
53
+ reg_ = torch.zeros(1).to(model.device)
54
+
55
+ def closure():
56
+ nonlocal train_loss, reg_
57
+ optimizer.zero_grad()
58
+ pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
59
+ loss = loss_fn(pred, batch_train_label)
60
+ if model.save_act:
61
+ if reg_metric == 'edge_backward':
62
+ model.attribute()
63
+ if reg_metric == 'node_backward':
64
+ model.node_attribute()
65
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
66
+ else:
67
+ reg_ = torch.tensor(0.)
68
+ objective = loss + lamb * reg_
69
+ train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
70
+ objective.backward()
71
+ return objective
72
+
73
+ if save_fig:
74
+ if not os.path.exists(img_folder):
75
+ os.makedirs(img_folder)
76
+
77
+ for epoch in range(epochs):
78
+
79
+ if epoch == epochs - 1 and old_save_act:
80
+ model.save_act = True
81
+
82
+ if save_fig and epoch % save_fig_freq == 0:
83
+ save_act = model.save_act
84
+ model.save_act = True
85
+
86
+ train_indices = np.arange(dataset['train_input'].shape[0])
87
+ np.random.shuffle(train_indices)
88
+ train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
89
+ for batch_num in train_pbar:
90
+ step = epoch * steps + batch_num + 1
91
+ i = batch_num * batch_size
92
+ batch_train_id = train_indices[i:i + batch_size]
93
+ batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
94
+ batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
95
+
96
+ if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
97
+ model.update_grid(batch_train_input)
98
+
99
+ if opt == "LBFGS":
100
+ optimizer.step(closure)
101
+
102
+ else:
103
+ optimizer.zero_grad()
104
+ pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
105
+ y_th=y_th)
106
+ loss = loss_fn(pred, batch_train_label)
107
+ if model.save_act:
108
+ if reg_metric == 'edge_backward':
109
+ model.attribute()
110
+ if reg_metric == 'node_backward':
111
+ model.node_attribute()
112
+ reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
113
+ else:
114
+ reg_ = torch.tensor(0.)
115
+ loss = loss + lamb * reg_
116
+ train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
117
+ loss.backward()
118
+ optimizer.step()
119
+ train_pbar.set_postfix(loss=train_loss.item())
120
+
121
+ # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
122
+ val_loss = torch.zeros(1).to(model.device)
123
+ with torch.no_grad():
124
+ test_indices = np.arange(dataset['test_input'].shape[0])
125
+ np.random.shuffle(test_indices)
126
+ test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
127
+ test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
128
+ for batch_num in test_pbar:
129
+ i = batch_num * batch_size_test
130
+ batch_test_id = test_indices[i:i + batch_size_test]
131
+ batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
132
+ batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
133
+
134
+ outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
135
+ y_th=y_th)
136
+
137
+ loss = loss_fn(outputs, batch_test_label)
138
+
139
+ val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
140
+ test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
141
+ if label is not None:
142
+ diffs = torch.abs(outputs - label)
143
+ closest_indices = torch.argmin(diffs, dim=1)
144
+ closest_values = label[closest_indices]
145
+ results['all_predictions'].extend(closest_values.detach().cpu().numpy())
146
+ results['all_labels'].extend(batch_test_label.detach().cpu().numpy())
147
+
148
+ lr_scheduler.step(val_loss)
149
+
150
+ results['train_loss'] = train_loss.item()
151
+ results['val_loss'] = val_loss.item()
152
+ results['regularize'] = reg_.item()
153
+
154
+ if save_fig and epoch % save_fig_freq == 0:
155
+ model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
156
+ beta=beta)
157
+ plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
158
+ plt.close()
159
+ model.save_act = save_act
160
+
161
+ # append_to_results_file(results_file, results, result_info)
162
+ model.log_history('fit')
163
+ model.symbolic_enabled = old_symbolic_enabled
164
+ return results
165
+
166
+ # def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
167
+ # lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
168
+ # lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
169
+ # stop_grid_update_step=100,
170
+ # save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
171
+ # singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
172
+ # # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
173
+ # # 'precisions', 'recalls', 'f1-scores']
174
+ # # initialize_results_file(results_file, result_info)
175
+ # all_predictions = []
176
+ # all_labels = []
177
+ # if lamb > 0. and not model.save_act:
178
+ # print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
179
+ #
180
+ # old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
181
+ # if label is not None:
182
+ # label = label.to(model.device)
183
+ #
184
+ # if loss_fn is None:
185
+ # loss_fn = lambda x, y: torch.mean((x - y) ** 2)
186
+ # else:
187
+ # loss_fn = loss_fn
188
+ #
189
+ # grid_update_freq = int(stop_grid_update_step / grid_update_num)
190
+ #
191
+ # if opt == "Adam":
192
+ # optimizer = torch.optim.Adam(model.get_params(), lr=lr)
193
+ # elif opt == "LBFGS":
194
+ # optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
195
+ # tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
196
+ # else:
197
+ # optimizer = torch.optim.SGD(model.get_params(), lr=lr)
198
+ #
199
+ # lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
200
+ #
201
+ # results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
202
+ # 'precisions': [], 'recalls': [], 'f1-scores': []}
203
+ #
204
+ # steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
205
+ #
206
+ # train_loss = torch.zeros(1).to(model.device)
207
+ # reg_ = torch.zeros(1).to(model.device)
208
+ #
209
+ # def closure():
210
+ # nonlocal train_loss, reg_
211
+ # optimizer.zero_grad()
212
+ # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
213
+ # loss = loss_fn(pred, batch_train_label)
214
+ # if model.save_act:
215
+ # if reg_metric == 'edge_backward':
216
+ # model.attribute()
217
+ # if reg_metric == 'node_backward':
218
+ # model.node_attribute()
219
+ # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
220
+ # else:
221
+ # reg_ = torch.tensor(0.)
222
+ # objective = loss + lamb * reg_
223
+ # train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
224
+ # objective.backward()
225
+ # return objective
226
+ #
227
+ # if save_fig:
228
+ # if not os.path.exists(img_folder):
229
+ # os.makedirs(img_folder)
230
+ #
231
+ # for epoch in range(epochs):
232
+ #
233
+ # if epoch == epochs - 1 and old_save_act:
234
+ # model.save_act = True
235
+ #
236
+ # if save_fig and epoch % save_fig_freq == 0:
237
+ # save_act = model.save_act
238
+ # model.save_act = True
239
+ #
240
+ # train_indices = np.arange(dataset['train_input'].shape[0])
241
+ # np.random.shuffle(train_indices)
242
+ # train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
243
+ # for batch_num in train_pbar:
244
+ # step = epoch * steps + batch_num + 1
245
+ # i = batch_num * batch_size
246
+ # batch_train_id = train_indices[i:i + batch_size]
247
+ # batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
248
+ # batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
249
+ #
250
+ # if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
251
+ # model.update_grid(batch_train_input)
252
+ #
253
+ # if opt == "LBFGS":
254
+ # optimizer.step(closure)
255
+ #
256
+ # else:
257
+ # optimizer.zero_grad()
258
+ # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
259
+ # y_th=y_th)
260
+ # loss = loss_fn(pred, batch_train_label)
261
+ # if model.save_act:
262
+ # if reg_metric == 'edge_backward':
263
+ # model.attribute()
264
+ # if reg_metric == 'node_backward':
265
+ # model.node_attribute()
266
+ # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
267
+ # else:
268
+ # reg_ = torch.tensor(0.)
269
+ # loss = loss + lamb * reg_
270
+ # train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
271
+ # loss.backward()
272
+ # optimizer.step()
273
+ # train_pbar.set_postfix(loss=train_loss.item())
274
+ #
275
+ # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
276
+ # val_loss = torch.zeros(1).to(model.device)
277
+ # with torch.no_grad():
278
+ # test_indices = np.arange(dataset['test_input'].shape[0])
279
+ # np.random.shuffle(test_indices)
280
+ # test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
281
+ # test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
282
+ # for batch_num in test_pbar:
283
+ # i = batch_num * batch_size_test
284
+ # batch_test_id = test_indices[i:i + batch_size_test]
285
+ # batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
286
+ # batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
287
+ #
288
+ # outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
289
+ # y_th=y_th)
290
+ #
291
+ # loss = loss_fn(outputs, batch_test_label)
292
+ #
293
+ # val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
294
+ # test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
295
+ # if label is not None:
296
+ # diffs = torch.abs(outputs - label)
297
+ # closest_indices = torch.argmin(diffs, dim=1)
298
+ # closest_values = label[closest_indices]
299
+ # all_predictions.extend(closest_values.detach().cpu().numpy())
300
+ # all_labels.extend(batch_test_label.detach().cpu().numpy())
301
+ #
302
+ # lr_scheduler.step(val_loss)
303
+ #
304
+ # results['train_losses'].append(train_loss.cpu().item())
305
+ # results['val_losses'].append(val_loss.cpu().item())
306
+ # results['regularize'].append(reg_.cpu().item())
307
+ #
308
+ # if save_fig and epoch % save_fig_freq == 0:
309
+ # model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
310
+ # beta=beta)
311
+ # plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
312
+ # plt.close()
313
+ # model.save_act = save_act
314
+ #
315
+ # # append_to_results_file(results_file, results, result_info)
316
+ # model.log_history('fit')
317
+ # model.symbolic_enabled = old_symbolic_enabled
318
+ # return results
@@ -0,0 +1 @@
1
+ __version__ = "0.0.2" # 初始版本
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -1,4 +1,5 @@
1
1
  LICENSE
2
+ MANIFEST.in
2
3
  README.md
3
4
  pyproject.toml
4
5
  yms_kan/KANLayer.py
@@ -19,4 +20,6 @@ yms_kan/version.py
19
20
  yms_kan.egg-info/PKG-INFO
20
21
  yms_kan.egg-info/SOURCES.txt
21
22
  yms_kan.egg-info/dependency_links.txt
22
- yms_kan.egg-info/top_level.txt
23
+ yms_kan.egg-info/top_level.txt
24
+ yms_kan/assets/img/mult_symbol.png
25
+ yms_kan/assets/img/sum_symbol.png
@@ -1,175 +0,0 @@
1
- import math
2
- import os
3
- import sys
4
- from enum import Enum, auto
5
-
6
- import numpy as np
7
- import torch
8
- from matplotlib import pyplot as plt
9
- from torch.optim.lr_scheduler import ReduceLROnPlateau
10
- from tqdm import tqdm
11
-
12
- from yms_kan import LBFGS
13
-
14
-
15
- class TaskType(Enum):
16
- classification = auto()
17
- zlm = auto()
18
-
19
-
20
- def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
21
- lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
22
- lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
23
- stop_grid_update_step=100,
24
- save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
25
- singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
26
- # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
27
- # 'precisions', 'recalls', 'f1-scores']
28
- # initialize_results_file(results_file, result_info)
29
- all_predictions = []
30
- all_labels = []
31
- if lamb > 0. and not model.save_act:
32
- print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
33
-
34
- old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
35
- if label is not None:
36
- label = label.to(model.device)
37
-
38
- if loss_fn is None:
39
- loss_fn = lambda x, y: torch.mean((x - y) ** 2)
40
- else:
41
- loss_fn = loss_fn
42
-
43
- grid_update_freq = int(stop_grid_update_step / grid_update_num)
44
-
45
- if opt == "Adam":
46
- optimizer = torch.optim.Adam(model.get_params(), lr=lr)
47
- elif opt == "LBFGS":
48
- optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
49
- tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
50
- else:
51
- optimizer = torch.optim.SGD(model.get_params(), lr=lr)
52
-
53
- lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
54
-
55
- results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
56
- 'precisions': [], 'recalls': [], 'f1-scores': []}
57
-
58
- steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
59
-
60
- train_loss = torch.zeros(1).to(model.device)
61
- reg_ = torch.zeros(1).to(model.device)
62
-
63
- def closure():
64
- nonlocal train_loss, reg_
65
- optimizer.zero_grad()
66
- pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
67
- loss = loss_fn(pred, batch_train_label)
68
- if model.save_act:
69
- if reg_metric == 'edge_backward':
70
- model.attribute()
71
- if reg_metric == 'node_backward':
72
- model.node_attribute()
73
- reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
74
- else:
75
- reg_ = torch.tensor(0.)
76
- objective = loss + lamb * reg_
77
- train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
78
- objective.backward()
79
- return objective
80
-
81
- if save_fig:
82
- if not os.path.exists(img_folder):
83
- os.makedirs(img_folder)
84
-
85
- for epoch in range(epochs):
86
-
87
- if epoch == epochs - 1 and old_save_act:
88
- model.save_act = True
89
-
90
- if save_fig and epoch % save_fig_freq == 0:
91
- save_act = model.save_act
92
- model.save_act = True
93
-
94
- train_indices = np.arange(dataset['train_input'].shape[0])
95
- np.random.shuffle(train_indices)
96
- train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
97
- for batch_num in train_pbar:
98
- step = epoch * steps + batch_num + 1
99
- i = batch_num * batch_size
100
- batch_train_id = train_indices[i:i + batch_size]
101
- batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
102
- batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
103
-
104
- if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
105
- model.update_grid(batch_train_input)
106
-
107
- if opt == "LBFGS":
108
- optimizer.step(closure)
109
-
110
- else:
111
- optimizer.zero_grad()
112
- pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
113
- y_th=y_th)
114
- loss = loss_fn(pred, batch_train_label)
115
- if model.save_act:
116
- if reg_metric == 'edge_backward':
117
- model.attribute()
118
- if reg_metric == 'node_backward':
119
- model.node_attribute()
120
- reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
121
- else:
122
- reg_ = torch.tensor(0.)
123
- loss = loss + lamb * reg_
124
- train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
125
- loss.backward()
126
- optimizer.step()
127
- train_pbar.set_postfix(loss=train_loss.item())
128
-
129
- val_loss = torch.zeros(1).to(model.device)
130
- with torch.no_grad():
131
- test_indices = np.arange(dataset['test_input'].shape[0])
132
- np.random.shuffle(test_indices)
133
- test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
134
- test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
135
- for batch_num in test_pbar:
136
- i = batch_num * batch_size_test
137
- batch_test_id = test_indices[i:i + batch_size_test]
138
- batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
139
- batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
140
-
141
- outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
142
- y_th=y_th)
143
-
144
- loss = loss_fn(outputs, batch_test_label)
145
-
146
- val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
147
- test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
148
- if label is not None:
149
- diffs = torch.abs(outputs - label)
150
- closest_indices = torch.argmin(diffs, dim=1)
151
- closest_values = label[closest_indices]
152
- all_predictions.extend(closest_values.detach().cpu().numpy())
153
- all_labels.extend(batch_test_label.detach().cpu().numpy())
154
-
155
- lr_scheduler.step(val_loss)
156
-
157
- results['train_losses'].append(train_loss.cpu().item())
158
- results['val_losses'].append(val_loss.cpu().item())
159
- results['regularize'].append(reg_.cpu().item())
160
-
161
- if save_fig and epoch % save_fig_freq == 0:
162
- model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
163
- beta=beta)
164
- plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
165
- plt.close()
166
- model.save_act = save_act
167
-
168
- # append_to_results_file(results_file, results, result_info)
169
- model.log_history('fit')
170
- model.symbolic_enabled = old_symbolic_enabled
171
- return results
172
-
173
-
174
- if __name__ == '__main__':
175
- print(TaskType.zlm.value)
@@ -1 +0,0 @@
1
- __version__ = "0.0.1" # 初始版本
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes