junshan-kit 2.2.8__py2.py3-none-any.whl → 2.5.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,419 @@
1
+ import numpy as np
2
+ import sys, os, torch, random
3
+ import argparse
4
+ import junshan_kit.ModelsHub as ModelsHub
5
+
6
+
7
+ class check_args:
8
+ def __init__(self):
9
+ pass
10
+
11
+ def get_args(self):
12
+ parser = argparse.ArgumentParser(description="Combined config argument example")
13
+
14
+ allowed_models = ["LS", "LRL2","ResNet18"]
15
+ allowed_optimizers = ["ADAM", "SGD", "Bundle"]
16
+
17
+ allowed_datasets = ["MNIST",
18
+ "CIFAR100",
19
+ "AIP",
20
+ "CCFD",
21
+ ]
22
+
23
+ optimizers_mapping = {
24
+ "ADAM": "ADAM",
25
+ "SGD": "SGD",
26
+ "Bundle": "Bundle"
27
+ }
28
+
29
+ model_mapping = {
30
+ "LS": "LeastSquares",
31
+ "LRL2": "LogRegressionBinaryL2",
32
+ "ResNet18": "ResNet18"
33
+ }
34
+
35
+ data_name_mapping = {
36
+ "MNIST": "MNIST",
37
+ "CIFAR100": "CIFAR100",
38
+ "AIP": "Adult_Income_Prediction",
39
+ "CCFD": "Credit_Card_Fraud_Detection"
40
+ }
41
+
42
+
43
+
44
+ # Single combined argument that can appear multiple times
45
+ parser.add_argument(
46
+ "--train",
47
+ type=str,
48
+ nargs="+", # Allow multiple configs
49
+ required=True,
50
+ help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--e",
55
+ type=int,
56
+ required=True,
57
+ help="Number of training epochs. Example: --e 50"
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--seed",
62
+ type=int,
63
+ default=42,
64
+ help="Random seed for experiment reproducibility. Default: 42"
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--bs",
69
+ type=int,
70
+ required=True,
71
+ help="Batch size for training. Example: --bs 128"
72
+ )
73
+
74
+ parser.add_argument(
75
+ "--cuda",
76
+ type=int,
77
+ default=0,
78
+ required=True,
79
+ help="The number of cuda. Example: --cuda 1 (default=0) "
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--s",
84
+ type=float,
85
+ default=1.0,
86
+ # required=True,
87
+ help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--subset",
92
+ type=float,
93
+ nargs=2,
94
+ # required=True,
95
+ help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
96
+ )
97
+
98
+ args = parser.parse_args()
99
+ args.model_name_mapping = model_mapping
100
+ args.data_name_mapping = data_name_mapping
101
+ args.optimizers_name_mapping = optimizers_mapping
102
+
103
+
104
+ if args.subset is not None:
105
+ self.check_subset_info(args, parser)
106
+
107
+
108
+ self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
109
+
110
+ return args
111
+
112
+ def check_subset_info(self, args, parser):
113
+ total = sum(args.subset)
114
+ if args.subset[0]>1:
115
+ # CHECK
116
+ for i in args.subset:
117
+ if i < 1:
118
+ parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
119
+ else:
120
+ if abs(total - 1.0) != 0.0:
121
+ parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
122
+
123
+
124
+ def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
125
+ # Parse and validate each train_group
126
+ for cfg in args.train:
127
+ try:
128
+ model, dataset, optimizer = cfg.split("-")
129
+
130
+ if model not in allowed_models:
131
+ parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
132
+ if optimizer not in allowed_optimizers:
133
+ parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
134
+ if dataset not in allowed_datasets:
135
+ parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
136
+
137
+ except ValueError:
138
+ parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
139
+
140
+ for cfg in args.train:
141
+ model_name, dataset_name, optimizer_name = cfg.split("-")
142
+ try:
143
+ f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
144
+
145
+ except:
146
+ print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
147
+ assert False
148
+
149
+ def get_train_group(args):
150
+ training_group = []
151
+ for cfg in args.train:
152
+ model, dataset, optimizer = cfg.split("-")
153
+ training_group.append((args.model_name_mapping[model], args.data_name_mapping[dataset], args.optimizers_name_mapping[optimizer]))
154
+
155
+ return training_group
156
+
157
+
158
+ def set_paras(args, OtherParas):
159
+ Paras = {
160
+ # Name of the folder where results will be saved.
161
+ "results_folder_name": OtherParas["results_folder_name"],
162
+
163
+ # Print loss every N epochs.
164
+ "epoch_log_interval": 1,
165
+
166
+ # Timestamp string for result saving.
167
+ "time_str": ["time_str"],
168
+
169
+ # Random seed
170
+ "seed": args.seed,
171
+
172
+ # Device used for training.
173
+ "cuda": f"cuda:{args.cuda}",
174
+
175
+ # batch-size
176
+ "batch_size": args.bs,
177
+
178
+ # epochs
179
+ "epochs": args.e,
180
+
181
+ # split_train_data
182
+ "split_train_data": args.s,
183
+
184
+ # select_subset
185
+ "select_subset": args.subset
186
+ }
187
+
188
+ Paras = model_list(Paras)
189
+ Paras = model_type(Paras)
190
+ Paras = data_list(Paras)
191
+ Paras = optimizer_dict(Paras, OtherParas)
192
+ Paras = device(Paras)
193
+
194
+ return Paras
195
+
196
+
197
+ def set_seed(seed=42):
198
+ torch.manual_seed(seed)
199
+ torch.cuda.manual_seed_all(seed)
200
+ np.random.seed(seed)
201
+ random.seed(seed)
202
+ torch.backends.cudnn.deterministic = True
203
+ torch.backends.cudnn.benchmark = False
204
+
205
+ def device(Paras) -> dict:
206
+ device = torch.device(f"{Paras['cuda']}" if torch.cuda.is_available() else "cpu")
207
+ Paras["device"] = device
208
+ use_color = sys.stdout.isatty()
209
+ Paras["use_color"] = use_color
210
+
211
+ return Paras
212
+
213
+ def model_list(Paras) -> dict:
214
+ model_list = [
215
+ "ResNet18",
216
+ "ResNet34",
217
+ "LeastSquares",
218
+ "LogRegressionBinary",
219
+ "LogRegressionBinaryL2",
220
+ ]
221
+ Paras["model_list"] = model_list
222
+ return Paras
223
+
224
+ def model_type(Paras) -> dict:
225
+ model_type = {
226
+ "ResNet18": "multi",
227
+ "ResNet34": "multi",
228
+ "LeastSquares": "multi",
229
+ "LogRegressionBinary": "binary",
230
+ "LogRegressionBinaryL2": "binary",
231
+ }
232
+
233
+ Paras["model_type"] = model_type
234
+ return Paras
235
+
236
+ def data_list(Paras) -> dict:
237
+ """
238
+ Attach a predefined list of dataset names to the parameter dictionary.
239
+
240
+ The predefined datasets include:
241
+ - Duke:
242
+ - classes: 2
243
+ - data: 42 (38 + 4)
244
+ - features: 7,129
245
+ - Ijcnn:
246
+ - classes: 2
247
+ - data: (35,000 + 91,701)
248
+ - features: 22
249
+ - w8a:
250
+ - classes: 2
251
+ - data: (49,749 + 14,951)
252
+ - features: 300
253
+ - RCV1
254
+ - Shuttle
255
+ - Letter
256
+ - Vowel
257
+ - MNIST
258
+ - CIFAR100
259
+ - CALTECH101_Resize_32
260
+ - Adult Income Prediction
261
+ -
262
+ - Credit_Card_Fraud_Detection
263
+ """
264
+
265
+ data_list = [
266
+ "Duke",
267
+ "Ijcnn",
268
+ "w8a",
269
+ "RCV1",
270
+ "Shuttle",
271
+ "Letter",
272
+ "Vowel",
273
+ "MNIST",
274
+ "CIFAR100",
275
+ "CALTECH101_Resize_32",
276
+ "Adult_Income_Prediction",
277
+ "Credit_Card_Fraud_Detection"
278
+ ]
279
+ Paras["data_list"] = data_list
280
+ return Paras
281
+
282
+
283
+ def optimizer_dict(Paras, OtherParas)->dict:
284
+ optimizer_dict = {
285
+ # ----------------- ADAM --------------------
286
+ "ADAM": {
287
+ "params": {
288
+ # "alpha": [2 * 1e-3],
289
+ "alpha": (
290
+ [0.5 * 1e-3, 1e-3, 2 * 1e-3]
291
+ if OtherParas["SeleParasOn"]
292
+ else [0.0005]
293
+ ),
294
+ "epsilon": [1e-8],
295
+ "beta1": [0.9],
296
+ "beta2": [0.999],
297
+ },
298
+ },
299
+ # ------------- ALR-SMAG --------------------
300
+ "ALR-SMAG": {
301
+ "params": {
302
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
303
+ "eta_max": (
304
+ [2**i for i in range(-8, 9)]
305
+ if OtherParas["SeleParasOn"]
306
+ else [0.125]
307
+ ),
308
+ "beta": [0.9],
309
+ },
310
+ },
311
+ # ------------ Bundle -----------------------
312
+ "Bundle": {
313
+ "params": {
314
+ "delta": (
315
+ [2**i for i in range(-8, 9)]
316
+ if OtherParas["SeleParasOn"]
317
+ else [0.25]
318
+ ),
319
+ "cutting_number": [10],
320
+ },
321
+ },
322
+ # ------------------- SGD -------------------
323
+ "SGD": {
324
+ "params": {
325
+ "alpha": (
326
+ [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.5]
327
+ )
328
+ }
329
+ },
330
+ # ------------------- SPSmax ----------------
331
+ "SPSmax": {
332
+ "params": {
333
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
334
+ "gamma": (
335
+ [2**i for i in range(-8, 9)]
336
+ if OtherParas["SeleParasOn"]
337
+ else [0.125]),
338
+ },
339
+ },
340
+ # -------------- SPBM-PF --------------------
341
+ "SPBM-PF": {
342
+ "params": {
343
+ "M": [1e-5],
344
+ "delta": (
345
+ [2**i for i in range(9, 20)]
346
+ if OtherParas["SeleParasOn"]
347
+ else [1]
348
+ ),
349
+ "cutting_number": [10],
350
+ },
351
+ },
352
+ # -------------- SPBM-TR --------------------
353
+ "SPBM-TR": {
354
+ "params": {
355
+ "M": [1e-5],
356
+ "delta": (
357
+ [2**i for i in range(9, 20)]
358
+ if OtherParas["SeleParasOn"]
359
+ else [256]
360
+ ),
361
+ "cutting_number": [10],
362
+ },
363
+ },
364
+
365
+ # ----------- SPBM-TR-NoneLower -------------
366
+ "SPBM-TR-NoneLower": {
367
+ "params": {
368
+ "M": [1e-5],
369
+ "delta": (
370
+ [2**i for i in range(0, 9)]
371
+ if OtherParas["SeleParasOn"]
372
+ else [256]
373
+ ),
374
+ "cutting_number": [10],
375
+ },
376
+ },
377
+ # ----------- SPBM-TR-NoneSpecial -----------
378
+ "SPBM-TR-NoneSpecial": {
379
+ "params": {
380
+ "M": [1e-5],
381
+ "delta": (
382
+ [2**i for i in range(-8, 9)]
383
+ if OtherParas["SeleParasOn"]
384
+ else [1]
385
+ ),
386
+ "cutting_number": [10],
387
+ },
388
+ },
389
+ # ------------- SPBM-PF-NoneLower -----------
390
+ "SPBM-PF-NoneLower": {
391
+ "params": {
392
+ "M": [1e-5],
393
+ "delta": (
394
+ [2**i for i in range(0, 9)]
395
+ if OtherParas["SeleParasOn"]
396
+ else [0]
397
+ ),
398
+ "cutting_number": [10],
399
+ },
400
+ },
401
+ }
402
+
403
+ Paras["optimizer_dict"] = optimizer_dict
404
+ return Paras
405
+
406
+
407
+
408
+ def metrics()->dict:
409
+ metrics = {
410
+ "epoch_loss": [],
411
+ "training_loss": [],
412
+ "test_loss": [],
413
+ "iter_loss": [],
414
+ "training_acc": [],
415
+ "test_acc": [],
416
+ "grad_norm": [],
417
+ "per_epoch_loss": []
418
+ }
419
+ return metrics
@@ -0,0 +1,63 @@
1
+ from junshan_kit import ParametersHub
2
+
3
+
4
+ # -------------------------------------------------------------
5
+ def training_group(training_group):
6
+ print(f"--------------------- training_group ------------------")
7
+ for g in training_group:
8
+ print(g)
9
+ print(f"-------------------------------------------------------")
10
+
11
+
12
+ def training_info(args, use_color, data_name, optimizer_name, folder_path, hyperparams, Paras, model_name):
13
+ if use_color:
14
+ print("\033[90m" + "-" * 115 + "\033[0m")
15
+ print(
16
+ f"\033[32m✅ \033[34mDataset:\033[32m {data_name}, \t\033[34mBatch-size:\033[32m {args.bs}, \t\033[34m(training, test) = \033[32m ({Paras['train_data_num']}, {Paras['test_data_num']}), \t\033[34m device:\033[32m {Paras['device']}"
17
+ )
18
+ print(
19
+ f"\033[32m✅ \033[34mOptimizer:\033[32m {optimizer_name}, \t\033[34mParams:\033[32m {hyperparams}"
20
+ )
21
+ print(
22
+ f'\033[32m✅ \033[34mmodel:\033[32m {model_name}, \t\033[34mmodel type:\033[32m {Paras["model_type"][model_name]},\t\033[34m loss_fn:\033[32m {Paras["loss_fn"]}'
23
+ )
24
+ print(f"\033[32m✅ \033[34mfolder_path:\033[32m {folder_path}")
25
+ print("\033[90m" + "-" * 115 + "\033[0m")
26
+
27
+ else:
28
+ print("-" * 115)
29
+ print(
30
+ f"✅ Dataset: {data_name}, \tBatch-size: {Paras['batch_size'][data_name]}, \t(training, val, test) = ({Paras['training_samples']}, {Paras['val_samples']}, {Paras['test_samples']}), \tdevice: {Paras['device']}"
31
+ )
32
+ print(f"✅ Optimizer: {optimizer_name}, \tParams: {hyperparams}")
33
+ print(
34
+ f'✅ model: {model_name}, \t model type: {Paras["model_type"][model_name]}, loss_fn: {Paras["loss_fn"]}'
35
+ )
36
+ print(f"✅ folder_path: {folder_path}")
37
+ print("-" * 115)
38
+
39
+ # <Step_7_2>
40
+
41
+ def print_per_epoch_info(use_color, epoch, Paras, epoch_loss, training_loss, training_acc, test_loss, test_acc, run_time):
42
+ epochs = Paras["epochs"][Paras["data_name"]]
43
+ # result = [(k, f"{v:.4f}") for k, v in run_time.items()]
44
+ if use_color:
45
+ print(
46
+ f'\033[34m epoch = \033[32m{epoch+1}/{epochs}\033[0m,\t\b'
47
+ f'\033[34m epoch_loss = \033[32m{epoch_loss[epoch+1]:.4e}\033[0m,\t\b'
48
+ f'\033[34m train_loss = \033[32m{training_loss[epoch+1]:.4e}\033[0m,\t\b'
49
+ f'\033[34m train_acc = \033[32m{100 * training_acc[epoch+1]:.2f}%\033[0m,\t\b'
50
+ f'\033[34m test_acc = \033[32m{100 * test_acc[epoch+1]:.2f}%\033[0m,\t\b'
51
+ f'\033[34m time (ep, tr, te) = \033[32m({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})\033[0m')
52
+ else:
53
+ print(
54
+ f'epoch = {epoch+1}/{epochs},\t'
55
+ f'epoch_loss = {epoch_loss[epoch+1]:.4e},\t'
56
+ f'train_loss = {training_loss[epoch+1]:.4e},\t'
57
+ f'train_acc = {100 * training_acc[epoch+1]:.2f}%,\t'
58
+ f'test_acc = {100 * test_acc[epoch+1]:.2f}%,\t'
59
+ f'time (ep, tr, te) = ({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})')
60
+
61
+
62
+ def data_info():
63
+ print(ParametersHub.data_list.__doc__)
@@ -0,0 +1,174 @@
1
+ import torch, time
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.utils.data as Data
5
+ from torch.nn.utils import parameters_to_vector
6
+ from junshan_kit import DataHub, ParametersHub, TrainingHub, Evaluate_Metrics
7
+
8
+ def chosen_loss_fn(model_name, Paras):
9
+ # ---------------------------------------
10
+ # There have an addition parameter
11
+ if model_name == "LogRegressionBinaryL2":
12
+ Paras["lambda"] = 1e-3
13
+ # ---------------------------------------
14
+
15
+ if model_name in ["LeastSquares"]:
16
+ loss_fn = nn.MSELoss()
17
+
18
+ else:
19
+ if Paras["model_type"][model_name] == "binary":
20
+ loss_fn = nn.BCEWithLogitsLoss()
21
+
22
+ elif Paras["model_type"][model_name] == "multi":
23
+ loss_fn = nn.CrossEntropyLoss()
24
+
25
+ else:
26
+ loss_fn = nn.MSELoss()
27
+ print("\033[91m The loss function is error!\033[0m")
28
+ assert False
29
+
30
+ Paras["loss_fn"] = loss_fn
31
+
32
+ return loss_fn, Paras
33
+
34
+
35
+ def load_data(model_name, data_name, Paras):
36
+ # load data
37
+ train_path = f"./exp_data/{data_name}/training_data"
38
+ test_path = f"./exp_data/{data_name}/test_data"
39
+
40
+ if data_name == "MNIST":
41
+ train_dataset, test_dataset, transform = DataHub.MNIST(Paras, model_name)
42
+
43
+ elif data_name == "CIFAR100":
44
+ train_dataset, test_dataset, transform = DataHub.CIFAR100(Paras, model_name)
45
+
46
+ elif data_name == "Adult_Income_Prediction":
47
+ train_dataset, test_dataset, transform = DataHub.Adult_Income_Prediction(Paras)
48
+
49
+ elif data_name == "Credit_Card_Fraud_Detection":
50
+ train_dataset, test_dataset, transform = DataHub.Credit_Card_Fraud_Detection(Paras)
51
+
52
+
53
+ # elif data_name == "CALTECH101_Resize_32":
54
+ # Paras["train_ratio"] = 0.7
55
+ # train_dataset, test_dataset, transform = datahub.caltech101_Resize_32(
56
+ # Paras["seed"], Paras["train_ratio"], split=True
57
+ # )
58
+
59
+ # elif data_name in ["Vowel", "Letter", "Shuttle", "w8a"]:
60
+ # Paras["train_ratio"] = Paras["split_train_data"][data_name]
61
+ # train_dataset, test_dataset, transform = datahub.get_libsvm_data(
62
+ # train_path + ".txt", test_path + ".txt", data_name
63
+ # )
64
+
65
+ # elif data_name in ["RCV1", "Duke", "Ijcnn"]:
66
+ # Paras["train_ratio"] = Paras["split_train_data"][data_name]
67
+ # train_dataset, test_dataset, transform = datahub.get_libsvm_bz2_data(
68
+ # train_path + ".bz2", test_path + ".bz2", data_name, Paras
69
+ # )
70
+
71
+ else:
72
+ transform = None
73
+ print(f"The data_name is error!")
74
+ assert False
75
+
76
+ return train_dataset, test_dataset, transform
77
+
78
+
79
+ def get_dataloader(data_name, train_dataset, test_dataset, Paras):
80
+ ParametersHub.set_seed(Paras["seed"])
81
+ g = torch.Generator()
82
+ g.manual_seed(Paras["seed"])
83
+
84
+ train_loader = Data.DataLoader(
85
+ dataset=train_dataset,
86
+ shuffle=True,
87
+ batch_size=Paras["batch_size"],
88
+ generator=g,
89
+ num_workers=4,
90
+ )
91
+
92
+ test_loader = Data.DataLoader(
93
+ dataset=test_dataset,
94
+ shuffle=False,
95
+ batch_size=Paras["batch_size"],
96
+ generator=g,
97
+ num_workers=4,
98
+ )
99
+
100
+ return train_loader, test_loader
101
+
102
+ def chosen_optimizer(optimizer_name, model, hyperparams, Paras):
103
+ if optimizer_name == "SGD":
104
+ optimizer = torch.optim.SGD(model.parameters(), lr=hyperparams["alpha"])
105
+
106
+ elif optimizer_name == "ADAM":
107
+ optimizer = torch.optim.Adam(
108
+ model.parameters(),
109
+ lr=hyperparams["alpha"],
110
+ betas=(hyperparams["beta1"], hyperparams["beta2"]),
111
+ eps=hyperparams["epsilon"],
112
+ )
113
+
114
+ else:
115
+ raise NotImplementedError(f"{optimizer_name} is not supported.")
116
+
117
+ return optimizer
118
+
119
+ def load_model_dataloader(base_model_fun, initial_state_dict, data_name, train_dataset, test_dataset, Paras):
120
+ ParametersHub.set_seed(Paras["seed"])
121
+ model = base_model_fun()
122
+ model.load_state_dict(initial_state_dict)
123
+ model.to(Paras["device"])
124
+ train_loader, test_loader = TrainingHub.get_dataloader(data_name, train_dataset, test_dataset, Paras)
125
+
126
+ return model, train_loader, test_loader
127
+
128
+ def train(train_loader, optimizer_name, optimizer, model, loss_fn, Paras):
129
+ metrics = ParametersHub.metrics()
130
+ for epoch in range(Paras["epochs"]):
131
+ for index, (X, Y) in enumerate(train_loader):
132
+ X, Y = X.to(Paras["device"]), Y.to(Paras["device"])
133
+
134
+ if optimizer_name in ["SGD", "ADAM"]:
135
+ optimizer.zero_grad()
136
+ loss = Evaluate_Metrics.compute_epoch_loss(X, Y, model, loss_fn, Paras)
137
+ loss.backward()
138
+ optimizer.step()
139
+
140
+ elif optimizer_name in [
141
+ "Bundle",
142
+ "SBPM",
143
+ ]:
144
+ def closure():
145
+ optimizer.zero_grad()
146
+ loss = Evaluate_Metrics.compute_epoch_loss(X, Y, model, loss_fn, Paras)
147
+ loss.backward()
148
+ return loss
149
+
150
+ loss = optimizer.step(closure)
151
+
152
+ else:
153
+ loss = 0
154
+ raise NotImplementedError(f"{optimizer_name} is not supported.")
155
+
156
+ if index == 0 and epoch == 0:
157
+ metrics["iter_loss"].append(loss.item())
158
+ metrics["epoch_loss"].append(loss.item())
159
+ with torch.no_grad():
160
+ g_k = parameters_to_vector(
161
+ [
162
+ p.grad if p.grad is not None else torch.zeros_like(p)
163
+ for p in model.parameters()
164
+ ]
165
+ )
166
+ metrics["grad_norm"].append(torch.norm(g_k, p=2).detach().cpu().item())
167
+
168
+ metrics["per_epoch_loss"].append(loss.item())
169
+ metrics["epoch_loss"].append(np.mean(metrics["per_epoch_loss"]).item())
170
+
171
+
172
+
173
+
174
+