junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junshan_kit/BenchmarkFunctions.py +7 -0
- junshan_kit/Check_Info.py +44 -0
- junshan_kit/DataHub.py +214 -0
- junshan_kit/DataProcessor.py +306 -16
- junshan_kit/DataSets.py +330 -18
- junshan_kit/Evaluate_Metrics.py +113 -0
- junshan_kit/FiguresHub.py +286 -0
- junshan_kit/ModelsHub.py +239 -0
- junshan_kit/OptimizerHup/OptimizerFactory.py +130 -0
- junshan_kit/OptimizerHup/SPBM.py +350 -0
- junshan_kit/OptimizerHup/SPBM_func.py +602 -0
- junshan_kit/OptimizerHup/__init__.py +0 -0
- junshan_kit/ParametersHub.py +690 -0
- junshan_kit/Print_Info.py +109 -0
- junshan_kit/TrainingHub.py +324 -0
- junshan_kit/kit.py +83 -24
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/METADATA +6 -2
- junshan_kit-2.7.3.dist-info/RECORD +20 -0
- {junshan_kit-2.2.8.dist-info → junshan_kit-2.7.3.dist-info}/WHEEL +1 -1
- junshan_kit-2.2.8.dist-info/RECORD +0 -7
junshan_kit/DataSets.py
CHANGED
|
@@ -1,18 +1,60 @@
|
|
|
1
1
|
"""
|
|
2
2
|
----------------------------------------------------------------------
|
|
3
3
|
>>> Author : Junshan Yin
|
|
4
|
-
>>> Last Updated : 2025-
|
|
4
|
+
>>> Last Updated : 2025-10-16
|
|
5
5
|
----------------------------------------------------------------------
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
import os
|
|
8
|
+
import os
|
|
9
9
|
import pandas as pd
|
|
10
|
+
from scipy.sparse import csr_matrix
|
|
11
|
+
from scipy.io import savemat
|
|
10
12
|
import junshan_kit.DataProcessor
|
|
11
13
|
import junshan_kit.kit
|
|
12
14
|
from sklearn.preprocessing import StandardScaler
|
|
13
15
|
|
|
16
|
+
#----------------------------------------------------------
|
|
17
|
+
def _download_data(data_name, data_type):
|
|
18
|
+
"""
|
|
19
|
+
Download and extract a dataset from Jianguoyun using either Firefox or Chrome automation.
|
|
14
20
|
|
|
15
|
-
|
|
21
|
+
This helper function allows the user to manually provide a Jianguoyun download link,
|
|
22
|
+
choose a browser (Firefox or Chrome) for automated downloading, and automatically unzip the downloaded dataset into a structured local directory.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
data_name (str):
|
|
26
|
+
The name of the dataset (used as a folder name for storage).
|
|
27
|
+
|
|
28
|
+
data_type (str):
|
|
29
|
+
The dataset category, e.g., "binary" or "multi".
|
|
30
|
+
Determines the subdirectory under './exp_data/'.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
ValueError:
|
|
34
|
+
If `data_type` is not one of the allowed options: ["binary", "multi"].
|
|
35
|
+
|
|
36
|
+
Behavior:
|
|
37
|
+
- Prompts the user to input a Jianguoyun download URL.
|
|
38
|
+
- Lets the user select a download method (Firefox or Chrome).
|
|
39
|
+
- Downloads the `.zip` file into `./exp_data/{data_name}/`.
|
|
40
|
+
- Automatically extracts the zip file in the same directory.
|
|
41
|
+
- Prints progress and completion messages.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> _download_data("mnist", "binary")
|
|
45
|
+
Enter the Jianguoyun download URL: https://www.jianguoyun.com/p/abcd1234
|
|
46
|
+
Select download method:
|
|
47
|
+
1. Firefox
|
|
48
|
+
2. Chrome
|
|
49
|
+
Enter the number of your choice (1 or 2):
|
|
50
|
+
|
|
51
|
+
Note:
|
|
52
|
+
Requires `junshan_kit` with `JianguoyunDownloaderFirefox`,
|
|
53
|
+
`JianguoyunDownloaderChrome`, and `unzip_file` utilities available.
|
|
54
|
+
"""
|
|
55
|
+
allowed_types = ["binary", "multi"]
|
|
56
|
+
if data_type not in allowed_types:
|
|
57
|
+
raise ValueError(f"Invalid data_type: {data_type!r}. Must be one of {allowed_types}.")
|
|
16
58
|
from junshan_kit.kit import JianguoyunDownloaderFirefox, JianguoyunDownloaderChrome
|
|
17
59
|
|
|
18
60
|
# User selects download method
|
|
@@ -27,38 +69,308 @@ def download_data(data_name):
|
|
|
27
69
|
|
|
28
70
|
if choice == "1":
|
|
29
71
|
JianguoyunDownloaderFirefox(url, f"./exp_data/{data_name}").run()
|
|
30
|
-
print("
|
|
72
|
+
print("*** Download completed using Firefox ***")
|
|
31
73
|
break
|
|
32
74
|
elif choice == "2":
|
|
33
75
|
JianguoyunDownloaderChrome(url, f"./exp_data/{data_name}").run()
|
|
34
|
-
print("
|
|
76
|
+
print("*** Download completed using Chrome ***")
|
|
35
77
|
break
|
|
36
78
|
else:
|
|
37
|
-
print("
|
|
79
|
+
print("*** Invalid choice. Please enter 1 or 2 ***\n")
|
|
38
80
|
|
|
81
|
+
# unzip file
|
|
82
|
+
junshan_kit.kit.unzip_file(f'./exp_data/{data_name}/{data_name}.zip', f'./exp_data/{data_name}')
|
|
39
83
|
|
|
40
|
-
def
|
|
84
|
+
def _export_csv(df, data_name, data_type):
|
|
85
|
+
path = f'./exp_data/{data_name}/'
|
|
86
|
+
os.makedirs(path, exist_ok=True)
|
|
87
|
+
df.to_csv(path + f'{data_name}_num.csv', index=False)
|
|
88
|
+
print(path + f'{data_name}.csv')
|
|
41
89
|
|
|
42
|
-
csv_path = f'./exp_data/{data_name}/creditcard.csv'
|
|
43
|
-
drop_cols = []
|
|
44
|
-
label_col = 'Class'
|
|
45
|
-
label_map = {0: -1, 1: 1}
|
|
46
90
|
|
|
47
|
-
|
|
91
|
+
def _export_mat(df, data_name, label_col):
|
|
92
|
+
# Extract label and feature matrices
|
|
93
|
+
y = df[label_col].values # Target column
|
|
94
|
+
X = df.drop(columns=[label_col]).values # Feature matrix
|
|
95
|
+
|
|
96
|
+
# Convert to sparse matrices
|
|
97
|
+
X_sparse = csr_matrix(X)
|
|
98
|
+
Y_sparse = csr_matrix(y.reshape(-1, 1)) # Convert target to column sparse matrix
|
|
99
|
+
|
|
100
|
+
# Get number of samples and features
|
|
101
|
+
m, n = X.shape
|
|
102
|
+
|
|
103
|
+
# Save as a MAT file (supports large datasets)
|
|
104
|
+
save_path = f'exp_data/{data_name}/{data_name}.mat'
|
|
105
|
+
savemat(save_path, {'X': X_sparse, 'Y': Y_sparse, 'm': m, 'n': n}, do_compression=True)
|
|
106
|
+
|
|
107
|
+
# Print confirmation
|
|
108
|
+
print("Sparse MAT file saved to:", save_path)
|
|
109
|
+
print("Number of samples (m):", m)
|
|
110
|
+
print("Number of features (n):", n)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, user_one_hot_cols = [], export_csv = False, time_info = None, df = None, missing_strategy = 'drop', Paras = None):
|
|
114
|
+
|
|
115
|
+
if csv_path is not None and not os.path.exists(csv_path):
|
|
48
116
|
print('\n' + '*'*60)
|
|
49
117
|
print(f"Please download the data.")
|
|
50
118
|
print(csv_path)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
119
|
+
_download_data(data_name, data_type=data_type)
|
|
120
|
+
|
|
121
|
+
if not os.path.exists(f"./exp_data/{data_name}"):
|
|
122
|
+
print('\n' + '*'*60)
|
|
123
|
+
print(f"Please download the data.")
|
|
124
|
+
print(f"./exp_data/{data_name}")
|
|
125
|
+
_download_data(data_name, data_type=data_type)
|
|
126
|
+
|
|
127
|
+
if df is None:
|
|
128
|
+
df = pd.read_csv(csv_path)
|
|
129
|
+
|
|
54
130
|
cleaner = junshan_kit.DataProcessor.CSV_TO_Pandas()
|
|
55
|
-
df = cleaner.preprocess_dataset(
|
|
131
|
+
df = cleaner.preprocess_dataset(df, drop_cols, label_col, label_map, title_name=data_name, user_one_hot_cols=user_one_hot_cols, print_info=print_info, time_info = time_info, missing_strategy = missing_strategy)
|
|
132
|
+
|
|
133
|
+
if export_csv:
|
|
134
|
+
_export_csv(df, data_name, data_type)
|
|
135
|
+
|
|
136
|
+
if Paras is not None and Paras["export_mat"]:
|
|
137
|
+
_export_mat(df, data_name, label_col)
|
|
138
|
+
|
|
139
|
+
return df
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# ********************************************************************
|
|
144
|
+
"""
|
|
145
|
+
----------------------------------------------------------------------
|
|
146
|
+
Datasets
|
|
147
|
+
----------------------------------------------------------------------
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def credit_card_fraud_detection(data_name = "Credit_Card_Fraud_Detection", print_info = False, export_csv=False, drop_cols = []):
|
|
151
|
+
|
|
152
|
+
data_type = "binary"
|
|
153
|
+
csv_path = f'exp_data/{data_name}/creditcard.csv'
|
|
154
|
+
label_col = 'Class'
|
|
155
|
+
label_map = {0: 0, 1: 1}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
return df
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def diabetes_health_indicators(data_name = "Diabetes_Health_Indicators", print_info = False, export_csv = False, drop_cols = [], Standard = False):
|
|
165
|
+
data_type = "binary"
|
|
166
|
+
csv_path = f'exp_data/{data_name}/diabetes_dataset.csv'
|
|
167
|
+
label_col = 'diagnosed_diabetes'
|
|
168
|
+
label_map = {0: 0, 1: 1}
|
|
169
|
+
|
|
170
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
171
|
+
|
|
172
|
+
return df
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def electric_vehicle_population(data_name = "Electric_Vehicle_Population", print_info = False, export_csv = False, drop_cols = ['VIN (1-10)', 'DOL Vehicle ID', 'Vehicle Location'], Standard = False):
|
|
176
|
+
|
|
177
|
+
data_type = "binary"
|
|
178
|
+
csv_path = f'exp_data/{data_name}/Electric_Vehicle_Population_Data.csv'
|
|
179
|
+
# drop_cols = ['VIN (1-10)', 'DOL Vehicle ID', 'Vehicle Location']
|
|
180
|
+
label_col = 'Electric Vehicle Type'
|
|
181
|
+
label_map = {
|
|
182
|
+
'Battery Electric Vehicle (BEV)': 1,
|
|
183
|
+
'Plug-in Hybrid Electric Vehicle (PHEV)': 0
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
188
|
+
|
|
189
|
+
return df
|
|
190
|
+
|
|
191
|
+
def global_house_purchase(data_name = "Global_House_Purchase", print_info = False, export_csv = False, drop_cols = ['property_id'], Standard =False):
|
|
192
|
+
|
|
193
|
+
data_type = "binary"
|
|
194
|
+
csv_path = f'exp_data/{data_name}/global_house_purchase_dataset.csv'
|
|
195
|
+
label_col = 'decision'
|
|
196
|
+
label_map = {0: 0, 1: 1}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
200
|
+
|
|
201
|
+
return df
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def health_lifestyle(data_name = "Health_Lifestyle", print_info = False, export_csv = False, drop_cols = ['id'], Standard =False):
|
|
205
|
+
|
|
206
|
+
data_type = "binary"
|
|
207
|
+
csv_path = f'exp_data/{data_name}/health_lifestyle_dataset.csv'
|
|
208
|
+
|
|
209
|
+
label_col = 'disease_risk'
|
|
210
|
+
label_map = {0: 0, 1: 1}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
214
|
+
|
|
215
|
+
return df
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def medical_insurance_cost_prediction(data_name = "Medical_Insurance_Cost Prediction", print_info = False, export_csv = False, drop_cols = ['alcohol_freq'], Standard = False):
|
|
219
|
+
"""
|
|
220
|
+
1. The missing values in this dataset are handled by directly removing the corresponding column. Since the `alcohol_freq` column contains a large number of missing values, deleting the rows would result in significant data loss, so the entire column is dropped instead.
|
|
221
|
+
|
|
222
|
+
2. There are several columns that could serve as binary classification labels, such as `is_high_risk`, `cardiovascular_disease`, and `liver_disease`. In this case, `is_high_risk` is chosen as the label column.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
data_type = "binary"
|
|
226
|
+
csv_path = f'exp_data/{data_name}/medical_insurance.csv'
|
|
227
|
+
|
|
228
|
+
label_col = 'is_high_risk'
|
|
229
|
+
label_map = {0: -1, 1: 1}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
56
233
|
|
|
57
234
|
return df
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def particle_physics_event_classification(data_name = "Particle_Physics_Event_Classification", print_info = False, export_csv = False, drop_cols = [], Standard =False):
|
|
238
|
+
|
|
239
|
+
data_type = "binary"
|
|
240
|
+
csv_path = f'exp_data/{data_name}/Particle Physics Event Classification.csv'
|
|
241
|
+
|
|
242
|
+
label_col = 'Label'
|
|
243
|
+
label_map = {'s': -1, 'b': 1}
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
247
|
+
|
|
248
|
+
return df
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def adult_income_prediction(data_name = "Adult_Income_Prediction", print_info = False, export_csv=False, drop_cols = [], Standard = False):
|
|
253
|
+
|
|
254
|
+
data_type = "binary"
|
|
255
|
+
csv_path = f'./exp_data/{data_name}/adult.csv'
|
|
256
|
+
|
|
257
|
+
label_col = 'income'
|
|
258
|
+
label_map = {'<=50K': 0, '>50K': 1}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
|
|
262
|
+
|
|
263
|
+
return df
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def TamilNadu_weather_2020_2025(data_name = "TN_Weather_2020_2025", print_info = False, export_csv = False, drop_cols = ['Unnamed: 0'], Standard = False):
|
|
267
|
+
|
|
268
|
+
data_type = "binary"
|
|
269
|
+
csv_path = f'./exp_data/{data_name}/TNweather_1.8M.csv'
|
|
270
|
+
|
|
271
|
+
label_col = 'rain_tomorrow'
|
|
272
|
+
label_map = {0: 0, 1: 1}
|
|
273
|
+
|
|
274
|
+
time_info = {
|
|
275
|
+
'time_col_name': 'time',
|
|
276
|
+
'trans_type': 0
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, time_info=time_info)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
return df
|
|
283
|
+
|
|
284
|
+
def YouTube_Recommendation(data_name = "YouTube_Recommendation", print_info = False, export_csv = False, drop_cols = ['user_id']):
|
|
285
|
+
|
|
286
|
+
data_type = "binary"
|
|
287
|
+
csv_path = f'./exp_data/{data_name}/youtube recommendation dataset.csv'
|
|
58
288
|
|
|
289
|
+
label_col = 'subscribed_after'
|
|
290
|
+
label_map = {0: -1, 1: 1}
|
|
291
|
+
|
|
292
|
+
# Extraction mode.
|
|
293
|
+
# - 0 : Extract ['year', 'month', 'day', 'hour']
|
|
294
|
+
# - 1 : Extract ['hour', 'dayofweek', 'is_weekend']
|
|
295
|
+
# - 2 : Extract ['year', 'month', 'day']
|
|
296
|
+
time_info = {
|
|
297
|
+
'time_col_name': 'timestamp',
|
|
298
|
+
'trans_type': 1
|
|
299
|
+
}
|
|
59
300
|
|
|
301
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, time_info=time_info)
|
|
302
|
+
|
|
303
|
+
return df
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def Santander_Customer_Satisfaction(data_name = "Santander_Customer_Satisfaction", print_info = False, export_csv = False):
|
|
307
|
+
data_type = "binary"
|
|
308
|
+
csv_path = None
|
|
309
|
+
|
|
310
|
+
drop_cols = ['ID_code']
|
|
311
|
+
label_col = 'target'
|
|
312
|
+
label_map = {False: 0, True: 1}
|
|
313
|
+
|
|
314
|
+
df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
|
|
315
|
+
|
|
316
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, df=df)
|
|
317
|
+
|
|
318
|
+
return df
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def newsgroups_drift(data_name = "20_newsgroups.drift", print_info = False, export_csv = False):
|
|
322
|
+
data_type = "binary"
|
|
323
|
+
csv_path = None
|
|
324
|
+
|
|
325
|
+
drop_cols = ['ID_code']
|
|
326
|
+
label_col = 'target'
|
|
327
|
+
label_map = {False: 0, True: 1}
|
|
328
|
+
|
|
329
|
+
df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
|
|
330
|
+
|
|
331
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, df=df)
|
|
332
|
+
|
|
333
|
+
return df
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def Homesite_Quote_Conversion(data_name = "Homesite_Quote_Conversion", print_info = False, export_csv = False):
|
|
337
|
+
data_type = "binary"
|
|
338
|
+
csv_path = None
|
|
339
|
+
missing_strategy = 'mode'
|
|
340
|
+
|
|
341
|
+
drop_cols = ['QuoteNumber']
|
|
342
|
+
label_col = 'QuoteConversion_Flag'
|
|
343
|
+
label_map = {0: 0, 1: 1}
|
|
344
|
+
|
|
345
|
+
time_info = {
|
|
346
|
+
'time_col_name': 'Original_Quote_Date',
|
|
347
|
+
'trans_type': 2
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
|
|
351
|
+
|
|
352
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, df=df, time_info = time_info, missing_strategy = missing_strategy)
|
|
353
|
+
|
|
354
|
+
return df
|
|
355
|
+
|
|
356
|
+
def IEEE_CIS_Fraud_Detection(data_name = "IEEE-CIS_Fraud_Detection", print_info = False, export_csv = False, export_mat = False):
|
|
357
|
+
data_type = "binary"
|
|
358
|
+
csv_path = None
|
|
359
|
+
missing_strategy = 'mode'
|
|
360
|
+
|
|
361
|
+
drop_cols = ['TransactionID']
|
|
362
|
+
label_col = 'isFraud'
|
|
363
|
+
label_map = {0: 0, 1: 1}
|
|
364
|
+
|
|
365
|
+
Paras = {
|
|
366
|
+
"export_mat": export_mat
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
|
|
370
|
+
|
|
371
|
+
df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv, df=df, missing_strategy = missing_strategy, Paras = Paras)
|
|
372
|
+
|
|
373
|
+
return df
|
|
60
374
|
|
|
61
375
|
|
|
62
|
-
def wine_and_food_pairing_dataset():
|
|
63
|
-
pass
|
|
64
376
|
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn.utils import parameters_to_vector
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
def loss(X, y, model, loss_fn, Paras):
|
|
6
|
+
pred = model(X)
|
|
7
|
+
_, c = pred.shape
|
|
8
|
+
|
|
9
|
+
if c == 1:
|
|
10
|
+
# Logistic Regression with L2 (binary)
|
|
11
|
+
if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss):
|
|
12
|
+
pred = pred.view(-1).float()
|
|
13
|
+
loss = loss_fn(pred, y.float())
|
|
14
|
+
if Paras["model_name"] == "LogRegressionBinaryL2":
|
|
15
|
+
x = parameters_to_vector(model.parameters())
|
|
16
|
+
lam = Paras["lambda"]
|
|
17
|
+
loss = loss + 0.5 * lam * torch.norm(x, p=2) ** 2
|
|
18
|
+
|
|
19
|
+
else:
|
|
20
|
+
assert False
|
|
21
|
+
|
|
22
|
+
else:
|
|
23
|
+
# Least Square (mutil)
|
|
24
|
+
if isinstance(loss_fn, torch.nn.MSELoss):
|
|
25
|
+
# loss
|
|
26
|
+
y_onehot = F.one_hot(y.long(), num_classes=c).float()
|
|
27
|
+
pred_prob = torch.softmax(pred, dim=1)
|
|
28
|
+
loss = 0.5 * loss_fn(pred_prob, y_onehot) * float(c)
|
|
29
|
+
|
|
30
|
+
elif isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
31
|
+
# loss
|
|
32
|
+
loss = loss_fn(pred, y.long())
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
print(
|
|
36
|
+
f"\033[34m **** isinstance(loss_fn, torch.nn.MSELoss)? {loss_fn} **** \033[0m"
|
|
37
|
+
)
|
|
38
|
+
assert False
|
|
39
|
+
|
|
40
|
+
return loss
|
|
41
|
+
|
|
42
|
+
def compute_loss_acc(X, y, model, loss_fn, Paras):
|
|
43
|
+
pred = model(X)
|
|
44
|
+
m, c = pred.shape
|
|
45
|
+
|
|
46
|
+
if c == 1:
|
|
47
|
+
# Logistic Regression (binary)
|
|
48
|
+
if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss):
|
|
49
|
+
pred = pred.view(-1).float()
|
|
50
|
+
loss = loss_fn(pred, y).item()
|
|
51
|
+
|
|
52
|
+
if Paras["model_name"] == "LogRegressionBinaryL2":
|
|
53
|
+
x = parameters_to_vector(model.parameters())
|
|
54
|
+
lam = Paras["lambda"]
|
|
55
|
+
loss = (loss + 0.5 * lam * torch.norm(x, p=2) ** 2).item()
|
|
56
|
+
|
|
57
|
+
pred_label = (torch.sigmoid(pred) > 0.5).float()
|
|
58
|
+
correct = (pred_label == y).sum().item()
|
|
59
|
+
|
|
60
|
+
else:
|
|
61
|
+
assert False
|
|
62
|
+
|
|
63
|
+
else:
|
|
64
|
+
|
|
65
|
+
# Least Square (mutil)
|
|
66
|
+
if isinstance(loss_fn, torch.nn.MSELoss):
|
|
67
|
+
# loss
|
|
68
|
+
y_onehot = F.one_hot(y.long(), num_classes=c).float()
|
|
69
|
+
pred_label = pred.argmax(1).long()
|
|
70
|
+
pred_ont = F.one_hot(pred_label, num_classes=c).float()
|
|
71
|
+
loss = 0.5 * loss_fn(pred_ont, y_onehot).item() * c
|
|
72
|
+
|
|
73
|
+
# acc
|
|
74
|
+
correct = (pred_label == y).sum().item()
|
|
75
|
+
|
|
76
|
+
elif isinstance(loss_fn, torch.nn.CrossEntropyLoss):
|
|
77
|
+
|
|
78
|
+
# loss
|
|
79
|
+
loss = loss_fn(pred, y.long()).item()
|
|
80
|
+
|
|
81
|
+
# acc
|
|
82
|
+
# acc
|
|
83
|
+
pred_label = pred.argmax(1).long()
|
|
84
|
+
correct = (pred_label == y).sum().item()
|
|
85
|
+
|
|
86
|
+
else:
|
|
87
|
+
print(
|
|
88
|
+
f"\033[34m **** isinstance(loss_fn, torch.nn.MSELoss)? {isinstance(loss_fn, torch.nn.MSELoss)} **** \033[0m"
|
|
89
|
+
)
|
|
90
|
+
assert False
|
|
91
|
+
|
|
92
|
+
return loss, correct
|
|
93
|
+
|
|
94
|
+
def get_loss_acc(dataloader, model, loss_fn, Paras):
|
|
95
|
+
# model.eval()
|
|
96
|
+
size = len(dataloader.dataset)
|
|
97
|
+
num_batches = len(dataloader)
|
|
98
|
+
loss, correct = 0, 0
|
|
99
|
+
|
|
100
|
+
device = Paras["device"]
|
|
101
|
+
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
for X, y in dataloader:
|
|
104
|
+
X, y = X.to(device).float(), y.to(device).float()
|
|
105
|
+
per_loss, per_acc = compute_loss_acc(X, y, model, loss_fn, Paras)
|
|
106
|
+
|
|
107
|
+
loss += per_loss
|
|
108
|
+
correct += per_acc
|
|
109
|
+
|
|
110
|
+
loss /= num_batches
|
|
111
|
+
correct /= size
|
|
112
|
+
|
|
113
|
+
return loss, correct
|