junshan-kit 2.2.8__py2.py3-none-any.whl → 2.5.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/DataHub.py +114 -0
- junshan_kit/DataProcessor.py +224 -12
- junshan_kit/DataSets.py +331 -18
- junshan_kit/Evaluate_Metrics.py +40 -0
- junshan_kit/ModelsHub.py +212 -0
- junshan_kit/ParametersHub.py +419 -0
- junshan_kit/Print_Info.py +63 -0
- junshan_kit/TrainingHub.py +174 -0
- junshan_kit/kit.py +93 -23
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.5.1.dist-info}/METADATA +2 -4
- junshan_kit-2.5.1.dist-info/RECORD +13 -0
- junshan_kit-2.2.8.dist-info/RECORD +0 -7
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.5.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,419 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import sys, os, torch, random
|
|
3
|
+
import argparse
|
|
4
|
+
import junshan_kit.ModelsHub as ModelsHub
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class check_args:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
def get_args(self):
|
|
12
|
+
parser = argparse.ArgumentParser(description="Combined config argument example")
|
|
13
|
+
|
|
14
|
+
allowed_models = ["LS", "LRL2","ResNet18"]
|
|
15
|
+
allowed_optimizers = ["ADAM", "SGD", "Bundle"]
|
|
16
|
+
|
|
17
|
+
allowed_datasets = ["MNIST",
|
|
18
|
+
"CIFAR100",
|
|
19
|
+
"AIP",
|
|
20
|
+
"CCFD",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
optimizers_mapping = {
|
|
24
|
+
"ADAM": "ADAM",
|
|
25
|
+
"SGD": "SGD",
|
|
26
|
+
"Bundle": "Bundle"
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
model_mapping = {
|
|
30
|
+
"LS": "LeastSquares",
|
|
31
|
+
"LRL2": "LogRegressionBinaryL2",
|
|
32
|
+
"ResNet18": "ResNet18"
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
data_name_mapping = {
|
|
36
|
+
"MNIST": "MNIST",
|
|
37
|
+
"CIFAR100": "CIFAR100",
|
|
38
|
+
"AIP": "Adult_Income_Prediction",
|
|
39
|
+
"CCFD": "Credit_Card_Fraud_Detection"
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Single combined argument that can appear multiple times
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--train",
|
|
47
|
+
type=str,
|
|
48
|
+
nargs="+", # Allow multiple configs
|
|
49
|
+
required=True,
|
|
50
|
+
help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--e",
|
|
55
|
+
type=int,
|
|
56
|
+
required=True,
|
|
57
|
+
help="Number of training epochs. Example: --e 50"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--seed",
|
|
62
|
+
type=int,
|
|
63
|
+
default=42,
|
|
64
|
+
help="Random seed for experiment reproducibility. Default: 42"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
parser.add_argument(
|
|
68
|
+
"--bs",
|
|
69
|
+
type=int,
|
|
70
|
+
required=True,
|
|
71
|
+
help="Batch size for training. Example: --bs 128"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--cuda",
|
|
76
|
+
type=int,
|
|
77
|
+
default=0,
|
|
78
|
+
required=True,
|
|
79
|
+
help="The number of cuda. Example: --cuda 1 (default=0) "
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
parser.add_argument(
|
|
83
|
+
"--s",
|
|
84
|
+
type=float,
|
|
85
|
+
default=1.0,
|
|
86
|
+
# required=True,
|
|
87
|
+
help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--subset",
|
|
92
|
+
type=float,
|
|
93
|
+
nargs=2,
|
|
94
|
+
# required=True,
|
|
95
|
+
help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
args = parser.parse_args()
|
|
99
|
+
args.model_name_mapping = model_mapping
|
|
100
|
+
args.data_name_mapping = data_name_mapping
|
|
101
|
+
args.optimizers_name_mapping = optimizers_mapping
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if args.subset is not None:
|
|
105
|
+
self.check_subset_info(args, parser)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
|
|
109
|
+
|
|
110
|
+
return args
|
|
111
|
+
|
|
112
|
+
def check_subset_info(self, args, parser):
|
|
113
|
+
total = sum(args.subset)
|
|
114
|
+
if args.subset[0]>1:
|
|
115
|
+
# CHECK
|
|
116
|
+
for i in args.subset:
|
|
117
|
+
if i < 1:
|
|
118
|
+
parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
|
|
119
|
+
else:
|
|
120
|
+
if abs(total - 1.0) != 0.0:
|
|
121
|
+
parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
|
|
125
|
+
# Parse and validate each train_group
|
|
126
|
+
for cfg in args.train:
|
|
127
|
+
try:
|
|
128
|
+
model, dataset, optimizer = cfg.split("-")
|
|
129
|
+
|
|
130
|
+
if model not in allowed_models:
|
|
131
|
+
parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
|
|
132
|
+
if optimizer not in allowed_optimizers:
|
|
133
|
+
parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
|
|
134
|
+
if dataset not in allowed_datasets:
|
|
135
|
+
parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
|
|
136
|
+
|
|
137
|
+
except ValueError:
|
|
138
|
+
parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
|
|
139
|
+
|
|
140
|
+
for cfg in args.train:
|
|
141
|
+
model_name, dataset_name, optimizer_name = cfg.split("-")
|
|
142
|
+
try:
|
|
143
|
+
f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
|
|
144
|
+
|
|
145
|
+
except:
|
|
146
|
+
print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
|
|
147
|
+
assert False
|
|
148
|
+
|
|
149
|
+
def get_train_group(args):
|
|
150
|
+
training_group = []
|
|
151
|
+
for cfg in args.train:
|
|
152
|
+
model, dataset, optimizer = cfg.split("-")
|
|
153
|
+
training_group.append((args.model_name_mapping[model], args.data_name_mapping[dataset], args.optimizers_name_mapping[optimizer]))
|
|
154
|
+
|
|
155
|
+
return training_group
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def set_paras(args, OtherParas):
|
|
159
|
+
Paras = {
|
|
160
|
+
# Name of the folder where results will be saved.
|
|
161
|
+
"results_folder_name": OtherParas["results_folder_name"],
|
|
162
|
+
|
|
163
|
+
# Print loss every N epochs.
|
|
164
|
+
"epoch_log_interval": 1,
|
|
165
|
+
|
|
166
|
+
# Timestamp string for result saving.
|
|
167
|
+
"time_str": ["time_str"],
|
|
168
|
+
|
|
169
|
+
# Random seed
|
|
170
|
+
"seed": args.seed,
|
|
171
|
+
|
|
172
|
+
# Device used for training.
|
|
173
|
+
"cuda": f"cuda:{args.cuda}",
|
|
174
|
+
|
|
175
|
+
# batch-size
|
|
176
|
+
"batch_size": args.bs,
|
|
177
|
+
|
|
178
|
+
# epochs
|
|
179
|
+
"epochs": args.e,
|
|
180
|
+
|
|
181
|
+
# split_train_data
|
|
182
|
+
"split_train_data": args.s,
|
|
183
|
+
|
|
184
|
+
# select_subset
|
|
185
|
+
"select_subset": args.subset
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
Paras = model_list(Paras)
|
|
189
|
+
Paras = model_type(Paras)
|
|
190
|
+
Paras = data_list(Paras)
|
|
191
|
+
Paras = optimizer_dict(Paras, OtherParas)
|
|
192
|
+
Paras = device(Paras)
|
|
193
|
+
|
|
194
|
+
return Paras
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def set_seed(seed=42):
|
|
198
|
+
torch.manual_seed(seed)
|
|
199
|
+
torch.cuda.manual_seed_all(seed)
|
|
200
|
+
np.random.seed(seed)
|
|
201
|
+
random.seed(seed)
|
|
202
|
+
torch.backends.cudnn.deterministic = True
|
|
203
|
+
torch.backends.cudnn.benchmark = False
|
|
204
|
+
|
|
205
|
+
def device(Paras) -> dict:
|
|
206
|
+
device = torch.device(f"{Paras['cuda']}" if torch.cuda.is_available() else "cpu")
|
|
207
|
+
Paras["device"] = device
|
|
208
|
+
use_color = sys.stdout.isatty()
|
|
209
|
+
Paras["use_color"] = use_color
|
|
210
|
+
|
|
211
|
+
return Paras
|
|
212
|
+
|
|
213
|
+
def model_list(Paras) -> dict:
|
|
214
|
+
model_list = [
|
|
215
|
+
"ResNet18",
|
|
216
|
+
"ResNet34",
|
|
217
|
+
"LeastSquares",
|
|
218
|
+
"LogRegressionBinary",
|
|
219
|
+
"LogRegressionBinaryL2",
|
|
220
|
+
]
|
|
221
|
+
Paras["model_list"] = model_list
|
|
222
|
+
return Paras
|
|
223
|
+
|
|
224
|
+
def model_type(Paras) -> dict:
|
|
225
|
+
model_type = {
|
|
226
|
+
"ResNet18": "multi",
|
|
227
|
+
"ResNet34": "multi",
|
|
228
|
+
"LeastSquares": "multi",
|
|
229
|
+
"LogRegressionBinary": "binary",
|
|
230
|
+
"LogRegressionBinaryL2": "binary",
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
Paras["model_type"] = model_type
|
|
234
|
+
return Paras
|
|
235
|
+
|
|
236
|
+
def data_list(Paras) -> dict:
|
|
237
|
+
"""
|
|
238
|
+
Attach a predefined list of dataset names to the parameter dictionary.
|
|
239
|
+
|
|
240
|
+
The predefined datasets include:
|
|
241
|
+
- Duke:
|
|
242
|
+
- classes: 2
|
|
243
|
+
- data: 42 (38 + 4)
|
|
244
|
+
- features: 7,129
|
|
245
|
+
- Ijcnn:
|
|
246
|
+
- classes: 2
|
|
247
|
+
- data: (35,000 + 91,701)
|
|
248
|
+
- features: 22
|
|
249
|
+
- w8a:
|
|
250
|
+
- classes: 2
|
|
251
|
+
- data: (49,749 + 14,951)
|
|
252
|
+
- features: 300
|
|
253
|
+
- RCV1
|
|
254
|
+
- Shuttle
|
|
255
|
+
- Letter
|
|
256
|
+
- Vowel
|
|
257
|
+
- MNIST
|
|
258
|
+
- CIFAR100
|
|
259
|
+
- CALTECH101_Resize_32
|
|
260
|
+
- Adult Income Prediction
|
|
261
|
+
-
|
|
262
|
+
- Credit_Card_Fraud_Detection
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
data_list = [
|
|
266
|
+
"Duke",
|
|
267
|
+
"Ijcnn",
|
|
268
|
+
"w8a",
|
|
269
|
+
"RCV1",
|
|
270
|
+
"Shuttle",
|
|
271
|
+
"Letter",
|
|
272
|
+
"Vowel",
|
|
273
|
+
"MNIST",
|
|
274
|
+
"CIFAR100",
|
|
275
|
+
"CALTECH101_Resize_32",
|
|
276
|
+
"Adult_Income_Prediction",
|
|
277
|
+
"Credit_Card_Fraud_Detection"
|
|
278
|
+
]
|
|
279
|
+
Paras["data_list"] = data_list
|
|
280
|
+
return Paras
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def optimizer_dict(Paras, OtherParas)->dict:
|
|
284
|
+
optimizer_dict = {
|
|
285
|
+
# ----------------- ADAM --------------------
|
|
286
|
+
"ADAM": {
|
|
287
|
+
"params": {
|
|
288
|
+
# "alpha": [2 * 1e-3],
|
|
289
|
+
"alpha": (
|
|
290
|
+
[0.5 * 1e-3, 1e-3, 2 * 1e-3]
|
|
291
|
+
if OtherParas["SeleParasOn"]
|
|
292
|
+
else [0.0005]
|
|
293
|
+
),
|
|
294
|
+
"epsilon": [1e-8],
|
|
295
|
+
"beta1": [0.9],
|
|
296
|
+
"beta2": [0.999],
|
|
297
|
+
},
|
|
298
|
+
},
|
|
299
|
+
# ------------- ALR-SMAG --------------------
|
|
300
|
+
"ALR-SMAG": {
|
|
301
|
+
"params": {
|
|
302
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
303
|
+
"eta_max": (
|
|
304
|
+
[2**i for i in range(-8, 9)]
|
|
305
|
+
if OtherParas["SeleParasOn"]
|
|
306
|
+
else [0.125]
|
|
307
|
+
),
|
|
308
|
+
"beta": [0.9],
|
|
309
|
+
},
|
|
310
|
+
},
|
|
311
|
+
# ------------ Bundle -----------------------
|
|
312
|
+
"Bundle": {
|
|
313
|
+
"params": {
|
|
314
|
+
"delta": (
|
|
315
|
+
[2**i for i in range(-8, 9)]
|
|
316
|
+
if OtherParas["SeleParasOn"]
|
|
317
|
+
else [0.25]
|
|
318
|
+
),
|
|
319
|
+
"cutting_number": [10],
|
|
320
|
+
},
|
|
321
|
+
},
|
|
322
|
+
# ------------------- SGD -------------------
|
|
323
|
+
"SGD": {
|
|
324
|
+
"params": {
|
|
325
|
+
"alpha": (
|
|
326
|
+
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.5]
|
|
327
|
+
)
|
|
328
|
+
}
|
|
329
|
+
},
|
|
330
|
+
# ------------------- SPSmax ----------------
|
|
331
|
+
"SPSmax": {
|
|
332
|
+
"params": {
|
|
333
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
334
|
+
"gamma": (
|
|
335
|
+
[2**i for i in range(-8, 9)]
|
|
336
|
+
if OtherParas["SeleParasOn"]
|
|
337
|
+
else [0.125]),
|
|
338
|
+
},
|
|
339
|
+
},
|
|
340
|
+
# -------------- SPBM-PF --------------------
|
|
341
|
+
"SPBM-PF": {
|
|
342
|
+
"params": {
|
|
343
|
+
"M": [1e-5],
|
|
344
|
+
"delta": (
|
|
345
|
+
[2**i for i in range(9, 20)]
|
|
346
|
+
if OtherParas["SeleParasOn"]
|
|
347
|
+
else [1]
|
|
348
|
+
),
|
|
349
|
+
"cutting_number": [10],
|
|
350
|
+
},
|
|
351
|
+
},
|
|
352
|
+
# -------------- SPBM-TR --------------------
|
|
353
|
+
"SPBM-TR": {
|
|
354
|
+
"params": {
|
|
355
|
+
"M": [1e-5],
|
|
356
|
+
"delta": (
|
|
357
|
+
[2**i for i in range(9, 20)]
|
|
358
|
+
if OtherParas["SeleParasOn"]
|
|
359
|
+
else [256]
|
|
360
|
+
),
|
|
361
|
+
"cutting_number": [10],
|
|
362
|
+
},
|
|
363
|
+
},
|
|
364
|
+
|
|
365
|
+
# ----------- SPBM-TR-NoneLower -------------
|
|
366
|
+
"SPBM-TR-NoneLower": {
|
|
367
|
+
"params": {
|
|
368
|
+
"M": [1e-5],
|
|
369
|
+
"delta": (
|
|
370
|
+
[2**i for i in range(0, 9)]
|
|
371
|
+
if OtherParas["SeleParasOn"]
|
|
372
|
+
else [256]
|
|
373
|
+
),
|
|
374
|
+
"cutting_number": [10],
|
|
375
|
+
},
|
|
376
|
+
},
|
|
377
|
+
# ----------- SPBM-TR-NoneSpecial -----------
|
|
378
|
+
"SPBM-TR-NoneSpecial": {
|
|
379
|
+
"params": {
|
|
380
|
+
"M": [1e-5],
|
|
381
|
+
"delta": (
|
|
382
|
+
[2**i for i in range(-8, 9)]
|
|
383
|
+
if OtherParas["SeleParasOn"]
|
|
384
|
+
else [1]
|
|
385
|
+
),
|
|
386
|
+
"cutting_number": [10],
|
|
387
|
+
},
|
|
388
|
+
},
|
|
389
|
+
# ------------- SPBM-PF-NoneLower -----------
|
|
390
|
+
"SPBM-PF-NoneLower": {
|
|
391
|
+
"params": {
|
|
392
|
+
"M": [1e-5],
|
|
393
|
+
"delta": (
|
|
394
|
+
[2**i for i in range(0, 9)]
|
|
395
|
+
if OtherParas["SeleParasOn"]
|
|
396
|
+
else [0]
|
|
397
|
+
),
|
|
398
|
+
"cutting_number": [10],
|
|
399
|
+
},
|
|
400
|
+
},
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
Paras["optimizer_dict"] = optimizer_dict
|
|
404
|
+
return Paras
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def metrics()->dict:
|
|
409
|
+
metrics = {
|
|
410
|
+
"epoch_loss": [],
|
|
411
|
+
"training_loss": [],
|
|
412
|
+
"test_loss": [],
|
|
413
|
+
"iter_loss": [],
|
|
414
|
+
"training_acc": [],
|
|
415
|
+
"test_acc": [],
|
|
416
|
+
"grad_norm": [],
|
|
417
|
+
"per_epoch_loss": []
|
|
418
|
+
}
|
|
419
|
+
return metrics
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from junshan_kit import ParametersHub
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# -------------------------------------------------------------
|
|
5
|
+
def training_group(training_group):
|
|
6
|
+
print(f"--------------------- training_group ------------------")
|
|
7
|
+
for g in training_group:
|
|
8
|
+
print(g)
|
|
9
|
+
print(f"-------------------------------------------------------")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def training_info(args, use_color, data_name, optimizer_name, folder_path, hyperparams, Paras, model_name):
|
|
13
|
+
if use_color:
|
|
14
|
+
print("\033[90m" + "-" * 115 + "\033[0m")
|
|
15
|
+
print(
|
|
16
|
+
f"\033[32m✅ \033[34mDataset:\033[32m {data_name}, \t\033[34mBatch-size:\033[32m {args.bs}, \t\033[34m(training, test) = \033[32m ({Paras['train_data_num']}, {Paras['test_data_num']}), \t\033[34m device:\033[32m {Paras['device']}"
|
|
17
|
+
)
|
|
18
|
+
print(
|
|
19
|
+
f"\033[32m✅ \033[34mOptimizer:\033[32m {optimizer_name}, \t\033[34mParams:\033[32m {hyperparams}"
|
|
20
|
+
)
|
|
21
|
+
print(
|
|
22
|
+
f'\033[32m✅ \033[34mmodel:\033[32m {model_name}, \t\033[34mmodel type:\033[32m {Paras["model_type"][model_name]},\t\033[34m loss_fn:\033[32m {Paras["loss_fn"]}'
|
|
23
|
+
)
|
|
24
|
+
print(f"\033[32m✅ \033[34mfolder_path:\033[32m {folder_path}")
|
|
25
|
+
print("\033[90m" + "-" * 115 + "\033[0m")
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
print("-" * 115)
|
|
29
|
+
print(
|
|
30
|
+
f"✅ Dataset: {data_name}, \tBatch-size: {Paras['batch_size'][data_name]}, \t(training, val, test) = ({Paras['training_samples']}, {Paras['val_samples']}, {Paras['test_samples']}), \tdevice: {Paras['device']}"
|
|
31
|
+
)
|
|
32
|
+
print(f"✅ Optimizer: {optimizer_name}, \tParams: {hyperparams}")
|
|
33
|
+
print(
|
|
34
|
+
f'✅ model: {model_name}, \t model type: {Paras["model_type"][model_name]}, loss_fn: {Paras["loss_fn"]}'
|
|
35
|
+
)
|
|
36
|
+
print(f"✅ folder_path: {folder_path}")
|
|
37
|
+
print("-" * 115)
|
|
38
|
+
|
|
39
|
+
# <Step_7_2>
|
|
40
|
+
|
|
41
|
+
def print_per_epoch_info(use_color, epoch, Paras, epoch_loss, training_loss, training_acc, test_loss, test_acc, run_time):
|
|
42
|
+
epochs = Paras["epochs"][Paras["data_name"]]
|
|
43
|
+
# result = [(k, f"{v:.4f}") for k, v in run_time.items()]
|
|
44
|
+
if use_color:
|
|
45
|
+
print(
|
|
46
|
+
f'\033[34m epoch = \033[32m{epoch+1}/{epochs}\033[0m,\t\b'
|
|
47
|
+
f'\033[34m epoch_loss = \033[32m{epoch_loss[epoch+1]:.4e}\033[0m,\t\b'
|
|
48
|
+
f'\033[34m train_loss = \033[32m{training_loss[epoch+1]:.4e}\033[0m,\t\b'
|
|
49
|
+
f'\033[34m train_acc = \033[32m{100 * training_acc[epoch+1]:.2f}%\033[0m,\t\b'
|
|
50
|
+
f'\033[34m test_acc = \033[32m{100 * test_acc[epoch+1]:.2f}%\033[0m,\t\b'
|
|
51
|
+
f'\033[34m time (ep, tr, te) = \033[32m({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})\033[0m')
|
|
52
|
+
else:
|
|
53
|
+
print(
|
|
54
|
+
f'epoch = {epoch+1}/{epochs},\t'
|
|
55
|
+
f'epoch_loss = {epoch_loss[epoch+1]:.4e},\t'
|
|
56
|
+
f'train_loss = {training_loss[epoch+1]:.4e},\t'
|
|
57
|
+
f'train_acc = {100 * training_acc[epoch+1]:.2f}%,\t'
|
|
58
|
+
f'test_acc = {100 * test_acc[epoch+1]:.2f}%,\t'
|
|
59
|
+
f'time (ep, tr, te) = ({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def data_info():
|
|
63
|
+
print(ParametersHub.data_list.__doc__)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
import torch, time
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch.utils.data as Data
|
|
5
|
+
from torch.nn.utils import parameters_to_vector
|
|
6
|
+
from junshan_kit import DataHub, ParametersHub, TrainingHub, Evaluate_Metrics
|
|
7
|
+
|
|
8
|
+
def chosen_loss_fn(model_name, Paras):
|
|
9
|
+
# ---------------------------------------
|
|
10
|
+
# There have an addition parameter
|
|
11
|
+
if model_name == "LogRegressionBinaryL2":
|
|
12
|
+
Paras["lambda"] = 1e-3
|
|
13
|
+
# ---------------------------------------
|
|
14
|
+
|
|
15
|
+
if model_name in ["LeastSquares"]:
|
|
16
|
+
loss_fn = nn.MSELoss()
|
|
17
|
+
|
|
18
|
+
else:
|
|
19
|
+
if Paras["model_type"][model_name] == "binary":
|
|
20
|
+
loss_fn = nn.BCEWithLogitsLoss()
|
|
21
|
+
|
|
22
|
+
elif Paras["model_type"][model_name] == "multi":
|
|
23
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
24
|
+
|
|
25
|
+
else:
|
|
26
|
+
loss_fn = nn.MSELoss()
|
|
27
|
+
print("\033[91m The loss function is error!\033[0m")
|
|
28
|
+
assert False
|
|
29
|
+
|
|
30
|
+
Paras["loss_fn"] = loss_fn
|
|
31
|
+
|
|
32
|
+
return loss_fn, Paras
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_data(model_name, data_name, Paras):
|
|
36
|
+
# load data
|
|
37
|
+
train_path = f"./exp_data/{data_name}/training_data"
|
|
38
|
+
test_path = f"./exp_data/{data_name}/test_data"
|
|
39
|
+
|
|
40
|
+
if data_name == "MNIST":
|
|
41
|
+
train_dataset, test_dataset, transform = DataHub.MNIST(Paras, model_name)
|
|
42
|
+
|
|
43
|
+
elif data_name == "CIFAR100":
|
|
44
|
+
train_dataset, test_dataset, transform = DataHub.CIFAR100(Paras, model_name)
|
|
45
|
+
|
|
46
|
+
elif data_name == "Adult_Income_Prediction":
|
|
47
|
+
train_dataset, test_dataset, transform = DataHub.Adult_Income_Prediction(Paras)
|
|
48
|
+
|
|
49
|
+
elif data_name == "Credit_Card_Fraud_Detection":
|
|
50
|
+
train_dataset, test_dataset, transform = DataHub.Credit_Card_Fraud_Detection(Paras)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# elif data_name == "CALTECH101_Resize_32":
|
|
54
|
+
# Paras["train_ratio"] = 0.7
|
|
55
|
+
# train_dataset, test_dataset, transform = datahub.caltech101_Resize_32(
|
|
56
|
+
# Paras["seed"], Paras["train_ratio"], split=True
|
|
57
|
+
# )
|
|
58
|
+
|
|
59
|
+
# elif data_name in ["Vowel", "Letter", "Shuttle", "w8a"]:
|
|
60
|
+
# Paras["train_ratio"] = Paras["split_train_data"][data_name]
|
|
61
|
+
# train_dataset, test_dataset, transform = datahub.get_libsvm_data(
|
|
62
|
+
# train_path + ".txt", test_path + ".txt", data_name
|
|
63
|
+
# )
|
|
64
|
+
|
|
65
|
+
# elif data_name in ["RCV1", "Duke", "Ijcnn"]:
|
|
66
|
+
# Paras["train_ratio"] = Paras["split_train_data"][data_name]
|
|
67
|
+
# train_dataset, test_dataset, transform = datahub.get_libsvm_bz2_data(
|
|
68
|
+
# train_path + ".bz2", test_path + ".bz2", data_name, Paras
|
|
69
|
+
# )
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
transform = None
|
|
73
|
+
print(f"The data_name is error!")
|
|
74
|
+
assert False
|
|
75
|
+
|
|
76
|
+
return train_dataset, test_dataset, transform
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_dataloader(data_name, train_dataset, test_dataset, Paras):
|
|
80
|
+
ParametersHub.set_seed(Paras["seed"])
|
|
81
|
+
g = torch.Generator()
|
|
82
|
+
g.manual_seed(Paras["seed"])
|
|
83
|
+
|
|
84
|
+
train_loader = Data.DataLoader(
|
|
85
|
+
dataset=train_dataset,
|
|
86
|
+
shuffle=True,
|
|
87
|
+
batch_size=Paras["batch_size"],
|
|
88
|
+
generator=g,
|
|
89
|
+
num_workers=4,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
test_loader = Data.DataLoader(
|
|
93
|
+
dataset=test_dataset,
|
|
94
|
+
shuffle=False,
|
|
95
|
+
batch_size=Paras["batch_size"],
|
|
96
|
+
generator=g,
|
|
97
|
+
num_workers=4,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return train_loader, test_loader
|
|
101
|
+
|
|
102
|
+
def chosen_optimizer(optimizer_name, model, hyperparams, Paras):
|
|
103
|
+
if optimizer_name == "SGD":
|
|
104
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=hyperparams["alpha"])
|
|
105
|
+
|
|
106
|
+
elif optimizer_name == "ADAM":
|
|
107
|
+
optimizer = torch.optim.Adam(
|
|
108
|
+
model.parameters(),
|
|
109
|
+
lr=hyperparams["alpha"],
|
|
110
|
+
betas=(hyperparams["beta1"], hyperparams["beta2"]),
|
|
111
|
+
eps=hyperparams["epsilon"],
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
else:
|
|
115
|
+
raise NotImplementedError(f"{optimizer_name} is not supported.")
|
|
116
|
+
|
|
117
|
+
return optimizer
|
|
118
|
+
|
|
119
|
+
def load_model_dataloader(base_model_fun, initial_state_dict, data_name, train_dataset, test_dataset, Paras):
|
|
120
|
+
ParametersHub.set_seed(Paras["seed"])
|
|
121
|
+
model = base_model_fun()
|
|
122
|
+
model.load_state_dict(initial_state_dict)
|
|
123
|
+
model.to(Paras["device"])
|
|
124
|
+
train_loader, test_loader = TrainingHub.get_dataloader(data_name, train_dataset, test_dataset, Paras)
|
|
125
|
+
|
|
126
|
+
return model, train_loader, test_loader
|
|
127
|
+
|
|
128
|
+
def train(train_loader, optimizer_name, optimizer, model, loss_fn, Paras):
|
|
129
|
+
metrics = ParametersHub.metrics()
|
|
130
|
+
for epoch in range(Paras["epochs"]):
|
|
131
|
+
for index, (X, Y) in enumerate(train_loader):
|
|
132
|
+
X, Y = X.to(Paras["device"]), Y.to(Paras["device"])
|
|
133
|
+
|
|
134
|
+
if optimizer_name in ["SGD", "ADAM"]:
|
|
135
|
+
optimizer.zero_grad()
|
|
136
|
+
loss = Evaluate_Metrics.compute_epoch_loss(X, Y, model, loss_fn, Paras)
|
|
137
|
+
loss.backward()
|
|
138
|
+
optimizer.step()
|
|
139
|
+
|
|
140
|
+
elif optimizer_name in [
|
|
141
|
+
"Bundle",
|
|
142
|
+
"SBPM",
|
|
143
|
+
]:
|
|
144
|
+
def closure():
|
|
145
|
+
optimizer.zero_grad()
|
|
146
|
+
loss = Evaluate_Metrics.compute_epoch_loss(X, Y, model, loss_fn, Paras)
|
|
147
|
+
loss.backward()
|
|
148
|
+
return loss
|
|
149
|
+
|
|
150
|
+
loss = optimizer.step(closure)
|
|
151
|
+
|
|
152
|
+
else:
|
|
153
|
+
loss = 0
|
|
154
|
+
raise NotImplementedError(f"{optimizer_name} is not supported.")
|
|
155
|
+
|
|
156
|
+
if index == 0 and epoch == 0:
|
|
157
|
+
metrics["iter_loss"].append(loss.item())
|
|
158
|
+
metrics["epoch_loss"].append(loss.item())
|
|
159
|
+
with torch.no_grad():
|
|
160
|
+
g_k = parameters_to_vector(
|
|
161
|
+
[
|
|
162
|
+
p.grad if p.grad is not None else torch.zeros_like(p)
|
|
163
|
+
for p in model.parameters()
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
metrics["grad_norm"].append(torch.norm(g_k, p=2).detach().cpu().item())
|
|
167
|
+
|
|
168
|
+
metrics["per_epoch_loss"].append(loss.item())
|
|
169
|
+
metrics["epoch_loss"].append(np.mean(metrics["per_epoch_loss"]).item())
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
|