junshan-kit 2.3.8__py2.py3-none-any.whl → 2.4.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,470 @@
1
+ # Step 1 : training_group
2
+ def set_training_group():
3
+ # <training_group>
4
+ training_group = [
5
+ # *********************************************************
6
+ # ----------------- MNIST (ResNet18) ----------------------
7
+ # ("ResNet18", "MNIST", "SGD"),
8
+ # ("ResNet18", "MNIST", "ADAM"),
9
+ # ("ResNet18", "MNIST", "SPSmax"),
10
+ # ("ResNet18", "MNIST", "Bundle"),
11
+ # ("ResNet18", "MNIST", "ALR-SMAG"),
12
+ # ("ResNet18", "MNIST", "SPBM-TR"),
13
+ # ("ResNet18", "MNIST", "SPBM-PF"),
14
+ # ("ResNet18", "MNIST", "SPBM-PF-NoneLower"),
15
+ # ("ResNet18", "MNIST", "SPBM-TR-NoneLower"),
16
+ # ("ResNet18", "MNIST", "SPBM-TR-NoneSpecial"),
17
+ # ---------------- CIFAR100 (ResNet18)---------------------
18
+ # ("ResNet18", "CIFAR100", "SGD"),
19
+ # ("ResNet18", "CIFAR100", "ADAM"),
20
+ # ("ResNet18", "CIFAR100", "SPSmax"),
21
+ # ("ResNet18", "CIFAR100", "Bundle"),
22
+ # ("ResNet18", "CIFAR100", "ALR-SMAG"),
23
+ # ("ResNet18", "CIFAR100", "SPBM-TR"),
24
+ # ("ResNet18", "CIFAR100", "SPBM-PF"),
25
+ # ("ResNet18", "CIFAR100", "SPBM-PF-NoneLower"),
26
+ # ("ResNet18", "CIFAR100", "SPBM-TR-NoneLower"),
27
+ # # ----------- CALTECH101_Resize_32 (ResNet18) -------------
28
+ ("ResNet18", "CALTECH101_Resize_32", "SGD"),
29
+ # ("ResNet18", "CALTECH101_Resize_32", "ADAM"),
30
+ # ("ResNet18", "CALTECH101_Resize_32", "SPSmax"),
31
+ # ("ResNet18", "CALTECH101_Resize_32", "Bundle"),
32
+ # ("ResNet18", "CALTECH101_Resize_32", "ALR-SMAG"),
33
+ # ("ResNet18", "CALTECH101_Resize_32", "SPBM-TR"),
34
+ # ("ResNet18", "CALTECH101_Resize_32", "SPBM-PF"),
35
+ # ("ResNet18", "CALTECH101_Resize_32", "SPBM-PF-NoneLower"),
36
+ # ("ResNet18", "CALTECH101_Resize_32", "SPBM-TR-NoneLower"),
37
+
38
+ # *********************************************************
39
+ # ---------------- MNIST (ResNet34) -----------------------
40
+ # ("ResNet34" ,"MNIST", "SGD"),
41
+ # ("ResNet34" ,"MNIST", "ADAM"),
42
+ # ("ResNet34" ,"MNIST", "SPSmax"),
43
+ # ("ResNet34" ,"MNIST", "Bundle"),
44
+ # ("ResNet34" ,"MNIST", "ALR-SMAG"),
45
+ # ("ResNet34" ,"MNIST", "SPBM-TR"),
46
+ # ("ResNet34" ,"MNIST", "SPBM-PF"),
47
+ # ("ResNet34" ,"MNIST", "SPBM-PF-NoneLower"),
48
+ # ("ResNet34" ,"MNIST", "SPBM-TR-NoneLower"),
49
+ # ------------------ CIFAR100 (ResNet34)-------------------
50
+ # ("ResNet34" ,"CIFAR100", "SGD"),
51
+ # ("ResNet34" ,"CIFAR100", "ADAM"),
52
+ # ("ResNet34" ,"CIFAR100", "SPSmax"),
53
+ # ("ResNet34" ,"CIFAR100", "Bundle"),
54
+ # ("ResNet34" ,"CIFAR100", "ALR-SMAG"),
55
+ # ("ResNet34" ,"CIFAR100", "SPBM-TR"),
56
+ # ("ResNet34" ,"CIFAR100", "SPBM-PF"),
57
+ # ("ResNet34" ,"CIFAR100", "SPBM-PF-NoneLower"),
58
+ # ("ResNet34" ,"CIFAR100", "SPBM-TR-NoneLower"),
59
+ # ------------ CALTECH101_Resize_32 (ResNet34) ------------
60
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SGD"),
61
+ # ("ResNet34" ,"CALTECH101_Resize_32", "ADAM"),
62
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SPSmax"),
63
+ # ("ResNet34" ,"CALTECH101_Resize_32", "Bundle"),
64
+ # ("ResNet34" ,"CALTECH101_Resize_32", "ALR-SMAG"),
65
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SPBM-TR"),
66
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SPBM-PF"),
67
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SPBM-PF-NoneLower"),
68
+ # ("ResNet34" ,"CALTECH101_Resize_32", "SPBM-TR-NoneLower"),
69
+
70
+ # *********************************************************
71
+ # ------------------ MNIST (LeastSquares) -----------------
72
+ # ("LeastSquares" ,"MNIST", "SGD"),
73
+ # ("LeastSquares" ,"MNIST", "ADAM"),
74
+ # ("LeastSquares" ,"MNIST", "SPSmax"),
75
+ # ("LeastSquares" ,"MNIST", "Bundle"),
76
+ # ("LeastSquares" ,"MNIST", "ALR-SMAG"),
77
+ # ("LeastSquares" ,"MNIST", "SPBM-TR"),
78
+ # ("LeastSquares" ,"MNIST", "SPBM-PF"),
79
+ # ("LeastSquares" ,"MNIST", "SPBM-PF-NoneLower"),
80
+ # ("LeastSquares" ,"MNIST", "SPBM-TR-NoneLower"),
81
+ # ---------------- CIFAR100 (LeastSquares) ----------------
82
+ # ("LeastSquares" ,"CIFAR100", "SGD"),
83
+ # ("LeastSquares" ,"CIFAR100", "ADAM"),
84
+ # ("LeastSquares" ,"CIFAR100", "SPSmax"),
85
+ # ("LeastSquares" ,"CIFAR100", "Bundle"),
86
+ # ("LeastSquares" ,"CIFAR100", "ALR-SMAG"),
87
+ # ("LeastSquares" ,"CIFAR100", "SPBM-TR"),
88
+ # ("LeastSquares" ,"CIFAR100", "SPBM-PF"),
89
+ # ("LeastSquares" ,"CIFAR100", "SPBM-PF-NoneLower"),
90
+ # ("LeastSquares" ,"CIFAR100", "SPBM-TR-NoneLower"),
91
+ # ---------------- CIFAR100 (LeastSquares) ----------------
92
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SGD"),
93
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "ADAM"),
94
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SPSmax"),
95
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "Bundle"),
96
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "ALR-SMAG"),
97
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SPBM-TR"),
98
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SPBM-PF"),
99
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SPBM-PF-NoneLower"),
100
+ # ("LeastSquares" ,"CALTECH101_Resize_32", "SPBM-TR-NoneLower"),
101
+
102
+ # *********************************************************
103
+ # ------------- MNIST (LogRegressionBinary) ---------------
104
+ # ("LogRegressionBinary" ,"MNIST", "SGD"),
105
+ # ("LogRegressionBinary" ,"MNIST", "ADAM"),
106
+ # ("LogRegressionBinary" ,"MNIST", "SPSmax"),
107
+ # ("LogRegressionBinary" ,"MNIST", "Bundle"),
108
+ # ("LogRegressionBinary" ,"MNIST", "ALR-SMAG"),
109
+ # ("LogRegressionBinary" ,"MNIST", "SPBM-TR"),
110
+ # ("LogRegressionBinary" ,"MNIST", "SPBM-PF"),
111
+ # ("LogRegressionBinary" ,"MNIST", "SPBM-PF-NoneLower"),
112
+ # ("LogRegressionBinary" ,"MNIST", "SPBM-TR-NoneLower"),
113
+ # ------------- CIFAR100 (LogRegressionBinary) ------------
114
+ # ("LogRegressionBinary" ,"CIFAR100", "SGD"),
115
+ # ("LogRegressionBinary" ,"CIFAR100", "ADAM"),
116
+ # ("LogRegressionBinary" ,"CIFAR100", "SPSmax"),
117
+ # ("LogRegressionBinary" ,"CIFAR100", "Bundle"),
118
+ # ("LogRegressionBinary" ,"CIFAR100", "ALR-SMAG"),
119
+ # ("LogRegressionBinary" ,"CIFAR100", "SPBM-TR"),
120
+ # ("LogRegressionBinary" ,"CIFAR100", "SPBM-PF"),
121
+ # ("LogRegressionBinary" ,"CIFAR100", "SPBM-PF-NoneLower"),
122
+ # ("LogRegressionBinary" ,"CIFAR100", "SPBM-TR-NoneLower"),
123
+ # # --------------- RCV1 (LogRegressionBinary) --------------
124
+ # ("LogRegressionBinary" ,"RCV1", "SGD"),
125
+ # ("LogRegressionBinary" ,"RCV1", "ADAM"),
126
+ # ("LogRegressionBinary" ,"RCV1", "SPSmax"),
127
+ # ("LogRegressionBinary" ,"RCV1", "Bundle"),
128
+ # ("LogRegressionBinary" ,"RCV1", "ALR-SMAG"),
129
+ # ("LogRegressionBinary" ,"RCV1", "SPBM-TR"),
130
+ # ("LogRegressionBinary" ,"RCV1", "SPBM-PF"),
131
+
132
+ # # *********************************************************
133
+ # # ------------ MNIST (LogRegressionBinaryL2) --------------
134
+ # ("LogRegressionBinaryL2" ,"MNIST", "SGD"),
135
+ # ("LogRegressionBinaryL2" ,"MNIST", "ADAM"),
136
+ # ("LogRegressionBinaryL2" ,"MNIST", "SPSmax"),
137
+ # ("LogRegressionBinaryL2" ,"MNIST", "Bundle"),
138
+ # ("LogRegressionBinaryL2" ,"MNIST", "ALR-SMAG"),
139
+ # ("LogRegressionBinaryL2" ,"MNIST", "SPBM-TR"),
140
+ # ("LogRegressionBinaryL2" ,"MNIST", "SPBM-PF"),
141
+ # # ------------- CIFAR100 (LogRegressionBinaryL2) ----------
142
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "SGD"),
143
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "ADAM"),
144
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "SPSmax"),
145
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "Bundle"),
146
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "ALR-SMAG"),
147
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "SPBM-TR"),
148
+ # ("LogRegressionBinaryL2" ,"CIFAR100", "SPBM-PF"),
149
+ # # --------------- RCV1 (LogRegressionBinaryL2) ------------
150
+ # ("LogRegressionBinaryL2", "RCV1", "SGD"),
151
+ # ("LogRegressionBinaryL2", "RCV1", "ADAM"),
152
+ # ("LogRegressionBinaryL2", "RCV1", "SPSmax"),
153
+ # ("LogRegressionBinaryL2", "RCV1", "Bundle"),
154
+ # ("LogRegressionBinaryL2", "RCV1", "ALR-SMAG"),
155
+ # ("LogRegressionBinaryL2", "RCV1", "SPBM-TR"),
156
+ # ("LogRegressionBinaryL2", "RCV1", "SPBM-PF"),
157
+ # # -------------- Duke (LogRegressionBinaryL2) -------------
158
+ # ("LogRegressionBinaryL2" ,"Duke", "SGD"),
159
+ # ("LogRegressionBinaryL2" ,"Duke", "ADAM"),
160
+ # ("LogRegressionBinaryL2" ,"Duke", "SPSmax"),
161
+ # ("LogRegressionBinaryL2" ,"Duke", "Bundle"),
162
+ # ("LogRegressionBinaryL2" ,"Duke", "ALR-SMAG"),
163
+ # ("LogRegressionBinaryL2" ,"Duke", "SPBM-TR"),
164
+ # ("LogRegressionBinaryL2" ,"Duke", "SPBM-PF"),
165
+ # # -------------- Ijcnn (LogRegressionBinaryL2) ------------
166
+ # ("LogRegressionBinaryL2", "Ijcnn", "SGD"),
167
+ # ("LogRegressionBinaryL2", "Ijcnn", "ADAM"),
168
+ # ("LogRegressionBinaryL2", "Ijcnn", "SPSmax"),
169
+ # ("LogRegressionBinaryL2", "Ijcnn", "Bundle"),
170
+ # ("LogRegressionBinaryL2", "Ijcnn", "ALR-SMAG"),
171
+ # ("LogRegressionBinaryL2", "Ijcnn", "SPBM-TR"),
172
+ # ("LogRegressionBinaryL2", "Ijcnn", "SPBM-PF"),
173
+ # # ----------------- w8a (LogRegressionBinaryL2) -----------
174
+ # ("LogRegressionBinaryL2", "w8a", "SGD"),
175
+ # ("LogRegressionBinaryL2", "w8a", "ADAM"),
176
+ # ("LogRegressionBinaryL2", "w8a", "SPSmax"),
177
+ # ("LogRegressionBinaryL2", "w8a", "Bundle"),
178
+ # ("LogRegressionBinaryL2", "w8a", "ALR-SMAG"),
179
+ # ("LogRegressionBinaryL2", "w8a", "SPBM-TR"),
180
+ # ("LogRegressionBinaryL2", "w8a", "SPBM-PF"),
181
+ ]
182
+ # <training_group>
183
+
184
+ return training_group
185
+
186
+ def batch_size() -> dict:
187
+ batch_size = {
188
+ # 15123/12560
189
+ "Shuttle": 256,
190
+ # 15000/5000
191
+ "Letter": 256,
192
+ # 528/462
193
+ "Vowel": 52,
194
+ # 60000/10000
195
+ "MNIST": 256,
196
+ # 50000/10000
197
+ "CIFAR100": 256,
198
+ # 8,677 (will be split into 7:3)---> 6073/2604
199
+ "CALTECH101_Resize_32": 256,
200
+ # 20,242 (will be split into 7:3)---> 14169/6073
201
+ "RCV1": 256,
202
+ # only 42 (38+4) examples (cancer data)
203
+ "Duke": 10,
204
+ # 35000 + 91701
205
+ "Ijcnn": 64,
206
+ # classes: 2 data: (49749 14,951) features: 300
207
+ "w8a": 128,
208
+ }
209
+ return batch_size
210
+
211
+
212
+ def epochs(OtherParas) -> dict:
213
+ epochs = {
214
+ # 15123/12560
215
+ "Shuttle": 10,
216
+ # 15000/5000
217
+ "Letter": 10,
218
+ # 528/462
219
+ "Vowel": 10,
220
+ # 60000/10000
221
+ "MNIST": 50,
222
+ # 50000/10000
223
+ "CIFAR100": 50,
224
+ # 8,677 (will be split into 7:3)---> 6073/2604
225
+ "CALTECH101_Resize_32": 50,
226
+ # 20,242 (will be split into 7:3)---> 14169/6073
227
+ "RCV1": 10,
228
+ # only 42 (38+4) examples (cancer data)
229
+ "Duke": 10,
230
+ # 35000 + 91701
231
+ "Ijcnn": 10,
232
+ # classes: 2 data: (49749 14,951) features: 300
233
+ "w8a": 10,
234
+ }
235
+ if OtherParas["debug"]:
236
+ epochs = {k: 2 for k in epochs}
237
+
238
+ return epochs
239
+
240
+
241
+ def split_train_data() -> dict:
242
+ split_train_data = {
243
+ # 20,242 + 0 (test data to large)
244
+ "RCV1": 0.7,
245
+ # only 42 (38+4) examples (Not need)
246
+ "Duke": 1,
247
+ # classes: 2 data: (35000, 91701) features: 22
248
+ "Ijcnn": 1,
249
+ # classes: 2 data: (49749 14,951) features: 300
250
+ "w8a": 1,
251
+ }
252
+ return split_train_data
253
+
254
+
255
+ def select_subset():
256
+ select_subset = {
257
+ "CALTECH101_Resize_32": True,
258
+ "CIFAR100": True,
259
+ "Duke": False,
260
+ "Ijcnn": True,
261
+ "MNIST": True,
262
+ "RCV1": True,
263
+ "w8a": True,
264
+ }
265
+ return select_subset
266
+
267
+
268
+ def subset_number_dict(OtherParas):
269
+ subset_number_dict = {
270
+ # Max: 60,000/10,000
271
+ "MNIST": (1000, 5000),
272
+ # Max: 50,000
273
+ "CIFAR100": (2000, 10000),
274
+ # Max: 8,677 (6073/2604)
275
+ "CALTECH101_Resize_32": (2000, 2604), #test max: 2604
276
+ # classes: 2 data: (35000, 91701) features: 22
277
+ "Ijcnn": (1000, 1000),
278
+ # classes: 2 data: (14169, 6,073) features: 47,236
279
+ "RCV1": (1000, 1000),
280
+ # classes: 2 data: (49749 14,951) features: 300
281
+ "w8a": (1000, 1000),
282
+ }
283
+
284
+ if OtherParas["debug"]:
285
+ subset_number_dict = {k: (50, 50) for k in subset_number_dict}
286
+ return subset_number_dict
287
+
288
+
289
+ def validation() -> dict:
290
+ validation = {
291
+ # "MNIST": True,
292
+ # "CIFAR100": True,
293
+ # "CALTECH101_Resize_32": True
294
+ }
295
+ return validation
296
+
297
+
298
+ def validation_rate() -> dict:
299
+ validation_rate = {
300
+ "MNIST": 0.3, # Max: 60,000/10,000
301
+ "CIFAR100": 0.3, # Max: 50,000
302
+ "CALTECH101_Resize_32": 0.3, # Max: 8,677 (6073/2604)
303
+ }
304
+ return validation_rate
305
+
306
+
307
+ def model_list() -> list:
308
+ model_list = [
309
+ "ResNet18",
310
+ "ResNet34",
311
+ "LeastSquares",
312
+ "LogRegressionBinary",
313
+ "LogRegressionBinaryL2",
314
+ ]
315
+ return model_list
316
+
317
+
318
+ def model_type() -> dict:
319
+ model_type = {
320
+ "ResNet18": "multi",
321
+ "ResNet34": "multi",
322
+ "LeastSquares": "multi",
323
+ "LogRegressionBinary": "binary",
324
+ "LogRegressionBinaryL2": "binary",
325
+ }
326
+ return model_type
327
+
328
+
329
+ def data_list() -> list:
330
+ data_list = [
331
+ # classes: 2 data: 42 (38+4) features: 7,129
332
+ "Duke",
333
+ # classes: 2 data: (35000, 91701) features: 22
334
+ "Ijcnn",
335
+ # classes: 2 data: (49749 14,951) features: 300
336
+ "w8a",
337
+ #
338
+ "RCV1",
339
+ "Shuttle",
340
+ "Letter",
341
+ "Vowel",
342
+ "MNIST",
343
+ "CIFAR100",
344
+ "CALTECH101_Resize_32",
345
+ ]
346
+ return data_list
347
+
348
+ def optimizer_dict(OtherParas)->dict:
349
+ optimizer_dict = {
350
+ # --------------------------- ADAM ----------------------------
351
+ "ADAM": {
352
+ "params": {
353
+ # "alpha": [2 * 1e-3],
354
+ "alpha": (
355
+ [0.5 * 1e-3, 1e-3, 2 * 1e-3]
356
+ if OtherParas["SeleParasOn"]
357
+ else [0.0005]
358
+ ),
359
+ "epsilon": [1e-8],
360
+ "beta1": [0.9],
361
+ "beta2": [0.999],
362
+ },
363
+ },
364
+ # ----------------------- ALR-SMAG ---------------------------
365
+ "ALR-SMAG": {
366
+ "params": {
367
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
368
+ "eta_max": (
369
+ [2**i for i in range(-8, 9)]
370
+ if OtherParas["SeleParasOn"]
371
+ else [0.125]
372
+ ),
373
+ "beta": [0.9],
374
+ },
375
+ },
376
+ # ------------------------ Bundle -----------------------------
377
+ "Bundle": {
378
+ "params": {
379
+ "delta": (
380
+ [2**i for i in range(-8, 9)]
381
+ if OtherParas["SeleParasOn"]
382
+ else [0.25]
383
+ ),
384
+ "cutting_number": [10],
385
+ },
386
+ },
387
+ # --------------------------- SGD -----------------------------
388
+ "SGD": {
389
+ "params": {
390
+ "alpha": (
391
+ [2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.1]
392
+ )
393
+ }
394
+ },
395
+ # -------------------------- SPSmax ---------------------------
396
+ "SPSmax": {
397
+ "params": {
398
+ "c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
399
+ "gamma": (
400
+ [2**i for i in range(-8, 9)]
401
+ if OtherParas["SeleParasOn"]
402
+ else [0.125]),
403
+ },
404
+ },
405
+ # ----------------------- SPBM-PF -----------------------------
406
+ "SPBM-PF": {
407
+ "params": {
408
+ "M": [1e-5],
409
+ "delta": (
410
+ [2**i for i in range(9, 20)]
411
+ if OtherParas["SeleParasOn"]
412
+ else [1]
413
+ ),
414
+ "cutting_number": [10],
415
+ },
416
+ },
417
+ # ----------------------- SPBM-TR -----------------------------
418
+ "SPBM-TR": {
419
+ "params": {
420
+ "M": [1e-5],
421
+ "delta": (
422
+ [2**i for i in range(9, 20)]
423
+ if OtherParas["SeleParasOn"]
424
+ else [256]
425
+ ),
426
+ "cutting_number": [10],
427
+ },
428
+ },
429
+
430
+ # ------------------- SPBM-TR-NoneLower -----------------------
431
+ "SPBM-TR-NoneLower": {
432
+ "params": {
433
+ "M": [1e-5],
434
+ "delta": (
435
+ [2**i for i in range(0, 9)]
436
+ if OtherParas["SeleParasOn"]
437
+ else [256]
438
+ ),
439
+ "cutting_number": [10],
440
+ },
441
+ },
442
+ # ------------------- SPBM-TR-NoneSpecial -----------------------
443
+ "SPBM-TR-NoneSpecial": {
444
+ "params": {
445
+ "M": [1e-5],
446
+ "delta": (
447
+ [2**i for i in range(-8, 9)]
448
+ if OtherParas["SeleParasOn"]
449
+ else [1]
450
+ ),
451
+ "cutting_number": [10],
452
+ },
453
+ },
454
+ # -------------------- SPBM-PF-NoneLower ----------------------
455
+ "SPBM-PF-NoneLower": {
456
+ "params": {
457
+ "M": [1e-5],
458
+ "delta": (
459
+ [2**i for i in range(0, 9)]
460
+ if OtherParas["SeleParasOn"]
461
+ else [0]
462
+ ),
463
+ "cutting_number": [10],
464
+ },
465
+ },
466
+
467
+
468
+ }
469
+ return optimizer_dict
470
+
@@ -0,0 +1,116 @@
1
+ import argparse
2
+ from junshan_kit import Models
3
+
4
+ def get_args():
5
+ parser = argparse.ArgumentParser(description="Combined config argument example")
6
+
7
+ allowed_models = ["LS", "LRL2","ResNet18"]
8
+ allowed_optimizers = ["Adam", "SGD",]
9
+ allowed_datasets = ["MNIST", "CIFAR100"]
10
+
11
+ model_mapping = {
12
+ "LS": "LeastSquares",
13
+ "LRL2": "LogRegressionBinaryL2",
14
+ "ResNet18": "ResNet18"
15
+ }
16
+
17
+ # Single combined argument that can appear multiple times
18
+ parser.add_argument(
19
+ "--train_group",
20
+ type=str,
21
+ nargs="+", # Allow multiple configs
22
+ required=True,
23
+ help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
24
+ )
25
+
26
+ parser.add_argument(
27
+ "--e",
28
+ type=int,
29
+ required=True,
30
+ help="Number of training epochs. Example: --e 50"
31
+ )
32
+
33
+ parser.add_argument(
34
+ "--seed",
35
+ type=int,
36
+ default=42,
37
+ help="Random seed for experiment reproducibility. Default: 42"
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--bs",
42
+ type=int,
43
+ required=True,
44
+ help="Batch size for training. Example: --bs 128"
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--cuda",
49
+ type=int,
50
+ default=0,
51
+ required=True,
52
+ help="The number of cuda. Example: --cuda 1 (default=0) "
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--s",
57
+ type=float,
58
+ default=1.0,
59
+ # required=True,
60
+ help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
61
+ )
62
+
63
+ parser.add_argument(
64
+ "--subset",
65
+ type=float,
66
+ nargs=2,
67
+ # required=True,
68
+ help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
69
+ )
70
+
71
+ args = parser.parse_args()
72
+ args.model_mapping = model_mapping
73
+
74
+
75
+ if args.subset is not None:
76
+ check_subset_info(args, parser)
77
+
78
+
79
+ check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets, model_mapping)
80
+
81
+ return args
82
+
83
+ def check_subset_info(args, parser):
84
+ total = sum(args.subset)
85
+ if args.subset[0]>1:
86
+ # CHECK
87
+ for i in args.subset:
88
+ if i < 1:
89
+ parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
90
+ else:
91
+ if abs(total - 1.0) != 0.0:
92
+ parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f})")
93
+
94
+
95
+ def check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets, model_mapping):
96
+ # Parse and validate each train_group
97
+ for cfg in args.train_group:
98
+ try:
99
+ model, dataset, optimizer = cfg.split("-")
100
+
101
+ if model not in allowed_models:
102
+ parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
103
+ if optimizer not in allowed_optimizers:
104
+ parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
105
+ if dataset not in allowed_datasets:
106
+ parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
107
+
108
+ except ValueError:
109
+ parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
110
+
111
+ for cfg in args.train_group:
112
+ model_name, dataset_name, optimizer_name = cfg.split("-")
113
+ try:
114
+ f = getattr(Models, f"Build_{model_mapping[model_name]}_{dataset_name}")
115
+ except:
116
+ print(getattr(Models, f"Build_{model_mapping[model_name]}_{dataset_name}"))