junshan-kit 2.4.7__py2.py3-none-any.whl → 2.4.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of junshan-kit might be problematic. Click here for more details.
- junshan_kit/DataHub.py +114 -0
- junshan_kit/DataProcessor.py +114 -24
- junshan_kit/DataSets.py +186 -37
- junshan_kit/{Models.py → ModelsHub.py} +5 -0
- junshan_kit/ParametersHub.py +404 -0
- junshan_kit/Print_Info.py +6 -2
- junshan_kit/TrainingHub.py +75 -0
- junshan_kit/kit.py +94 -30
- {junshan_kit-2.4.7.dist-info → junshan_kit-2.4.9.dist-info}/METADATA +2 -2
- junshan_kit-2.4.9.dist-info/RECORD +12 -0
- junshan_kit/ComOptimizers.py +0 -126
- junshan_kit/ExperimentHub.py +0 -338
- junshan_kit/SPBM.py +0 -350
- junshan_kit/SPBM_func.py +0 -601
- junshan_kit/TrainingParas.py +0 -470
- junshan_kit/check_args.py +0 -116
- junshan_kit/datahub.py +0 -281
- junshan_kit-2.4.7.dist-info/RECORD +0 -16
- {junshan_kit-2.4.7.dist-info → junshan_kit-2.4.9.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import sys, os, torch, random
|
|
3
|
+
import argparse
|
|
4
|
+
import junshan_kit.ModelsHub as ModelsHub
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class check_args:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
def get_args(self):
|
|
12
|
+
parser = argparse.ArgumentParser(description="Combined config argument example")
|
|
13
|
+
|
|
14
|
+
allowed_models = ["LS", "LRL2","ResNet18"]
|
|
15
|
+
allowed_optimizers = ["ADAM", "SGD", "Bundle"]
|
|
16
|
+
|
|
17
|
+
allowed_datasets = ["MNIST",
|
|
18
|
+
"CIFAR100",
|
|
19
|
+
"AIP",
|
|
20
|
+
"CCFD",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
optimizers_mapping = {
|
|
24
|
+
"ADAM": "ADAM",
|
|
25
|
+
"SGD": "SGD",
|
|
26
|
+
"Bundle": "Bundle"
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
model_mapping = {
|
|
30
|
+
"LS": "LeastSquares",
|
|
31
|
+
"LRL2": "LogRegressionBinaryL2",
|
|
32
|
+
"ResNet18": "ResNet18"
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
data_name_mapping = {
|
|
36
|
+
"MNIST": "MNIST",
|
|
37
|
+
"CIFAR100": "CIFAR100",
|
|
38
|
+
"AIP": "Adult_Income_Prediction",
|
|
39
|
+
"CCFD": "Credit_Card_Fraud_Detection"
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Single combined argument that can appear multiple times
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"--train",
|
|
47
|
+
type=str,
|
|
48
|
+
nargs="+", # Allow multiple configs
|
|
49
|
+
required=True,
|
|
50
|
+
help = f"Format: model-dataset-optimizer (e.g., ResNet18-CIFAR10-Adam). model: {model_mapping}, \n datasets: {allowed_datasets}, optimizers: {allowed_optimizers},"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--e",
|
|
55
|
+
type=int,
|
|
56
|
+
required=True,
|
|
57
|
+
help="Number of training epochs. Example: --e 50"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--seed",
|
|
62
|
+
type=int,
|
|
63
|
+
default=42,
|
|
64
|
+
help="Random seed for experiment reproducibility. Default: 42"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
parser.add_argument(
|
|
68
|
+
"--bs",
|
|
69
|
+
type=int,
|
|
70
|
+
required=True,
|
|
71
|
+
help="Batch size for training. Example: --bs 128"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--cuda",
|
|
76
|
+
type=int,
|
|
77
|
+
default=0,
|
|
78
|
+
required=True,
|
|
79
|
+
help="The number of cuda. Example: --cuda 1 (default=0) "
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
parser.add_argument(
|
|
83
|
+
"--s",
|
|
84
|
+
type=float,
|
|
85
|
+
default=1.0,
|
|
86
|
+
# required=True,
|
|
87
|
+
help="Proportion of dataset to use for training split. Example: --s 0.8 (default=1.0)"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--subset",
|
|
92
|
+
type=float,
|
|
93
|
+
nargs=2,
|
|
94
|
+
# required=True,
|
|
95
|
+
help = "Two subset ratios (train, test), e.g., --subset 0.7 0.3 or --subset 500 500"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
args = parser.parse_args()
|
|
99
|
+
args.model_name_mapping = model_mapping
|
|
100
|
+
args.data_name_mapping = data_name_mapping
|
|
101
|
+
args.optimizers_name_mapping = optimizers_mapping
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
if args.subset is not None:
|
|
105
|
+
self.check_subset_info(args, parser)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
self.check_args(args, parser, allowed_models, allowed_optimizers, allowed_datasets)
|
|
109
|
+
|
|
110
|
+
return args
|
|
111
|
+
|
|
112
|
+
def check_subset_info(self, args, parser):
|
|
113
|
+
total = sum(args.subset)
|
|
114
|
+
if args.subset[0]>1:
|
|
115
|
+
# CHECK
|
|
116
|
+
for i in args.subset:
|
|
117
|
+
if i < 1:
|
|
118
|
+
parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
|
|
119
|
+
else:
|
|
120
|
+
if abs(total - 1.0) != 0.0:
|
|
121
|
+
parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
|
|
125
|
+
# Parse and validate each train_group
|
|
126
|
+
for cfg in args.train:
|
|
127
|
+
try:
|
|
128
|
+
model, dataset, optimizer = cfg.split("-")
|
|
129
|
+
|
|
130
|
+
if model not in allowed_models:
|
|
131
|
+
parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
|
|
132
|
+
if optimizer not in allowed_optimizers:
|
|
133
|
+
parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
|
|
134
|
+
if dataset not in allowed_datasets:
|
|
135
|
+
parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
|
|
136
|
+
|
|
137
|
+
except ValueError:
|
|
138
|
+
parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
|
|
139
|
+
|
|
140
|
+
for cfg in args.train:
|
|
141
|
+
model_name, dataset_name, optimizer_name = cfg.split("-")
|
|
142
|
+
try:
|
|
143
|
+
f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
|
|
144
|
+
|
|
145
|
+
except:
|
|
146
|
+
print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
|
|
147
|
+
assert False
|
|
148
|
+
|
|
149
|
+
def get_train_group(args):
|
|
150
|
+
training_group = []
|
|
151
|
+
for cfg in args.train:
|
|
152
|
+
model, dataset, optimizer = cfg.split("-")
|
|
153
|
+
training_group.append((args.model_name_mapping[model], args.data_name_mapping[dataset], args.optimizers_name_mapping[optimizer]))
|
|
154
|
+
|
|
155
|
+
return training_group
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def set_paras(args, OtherParas):
|
|
159
|
+
Paras = {
|
|
160
|
+
# Name of the folder where results will be saved.
|
|
161
|
+
"results_folder_name": OtherParas["results_folder_name"],
|
|
162
|
+
|
|
163
|
+
# Print loss every N epochs.
|
|
164
|
+
"epoch_log_interval": 1,
|
|
165
|
+
|
|
166
|
+
# Timestamp string for result saving.
|
|
167
|
+
"time_str": ["time_str"],
|
|
168
|
+
|
|
169
|
+
# Random seed
|
|
170
|
+
"seed": args.seed,
|
|
171
|
+
|
|
172
|
+
# Device used for training.
|
|
173
|
+
"cuda": f"cuda:{args.cuda}",
|
|
174
|
+
|
|
175
|
+
# batch-size
|
|
176
|
+
"batch_size": args.bs,
|
|
177
|
+
|
|
178
|
+
# epochs
|
|
179
|
+
"epochs": args.e,
|
|
180
|
+
|
|
181
|
+
# split_train_data
|
|
182
|
+
"split_train_data": args.s,
|
|
183
|
+
|
|
184
|
+
# select_subset
|
|
185
|
+
"select_subset": args.subset
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
Paras = model_list(Paras)
|
|
189
|
+
Paras = model_type(Paras)
|
|
190
|
+
Paras = data_list(Paras)
|
|
191
|
+
Paras = optimizer_dict(Paras, OtherParas)
|
|
192
|
+
Paras = device(Paras)
|
|
193
|
+
|
|
194
|
+
return Paras
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def set_seed(seed=42):
|
|
198
|
+
torch.manual_seed(seed)
|
|
199
|
+
torch.cuda.manual_seed_all(seed)
|
|
200
|
+
np.random.seed(seed)
|
|
201
|
+
random.seed(seed)
|
|
202
|
+
torch.backends.cudnn.deterministic = True
|
|
203
|
+
torch.backends.cudnn.benchmark = False
|
|
204
|
+
|
|
205
|
+
def device(Paras) -> dict:
|
|
206
|
+
device = torch.device(f"{Paras['cuda']}" if torch.cuda.is_available() else "cpu")
|
|
207
|
+
Paras["device"] = device
|
|
208
|
+
use_color = sys.stdout.isatty()
|
|
209
|
+
Paras["use_color"] = use_color
|
|
210
|
+
|
|
211
|
+
return Paras
|
|
212
|
+
|
|
213
|
+
def model_list(Paras) -> dict:
|
|
214
|
+
model_list = [
|
|
215
|
+
"ResNet18",
|
|
216
|
+
"ResNet34",
|
|
217
|
+
"LeastSquares",
|
|
218
|
+
"LogRegressionBinary",
|
|
219
|
+
"LogRegressionBinaryL2",
|
|
220
|
+
]
|
|
221
|
+
Paras["model_list"] = model_list
|
|
222
|
+
return Paras
|
|
223
|
+
|
|
224
|
+
def model_type(Paras) -> dict:
|
|
225
|
+
model_type = {
|
|
226
|
+
"ResNet18": "multi",
|
|
227
|
+
"ResNet34": "multi",
|
|
228
|
+
"LeastSquares": "multi",
|
|
229
|
+
"LogRegressionBinary": "binary",
|
|
230
|
+
"LogRegressionBinaryL2": "binary",
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
Paras["model_type"] = model_type
|
|
234
|
+
return Paras
|
|
235
|
+
|
|
236
|
+
def data_list(Paras) -> dict:
|
|
237
|
+
"""
|
|
238
|
+
Attach a predefined list of dataset names to the parameter dictionary.
|
|
239
|
+
|
|
240
|
+
The predefined datasets include:
|
|
241
|
+
- Duke:
|
|
242
|
+
- classes: 2
|
|
243
|
+
- data: 42 (38 + 4)
|
|
244
|
+
- features: 7,129
|
|
245
|
+
- Ijcnn:
|
|
246
|
+
- classes: 2
|
|
247
|
+
- data: (35,000 + 91,701)
|
|
248
|
+
- features: 22
|
|
249
|
+
- w8a:
|
|
250
|
+
- classes: 2
|
|
251
|
+
- data: (49,749 + 14,951)
|
|
252
|
+
- features: 300
|
|
253
|
+
- RCV1
|
|
254
|
+
- Shuttle
|
|
255
|
+
- Letter
|
|
256
|
+
- Vowel
|
|
257
|
+
- MNIST
|
|
258
|
+
- CIFAR100
|
|
259
|
+
- CALTECH101_Resize_32
|
|
260
|
+
- Adult Income Prediction
|
|
261
|
+
-
|
|
262
|
+
- Credit_Card_Fraud_Detection
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
data_list = [
|
|
266
|
+
"Duke",
|
|
267
|
+
"Ijcnn",
|
|
268
|
+
"w8a",
|
|
269
|
+
"RCV1",
|
|
270
|
+
"Shuttle",
|
|
271
|
+
"Letter",
|
|
272
|
+
"Vowel",
|
|
273
|
+
"MNIST",
|
|
274
|
+
"CIFAR100",
|
|
275
|
+
"CALTECH101_Resize_32",
|
|
276
|
+
"Adult_Income_Prediction",
|
|
277
|
+
"Credit_Card_Fraud_Detection"
|
|
278
|
+
]
|
|
279
|
+
Paras["data_list"] = data_list
|
|
280
|
+
return Paras
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def optimizer_dict(Paras, OtherParas)->dict:
|
|
284
|
+
optimizer_dict = {
|
|
285
|
+
# ----------------- ADAM --------------------
|
|
286
|
+
"ADAM": {
|
|
287
|
+
"params": {
|
|
288
|
+
# "alpha": [2 * 1e-3],
|
|
289
|
+
"alpha": (
|
|
290
|
+
[0.5 * 1e-3, 1e-3, 2 * 1e-3]
|
|
291
|
+
if OtherParas["SeleParasOn"]
|
|
292
|
+
else [0.0005]
|
|
293
|
+
),
|
|
294
|
+
"epsilon": [1e-8],
|
|
295
|
+
"beta1": [0.9],
|
|
296
|
+
"beta2": [0.999],
|
|
297
|
+
},
|
|
298
|
+
},
|
|
299
|
+
# ------------- ALR-SMAG --------------------
|
|
300
|
+
"ALR-SMAG": {
|
|
301
|
+
"params": {
|
|
302
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
303
|
+
"eta_max": (
|
|
304
|
+
[2**i for i in range(-8, 9)]
|
|
305
|
+
if OtherParas["SeleParasOn"]
|
|
306
|
+
else [0.125]
|
|
307
|
+
),
|
|
308
|
+
"beta": [0.9],
|
|
309
|
+
},
|
|
310
|
+
},
|
|
311
|
+
# ------------ Bundle -----------------------
|
|
312
|
+
"Bundle": {
|
|
313
|
+
"params": {
|
|
314
|
+
"delta": (
|
|
315
|
+
[2**i for i in range(-8, 9)]
|
|
316
|
+
if OtherParas["SeleParasOn"]
|
|
317
|
+
else [0.25]
|
|
318
|
+
),
|
|
319
|
+
"cutting_number": [10],
|
|
320
|
+
},
|
|
321
|
+
},
|
|
322
|
+
# ------------------- SGD -------------------
|
|
323
|
+
"SGD": {
|
|
324
|
+
"params": {
|
|
325
|
+
"alpha": (
|
|
326
|
+
[2**i for i in range(-8, 9)] if OtherParas["SeleParasOn"] else [0.1]
|
|
327
|
+
)
|
|
328
|
+
}
|
|
329
|
+
},
|
|
330
|
+
# ------------------- SPSmax ----------------
|
|
331
|
+
"SPSmax": {
|
|
332
|
+
"params": {
|
|
333
|
+
"c": ([0.1, 0.5, 1, 5, 10] if OtherParas["SeleParasOn"] else [0.1]),
|
|
334
|
+
"gamma": (
|
|
335
|
+
[2**i for i in range(-8, 9)]
|
|
336
|
+
if OtherParas["SeleParasOn"]
|
|
337
|
+
else [0.125]),
|
|
338
|
+
},
|
|
339
|
+
},
|
|
340
|
+
# -------------- SPBM-PF --------------------
|
|
341
|
+
"SPBM-PF": {
|
|
342
|
+
"params": {
|
|
343
|
+
"M": [1e-5],
|
|
344
|
+
"delta": (
|
|
345
|
+
[2**i for i in range(9, 20)]
|
|
346
|
+
if OtherParas["SeleParasOn"]
|
|
347
|
+
else [1]
|
|
348
|
+
),
|
|
349
|
+
"cutting_number": [10],
|
|
350
|
+
},
|
|
351
|
+
},
|
|
352
|
+
# -------------- SPBM-TR --------------------
|
|
353
|
+
"SPBM-TR": {
|
|
354
|
+
"params": {
|
|
355
|
+
"M": [1e-5],
|
|
356
|
+
"delta": (
|
|
357
|
+
[2**i for i in range(9, 20)]
|
|
358
|
+
if OtherParas["SeleParasOn"]
|
|
359
|
+
else [256]
|
|
360
|
+
),
|
|
361
|
+
"cutting_number": [10],
|
|
362
|
+
},
|
|
363
|
+
},
|
|
364
|
+
|
|
365
|
+
# ----------- SPBM-TR-NoneLower -------------
|
|
366
|
+
"SPBM-TR-NoneLower": {
|
|
367
|
+
"params": {
|
|
368
|
+
"M": [1e-5],
|
|
369
|
+
"delta": (
|
|
370
|
+
[2**i for i in range(0, 9)]
|
|
371
|
+
if OtherParas["SeleParasOn"]
|
|
372
|
+
else [256]
|
|
373
|
+
),
|
|
374
|
+
"cutting_number": [10],
|
|
375
|
+
},
|
|
376
|
+
},
|
|
377
|
+
# ----------- SPBM-TR-NoneSpecial -----------
|
|
378
|
+
"SPBM-TR-NoneSpecial": {
|
|
379
|
+
"params": {
|
|
380
|
+
"M": [1e-5],
|
|
381
|
+
"delta": (
|
|
382
|
+
[2**i for i in range(-8, 9)]
|
|
383
|
+
if OtherParas["SeleParasOn"]
|
|
384
|
+
else [1]
|
|
385
|
+
),
|
|
386
|
+
"cutting_number": [10],
|
|
387
|
+
},
|
|
388
|
+
},
|
|
389
|
+
# ------------- SPBM-PF-NoneLower -----------
|
|
390
|
+
"SPBM-PF-NoneLower": {
|
|
391
|
+
"params": {
|
|
392
|
+
"M": [1e-5],
|
|
393
|
+
"delta": (
|
|
394
|
+
[2**i for i in range(0, 9)]
|
|
395
|
+
if OtherParas["SeleParasOn"]
|
|
396
|
+
else [0]
|
|
397
|
+
),
|
|
398
|
+
"cutting_number": [10],
|
|
399
|
+
},
|
|
400
|
+
},
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
Paras["optimizer_dict"] = optimizer_dict
|
|
404
|
+
return Paras
|
junshan_kit/Print_Info.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
|
|
1
|
+
from junshan_kit import ParametersHub
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
# -------------------------------------------------------------
|
|
@@ -56,4 +56,8 @@ def print_per_epoch_info(use_color, epoch, Paras, epoch_loss, training_loss, tra
|
|
|
56
56
|
f'train_loss = {training_loss[epoch+1]:.4e},\t'
|
|
57
57
|
f'train_acc = {100 * training_acc[epoch+1]:.2f}%,\t'
|
|
58
58
|
f'test_acc = {100 * test_acc[epoch+1]:.2f}%,\t'
|
|
59
|
-
f'time (ep, tr, te) = ({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})')
|
|
59
|
+
f'time (ep, tr, te) = ({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})')
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def data_info():
|
|
63
|
+
print(ParametersHub.data_list.__doc__)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
|
|
3
|
+
from junshan_kit import DataHub
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def chosen_loss_fn(model_name, Paras):
|
|
8
|
+
# ---------------------------------------
|
|
9
|
+
# There have an addition parameter
|
|
10
|
+
if model_name == "LogRegressionBinaryL2":
|
|
11
|
+
Paras["lambda"] = 1e-3
|
|
12
|
+
# ---------------------------------------
|
|
13
|
+
|
|
14
|
+
if model_name in ["LeastSquares"]:
|
|
15
|
+
loss_fn = nn.MSELoss()
|
|
16
|
+
|
|
17
|
+
else:
|
|
18
|
+
if Paras["model_type"][model_name] == "binary":
|
|
19
|
+
loss_fn = nn.BCEWithLogitsLoss()
|
|
20
|
+
|
|
21
|
+
elif Paras["model_type"][model_name] == "multi":
|
|
22
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
23
|
+
|
|
24
|
+
else:
|
|
25
|
+
loss_fn = nn.MSELoss()
|
|
26
|
+
print("\033[91m The loss function is error!\033[0m")
|
|
27
|
+
assert False
|
|
28
|
+
|
|
29
|
+
Paras["loss_fn"] = loss_fn
|
|
30
|
+
|
|
31
|
+
return loss_fn, Paras
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def load_data(model_name, data_name, Paras):
|
|
35
|
+
# load data
|
|
36
|
+
train_path = f"./exp_data/{data_name}/training_data"
|
|
37
|
+
test_path = f"./exp_data/{data_name}/test_data"
|
|
38
|
+
|
|
39
|
+
if data_name == "MNIST":
|
|
40
|
+
train_dataset, test_dataset, transform = DataHub.MNIST(Paras, model_name)
|
|
41
|
+
|
|
42
|
+
elif data_name == "CIFAR100":
|
|
43
|
+
train_dataset, test_dataset, transform = DataHub.CIFAR100(Paras, model_name)
|
|
44
|
+
|
|
45
|
+
elif data_name == "Adult_Income_Prediction":
|
|
46
|
+
train_dataset, test_dataset, transform = DataHub.Adult_Income_Prediction(Paras)
|
|
47
|
+
|
|
48
|
+
elif data_name == "Credit_Card_Fraud_Detection":
|
|
49
|
+
train_dataset, test_dataset, transform = DataHub.Credit_Card_Fraud_Detection(Paras)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# elif data_name == "CALTECH101_Resize_32":
|
|
53
|
+
# Paras["train_ratio"] = 0.7
|
|
54
|
+
# train_dataset, test_dataset, transform = datahub.caltech101_Resize_32(
|
|
55
|
+
# Paras["seed"], Paras["train_ratio"], split=True
|
|
56
|
+
# )
|
|
57
|
+
|
|
58
|
+
# elif data_name in ["Vowel", "Letter", "Shuttle", "w8a"]:
|
|
59
|
+
# Paras["train_ratio"] = Paras["split_train_data"][data_name]
|
|
60
|
+
# train_dataset, test_dataset, transform = datahub.get_libsvm_data(
|
|
61
|
+
# train_path + ".txt", test_path + ".txt", data_name
|
|
62
|
+
# )
|
|
63
|
+
|
|
64
|
+
# elif data_name in ["RCV1", "Duke", "Ijcnn"]:
|
|
65
|
+
# Paras["train_ratio"] = Paras["split_train_data"][data_name]
|
|
66
|
+
# train_dataset, test_dataset, transform = datahub.get_libsvm_bz2_data(
|
|
67
|
+
# train_path + ".bz2", test_path + ".bz2", data_name, Paras
|
|
68
|
+
# )
|
|
69
|
+
|
|
70
|
+
else:
|
|
71
|
+
transform = None
|
|
72
|
+
print(f"The data_name is error!")
|
|
73
|
+
assert False
|
|
74
|
+
|
|
75
|
+
return train_dataset, test_dataset, transform
|