junshan-kit 2.5.1__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,60 +1,108 @@
1
1
  import numpy as np
2
- import sys, os, torch, random
2
+ import sys, os, torch, random, glob
3
3
  import argparse
4
- import junshan_kit.ModelsHub as ModelsHub
4
+ from datetime import datetime
5
+ script_dir = os.path.dirname(os.path.abspath(__file__))
6
+ sys.path.append(os.path.join(script_dir, 'src'))
7
+ from junshan_kit import ModelsHub, Check_Info
5
8
 
6
9
 
7
- class check_args:
10
+ class args:
8
11
  def __init__(self):
9
12
  pass
10
-
13
+
14
+ # <args>
11
15
  def get_args(self):
12
16
  parser = argparse.ArgumentParser(description="Combined config argument example")
13
17
 
14
- allowed_models = ["LS", "LRL2","ResNet18"]
15
- allowed_optimizers = ["ADAM", "SGD", "Bundle"]
18
+ # <allowed_models>
19
+ allowed_models = ["LS", "LRBL2", "ResNet18"]
20
+ # <allowed_models>
21
+
22
+ # <allowed_optimizers>
23
+ allowed_optimizers = [
24
+ "ADAM",
25
+ "ALR_SMAG",
26
+ "Bundle",
27
+ "SGD",
28
+ "SPBM_TR",
29
+ "SPBM_PF",
30
+ "SPSmax",
31
+ "SPBM_TR_NoneSpecial",
32
+ "SPBM_TR_NoneLower",
33
+ "SPBM_PF_NoneLower",
34
+ ]
35
+ # <allowed_optimizers>
36
+
37
+ # <allowed_datasets>
38
+ allowed_datasets = [
39
+ "MNIST",
40
+ "CIFAR100",
41
+ "Caltech101",
42
+ "AIP",
43
+ "CCFD",
44
+ "Duke",
45
+ "Ijcnn",
46
+ "DHI",
47
+ "EVP",
48
+ "GHP",
49
+ "HL",
50
+ "HQC",
51
+ "TN_Weather",
52
+ ],
53
+ # <allowed_datasets>
54
+ data_name_mapping = {
55
+ "MNIST": "MNIST",
56
+ "CIFAR100": "CIFAR100",
57
+ "Caltech101": "Caltech101_Resize_32",
58
+ "Duke": "Duke",
59
+ "AIP": "Adult_Income_Prediction",
60
+ "CCFD": "Credit_Card_Fraud_Detection",
61
+ "Ijcnn": "Ijcnn",
62
+ "DHI":"Diabetes_Health_Indicators",
63
+ "EVP": "Electric_Vehicle_Population",
64
+ "GHP": "Global_House_Purchase",
65
+ "HL": "Health_Lifestyle",
66
+ "HQC": "Homesite_Quote_Conversion",
67
+ "TN_Weather": "TN_Weather_2020_2025",
68
+ }
16
69
 
17
- allowed_datasets = ["MNIST",
18
- "CIFAR100",
19
- "AIP",
20
- "CCFD",
21
- ]
22
-
23
70
  optimizers_mapping = {
24
71
  "ADAM": "ADAM",
25
72
  "SGD": "SGD",
26
- "Bundle": "Bundle"
73
+ "Bundle": "Bundle",
74
+ "ALR_SMAG": "ALR-SMAG",
75
+ "SPBM_TR": "SPBM-TR",
76
+ "SPBM_PF": "SPBM-PF",
77
+ "SPSmax": "SPSmax",
78
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
79
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
80
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
81
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
82
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
83
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
27
84
  }
28
85
 
29
86
  model_mapping = {
30
87
  "LS": "LeastSquares",
31
- "LRL2": "LogRegressionBinaryL2",
88
+ "LRBL2": "LogRegressionBinaryL2",
32
89
  "ResNet18": "ResNet18"
33
90
  }
34
-
35
- data_name_mapping = {
36
- "MNIST": "MNIST",
37
- "CIFAR100": "CIFAR100",
38
- "AIP": "Adult_Income_Prediction",
39
- "CCFD": "Credit_Card_Fraud_Detection"
40
- }
41
91
 
42
-
43
-
44
- # Single combined argument that can appear multiple times
92
+ # <args_from_command>
45
93
  parser.add_argument(
46
94
  "--train",
47
95
  type=str,
48
96
  nargs="+", # Allow multiple configs
49
97
  required=True,
50
- help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
98
+ help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
51
99
  )
