junshan-kit 2.5.1__py2.py3-none-any.whl → 2.8.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,60 +1,110 @@
1
1
  import numpy as np
2
- import sys, os, torch, random
2
+ import sys, os, torch, random, glob
3
3
  import argparse
4
- import junshan_kit.ModelsHub as ModelsHub
4
+ from datetime import datetime
5
+ script_dir = os.path.dirname(os.path.abspath(__file__))
6
+ sys.path.append(os.path.join(script_dir, 'src'))
7
+ from junshan_kit import ModelsHub, Check_Info
5
8
 
6
9
 
7
- class check_args:
10
+ class args:
8
11
  def __init__(self):
9
12
  pass
10
-
13
+
14
+ # <args>
11
15
  def get_args(self):
12
16
  parser = argparse.ArgumentParser(description="Combined config argument example")
13
17
 
14
- allowed_models = ["LS", "LRL2","ResNet18"]
15
- allowed_optimizers = ["ADAM", "SGD", "Bundle"]
18
+ # <allowed_models>
19
+ allowed_models = ["LS", "LRBL2", "ResNet18"]
20
+ # <allowed_models>
21
+
22
+ # <allowed_optimizers>
23
+ allowed_optimizers = [
24
+ "ADAM",
25
+ "ALR_SMAG",
26
+ "Bundle",
27
+ "SGD",
28
+ "SPBM_TR",
29
+ "SPBM_PF",
30
+ "SPSmax",
31
+ "SPBM_TR_NoneSpecial",
32
+ "SPBM_TR_NoneLower",
33
+ "SPBM_PF_NoneLower",
34
+ ]
35
+ # <allowed_optimizers>
36
+
37
+ # <allowed_datasets>
38
+ allowed_datasets = [
39
+ "MNIST",
40
+ "CIFAR100",
41
+ "Caltech101",
42
+ "AIP",
43
+ "CCFD",
44
+ "Duke",
45
+ "Ijcnn",
46
+ "DHI",
47
+ "EVP",
48
+ "GHP",
49
+ "HL",
50
+ "HQC",
51
+ "TN_Weather",
52
+ ],
53
+ # <allowed_datasets>
54
+ data_name_mapping = {
55
+ "MNIST": "MNIST",
56
+ "CIFAR100": "CIFAR100",
57
+ "Caltech101": "Caltech101_Resize_32",
58
+ "Duke": "Duke",
59
+ "AIP": "Adult_Income_Prediction",
60
+ "CCFD": "Credit_Card_Fraud_Detection",
61
+ "Ijcnn": "Ijcnn",
62
+ "RCV1": "RCV1",
63
+ "w8a": "w8a",
64
+ "DHI":"Diabetes_Health_Indicators",
65
+ "EVP": "Electric_Vehicle_Population",
66
+ "GHP": "Global_House_Purchase",
67
+ "HL": "Health_Lifestyle",
68
+ "HQC": "Homesite_Quote_Conversion",
69
+ "TN_Weather": "TN_Weather_2020_2025",
70
+ }
16
71
 
17
- allowed_datasets = ["MNIST",
18
- "CIFAR100",
19
- "AIP",
20
- "CCFD",
21
- ]
22
-
23
72
  optimizers_mapping = {
24
73
  "ADAM": "ADAM",
25
74
  "SGD": "SGD",
26
- "Bundle": "Bundle"
75
+ "Bundle": "Bundle",
76
+ "ALR_SMAG": "ALR-SMAG",
77
+ "SPBM_TR": "SPBM-TR",
78
+ "SPBM_PF": "SPBM-PF",
79
+ "SPSmax": "SPSmax",
80
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
81
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
82
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
83
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
84
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
85
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
27
86
  }
28
87
 
29
88
  model_mapping = {
30
89
  "LS": "LeastSquares",
31
- "LRL2": "LogRegressionBinaryL2",
90
+ "LRBL2": "LogRegressionBinaryL2",
32
91
  "ResNet18": "ResNet18"
33
92
  }
