yms-kan 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yms_kan/MultKAN.py CHANGED
@@ -3,6 +3,7 @@ import math
3
3
  import os
4
4
  import random
5
5
  import sys
6
+ from importlib.resources import files
6
7
 
7
8
  import matplotlib.pyplot as plt
8
9
  import numpy as np
@@ -1299,7 +1300,8 @@ class MultKAN(nn.Module):
1299
1300
  N = n = width_out[l + 1]
1300
1301
  for j in range(n):
1301
1302
  id_ = j
1302
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1303
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/sum_symbol.png"
1304
+ path = files('yms_kan') / "assets/img/sum_symbol.png"
1303
1305
  im = plt.imread(path)
1304
1306
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1305
1307
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
@@ -1315,7 +1317,8 @@ class MultKAN(nn.Module):
1315
1317
  n_mult = width[l + 1][1]
1316
1318
  for j in range(n_mult):
1317
1319
  id_ = j + n_sum
1318
- path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1320
+ # path = os.path.dirname(os.path.abspath(__file__)) + "/assets/img/mult_symbol.png"
1321
+ path = files('yms_kan') / "assets/img/mult_symbol.png"
1319
1322
  im = plt.imread(path)
1320
1323
  left = DC_to_NFC([1 / (2 * N) + id_ / N - y2, 0])[0]
1321
1324
  right = DC_to_NFC([1 / (2 * N) + id_ / N + y2, 0])[0]
Binary file
Binary file
@@ -12,22 +12,14 @@ from tqdm import tqdm
12
12
  from yms_kan import LBFGS
13
13
 
14
14
 
15
- class TaskType(Enum):
16
- classification = auto()
17
- zlm = auto()
18
-
19
-
20
15
  def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
21
16
  lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
22
17
  lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
23
18
  stop_grid_update_step=100,
24
19
  save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
25
20
  singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
26
- # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
27
- # 'precisions', 'recalls', 'f1-scores']
28
- # initialize_results_file(results_file, result_info)
29
- all_predictions = []
30
- all_labels = []
21
+ # all_predictions = []
22
+ # all_labels = []
31
23
  if lamb > 0. and not model.save_act:
32
24
  print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
33
25
 
@@ -52,8 +44,8 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
52
44
 
53
45
  lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
54
46
 
55
- results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
56
- 'precisions': [], 'recalls': [], 'f1-scores': []}
47
+ results = {'train_loss': .0, 'val_loss': .0, 'regularize': .0, 'all_predictions': [],
48
+ 'all_labels': []}
57
49
 
58
50
  steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
59
51
 
@@ -126,6 +118,7 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
126
118
  optimizer.step()
127
119
  train_pbar.set_postfix(loss=train_loss.item())
128
120
 
121
+ # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
129
122
  val_loss = torch.zeros(1).to(model.device)
130
123
  with torch.no_grad():
131
124
  test_indices = np.arange(dataset['test_input'].shape[0])
@@ -149,14 +142,14 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
149
142
  diffs = torch.abs(outputs - label)
150
143
  closest_indices = torch.argmin(diffs, dim=1)
151
144
  closest_values = label[closest_indices]
152
- all_predictions.extend(closest_values.detach().cpu().numpy())
153
- all_labels.extend(batch_test_label.detach().cpu().numpy())
145
+ results['all_predictions'].extend(closest_values.detach().cpu().numpy())
146
+ results['all_labels'].extend(batch_test_label.detach().cpu().numpy())
154
147
 
155
148
  lr_scheduler.step(val_loss)
156
149
 
157
- results['train_losses'].append(train_loss.cpu().item())
158
- results['val_losses'].append(val_loss.cpu().item())
159
- results['regularize'].append(reg_.cpu().item())
150
+ results['train_loss'] = train_loss.item()
151
+ results['val_loss'] = val_loss.item()
152
+ results['regularize'] = reg_.item()
160
153
 
161
154
  if save_fig and epoch % save_fig_freq == 0:
162
155
  model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
@@ -170,6 +163,156 @@ def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", ep
170
163
  model.symbolic_enabled = old_symbolic_enabled
