junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +214 -0
- junshan_kit/DataProcessor.py +306 -16
- junshan_kit/DataSets.py +330 -18
- junshan_kit/Evaluate_Metrics.py +113 -0
- junshan_kit/FiguresHub.py +286 -0
- junshan_kit/ModelsHub.py +239 -0
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +350 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +690 -0
- junshan_kit/Print_Info.py +109 -0
- junshan_kit/TrainingHub.py +324 -0
- junshan_kit/kit.py +83 -24
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/METADATA +6 -2
- junshan_kit-2.7.3.dist-info/RECORD +20 -0
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/WHEEL +1 -1
- junshan_kit-2.2.8.dist-info/RECORD +0 -7
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from junshan_kit import DataSets, ParametersHub
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# -------------------------------------------------------------
|
|
5
|
+
def training_group(training_group):
|
|
6
|
+
print(f"--------------------- training_group ------------------")
|
|
7
|
+
for g in training_group:
|
|
8
|
+
print(g)
|
|
9
|
+
print(f"-------------------------------------------------------")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def training_info(args, data_name, optimizer_name, hyperparams, Paras, model_name):
|
|
13
|
+
if Paras["use_color"]:
|
|
14
|
+
print("\033[90m" + "-" * 115 + "\033[0m")
|
|
15
|
+
print(
|
|
16
|
+
f"\033[32m✅ \033[34mDataset:\033[32m {data_name}, \t\033[34mBatch-size:\033[32m {args.bs}, \t\033[34m(training, test) = \033[32m ({Paras['train_data_num']}, {Paras['test_data_num']}), \t\033[34m device:\033[32m {Paras['device']}"
|
|
17
|
+
)
|
|
18
|
+
print(
|
|
19
|
+
f"\033[32m✅ \033[34mOptimizer:\033[32m {optimizer_name}, \t\033[34mParams:\033[32m {hyperparams}"
|
|
20
|
+
)
|
|
21
|
+
print(
|
|
22
|
+
f'\033[32m✅ \033[34mmodel:\033[32m {model_name}, \t\033[34mmodel type:\033[32m {Paras["model_type"][model_name]},\t\033[34m loss_fn:\033[32m {Paras["loss_fn"]}'
|
|
23
|
+
)
|
|
24
|
+
print(f'\033[32m✅ \033[34mResults_folder:\033[32m {Paras["Results_folder"]}')
|
|
25
|
+
print("\033[90m" + "-" * 115 + "\033[0m")
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
print("-" * 115)
|
|
29
|
+
print(
|
|
30
|
+
f"✅ Dataset: {data_name}, \tBatch-size: {args.bs}, \t(training, test) = ({Paras['train_data_num']}, {Paras['test_data_num']}), \tdevice: {Paras['device']}"
|
|
31
|
+
)
|
|
32
|
+
print(f"✅ Optimizer: {optimizer_name}, \tParams: {hyperparams}")
|
|
33
|
+
print(
|
|
34
|
+
f"✅ model: {model_name}, \tmodel type: {Paras['model_type'][model_name]}, \tloss_fn: {Paras['loss_fn']}"
|
|
35
|
+
)
|
|
36
|
+
print(f"✅ Results_folder: {Paras['Results_folder']}")
|
|
37
|
+
print("-" * 115)
|
|
38
|
+
|
|
39
|
+
# <Step_7_2>
|
|
40
|
+
|
|
41
|
+
def per_epoch_info(Paras, epoch, metrics, time):
|
|
42
|
+
if Paras["use_color"]:
|
|
43
|
+
print(
|
|
44
|
+
f'\033[34m epoch = \033[32m{epoch+1}/{Paras["epochs"]}\033[0m,\t\b'
|
|
45
|
+
f'\033[34m training_loss = \033[32m{metrics["training_loss"][epoch+1]:.4e}\033[0m,\t\b'
|
|
46
|
+
f'\033[34m training_acc = \033[32m{100 * metrics["training_acc"][epoch+1]:.2f}\033[0m,\t\b'
|
|
47
|
+
f'\033[34m time = \033[32m{time:.2f}\033[0m,\t\b')
|
|
48
|
+
|
|
49
|
+
else:
|
|
50
|
+
print(
|
|
51
|
+
f"epoch = {epoch+1}/{Paras['epochs']},\t"
|
|
52
|
+
f"training_loss = {metrics['training_loss'][epoch+1]:.4e},\t"
|
|
53
|
+
f"training_acc = {100 * metrics['training_acc'][epoch+1]:.2f}%,\t"
|
|
54
|
+
f"time = {time:.2f}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def print_per_epoch_info(epoch, Paras, epoch_loss, training_loss, training_acc, test_loss, test_acc, run_time):
|
|
58
|
+
epochs = Paras["epochs"][Paras["data_name"]]
|
|
59
|
+
# result = [(k, f"{v:.4f}") for k, v in run_time.items()]
|
|
60
|
+
if Paras["use_color"]:
|
|
61
|
+
print(
|
|
62
|
+
f'\033[34m epoch = \033[32m{epoch+1}/{epochs}\033[0m,\t\b'
|
|
63
|
+
f'\033[34m epoch_loss = \033[32m{epoch_loss[epoch+1]:.4e}\033[0m,\t\b'
|
|
64
|
+
f'\033[34m train_loss = \033[32m{training_loss[epoch+1]:.4e}\033[0m,\t\b'
|
|
65
|
+
f'\033[34m train_acc = \033[32m{100 * training_acc[epoch+1]:.2f}%\033[0m,\t\b'
|
|
66
|
+
f'\033[34m test_acc = \033[32m{100 * test_acc[epoch+1]:.2f}%\033[0m,\t\b'
|
|
67
|
+
f'\033[34m time (ep, tr, te) = \033[32m({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})\033[0m')
|
|
68
|
+
else:
|
|
69
|
+
print(
|
|
70
|
+
f'epoch = {epoch+1}/{epochs},\t'
|
|
71
|
+
f'epoch_loss = {epoch_loss[epoch+1]:.4e},\t'
|
|
72
|
+
f'train_loss = {training_loss[epoch+1]:.4e},\t'
|
|
73
|
+
f'train_acc = {100 * training_acc[epoch+1]:.2f}%,\t'
|
|
74
|
+
f'test_acc = {100 * test_acc[epoch+1]:.2f}%,\t'
|
|
75
|
+
f'time (ep, tr, te) = ({run_time["epoch"]:.2f}, {run_time["train"]:.2f}, {run_time["test"]:.2f})')
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def all_data_info():
|
|
79
|
+
print(ParametersHub.data_list.__doc__)
|
|
80
|
+
|
|
81
|
+
def data_info_DHI():
|
|
82
|
+
data = DataSets.adult_income_prediction(print_info=True, export_csv=False)
|
|
83
|
+
|
|
84
|
+
def data_info_CCFD():
|
|
85
|
+
data = DataSets.credit_card_fraud_detection(print_info=True, export_csv=False)
|
|
86
|
+
|
|
87
|
+
def data_info_AIP():
|
|
88
|
+
data = DataSets.adult_income_prediction(print_info=True, export_csv=False)
|
|
89
|
+
|
|
90
|
+
def data_info_EVP():
|
|
91
|
+
data = DataSets.electric_vehicle_population(print_info=True, export_csv=False)
|
|
92
|
+
|
|
93
|
+
def data_info_GHP():
|
|
94
|
+
data = DataSets.global_house_purchase(print_info=True, export_csv=False)
|
|
95
|
+
|
|
96
|
+
def data_info_HL():
|
|
97
|
+
data = DataSets.health_lifestyle(print_info=True, export_csv=False)
|
|
98
|
+
|
|
99
|
+
def data_info_HQC():
|
|
100
|
+
data = DataSets.Homesite_Quote_Conversion(print_info=True)
|
|
101
|
+
|
|
102
|
+
def data_info_IEEE_CIS():
|
|
103
|
+
data = DataSets.IEEE_CIS_Fraud_Detection(print_info=True)
|
|
104
|
+
|
|
105
|
+
def data_info_MICP():
|
|
106
|
+
data = DataSets.medical_insurance_cost_prediction(print_info=True)
|
|
107
|
+
|
|
108
|
+
def data_info_PPE():
|
|
109
|
+
data = DataSets.particle_physics_event_classification(print_info=True)
|
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
import torch, time, pickle
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch.utils.data as Data
|
|
5
|
+
from torch.nn.utils import parameters_to_vector
|
|
6
|
+
from junshan_kit import DataHub, TrainingHub, Evaluate_Metrics, DataProcessor, Print_Info, ParametersHub
|
|
7
|
+
|
|
8
|
+
from junshan_kit.OptimizerHup import OptimizerFactory, SPBM
|
|
9
|
+
|
|
10
|
+
def chosen_loss_fn(model_name, Paras):
|
|
11
|
+
# ---------------------------------------
|
|
12
|
+
# There have an addition parameter
|
|
13
|
+
if model_name == "LogRegressionBinaryL2":
|
|
14
|
+
Paras["lambda"] = 1e-3
|
|
15
|
+
# ---------------------------------------
|
|
16
|
+
|
|
17
|
+
if model_name in ["LeastSquares"]:
|
|
18
|
+
loss_fn = nn.MSELoss()
|
|
19
|
+
|
|
20
|
+
else:
|
|
21
|
+
if Paras["model_type"][model_name] == "binary":
|
|
22
|
+
loss_fn = nn.BCEWithLogitsLoss()
|
|
23
|
+
|
|
24
|
+
elif Paras["model_type"][model_name] == "multi":
|
|
25
|
+
loss_fn = nn.CrossEntropyLoss()
|
|
26
|
+
|
|
27
|
+
else:
|
|
28
|
+
loss_fn = nn.MSELoss()
|
|
29
|
+
print("\033[91m The loss function is error!\033[0m")
|
|
30
|
+
assert False
|
|
31
|
+
|
|
32
|
+
Paras["loss_fn"] = loss_fn
|
|
33
|
+
return loss_fn, Paras
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_data(model_name, data_name, Paras):
|
|
37
|
+
# load data
|
|
38
|
+
train_path = f"./exp_data/{data_name}/{data_name}_training"
|
|
39
|
+
test_path = f"./exp_data/{data_name}/{data_name}_test"
|
|
40
|
+
|
|
41
|
+
if data_name == "MNIST":
|
|
42
|
+
train_dataset, test_dataset, transform = DataHub.MNIST(Paras, model_name)
|
|
43
|
+
|
|
44
|
+
elif data_name == "CIFAR100":
|
|
45
|
+
train_dataset, test_dataset, transform = DataHub.CIFAR100(Paras, model_name)
|
|
46
|
+
|
|
47
|
+
elif data_name == "Adult_Income_Prediction":
|
|
48
|
+
train_dataset, test_dataset, transform = DataHub.Adult_Income_Prediction(Paras)
|
|
49
|
+
|
|
50
|
+
elif data_name == "Credit_Card_Fraud_Detection":
|
|
51
|
+
train_dataset, test_dataset, transform = DataHub.Credit_Card_Fraud_Detection(Paras)
|
|
52
|
+
|
|
53
|
+
elif data_name == "Diabetes_Health_Indicators":
|
|
54
|
+
train_dataset, test_dataset, transform = DataHub.Diabetes_Health_Indicators(Paras)
|
|
55
|
+
|
|
56
|
+
elif data_name == "Electric_Vehicle_Population":
|
|
57
|
+
train_dataset, test_dataset, transform = DataHub.Electric_Vehicle_Population(Paras)
|
|
58
|
+
|
|
59
|
+
elif data_name == "Global_House_Purchase":
|
|
60
|
+
train_dataset, test_dataset, transform = DataHub.Global_House_Purchase(Paras)
|
|
61
|
+
|
|
62
|
+
elif data_name == "Health_Lifestyle":
|
|
63
|
+
train_dataset, test_dataset, transform = DataHub.Health_Lifestyle(Paras)
|
|
64
|
+
|
|
65
|
+
elif data_name == "Homesite_Quote_Conversion":
|
|
66
|
+
train_dataset, test_dataset, transform = DataHub.Homesite_Quote_Conversion(Paras)
|
|
67
|
+
|
|
68
|
+
elif data_name == "TN_Weather_2020_2025":
|
|
69
|
+
train_dataset, test_dataset, transform = DataHub.TN_Weather_2020_2025(Paras)
|
|
70
|
+
|
|
71
|
+
elif data_name == "Caltech101_Resize_32":
|
|
72
|
+
train_dataset, test_dataset, transform = DataHub.Caltech101_Resize_32(
|
|
73
|
+
Paras, 0.7, split=True
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# elif data_name in ["Vowel", "Letter", "Shuttle", "w8a"]:
|
|
77
|
+
# Paras["train_ratio"] = Paras["split_train_data"][data_name]
|
|
78
|
+
# train_dataset, test_dataset, transform = datahub.get_libsvm_data(
|
|
79
|
+
# train_path + ".txt", test_path + ".txt", data_name
|
|
80
|
+
# )
|
|
81
|
+
|
|
82
|
+
elif data_name in ["RCV1", "Duke", "Ijcnn"]:
|
|
83
|
+
train_dataset, test_dataset, transform = DataProcessor.get_libsvm_bz2_data(
|
|
84
|
+
train_path + ".bz2", test_path + ".bz2", data_name, Paras
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
else:
|
|
88
|
+
transform = None
|
|
89
|
+
print(f"The data_name is error!")
|
|
90
|
+
assert False
|
|
91
|
+
|
|
92
|
+
# Computing the number of data
|
|
93
|
+
Paras["train_data_num"] = len(train_dataset)
|
|
94
|
+
Paras["test_data_num"] = len(test_dataset)
|
|
95
|
+
|
|
96
|
+
return train_dataset, test_dataset, Paras
|
|
97
|
+
|
|
98
|
+
def get_dataloader(data_name, train_dataset, test_dataset, Paras):
|
|
99
|
+
ParametersHub.set_seed(Paras["seed"])
|
|
100
|
+
g = torch.Generator()
|
|
101
|
+
g.manual_seed(Paras["seed"])
|
|
102
|
+
|
|
103
|
+
train_loader = Data.DataLoader(
|
|
104
|
+
dataset=train_dataset,
|
|
105
|
+
shuffle=True,
|
|
106
|
+
batch_size=Paras["batch_size"],
|
|
107
|
+
generator=g,
|
|
108
|
+
num_workers=0,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
test_loader = Data.DataLoader(
|
|
112
|
+
dataset=test_dataset,
|
|
113
|
+
shuffle=False,
|
|
114
|
+
batch_size=Paras["batch_size"],
|
|
115
|
+
generator=g,
|
|
116
|
+
num_workers=0,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return train_loader, test_loader
|
|
120
|
+
|
|
121
|
+
def chosen_optimizer(optimizer_name, model, hyperparams, Paras):
|
|
122
|
+
if optimizer_name == "SGD":
|
|
123
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=hyperparams["alpha"])
|
|
124
|
+
|
|
125
|
+
elif optimizer_name == "ADAM":
|
|
126
|
+
optimizer = torch.optim.Adam(
|
|
127
|
+
model.parameters(),
|
|
128
|
+
lr=hyperparams["alpha"],
|
|
129
|
+
betas=(hyperparams["beta1"], hyperparams["beta2"]),
|
|
130
|
+
eps=hyperparams["epsilon"],
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
elif optimizer_name in ["Bundle"]:
|
|
134
|
+
optimizer = OptimizerFactory.Bundle(
|
|
135
|
+
model.parameters(), model, hyperparams, Paras
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
elif optimizer_name in ["ALR-SMAG"]:
|
|
139
|
+
optimizer = OptimizerFactory.ALR_SMAG(
|
|
140
|
+
model.parameters(), model, hyperparams, Paras
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
elif optimizer_name in ["SPBM-TR"]:
|
|
144
|
+
optimizer = SPBM.TR(model.parameters(), model, hyperparams, Paras)
|
|
145
|
+
|
|
146
|
+
elif optimizer_name in ["SPBM-TR-NoneLower"]:
|
|
147
|
+
optimizer = SPBM.TR_NoneLower(model.parameters(), model, hyperparams, Paras)
|
|
148
|
+
|
|
149
|
+
elif optimizer_name in ["SPBM-TR-NoneSpecial"]:
|
|
150
|
+
optimizer = SPBM.TR_NoneSpecial(model.parameters(), model, hyperparams, Paras)
|
|
151
|
+
|
|
152
|
+
elif optimizer_name in ["SPBM-TR-NoneCut"]:
|
|
153
|
+
optimizer = SPBM.TR_NoneCut(model.parameters(), model, hyperparams, Paras)
|
|
154
|
+
|
|
155
|
+
elif optimizer_name in ["SPBM-PF-NoneLower"]:
|
|
156
|
+
optimizer = SPBM.PF_NoneLower(model.parameters(), model, hyperparams, Paras)
|
|
157
|
+
|
|
158
|
+
elif optimizer_name in ["SPBM-PF"]:
|
|
159
|
+
optimizer = SPBM.PF(model.parameters(), model, hyperparams, Paras)
|
|
160
|
+
|
|
161
|
+
elif optimizer_name in ["SPBM-PF-NoneCut"]:
|
|
162
|
+
optimizer = SPBM.PF_NoneCut(model.parameters(), model, hyperparams, Paras)
|
|
163
|
+
|
|
164
|
+
elif optimizer_name in ["SPSmax"]:
|
|
165
|
+
optimizer = OptimizerFactory.SPSmax(
|
|
166
|
+
model.parameters(), model, hyperparams, Paras
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
else:
|
|
170
|
+
raise NotImplementedError(f"{optimizer_name} is not supported.")
|
|
171
|
+
|
|
172
|
+
return optimizer
|
|
173
|
+
|
|
174
|
+
def load_model_dataloader(base_model_fun, initial_state_dict, data_name, train_dataset, test_dataset, Paras):
|
|
175
|
+
ParametersHub.set_seed(Paras["seed"])
|
|
176
|
+
model = base_model_fun()
|
|
177
|
+
model.load_state_dict(initial_state_dict)
|
|
178
|
+
model.to(Paras["device"])
|
|
179
|
+
train_loader, test_loader = TrainingHub.get_dataloader(data_name, train_dataset, test_dataset, Paras)
|
|
180
|
+
|
|
181
|
+
return model, train_loader, test_loader
|
|
182
|
+
# <training>
|
|
183
|
+
def train(train_loader, optimizer_name, optimizer, model, loss_fn, Paras):
|
|
184
|
+
train_time = time.time()
|
|
185
|
+
metrics = ParametersHub.metrics()
|
|
186
|
+
for epoch in range(Paras["epochs"]):
|
|
187
|
+
epoch_time = time.time()
|
|
188
|
+
for index, (X, Y) in enumerate(train_loader):
|
|
189
|
+
X, Y = X.to(Paras["device"]), Y.to(Paras["device"])
|
|
190
|
+
|
|
191
|
+
if epoch == 0 and index == 0:
|
|
192
|
+
# # compute gradient norm
|
|
193
|
+
# with torch.no_grad():
|
|
194
|
+
# g_k = parameters_to_vector(
|
|
195
|
+
# [
|
|
196
|
+
# p.grad if p.grad is not None else torch.zeros_like(p)
|
|
197
|
+
# for p in model.parameters()
|
|
198
|
+
# ]
|
|
199
|
+
# )
|
|
200
|
+
# metrics["grad_norm"].append(torch.norm(g_k, p=2).detach().cpu().item())
|
|
201
|
+
# print(metrics["grad_norm"][-1])
|
|
202
|
+
|
|
203
|
+
# initial training loss
|
|
204
|
+
initial_time = time.time()
|
|
205
|
+
initial_loss, initial_correct = Evaluate_Metrics.get_loss_acc(train_loader, model, loss_fn, Paras)
|
|
206
|
+
metrics["training_loss"].append(initial_loss)
|
|
207
|
+
metrics["training_acc"].append(initial_correct)
|
|
208
|
+
|
|
209
|
+
Print_Info.per_epoch_info(Paras, -1, metrics, time.time() - initial_time)
|
|
210
|
+
|
|
211
|
+
# Update the model
|
|
212
|
+
if optimizer_name in ["SGD", "ADAM"]:
|
|
213
|
+
optimizer.zero_grad()
|
|
214
|
+
loss = Evaluate_Metrics.loss(X, Y, model, loss_fn, Paras)
|
|
215
|
+
loss.backward()
|
|
216
|
+
optimizer.step()
|
|
217
|
+
|
|
218
|
+
elif optimizer_name in [
|
|
219
|
+
"Bundle",
|
|
220
|
+
"SPBM-TR",
|
|
221
|
+
"SPBM-PF",
|
|
222
|
+
"ALR-SMAG",
|
|
223
|
+
"SPSmax",
|
|
224
|
+
"SPBM-TR-NoneSpecial",
|
|
225
|
+
"SPBM-TR-NoneLower",
|
|
226
|
+
"SPBM-TR-NoneCut",
|
|
227
|
+
"SPBM-PF-NoneCut",
|
|
228
|
+
]:
|
|
229
|
+
def closure():
|
|
230
|
+
optimizer.zero_grad()
|
|
231
|
+
loss = Evaluate_Metrics.loss(X, Y, model, loss_fn, Paras)
|
|
232
|
+
loss.backward()
|
|
233
|
+
return loss
|
|
234
|
+
|
|
235
|
+
loss = optimizer.step(closure)
|
|
236
|
+
|
|
237
|
+
else:
|
|
238
|
+
loss = 0
|
|
239
|
+
raise NotImplementedError(f"{optimizer_name} is not supported.")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# Evaluation
|
|
243
|
+
training_loss, training_acc = Evaluate_Metrics.get_loss_acc(train_loader, model, loss_fn, Paras)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
metrics["training_loss"].append(training_loss)
|
|
247
|
+
metrics["training_acc"].append(training_acc)
|
|
248
|
+
|
|
249
|
+
Print_Info.per_epoch_info(Paras, epoch, metrics, time.time() - epoch_time)
|
|
250
|
+
|
|
251
|
+
time_cost = time.time() - train_time
|
|
252
|
+
metrics["train_time"] = time_cost
|
|
253
|
+
|
|
254
|
+
return metrics
|
|
255
|
+
# <training>
|
|
256
|
+
|
|
257
|
+
def Record_Results(hyperparams,data_name, model_name, optimizer_name, metrics, Paras):
|
|
258
|
+
|
|
259
|
+
keys = list(hyperparams.keys())
|
|
260
|
+
values = list(hyperparams.values())
|
|
261
|
+
|
|
262
|
+
param_str = "_".join(f"{k}_{v}" for k, v in zip(keys, values))
|
|
263
|
+
|
|
264
|
+
if model_name not in Paras["Results_dict"]:
|
|
265
|
+
Paras["Results_dict"][model_name] = {}
|
|
266
|
+
|
|
267
|
+
if data_name not in Paras["Results_dict"][model_name]:
|
|
268
|
+
Paras["Results_dict"][model_name][data_name] = {}
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
if optimizer_name not in Paras["Results_dict"][model_name][data_name]:
|
|
272
|
+
Paras["Results_dict"][model_name][data_name][optimizer_name] = {}
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
Paras["Results_dict"][model_name][data_name][optimizer_name][param_str] = {
|
|
276
|
+
"training_acc": metrics["training_acc"],
|
|
277
|
+
"training_loss": metrics["training_loss"],
|
|
278
|
+
"train_time": metrics["train_time"]
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
return Paras
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def Save_Results(Paras, model_name, data_name, optimizer_name):
|
|
285
|
+
"""
|
|
286
|
+
Save the result dictionary for a specific (model, dataset, optimizer) combination.
|
|
287
|
+
|
|
288
|
+
Parameters
|
|
289
|
+
----------
|
|
290
|
+
Paras : dict or Namespace
|
|
291
|
+
A container holding all experiment-related information, where:
|
|
292
|
+
- Paras["Results_folder"] : str
|
|
293
|
+
Directory to save result files.
|
|
294
|
+
- Paras["Results_dict"] : dict
|
|
295
|
+
Nested dictionary storing experiment results.
|
|
296
|
+
|
|
297
|
+
model_name : str
|
|
298
|
+
Full name of the model (e.g., "LeastSquares").
|
|
299
|
+
|
|
300
|
+
data_name : str
|
|
301
|
+
Name of the dataset used in the experiment.
|
|
302
|
+
|
|
303
|
+
optimizer_name : str
|
|
304
|
+
Name of the optimizer for which the results are saved.
|
|
305
|
+
|
|
306
|
+
Notes
|
|
307
|
+
-----
|
|
308
|
+
The function generates a filename in the format:
|
|
309
|
+
Results_{model_abbr}_{dataset_abbr}_{optimizer}.pkl
|
|
310
|
+
and dumps the corresponding result dictionary to disk.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
# Construct the output file path using model/dataset abbreviations
|
|
314
|
+
filename = (
|
|
315
|
+
f'{Paras["Results_folder"]}/'
|
|
316
|
+
f'Results_{ParametersHub.model_abbr(model_name)}_'
|
|
317
|
+
f'{data_name}_'
|
|
318
|
+
f'{optimizer_name}.pkl'
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Save the nested results dict to disk
|
|
322
|
+
with open(filename, "wb") as f:
|
|
323
|
+
pickle.dump(Paras["Results_dict"][model_name][data_name][optimizer_name], f)
|
|
324
|
+
|
junshan_kit/kit.py
CHANGED
|
@@ -4,9 +4,9 @@
|
|
|
4
4
|
>>> Last Updated : 2025-10-13
|
|
5
5
|
----------------------------------------------------------------------
|
|
6
6
|
"""
|
|
7
|
-
|
|
7
|
+
import subprocess, smtplib
|
|
8
8
|
import zipfile
|
|
9
|
-
import os, time
|
|
9
|
+
import os, time, openml, pickle
|
|
10
10
|
|
|
11
11
|
from selenium import webdriver
|
|
12
12
|
from selenium.webdriver.common.by import By
|
|
@@ -35,7 +35,7 @@ def unzip_file(zip_path: str, unzip_folder: str):
|
|
|
35
35
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
36
36
|
zip_ref.extractall(unzip_folder)
|
|
37
37
|
|
|
38
|
-
print(f"
|
|
38
|
+
print(f"- Extracted '{zip_path}' to '{os.path.abspath(unzip_folder)}'")
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
# =============================================================
|
|
@@ -73,13 +73,13 @@ class JianguoyunDownloaderChrome:
|
|
|
73
73
|
self.driver = webdriver.Chrome(options=self.chrome_options)
|
|
74
74
|
|
|
75
75
|
def open_page(self):
|
|
76
|
-
print(f"
|
|
76
|
+
print(f">>> Opening page: {self.url}")
|
|
77
77
|
self.driver.get(self.url)
|
|
78
|
-
print(f"
|
|
78
|
+
print(f">>> Page loaded: {self.driver.title}")
|
|
79
79
|
|
|
80
80
|
def click_download_button(self):
|
|
81
81
|
"""Find and click the 'Download' button (supports English and Chinese)."""
|
|
82
|
-
print("
|
|
82
|
+
print(">>> Searching for the download button...")
|
|
83
83
|
wait = WebDriverWait(self.driver, 30)
|
|
84
84
|
|
|
85
85
|
try:
|
|
@@ -97,33 +97,33 @@ class JianguoyunDownloaderChrome:
|
|
|
97
97
|
|
|
98
98
|
# Click using JavaScript to avoid overlay or interaction issues
|
|
99
99
|
self.driver.execute_script("arguments[0].click();", button)
|
|
100
|
-
print(f"
|
|
100
|
+
print(f">>> Download button clicked. Files will be saved to: {self.download_path}")
|
|
101
101
|
|
|
102
102
|
# If the cloud service opens a new tab, switch to it
|
|
103
103
|
time.sleep(3)
|
|
104
104
|
if len(self.driver.window_handles) > 1:
|
|
105
105
|
self.driver.switch_to.window(self.driver.window_handles[-1])
|
|
106
|
-
print("
|
|
106
|
+
print(">>> Switched to the new download tab.")
|
|
107
107
|
|
|
108
108
|
except Exception as e:
|
|
109
|
-
print("
|
|
109
|
+
print(">>> Failed to find or click the download button:", e)
|
|
110
110
|
raise
|
|
111
111
|
|
|
112
112
|
|
|
113
113
|
def wait_for_downloads(self, timeout=3600):
|
|
114
|
-
print("
|
|
114
|
+
print(">>> Waiting for downloads to finish...")
|
|
115
115
|
start_time = time.time()
|
|
116
116
|
while time.time() - start_time < timeout:
|
|
117
117
|
downloading = [f for f in os.listdir(self.download_path) if f.endswith(".crdownload")]
|
|
118
118
|
if not downloading:
|
|
119
|
-
print("
|
|
119
|
+
print(">>> Download completed!")
|
|
120
120
|
return
|
|
121
121
|
time.sleep(2)
|
|
122
|
-
print("
|
|
122
|
+
print(">>> Timeout: download not completed within 1 hour")
|
|
123
123
|
|
|
124
124
|
def close(self):
|
|
125
125
|
self.driver.quit()
|
|
126
|
-
print("
|
|
126
|
+
print(">>> Browser closed.")
|
|
127
127
|
|
|
128
128
|
def run(self):
|
|
129
129
|
print('*' * 60)
|
|
@@ -132,7 +132,7 @@ class JianguoyunDownloaderChrome:
|
|
|
132
132
|
self.click_download_button()
|
|
133
133
|
self.wait_for_downloads()
|
|
134
134
|
except Exception as e:
|
|
135
|
-
print("
|
|
135
|
+
print(">>> Error:", e)
|
|
136
136
|
finally:
|
|
137
137
|
self.close()
|
|
138
138
|
print('*' * 60)
|
|
@@ -169,13 +169,13 @@ class JianguoyunDownloaderFirefox:
|
|
|
169
169
|
self.driver = webdriver.Firefox(service=service, options=options)
|
|
170
170
|
|
|
171
171
|
def open_page(self):
|
|
172
|
-
print(f"
|
|
172
|
+
print(f">>> Opening page: {self.url}")
|
|
173
173
|
self.driver.get(self.url)
|
|
174
|
-
print(f"
|
|
174
|
+
print(f">>> Page loaded: {self.driver.title}")
|
|
175
175
|
|
|
176
176
|
def click_download_button(self):
|
|
177
177
|
"""Find and click the 'Download' button (supports English and Chinese)."""
|
|
178
|
-
print("
|
|
178
|
+
print(">>> Searching for the download button...")
|
|
179
179
|
wait = WebDriverWait(self.driver, 30)
|
|
180
180
|
|
|
181
181
|
try:
|
|
@@ -193,21 +193,21 @@ class JianguoyunDownloaderFirefox:
|
|
|
193
193
|
|
|
194
194
|
# Click using JavaScript to avoid overlay or interaction issues
|
|
195
195
|
self.driver.execute_script("arguments[0].click();", button)
|
|
196
|
-
print(f"
|
|
196
|
+
print(f">>> Download button clicked. Files will be saved to: {self.download_path}")
|
|
197
197
|
|
|
198
198
|
# If the cloud service opens a new tab, switch to it
|
|
199
199
|
time.sleep(3)
|
|
200
200
|
if len(self.driver.window_handles) > 1:
|
|
201
201
|
self.driver.switch_to.window(self.driver.window_handles[-1])
|
|
202
|
-
print("
|
|
202
|
+
print(">>> Switched to the new download tab.")
|
|
203
203
|
|
|
204
204
|
except Exception as e:
|
|
205
|
-
print("
|
|
205
|
+
print(">>> Failed to find or click the download button:", e)
|
|
206
206
|
raise
|
|
207
207
|
|
|
208
208
|
def wait_for_download(self, timeout=3600):
|
|
209
209
|
"""Wait until all downloads are finished (auto-detects browser type)."""
|
|
210
|
-
print("
|
|
210
|
+
print(">>> Waiting for downloads to finish...")
|
|
211
211
|
start_time = time.time()
|
|
212
212
|
|
|
213
213
|
# Determine the temporary file extension based on the browser type
|
|
@@ -216,13 +216,13 @@ class JianguoyunDownloaderFirefox:
|
|
|
216
216
|
while time.time() - start_time < timeout:
|
|
217
217
|
downloading = [f for f in os.listdir(self.download_path) if f.endswith(temp_ext)]
|
|
218
218
|
if not downloading:
|
|
219
|
-
print("
|
|
219
|
+
print(">>> Download completed!")
|
|
220
220
|
return True
|
|
221
221
|
time.sleep(2)
|
|
222
222
|
|
|
223
223
|
|
|
224
224
|
def close(self):
|
|
225
|
-
print("
|
|
225
|
+
print(">>> Closing browser...")
|
|
226
226
|
self.driver.quit()
|
|
227
227
|
|
|
228
228
|
def run(self):
|
|
@@ -232,11 +232,70 @@ class JianguoyunDownloaderFirefox:
|
|
|
232
232
|
self.click_download_button()
|
|
233
233
|
self.wait_for_download(timeout=3600)
|
|
234
234
|
except Exception as e:
|
|
235
|
-
print("
|
|
235
|
+
print(">>> Error:", e)
|
|
236
236
|
finally:
|
|
237
237
|
self.close()
|
|
238
238
|
print('*' * 60)
|
|
239
239
|
|
|
240
240
|
|
|
241
|
+
def download_openml_data(data_name):
|
|
242
|
+
"""
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
X : ndarray, dataframe, or sparse matrix, shape (n_samples, n_columns)
|
|
246
|
+
Dataset
|
|
247
|
+
y : ndarray or pd.Series, shape (n_samples, ) or None
|
|
248
|
+
Target column
|
|
249
|
+
categorical_indicator : boolean ndarray
|
|
250
|
+
Mask that indicate categorical features.
|
|
251
|
+
attribute_names : List[str]
|
|
252
|
+
List of attribute names.
|
|
253
|
+
"""
|
|
254
|
+
openml.config.set_root_cache_directory(f"./exp_data/{data_name}")
|
|
255
|
+
dataset = openml.datasets.get_dataset(f'{data_name}', download_data=True)
|
|
256
|
+
X, y, categorical_indicator, attribute_names = dataset.get_data(dataset_format="dataframe")
|
|
257
|
+
|
|
258
|
+
return X, y, categorical_indicator, attribute_names
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def read_pkl_data(file_path):
|
|
262
|
+
"""
|
|
263
|
+
Read data from a pickle file at the specified path
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
file_path (str): Path to the pickle file
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
object: Data object loaded from the pickle file
|
|
270
|
+
"""
|
|
271
|
+
with open(file_path, 'rb') as f:
|
|
272
|
+
data = pickle.load(f)
|
|
273
|
+
|
|
274
|
+
return data
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def git_commit_push(commit_message, repo_path="."):
|
|
278
|
+
try:
|
|
279
|
+
subprocess.run(["git", "-C", repo_path, "add", "."], check=True)
|
|
280
|
+
subprocess.run(["git", "-C", repo_path, "commit", "-q", "-m", commit_message], check=True)
|
|
281
|
+
subprocess.run(["git", "-C", repo_path, "push", "-q"], check=True)
|
|
282
|
+
print("Submitted and pushed successfully!")
|
|
283
|
+
except subprocess.CalledProcessError as e:
|
|
284
|
+
print(f"Git Command execution failed: {e}")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def seed_meg(meg, Subject, from_email, to_email, from_pwd):
|
|
288
|
+
from email.mime.text import MIMEText
|
|
289
|
+
msg = MIMEText(meg)
|
|
290
|
+
msg["Subject"] = Subject
|
|
291
|
+
msg["From"] = from_email
|
|
292
|
+
msg["To"] = to_email
|
|
293
|
+
|
|
294
|
+
server = smtplib.SMTP_SSL("smtp.qq.com", 465)
|
|
295
|
+
server.login(from_email, from_pwd)
|
|
296
|
+
server.sendmail(from_email, [to_email], msg.as_string())
|
|
297
|
+
server.quit()
|
|
298
|
+
|
|
299
|
+
|
|
241
300
|
|
|
242
301
|
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: junshan_kit
|
|
3
|
-
Version: 2.
|
|
3
|
+
Version: 2.7.3
|
|
4
4
|
Summary: This is an optimization tool.
|
|
5
5
|
Author-email: Junshan Yin <junshanyin@163.com>
|
|
6
|
+
Requires-Dist: cvxpy==1.6.5
|
|
6
7
|
Requires-Dist: kaggle==1.7.4.5
|
|
7
8
|
Requires-Dist: kagglehub==0.3.13
|
|
9
|
+
Requires-Dist: matplotlib==3.10.3
|
|
8
10
|
Requires-Dist: numpy==2.2.6
|
|
9
|
-
Requires-Dist:
|
|
11
|
+
Requires-Dist: openml==0.15.1
|
|
10
12
|
Requires-Dist: scikit-learn==1.7.1
|
|
11
13
|
Requires-Dist: selenium==4.36.0
|
|
14
|
+
Requires-Dist: torch==2.6.0
|
|
15
|
+
Requires-Dist: torchvision==0.21.0
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
junshan_kit/BenchmarkFunctions.py,sha256=tXgZGg-CjTNz78nMyVEQflVFIJDgmmePytXjY_RT9BM,120
|
|
2
|
+
junshan_kit/Check_Info.py,sha256=Z6Ls2S7Fl4h8S9s0NB8jP_YpSLZInvQAeyjIXzq5Bpc,1872
|
|
3
|
+
junshan_kit/DataHub.py,sha256=6RCNr8dBTqK-8ey4m-baMU1qOsJP6swOFkaraGdk0fM,6801
|
|
4
|
+
junshan_kit/DataProcessor.py,sha256=W2bzugcYnwQC403GdvSmGDBhfz8X1KxJBkOAVg1vHHk,14385
|
|
5
|
+
junshan_kit/DataSets.py,sha256=DcpwWRm1_B29hIDjOhvaeKAYYeBknEW2QqsS_qm8Hxs,13367
|
|
6
|
+
junshan_kit/Evaluate_Metrics.py,sha256=PQBGU8fETIvDon1VMdouZ1dhG2n7XHYGbzs2EQUA9FM,3392
|
|
7
|
+
junshan_kit/FiguresHub.py,sha256=116cvRUGUcBqIAs0_xiRzZCzgnPaqmgI5kvNu6cAd_Q,10181
|
|
8
|
+
junshan_kit/ModelsHub.py,sha256=xM6cwLecq9vukrt1c9l7l9dy7mQn3yq0ZwrRg5f_CfM,7995
|
|
9
|
+
junshan_kit/ParametersHub.py,sha256=RSgsSlH0bgehn27lleKfboT1MuLAyIMxZ5FWC-ANbhA,19822
|
|
10
|
+
junshan_kit/Print_Info.py,sha256=uBLpeynOYSZTN8LbJupSH1SuLZ-7cMU3Yp3IlVJWB1s,4772
|
|
11
|
+
junshan_kit/TrainingHub.py,sha256=WV3cUz4JsEdGTpbTqgnU3WmlKeob8RAOuL993EsADj0,11469
|
|
12
|
+
junshan_kit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
junshan_kit/kit.py,sha256=tQGoJJQZW9BeadX2cuwhvOxX2riHBZG0iFExelS4MIY,11487
|
|
14
|
+
junshan_kit/OptimizerHup/OptimizerFactory.py,sha256=x1_cE5ZSkKffdY0uCIirocBNj2X-u_R-V5jNawJ1EfA,4607
|
|
15
|
+
junshan_kit/OptimizerHup/SPBM.py,sha256=2Yg8Fmc8OkYOrjevD8eAGfI-m-fefoOldybtlp4ZEEs,13730
|
|
16
|
+
junshan_kit/OptimizerHup/SPBM_func.py,sha256=5Fz6eHYIVGMoR_CBDA_Xk_1dnPRq3K16DUNoNaWQ2Ag,17301
|
|
17
|
+
junshan_kit/OptimizerHup/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
|
+
junshan_kit-2.7.3.dist-info/METADATA,sha256=_mIq2RLQUdQECGnjRK4qABiPr41BGiK-aCkk4EQVKik,455
|
|
19
|
+
junshan_kit-2.7.3.dist-info/WHEEL,sha256=aha0VrrYvgDJ3Xxl3db_g_MDIW-ZexDdrc_m-Hk8YY4,105
|
|
20
|
+
junshan_kit-2.7.3.dist-info/RECORD,,
|