34
-
35
- data_name_mapping = {
36
- "MNIST": "MNIST",
37
- "CIFAR100": "CIFAR100",
38
- "AIP": "Adult_Income_Prediction",
39
- "CCFD": "Credit_Card_Fraud_Detection"
40
- }
41
93
 
42
-
43
-
44
- # Single combined argument that can appear multiple times
94
+ # <args_from_command>
45
95
  parser.add_argument(
46
96
  "--train",
47
97
  type=str,
48
98
  nargs="+", # Allow multiple configs
49
99
  required=True,
50
- help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
100
+ help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR100-ADAM). model: {allowed_models},\n datasets: {allowed_datasets},\n optimizers: {allowed_optimizers},"
51
101
  )
52
102
 
53
103
  parser.add_argument(
54
- "--e",
55
- type=int,
56
- required=True,
57
- help="Number of training epochs. Example: --e 50"
104
+ "--e",
105
+ type=int,
106
+ required=True,
107
+ help="Number of training epochs. Example: --e 50"
58
108
  )
59
109
 
60
110
  parser.add_argument(
@@ -88,63 +138,84 @@ class check_args:
88
138
  )
89
139
 
90
140
  parser.add_argument(
91
- "--subset",
92
- type=float,
93
- nargs=2,
94
- # required=True,
95
- help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
141
+ "--subset",
142
+ type=float,
143
+ nargs=2,
144
+ # required=True,
145
+ help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
96
146
  )
97
147
 
98
- args = parser.parse_args()
99
- args.model_name_mapping = model_mapping
100
- args.data_name_mapping = data_name_mapping
101
- args.optimizers_name_mapping = optimizers_mapping
148
+ parser.add_argument(
149
+ "--time_str",
150
+ type=str,
151
+ nargs=1,
152
+ # required=True,
153
+ help = "the str of time"
154
+ )
102
155
 
156
+ parser.add_argument(
157
+ "--send_email",
158
+ type=str,
159
+ nargs=3,
160
+ # required=True,
161
+ help = "from_email to_email, from_pwd"
162
+ )
103
163
 
104
- if args.subset is not None:
105
- self.check_subset_info(args, parser)
164
+ parser.add_argument(
165
+ "--user_search_grid",
166
+ type=int,
167
+ nargs=1,
168
+ # required=True,
169
+ help = "search_grid: 1: "
170
+ )
106
171
 
172
+ parser.add_argument(
173
+ "--OptParas",
174
+ type=int,
175
+ nargs=1,
176
+ help="Number of optimization steps for parameter tuning (default: 1)"
177
+ )
178
+ # <args_from_command>
107
179
 
108
- self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
180
+ args = parser.parse_args()
181
+ args.model_name_mapping = model_mapping
182
+ args.data_name_mapping = data_name_mapping
183
+ args.optimizers_name_mapping = optimizers_mapping
109
184
 
110
185
  return args
186
+ # <args>
187
+
188
+ def UpdateOtherParas(args, OtherParas):
189
+ # <time_str>
190
+ if args.time_str is not None:
191
+ time_str = args.time_str[0]
192
+ else:
193
+ time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
111
194
 
112
- def check_subset_info(self, args, parser):
113
- total = sum(args.subset)
114
- if args.subset[0]>1:
115
- # CHECK
116
- for i in args.subset:
117
- if i < 1:
118
- parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
119
- else:
120
- if abs(total - 1.0) != 0.0:
121
- parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
122
-
123
-
124
- def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
125
- # Parse and validate each train_group
126
- for cfg in args.train:
127
- try:
128
- model, dataset, optimizer = cfg.split("-")
129
-
130
- if model not in allowed_models:
131
- parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
132
- if optimizer not in allowed_optimizers:
133
- parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
134
- if dataset not in allowed_datasets:
135
- parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
136
-
137
- except ValueError:
138
- parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
139
-
140
- for cfg in args.train:
141
- model_name, dataset_name, optimizer_name = cfg.split("-")
142
- try:
143
- f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
144
-
145
- except:
146
- print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
147
- assert False
195
+ # <user_search_grid>
196
+ if args.user_search_grid is not None:
197
+ OtherParas["user_search_grid"] = args.user_search_grid[0]
198
+ else:
199
+ OtherParas["user_search_grid"] = None
200
+
201
+ # <send_email>
202
+ if args.send_email is not None:
203
+ OtherParas["from_email"] = args.send_email[0]
204
+ OtherParas["to_email"] = args.send_email[1]
205
+ OtherParas["from_pwd"] = args.send_email[2]
206
+ OtherParas["send_email"] = True
207
+ else:
208
+ OtherParas["send_email"] = False
209
+
210
+ if args.OptParas is not None:
211
+ OtherParas["SeleParasOn"] = False
212
+ else:
213
+ OtherParas["SeleParasOn"] = True
214
+
215
+ OtherParas["time_str"] = time_str
216
+ OtherParas["results_folder_name"] = f'Results_{OtherParas["exp_name"]}'
217
+
218
+ return OtherParas
148
219
 
