junshan-kit 2.5.1__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +108 -8
- junshan_kit/DataProcessor.py +86 -8
- junshan_kit/DataSets.py +29 -30
- junshan_kit/Evaluate_Metrics.py +75 -2
- junshan_kit/FiguresHub.py +286 -0
- junshan_kit/ModelsHub.py +32 -5
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +350 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +390 -119
- junshan_kit/Print_Info.py +57 -11
- junshan_kit/TrainingHub.py +190 -40
- junshan_kit/kit.py +39 -50
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.7.3.dist-info}/METADATA +7 -1
- junshan_kit-2.7.3.dist-info/RECORD +20 -0
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.7.3.dist-info}/WHEEL +1 -1
- junshan_kit-2.5.1.dist-info/RECORD +0 -13
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""
|
|
2
|
+
----------------------------------------------------------------------
|
|
3
|
+
>>> Author : Junshan Yin
|
|
4
|
+
>>> Last Updated : 2025-11-14
|
|
5
|
+
----------------------------------------------------------------------
|
|
6
|
+
"""
|
|
7
|
+
import math, os
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import matplotlib as mpl
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from junshan_kit import kit, ParametersHub
|
|
13
|
+
|
|
14
|
+
def marker_schedule(marker_schedule=None):
|
|
15
|
+
|
|
16
|
+
if marker_schedule == "SPBM":
|
|
17
|
+
based_marker = {
|
|
18
|
+
"ADAM": "s", # square
|
|
19
|
+
"ALR-SMAG": "h", # pixel marker
|
|
20
|
+
"Bundle": "o", # circle
|
|
21
|
+
"SGD": "p", # pentagon
|
|
22
|
+
"SPSmax": "4", # tri-right
|
|
23
|
+
"SPBM-PF": "*", # star
|
|
24
|
+
"SPBM-TR": "s", # star
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
based_marker = {
|
|
29
|
+
"point": ".", # point marker
|
|
30
|
+
"pixel": ",", # pixel marker
|
|
31
|
+
"circle": "o", # circle
|
|
32
|
+
"triangle_down": "v", # down triangle
|
|
33
|
+
"triangle_up": "^", # up triangle
|
|
34
|
+
"triangle_left": "<", # left triangle
|
|
35
|
+
"triangle_right": ">", # right triangle
|
|
36
|
+
"tri_down": "1", # tri-down
|
|
37
|
+
"tri_up": "2", # tri-up
|
|
38
|
+
"tri_left": "3", # tri-left
|
|
39
|
+
"tri_right": "4", # tri-right
|
|
40
|
+
"square": "s", # square
|
|
41
|
+
"pentagon": "p", # pentagon
|
|
42
|
+
"star": "*", # star
|
|
43
|
+
"hexagon1": "h", # hexagon 1
|
|
44
|
+
"hexagon2": "H", # hexagon 2
|
|
45
|
+
"plus": "+", # plus
|
|
46
|
+
"x": "x", # x
|
|
47
|
+
"diamond": "D", # diamond
|
|
48
|
+
"thin_diamond": "d", # thin diamond
|
|
49
|
+
"vline": "|", # vertical line
|
|
50
|
+
"hline": "_", # horizontal line
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
return based_marker
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def colors_schedule(colors_schedule=None):
|
|
57
|
+
|
|
58
|
+
if colors_schedule == "SPBM":
|
|
59
|
+
based_color = {
|
|
60
|
+
"ADAM": "#7f7f7f",
|
|
61
|
+
"ALR-SMAG": "#796378",
|
|
62
|
+
"Bundle": "#17becf",
|
|
63
|
+
"SGD": "#2ca02c",
|
|
64
|
+
"SPSmax": "#BA6262",
|
|
65
|
+
"SPBM-PF": "#1f77b4",
|
|
66
|
+
"SPBM-TR": "#d62728",
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
else:
|
|
70
|
+
based_color = {
|
|
71
|
+
"ADAM": "#1f77b4",
|
|
72
|
+
"ALR-SMAG": "#ff7f0e",
|
|
73
|
+
"Bundle": "#2ca02c",
|
|
74
|
+
"SGD": "#d62728",
|
|
75
|
+
"SPSmax": "#9467bd",
|
|
76
|
+
"SPBM-PF": "#8c564b",
|
|
77
|
+
"SPBM-TR": "#e377c2",
|
|
78
|
+
"dddd": "#7f7f7f",
|
|
79
|
+
"xxx": "#bcbd22",
|
|
80
|
+
"ED": "#17becf",
|
|
81
|
+
}
|
|
82
|
+
return based_color
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def Search_Paras(Paras, args, model_name, data_name, optimizer_name, metric_key = "training_loss"):
|
|
86
|
+
|
|
87
|
+
param_dict = Paras["Results_dict"][model_name][data_name][optimizer_name]
|
|
88
|
+
|
|
89
|
+
num_polts = len(param_dict)
|
|
90
|
+
cols = 3
|
|
91
|
+
rows = math.ceil(num_polts / cols)
|
|
92
|
+
|
|
93
|
+
fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
|
|
94
|
+
axes = axes.flatten()
|
|
95
|
+
|
|
96
|
+
for idx, (param_str, info) in enumerate(param_dict.items()):
|
|
97
|
+
ax = axes[idx]
|
|
98
|
+
metric_list = info.get(metric_key, [])
|
|
99
|
+
# duration = info.get('duration', 0)
|
|
100
|
+
ax.plot(metric_list)
|
|
101
|
+
# ax.set_title(f"time:{duration:.8f}s - seed: {Paras['seed']}, ID: {Paras['time_str']} \n params = {param_str}", fontsize=10)
|
|
102
|
+
ax.set_title(f'time = {info["train_time"]:.2f}, seed: {Paras["seed"]}, ID: {Paras["time_str"]} \n params = {param_str}', fontsize=10)
|
|
103
|
+
ax.set_xlabel("epochs")
|
|
104
|
+
ax.set_ylabel(ParametersHub.fig_ylabel(metric_key))
|
|
105
|
+
ax.grid(True)
|
|
106
|
+
if Paras.get('use_log_scale', False) and any(k in metric_key for k in ['loss', 'grad']):
|
|
107
|
+
ax.set_yscale("log")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# Delete the redundant subfigures
|
|
111
|
+
for i in range(len(param_dict), len(axes)):
|
|
112
|
+
fig.delaxes(axes[i])
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
plt.suptitle(f'{model_name} on {data_name} - {optimizer_name} (training/test samples: {Paras["train_data_num"]}/{Paras["test_data_num"]}), {Paras["device"]}', fontsize=16)
|
|
116
|
+
plt.tight_layout(rect=(0, 0, 1, 0.9))
|
|
117
|
+
|
|
118
|
+
filename = f'{Paras["Results_folder"]}/{metric_key}_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pdf'
|
|
119
|
+
fig.savefig(filename)
|
|
120
|
+
print(f"✅ Saved: {filename}")
|
|
121
|
+
plt.close('all')
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def Read_Results_from_pkl(info_dict, Exp_name, model_name):
|
|
125
|
+
draw_data = defaultdict(dict)
|
|
126
|
+
for data_name, info in info_dict.items():
|
|
127
|
+
for optimizer_name, info_opt in info["optimizer"].items():
|
|
128
|
+
|
|
129
|
+
pkl_path = f'{Exp_name}/seed_{info["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{info["train_test"][0]}_test_{info["train_test"][1]}/Batch_size_{info["batch_size"]}/epoch_{info["epochs"]}/{info_opt["ID"]}/Results_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pkl'
|
|
130
|
+
|
|
131
|
+
data_ = kit.read_pkl_data(pkl_path)
|
|
132
|
+
|
|
133
|
+
param_str = ParametersHub.opt_paras_str(info["optimizer"][optimizer_name])
|
|
134
|
+
|
|
135
|
+
# draw_data[data_name][optimizer_name] = data_[param_str][info["metric_key"]]
|
|
136
|
+
# draw_data[data_name][optimizer_name][param_str] = param_str
|
|
137
|
+
# Store both metric list and parameter string
|
|
138
|
+
draw_data[data_name][optimizer_name] = {
|
|
139
|
+
"metrics": data_[param_str][info["metric_key"]],
|
|
140
|
+
"param_str": param_str
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return draw_data
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def Mul_Plot(model_name, info_dict, Exp_name = "SPBM", cols = 3, save_path = None, save_name = None, fig_show = False):
|
|
148
|
+
# matplotlib settings
|
|
149
|
+
mpl.rcParams['font.family'] = 'Times New Roman'
|
|
150
|
+
mpl.rcParams["mathtext.fontset"] = "stix"
|
|
151
|
+
mpl.rcParams["axes.unicode_minus"] = False
|
|
152
|
+
mpl.rcParams["font.size"] = 12
|
|
153
|
+
mpl.rcParams["font.family"] = "serif"
|
|
154
|
+
|
|
155
|
+
# Read data
|
|
156
|
+
draw_data = defaultdict(dict)
|
|
157
|
+
for data_name, info in info_dict.items():
|
|
158
|
+
for optimizer_name, info_opt in info["optimizer"].items():
|
|
159
|
+
|
|
160
|
+
pkl_path = f'{Exp_name}/seed_{info["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{info["train_test"][0]}_test_{info["train_test"][1]}/Batch_size_{info["batch_size"]}/epoch_{info["epochs"]}/{info_opt["ID"]}/Results_{ParametersHub.model_abbr(model_name)}_{data_name}_{optimizer_name}.pkl'
|
|
161
|
+
|
|
162
|
+
data_ = kit.read_pkl_data(pkl_path)
|
|
163
|
+
|
|
164
|
+
param_str = ParametersHub.opt_paras_str(info["optimizer"][optimizer_name])
|
|
165
|
+
|
|
166
|
+
draw_data[data_name][optimizer_name] = data_[param_str][info["metric_key"]]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# Draw figures
|
|
170
|
+
num_datasets = len(draw_data)
|
|
171
|
+
|
|
172
|
+
nrows = math.ceil(num_datasets / cols)
|
|
173
|
+
|
|
174
|
+
fig, axes = plt.subplots(nrows, cols, figsize=(5 * cols, 4 * nrows), squeeze=False)
|
|
175
|
+
axes = axes.flatten()
|
|
176
|
+
|
|
177
|
+
for idx, (data_name, info) in enumerate(draw_data.items()):
|
|
178
|
+
ax = axes[idx]
|
|
179
|
+
for optimizer_name, metric_list in info.items():
|
|
180
|
+
ax.plot(metric_list, label=optimizer_name, color = colors_schedule("SPBM")[optimizer_name])
|
|
181
|
+
|
|
182
|
+
# marker
|
|
183
|
+
if info_dict[data_name]["marker"] is not None:
|
|
184
|
+
x = np.array(info_dict[data_name]["marker"])
|
|
185
|
+
|
|
186
|
+
metric_list_arr = np.array(metric_list)
|
|
187
|
+
|
|
188
|
+
ax.scatter(x, metric_list_arr[x], marker=marker_schedule("SPBM")[optimizer_name], color = colors_schedule("SPBM")[optimizer_name])
|
|
189
|
+
|
|
190
|
+
ax.set_title(f'{data_name}', fontsize=12)
|
|
191
|
+
ax.set_xlabel("epochs", fontsize=12)
|
|
192
|
+
ax.set_ylabel(ParametersHub.fig_ylabel(info_dict[data_name]["metric_key"]), fontsize=12)
|
|
193
|
+
if any(k in info_dict[data_name]["metric_key"] for k in ['loss', 'grad']):
|
|
194
|
+
ax.set_yscale("log")
|
|
195
|
+
ax.grid(True)
|
|
196
|
+
|
|
197
|
+
# Hide redundant axes
|
|
198
|
+
for ax in axes[num_datasets:]:
|
|
199
|
+
ax.axis('off')
|
|
200
|
+
|
|
201
|
+
# legend
|
|
202
|
+
all_handles, all_labels = [], []
|
|
203
|
+
for ax in axes[:num_datasets]:
|
|
204
|
+
h, l = ax.get_legend_handles_labels()
|
|
205
|
+
all_handles.extend(h)
|
|
206
|
+
all_labels.extend(l)
|
|
207
|
+
|
|
208
|
+
# duplicate removal
|
|
209
|
+
unique = dict(zip(all_labels, all_handles))
|
|
210
|
+
handles = list(unique.values())
|
|
211
|
+
labels = list(unique.keys())
|
|
212
|
+
|
|
213
|
+
fig.legend(
|
|
214
|
+
handles,
|
|
215
|
+
labels,
|
|
216
|
+
loc="lower center",
|
|
217
|
+
bbox_to_anchor=(0.5, -0.08),
|
|
218
|
+
ncol=len(handles),
|
|
219
|
+
fontsize=12
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
plt.tight_layout()
|
|
223
|
+
if save_name is None:
|
|
224
|
+
save_path = f'{model_name}.pdf'
|
|
225
|
+
else:
|
|
226
|
+
os.makedirs(save_name, exist_ok=True)
|
|
227
|
+
save_path = f'{save_name}/{save_name}.pdf'
|
|
228
|
+
plt.savefig(save_path, bbox_inches="tight")
|
|
229
|
+
if fig_show:
|
|
230
|
+
plt.show()
|
|
231
|
+
plt.close() # Colse the fig
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def Opt_Paras_Plot(model_name, info_dict, Exp_name = "SPBM", svae_path = None, save_name = None, fig_show = False):
|
|
236
|
+
|
|
237
|
+
mpl.rcParams['font.family'] = 'Times New Roman'
|
|
238
|
+
mpl.rcParams["mathtext.fontset"] = "stix"
|
|
239
|
+
mpl.rcParams["axes.unicode_minus"] = False
|
|
240
|
+
mpl.rcParams["font.size"] = 12
|
|
241
|
+
mpl.rcParams["font.family"] = "serif"
|
|
242
|
+
|
|
243
|
+
# Read data
|
|
244
|
+
draw_data = Read_Results_from_pkl(info_dict, Exp_name, model_name)
|
|
245
|
+
|
|
246
|
+
if len(draw_data) >1:
|
|
247
|
+
print('*' * 40)
|
|
248
|
+
print("Only one data can be drawn at a time.")
|
|
249
|
+
print(info_dict.keys())
|
|
250
|
+
print('*' * 40)
|
|
251
|
+
assert False
|
|
252
|
+
|
|
253
|
+
plt.figure(figsize=(9, 6)) # Optional: set figure size
|
|
254
|
+
|
|
255
|
+
data_name = None
|
|
256
|
+
|
|
257
|
+
for data_name, _info in draw_data.items():
|
|
258
|
+
for optimizer_name, metric_dict in _info.items():
|
|
259
|
+
plt.plot(metric_dict["metrics"], label=f'{optimizer_name}_{metric_dict["param_str"]}',
|
|
260
|
+
color=colors_schedule("SPBM")[optimizer_name])
|
|
261
|
+
|
|
262
|
+
if data_name is not None:
|
|
263
|
+
plt.title(f'{data_name}')
|
|
264
|
+
|
|
265
|
+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=1)
|
|
266
|
+
plt.grid(True)
|
|
267
|
+
|
|
268
|
+
if any(k in info_dict[data_name]["metric_key"] for k in ['loss', 'grad']):
|
|
269
|
+
plt.yscale("log")
|
|
270
|
+
|
|
271
|
+
plt.tight_layout() # Adjust layout so the legend fits
|
|
272
|
+
plt.xlabel("epochs") # Or whatever your x-axis represents
|
|
273
|
+
plt.ylabel(f'{ParametersHub.fig_ylabel(info_dict[data_name]["metric_key"])}')
|
|
274
|
+
if save_name is None:
|
|
275
|
+
save_path = f'{model_name}.pdf'
|
|
276
|
+
else:
|
|
277
|
+
os.makedirs(save_name, exist_ok=True)
|
|
278
|
+
save_path = f'{save_name}/{save_name}.pdf'
|
|
279
|
+
plt.savefig(save_path, bbox_inches="tight")
|
|
280
|
+
|
|
281
|
+
if fig_show:
|
|
282
|
+
plt.show()
|
|
283
|
+
|
|
284
|
+
plt.close()
|
|
285
|
+
|
|
286
|
+
|
junshan_kit/ModelsHub.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
# ---------------- Build ResNet18 - Caltech101 -----------------------
|
|
8
|
-
def
|
|
8
|
+
def Build_ResNet18_Caltech101_Resize_32():
|
|
9
9
|
|
|
10
10
|
"""
|
|
11
11
|
1. Modify the first convolutional layer for smaller input (e.g., 32x32 instead of 224x224)
|
|
@@ -72,7 +72,7 @@ def Build_ResNet34_MNIST():
|
|
|
72
72
|
return model
|
|
73
73
|
|
|
74
74
|
# ---------------- Build ResNet34 - Caltech101 -----------------------
|
|
75
|
-
def
|
|
75
|
+
def Build_ResNet34_Caltech101_Resize_32():
|
|
76
76
|
|
|
77
77
|
model = resnet34(weights=None)
|
|
78
78
|
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
|
|
@@ -104,7 +104,7 @@ def Build_LeastSquares_CIFAR100():
|
|
|
104
104
|
nn.Linear(3 * 32 * 32, 100))
|
|
105
105
|
|
|
106
106
|
# ---------------- LeastSquares - Caltech101 ------------------
|
|
107
|
-
def
|
|
107
|
+
def Build_LeastSquares_Caltech101_Resize_32():
|
|
108
108
|
return nn.Sequential(
|
|
109
109
|
nn.Flatten(),
|
|
110
110
|
nn.Linear(3*32*32, 101)
|
|
@@ -204,9 +204,36 @@ def Build_LogRegressionBinaryL2_w8a():
|
|
|
204
204
|
|
|
205
205
|
# ---------------------------------------------------------
|
|
206
206
|
def Build_LogRegressionBinaryL2_Adult_Income_Prediction():
|
|
207
|
-
|
|
207
|
+
return nn.Sequential(
|
|
208
|
+
nn.Linear(108, 1))
|
|
208
209
|
|
|
209
210
|
|
|
210
211
|
def Build_LogRegressionBinaryL2_Credit_Card_Fraud_Detection():
|
|
211
|
-
|
|
212
|
+
return nn.Sequential(
|
|
213
|
+
nn.Linear(30, 1))
|
|
212
214
|
|
|
215
|
+
|
|
216
|
+
def Build_LogRegressionBinaryL2_Diabetes_Health_Indicators():
|
|
217
|
+
return nn.Sequential(
|
|
218
|
+
nn.Linear(52, 1))
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def Build_LogRegressionBinaryL2_Electric_Vehicle_Population():
|
|
222
|
+
return nn.Sequential(
|
|
223
|
+
nn.Linear(835, 1))
|
|
224
|
+
|
|
225
|
+
def Build_LogRegressionBinaryL2_Global_House_Purchase():
|
|
226
|
+
return nn.Sequential(
|
|
227
|
+
nn.Linear(81, 1))
|
|
228
|
+
|
|
229
|
+
def Build_LogRegressionBinaryL2_Health_Lifestyle():
|
|
230
|
+
return nn.Sequential(
|
|
231
|
+
nn.Linear(15, 1))
|
|
232
|
+
|
|
233
|
+
def Build_LogRegressionBinaryL2_Homesite_Quote_Conversion():
|
|
234
|
+
return nn.Sequential(
|
|
235
|
+
nn.Linear(655, 1))
|
|
236
|
+
|
|
237
|
+
def Build_LogRegressionBinaryL2_TN_Weather_2020_2025():
|
|
238
|
+
return nn.Sequential(
|
|
239
|
+
nn.Linear(121, 1))
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from junshan_kit.OptimizerHup import SPBM, SPBM_func
|
|
2
|
+
import torch, time, os
|
|
3
|
+
from torch.optim.optimizer import Optimizer
|
|
4
|
+
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SPSmax(Optimizer):
|
|
8
|
+
def __init__(self, params, model, hyperparams, Paras):
|
|
9
|
+
defaults = dict()
|
|
10
|
+
super().__init__(params, defaults)
|
|
11
|
+
self.model = model
|
|
12
|
+
self.c = hyperparams['c']
|
|
13
|
+
self.gamma = hyperparams['gamma']
|
|
14
|
+
if 'f_star' not in Paras or Paras['f_star'] is None:
|
|
15
|
+
self.f_star = 0
|
|
16
|
+
else:
|
|
17
|
+
self.f_star = Paras['f_star']
|
|
18
|
+
self.step_size = []
|
|
19
|
+
|
|
20
|
+
def step(self, closure=None):
|
|
21
|
+
if closure is None:
|
|
22
|
+
raise RuntimeError("Closure required for SPSmax")
|
|
23
|
+
|
|
24
|
+
# Reset the gradient and perform forward computation
|
|
25
|
+
loss = closure()
|
|
26
|
+
|
|
27
|
+
with torch.no_grad():
|
|
28
|
+
xk = parameters_to_vector(self.model.parameters())
|
|
29
|
+
# print(torch.norm(xk))
|
|
30
|
+
g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
|
|
31
|
+
|
|
32
|
+
# Step-size
|
|
33
|
+
step_size = (loss - self.f_star) / ((self.c * torch.norm(g_k, p=2) ** 2) + 1e-8)
|
|
34
|
+
step_size = min(step_size, self.gamma)
|
|
35
|
+
self.step_size.append(step_size)
|
|
36
|
+
|
|
37
|
+
# Update
|
|
38
|
+
xk = xk - step_size * g_k
|
|
39
|
+
|
|
40
|
+
# print(len(self.f_his))
|
|
41
|
+
vector_to_parameters(xk, self.model.parameters())
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# emporarily return loss (tensor type)
|
|
45
|
+
return loss
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ALR_SMAG(Optimizer):
|
|
49
|
+
def __init__(self, params, model, hyperparams, Paras):
|
|
50
|
+
defaults = dict()
|
|
51
|
+
super().__init__(params, defaults)
|
|
52
|
+
self.model = model
|
|
53
|
+
self.c = hyperparams['c']
|
|
54
|
+
self.eta_max = hyperparams['eta_max']
|
|
55
|
+
self.beta = hyperparams['beta']
|
|
56
|
+
if 'f_star' not in Paras or Paras['f_star'] is None:
|
|
57
|
+
self.f_star = 0
|
|
58
|
+
else:
|
|
59
|
+
self.f_star = Paras['f_star']
|
|
60
|
+
self.step_size = []
|
|
61
|
+
self.d_k = torch.zeros_like(parameters_to_vector(self.model.parameters()))
|
|
62
|
+
|
|
63
|
+
def step(self, closure=None):
|
|
64
|
+
if closure is None:
|
|
65
|
+
raise RuntimeError("Closure required for SPSmax")
|
|
66
|
+
|
|
67
|
+
# Reset the gradient and perform forward computation
|
|
68
|
+
loss = closure()
|
|
69
|
+
|
|
70
|
+
with torch.no_grad():
|
|
71
|
+
xk = parameters_to_vector(self.model.parameters())
|
|
72
|
+
# print(torch.norm(xk))
|
|
73
|
+
g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
|
|
74
|
+
|
|
75
|
+
self.d_k = self.beta * self.d_k + g_k
|
|
76
|
+
# Step-size
|
|
77
|
+
step_size = (loss - self.f_star) / ((self.c * torch.norm(self.d_k, p=2) ** 2) + 1e-8)
|
|
78
|
+
step_size = min(step_size, self.eta_max)
|
|
79
|
+
self.step_size.append(step_size)
|
|
80
|
+
|
|
81
|
+
# Update
|
|
82
|
+
xk = xk - step_size * g_k
|
|
83
|
+
|
|
84
|
+
# print(len(self.f_his))
|
|
85
|
+
vector_to_parameters(xk, self.model.parameters())
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# emporarily return loss (tensor type)
|
|
89
|
+
return loss
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class Bundle(Optimizer):
|
|
93
|
+
def __init__(self, params, model, hyperparams, Paras):
|
|
94
|
+
defaults = dict()
|
|
95
|
+
super().__init__(params, defaults)
|
|
96
|
+
self.model = model
|
|
97
|
+
self.cutting_num = hyperparams['cutting_number']
|
|
98
|
+
self.delta = hyperparams['delta']
|
|
99
|
+
self.Paras = Paras
|
|
100
|
+
|
|
101
|
+
self.x_his, self.g_his, self.f_his = [], [], []
|
|
102
|
+
|
|
103
|
+
def step(self, closure=None):
|
|
104
|
+
if closure is None:
|
|
105
|
+
raise RuntimeError("Closure required for CuttingPlaneOptimizer")
|
|
106
|
+
|
|
107
|
+
# Reset the gradient and perform forward computation
|
|
108
|
+
loss = closure()
|
|
109
|
+
|
|
110
|
+
with torch.no_grad():
|
|
111
|
+
xk = parameters_to_vector(self.model.parameters())
|
|
112
|
+
# print(torch.norm(xk))
|
|
113
|
+
g_k = parameters_to_vector([p.grad if p.grad is not None else torch.zeros_like(p) for p in self.model.parameters()])
|
|
114
|
+
|
|
115
|
+
# Add cutting plane
|
|
116
|
+
x_his, f_his, g_his = SPBM_func.add_cutting(self.x_his, self.f_his, self.g_his,xk.detach().clone(), g_k.detach().clone(), loss.detach().clone(), self.cutting_num)
|
|
117
|
+
|
|
118
|
+
# the coefficient of dual problem
|
|
119
|
+
Gk, rk, ek = SPBM_func.get_var(x_his, f_his, g_his, self.delta)
|
|
120
|
+
|
|
121
|
+
# SOVER (dual)
|
|
122
|
+
xk = SPBM_func.bundle(Gk, ek, xk, self.delta, self.Paras)
|
|
123
|
+
|
|
124
|
+
# print(len(self.f_his))
|
|
125
|
+
vector_to_parameters(xk, self.model.parameters())
|
|
126
|
+
|
|
127
|
+
# loss(tensor)
|
|
128
|
+
return loss
|
|
129
|
+
|
|
130
|
+
|