junshan-kit 2.5.1__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +108 -8
- junshan_kit/DataProcessor.py +86 -8
- junshan_kit/DataSets.py +29 -30
- junshan_kit/Evaluate_Metrics.py +75 -2
- junshan_kit/FiguresHub.py +286 -0
- junshan_kit/ModelsHub.py +32 -5
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +350 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +390 -119
- junshan_kit/Print_Info.py +57 -11
- junshan_kit/TrainingHub.py +190 -40
- junshan_kit/kit.py +39 -50
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.7.3.dist-info}/METADATA +7 -1
- junshan_kit-2.7.3.dist-info/RECORD +20 -0
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.7.3.dist-info}/WHEEL +1 -1
- junshan_kit-2.5.1.dist-info/RECORD +0 -13
junshan_kit/ParametersHub.py
CHANGED
|
@@ -1,60 +1,108 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import sys, os, torch, random
|
|
2
|
+
import sys, os, torch, random, glob
|
|
3
3
|
import argparse
|
|
4
|
-
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
6
|
+
sys.path.append(os.path.join(script_dir, 'src'))
|
|
7
|
+
from junshan_kit import ModelsHub, Check_Info
|
|
5
8
|
|
|
6
9
|
|
|
7
|
-
class
|
|
10
|
+
class args:
|
|
8
11
|
def __init__(self):
|
|
9
12
|
pass
|
|
10
|
-
|
|
13
|
+
|
|
14
|
+
# <args>
|
|
11
15
|
def get_args(self):
|
|
12
16
|
parser = argparse.ArgumentParser(description="Combined config argument example")
|
|
13
17
|
|
|
14
|
-
|
|
15
|
-
|
|
18
|
+
# <allowed_models>
|
|
19
|
+
allowed_models = ["LS", "LRBL2", "ResNet18"]
|
|
20
|
+
# <allowed_models>
|
|
21
|
+
|
|
22
|
+
# <allowed_optimizers>
|
|
23
|
+
allowed_optimizers = [
|
|
24
|
+
"ADAM",
|
|
25
|
+
"ALR_SMAG",
|
|
26
|
+
"Bundle",
|
|
27
|
+
"SGD",
|
|
28
|
+
"SPBM_TR",
|
|
29
|
+
"SPBM_PF",
|
|
30
|
+
"SPSmax",
|
|
31
|
+
"SPBM_TR_NoneSpecial",
|
|
32
|
+
"SPBM_TR_NoneLower",
|
|
33
|
+
"SPBM_PF_NoneLower",
|
|
34
|
+
]
|
|
35
|
+
# <allowed_optimizers>
|
|
36
|
+
|
|
37
|
+
# <allowed_datasets>
|
|
38
|
+
allowed_datasets = [
|
|
39
|
+
"MNIST",
|
|
40
|
+
"CIFAR100",
|
|
41
|
+
"Caltech101",
|
|
42
|
+
"AIP",
|
|
43
|
+
"CCFD",
|
|
44
|
+
"Duke",
|
|
45
|
+
"Ijcnn",
|
|
46
|
+
"DHI",
|
|
47
|
+
"EVP",
|
|
48
|
+
"GHP",
|
|
49
|
+
"HL",
|
|
50
|
+
"HQC",
|
|
51
|
+
"TN_Weather",
|
|
52
|
+
],
|
|
53
|
+
# <allowed_datasets>
|
|
54
|
+
data_name_mapping = {
|
|
55
|
+
"MNIST": "MNIST",
|
|
56
|
+
"CIFAR100": "CIFAR100",
|
|
57
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
58
|
+
"Duke": "Duke",
|
|
59
|
+
"AIP": "Adult_Income_Prediction",
|
|
60
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
61
|
+
"Ijcnn": "Ijcnn",
|
|
62
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
63
|
+
"EVP": "Electric_Vehicle_Population",
|
|
64
|
+
"GHP": "Global_House_Purchase",
|
|
65
|
+
"HL": "Health_Lifestyle",
|
|
66
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
67
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
68
|
+
}
|
|
16
69
|
|
|
17
|
-
allowed_datasets = ["MNIST",
|
|
18
|
-
"CIFAR100",
|
|
19
|
-
"AIP",
|
|
20
|
-
"CCFD",
|
|
21
|
-
]
|
|
22
|
-
|
|
23
70
|
optimizers_mapping = {
|
|
24
71
|
"ADAM": "ADAM",
|
|
25
72
|
"SGD": "SGD",
|
|
26
|
-
"Bundle": "Bundle"
|
|
73
|
+
"Bundle": "Bundle",
|
|
74
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
75
|
+
"SPBM_TR": "SPBM-TR",
|
|
76
|
+
"SPBM_PF": "SPBM-PF",
|
|
77
|
+
"SPSmax": "SPSmax",
|
|
78
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
79
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
80
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
81
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
82
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
83
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
27
84
|
}
|
|
28
85
|
|
|
29
86
|
model_mapping = {
|
|
30
87
|
"LS": "LeastSquares",
|
|
31
|
-
"
|
|
88
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
32
89
|
"ResNet18": "ResNet18"
|
|
33
90
|
}
|
|
34
|
-
|
|
35
|
-
data_name_mapping = {
|
|
36
|
-
"MNIST": "MNIST",
|
|
37
|
-
"CIFAR100": "CIFAR100",
|
|
38
|
-
"AIP": "Adult_Income_Prediction",
|
|
39
|
-
"CCFD": "Credit_Card_Fraud_Detection"
|
|
40
|
-
}
|
|
41
91
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# Single combined argument that can appear multiple times
|
|
92
|
+
# <args_from_command>
|
|
45
93
|
parser.add_argument(
|
|
46
94
|
"--train",
|
|
47
95
|
type=str,
|
|
48
96
|
nargs="+", # Allow multiple configs
|
|
49
97
|
required=True,
|
|
50
|
-
help = f"Format: model-dataset-optimizer (e.g., ResNet18-
|
|
98
|
+
help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
|
|
51
99
|
)
|
|
52
100
|
|
|
53
101
|
parser.add_argument(
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
102
|
+
"--e",
|
|
103
|
+
type=int,
|
|
104
|
+
required=True,
|
|
105
|
+
help="Number of training epochs. Example: --e 50"
|
|
58
106
|
)
|
|
59
107
|
|
|
60
108
|
parser.add_argument(
|
|
@@ -88,63 +136,69 @@ class check_args:
|
|
|
88
136
|
)
|
|
89
137
|
|
|
90
138
|
parser.add_argument(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
139
|
+
"--subset",
|
|
140
|
+
type=float,
|
|
141
|
+
nargs=2,
|
|
142
|
+
# required=True,
|
|
143
|
+
help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
|
|
96
144
|
)
|
|
97
145
|
|
|
146
|
+
parser.add_argument(
|
|
147
|
+
"--time_str",
|
|
148
|
+
type=str,
|
|
149
|
+
nargs=1,
|
|
150
|
+
# required=True,
|
|
151
|
+
help = "the str of time"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
parser.add_argument(
|
|
155
|
+
"--send_email",
|
|
156
|
+
type=str,
|
|
157
|
+
nargs=3,
|
|
158
|
+
# required=True,
|
|
159
|
+
help = "from_email to_email, from_pwd"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
parser.add_argument(
|
|
163
|
+
"--user_search_grid",
|
|
164
|
+
type=int,
|
|
165
|
+
nargs=1,
|
|
166
|
+
# required=True,
|
|
167
|
+
help = "search_grid: 1: "
|
|
168
|
+
)
|
|
169
|
+
# <args_from_command>
|
|
170
|
+
|
|
98
171
|
args = parser.parse_args()
|
|
99
172
|
args.model_name_mapping = model_mapping
|
|
100
173
|
args.data_name_mapping = data_name_mapping
|
|
101
174
|
args.optimizers_name_mapping = optimizers_mapping
|
|
102
175
|
|
|
103
|
-
|
|
104
|
-
if args.subset is not None:
|
|
105
|
-
self.check_subset_info(args, parser)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
|
|
109
|
-
|
|
110
176
|
return args
|
|
177
|
+
# <args>
|
|
178
|
+
|
|
179
|
+
def UpdateOtherParas(args, OtherParas):
|
|
180
|
+
if args.time_str is not None:
|
|
181
|
+
time_str = args.time_str[0]
|
|
182
|
+
else:
|
|
183
|
+
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
111
184
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
if model not in allowed_models:
|
|
131
|
-
parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
|
|
132
|
-
if optimizer not in allowed_optimizers:
|
|
133
|
-
parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
|
|
134
|
-
if dataset not in allowed_datasets:
|
|
135
|
-
parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
|
|
136
|
-
|
|
137
|
-
except ValueError:
|
|
138
|
-
parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
|
|
139
|
-
|
|
140
|
-
for cfg in args.train:
|
|
141
|
-
model_name, dataset_name, optimizer_name = cfg.split("-")
|
|
142
|
-
try:
|
|
143
|
-
f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
|
|
144
|
-
|
|
145
|
-
except:
|
|
146
|
-
print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
|
|
147
|
-
assert False
|
|
185
|
+
if args.user_search_grid is not None:
|
|
186
|
+
OtherParas["user_search_grid"] = args.user_search_grid[0]
|
|
187
|
+
else:
|
|
188
|
+
OtherParas["user_search_grid"] = None
|
|
189
|
+
|
|
190
|
+
if args.send_email is not None:
|
|
191
|
+
OtherParas["from_email"] = args.send_email[0]
|
|
192
|
+
OtherParas["to_email"] = args.send_email[1]
|
|
193
|
+
OtherParas["from_pwd"] = args.send_email[2]
|
|
194
|
+
OtherParas["send_email"] = True
|
|
195
|
+
else:
|
|
196
|
+
OtherParas["send_email"] = False
|
|
197
|
+
|
|
198
|
+
OtherParas["time_str"] = time_str
|
|
199
|
+
OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
|
|
200
|
+
|
|
201
|
+
return OtherParas
|
|
148
202
|
|
|
149
203
|
def get_train_group(args):
|
|
150
204
|
training_group = []
|
|
@@ -163,8 +217,10 @@ def set_paras(args, OtherParas):
|
|
|
163
217
|
# Print loss every N epochs.
|
|
164
218
|
"epoch_log_interval": 1,
|
|
165
219
|
|
|
220
|
+
"use_log_scale": True,
|
|
221
|
+
|
|
166
222
|
# Timestamp string for result saving.
|
|
167
|
-
"time_str": ["time_str"],
|
|
223
|
+
"time_str": OtherParas["time_str"],
|
|
168
224
|
|
|
169
225
|
# Random seed
|
|
170
226
|
"seed": args.seed,
|
|
@@ -182,18 +238,23 @@ def set_paras(args, OtherParas):
|
|
|
182
238
|
"split_train_data": args.s,
|
|
183
239
|
|
|
184
240
|
# select_subset
|
|
185
|
-
"select_subset": args.subset
|
|
241
|
+
"select_subset": args.subset,
|
|
242
|
+
|
|
243
|
+
# Results_dict
|
|
244
|
+
"Results_dict": {},
|
|
245
|
+
|
|
246
|
+
# type: bool
|
|
247
|
+
"user_search_grid": OtherParas["user_search_grid"],
|
|
186
248
|
}
|
|
187
249
|
|
|
188
250
|
Paras = model_list(Paras)
|
|
189
251
|
Paras = model_type(Paras)
|
|
190
252
|
Paras = data_list(Paras)
|
|
191
|
-
Paras =
|
|
253
|
+
Paras = optimizer_paras_dict(Paras, OtherParas)
|
|
192
254
|
Paras = device(Paras)
|
|
193
255
|
|
|
194
256
|
return Paras
|
|
195
257
|
|
|
196
|
-
|
|
197
258
|
def set_seed(seed=42):
|
|
198
259
|
torch.manual_seed(seed)
|
|
199
260
|
torch.cuda.manual_seed_all(seed)
|
|
@@ -234,34 +295,6 @@ def model_type(Paras) -> dict:
|
|
|
234
295
|
return Paras
|
|
235
296
|
|
|
236
297
|
def data_list(Paras) -> dict:
|
|
237
|
-
"""
|
|
238
|
-
Attach a predefined list of dataset names to the parameter dictionary.
|
|
239
|
-
|
|
240
|
-
The predefined datasets include:
|
|
241
|
-
- Duke:
|
|
242
|
-
- classes: 2
|
|
243
|
-
- data: 42 (38 + 4)
|
|
244
|
-
- features: 7,129
|
|
245
|
-
- Ijcnn:
|
|
246
|
-
- classes: 2
|
|
247
|
-
- data: (35,000 + 91,701)
|
|
248
|
-
- features: 22
|
|
249
|
-
- w8a:
|
|
250
|
-
- classes: 2
|
|
251
|
-
- data: (49,749 + 14,951)
|
|
252
|
-
- features: 300
|
|
253
|
-
- RCV1
|
|
254
|
-
- Shuttle
|
|
255
|
-
- Letter
|
|
256
|
-
- Vowel
|
|
257
|
-
- MNIST
|
|
258
|
-
- CIFAR100
|
|
259
|
-
- CALTECH101_Resize_32
|
|
260
|
-
- Adult Income Prediction
|
|
261
|
-
-
|
|
262
|
-
- Credit_Card_Fraud_Detection
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
298
|
data_list = [
|
|
266
299
|
"Duke",
|
|
267
300
|
"Ijcnn",
|
|
@@ -272,15 +305,21 @@ def data_list(Paras) -> dict:
|
|
|
272
305
|
"Vowel",
|
|
273
306
|
"MNIST",
|
|
274
307
|
"CIFAR100",
|
|
275
|
-
"
|
|
308
|
+
"Caltech101_Resize_32",
|
|
276
309
|
"Adult_Income_Prediction",
|
|
277
|
-
"Credit_Card_Fraud_Detection"
|
|
310
|
+
"Credit_Card_Fraud_Detection",
|
|
311
|
+
"Diabetes_Health_Indicators",
|
|
312
|
+
"Electric_Vehicle_Population",
|
|
313
|
+
"Global_House_Purchase",
|
|
314
|
+
"Health_Lifestyle",
|
|
315
|
+
"Homesite_Quote_Conversion",
|
|
316
|
+
"TN_Weather_2020_2025"
|
|
278
317
|
]
|
|
279
318
|
Paras["data_list"] = data_list
|
|
280
319
|
return Paras
|
|
281
320
|
|
|
282
321
|
|
|
283
|
-
def
|
|
322
|
+
def optimizer_paras_dict(Paras, OtherParas)->dict:
|
|
284
323
|
optimizer_dict = {
|
|
285
324
|
# ----------------- ADAM --------------------
|
|
286
325
|
"ADAM": {
|
|
@@ -289,7 +328,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
289
328
|
"alpha": (
|
|
290
329
|
[0.5 * 1e-3, 1e-3, 2 * 1e-3]
|
|
291
330
|
if OtherParas["SeleParasOn"]
|
|
292
|
-
else [
|
|
331
|
+
else [1e-3]
|
|
293
332
|
),
|
|
294
333
|
"epsilon": [1e-8],
|
|
295
334
|
"beta1": [0.9],
|
|
@@ -314,7 +353,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
314
353
|
"delta": (
|
|
315
354
|
[2**i for i in range(-8, 9)]
|
|
316
355
|
if OtherParas["SeleParasOn"]
|
|
317
|
-
else [0.
|
|
356
|
+
else [0.01]
|
|
318
357
|
),
|
|
319
358
|
"cutting_number": [10],
|
|
320
359
|
},
|
|
@@ -323,7 +362,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
323
362
|
"SGD": {
|
|
324
363
|
"params": {
|
|
325
364
|
"alpha": (
|
|
326
|
-
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.
|
|
365
|
+
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
|
|
327
366
|
)
|
|
328
367
|
}
|
|
329
368
|
},
|
|
@@ -386,6 +425,19 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
386
425
|
"cutting_number": [10],
|
|
387
426
|
},
|
|
388
427
|
},
|
|
428
|
+
# ----------- SPBM-TR-NoneCut -----------
|
|
429
|
+
"SPBM-TR-NoneCut": {
|
|
430
|
+
"params": {
|
|
431
|
+
"M": [1e-5],
|
|
432
|
+
"delta": (
|
|
433
|
+
[2**i for i in range(-8, 9)]
|
|
434
|
+
if OtherParas["SeleParasOn"]
|
|
435
|
+
else [1]
|
|
436
|
+
),
|
|
437
|
+
"cutting_number": [10],
|
|
438
|
+
},
|
|
439
|
+
},
|
|
440
|
+
|
|
389
441
|
# ------------- SPBM-PF-NoneLower -----------
|
|
390
442
|
"SPBM-PF-NoneLower": {
|
|
391
443
|
"params": {
|
|
@@ -398,13 +450,23 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
398
450
|
"cutting_number": [10],
|
|
399
451
|
},
|
|
400
452
|
},
|
|
453
|
+
# ----------- SPBM-PF-NoneCut -----------
|
|
454
|
+
"SPBM-PF-NoneCut": {
|
|
455
|
+
"params": {
|
|
456
|
+
"M": [1e-5],
|
|
457
|
+
"delta": (
|
|
458
|
+
[2**i for i in range(-8, 9)]
|
|
459
|
+
if OtherParas["SeleParasOn"]
|
|
460
|
+
else [1]
|
|
461
|
+
),
|
|
462
|
+
"cutting_number": [10],
|
|
463
|
+
},
|
|
464
|
+
},
|
|
401
465
|
}
|
|
402
466
|
|
|
403
|
-
Paras["
|
|
467
|
+
Paras["optimizer_search_grid"] = optimizer_dict
|
|
404
468
|
return Paras
|
|
405
469
|
|
|
406
|
-
|
|
407
|
-
|
|
408
470
|
def metrics()->dict:
|
|
409
471
|
metrics = {
|
|
410
472
|
"epoch_loss": [],
|
|
@@ -416,4 +478,213 @@ def metrics()->dict:
|
|
|
416
478
|
"grad_norm": [],
|
|
417
479
|
"per_epoch_loss": []
|
|
418
480
|
}
|
|
419
|
-
return metrics
|
|
481
|
+
return metrics
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
|
|
485
|
+
|
|
486
|
+
keys, values = list(params_gird.keys()), list(params_gird.values())
|
|
487
|
+
|
|
488
|
+
Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
|
|
489
|
+
os.makedirs(Paras["Results_folder"], exist_ok=True)
|
|
490
|
+
|
|
491
|
+
return keys, values, Paras
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def fig_ylabel(str_name):
|
|
495
|
+
|
|
496
|
+
ylabel = {
|
|
497
|
+
"training_loss": "training loss",
|
|
498
|
+
"test_loss": "test loss",
|
|
499
|
+
"training_acc": "training accuracy",
|
|
500
|
+
"test_acc": "test accuracy",
|
|
501
|
+
"grad_norm": "grad norm",
|
|
502
|
+
"per_epoch_loss": "per epoch loss",
|
|
503
|
+
"epoch_loss": "epoch loss",
|
|
504
|
+
}
|
|
505
|
+
return ylabel[str_name]
|
|
506
|
+
|
|
507
|
+
def model_abbr(model_name):
|
|
508
|
+
name_map = {
|
|
509
|
+
"LogRegressionBinaryL2": "LRBL2",
|
|
510
|
+
"ResNet18": "ResNet18",
|
|
511
|
+
"ResNet34": "ResNet34",
|
|
512
|
+
"LstSquares": "LS"
|
|
513
|
+
}
|
|
514
|
+
return name_map[model_name]
|
|
515
|
+
|
|
516
|
+
def dataset_abbr(model_name):
|
|
517
|
+
name_map = {
|
|
518
|
+
"MNIST": "MNIST",
|
|
519
|
+
"CIFAR100": "CIFAR100",
|
|
520
|
+
"Duke": "Duke",
|
|
521
|
+
"Ijcnn": "Ijcnn",
|
|
522
|
+
"Adult_Income_Prediction": "AIP",
|
|
523
|
+
"Credit_Card_Frau_Detection": "CCFD",
|
|
524
|
+
"Diabetes_Health_Indicators": "DHI",
|
|
525
|
+
"Electric_Vehicle_Population": "EVP",
|
|
526
|
+
"Global_House_Purchase": "GHP",
|
|
527
|
+
"Health_Lifestyle": "HL",
|
|
528
|
+
}
|
|
529
|
+
return name_map[model_name]
|
|
530
|
+
|
|
531
|
+
def model_full_name(model_name):
|
|
532
|
+
model_mapping = {
|
|
533
|
+
"LS": "LeastSquares",
|
|
534
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
535
|
+
"ResNet18": "ResNet18",
|
|
536
|
+
}
|
|
537
|
+
return model_mapping[model_name]
|
|
538
|
+
# <optimizers_full_name>
|
|
539
|
+
def optimizers_full_name(optimizer_name):
|
|
540
|
+
name_map = {
|
|
541
|
+
"ADAM": "ADAM",
|
|
542
|
+
"SGD": "SGD",
|
|
543
|
+
"Bundle": "Bundle",
|
|
544
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
545
|
+
"SPBM_TR": "SPBM-TR",
|
|
546
|
+
"SPBM_PF": "SPBM-PF",
|
|
547
|
+
"SPSmax": "SPSmax",
|
|
548
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
549
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
550
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
551
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
552
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
553
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
554
|
+
}
|
|
555
|
+
return name_map[optimizer_name]
|
|
556
|
+
# <optimizers_full_name>
|
|
557
|
+
|
|
558
|
+
# <dataset_full_name>
|
|
559
|
+
def dataset_full_name(dataset_name):
|
|
560
|
+
name_map = {
|
|
561
|
+
"MNIST": "MNIST",
|
|
562
|
+
"CIFAR100": "CIFAR100",
|
|
563
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
564
|
+
"Duke": "Duke",
|
|
565
|
+
"AIP": "Adult_Income_Prediction",
|
|
566
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
567
|
+
"Ijcnn": "Ijcnn",
|
|
568
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
569
|
+
"EVP": "Electric_Vehicle_Population",
|
|
570
|
+
"GHP": "Global_House_Purchase",
|
|
571
|
+
"HL": "Health_Lifestyle",
|
|
572
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
573
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
574
|
+
}
|
|
575
|
+
return name_map[dataset_name]
|
|
576
|
+
# <dataset_full_name>
|
|
577
|
+
|
|
578
|
+
def opt_paras_str(opt_paras_dict):
|
|
579
|
+
# Example: "k1_v1_k2_v2_..."
|
|
580
|
+
|
|
581
|
+
keys = list(opt_paras_dict.keys())
|
|
582
|
+
values = list(opt_paras_dict.values())
|
|
583
|
+
|
|
584
|
+
param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
|
|
585
|
+
|
|
586
|
+
return param_str
|
|
587
|
+
# <set_marker_point>
|
|
588
|
+
def set_marker_point(epoch_num: int) -> list:
|
|
589
|
+
marker_point = {
|
|
590
|
+
1: [0],
|
|
591
|
+
4: [0, 2, 4],
|
|
592
|
+
6: [0, 2, 4, 6],
|
|
593
|
+
8: [0, 2, 4, 6, 8],
|
|
594
|
+
10: [0, 2, 4, 6, 8, 10],
|
|
595
|
+
100: [0, 20, 40, 60, 80, 100],
|
|
596
|
+
200: [0, 40, 80, 120, 160, 200],
|
|
597
|
+
}
|
|
598
|
+
if epoch_num not in marker_point:
|
|
599
|
+
raise ValueError(f"No marker defined for epoch {epoch_num}")
|
|
600
|
+
|
|
601
|
+
return marker_point[epoch_num]
|
|
602
|
+
|
|
603
|
+
# <set_marker_point>
|
|
604
|
+
# <results_path_to_info>
|
|
605
|
+
def results_path_to_info(path_list):
|
|
606
|
+
info_dict = {}
|
|
607
|
+
|
|
608
|
+
for path in path_list:
|
|
609
|
+
parts = path.split("/")
|
|
610
|
+
seed = parts[1]
|
|
611
|
+
model_name = parts[2]
|
|
612
|
+
data_name = parts[3]
|
|
613
|
+
optimizer = parts[4]
|
|
614
|
+
train_test = parts[5].split("_")
|
|
615
|
+
batch_size = parts[6].split("_")[2]
|
|
616
|
+
epochs = parts[7].split("_")[1]
|
|
617
|
+
ID = parts[8]
|
|
618
|
+
|
|
619
|
+
if model_name not in info_dict:
|
|
620
|
+
info_dict[model_name] = {}
|
|
621
|
+
|
|
622
|
+
if data_name not in info_dict[model_name]:
|
|
623
|
+
info_dict[model_name][data_name] = {}
|
|
624
|
+
|
|
625
|
+
if optimizer not in info_dict[model_name][data_name]:
|
|
626
|
+
info_dict[model_name][data_name][optimizer] = {}
|
|
627
|
+
|
|
628
|
+
info_dict[model_name][data_name][optimizer][ID] = {
|
|
629
|
+
"seed": seed.split("_")[1],
|
|
630
|
+
"epochs": int(epochs),
|
|
631
|
+
"train_test": (train_test[1], train_test[3]),
|
|
632
|
+
"batch_size": batch_size,
|
|
633
|
+
"marker": set_marker_point(int(epochs)),
|
|
634
|
+
"optimizer":{
|
|
635
|
+
f"{optimizer}":{
|
|
636
|
+
"ID": ID,
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
return info_dict
|
|
642
|
+
# <results_path_to_info>
|
|
643
|
+
|
|
644
|
+
# <update_info_dict>
|
|
645
|
+
def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
|
|
646
|
+
for data_name in draw_data_list:
|
|
647
|
+
for i in draw_data[data_name]:
|
|
648
|
+
optimizer_name, ID, Opt_Paras = i
|
|
649
|
+
|
|
650
|
+
if data_name not in results_dict[model_name].keys():
|
|
651
|
+
print('*' * 40)
|
|
652
|
+
print(f'{data_name} not in results')
|
|
653
|
+
print('*' * 40)
|
|
654
|
+
assert False
|
|
655
|
+
|
|
656
|
+
# Check if optimizer_name exists in results_dict
|
|
657
|
+
if optimizer_name not in results_dict[model_name][data_name]:
|
|
658
|
+
print('*' * 40)
|
|
659
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
|
|
660
|
+
print('*' * 40)
|
|
661
|
+
assert False
|
|
662
|
+
|
|
663
|
+
# Check if ID exists in results_dict
|
|
664
|
+
if ID not in results_dict[model_name][data_name][optimizer_name]:
|
|
665
|
+
print('*' * 60)
|
|
666
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
|
|
667
|
+
print('*' * 60)
|
|
668
|
+
assert False
|
|
669
|
+
|
|
670
|
+
# Initialize info_dict[data_name] if it does not exist
|
|
671
|
+
if data_name not in info_dict:
|
|
672
|
+
info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
|
|
673
|
+
|
|
674
|
+
# Update optimizer parameters
|
|
675
|
+
if "optimizer" not in info_dict[data_name]:
|
|
676
|
+
info_dict[data_name]["optimizer"] = {}
|
|
677
|
+
info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
|
|
678
|
+
info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
|
|
679
|
+
|
|
680
|
+
# Update metric_key
|
|
681
|
+
info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
|
|
682
|
+
|
|
683
|
+
return info_dict
|
|
684
|
+
# <update_info_dict>
|
|
685
|
+
|
|
686
|
+
def get_results_all_pkl_path(results_folder):
|
|
687
|
+
|
|
688
|
+
pattern = os.path.join(results_folder, "**", "*.pkl")
|
|
689
|
+
|
|
690
|
+
return glob.glob(pattern, recursive=True)
|