149
220
  def get_train_group(args):
150
221
  training_group = []
@@ -163,8 +234,10 @@ def set_paras(args, OtherParas):
163
234
  # Print loss every N epochs.
164
235
  "epoch_log_interval": 1,
165
236
 
237
+ "use_log_scale": True,
238
+
166
239
  # Timestamp string for result saving.
167
- "time_str": ["time_str"],
240
+ "time_str": OtherParas["time_str"],
168
241
 
169
242
  # Random seed
170
243
  "seed": args.seed,
@@ -182,18 +255,23 @@ def set_paras(args, OtherParas):
182
255
  "split_train_data": args.s,
183
256
 
184
257
  # select_subset
185
- "select_subset": args.subset
258
+ "select_subset": args.subset,
259
+
260
+ # Results_dict
261
+ "Results_dict": {},
262
+
263
+ # type: bool
264
+ "user_search_grid": OtherParas["user_search_grid"],
186
265
  }
187
266
 
188
267
  Paras = model_list(Paras)
189
268
  Paras = model_type(Paras)
190
269
  Paras = data_list(Paras)
191
- Paras = optimizer_dict(Paras, OtherParas)
270
+ Paras = optimizer_paras_dict(Paras, OtherParas)
192
271
  Paras = device(Paras)
193
272
 
194
273
  return Paras
195
274
 
196
-
197
275
  def set_seed(seed=42):
198
276
  torch.manual_seed(seed)
199
277
  torch.cuda.manual_seed_all(seed)
@@ -234,34 +312,6 @@ def model_type(Paras) -> dict:
234
312
  return Paras
235
313
 
236
314
  def data_list(Paras) -> dict:
237
- """
238
- Attach a predefined list of dataset names to the parameter dictionary.
239
-
240
- The predefined datasets include:
241
- - Duke:
242
- - classes: 2
243
- - data: 42 (38 + 4)
244
- - features: 7,129
245
- - Ijcnn:
246
- - classes: 2
247
- - data: (35,000 + 91,701)
248
- - features: 22
249
- - w8a:
250
- - classes: 2
251
- - data: (49,749 + 14,951)
252
- - features: 300
253
- - RCV1
254
- - Shuttle
255
- - Letter
256
- - Vowel
257
- - MNIST
258
- - CIFAR100
259
- - CALTECH101_Resize_32
260
- - Adult Income Prediction
261
- -
262
- - Credit_Card_Fraud_Detection
263
- """
264
-
265
315
  data_list = [
266
316
  "Duke",
267
317
  "Ijcnn",
@@ -272,15 +322,21 @@ def data_list(Paras) -> dict:
272
322
  "Vowel",
273
323
  "MNIST",
274
324
  "CIFAR100",
275
- "CALTECH101_Resize_32",
325
+ "Caltech101_Resize_32",
276
326
  "Adult_Income_Prediction",
277
- "Credit_Card_Fraud_Detection"
327
+ "Credit_Card_Fraud_Detection",
328
+ "Diabetes_Health_Indicators",
329
+ "Electric_Vehicle_Population",
330
+ "Global_House_Purchase",
331
+ "Health_Lifestyle",
332
+ "Homesite_Quote_Conversion",
333
+ "TN_Weather_2020_2025"
278
334
  ]