52
100
 
53
101
  parser.add_argument(
54
- "--e",
55
- type=int,
56
- required=True,
57
- help="Number of training epochs. Example: --e 50"
102
+ "--e",
103
+ type=int,
104
+ required=True,
105
+ help="Number of training epochs. Example: --e 50"
58
106
  )
59
107
 
60
108
  parser.add_argument(
@@ -88,63 +136,69 @@ class check_args:
88
136
  )
89
137
 
90
138
  parser.add_argument(
91
- "--subset",
92
- type=float,
93
- nargs=2,
94
- # required=True,
95
- help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
139
+ "--subset",
140
+ type=float,
141
+ nargs=2,
142
+ # required=True,
143
+ help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
96
144
  )
97
145
 
146
+ parser.add_argument(
147
+ "--time_str",
148
+ type=str,
149
+ nargs=1,
150
+ # required=True,
151
+ help = "the str of time"
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--send_email",
156
+ type=str,
157
+ nargs=3,
158
+ # required=True,
159
+ help = "from_email to_email, from_pwd"
160
+ )
161
+
162
+ parser.add_argument(
163
+ "--user_search_grid",
164
+ type=int,
165
+ nargs=1,
166
+ # required=True,
167
+ help = "search_grid: 1: "
168
+ )
169
+ # <args_from_command>
170
+
98
171
  args = parser.parse_args()
99
172
  args.model_name_mapping = model_mapping
100
173
  args.data_name_mapping = data_name_mapping
101
174
  args.optimizers_name_mapping = optimizers_mapping
102
175
 
103
-
104
- if args.subset is not None:
105
- self.check_subset_info(args, parser)
106
-
107
-
108
- self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
109
-
110
176
  return args
177
+ # <args>
178
+
179
+ def UpdateOtherParas(args, OtherParas):
180
+ if args.time_str is not None:
181
+ time_str = args.time_str[0]
182
+ else:
183
+ time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
111
184
 
112
- def check_subset_info(self, args, parser):
113
- total = sum(args.subset)
114
- if args.subset[0]>1:
115
- # CHECK
116
- for i in args.subset:
117
- if i < 1:
118
- parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
119
- else:
120
- if abs(total - 1.0) != 0.0:
121
- parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
122
-
123
-
124
- def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
125
- # Parse and validate each train_group
126
- for cfg in args.train:
127
- try:
128
- model, dataset, optimizer = cfg.split("-")
129
-
130
- if model not in allowed_models:
131
- parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
132
- if optimizer not in allowed_optimizers:
133
- parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
134
- if dataset not in allowed_datasets:
135
- parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
136
-
137
- except ValueError:
138
- parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
139
-
140
- for cfg in args.train:
141
- model_name, dataset_name, optimizer_name = cfg.split("-")
142
- try:
143
- f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
144
-
145
- except:
146
- print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
147
- assert False
185
+ if args.user_search_grid is not None:
186
+ OtherParas["user_search_grid"] = args.user_search_grid[0]
187
+ else:
188
+ OtherParas["user_search_grid"] = None
189
+
190
+ if args.send_email is not None:
191
+ OtherParas["from_email"] = args.send_email[0]
192
+ OtherParas["to_email"] = args.send_email[1]
193
+ OtherParas["from_pwd"] = args.send_email[2]
194
+ OtherParas["send_email"] = True
195
+ else:
196
+ OtherParas["send_email"] = False
197
+
198
+ OtherParas["time_str"] = time_str
199
+ OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
200
+
201
+ return OtherParas
148
202
 
149
203
  def get_train_group(args):
150
204
  training_group = []
@@ -163,8 +217,10 @@ def set_paras(args, OtherParas):
163
217
  # Print loss every N epochs.
164
218
  "epoch_log_interval": 1,
165
219
 
220
+ "use_log_scale": True,
221
+
166
222
  # Timestamp string for result saving.
167
- "time_str": ["time_str"],
223
+ "time_str": OtherParas["time_str"],
168
224
 
169
225
  # Random seed
170
226
  "seed": args.seed,
@@ -182,18 +238,23 @@ def set_paras(args, OtherParas):
182
238
  "split_train_data": args.s,
183
239
 
184
240
  # select_subset