171
164
  return results
172
165
 
173
-
174
- if __name__ == '__main__':
175
- print(TaskType.zlm.value)
166
+ # def train_val(model, dataset: dict, batch_size, batch_size_test, opt="LBFGS", epochs=100, lamb=0.,
167
+ # lamb_l1=1., label=None, lamb_entropy=2., lamb_coef=0.,
168
+ # lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1,
169
+ # stop_grid_update_step=100,
170
+ # save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video',
171
+ # singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n'):
172
+ # # result_info = ['epoch','train_losses', 'val_losses', 'regularize', 'accuracies',
173
+ # # 'precisions', 'recalls', 'f1-scores']
174
+ # # initialize_results_file(results_file, result_info)
175
+ # all_predictions = []
176
+ # all_labels = []
177
+ # if lamb > 0. and not model.save_act:
178
+ # print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
179
+ #
180
+ # old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)
181
+ # if label is not None:
182
+ # label = label.to(model.device)
183
+ #
184
+ # if loss_fn is None:
185
+ # loss_fn = lambda x, y: torch.mean((x - y) ** 2)
186
+ # else:
187
+ # loss_fn = loss_fn
188
+ #
189
+ # grid_update_freq = int(stop_grid_update_step / grid_update_num)
190
+ #
191
+ # if opt == "Adam":
192
+ # optimizer = torch.optim.Adam(model.get_params(), lr=lr)
193
+ # elif opt == "LBFGS":
194
+ # optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, line_search_fn="strong_wolfe",
195
+ # tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
196
+ # else:
197
+ # optimizer = torch.optim.SGD(model.get_params(), lr=lr)
198
+ #
199
+ # lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9)
200
+ #
201
+ # results = {'train_losses': [], 'val_losses': [], 'regularize': [], 'accuracies': [],
202
+ # 'precisions': [], 'recalls': [], 'f1-scores': []}
203
+ #
204
+ # steps = math.ceil(dataset['train_input'].shape[0] / batch_size)
205
+ #
206
+ # train_loss = torch.zeros(1).to(model.device)
207
+ # reg_ = torch.zeros(1).to(model.device)
208
+ #
209
+ # def closure():
210
+ # nonlocal train_loss, reg_
211
+ # optimizer.zero_grad()
212
+ # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding, y_th=y_th)
213
+ # loss = loss_fn(pred, batch_train_label)
214
+ # if model.save_act:
215
+ # if reg_metric == 'edge_backward':
216
+ # model.attribute()
217
+ # if reg_metric == 'node_backward':
218
+ # model.node_attribute()
219
+ # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
220
+ # else:
221
+ # reg_ = torch.tensor(0.)
222
+ # objective = loss + lamb * reg_
223
+ # train_loss = (train_loss * batch_num + objective.detach()) / (batch_num + 1)
224
+ # objective.backward()
225
+ # return objective
226
+ #
227
+ # if save_fig:
228
+ # if not os.path.exists(img_folder):
229
+ # os.makedirs(img_folder)
230
+ #
231
+ # for epoch in range(epochs):
232
+ #
233
+ # if epoch == epochs - 1 and old_save_act:
234
+ # model.save_act = True
235
+ #
236
+ # if save_fig and epoch % save_fig_freq == 0:
237
+ # save_act = model.save_act
238
+ # model.save_act = True
239
+ #
240
+ # train_indices = np.arange(dataset['train_input'].shape[0])
241
+ # np.random.shuffle(train_indices)
242
+ # train_pbar = tqdm(range(steps), desc=f'Epoch {epoch + 1}/{epochs} Training', file=sys.stdout)
243
+ # for batch_num in train_pbar:
244
+ # step = epoch * steps + batch_num + 1
245
+ # i = batch_num * batch_size
246
+ # batch_train_id = train_indices[i:i + batch_size]
247
+ # batch_train_input = dataset['train_input'][batch_train_id].to(model.device)
248
+ # batch_train_label = dataset['train_label'][batch_train_id].to(model.device)
249
+ #
250
+ # if step % grid_update_freq == 0 and step < stop_grid_update_step and update_grid and step >= start_grid_update_step:
251
+ # model.update_grid(batch_train_input)
252
+ #
253
+ # if opt == "LBFGS":
254
+ # optimizer.step(closure)
255
+ #
256
+ # else:
257
+ # optimizer.zero_grad()
258
+ # pred = model.forward(batch_train_input, singularity_avoiding=singularity_avoiding,
259
+ # y_th=y_th)
260
+ # loss = loss_fn(pred, batch_train_label)
261
+ # if model.save_act:
262
+ # if reg_metric == 'edge_backward':
263
+ # model.attribute()
264
+ # if reg_metric == 'node_backward':
265
+ # model.node_attribute()
266
+ # reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
267
+ # else:
268
+ # reg_ = torch.tensor(0.)
269
+ # loss = loss + lamb * reg_
270
+ # train_loss = (train_loss * batch_num + loss.detach()) / (batch_num + 1)
271
+ # loss.backward()
272
+ # optimizer.step()
273
+ # train_pbar.set_postfix(loss=train_loss.item())
274
+ #
275
+ # print(f'{epoch}/{epochs}:train_loss:{train_loss.item()}')
276
+ # val_loss = torch.zeros(1).to(model.device)
277
+ # with torch.no_grad():
278
+ # test_indices = np.arange(dataset['test_input'].shape[0])
279
+ # np.random.shuffle(test_indices)
280
+ # test_steps = math.ceil(dataset['test_input'].shape[0] / batch_size_test)
281
+ # test_pbar = tqdm(range(test_steps), desc=f'Epoch {epoch + 1}/{epochs} Validation', file=sys.stdout)
282
+ # for batch_num in test_pbar:
283
+ # i = batch_num * batch_size_test
284
+ # batch_test_id = test_indices[i:i + batch_size_test]
285
+ # batch_test_input = dataset['test_input'][batch_test_id].to(model.device)
286
+ # batch_test_label = dataset['test_label'][batch_test_id].to(model.device)
287
+ #
288
+ # outputs = model.forward(batch_test_input, singularity_avoiding=singularity_avoiding,
289
+ # y_th=y_th)
290
+ #
291
+ # loss = loss_fn(outputs, batch_test_label)
292
+ #
293
+ # val_loss = (val_loss * batch_num + loss.detach()) / (batch_num + 1)
294
+ # test_pbar.set_postfix(loss=loss.item(), val_loss=val_loss.item())
295
+ # if label is not None:
296
+ # diffs = torch.abs(outputs - label)
297
+ # closest_indices = torch.argmin(diffs, dim=1)
298
+ # closest_values = label[closest_indices]
299
+ # all_predictions.extend(closest_values.detach().cpu().numpy())
300
+ # all_labels.extend(batch_test_label.detach().cpu().numpy())
301
+ #
302
+ # lr_scheduler.step(val_loss)
303
+ #
304
+ # results['train_losses'].append(train_loss.cpu().item())
305
+ # results['val_losses'].append(val_loss.cpu().item())
306
+ # results['regularize'].append(reg_.cpu().item())
307
+ #
308
+ # if save_fig and epoch % save_fig_freq == 0:
309
+ # model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(epoch),
310
+ # beta=beta)
311
+ # plt.savefig(img_folder + '/' + str(epoch) + '.jpg', bbox_inches='tight', dpi=100)
312
+ # plt.close()
313
+ # model.save_act = save_act
314
+ #
315
+ # # append_to_results_file(results_file, results, result_info)
316
+ # model.log_history('fit')
317
+ # model.symbolic_enabled = old_symbolic_enabled
318
+ # return results
yms_kan/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.0.1" # 初始版本
1
+ __version__ = "0.0.2" # 初始版本
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: yms_kan
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: My awesome package
5
5
  Author-email: yms <11@qq.com>
