junshan-kit 2.5.1__py2.py3-none-any.whl → 2.8.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +108 -8
- junshan_kit/DataProcessor.py +133 -8
- junshan_kit/DataSets.py +29 -30
- junshan_kit/Evaluate_Metrics.py +75 -2
- junshan_kit/FiguresHub.py +290 -0
- junshan_kit/ModelsHub.py +32 -5
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +352 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +406 -119
- junshan_kit/Print_Info.py +58 -12
- junshan_kit/TrainingHub.py +190 -40
- junshan_kit/kit.py +39 -50
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.8.5.dist-info}/METADATA +7 -1
- junshan_kit-2.8.5.dist-info/RECORD +20 -0
- {junshan_kit-2.5.1.dist-info → junshan_kit-2.8.5.dist-info}/WHEEL +1 -1
- junshan_kit-2.5.1.dist-info/RECORD +0 -13
junshan_kit/ParametersHub.py
CHANGED
|
@@ -1,60 +1,110 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
-
import sys, os, torch, random
|
|
2
|
+
import sys, os, torch, random, glob
|
|
3
3
|
import argparse
|
|
4
|
-
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
6
|
+
sys.path.append(os.path.join(script_dir, 'src'))
|
|
7
|
+
from junshan_kit import ModelsHub, Check_Info
|
|
5
8
|
|
|
6
9
|
|
|
7
|
-
class
|
|
10
|
+
class args:
|
|
8
11
|
def __init__(self):
|
|
9
12
|
pass
|
|
10
|
-
|
|
13
|
+
|
|
14
|
+
# <args>
|
|
11
15
|
def get_args(self):
|
|
12
16
|
parser = argparse.ArgumentParser(description="Combined config argument example")
|
|
13
17
|
|
|
14
|
-
|
|
15
|
-
|
|
18
|
+
# <allowed_models>
|
|
19
|
+
allowed_models = ["LS", "LRBL2", "ResNet18"]
|
|
20
|
+
# <allowed_models>
|
|
21
|
+
|
|
22
|
+
# <allowed_optimizers>
|
|
23
|
+
allowed_optimizers = [
|
|
24
|
+
"ADAM",
|
|
25
|
+
"ALR_SMAG",
|
|
26
|
+
"Bundle",
|
|
27
|
+
"SGD",
|
|
28
|
+
"SPBM_TR",
|
|
29
|
+
"SPBM_PF",
|
|
30
|
+
"SPSmax",
|
|
31
|
+
"SPBM_TR_NoneSpecial",
|
|
32
|
+
"SPBM_TR_NoneLower",
|
|
33
|
+
"SPBM_PF_NoneLower",
|
|
34
|
+
]
|
|
35
|
+
# <allowed_optimizers>
|
|
36
|
+
|
|
37
|
+
# <allowed_datasets>
|
|
38
|
+
allowed_datasets = [
|
|
39
|
+
"MNIST",
|
|
40
|
+
"CIFAR100",
|
|
41
|
+
"Caltech101",
|
|
42
|
+
"AIP",
|
|
43
|
+
"CCFD",
|
|
44
|
+
"Duke",
|
|
45
|
+
"Ijcnn",
|
|
46
|
+
"DHI",
|
|
47
|
+
"EVP",
|
|
48
|
+
"GHP",
|
|
49
|
+
"HL",
|
|
50
|
+
"HQC",
|
|
51
|
+
"TN_Weather",
|
|
52
|
+
],
|
|
53
|
+
# <allowed_datasets>
|
|
54
|
+
data_name_mapping = {
|
|
55
|
+
"MNIST": "MNIST",
|
|
56
|
+
"CIFAR100": "CIFAR100",
|
|
57
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
58
|
+
"Duke": "Duke",
|
|
59
|
+
"AIP": "Adult_Income_Prediction",
|
|
60
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
61
|
+
"Ijcnn": "Ijcnn",
|
|
62
|
+
"RCV1": "RCV1",
|
|
63
|
+
"w8a": "w8a",
|
|
64
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
65
|
+
"EVP": "Electric_Vehicle_Population",
|
|
66
|
+
"GHP": "Global_House_Purchase",
|
|
67
|
+
"HL": "Health_Lifestyle",
|
|
68
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
69
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
70
|
+
}
|
|
16
71
|
|
|
17
|
-
allowed_datasets = ["MNIST",
|
|
18
|
-
"CIFAR100",
|
|
19
|
-
"AIP",
|
|
20
|
-
"CCFD",
|
|
21
|
-
]
|
|
22
|
-
|
|
23
72
|
optimizers_mapping = {
|
|
24
73
|
"ADAM": "ADAM",
|
|
25
74
|
"SGD": "SGD",
|
|
26
|
-
"Bundle": "Bundle"
|
|
75
|
+
"Bundle": "Bundle",
|
|
76
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
77
|
+
"SPBM_TR": "SPBM-TR",
|
|
78
|
+
"SPBM_PF": "SPBM-PF",
|
|
79
|
+
"SPSmax": "SPSmax",
|
|
80
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
81
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
82
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
83
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
84
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
85
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
27
86
|
}
|
|
28
87
|
|
|
29
88
|
model_mapping = {
|
|
30
89
|
"LS": "LeastSquares",
|
|
31
|
-
"
|
|
90
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
32
91
|
"ResNet18": "ResNet18"
|
|
33
92
|
}
|
|
34
|
-
|
|
35
|
-
data_name_mapping = {
|
|
36
|
-
"MNIST": "MNIST",
|
|
37
|
-
"CIFAR100": "CIFAR100",
|
|
38
|
-
"AIP": "Adult_Income_Prediction",
|
|
39
|
-
"CCFD": "Credit_Card_Fraud_Detection"
|
|
40
|
-
}
|
|
41
93
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# Single combined argument that can appear multiple times
|
|
94
|
+
# <args_from_command>
|
|
45
95
|
parser.add_argument(
|
|
46
96
|
"--train",
|
|
47
97
|
type=str,
|
|
48
98
|
nargs="+", # Allow multiple configs
|
|
49
99
|
required=True,
|
|
50
|
-
help = f"Format: model-dataset-optimizer (e.g., ResNet18-
|
|
100
|
+
help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
|
|
51
101
|
)
|
|
52
102
|
|
|
53
103
|
parser.add_argument(
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
104
|
+
"--e",
|
|
105
|
+
type=int,
|
|
106
|
+
required=True,
|
|
107
|
+
help="Number of training epochs. Example: --e 50"
|
|
58
108
|
)
|
|
59
109
|
|
|
60
110
|
parser.add_argument(
|
|
@@ -88,63 +138,84 @@ class check_args:
|
|
|
88
138
|
)
|
|
89
139
|
|
|
90
140
|
parser.add_argument(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
141
|
+
"--subset",
|
|
142
|
+
type=float,
|
|
143
|
+
nargs=2,
|
|
144
|
+
# required=True,
|
|
145
|
+
help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
|
|
96
146
|
)
|
|
97
147
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
148
|
+
parser.add_argument(
|
|
149
|
+
"--time_str",
|
|
150
|
+
type=str,
|
|
151
|
+
nargs=1,
|
|
152
|
+
# required=True,
|
|
153
|
+
help = "the str of time"
|
|
154
|
+
)
|
|
102
155
|
|
|
156
|
+
parser.add_argument(
|
|
157
|
+
"--send_email",
|
|
158
|
+
type=str,
|
|
159
|
+
nargs=3,
|
|
160
|
+
# required=True,
|
|
161
|
+
help = "from_email to_email, from_pwd"
|
|
162
|
+
)
|
|
103
163
|
|
|
104
|
-
|
|
105
|
-
|
|
164
|
+
parser.add_argument(
|
|
165
|
+
"--user_search_grid",
|
|
166
|
+
type=int,
|
|
167
|
+
nargs=1,
|
|
168
|
+
# required=True,
|
|
169
|
+
help = "search_grid: 1: "
|
|
170
|
+
)
|
|
106
171
|
|
|
172
|
+
parser.add_argument(
|
|
173
|
+
"--OptParas",
|
|
174
|
+
type=int,
|
|
175
|
+
nargs=1,
|
|
176
|
+
help="Number of optimization steps for parameter tuning (default: 1)"
|
|
177
|
+
)
|
|
178
|
+
# <args_from_command>
|
|
107
179
|
|
|
108
|
-
|
|
180
|
+
args = parser.parse_args()
|
|
181
|
+
args.model_name_mapping = model_mapping
|
|
182
|
+
args.data_name_mapping = data_name_mapping
|
|
183
|
+
args.optimizers_name_mapping = optimizers_mapping
|
|
109
184
|
|
|
110
185
|
return args
|
|
186
|
+
# <args>
|
|
187
|
+
|
|
188
|
+
def UpdateOtherParas(args, OtherParas):
|
|
189
|
+
# <time_str>
|
|
190
|
+
if args.time_str is not None:
|
|
191
|
+
time_str = args.time_str[0]
|
|
192
|
+
else:
|
|
193
|
+
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
111
194
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
except ValueError:
|
|
138
|
-
parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
|
|
139
|
-
|
|
140
|
-
for cfg in args.train:
|
|
141
|
-
model_name, dataset_name, optimizer_name = cfg.split("-")
|
|
142
|
-
try:
|
|
143
|
-
f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
|
|
144
|
-
|
|
145
|
-
except:
|
|
146
|
-
print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
|
|
147
|
-
assert False
|
|
195
|
+
# <user_search_grid>
|
|
196
|
+
if args.user_search_grid is not None:
|
|
197
|
+
OtherParas["user_search_grid"] = args.user_search_grid[0]
|
|
198
|
+
else:
|
|
199
|
+
OtherParas["user_search_grid"] = None
|
|
200
|
+
|
|
201
|
+
# <send_email>
|
|
202
|
+
if args.send_email is not None:
|
|
203
|
+
OtherParas["from_email"] = args.send_email[0]
|
|
204
|
+
OtherParas["to_email"] = args.send_email[1]
|
|
205
|
+
OtherParas["from_pwd"] = args.send_email[2]
|
|
206
|
+
OtherParas["send_email"] = True
|
|
207
|
+
else:
|
|
208
|
+
OtherParas["send_email"] = False
|
|
209
|
+
|
|
210
|
+
if args.OptParas is not None:
|
|
211
|
+
OtherParas["SeleParasOn"] = False
|
|
212
|
+
else:
|
|
213
|
+
OtherParas["SeleParasOn"] = True
|
|
214
|
+
|
|
215
|
+
OtherParas["time_str"] = time_str
|
|
216
|
+
OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
|
|
217
|
+
|
|
218
|
+
return OtherParas
|
|
148
219
|
|
|
149
220
|
def get_train_group(args):
|
|
150
221
|
training_group = []
|
|
@@ -163,8 +234,10 @@ def set_paras(args, OtherParas):
|
|
|
163
234
|
# Print loss every N epochs.
|
|
164
235
|
"epoch_log_interval": 1,
|
|
165
236
|
|
|
237
|
+
"use_log_scale": True,
|
|
238
|
+
|
|
166
239
|
# Timestamp string for result saving.
|
|
167
|
-
"time_str": ["time_str"],
|
|
240
|
+
"time_str": OtherParas["time_str"],
|
|
168
241
|
|
|
169
242
|
# Random seed
|
|
170
243
|
"seed": args.seed,
|
|
@@ -182,18 +255,23 @@ def set_paras(args, OtherParas):
|
|
|
182
255
|
"split_train_data": args.s,
|
|
183
256
|
|
|
184
257
|
# select_subset
|
|
185
|
-
"select_subset": args.subset
|
|
258
|
+
"select_subset": args.subset,
|
|
259
|
+
|
|
260
|
+
# Results_dict
|
|
261
|
+
"Results_dict": {},
|
|
262
|
+
|
|
263
|
+
# type: bool
|
|
264
|
+
"user_search_grid": OtherParas["user_search_grid"],
|
|
186
265
|
}
|
|
187
266
|
|
|
188
267
|
Paras = model_list(Paras)
|
|
189
268
|
Paras = model_type(Paras)
|
|
190
269
|
Paras = data_list(Paras)
|
|
191
|
-
Paras =
|
|
270
|
+
Paras = optimizer_paras_dict(Paras, OtherParas)
|
|
192
271
|
Paras = device(Paras)
|
|
193
272
|
|
|
194
273
|
return Paras
|
|
195
274
|
|
|
196
|
-
|
|
197
275
|
def set_seed(seed=42):
|
|
198
276
|
torch.manual_seed(seed)
|
|
199
277
|
torch.cuda.manual_seed_all(seed)
|
|
@@ -234,34 +312,6 @@ def model_type(Paras) -> dict:
|
|
|
234
312
|
return Paras
|
|
235
313
|
|
|
236
314
|
def data_list(Paras) -> dict:
|
|
237
|
-
"""
|
|
238
|
-
Attach a predefined list of dataset names to the parameter dictionary.
|
|
239
|
-
|
|
240
|
-
The predefined datasets include:
|
|
241
|
-
- Duke:
|
|
242
|
-
- classes: 2
|
|
243
|
-
- data: 42 (38 + 4)
|
|
244
|
-
- features: 7,129
|
|
245
|
-
- Ijcnn:
|
|
246
|
-
- classes: 2
|
|
247
|
-
- data: (35,000 + 91,701)
|
|
248
|
-
- features: 22
|
|
249
|
-
- w8a:
|
|
250
|
-
- classes: 2
|
|
251
|
-
- data: (49,749 + 14,951)
|
|
252
|
-
- features: 300
|
|
253
|
-
- RCV1
|
|
254
|
-
- Shuttle
|
|
255
|
-
- Letter
|
|
256
|
-
- Vowel
|
|
257
|
-
- MNIST
|
|
258
|
-
- CIFAR100
|
|
259
|
-
- CALTECH101_Resize_32
|
|
260
|
-
- Adult Income Prediction
|
|
261
|
-
-
|
|
262
|
-
- Credit_Card_Fraud_Detection
|
|
263
|
-
"""
|
|
264
|
-
|
|
265
315
|
data_list = [
|
|
266
316
|
"Duke",
|
|
267
317
|
"Ijcnn",
|
|
@@ -272,15 +322,21 @@ def data_list(Paras) -> dict:
|
|
|
272
322
|
"Vowel",
|
|
273
323
|
"MNIST",
|
|
274
324
|
"CIFAR100",
|
|
275
|
-
"
|
|
325
|
+
"Caltech101_Resize_32",
|
|
276
326
|
"Adult_Income_Prediction",
|
|
277
|
-
"Credit_Card_Fraud_Detection"
|
|
327
|
+
"Credit_Card_Fraud_Detection",
|
|
328
|
+
"Diabetes_Health_Indicators",
|
|
329
|
+
"Electric_Vehicle_Population",
|
|
330
|
+
"Global_House_Purchase",
|
|
331
|
+
"Health_Lifestyle",
|
|
332
|
+
"Homesite_Quote_Conversion",
|
|
333
|
+
"TN_Weather_2020_2025"
|
|
278
334
|
]
|
|
279
335
|
Paras["data_list"] = data_list
|
|
280
336
|
return Paras
|
|
281
337
|
|
|
282
338
|
|
|
283
|
-
def
|
|
339
|
+
def optimizer_paras_dict(Paras, OtherParas)->dict:
|
|
284
340
|
optimizer_dict = {
|
|
285
341
|
# ----------------- ADAM --------------------
|
|
286
342
|
"ADAM": {
|
|
@@ -289,7 +345,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
289
345
|
"alpha": (
|
|
290
346
|
[0.5 * 1e-3, 1e-3, 2 * 1e-3]
|
|
291
347
|
if OtherParas["SeleParasOn"]
|
|
292
|
-
else [
|
|
348
|
+
else [1e-3]
|
|
293
349
|
),
|
|
294
350
|
"epsilon": [1e-8],
|
|
295
351
|
"beta1": [0.9],
|
|
@@ -314,7 +370,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
314
370
|
"delta": (
|
|
315
371
|
[2**i for i in range(-8, 9)]
|
|
316
372
|
if OtherParas["SeleParasOn"]
|
|
317
|
-
else [0.
|
|
373
|
+
else [0.01]
|
|
318
374
|
),
|
|
319
375
|
"cutting_number": [10],
|
|
320
376
|
},
|
|
@@ -323,7 +379,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
323
379
|
"SGD": {
|
|
324
380
|
"params": {
|
|
325
381
|
"alpha": (
|
|
326
|
-
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.
|
|
382
|
+
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
|
|
327
383
|
)
|
|
328
384
|
}
|
|
329
385
|
},
|
|
@@ -386,6 +442,18 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
386
442
|
"cutting_number": [10],
|
|
387
443
|
},
|
|
388
444
|
},
|
|
445
|
+
# ----------- SPBM-TR-NoneCut -----------
|
|
446
|
+
"SPBM-TR-NoneCut": {
|
|
447
|
+
"params": {
|
|
448
|
+
"delta": (
|
|
449
|
+
[2**i for i in range(-8, 9)]
|
|
450
|
+
if OtherParas["SeleParasOn"]
|
|
451
|
+
else [1]
|
|
452
|
+
),
|
|
453
|
+
"cutting_number": [10],
|
|
454
|
+
},
|
|
455
|
+
},
|
|
456
|
+
|
|
389
457
|
# ------------- SPBM-PF-NoneLower -----------
|
|
390
458
|
"SPBM-PF-NoneLower": {
|
|
391
459
|
"params": {
|
|
@@ -398,13 +466,22 @@ def optimizer_dict(Paras, OtherParas)->dict:
|
|
|
398
466
|
"cutting_number": [10],
|
|
399
467
|
},
|
|
400
468
|
},
|
|
469
|
+
# ----------- SPBM-PF-NoneCut -----------
|
|
470
|
+
"SPBM-PF-NoneCut": {
|
|
471
|
+
"params": {
|
|
472
|
+
"delta": (
|
|
473
|
+
[2**i for i in range(-8, 9)]
|
|
474
|
+
if OtherParas["SeleParasOn"]
|
|
475
|
+
else [1]
|
|
476
|
+
),
|
|
477
|
+
"cutting_number": [10],
|
|
478
|
+
},
|
|
479
|
+
},
|
|
401
480
|
}
|
|
402
481
|
|
|
403
|
-
Paras["
|
|
482
|
+
Paras["optimizer_search_grid"] = optimizer_dict
|
|
404
483
|
return Paras
|
|
405
484
|
|
|
406
|
-
|
|
407
|
-
|
|
408
485
|
def metrics()->dict:
|
|
409
486
|
metrics = {
|
|
410
487
|
"epoch_loss": [],
|
|
@@ -416,4 +493,214 @@ def metrics()->dict:
|
|
|
416
493
|
"grad_norm": [],
|
|
417
494
|
"per_epoch_loss": []
|
|
418
495
|
}
|
|
419
|
-
return metrics
|
|
496
|
+
return metrics
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
|
|
500
|
+
|
|
501
|
+
keys, values = list(params_gird.keys()), list(params_gird.values())
|
|
502
|
+
|
|
503
|
+
Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
|
|
504
|
+
os.makedirs(Paras["Results_folder"], exist_ok=True)
|
|
505
|
+
|
|
506
|
+
return keys, values, Paras
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def fig_ylabel(str_name):
|
|
510
|
+
|
|
511
|
+
ylabel = {
|
|
512
|
+
"training_loss": "training loss",
|
|
513
|
+
"test_loss": "test loss",
|
|
514
|
+
"training_acc": "training accuracy",
|
|
515
|
+
"test_acc": "test accuracy",
|
|
516
|
+
"grad_norm": "grad norm",
|
|
517
|
+
"per_epoch_loss": "per epoch loss",
|
|
518
|
+
"epoch_loss": "epoch loss",
|
|
519
|
+
}
|
|
520
|
+
return ylabel[str_name]
|
|
521
|
+
|
|
522
|
+
def model_abbr(model_name):
|
|
523
|
+
name_map = {
|
|
524
|
+
"LogRegressionBinaryL2": "LRBL2",
|
|
525
|
+
"ResNet18": "ResNet18",
|
|
526
|
+
"ResNet34": "ResNet34",
|
|
527
|
+
"LstSquares": "LS"
|
|
528
|
+
}
|
|
529
|
+
return name_map[model_name]
|
|
530
|
+
|
|
531
|
+
def dataset_abbr(model_name):
|
|
532
|
+
name_map = {
|
|
533
|
+
"MNIST": "MNIST",
|
|
534
|
+
"CIFAR100": "CIFAR100",
|
|
535
|
+
"Duke": "Duke",
|
|
536
|
+
"Ijcnn": "Ijcnn",
|
|
537
|
+
"Adult_Income_Prediction": "AIP",
|
|
538
|
+
"Credit_Card_Frau_Detection": "CCFD",
|
|
539
|
+
"Diabetes_Health_Indicators": "DHI",
|
|
540
|
+
"Electric_Vehicle_Population": "EVP",
|
|
541
|
+
"Global_House_Purchase": "GHP",
|
|
542
|
+
"Health_Lifestyle": "HL",
|
|
543
|
+
}
|
|
544
|
+
return name_map[model_name]
|
|
545
|
+
|
|
546
|
+
def model_full_name(model_name):
|
|
547
|
+
model_mapping = {
|
|
548
|
+
"LS": "LeastSquares",
|
|
549
|
+
"LRBL2": "LogRegressionBinaryL2",
|
|
550
|
+
"ResNet18": "ResNet18",
|
|
551
|
+
}
|
|
552
|
+
return model_mapping[model_name]
|
|
553
|
+
# <optimizers_full_name>
|
|
554
|
+
def optimizers_full_name(optimizer_name):
|
|
555
|
+
name_map = {
|
|
556
|
+
"ADAM": "ADAM",
|
|
557
|
+
"SGD": "SGD",
|
|
558
|
+
"Bundle": "Bundle",
|
|
559
|
+
"ALR_SMAG": "ALR-SMAG",
|
|
560
|
+
"SPBM_TR": "SPBM-TR",
|
|
561
|
+
"SPBM_PF": "SPBM-PF",
|
|
562
|
+
"SPSmax": "SPSmax",
|
|
563
|
+
"SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
|
|
564
|
+
"SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
|
|
565
|
+
"SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
|
|
566
|
+
"SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
|
|
567
|
+
"SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
|
|
568
|
+
"SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
|
|
569
|
+
}
|
|
570
|
+
return name_map[optimizer_name]
|
|
571
|
+
# <optimizers_full_name>
|
|
572
|
+
|
|
573
|
+
# <dataset_full_name>
|
|
574
|
+
def dataset_full_name(dataset_name):
|
|
575
|
+
name_map = {
|
|
576
|
+
"MNIST": "MNIST",
|
|
577
|
+
"CIFAR100": "CIFAR100",
|
|
578
|
+
"Caltech101": "Caltech101_Resize_32",
|
|
579
|
+
"Duke": "Duke",
|
|
580
|
+
"AIP": "Adult_Income_Prediction",
|
|
581
|
+
"CCFD": "Credit_Card_Fraud_Detection",
|
|
582
|
+
"Ijcnn": "Ijcnn",
|
|
583
|
+
"DHI":"Diabetes_Health_Indicators",
|
|
584
|
+
"EVP": "Electric_Vehicle_Population",
|
|
585
|
+
"GHP": "Global_House_Purchase",
|
|
586
|
+
"HL": "Health_Lifestyle",
|
|
587
|
+
"HQC": "Homesite_Quote_Conversion",
|
|
588
|
+
"TN_Weather": "TN_Weather_2020_2025",
|
|
589
|
+
}
|
|
590
|
+
return name_map[dataset_name]
|
|
591
|
+
# <dataset_full_name>
|
|
592
|
+
|
|
593
|
+
def opt_paras_str(opt_paras_dict):
|
|
594
|
+
# Example: "k1_v1_k2_v2_..."
|
|
595
|
+
|
|
596
|
+
keys = list(opt_paras_dict.keys())
|
|
597
|
+
values = list(opt_paras_dict.values())
|
|
598
|
+
|
|
599
|
+
param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
|
|
600
|
+
|
|
601
|
+
return param_str
|
|
602
|
+
# <set_marker_point>
|
|
603
|
+
def set_marker_point(epoch_num: int) -> list:
|
|
604
|
+
marker_point = {
|
|
605
|
+
1: [0],
|
|
606
|
+
4: [0, 2, 4],
|
|
607
|
+
6: [0, 2, 4, 6],
|
|
608
|
+
8: [0, 2, 4, 6, 8],
|
|
609
|
+
10: [0, 2, 4, 6, 8, 10],
|
|
610
|
+
50: [0, 10, 20, 30, 40, 50],
|
|
611
|
+
100: [0, 20, 40, 60, 80, 100],
|
|
612
|
+
200: [0, 40, 80, 120, 160, 200],
|
|
613
|
+
}
|
|
614
|
+
if epoch_num not in marker_point:
|
|
615
|
+
raise ValueError(f"No marker defined for epoch {epoch_num}")
|
|
616
|
+
|
|
617
|
+
return marker_point[epoch_num]
|
|
618
|
+
|
|
619
|
+
# <set_marker_point>
|
|
620
|
+
# <results_path_to_info>
|
|
621
|
+
def results_path_to_info(path_list):
|
|
622
|
+
info_dict = {}
|
|
623
|
+
|
|
624
|
+
for path in path_list:
|
|
625
|
+
parts = path.split("/")
|
|
626
|
+
seed = parts[1]
|
|
627
|
+
model_name = parts[2]
|
|
628
|
+
data_name = parts[3]
|
|
629
|
+
optimizer = parts[4]
|
|
630
|
+
train_test = parts[5].split("_")
|
|
631
|
+
batch_size = parts[6].split("_")[2]
|
|
632
|
+
epochs = parts[7].split("_")[1]
|
|
633
|
+
ID = parts[8]
|
|
634
|
+
|
|
635
|
+
if model_name not in info_dict:
|
|
636
|
+
info_dict[model_name] = {}
|
|
637
|
+
|
|
638
|
+
if data_name not in info_dict[model_name]:
|
|
639
|
+
info_dict[model_name][data_name] = {}
|
|
640
|
+
|
|
641
|
+
if optimizer not in info_dict[model_name][data_name]:
|
|
642
|
+
info_dict[model_name][data_name][optimizer] = {}
|
|
643
|
+
|
|
644
|
+
info_dict[model_name][data_name][optimizer][ID] = {
|
|
645
|
+
"seed": seed.split("_")[1],
|
|
646
|
+
"epochs": int(epochs),
|
|
647
|
+
"train_test": (train_test[1], train_test[3]),
|
|
648
|
+
"batch_size": batch_size,
|
|
649
|
+
"marker": set_marker_point(int(epochs)),
|
|
650
|
+
"optimizer":{
|
|
651
|
+
f"{optimizer}":{
|
|
652
|
+
"ID": ID,
|
|
653
|
+
}
|
|
654
|
+
}
|
|
655
|
+
}
|
|
656
|
+
|
|
657
|
+
return info_dict
|
|
658
|
+
# <results_path_to_info>
|
|
659
|
+
|
|
660
|
+
# <update_info_dict>
|
|
661
|
+
def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
|
|
662
|
+
for data_name in draw_data_list:
|
|
663
|
+
for i in draw_data[data_name]:
|
|
664
|
+
optimizer_name, ID, Opt_Paras = i
|
|
665
|
+
|
|
666
|
+
if data_name not in results_dict[model_name].keys():
|
|
667
|
+
print('*' * 40)
|
|
668
|
+
print(f'{data_name} not in results')
|
|
669
|
+
print('*' * 40)
|
|
670
|
+
assert False
|
|
671
|
+
|
|
672
|
+
# Check if optimizer_name exists in results_dict
|
|
673
|
+
if optimizer_name not in results_dict[model_name][data_name]:
|
|
674
|
+
print('*' * 40)
|
|
675
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
|
|
676
|
+
print('*' * 40)
|
|
677
|
+
assert False
|
|
678
|
+
|
|
679
|
+
# Check if ID exists in results_dict
|
|
680
|
+
if ID not in results_dict[model_name][data_name][optimizer_name]:
|
|
681
|
+
print('*' * 60)
|
|
682
|
+
print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
|
|
683
|
+
print('*' * 60)
|
|
684
|
+
assert False
|
|
685
|
+
|
|
686
|
+
# Initialize info_dict[data_name] if it does not exist
|
|
687
|
+
if data_name not in info_dict:
|
|
688
|
+
info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
|
|
689
|
+
|
|
690
|
+
# Update optimizer parameters
|
|
691
|
+
if "optimizer" not in info_dict[data_name]:
|
|
692
|
+
info_dict[data_name]["optimizer"] = {}
|
|
693
|
+
info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
|
|
694
|
+
info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
|
|
695
|
+
|
|
696
|
+
# Update metric_key
|
|
697
|
+
info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
|
|
698
|
+
|
|
699
|
+
return info_dict
|
|
700
|
+
# <update_info_dict>
|
|
701
|
+
|
|
702
|
+
def get_results_all_pkl_path(results_folder):
|
|
703
|
+
|
|
704
|
+
pattern = os.path.join(results_folder, "**", "*.pkl")
|
|
705
|
+
|
|
706
|
+
return glob.glob(pattern, recursive=True)
|