279
335
  Paras["data_list"] = data_list
280
336
  return Paras
281
337
 
282
338
 
283
- def optimizer_dict(Paras, OtherParas)->dict:
339
+ def optimizer_paras_dict(Paras, OtherParas)->dict:
284
340
  optimizer_dict = {
285
341
  # ----------------- ADAM --------------------
286
342
  "ADAM": {
@@ -289,7 +345,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
289
345
  "alpha": (
290
346
  [0.5 * 1e-3, 1e-3, 2 * 1e-3]
291
347
  if OtherParas["SeleParasOn"]
292
- else [0.0005]
348
+ else [1e-3]
293
349
  ),
294
350
  "epsilon": [1e-8],
295
351
  "beta1": [0.9],
@@ -314,7 +370,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
314
370
  "delta": (
315
371
  [2**i for i in range(-8, 9)]
316
372
  if OtherParas["SeleParasOn"]
317
- else [0.25]
373
+ else [0.01]
318
374
  ),
319
375
  "cutting_number": [10],
320
376
  },
@@ -323,7 +379,7 @@ def optimizer_dict(Paras, OtherParas)->dict:
323
379
  "SGD": {
324
380
  "params": {
325
381
  "alpha": (
326
- [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.5]
382
+ [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.001]
327
383
  )
328
384
  }
329
385
  },
@@ -386,6 +442,18 @@ def optimizer_dict(Paras, OtherParas)->dict:
386
442
  "cutting_number": [10],
387
443
  },
388
444
  },
445
+ # ----------- SPBM-TR-NoneCut -----------
446
+ "SPBM-TR-NoneCut": {
447
+ "params": {
448
+ "delta": (
449
+ [2**i for i in range(-8, 9)]
450
+ if OtherParas["SeleParasOn"]
451
+ else [1]
452
+ ),
453
+ "cutting_number": [10],
454
+ },
455
+ },
456
+
389
457
  # ------------- SPBM-PF-NoneLower -----------
390
458
  "SPBM-PF-NoneLower": {
391
459
  "params": {
@@ -398,13 +466,22 @@ def optimizer_dict(Paras, OtherParas)->dict:
398
466
  "cutting_number": [10],
399
467
  },
400
468
  },
469
+ # ----------- SPBM-PF-NoneCut -----------
470
+ "SPBM-PF-NoneCut": {
471
+ "params": {
472
+ "delta": (
473
+ [2**i for i in range(-8, 9)]
474
+ if OtherParas["SeleParasOn"]
475
+ else [1]
476
+ ),
477
+ "cutting_number": [10],
478
+ },
479
+ },
401
480
  }
402
481
 
403
- Paras["optimizer_dict"] = optimizer_dict
482
+ Paras["optimizer_search_grid"] = optimizer_dict
404
483
  return Paras
405
484
 
406
-
407
-
408
485
  def metrics()->dict:
409
486
  metrics = {
410
487
  "epoch_loss": [],
@@ -416,4 +493,214 @@ def metrics()->dict:
416
493
  "grad_norm": [],
417
494
  "per_epoch_loss": []
418
495
  }