6
6
  License-Expression: MIT
@@ -1,7 +1,7 @@
1
1
  yms_kan/KANLayer.py,sha256=-V2Fh5wvPYvfF1tmQVxJKWvvaAHiwo2EiFpd8VDgB1c,14149
2
2
  yms_kan/LBFGS.py,sha256=OPeRPDp40jaVH4qPoBDMEub7TPhyvw7pbqwQar3OZ1A,17620
3
3
  yms_kan/MLP.py,sha256=ryLzSuBrsGlSHRLwnQZCNj-Ru9BwXJYoHNkAwX14N64,12804
4
- yms_kan/MultKAN.py,sha256=sVVYmI5DSQPCU4y3s-JH7jPTS1dL9ztFCCB5KPbrjrU,121930
4
+ yms_kan/MultKAN.py,sha256=n58W6tORBDuh_-pUVp-ER-R9KNJKdYjFaisDx8wJpWw,122113
5
5
  yms_kan/Symbolic_KANLayer.py,sha256=WhJzC5IMIpXI_K7aYamOrWTK7uckxVdsM9N4oLZMO3I,9897
6
6
  yms_kan/__init__.py,sha256=O2c6DIG4PHavXF2v7K9jNqMbJXWr4-gTN3Vs1YSlc64,120
7
7
  yms_kan/compiler.py,sha256=7bVwDNX0xmLAjQ8V1FdmkIIIibmy_W5eaeSKBlYL0Vc,18632
