junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,690 @@
1
+ import numpy as np
2
+ import sys, os, torch, random, glob
3
+ import argparse
4
+ from datetime import datetime
5
+ script_dir = os.path.dirname(os.path.abspath(__file__))
6
+ sys.path.append(os.path.join(script_dir, 'src'))
7
+ from junshan_kit import ModelsHub, Check_Info
8
+
9
+
10
+ class args:
11
+ def __init__(self):
12
+ pass
13
+
14
+ # <args>
15
+ def get_args(self):
16
+ parser = argparse.ArgumentParser(description="Combined config argument example")
17
+
18
+ # <allowed_models>
19
+ allowed_models = ["LS", "LRBL2", "ResNet18"]
20
+ # <allowed_models>
21
+
22
+ # <allowed_optimizers>
23
+ allowed_optimizers = [
24
+ "ADAM",
25
+ "ALR_SMAG",
26
+ "Bundle",
27
+ "SGD",
28
+ "SPBM_TR",
29
+ "SPBM_PF",
30
+ "SPSmax",
31
+ "SPBM_TR_NoneSpecial",
32
+ "SPBM_TR_NoneLower",
33
+ "SPBM_PF_NoneLower",
34
+ ]
35
+ # <allowed_optimizers>
36
+
37
+ # <allowed_datasets>
38
+ allowed_datasets = [
39
+ "MNIST",
40
+ "CIFAR100",
41
+ "Caltech101",
42
+ "AIP",
43
+ "CCFD",
44
+ "Duke",
45
+ "Ijcnn",
46
+ "DHI",
47
+ "EVP",
48
+ "GHP",
49
+ "HL",
50
+ "HQC",
51
+ "TN_Weather",
52
+ ],
53
+ # <allowed_datasets>
54
+ data_name_mapping = {
55
+ "MNIST": "MNIST",
56
+ "CIFAR100": "CIFAR100",
57
+ "Caltech101": "Caltech101_Resize_32",
58
+ "Duke": "Duke",
59
+ "AIP": "Adult_Income_Prediction",
60
+ "CCFD": "Credit_Card_Fraud_Detection",
61
+ "Ijcnn": "Ijcnn",
62
+ "DHI":"Diabetes_Health_Indicators",
63
+ "EVP": "Electric_Vehicle_Population",
64
+ "GHP": "Global_House_Purchase",
65
+ "HL": "Health_Lifestyle",
66
+ "HQC": "Homesite_Quote_Conversion",
67
+ "TN_Weather": "TN_Weather_2020_2025",
68
+ }
69
+
70
+ optimizers_mapping = {
71
+ "ADAM": "ADAM",
72
+ "SGD": "SGD",
73
+ "Bundle": "Bundle",
74
+ "ALR_SMAG": "ALR-SMAG",
75
+ "SPBM_TR": "SPBM-TR",
76
+ "SPBM_PF": "SPBM-PF",
77
+ "SPSmax": "SPSmax",
78
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
79
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
80
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
81
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
82
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
83
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
84
+ }
85
+
86
+ model_mapping = {
87
+ "LS": "LeastSquares",
88
+ "LRBL2": "LogRegressionBinaryL2",
89
+ "ResNet18": "ResNet18"
90
+ }
91
+
92
+ # <args_from_command>
93
+ parser.add_argument(
94
+ "--train",
95
+ type=str,
96
+ nargs="+", # Allow multiple configs
97
+ required=True,
98
+ help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--e",
103
+ type=int,
104
+ required=True,
105
+ help="Number of training epochs. Example: --e 50"
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--seed",
110
+ type=int,
111
+ default=42,
112
+ help="Random seed for experiment reproducibility. Default: 42"
113
+ )
114
+
115
+ parser.add_argument(
116
+ "--bs",
117
+ type=int,
118
+ required=True,
119
+ help="Batch size for training. Example: --bs 128"
120
+ )
121
+
122
+ parser.add_argument(
123
+ "--cuda",
124
+ type=int,
125
+ default=0,
126
+ required=True,
127
+ help="The number of cuda. Example: --cuda 1 (default=0) "
128
+ )
129
+
130
+ parser.add_argument(
131
+ "--s",
132
+ type=float,
133
+ default=1.0,
134
+ # required=True,
135
+ help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
136
+ )
137
+
138
+ parser.add_argument(
139
+ "--subset",
140
+ type=float,
141
+ nargs=2,
142
+ # required=True,
143
+ help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
144
+ )
145
+
146
+ parser.add_argument(
147
+ "--time_str",
148
+ type=str,
149
+ nargs=1,
150
+ # required=True,
151
+ help = "the str of time"
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--send_email",
156
+ type=str,
157
+ nargs=3,
158
+ # required=True,
159
+ help = "from_email to_email, from_pwd"
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--user_search_grid",
164
+ type=int,
165
+ nargs=1,
166
+ # required=True,
167
+ help = "search_grid: 1: "
168
+ )
169
+ # <args_from_command>
170
+
171
+ args = parser.parse_args()
172
+ args.model_name_mapping = model_mapping
173
+ args.data_name_mapping = data_name_mapping
174
+ args.optimizers_name_mapping = optimizers_mapping
175
+
176
+ return args
177
+ # <args>
178
+
179
+ def UpdateOtherParas(args, OtherParas):
180
+ if args.time_str is not None:
181
+ time_str = args.time_str[0]
182
+ else:
183
+ time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
184
+
185
+ if args.user_search_grid is not None:
186
+ OtherParas["user_search_grid"] = args.user_search_grid[0]
187
+ else:
188
+ OtherParas["user_search_grid"] = None
189
+
190
+ if args.send_email is not None:
191
+ OtherParas["from_email"] = args.send_email[0]
192
+ OtherParas["to_email"] = args.send_email[1]
193
+ OtherParas["from_pwd"] = args.send_email[2]
194
+ OtherParas["send_email"] = True
195
+ else:
196
+ OtherParas["send_email"] = False
197
+
198
+ OtherParas["time_str"] = time_str
199
+ OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
200
+
201
+ return OtherParas
202
+
203
+ def get_train_group(args):
204
+ training_group = []
205
+ for cfg in args.train:
206
+ model, dataset, optimizer = cfg.split("-")
207
+ training_group.append((args.model_name_mapping[model], args.data_name_mapping[dataset], args.optimizers_name_mapping[optimizer]))
208
+
209
+ return training_group
210
+
211
+
212
+ def set_paras(args, OtherParas):
213
+ Paras = {
214
+ # Name of the folder where results will be saved.
215
+ "results_folder_name": OtherParas["results_folder_name"],
216
+
217
+ # Print loss every N epochs.
218
+ "epoch_log_interval": 1,
219
+
220
+ "use_log_scale": True,
221
+
222
+ # Timestamp string for result saving.
223
+ "time_str": OtherParas["time_str"],
224
+
225
+ # Random seed
226
+ "seed": args.seed,
227
+
228
+ # Device used for training.
229
+ "cuda": f"cuda:{args.cuda}",
230
+
231
+ # batch-size
232
+ "batch_size": args.bs,
233
+
234
+ # epochs
235
+ "epochs": args.e,
236
+
237
+ # split_train_data
238
+ "split_train_data": args.s,
239
+
240
+ # select_subset
241
+ "select_subset": args.subset,
242
+
243
+ # Results_dict
244
+ "Results_dict": {},
245
+
246
+ # type: bool
247
+ "user_search_grid": OtherParas["user_search_grid"],
248
+ }
249
+
250
+ Paras = model_list(Paras)
251
+ Paras = model_type(Paras)
252
+ Paras = data_list(Paras)
253
+ Paras = optimizer_paras_dict(Paras, OtherParas)
254
+ Paras = device(Paras)
255
+
256
+ return Paras
257
+
258
+ def set_seed(seed=42):
259
+ torch.manual_seed(seed)
260
+ torch.cuda.manual_seed_all(seed)
261
+ np.random.seed(seed)
262
+ random.seed(seed)
263
+ torch.backends.cudnn.deterministic = True
264
+ torch.backends.cudnn.benchmark = False
265
+
266
+ def device(Paras) -> dict:
267
+ device = torch.device(f"{Paras['cuda']}" if torch.cuda.is_available() else "cpu")
268
+ Paras["device"] = device
269
+ use_color = sys.stdout.isatty()
270
+ Paras["use_color"] = use_color
271
+
272
+ return Paras
273
+
274
+ def model_list(Paras) -> dict:
275
+ model_list = [
276
+ "ResNet18",
277
+ "ResNet34",
278
+ "LeastSquares",
279
+ "LogRegressionBinary",
280
+ "LogRegressionBinaryL2",
281
+ ]
282
+ Paras["model_list"] = model_list
283
+ return Paras
284
+
285
+ def model_type(Paras) -> dict:
286
+ model_type = {
287
+ "ResNet18": "multi",
288
+ "ResNet34": "multi",
289
+ "LeastSquares": "multi",
290
+ "LogRegressionBinary": "binary",
291
+ "LogRegressionBinaryL2": "binary",
292
+ }
293
+
294
+ Paras["model_type"] = model_type
295
+ return Paras
296
+
297
+ def data_list(Paras) -> dict:
298
+ data_list = [
299
+ "Duke",
300
+ "Ijcnn",
301
+ "w8a",
302
+ "RCV1",
303
+ "Shuttle",
304
+ "Letter",
305
+ "Vowel",
306
+ "MNIST",
307
+ "CIFAR100",
308
+ "Caltech101_Resize_32",
309
+ "Adult_Income_Prediction",
310
+ "Credit_Card_Fraud_Detection",
311
+ "Diabetes_Health_Indicators",
312
+ "Electric_Vehicle_Population",
313
+ "Global_House_Purchase",
314
+ "Health_Lifestyle",
315
+ "Homesite_Quote_Conversion",
316
+ "TN_Weather_2020_2025"
317
+ ]
318
+ Paras["data_list"] = data_list
319
+ return Paras
320
+
321
+
322
+ def optimizer_paras_dict(Paras, OtherParas)->dict:
323
+ optimizer_dict = {
324
+ # ----------------- ADAM --------------------
325
+ "ADAM": {
326
+ "params": {
327
+ # "alpha": [2 * 1e-3],
328
+ "alpha": (
329
+ [0.5 * 1e-3, 1e-3, 2 * 1e-3]
330
+ if OtherParas["SeleParasOn"]
331
+ else [1e-3]
332
+ ),
333
+ "epsilon": [1e-8],
334
+ "beta1": [0.9],
335
+ "beta2": [0.999],
336
+ },
337
+ },
338
+ # ------------- ALR-SMAG --------------------
339
+ "ALR-SMAG": {
340
+ "params": {
341
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
342
+ "eta_max": (
343
+ [2**i for i in range(-8, 9)]
344
+ if OtherParas["SeleParasOn"]
345
+ else [0.125]
346
+ ),
347
+ "beta": [0.9],
348
+ },
349
+ },
350
+ # ------------ Bundle -----------------------
351
+ "Bundle": {
352
+ "params": {
353
+ "delta": (
354
+ [2**i for i in range(-8, 9)]
355
+ if OtherParas["SeleParasOn"]
356
+ else [0.01]
357
+ ),
358
+ "cutting_number": [10],
359
+ },
360
+ },
361
+ # ------------------- SGD -------------------
362
+ "SGD": {
363
+ "params": {
364
+ "alpha": (
365
+ [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
366
+ )
367
+ }
368
+ },
369
+ # ------------------- SPSmax ----------------
370
+ "SPSmax": {
371
+ "params": {
372
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
373
+ "gamma": (
374
+ [2**i for i in range(-8, 9)]
375
+ if OtherParas["SeleParasOn"]
376
+ else [0.125]),
377
+ },
378
+ },
379
+ # -------------- SPBM-PF --------------------
380
+ "SPBM-PF": {
381
+ "params": {
382
+ "M": [1e-5],
383
+ "delta": (
384
+ [2**i for i in range(9, 20)]
385
+ if OtherParas["SeleParasOn"]
386
+ else [1]
387
+ ),
388
+ "cutting_number": [10],
389
+ },
390
+ },
391
+ # -------------- SPBM-TR --------------------
392
+ "SPBM-TR": {
393
+ "params": {
394
+ "M": [1e-5],
395
+ "delta": (
396
+ [2**i for i in range(9, 20)]
397
+ if OtherParas["SeleParasOn"]
398
+ else [256]
399
+ ),
400
+ "cutting_number": [10],
401
+ },
402
+ },
403
+
404
+ # ----------- SPBM-TR-NoneLower -------------
405
+ "SPBM-TR-NoneLower": {
406
+ "params": {
407
+ "M": [1e-5],
408
+ "delta": (
409
+ [2**i for i in range(0, 9)]
410
+ if OtherParas["SeleParasOn"]
411
+ else [256]
412
+ ),
413
+ "cutting_number": [10],
414
+ },
415
+ },
416
+ # ----------- SPBM-TR-NoneSpecial -----------
417
+ "SPBM-TR-NoneSpecial": {
418
+ "params": {
419
+ "M": [1e-5],
420
+ "delta": (
421
+ [2**i for i in range(-8, 9)]
422
+ if OtherParas["SeleParasOn"]
423
+ else [1]
424
+ ),
425
+ "cutting_number": [10],
426
+ },
427
+ },
428
+ # ----------- SPBM-TR-NoneCut -----------
429
+ "SPBM-TR-NoneCut": {
430
+ "params": {
431
+ "M": [1e-5],
432
+ "delta": (
433
+ [2**i for i in range(-8, 9)]
434
+ if OtherParas["SeleParasOn"]
435
+ else [1]
436
+ ),
437
+ "cutting_number": [10],
438
+ },
439
+ },
440
+
441
+ # ------------- SPBM-PF-NoneLower -----------
442
+ "SPBM-PF-NoneLower": {
443
+ "params": {
444
+ "M": [1e-5],
445
+ "delta": (
446
+ [2**i for i in range(0, 9)]
447
+ if OtherParas["SeleParasOn"]
448
+ else [0]
449
+ ),
450
+ "cutting_number": [10],
451
+ },
452
+ },
453
+ # ----------- SPBM-PF-NoneCut -----------
454
+ "SPBM-PF-NoneCut": {
455
+ "params": {
456
+ "M": [1e-5],
457
+ "delta": (
458
+ [2**i for i in range(-8, 9)]
459
+ if OtherParas["SeleParasOn"]
460
+ else [1]
461
+ ),
462
+ "cutting_number": [10],
463
+ },
464
+ },
465
+ }
466
+
467
+ Paras["optimizer_search_grid"] = optimizer_dict
468
+ return Paras
469
+
470
+ def metrics()->dict:
471
+ metrics = {
472
+ "epoch_loss": [],
473
+ "training_loss": [],
474
+ "test_loss": [],
475
+ "iter_loss": [],
476
+ "training_acc": [],
477
+ "test_acc": [],
478
+ "grad_norm": [],
479
+ "per_epoch_loss": []
480
+ }
481
+ return metrics
482
+
483
+
484
+ def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
485
+
486
+ keys, values = list(params_gird.keys()), list(params_gird.values())
487
+
488
+ Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
489
+ os.makedirs(Paras["Results_folder"], exist_ok=True)
490
+
491
+ return keys, values, Paras
492
+
493
+
494
+ def fig_ylabel(str_name):
495
+
496
+ ylabel = {
497
+ "training_loss": "training loss",
498
+ "test_loss": "test loss",
499
+ "training_acc": "training accuracy",
500
+ "test_acc": "test accuracy",
501
+ "grad_norm": "grad norm",
502
+ "per_epoch_loss": "per epoch loss",
503
+ "epoch_loss": "epoch loss",
504
+ }
505
+ return ylabel[str_name]
506
+
507
+ def model_abbr(model_name):
508
+ name_map = {
509
+ "LogRegressionBinaryL2": "LRBL2",
510
+ "ResNet18": "ResNet18",
511
+ "ResNet34": "ResNet34",
512
+ "LstSquares": "LS"
513
+ }
514
+ return name_map[model_name]
515
+
516
+ def dataset_abbr(model_name):
517
+ name_map = {
518
+ "MNIST": "MNIST",
519
+ "CIFAR100": "CIFAR100",
520
+ "Duke": "Duke",
521
+ "Ijcnn": "Ijcnn",
522
+ "Adult_Income_Prediction": "AIP",
523
+ "Credit_Card_Frau_Detection": "CCFD",
524
+ "Diabetes_Health_Indicators": "DHI",
525
+ "Electric_Vehicle_Population": "EVP",
526
+ "Global_House_Purchase": "GHP",
527
+ "Health_Lifestyle": "HL",
528
+ }
529
+ return name_map[model_name]
530
+
531
+ def model_full_name(model_name):
532
+ model_mapping = {
533
+ "LS": "LeastSquares",
534
+ "LRBL2": "LogRegressionBinaryL2",
535
+ "ResNet18": "ResNet18",
536
+ }
537
+ return model_mapping[model_name]
538
+ # <optimizers_full_name>
539
+ def optimizers_full_name(optimizer_name):
540
+ name_map = {
541
+ "ADAM": "ADAM",
542
+ "SGD": "SGD",
543
+ "Bundle": "Bundle",
544
+ "ALR_SMAG": "ALR-SMAG",
545
+ "SPBM_TR": "SPBM-TR",
546
+ "SPBM_PF": "SPBM-PF",
547
+ "SPSmax": "SPSmax",
548
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
549
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
550
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
551
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
552
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
553
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
554
+ }
555
+ return name_map[optimizer_name]
556
+ # <optimizers_full_name>
557
+
558
+ # <dataset_full_name>
559
+ def dataset_full_name(dataset_name):
560
+ name_map = {
561
+ "MNIST": "MNIST",
562
+ "CIFAR100": "CIFAR100",
563
+ "Caltech101": "Caltech101_Resize_32",
564
+ "Duke": "Duke",
565
+ "AIP": "Adult_Income_Prediction",
566
+ "CCFD": "Credit_Card_Fraud_Detection",
567
+ "Ijcnn": "Ijcnn",
568
+ "DHI":"Diabetes_Health_Indicators",
569
+ "EVP": "Electric_Vehicle_Population",
570
+ "GHP": "Global_House_Purchase",
571
+ "HL": "Health_Lifestyle",
572
+ "HQC": "Homesite_Quote_Conversion",
573
+ "TN_Weather": "TN_Weather_2020_2025",
574
+ }
575
+ return name_map[dataset_name]
576
+ # <dataset_full_name>
577
+
578
+ def opt_paras_str(opt_paras_dict):
579
+ # Example: "k1_v1_k2_v2_..."
580
+
581
+ keys = list(opt_paras_dict.keys())
582
+ values = list(opt_paras_dict.values())
583
+
584
+ param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
585
+
586
+ return param_str
587
+ # <set_marker_point>
588
+ def set_marker_point(epoch_num: int) -> list:
589
+ marker_point = {
590
+ 1: [0],
591
+ 4: [0, 2, 4],
592
+ 6: [0, 2, 4, 6],
593
+ 8: [0, 2, 4, 6, 8],
594
+ 10: [0, 2, 4, 6, 8, 10],
595
+ 100: [0, 20, 40, 60, 80, 100],
596
+ 200: [0, 40, 80, 120, 160, 200],
597
+ }
598
+ if epoch_num not in marker_point:
599
+ raise ValueError(f"No marker defined for epoch {epoch_num}")
600
+
601
+ return marker_point[epoch_num]
602
+
603
+ # <set_marker_point>
604
+ # <results_path_to_info>
605
+ def results_path_to_info(path_list):
606
+ info_dict = {}
607
+
608
+ for path in path_list:
609
+ parts = path.split("/")
610
+ seed = parts[1]
611
+ model_name = parts[2]
612
+ data_name = parts[3]
613
+ optimizer = parts[4]
614
+ train_test = parts[5].split("_")
615
+ batch_size = parts[6].split("_")[2]
616
+ epochs = parts[7].split("_")[1]
617
+ ID = parts[8]
618
+
619
+ if model_name not in info_dict:
620
+ info_dict[model_name] = {}
621
+
622
+ if data_name not in info_dict[model_name]:
623
+ info_dict[model_name][data_name] = {}
624
+
625
+ if optimizer not in info_dict[model_name][data_name]:
626
+ info_dict[model_name][data_name][optimizer] = {}
627
+
628
+ info_dict[model_name][data_name][optimizer][ID] = {
629
+ "seed": seed.split("_")[1],
630
+ "epochs": int(epochs),
631
+ "train_test": (train_test[1], train_test[3]),
632
+ "batch_size": batch_size,
633
+ "marker": set_marker_point(int(epochs)),
634
+ "optimizer":{
635
+ f"{optimizer}":{
636
+ "ID": ID,
637
+ }
638
+ }
639
+ }
640
+
641
+ return info_dict
642
+ # <results_path_to_info>
643
+
644
+ # <update_info_dict>
645
+ def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
646
+ for data_name in draw_data_list:
647
+ for i in draw_data[data_name]:
648
+ optimizer_name, ID, Opt_Paras = i
649
+
650
+ if data_name not in results_dict[model_name].keys():
651
+ print('*' * 40)
652
+ print(f'{data_name} not in results')
653
+ print('*' * 40)
654
+ assert False
655
+
656
+ # Check if optimizer_name exists in results_dict
657
+ if optimizer_name not in results_dict[model_name][data_name]:
658
+ print('*' * 40)
659
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
660
+ print('*' * 40)
661
+ assert False
662
+
663
+ # Check if ID exists in results_dict
664
+ if ID not in results_dict[model_name][data_name][optimizer_name]:
665
+ print('*' * 60)
666
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
667
+ print('*' * 60)
668
+ assert False
669
+
670
+ # Initialize info_dict[data_name] if it does not exist
671
+ if data_name not in info_dict:
672
+ info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
673
+
674
+ # Update optimizer parameters
675
+ if "optimizer" not in info_dict[data_name]:
676
+ info_dict[data_name]["optimizer"] = {}
677
+ info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
678
+ info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
679
+
680
+ # Update metric_key
681
+ info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
682
+
683
+ return info_dict
684
+ # <update_info_dict>
685
+
686
+ def get_results_all_pkl_path(results_folder):
687
+
688
+ pattern = os.path.join(results_folder, "**", "*.pkl")
689
+
690
+ return glob.glob(pattern, recursive=True)