185
- "select_subset": args.subset
241
+ "select_subset": args.subset,
242
+
243
+ # Results_dict
244
+ "Results_dict": {},
245
+
246
+ # type: bool
247
+ "user_search_grid": OtherParas["user_search_grid"],
186
248
  }
187
249
 
188
250
  Paras = model_list(Paras)
189
251
  Paras = model_type(Paras)
190
252
  Paras = data_list(Paras)
191
- Paras = optimizer_dict(Paras, OtherParas)
253
+ Paras = optimizer_paras_dict(Paras, OtherParas)
192
254
  Paras = device(Paras)
193
255
 
194
256
  return Paras
195
257
 
196
-
197
258
  def set_seed(seed=42):
198
259
  torch.manual_seed(seed)
199
260
  torch.cuda.manual_seed_all(seed)
@@ -234,34 +295,6 @@ def model_type(Paras) -> dict:
234
295
  return Paras
235
296
 
236
297
  def data_list(Paras) -> dict:
237
- """
238
- Attach a predefined list of dataset names to the parameter dictionary.
239
-
240
- The predefined datasets include:
241
- - Duke:
242
- - classes: 2
243
- - data: 42 (38 + 4)
244
- - features: 7,129
245
- - Ijcnn:
246
- - classes: 2
247
- - data: (35,000 + 91,701)
248
- - features: 22
249
- - w8a:
250
- - classes: 2
251
- - data: (49,749 + 14,951)
252
- - features: 300
253
- - RCV1
254
- - Shuttle
255
- - Letter
256
- - Vowel
257
- - MNIST
258
- - CIFAR100
259
- - CALTECH101_Resize_32
260
- - Adult Income Prediction
261
- -
262
- - Credit_Card_Fraud_Detection
263
- """
264
-
265
298
  data_list = [
266
299
  "Duke",
267
300
  "Ijcnn",
@@ -272,15 +305,21 @@ def data_list(Paras) -> dict:
272
305
  "Vowel",
273
306
  "MNIST",
274
307
  "CIFAR100",
275
- "CALTECH101_Resize_32",
308
+ "Caltech101_Resize_32",
276
309
  "Adult_Income_Prediction",
277
- "Credit_Card_Fraud_Detection"
310
+ "Credit_Card_Fraud_Detection",
311
+ "Diabetes_Health_Indicators",
312
+ "Electric_Vehicle_Population",
313
+ "Global_House_Purchase",
314
+ "Health_Lifestyle",
315
+ "Homesite_Quote_Conversion",
316
+ "TN_Weather_2020_2025"
278
317
  ]
279
318
  Paras["data_list"] = data_list
280
319
  return Paras
281
320
 
282
321
 
283
- def optimizer_dict(Paras, OtherParas)->dict:
322
+ def optimizer_paras_dict(Paras, OtherParas)->dict:
284
323
  optimizer_dict = {
285
324
  # ----------------- ADAM --------------------
286
325
  "ADAM": {
@@ -289,7 +328,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
289
328
  "alpha": (
290
329
  [0.5 * 1e-3, 1e-3, 2 * 1e-3]
291
330
  if OtherParas["SeleParasOn"]
292
- else [0.0005]
331
+ else [1e-3]
293
332
  ),
294
333
  "epsilon": [1e-8],
295
334
  "beta1": [0.9],
@@ -314,7 +353,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
314
353
  "delta": (
315
354
  [2**i for i in range(-8, 9)]
316
355
  if OtherParas["SeleParasOn"]
317
- else [0.25]
356
+ else [0.01]
318
357
  ),
319
358
  "cutting_number": [10],
320
359
  },
@@ -323,7 +362,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
323
362
  "SGD": {
324
363
  "params": {
325
364
  "alpha": (
326
- [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.5]
365
+ [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
327
366
  )
328
367
  }
329
368
  },
@@ -386,6 +425,19 @@ def optimizer_dict(Paras, OtherParas)->dict:
386
425
  "cutting_number": [10],
387
426
  },
388
427
  },
428
+ # ----------- SPBM-TR-NoneCut -----------
429
+ "SPBM-TR-NoneCut": {
430
+ "params": {
431
+ "M": [1e-5],
432
+ "delta": (
433
+ [2**i for i in range(-8, 9)]
434
+ if OtherParas["SeleParasOn"]
435
+ else [1]
436
+ ),
437
+ "cutting_number": [10],
438
+ },
439
+ },
440
+
389
441
  # ------------- SPBM-PF-NoneLower -----------