@@ -10,11 +10,13 @@ yms_kan/feynman.py,sha256=Eisf69K49s4C6UlPEi5LnNK_p5TUJQLBKxMp-sW0a9w,33687
10
10
  yms_kan/hypothesis.py,sha256=Ec20xadfgOSSWeZHQaGn-h9F2PY7LWFU3iniNI2Zd_4,23165
11
11
  yms_kan/spline.py,sha256=ZXyGwl2Sc-UrnrcuUXeUQkBOMnetaWcHrbpZaqatCvs,4345
12
12
  yms_kan/tool.py,sha256=CLIsOYWwG-A5PJvoyIP8cRBzX8iRhEssW-2uXdLfi-U,12124
13
- yms_kan/train_eval_utils.py,sha256=hLBBe60AaPvqT-VLPHXzxEj14dBQJZ4uJFIkEp3Neak,7612
13
+ yms_kan/train_eval_utils.py,sha256=73pA3-HDPDik_yCsDW0oF1dIvVu_vPeHbvJ08o26ygQ,14867
14
14
  yms_kan/utils.py,sha256=J07L-tgmc1OfU6Tl6mGwHJRizjFN75EJK8BxejaZLUc,23860
15
- yms_kan/version.py,sha256=KHooFYaAikQv9yQmi9BTz3oq7Zm24uaNYJmVw0-8YCU,39
16
- yms_kan-0.0.1.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
17
- yms_kan-0.0.1.dist-info/METADATA,sha256=K9MR1ucGkoyozsc7MZ2-FbboTyd1Xoau5GBCkCHV1vk,240
18
- yms_kan-0.0.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
- yms_kan-0.0.1.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
20
- yms_kan-0.0.1.dist-info/RECORD,,
15
+ yms_kan/version.py,sha256=qeSnHAh3t9Zb2L0FPUF5OaQvWEJcfTki6FmrfynjWz4,39
16
+ yms_kan/assets/img/mult_symbol.png,sha256=2f4xUKdweft-qUbHjFI5h9-smnEtc0FWq8hNYZhPAXY,6392
17
+ yms_kan/assets/img/sum_symbol.png,sha256=94QkMUzmEjlCq_yf14nMEQmettaq86FmlGfdl22b4XE,6210
18
+ yms_kan-0.0.2.dist-info/licenses/LICENSE,sha256=BJXDWyF4Groqtnp4Gi9puH4aLg7A2IC3MpHmC-cSxwc,1067
19
+ yms_kan-0.0.2.dist-info/METADATA,sha256=jTD-nNMWFF64GiFO2-bYQePNxJh4J1-yi4eniWT1djQ,240
20
+ yms_kan-0.0.2.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
21
+ yms_kan-0.0.2.dist-info/top_level.txt,sha256=Z_JDh6yZf-EiW1eKgL6ADsN2yqEMRMspi-o29JZ1WPo,8
22
+ yms_kan-0.0.2.dist-info/RECORD,,