junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,286 @@
1
+ """
2
+ ----------------------------------------------------------------------
3
+ >>> Author : Junshan Yin
4
+ >>> Last Updated : 2025-11-14
5
+ ----------------------------------------------------------------------
6
+ """
7
+ import math, os
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import matplotlib as mpl
11
+ from collections import defaultdict
12
+ from junshan_kit import kit, ParametersHub
13
+
14
+ def marker_schedule(marker_schedule=None):
15
+
16
+ if marker_schedule == "SPBM":
17
+ based_marker = {
18
+ "ADAM": "s", # square
19
+ "ALR-SMAG": "h", # pixel marker
20
+ "Bundle": "o", # circle
21
+ "SGD": "p", # pentagon
22
+ "SPSmax": "4", # tri-right
23
+ "SPBM-PF": "*", # star
24
+ "SPBM-TR": "s", # star
25
+ }
26
+
27
+ else:
28
+ based_marker = {
29
+ "point": ".", # point marker
30
+ "pixel": ",", # pixel marker
31
+ "circle": "o", # circle
32
+ "triangle_down": "v", # down triangle
33
+ "triangle_up": "^", # up triangle
34
+ "triangle_left": "<", # left triangle
35
+ "triangle_right": ">", # right triangle
36
+ "tri_down": "1", # tri-down
37
+ "tri_up": "2", # tri-up
38
+ "tri_left": "3", # tri-left
39
+ "tri_right": "4", # tri-right
40
+ "square": "s", # square
41
+ "pentagon": "p", # pentagon
42
+ "star": "*", # star
43
+ "hexagon1": "h", # hexagon 1
44
+ "hexagon2": "H", # hexagon 2
45
+ "plus": "+", # plus
46
+ "x": "x", # x
47
+ "diamond": "D", # diamond
48
+ "thin_diamond": "d", # thin diamond
49
+ "vline": "|", # vertical line
50
+ "hline": "_", # horizontal line
51
+ }
52
+
53
+ return based_marker
54
+
55
+
56
+ def colors_schedule(colors_schedule=None):
57
+
58
+ if colors_schedule == "SPBM":
59
+ based_color = {
60
+ "ADAM": "#7f7f7f",
61
+ "ALR-SMAG": "#796378",
62
+ "Bundle": "#17becf",
63
+ "SGD": "#2ca02c",
64
+ "SPSmax": "#BA6262",
65
+ "SPBM-PF": "#1f77b4",
66
+ "SPBM-TR": "#d62728",
67
+ }
68
+
69
+ else:
70
+ based_color = {
71
+ "ADAM": "#1f77b4",
72
+ "ALR-SMAG": "#ff7f0e",
73
+ "Bundle": "#2ca02c",
74
+ "SGD": "#d62728",
75
+ "SPSmax": "#9467bd",
76
+ "SPBM-PF": "#8c564b",
77
+ "SPBM-TR": "#e377c2",
78
+ "dddd": "#7f7f7f",
79
+ "xxx": "#bcbd22",
80
+ "ED": "#17becf",
81
+ }
82
+ return based_color
83
+
84
+
85
+ def Search_Paras(Paras, args, model_name, data_name, optimizer_name, metric_key = "training_loss"):
86
+
87
+ param_dict = Paras["Results_dict"][model_name][data_name][optimizer_name]
88
+
89
+ num_polts = len(param_dict)
90
+ cols = 3
91
+ rows = math.ceil(num_polts / cols)
92
+
93
+ fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
94
+ axes = axes.flatten()
95
+
96
+ for idx, (param_str, info) in enumerate(param_dict.items()):
97
+ ax = axes[idx]
98
+ metric_list = info.get(metric_key, [])
99
+ # duration = info.get('duration', 0)
100
+ ax.plot(metric_list)
101
+ # ax.set_title(f"time:{duration:.8f}s - seed: {Paras['seed']}, ID: {Paras['time_str']} \n params = {param_str}", fontsize=10)
102
+ ax.set_title(f'time = {info["train_time"]:.2f}, seed: {Paras["seed"]}, ID: {Paras["time_str"]} \n params = {param_str}', fontsize=10)
103
+ ax.set_xlabel("epochs")
104
+ ax.set_ylabel(ParametersHub.fig_ylabel(metric_key))
105
+ ax.grid(True)
106
+ if Paras.get('use_log_scale', False) and any(k in metric_key for k in ['loss', 'grad']):
107
+ ax.set_yscale("log")
108
+
109
+
110
+ # Delete the redundant subfigures
111
+ for i in range(len(param_dict), len(axes)):
112
+ fig.delaxes(axes[i])
113
+
114
+
115
+ plt.suptitle(f'{model_name} on {data_name} - {optimizer_name} (training/test samples: {Paras["train_data_num"]}/{Paras["test_data_num"]}), {Paras["device"]}', fontsize=16)
116
+ plt.tight_layout(rect=(0, 0, 1, 0.9))
117
+
118
+ filename = f'{Paras["Results_folder"]}/{metric_key}_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pdf'
119
+ fig.savefig(filename)
120
+ print(f"✅ Saved: {filename}")
121
+ plt.close('all')
122
+
123
+
124
+ def Read_Results_from_pkl(info_dict, Exp_name, model_name):
125
+ draw_data = defaultdict(dict)
126
+ for data_name, info in info_dict.items():
127
+ for optimizer_name, info_opt in info["optimizer"].items():
128
+
129
+ pkl_path = f'{Exp_name}/seed_{info["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{info["train_test"][0]}_test_{info["train_test"][1]}/Batch_size_{info["batch_size"]}/epoch_{info["epochs"]}/{info_opt["ID"]}/Results_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pkl'
130
+
131
+ data_ = kit.read_pkl_data(pkl_path)
132
+
133
+ param_str = ParametersHub.opt_paras_str(info["optimizer"][optimizer_name])
134
+
135
+ # draw_data[data_name][optimizer_name] = data_[param_str][info["metric_key"]]
136
+ # draw_data[data_name][optimizer_name][param_str] = param_str
137
+ # Store both metric list and parameter string
138
+ draw_data[data_name][optimizer_name] = {
139
+ "metrics": data_[param_str][info["metric_key"]],
140
+ "param_str": param_str
141
+ }
142
+
143
+ return draw_data
144
+
145
+
146
+
147
+ def Mul_Plot(model_name, info_dict, Exp_name = "SPBM", cols = 3, save_path = None, save_name = None, fig_show = False):
148
+ # matplotlib settings
149
+ mpl.rcParams['font.family'] = 'Times New Roman'
150
+ mpl.rcParams["mathtext.fontset"] = "stix"
151
+ mpl.rcParams["axes.unicode_minus"] = False
152
+ mpl.rcParams["font.size"] = 12
153
+ mpl.rcParams["font.family"] = "serif"
154
+
155
+ # Read data
156
+ draw_data = defaultdict(dict)
157
+ for data_name, info in info_dict.items():
158
+ for optimizer_name, info_opt in info["optimizer"].items():
159
+
160
+ pkl_path = f'{Exp_name}/seed_{info["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{info["train_test"][0]}_test_{info["train_test"][1]}/Batch_size_{info["batch_size"]}/epoch_{info["epochs"]}/{info_opt["ID"]}/Results_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pkl'
161
+
162
+ data_ = kit.read_pkl_data(pkl_path)
163
+
164
+ param_str = ParametersHub.opt_paras_str(info["optimizer"][optimizer_name])
165
+
166
+ draw_data[data_name][optimizer_name] = data_[param_str][info["metric_key"]]
167
+
168
+
169
+ # Draw figures
170
+ num_datasets = len(draw_data)
171
+
172
+ nrows = math.ceil(num_datasets / cols)
173
+
174
+ fig, axes = plt.subplots(nrows, cols, figsize=(5 * cols, 4 * nrows), squeeze=False)
175
+ axes = axes.flatten()
176
+
177
+ for idx, (data_name, info) in enumerate(draw_data.items()):
178
+ ax = axes[idx]
179
+ for optimizer_name, metric_list in info.items():
180
+ ax.plot(metric_list, label=optimizer_name, color = colors_schedule("SPBM")[optimizer_name])
181
+
182
+ # marker
183
+ if info_dict[data_name]["marker"] is not None:
184
+ x = np.array(info_dict[data_name]["marker"])
185
+
186
+ metric_list_arr = np.array(metric_list)
187
+
188
+ ax.scatter(x, metric_list_arr[x], marker=marker_schedule("SPBM")[optimizer_name], color = colors_schedule("SPBM")[optimizer_name])
189
+
190
+ ax.set_title(f'{data_name}', fontsize=12)
191
+ ax.set_xlabel("epochs", fontsize=12)
192
+ ax.set_ylabel(ParametersHub.fig_ylabel(info_dict[data_name]["metric_key"]), fontsize=12)
193
+ if any(k in info_dict[data_name]["metric_key"] for k in ['loss', 'grad']):
194
+ ax.set_yscale("log")
195
+ ax.grid(True)
196
+
197
+ # Hide redundant axes
198
+ for ax in axes[num_datasets:]:
199
+ ax.axis('off')
200
+
201
+ # legend
202
+ all_handles, all_labels = [], []
203
+ for ax in axes[:num_datasets]:
204
+ h, l = ax.get_legend_handles_labels()
205
+ all_handles.extend(h)
206
+ all_labels.extend(l)
207
+
208
+ # duplicate removal
209
+ unique = dict(zip(all_labels, all_handles))
210
+ handles = list(unique.values())
211
+ labels = list(unique.keys())
212
+
213
+ fig.legend(
214
+ handles,
215
+ labels,
216
+ loc="lower center",
217
+ bbox_to_anchor=(0.5, -0.08),
218
+ ncol=len(handles),
219
+ fontsize=12
220
+ )
221
+
222
+ plt.tight_layout()
223
+ if save_name is None:
224
+ save_path = f'{model_name}.pdf'
225
+ else:
226
+ os.makedirs(save_name, exist_ok=True)
227
+ save_path = f'{save_name}/{save_name}.pdf'
228
+ plt.savefig(save_path, bbox_inches="tight")
229
+ if fig_show:
230
+ plt.show()
231
+ plt.close() # Colse the fig
232
+
233
+
234
+
235
+ def Opt_Paras_Plot(model_name, info_dict, Exp_name = "SPBM", svae_path = None, save_name = None, fig_show = False):
236
+
237
+ mpl.rcParams['font.family'] = 'Times New Roman'
238
+ mpl.rcParams["mathtext.fontset"] = "stix"
239
+ mpl.rcParams["axes.unicode_minus"] = False
240
+ mpl.rcParams["font.size"] = 12
241
+ mpl.rcParams["font.family"] = "serif"
242
+
243
+ # Read data
244
+ draw_data = Read_Results_from_pkl(info_dict, Exp_name, model_name)
245
+
246
+ if len(draw_data) >1:
247
+ print('*' * 40)
248
+ print("Only one data can be drawn at a time.")
249
+ print(info_dict.keys())
250
+ print('*' * 40)
251
+ assert False
252
+
253
+ plt.figure(figsize=(9, 6)) # Optional: set figure size
254
+
255
+ data_name = None
256
+
257
+ for data_name, _info in draw_data.items():
258
+ for optimizer_name, metric_dict in _info.items():
259
+ plt.plot(metric_dict["metrics"], label=f'{optimizer_name}_{metric_dict["param_str"]}',
260
+ color=colors_schedule("SPBM")[optimizer_name])
261
+
262
+ if data_name is not None:
263
+ plt.title(f'{data_name}')
264
+
265
+ plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=1)
266
+ plt.grid(True)
267
+
268
+ if any(k in info_dict[data_name]["metric_key"] for k in ['loss', 'grad']):
269
+ plt.yscale("log")
270
+
271
+ plt.tight_layout() # Adjust layout so the legend fits
272
+ plt.xlabel("epochs") # Or whatever your x-axis represents
273
+ plt.ylabel(f'{ParametersHub.fig_ylabel(info_dict[data_name]["metric_key"])}')
274
+ if save_name is None:
275
+ save_path = f'{model_name}.pdf'
276
+ else:
277
+ os.makedirs(save_name, exist_ok=True)
278
+ save_path = f'{save_name}/{save_name}.pdf'
279
+ plt.savefig(save_path, bbox_inches="tight")
280
+
281
+ if fig_show:
282
+ plt.show()
283
+
284
+ plt.close()
285
+
286
+
@@ -0,0 +1,239 @@
1
+ import torchvision,torch, random
2
+ import numpy as np
3
+ from torchvision.models import resnet18,resnet34, ResNet18_Weights, ResNet34_Weights
4
+ import torch.nn as nn
5
+
6
+
7
+ # ---------------- Build ResNet18 - Caltech101 -----------------------
8
+ def Build_ResNet18_Caltech101_Resize_32():
9
+
10
+ """
11
+ 1. Modify the first convolutional layer for smaller input (e.g., 32x32 instead of 224x224)
12
+ Original: kernel_size=7, stride=2, padding=3 → changed to 3x3 kernel, stride=1, padding=1
13
+
14
+ 2. Adjust the final fully connected layer to match the number of Caltech101 classes (101)
15
+ """
16
+ model = resnet18(weights=None)
17
+ model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # 1
18
+ model.fc = nn.Linear(model.fc.in_features, 101) # 2
19
+
20
+ return model
21
+
22
+
23
+ # ---------------- Build ResNet18 - CIFAR100 -----------------------
24
+ def Build_ResNet18_CIFAR100():
25
+ """
26
+ 1. Modify the first convolutional layer for smaller input (e.g., 32x32 instead of 224x224)
27
+ Original: kernel_size=7, stride=2, padding=3 → changed to 3x3 kernel, stride=1, padding=1
28
+
29
+ 2. Adjust the final fully connected layer to match the number of CIFAR-100 classes (100)
30
+ """
31
+
32
+ model = resnet18(weights=None)
33
+ # model = resnet18(weights=ResNet18_Weights.DEFAULT)
34
+ model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # 1
35
+ model.fc = nn.Linear(model.fc.in_features, 100) # 2
36
+
37
+ return model
38
+
39
+
40
+ # ---------------- Build ResNet18 - MNIST ----------------------------
41
+ def Build_ResNet18_MNIST():
42
+ """
43
+ 1. Modify the first convolutional layer to accept grayscale input (1 channel instead of 3)
44
+ Original: in_channels=3 → changed to in_channels=1
45
+
46
+ 2. Adjust the final fully connected layer to match the number of MNIST classes (10)
47
+ """
48
+
49
+ model = resnet18(weights=None)
50
+ model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # 1
51
+ model.fc = nn.Linear(model.fc.in_features, 10) # 2
52
+
53
+ return model
54
+
55
+
56
+ # ---------------- Build ResNet34 - CIFAR100 -----------------------
57
+ def Build_ResNet34_CIFAR100():
58
+
59
+ model = resnet34(weights=None)
60
+ model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
61
+ model.fc = nn.Linear(model.fc.in_features, 100)
62
+ return model
63
+
64
+ # ---------------- Build ResNet18 - MNIST ----------------------------
65
+ def Build_ResNet34_MNIST():
66
+ # Do not load the pre-trained weights
67
+ model = resnet34(weights=None)
68
+
69
+ model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
70
+ model.fc = nn.Linear(model.fc.in_features, 10)
71
+
72
+ return model
73
+
74
+ # ---------------- Build ResNet34 - Caltech101 -----------------------
75
+ def Build_ResNet34_Caltech101_Resize_32():
76
+
77
+ model = resnet34(weights=None)
78
+ model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
79
+ model.fc = nn.Linear(model.fc.in_features, 101)
80
+ return model
81
+
82
+
83
+ #**************************************************************
84
+ # ---------------------- LeastSquares -------------------------
85
+ #**************************************************************
86
+ # ---------------- LeastSquares - MNIST -----------------------
87
+ def Build_LeastSquares_MNIST():
88
+ """
89
+ 1. flatten MNIST images (1x28x28 → 784)
90
+ 2. Use a linear layer for multi-classification
91
+ """
92
+ return nn.Sequential(
93
+ nn.Flatten(),
94
+ nn.Linear(28 * 28, 10))
95
+
96
+ # ---------------- LeastSquares - CIFAR100 --------------------
97
+ def Build_LeastSquares_CIFAR100():
98
+ """
99
+ 1. flatten MNIST images (3 * 32 * 32 → 784)
100
+ 2. Use a linear layer for multi-classification
101
+ """
102
+ return nn.Sequential(
103
+ nn.Flatten(),
104
+ nn.Linear(3 * 32 * 32, 100))
105
+
106
+ # ---------------- LeastSquares - Caltech101 ------------------
107
+ def Build_LeastSquares_Caltech101_Resize_32():
108
+ return nn.Sequential(
109
+ nn.Flatten(),
110
+ nn.Linear(3*32*32, 101)
111
+ )
112
+
113
+
114
+ #*************************************************************
115
+ # --------------- LogRegressionBinary ------------------------
116
+ #*************************************************************
117
+ # -------------- LogRegressionBinary - MNIST ------------------
118
+ def Build_LogRegressionBinary_MNIST():
119
+ """
120
+ 1. flatten MNIST images (1x28x28 → 784)
121
+ 2. Use a linear layer for binary classification
122
+ """
123
+ return nn.Sequential(
124
+ nn.Flatten(),
125
+ nn.Linear(28 * 28, 1))
126
+
127
+
128
+ # --------------- LogRegressionBinary - CIFAR100 --------------
129
+ def Build_LogRegressionBinary_CIFAR100():
130
+ """
131
+ 1. flatten CIFAR100 images
132
+ 2. Use a linear layer for binary classification
133
+ """
134
+ return nn.Sequential(
135
+ nn.Flatten(),
136
+ nn.Linear(3* 32 * 32, 1))
137
+
138
+ # -------------- LogRegressionBinary - RCV1 ------------------
139
+ def Build_LogRegressionBinary_RCV1():
140
+ """
141
+ 1. Use a linear layer for binary classification
142
+ """
143
+ return nn.Sequential(
144
+ nn.Linear(47236, 1))
145
+
146
+ # <LogRegressionBinaryL2>
147
+ #**************************************************************
148
+ # ------------- LogRegressionBinaryL2 -------------------------
149
+ #**************************************************************
150
+ def Build_LogRegressionBinaryL2_RCV1():
151
+ """
152
+ 1. Use a linear layer for binary classification
153
+ """
154
+ return nn.Sequential(
155
+ nn.Linear(47236, 1))
156
+ # <LogRegressionBinaryL2>
157
+
158
+ # ---------------------------------------------------------
159
+ def Build_LogRegressionBinaryL2_MNIST():
160
+ """
161
+ 1. flatten MNIST images (1x28x28 -> 784)
162
+ 2. Use a linear layer for binary classification
163
+ """
164
+ return nn.Sequential(
165
+ nn.Flatten(),
166
+ nn.Linear(28 * 28, 1))
167
+
168
+ # ---------------------------------------------------------
169
+ def Build_LogRegressionBinaryL2_CIFAR100():
170
+ """
171
+ 1. flatten CIFAR100 images
172
+ 2. Use a linear layer for binary classification
173
+ """
174
+ return nn.Sequential(
175
+ nn.Flatten(),
176
+ nn.Linear(3* 32 * 32, 1))
177
+
178
+ # ---------------------------------------------------------
179
+ def Build_LogRegressionBinaryL2_Duke():
180
+ """
181
+ Use a linear layer for binary classification
182
+ """
183
+ return nn.Sequential(
184
+ nn.Flatten(),
185
+ nn.Linear(7129, 1))
186
+
187
+ # ---------------------------------------------------------
188
+ def Build_LogRegressionBinaryL2_Ijcnn():
189
+ """
190
+ Use a linear layer for binary classification
191
+ """
192
+ return nn.Sequential(
193
+ nn.Flatten(),
194
+ nn.Linear(22, 1))
195
+
196
+ # ---------------------------------------------------------
197
+ def Build_LogRegressionBinaryL2_w8a():
198
+ """
199
+ Use a linear layer for binary classification
200
+ """
201
+ return nn.Sequential(
202
+ nn.Flatten(),
203
+ nn.Linear(300, 1))
204
+
205
+ # ---------------------------------------------------------
206
+ def Build_LogRegressionBinaryL2_Adult_Income_Prediction():
207
+ return nn.Sequential(
208
+ nn.Linear(108, 1))
209
+
210
+
211
+ def Build_LogRegressionBinaryL2_Credit_Card_Fraud_Detection():
212
+ return nn.Sequential(
213
+ nn.Linear(30, 1))
214
+
215
+
216
+ def Build_LogRegressionBinaryL2_Diabetes_Health_Indicators():
217
+ return nn.Sequential(
218
+ nn.Linear(52, 1))
219
+
220
+
221
+ def Build_LogRegressionBinaryL2_Electric_Vehicle_Population():
222
+ return nn.Sequential(
223
+ nn.Linear(835, 1))
224
+
225
+ def Build_LogRegressionBinaryL2_Global_House_Purchase():
226
+ return nn.Sequential(
227
+ nn.Linear(81, 1))
228
+
229
+ def Build_LogRegressionBinaryL2_Health_Lifestyle():
230
+ return nn.Sequential(
231
+ nn.Linear(15, 1))
232
+
233
+ def Build_LogRegressionBinaryL2_Homesite_Quote_Conversion():
234
+ return nn.Sequential(
235
+ nn.Linear(655, 1))
236
+
237
+ def Build_LogRegressionBinaryL2_TN_Weather_2020_2025():
238
+ return nn.Sequential(
239
+ nn.Linear(121, 1))
@@ -0,0 +1,130 @@
1
+ from junshan_kit.OptimizerHup import SPBM, SPBM_func
2
+ import torch, time, os
3
+ from torch.optim.optimizer import Optimizer
4
+ from torch.nn.utils import parameters_to_vector, vector_to_parameters
5
+
6
+
7
+ class SPSmax(Optimizer):
8
+ def __init__(self, params, model, hyperparams, Paras):
9
+ defaults = dict()
10
+ super().__init__(params, defaults)
11
+ self.model = model
12
+ self.c = hyperparams['c']
13
+ self.gamma = hyperparams['gamma']
14
+ if 'f_star' not in Paras or Paras['f_star'] is None:
15
+ self.f_star = 0
16
+ else:
17
+ self.f_star = Paras['f_star']
18
+ self.step_size = []
19
+
20
+ def step(self, closure=None):
21
+ if closure is None:
22
+ raise RuntimeError("Closure required for SPSmax")
23
+
24
+ # Reset the gradient and perform forward computation
25
+ loss = closure()
26
+
27
+ with torch.no_grad():
28
+ xk = parameters_to_vector(self.model.parameters())
29
+ # print(torch.norm(xk))
30
+ g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
31
+
32
+ # Step-size
33
+ step_size = (loss - self.f_star) / ((self.c * torch.norm(g_k, p=2) ** 2) + 1e-8)
34
+ step_size = min(step_size, self.gamma)
35
+ self.step_size.append(step_size)
36
+
37
+ # Update
38
+ xk = xk - step_size * g_k
39
+
40
+ # print(len(self.f_his))
41
+ vector_to_parameters(xk, self.model.parameters())
42
+
43
+
44
+ # emporarily return loss (tensor type)
45
+ return loss
46
+
47
+
48
+ class ALR_SMAG(Optimizer):
49
+ def __init__(self, params, model, hyperparams, Paras):
50
+ defaults = dict()
51
+ super().__init__(params, defaults)
52
+ self.model = model
53
+ self.c = hyperparams['c']
54
+ self.eta_max = hyperparams['eta_max']
55
+ self.beta = hyperparams['beta']
56
+ if 'f_star' not in Paras or Paras['f_star'] is None:
57
+ self.f_star = 0
58
+ else:
59
+ self.f_star = Paras['f_star']
60
+ self.step_size = []
61
+ self.d_k = torch.zeros_like(parameters_to_vector(self.model.parameters()))
62
+
63
+ def step(self, closure=None):
64
+ if closure is None:
65
+ raise RuntimeError("Closure required for SPSmax")
66
+
67
+ # Reset the gradient and perform forward computation
68
+ loss = closure()
69
+
70
+ with torch.no_grad():
71
+ xk = parameters_to_vector(self.model.parameters())
72
+ # print(torch.norm(xk))
73
+ g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
74
+
75
+ self.d_k = self.beta * self.d_k + g_k
76
+ # Step-size
77
+ step_size = (loss - self.f_star) / ((self.c * torch.norm(self.d_k, p=2) ** 2) + 1e-8)
78
+ step_size = min(step_size, self.eta_max)
79
+ self.step_size.append(step_size)
80
+
81
+ # Update
82
+ xk = xk - step_size * g_k
83
+
84
+ # print(len(self.f_his))
85
+ vector_to_parameters(xk, self.model.parameters())
86
+
87
+
88
+ # emporarily return loss (tensor type)
89
+ return loss
90
+
91
+
92
+ class Bundle(Optimizer):
93
+ def __init__(self, params, model, hyperparams, Paras):
94
+ defaults = dict()
95
+ super().__init__(params, defaults)
96
+ self.model = model
97
+ self.cutting_num = hyperparams['cutting_number']
98
+ self.delta = hyperparams['delta']
99
+ self.Paras = Paras
100
+
101
+ self.x_his, self.g_his, self.f_his = [], [], []
102
+
103
+ def step(self, closure=None):
104
+ if closure is None:
105
+ raise RuntimeError("Closure required for CuttingPlaneOptimizer")
106
+
107
+ # Reset the gradient and perform forward computation
108
+ loss = closure()
109
+
110
+ with torch.no_grad():
111
+ xk = parameters_to_vector(self.model.parameters())
112
+ # print(torch.norm(xk))
113
+ g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
114
+
115
+ # Add cutting plane
116
+ x_his, f_his, g_his = SPBM_func.add_cutting(self.x_his, self.f_his, self.g_his,xk.detach().clone(), g_k.detach().clone(), loss.detach().clone(), self.cutting_num)
117
+
118
+ # the coefficient of dual problem
119
+ Gk, rk, ek = SPBM_func.get_var(x_his, f_his, g_his, self.delta)
120
+
121
+ # SOVER (dual)
122
+ xk = SPBM_func.bundle(Gk, ek, xk, self.delta, self.Paras)
123
+
124
+ # print(len(self.f_his))
125
+ vector_to_parameters(xk, self.model.parameters())
126
+
127
+ # loss(tensor)
128
+ return loss
129
+
130
+