junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +214 -0
- junshan_kit/DataProcessor.py +306 -16
- junshan_kit/DataSets.py +330 -18
- junshan_kit/Evaluate_Metrics.py +113 -0
- junshan_kit/FiguresHub.py +286 -0
- junshan_kit/ModelsHub.py +239 -0
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +350 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +690 -0
- junshan_kit/Print_Info.py +109 -0
- junshan_kit/TrainingHub.py +324 -0
- junshan_kit/kit.py +83 -24
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/METADATA +6 -2
- junshan_kit-2.7.3.dist-info/RECORD +20 -0
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/WHEEL +1 -1
- junshan_kit-2.2.8.dist-info/RECORD +0 -7
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import sys, os, torch, random, glob
|
|
3
|
+
import argparse
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
6
|
+
sys.path.append(os.path.join(script_dir, 'src'))
|
|
7
|
+
from junshan_kit import ModelsHub, Check_Info
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class args:
|
|
11
|
+
def __init__(self):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
# <args>
|
|
15
|
+
def get_args(self):
|
|
16
|
+
parser = argparse.ArgumentParser(description="Combined config argument example")
|
|
17
|
+
|
|
18
|
+
# <allowed_models>
|
|
19
|
+
allowed_models = ["LS", "LRBL2", "ResNet18"]
|
|
20
|
+
# <allowed_models>
|
|
21
|
+
|
|
22
|
+
# <allowed_optimizers>
|
|
23
|
+
allowed_optimizers = [
|
|
24
|
+
"ADAM",
|
|
25
|
+
"ALR_SMAG",
|
|
26
|
+
"Bundle",
|
|
27
|
+
"SGD",
|
|
28
|
+
"SPBM_TR",
|
|
29
|
+
"SPBM_PF",
|
|
30
|
+
"SPSmax",
|
|
31
|
+
"SPBM_TR_NoneSpecial",
|
|
32
|
+
"SPBM_TR_NoneLower",
|
|
33
|
+
"SPBM_PF_NoneLower",
|
|
34
|
+
]
|
|
35
|
+
# <allowed_optimizers>
|
|
36
|
+
|
|
37
|
+
# <allowed_datasets>
|
|
38
|
+
allowed_datasets = [
|
|
39
|
+
"MNIST",
|
|
40
|
+
"CIFAR100",
|
|
41
|
+
"Caltech101",
|
|
42
|
+
"AIP",
|
|
43
|
+
"CCFD",
|
|
44
|
+
"Duke",
|
|
45
|
+
"Ijcnn",
|
|
46
|
+
"DHI",
|
|
47
|
+
"EVP",
|
|
48
|
+
"GHP",
|
|
49
|
+
"HL",
|
|
50
|
+
"HQC",
|
|
51
|
+
"TN_Weather",
|
|
52
|
+
],
|
|
53
|
+
# <allowed_datasets>
|
|
54
|
+
data_name_mapping = {
|
|
55
|
+
"MNIST": "MNIST",
|
|
56
|
+
"CIFAR100": "CIFAR100",
|
|
57
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
58
|
+
"Duke": "Duke",
|
|
59
|
+
"AIP": "Adult_Income_Prediction",
|
|
60
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
61
|
+
"Ijcnn": "Ijcnn",
|
|
62
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
63
|
+
"EVP": "Electric_Vehicle_Population",
|
|
64
|
+
"GHP": "Global_House_Purchase",
|
|
65
|
+
"HL": "Health_Lifestyle",
|
|
66
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
67
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
optimizers_mapping = {
|
|
71
|
+
"ADAM": "ADAM",
|
|
72
|
+
"SGD": "SGD",
|
|
73
|
+
"Bundle": "Bundle",
|
|
74
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
75
|
+
"SPBM_TR": "SPBM-TR",
|
|
76
|
+
"SPBM_PF": "SPBM-PF",
|
|
77
|
+
"SPSmax": "SPSmax",
|
|
78
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
79
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
80
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
81
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
82
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
83
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
model_mapping = {
|
|
87
|
+
"LS": "LeastSquares",
|
|
88
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
89
|
+
"ResNet18": "ResNet18"
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# <args_from_command>
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--train",
|
|
95
|
+
type=str,
|
|
96
|
+
nargs="+", # Allow multiple configs
|
|
97
|
+
required=True,
|
|
98
|
+
help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--e",
|
|
103
|
+
type=int,
|
|
104
|
+
required=True,
|
|
105
|
+
help="Number of training epochs. Example: --e 50"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--seed",
|
|
110
|
+
type=int,
|
|
111
|
+
default=42,
|
|
112
|
+
help="Random seed for experiment reproducibility. Default: 42"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--bs",
|
|
117
|
+
type=int,
|
|
118
|
+
required=True,
|
|
119
|
+
help="Batch size for training. Example: --bs 128"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
parser.add_argument(
|
|
123
|
+
"--cuda",
|
|
124
|
+
type=int,
|
|
125
|
+
default=0,
|
|
126
|
+
required=True,
|
|
127
|
+
help="The number of cuda. Example: --cuda 1 (default=0) "
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
parser.add_argument(
|
|
131
|
+
"--s",
|
|
132
|
+
type=float,
|
|
133
|
+
default=1.0,
|
|
134
|
+
# required=True,
|
|
135
|
+
help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
parser.add_argument(
|
|
139
|
+
"--subset",
|
|
140
|
+
type=float,
|
|
141
|
+
nargs=2,
|
|
142
|
+
# required=True,
|
|
143
|
+
help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--time_str",
|
|
148
|
+
type=str,
|
|
149
|
+
nargs=1,
|
|
150
|
+
# required=True,
|
|
151
|
+
help = "the str of time"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
parser.add_argument(
|
|
155
|
+
"--send_email",
|
|
156
|
+
type=str,
|
|
157
|
+
nargs=3,
|
|
158
|
+
# required=True,
|
|
159
|
+
help = "from_email to_email, from_pwd"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--user_search_grid",
|
|
164
|
+
type=int,
|
|
165
|
+
nargs=1,
|
|
166
|
+
# required=True,
|
|
167
|
+
help = "search_grid: 1: "
|
|
168
|
+
)
|
|
169
|
+
# <args_from_command>
|
|
170
|
+
|
|
171
|
+
args = parser.parse_args()
|
|
172
|
+
args.model_name_mapping = model_mapping
|
|
173
|
+
args.data_name_mapping = data_name_mapping
|
|
174
|
+
args.optimizers_name_mapping = optimizers_mapping
|
|
175
|
+
|
|
176
|
+
return args
|
|
177
|
+
# <args>
|
|
178
|
+
|
|
179
|
+
def UpdateOtherParas(args, OtherParas):
|
|
180
|
+
if args.time_str is not None:
|
|
181
|
+
time_str = args.time_str[0]
|
|
182
|
+
else:
|
|
183
|
+
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
184
|
+
|
|
185
|
+
if args.user_search_grid is not None:
|
|
186
|
+
OtherParas["user_search_grid"] = args.user_search_grid[0]
|
|
187
|
+
else:
|
|
188
|
+
OtherParas["user_search_grid"] = None
|
|
189
|
+
|
|
190
|
+
if args.send_email is not None:
|
|
191
|
+
OtherParas["from_email"] = args.send_email[0]
|
|
192
|
+
OtherParas["to_email"] = args.send_email[1]
|
|
193
|
+
OtherParas["from_pwd"] = args.send_email[2]
|
|
194
|
+
OtherParas["send_email"] = True
|
|
195
|
+
else:
|
|
196
|
+
OtherParas["send_email"] = False
|
|
197
|
+
|
|
198
|
+
OtherParas["time_str"] = time_str
|
|
199
|
+
OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
|
|
200
|
+
|
|
201
|
+
return OtherParas
|
|
202
|
+
|
|
203
|
+
def get_train_group(args):
|
|
204
|
+
training_group = []
|
|
205
|
+
for cfg in args.train:
|
|
206
|
+
model, dataset, optimizer = cfg.split("-")
|
|
207
|
+
training_group.append((args.model_name_mapping[model], args.data_name_mapping[dataset], args.optimizers_name_mapping[optimizer]))
|
|
208
|
+
|
|
209
|
+
return training_group
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def set_paras(args, OtherParas):
|
|
213
|
+
Paras = {
|
|
214
|
+
# Name of the folder where results will be saved.
|
|
215
|
+
"results_folder_name": OtherParas["results_folder_name"],
|
|
216
|
+
|
|
217
|
+
# Print loss every N epochs.
|
|
218
|
+
"epoch_log_interval": 1,
|
|
219
|
+
|
|
220
|
+
"use_log_scale": True,
|
|
221
|
+
|
|
222
|
+
# Timestamp string for result saving.
|
|
223
|
+
"time_str": OtherParas["time_str"],
|
|
224
|
+
|
|
225
|
+
# Random seed
|
|
226
|
+
"seed": args.seed,
|
|
227
|
+
|
|
228
|
+
# Device used for training.
|
|
229
|
+
"cuda": f"cuda:{args.cuda}",
|
|
230
|
+
|
|
231
|
+
# batch-size
|
|
232
|
+
"batch_size": args.bs,
|
|
233
|
+
|
|
234
|
+
# epochs
|
|
235
|
+
"epochs": args.e,
|
|
236
|
+
|
|
237
|
+
# split_train_data
|
|
238
|
+
"split_train_data": args.s,
|
|
239
|
+
|
|
240
|
+
# select_subset
|
|
241
|
+
"select_subset": args.subset,
|
|
242
|
+
|
|
243
|
+
# Results_dict
|
|
244
|
+
"Results_dict": {},
|
|
245
|
+
|
|
246
|
+
# type: bool
|
|
247
|
+
"user_search_grid": OtherParas["user_search_grid"],
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
Paras = model_list(Paras)
|
|
251
|
+
Paras = model_type(Paras)
|
|
252
|
+
Paras = data_list(Paras)
|
|
253
|
+
Paras = optimizer_paras_dict(Paras, OtherParas)
|
|
254
|
+
Paras = device(Paras)
|
|
255
|
+
|
|
256
|
+
return Paras
|
|
257
|
+
|
|
258
|
+
def set_seed(seed=42):
|
|
259
|
+
torch.manual_seed(seed)
|
|
260
|
+
torch.cuda.manual_seed_all(seed)
|
|
261
|
+
np.random.seed(seed)
|
|
262
|
+
random.seed(seed)
|
|
263
|
+
torch.backends.cudnn.deterministic = True
|
|
264
|
+
torch.backends.cudnn.benchmark = False
|
|
265
|
+
|
|
266
|
+
def device(Paras) -> dict:
|
|
267
|
+
device = torch.device(f"{Paras['cuda']}" if torch.cuda.is_available() else "cpu")
|
|
268
|
+
Paras["device"] = device
|
|
269
|
+
use_color = sys.stdout.isatty()
|
|
270
|
+
Paras["use_color"] = use_color
|
|
271
|
+
|
|
272
|
+
return Paras
|
|
273
|
+
|
|
274
|
+
def model_list(Paras) -> dict:
|
|
275
|
+
model_list = [
|
|
276
|
+
"ResNet18",
|
|
277
|
+
"ResNet34",
|
|
278
|
+
"LeastSquares",
|
|
279
|
+
"LogRegressionBinary",
|
|
280
|
+
"LogRegressionBinaryL2",
|
|
281
|
+
]
|
|
282
|
+
Paras["model_list"] = model_list
|
|
283
|
+
return Paras
|
|
284
|
+
|
|
285
|
+
def model_type(Paras) -> dict:
|
|
286
|
+
model_type = {
|
|
287
|
+
"ResNet18": "multi",
|
|
288
|
+
"ResNet34": "multi",
|
|
289
|
+
"LeastSquares": "multi",
|
|
290
|
+
"LogRegressionBinary": "binary",
|
|
291
|
+
"LogRegressionBinaryL2": "binary",
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
Paras["model_type"] = model_type
|
|
295
|
+
return Paras
|
|
296
|
+
|
|
297
|
+
def data_list(Paras) -> dict:
|
|
298
|
+
data_list = [
|
|
299
|
+
"Duke",
|
|
300
|
+
"Ijcnn",
|
|
301
|
+
"w8a",
|
|
302
|
+
"RCV1",
|
|
303
|
+
"Shuttle",
|
|
304
|
+
"Letter",
|
|
305
|
+
"Vowel",
|
|
306
|
+
"MNIST",
|
|
307
|
+
"CIFAR100",
|
|
308
|
+
"Caltech101_Resize_32",
|
|
309
|
+
"Adult_Income_Prediction",
|
|
310
|
+
"Credit_Card_Fraud_Detection",
|
|
311
|
+
"Diabetes_Health_Indicators",
|
|
312
|
+
"Electric_Vehicle_Population",
|
|
313
|
+
"Global_House_Purchase",
|
|
314
|
+
"Health_Lifestyle",
|
|
315
|
+
"Homesite_Quote_Conversion",
|
|
316
|
+
"TN_Weather_2020_2025"
|
|
317
|
+
]
|
|
318
|
+
Paras["data_list"] = data_list
|
|
319
|
+
return Paras
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def optimizer_paras_dict(Paras, OtherParas)->dict:
|
|
323
|
+
optimizer_dict = {
|
|
324
|
+
# ----------------- ADAM --------------------
|
|
325
|
+
"ADAM": {
|
|
326
|
+
"params": {
|
|
327
|
+
# "alpha": [2 * 1e-3],
|
|
328
|
+
"alpha": (
|
|
329
|
+
[0.5 * 1e-3, 1e-3, 2 * 1e-3]
|
|
330
|
+
if OtherParas["SeleParasOn"]
|
|
331
|
+
else [1e-3]
|
|
332
|
+
),
|
|
333
|
+
"epsilon": [1e-8],
|
|
334
|
+
"beta1": [0.9],
|
|
335
|
+
"beta2": [0.999],
|
|
336
|
+
},
|
|
337
|
+
},
|
|
338
|
+
# ------------- ALR-SMAG --------------------
|
|
339
|
+
"ALR-SMAG": {
|
|
340
|
+
"params": {
|
|
341
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
342
|
+
"eta_max": (
|
|
343
|
+
[2**i for i in range(-8, 9)]
|
|
344
|
+
if OtherParas["SeleParasOn"]
|
|
345
|
+
else [0.125]
|
|
346
|
+
),
|
|
347
|
+
"beta": [0.9],
|
|
348
|
+
},
|
|
349
|
+
},
|
|
350
|
+
# ------------ Bundle -----------------------
|
|
351
|
+
"Bundle": {
|
|
352
|
+
"params": {
|
|
353
|
+
"delta": (
|
|
354
|
+
[2**i for i in range(-8, 9)]
|
|
355
|
+
if OtherParas["SeleParasOn"]
|
|
356
|
+
else [0.01]
|
|
357
|
+
),
|
|
358
|
+
"cutting_number": [10],
|
|
359
|
+
},
|
|
360
|
+
},
|
|
361
|
+
# ------------------- SGD -------------------
|
|
362
|
+
"SGD": {
|
|
363
|
+
"params": {
|
|
364
|
+
"alpha": (
|
|
365
|
+
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
|
|
366
|
+
)
|
|
367
|
+
}
|
|
368
|
+
},
|
|
369
|
+
# ------------------- SPSmax ----------------
|
|
370
|
+
"SPSmax": {
|
|
371
|
+
"params": {
|
|
372
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
373
|
+
"gamma": (
|
|
374
|
+
[2**i for i in range(-8, 9)]
|
|
375
|
+
if OtherParas["SeleParasOn"]
|
|
376
|
+
else [0.125]),
|
|
377
|
+
},
|
|
378
|
+
},
|
|
379
|
+
# -------------- SPBM-PF --------------------
|
|
380
|
+
"SPBM-PF": {
|
|
381
|
+
"params": {
|
|
382
|
+
"M": [1e-5],
|
|
383
|
+
"delta": (
|
|
384
|
+
[2**i for i in range(9, 20)]
|
|
385
|
+
if OtherParas["SeleParasOn"]
|
|
386
|
+
else [1]
|
|
387
|
+
),
|
|
388
|
+
"cutting_number": [10],
|
|
389
|
+
},
|
|
390
|
+
},
|
|
391
|
+
# -------------- SPBM-TR --------------------
|
|
392
|
+
"SPBM-TR": {
|
|
393
|
+
"params": {
|
|
394
|
+
"M": [1e-5],
|
|
395
|
+
"delta": (
|
|
396
|
+
[2**i for i in range(9, 20)]
|
|
397
|
+
if OtherParas["SeleParasOn"]
|
|
398
|
+
else [256]
|
|
399
|
+
),
|
|
400
|
+
"cutting_number": [10],
|
|
401
|
+
},
|
|
402
|
+
},
|
|
403
|
+
|
|
404
|
+
# ----------- SPBM-TR-NoneLower -------------
|
|
405
|
+
"SPBM-TR-NoneLower": {
|
|
406
|
+
"params": {
|
|
407
|
+
"M": [1e-5],
|
|
408
|
+
"delta": (
|
|
409
|
+
[2**i for i in range(0, 9)]
|
|
410
|
+
if OtherParas["SeleParasOn"]
|
|
411
|
+
else [256]
|
|
412
|
+
),
|
|
413
|
+
"cutting_number": [10],
|
|
414
|
+
},
|
|
415
|
+
},
|
|
416
|
+
# ----------- SPBM-TR-NoneSpecial -----------
|
|
417
|
+
"SPBM-TR-NoneSpecial": {
|
|
418
|
+
"params": {
|
|
419
|
+
"M": [1e-5],
|
|
420
|
+
"delta": (
|
|
421
|
+
[2**i for i in range(-8, 9)]
|
|
422
|
+
if OtherParas["SeleParasOn"]
|
|
423
|
+
else [1]
|
|
424
|
+
),
|
|
425
|
+
"cutting_number": [10],
|
|
426
|
+
},
|
|
427
|
+
},
|
|
428
|
+
# ----------- SPBM-TR-NoneCut -----------
|
|
429
|
+
"SPBM-TR-NoneCut": {
|
|
430
|
+
"params": {
|
|
431
|
+
"M": [1e-5],
|
|
432
|
+
"delta": (
|
|
433
|
+
[2**i for i in range(-8, 9)]
|
|
434
|
+
if OtherParas["SeleParasOn"]
|
|
435
|
+
else [1]
|
|
436
|
+
),
|
|
437
|
+
"cutting_number": [10],
|
|
438
|
+
},
|
|
439
|
+
},
|
|
440
|
+
|
|
441
|
+
# ------------- SPBM-PF-NoneLower -----------
|
|
442
|
+
"SPBM-PF-NoneLower": {
|
|
443
|
+
"params": {
|
|
444
|
+
"M": [1e-5],
|
|
445
|
+
"delta": (
|
|
446
|
+
[2**i for i in range(0, 9)]
|
|
447
|
+
if OtherParas["SeleParasOn"]
|
|
448
|
+
else [0]
|
|
449
|
+
),
|
|
450
|
+
"cutting_number": [10],
|
|
451
|
+
},
|
|
452
|
+
},
|
|
453
|
+
# ----------- SPBM-PF-NoneCut -----------
|
|
454
|
+
"SPBM-PF-NoneCut": {
|
|
455
|
+
"params": {
|
|
456
|
+
"M": [1e-5],
|
|
457
|
+
"delta": (
|
|
458
|
+
[2**i for i in range(-8, 9)]
|
|
459
|
+
if OtherParas["SeleParasOn"]
|
|
460
|
+
else [1]
|
|
461
|
+
),
|
|
462
|
+
"cutting_number": [10],
|
|
463
|
+
},
|
|
464
|
+
},
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
Paras["optimizer_search_grid"] = optimizer_dict
|
|
468
|
+
return Paras
|
|
469
|
+
|
|
470
|
+
def metrics()->dict:
|
|
471
|
+
metrics = {
|
|
472
|
+
"epoch_loss": [],
|
|
473
|
+
"training_loss": [],
|
|
474
|
+
"test_loss": [],
|
|
475
|
+
"iter_loss": [],
|
|
476
|
+
"training_acc": [],
|
|
477
|
+
"test_acc": [],
|
|
478
|
+
"grad_norm": [],
|
|
479
|
+
"per_epoch_loss": []
|
|
480
|
+
}
|
|
481
|
+
return metrics
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
|
|
485
|
+
|
|
486
|
+
keys, values = list(params_gird.keys()), list(params_gird.values())
|
|
487
|
+
|
|
488
|
+
Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
|
|
489
|
+
os.makedirs(Paras["Results_folder"], exist_ok=True)
|
|
490
|
+
|
|
491
|
+
return keys, values, Paras
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def fig_ylabel(str_name):
|
|
495
|
+
|
|
496
|
+
ylabel = {
|
|
497
|
+
"training_loss": "training loss",
|
|
498
|
+
"test_loss": "test loss",
|
|
499
|
+
"training_acc": "training accuracy",
|
|
500
|
+
"test_acc": "test accuracy",
|
|
501
|
+
"grad_norm": "grad norm",
|
|
502
|
+
"per_epoch_loss": "per epoch loss",
|
|
503
|
+
"epoch_loss": "epoch loss",
|
|
504
|
+
}
|
|
505
|
+
return ylabel[str_name]
|
|
506
|
+
|
|
507
|
+
def model_abbr(model_name):
|
|
508
|
+
name_map = {
|
|
509
|
+
"LogRegressionBinaryL2": "LRBL2",
|
|
510
|
+
"ResNet18": "ResNet18",
|
|
511
|
+
"ResNet34": "ResNet34",
|
|
512
|
+
"LstSquares": "LS"
|
|
513
|
+
}
|
|
514
|
+
return name_map[model_name]
|
|
515
|
+
|
|
516
|
+
def dataset_abbr(model_name):
|
|
517
|
+
name_map = {
|
|
518
|
+
"MNIST": "MNIST",
|
|
519
|
+
"CIFAR100": "CIFAR100",
|
|
520
|
+
"Duke": "Duke",
|
|
521
|
+
"Ijcnn": "Ijcnn",
|
|
522
|
+
"Adult_Income_Prediction": "AIP",
|
|
523
|
+
"Credit_Card_Frau_Detection": "CCFD",
|
|
524
|
+
"Diabetes_Health_Indicators": "DHI",
|
|
525
|
+
"Electric_Vehicle_Population": "EVP",
|
|
526
|
+
"Global_House_Purchase": "GHP",
|
|
527
|
+
"Health_Lifestyle": "HL",
|
|
528
|
+
}
|
|
529
|
+
return name_map[model_name]
|
|
530
|
+
|
|
531
|
+
def model_full_name(model_name):
|
|
532
|
+
model_mapping = {
|
|
533
|
+
"LS": "LeastSquares",
|
|
534
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
535
|
+
"ResNet18": "ResNet18",
|
|
536
|
+
}
|
|
537
|
+
return model_mapping[model_name]
|
|
538
|
+
# <optimizers_full_name>
|
|
539
|
+
def optimizers_full_name(optimizer_name):
|
|
540
|
+
name_map = {
|
|
541
|
+
"ADAM": "ADAM",
|
|
542
|
+
"SGD": "SGD",
|
|
543
|
+
"Bundle": "Bundle",
|
|
544
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
545
|
+
"SPBM_TR": "SPBM-TR",
|
|
546
|
+
"SPBM_PF": "SPBM-PF",
|
|
547
|
+
"SPSmax": "SPSmax",
|
|
548
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
549
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
550
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
551
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
552
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
553
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
554
|
+
}
|
|
555
|
+
return name_map[optimizer_name]
|
|
556
|
+
# <optimizers_full_name>
|
|
557
|
+
|
|
558
|
+
# <dataset_full_name>
|
|
559
|
+
def dataset_full_name(dataset_name):
|
|
560
|
+
name_map = {
|
|
561
|
+
"MNIST": "MNIST",
|
|
562
|
+
"CIFAR100": "CIFAR100",
|
|
563
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
564
|
+
"Duke": "Duke",
|
|
565
|
+
"AIP": "Adult_Income_Prediction",
|
|
566
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
567
|
+
"Ijcnn": "Ijcnn",
|
|
568
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
569
|
+
"EVP": "Electric_Vehicle_Population",
|
|
570
|
+
"GHP": "Global_House_Purchase",
|
|
571
|
+
"HL": "Health_Lifestyle",
|
|
572
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
573
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
574
|
+
}
|
|
575
|
+
return name_map[dataset_name]
|
|
576
|
+
# <dataset_full_name>
|
|
577
|
+
|
|
578
|
+
def opt_paras_str(opt_paras_dict):
|
|
579
|
+
# Example: "k1_v1_k2_v2_..."
|
|
580
|
+
|
|
581
|
+
keys = list(opt_paras_dict.keys())
|
|
582
|
+
values = list(opt_paras_dict.values())
|
|
583
|
+
|
|
584
|
+
param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
|
|
585
|
+
|
|
586
|
+
return param_str
|
|
587
|
+
# <set_marker_point>
|
|
588
|
+
def set_marker_point(epoch_num: int) -> list:
|
|
589
|
+
marker_point = {
|
|
590
|
+
1: [0],
|
|
591
|
+
4: [0, 2, 4],
|
|
592
|
+
6: [0, 2, 4, 6],
|
|
593
|
+
8: [0, 2, 4, 6, 8],
|
|
594
|
+
10: [0, 2, 4, 6, 8, 10],
|
|
595
|
+
100: [0, 20, 40, 60, 80, 100],
|
|
596
|
+
200: [0, 40, 80, 120, 160, 200],
|
|
597
|
+
}
|
|
598
|
+
if epoch_num not in marker_point:
|
|
599
|
+
raise ValueError(f"No marker defined for epoch {epoch_num}")
|
|
600
|
+
|
|
601
|
+
return marker_point[epoch_num]
|
|
602
|
+
|
|
603
|
+
# <set_marker_point>
|
|
604
|
+
# <results_path_to_info>
|
|
605
|
+
def results_path_to_info(path_list):
|
|
606
|
+
info_dict = {}
|
|
607
|
+
|
|
608
|
+
for path in path_list:
|
|
609
|
+
parts = path.split("/")
|
|
610
|
+
seed = parts[1]
|
|
611
|
+
model_name = parts[2]
|
|
612
|
+
data_name = parts[3]
|
|
613
|
+
optimizer = parts[4]
|
|
614
|
+
train_test = parts[5].split("_")
|
|
615
|
+
batch_size = parts[6].split("_")[2]
|
|
616
|
+
epochs = parts[7].split("_")[1]
|
|
617
|
+
ID = parts[8]
|
|
618
|
+
|
|
619
|
+
if model_name not in info_dict:
|
|
620
|
+
info_dict[model_name] = {}
|
|
621
|
+
|
|
622
|
+
if data_name not in info_dict[model_name]:
|
|
623
|
+
info_dict[model_name][data_name] = {}
|
|
624
|
+
|
|
625
|
+
if optimizer not in info_dict[model_name][data_name]:
|
|
626
|
+
info_dict[model_name][data_name][optimizer] = {}
|
|
627
|
+
|
|
628
|
+
info_dict[model_name][data_name][optimizer][ID] = {
|
|
629
|
+
"seed": seed.split("_")[1],
|
|
630
|
+
"epochs": int(epochs),
|
|
631
|
+
"train_test": (train_test[1], train_test[3]),
|
|
632
|
+
"batch_size": batch_size,
|
|
633
|
+
"marker": set_marker_point(int(epochs)),
|
|
634
|
+
"optimizer":{
|
|
635
|
+
f"{optimizer}":{
|
|
636
|
+
"ID": ID,
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
return info_dict
|
|
642
|
+
# <results_path_to_info>
|
|
643
|
+
|
|
644
|
+
# <update_info_dict>
|
|
645
|
+
def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
|
|
646
|
+
for data_name in draw_data_list:
|
|
647
|
+
for i in draw_data[data_name]:
|
|
648
|
+
optimizer_name, ID, Opt_Paras = i
|
|
649
|
+
|
|
650
|
+
if data_name not in results_dict[model_name].keys():
|
|
651
|
+
print('*' * 40)
|
|
652
|
+
print(f'{data_name} not in results')
|
|
653
|
+
print('*' * 40)
|
|
654
|
+
assert False
|
|
655
|
+
|
|
656
|
+
# Check if optimizer_name exists in results_dict
|
|
657
|
+
if optimizer_name not in results_dict[model_name][data_name]:
|
|
658
|
+
print('*' * 40)
|
|
659
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
|
|
660
|
+
print('*' * 40)
|
|
661
|
+
assert False
|
|
662
|
+
|
|
663
|
+
# Check if ID exists in results_dict
|
|
664
|
+
if ID not in results_dict[model_name][data_name][optimizer_name]:
|
|
665
|
+
print('*' * 60)
|
|
666
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
|
|
667
|
+
print('*' * 60)
|
|
668
|
+
assert False
|
|
669
|
+
|
|
670
|
+
# Initialize info_dict[data_name] if it does not exist
|
|
671
|
+
if data_name not in info_dict:
|
|
672
|
+
info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
|
|
673
|
+
|
|
674
|
+
# Update optimizer parameters
|
|
675
|
+
if "optimizer" not in info_dict[data_name]:
|
|
676
|
+
info_dict[data_name]["optimizer"] = {}
|
|
677
|
+
info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
|
|
678
|
+
info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
|
|
679
|
+
|
|
680
|
+
# Update metric_key
|
|
681
|
+
info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
|
|
682
|
+
|
|
683
|
+
return info_dict
|
|
684
|
+
# <update_info_dict>
|
|
685
|
+
|
|
686
|
+
def get_results_all_pkl_path(results_folder):
|
|
687
|
+
|
|
688
|
+
pattern = os.path.join(results_folder, "**", "*.pkl")
|
|
689
|
+
|
|
690
|
+
return glob.glob(pattern, recursive=True)
|