390
442
  "SPBM-PF-NoneLower": {
391
443
  "params": {
@@ -398,13 +450,23 @@ def optimizer_dict(Paras, OtherParas)->dict:
398
450
  "cutting_number": [10],
399
451
  },
400
452
  },
453
+ # ----------- SPBM-PF-NoneCut -----------
454
+ "SPBM-PF-NoneCut": {
455
+ "params": {
456
+ "M": [1e-5],
457
+ "delta": (
458
+ [2**i for i in range(-8, 9)]
459
+ if OtherParas["SeleParasOn"]
460
+ else [1]
461
+ ),
462
+ "cutting_number": [10],
463
+ },
464
+ },
401
465
  }
402
466
 
403
- Paras["optimizer_dict"] = optimizer_dict
467
+ Paras["optimizer_search_grid"] = optimizer_dict
404
468
  return Paras
405
469
 
406
-
407
-
408
470
  def metrics()->dict:
409
471
  metrics = {
410
472
  "epoch_loss": [],
@@ -416,4 +478,213 @@ def metrics()->dict:
416
478
  "grad_norm": [],
417
479
  "per_epoch_loss": []
418
480
  }
419
- return metrics
481
+ return metrics
482
+
483
+
484
+ def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
485
+
486
+ keys, values = list(params_gird.keys()), list(params_gird.values())
487
+
488
+ Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
489
+ os.makedirs(Paras["Results_folder"], exist_ok=True)
490
+
491
+ return keys, values, Paras
492
+
493
+
494
+ def fig_ylabel(str_name):
495
+
496
+ ylabel = {
497
+ "training_loss": "training loss",
498
+ "test_loss": "test loss",
499
+ "training_acc": "training accuracy",
500
+ "test_acc": "test accuracy",
501
+ "grad_norm": "grad norm",
502
+ "per_epoch_loss": "per epoch loss",
503
+ "epoch_loss": "epoch loss",
504
+ }
505
+ return ylabel[str_name]
506
+
507
+ def model_abbr(model_name):
508
+ name_map = {
509
+ "LogRegressionBinaryL2": "LRBL2",
510
+ "ResNet18": "ResNet18",
511
+ "ResNet34": "ResNet34",
512
+ "LstSquares": "LS"
513
+ }
514
+ return name_map[model_name]
515
+
516
+ def dataset_abbr(model_name):
517
+ name_map = {
518
+ "MNIST": "MNIST",
519
+ "CIFAR100": "CIFAR100",
520
+ "Duke": "Duke",
521
+ "Ijcnn": "Ijcnn",
522
+ "Adult_Income_Prediction": "AIP",
523
+ "Credit_Card_Frau_Detection": "CCFD",
524
+ "Diabetes_Health_Indicators": "DHI",
525
+ "Electric_Vehicle_Population": "EVP",
526
+ "Global_House_Purchase": "GHP",
527
+ "Health_Lifestyle": "HL",
528
+ }
529
+ return name_map[model_name]
530
+
531
+ def model_full_name(model_name):
532
+ model_mapping = {
533
+ "LS": "LeastSquares",
534
+ "LRBL2": "LogRegressionBinaryL2",
535
+ "ResNet18": "ResNet18",
536
+ }
537
+ return model_mapping[model_name]
538
+ # <optimizers_full_name>
539
+ def optimizers_full_name(optimizer_name):
540
+ name_map = {
541
+ "ADAM": "ADAM",
542
+ "SGD": "SGD",
543
+ "Bundle": "Bundle",
544
+ "ALR_SMAG": "ALR-SMAG",
545
+ "SPBM_TR": "SPBM-TR",
546
+ "SPBM_PF": "SPBM-PF",
547
+ "SPSmax": "SPSmax",
548
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
549
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
550
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
551
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
552
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
553
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
554
+ }
555
+ return name_map[optimizer_name]
556
+ # <optimizers_full_name>
557
+
558
+ # <dataset_full_name>
559
+ def dataset_full_name(dataset_name):
560
+ name_map = {
561
+ "MNIST": "MNIST",
562
+ "CIFAR100": "CIFAR100",
563
+ "Caltech101": "Caltech101_Resize_32",
564
+ "Duke": "Duke",
565
+ "AIP": "Adult_Income_Prediction",
566
+ "CCFD": "Credit_Card_Fraud_Detection",
567
+ "Ijcnn": "Ijcnn",
568
+ "DHI":"Diabetes_Health_Indicators",
569
+ "EVP": "Electric_Vehicle_Population",
570
+ "GHP": "Global_House_Purchase",
571
+ "HL": "Health_Lifestyle",
572
+ "HQC": "Homesite_Quote_Conversion",
573
+ "TN_Weather": "TN_Weather_2020_2025",
574
+ }
575
+ return name_map[dataset_name]
576
+ # <dataset_full_name>
577
+
578
+ def opt_paras_str(opt_paras_dict):
579
+ # Example: "k1_v1_k2_v2_..."
580
+
581
+ keys = list(opt_paras_dict.keys())
582
+ values = list(opt_paras_dict.values())
583
+
584
+ param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
585
+
586
+ return param_str
587
+ # <set_marker_point>
588
+ def set_marker_point(epoch_num: int) -> list:
589
+ marker_point = {
590
+ 1: [0],
591
+ 4: [0, 2, 4],
592
+ 6: [0, 2, 4, 6],
593
+ 8: [0, 2, 4, 6, 8],
594
+ 10: [0, 2, 4, 6, 8, 10],
595
+ 100: [0, 20, 40, 60, 80, 100],
596
+ 200: [0, 40, 80, 120, 160, 200],
597
+ }
598
+ if epoch_num not in marker_point:
599
+ raise ValueError(f"No marker defined for epoch {epoch_num}")
600
+
601
+ return marker_point[epoch_num]
602
+
603
+ # <set_marker_point>
604
+ # <results_path_to_info>
605
+ def results_path_to_info(path_list):
606
+ info_dict = {}
607
+
608
+ for path in path_list:
609
+ parts = path.split("/")
610
+ seed = parts[1]
611
+ model_name = parts[2]
612
+ data_name = parts[3]
613
+ optimizer = parts[4]
614
+ train_test = parts[5].split("_")
615
+ batch_size = parts[6].split("_")[2]
616
+ epochs = parts[7].split("_")[1]
617
+ ID = parts[8]
618
+
619
+ if model_name not in info_dict:
620
+ info_dict[model_name] = {}
621
+
622
+ if data_name not in info_dict[model_name]:
623
+ info_dict[model_name][data_name] = {}
624
+
625
+ if optimizer not in info_dict[model_name][data_name]:
626
+ info_dict[model_name][data_name][optimizer] = {}
627
+
628
+ info_dict[model_name][data_name][optimizer][ID] = {
629
+ "seed": seed.split("_")[1],
630
+ "epochs": int(epochs),
631
+ "train_test": (train_test[1], train_test[3]),
632
+ "batch_size": batch_size,
633
+ "marker": set_marker_point(int(epochs)),
634
+ "optimizer":{
635
+ f"{optimizer}":{
636
+ "ID": ID,
637
+ }
638
+ }
639
+ }
640
+
641
+ return info_dict
642
+ # <results_path_to_info>
643
+
644
+ # <update_info_dict>
645
+ def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
646
+ for data_name in draw_data_list:
647
+ for i in draw_data[data_name]:
648
+ optimizer_name, ID, Opt_Paras = i
649
+
650
+ if data_name not in results_dict[model_name].keys():
651
+ print('*' * 40)
652
+ print(f'{data_name} not in results')
653
+ print('*' * 40)
654
+ assert False
655
+
656
+ # Check if optimizer_name exists in results_dict
657
+ if optimizer_name not in results_dict[model_name][data_name]:
658
+ print('*' * 40)
659
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
660
+ print('*' * 40)
661
+ assert False
662
+
663
+ # Check if ID exists in results_dict
664
+ if ID not in results_dict[model_name][data_name][optimizer_name]:
665
+ print('*' * 60)
666
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
667
+ print('*' * 60)
668
+ assert False
669
+
670
+ # Initialize info_dict[data_name] if it does not exist
671
+ if data_name not in info_dict:
672
+ info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
673
+
674
+ # Update optimizer parameters
675
+ if "optimizer" not in info_dict[data_name]:
676
+ info_dict[data_name]["optimizer"] = {}
677
+ info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
678
+ info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
679
+
680
+ # Update metric_key
681
+ info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
682
+
683
+ return info_dict
684
+ # <update_info_dict>
685
+
686
+ def get_results_all_pkl_path(results_folder):
687
+
688
+ pattern = os.path.join(results_folder, "**", "*.pkl")
689
+
690
+ return glob.glob(pattern, recursive=True)