419
- return metrics
496
+ return metrics
497
+
498
+
499
+ def hyperparas_and_path(Paras, model_name, data_name, optimizer_name, params_gird):
500
+
501
+ keys, values = list(params_gird.keys()), list(params_gird.values())
502
+
503
+ Paras["Results_folder"] = f'./{Paras["results_folder_name"]}/seed_{Paras["seed"]}/{model_name}/{data_name}/{optimizer_name}/train_{Paras["train_data_num"]}_test_{Paras["test_data_num"]}/Batch_size_{Paras["batch_size"]}/epoch_{Paras["epochs"]}/{Paras["time_str"]}'
504
+ os.makedirs(Paras["Results_folder"], exist_ok=True)
505
+
506
+ return keys, values, Paras
507
+
508
+
509
+ def fig_ylabel(str_name):
510
+
511
+ ylabel = {
512
+ "training_loss": "training loss",
513
+ "test_loss": "test loss",
514
+ "training_acc": "training accuracy",
515
+ "test_acc": "test accuracy",
516
+ "grad_norm": "grad norm",
517
+ "per_epoch_loss": "per epoch loss",
518
+ "epoch_loss": "epoch loss",
519
+ }
520
+ return ylabel[str_name]
521
+
522
+ def model_abbr(model_name):
523
+ name_map = {
524
+ "LogRegressionBinaryL2": "LRBL2",
525
+ "ResNet18": "ResNet18",
526
+ "ResNet34": "ResNet34",
527
+ "LstSquares": "LS"
528
+ }
529
+ return name_map[model_name]
530
+
531
+ def dataset_abbr(model_name):
532
+ name_map = {
533
+ "MNIST": "MNIST",
534
+ "CIFAR100": "CIFAR100",
535
+ "Duke": "Duke",
536
+ "Ijcnn": "Ijcnn",
537
+ "Adult_Income_Prediction": "AIP",
538
+ "Credit_Card_Frau_Detection": "CCFD",
539
+ "Diabetes_Health_Indicators": "DHI",
540
+ "Electric_Vehicle_Population": "EVP",
541
+ "Global_House_Purchase": "GHP",
542
+ "Health_Lifestyle": "HL",
543
+ }
544
+ return name_map[model_name]
545
+
546
+ def model_full_name(model_name):
547
+ model_mapping = {
548
+ "LS": "LeastSquares",
549
+ "LRBL2": "LogRegressionBinaryL2",
550
+ "ResNet18": "ResNet18",
551
+ }
552
+ return model_mapping[model_name]
553
+ # <optimizers_full_name>
554
+ def optimizers_full_name(optimizer_name):
555
+ name_map = {
556
+ "ADAM": "ADAM",
557
+ "SGD": "SGD",
558
+ "Bundle": "Bundle",
559
+ "ALR_SMAG": "ALR-SMAG",
560
+ "SPBM_TR": "SPBM-TR",
561
+ "SPBM_PF": "SPBM-PF",
562
+ "SPSmax": "SPSmax",
563
+ "SPBM_TR_NoneSpecial": "SPBM-TR-NoneSpecial",
564
+ "SPBM_TR_NoneLower": "SPBM-TR-NoneLower",
565
+ "SPBM_TR_NoneCut": "SPBM-TR-NoneCut",
566
+ "SPBM_PF_NoneSpecial": "SPBM-PF-NoneSpecial",
567
+ "SPBM_PF_NoneLower": "SPBM-PF-NoneLower",
568
+ "SPBM_PF_NoneCut": "SPBM-PF-NoneCut"
569
+ }
570
+ return name_map[optimizer_name]
571
+ # <optimizers_full_name>
572
+
573
+ # <dataset_full_name>
574
+ def dataset_full_name(dataset_name):
575
+ name_map = {
576
+ "MNIST": "MNIST",
577
+ "CIFAR100": "CIFAR100",
578
+ "Caltech101": "Caltech101_Resize_32",
579
+ "Duke": "Duke",
580
+ "AIP": "Adult_Income_Prediction",
581
+ "CCFD": "Credit_Card_Fraud_Detection",
582
+ "Ijcnn": "Ijcnn",
583
+ "DHI":"Diabetes_Health_Indicators",
584
+ "EVP": "Electric_Vehicle_Population",
585
+ "GHP": "Global_House_Purchase",
586
+ "HL": "Health_Lifestyle",
587
+ "HQC": "Homesite_Quote_Conversion",
588
+ "TN_Weather": "TN_Weather_2020_2025",
589
+ }
590
+ return name_map[dataset_name]
591
+ # <dataset_full_name>
592
+
593
+ def opt_paras_str(opt_paras_dict):
594
+ # Example: "k1_v1_k2_v2_..."
595
+
596
+ keys = list(opt_paras_dict.keys())
597
+ values = list(opt_paras_dict.values())
598
+
599
+ param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values) if k != "ID")
600
+
601
+ return param_str
602
+ # <set_marker_point>
603
+ def set_marker_point(epoch_num: int) -> list:
604
+ marker_point = {
605
+ 1: [0],
606
+ 4: [0, 2, 4],
607
+ 6: [0, 2, 4, 6],
608
+ 8: [0, 2, 4, 6, 8],
609
+ 10: [0, 2, 4, 6, 8, 10],
610
+ 50: [0, 10, 20, 30, 40, 50],
611
+ 100: [0, 20, 40, 60, 80, 100],
612
+ 200: [0, 40, 80, 120, 160, 200],
613
+ }
614
+ if epoch_num not in marker_point:
615
+ raise ValueError(f"No marker defined for epoch {epoch_num}")
616
+
617
+ return marker_point[epoch_num]
618
+
619
+ # <set_marker_point>
620
+ # <results_path_to_info>
621
+ def results_path_to_info(path_list):
622
+ info_dict = {}
623
+
624
+ for path in path_list:
625
+ parts = path.split("/")
626
+ seed = parts[1]
627
+ model_name = parts[2]
628
+ data_name = parts[3]
629
+ optimizer = parts[4]
630
+ train_test = parts[5].split("_")
631
+ batch_size = parts[6].split("_")[2]
632
+ epochs = parts[7].split("_")[1]
633
+ ID = parts[8]
634
+
635
+ if model_name not in info_dict:
636
+ info_dict[model_name] = {}
637
+
638
+ if data_name not in info_dict[model_name]:
639
+ info_dict[model_name][data_name] = {}
640
+
641
+ if optimizer not in info_dict[model_name][data_name]:
642
+ info_dict[model_name][data_name][optimizer] = {}
643
+
644
+ info_dict[model_name][data_name][optimizer][ID] = {
645
+ "seed": seed.split("_")[1],
646
+ "epochs": int(epochs),
647
+ "train_test": (train_test[1], train_test[3]),
648
+ "batch_size": batch_size,
649
+ "marker": set_marker_point(int(epochs)),
650
+ "optimizer":{
651
+ f"{optimizer}":{
652
+ "ID": ID,
653
+ }
654
+ }
655
+ }
656
+
657
+ return info_dict
658
+ # <results_path_to_info>
659
+
660
+ # <update_info_dict>
661
+ def update_info_dict(draw_data_list, draw_data, results_dict, model_name, info_dict, metric_key_dict):
662
+ for data_name in draw_data_list:
663
+ for i in draw_data[data_name]:
664
+ optimizer_name, ID, Opt_Paras = i
665
+
666
+ if data_name not in results_dict[model_name].keys():
667
+ print('*' * 40)
668
+ print(f'{data_name} not in results')
669
+ print('*' * 40)
670
+ assert False
671
+
672
+ # Check if optimizer_name exists in results_dict
673
+ if optimizer_name not in results_dict[model_name][data_name]:
674
+ print('*' * 40)
675
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {optimizer_name} is error.')
676
+ print('*' * 40)
677
+ assert False
678
+
679
+ # Check if ID exists in results_dict
680
+ if ID not in results_dict[model_name][data_name][optimizer_name]:
681
+ print('*' * 60)
682
+ print(f'({data_name}, {optimizer_name}, {ID}) not in results_dict and \n {ID} is error.')
683
+ print('*' * 60)
684
+ assert False
685
+
686
+ # Initialize info_dict[data_name] if it does not exist
687
+ if data_name not in info_dict:
688
+ info_dict[data_name] = results_dict[model_name][data_name][optimizer_name][ID].copy()
689
+
690
+ # Update optimizer parameters
691
+ if "optimizer" not in info_dict[data_name]:
692
+ info_dict[data_name]["optimizer"] = {}
693
+ info_dict[data_name]["optimizer"][optimizer_name] = Opt_Paras
694
+ info_dict[data_name]["optimizer"][optimizer_name]["ID"] = ID
695
+
696
+ # Update metric_key
697
+ info_dict[data_name]["metric_key"] = metric_key_dict[data_name]
698
+
699
+ return info_dict
700
+ # <update_info_dict>
701
+
702
+ def get_results_all_pkl_path(results_folder):
703
+
704
+ pattern = os.path.join(results_folder, "**", "*.pkl")
705
+
706
+ return glob.glob(pattern, recursive=True)