junshan-kit 2.5.1__py2.py3-none-any.whl → 2.8.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+
2
+
3
+ def rosenbrock(x, a=1.0, b=100.0):
4
+ # Optimal value: (a, a^2)
5
+ return (a - x[0])**2 + b * (x[1] - x[0]**2)**2
6
+
7
+
@@ -0,0 +1,44 @@
1
+ """
2
+ ----------------------------------------------------------------------
3
+ >>> Author : Junshan Yin
4
+ >>> Last Updated : 2025-11-22
5
+ ----------------------------------------------------------------------
6
+ """
7
+
8
+ from junshan_kit import ModelsHub
9
+
10
+ def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
11
+ # Parse and validate each train_group
12
+ for cfg in args.train:
13
+ try:
14
+ model, dataset, optimizer = cfg.split("-")
15
+
16
+ if model not in allowed_models:
17
+ parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
18
+ if optimizer not in allowed_optimizers:
19
+ parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
20
+ if dataset not in allowed_datasets:
21
+ parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
22
+
23
+ except ValueError:
24
+ parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
25
+
26
+ for cfg in args.train:
27
+ model_name, dataset_name, optimizer_name = cfg.split("-")
28
+ try:
29
+ f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
30
+
31
+ except:
32
+ print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
33
+ assert False
34
+
35
+ def check_subset_info(self, args, parser):
36
+ total = sum(args.subset)
37
+ if args.subset[0]>1:
38
+ # CHECK
39
+ for i in args.subset:
40
+ if i < 1:
41
+ parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
42
+ else:
43
+ if abs(total - 1.0) != 0.0:
44
+ parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
junshan_kit/DataHub.py CHANGED
@@ -1,41 +1,106 @@
1
+ """
2
+ ----------------------------------------------------------------------
3
+ >>> Author : Junshan Yin
4
+ >>> Last Updated : 2025-10-28
5
+ ----------------------------------------------------------------------
6
+ """
7
+
1
8
  import torchvision, torch
2
9
  import torchvision.transforms as transforms
3
10
  import pandas as pd
11
+ from torch.utils.data import random_split, Subset
4
12
 
5
- from junshan_kit import DataSets, DataProcessor
13
+ from junshan_kit import DataSets, DataProcessor, ParametersHub
6
14
 
15
+ def Adult_Income_Prediction(Paras):
7
16
 
17
+ df = DataSets.adult_income_prediction()
18
+ transform = {
19
+ "train_size": 0.7,
20
+ "normalization": True
21
+ }
22
+ label_col='income'
8
23
 
24
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
9
25
 
26
+ return train_dataset, test_dataset, transform
10
27
 
11
28
 
29
+ def Credit_Card_Fraud_Detection(Paras):
30
+ df = DataSets.credit_card_fraud_detection()
31
+ transform = {
32
+ "train_size": 0.7,
33
+ "normalization": True
34
+ }
35
+ label_col='Class'
12
36
 
37
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
13
38
 
39
+ return train_dataset, test_dataset, transform
14
40
 
41
+ def Diabetes_Health_Indicators(Paras):
42
+ df = DataSets.diabetes_health_indicators()
43
+ transform = {
44
+ "train_size": 0.7,
45
+ "normalization": True
46
+ }
47
+ label_col='diagnosed_diabetes'
15
48
 
49
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
16
50
 
17
- def Adult_Income_Prediction(Paras):
51
+ return train_dataset, test_dataset, transform
18
52
 
19
- df = DataSets.adult_income_prediction()
53
+ def Electric_Vehicle_Population(Paras):
54
+ df = DataSets.electric_vehicle_population()
20
55
  transform = {
21
56
  "train_size": 0.7,
22
57
  "normalization": True
23
58
  }
24
- label_col='income'
59
+ label_col='Electric Vehicle Type'
60
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
61
+
62
+ return train_dataset, test_dataset, transform
25
63
 
64
+ def Global_House_Purchase(Paras):
65
+ df = DataSets.global_house_purchase()
66
+ transform = {
67
+ "train_size": 0.7,
68
+ "normalization": True
69
+ }
70
+ label_col='decision'
26
71
  train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
27
72
 
28
73
  return train_dataset, test_dataset, transform
29
74
 
75
+ def Health_Lifestyle(Paras):
76
+ df = DataSets.health_lifestyle()
77
+ transform = {
78
+ "train_size": 0.7,
79
+ "normalization": True
80
+ }
81
+ label_col='disease_risk'
82
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
30
83
 
31
- def Credit_Card_Fraud_Detection(Paras):
32
- df = DataSets.credit_card_fraud_detection()
84
+ return train_dataset, test_dataset, transform
85
+
86
+ def Homesite_Quote_Conversion(Paras):
87
+ df = DataSets.Homesite_Quote_Conversion()
33
88
  transform = {
34
89
  "train_size": 0.7,
35
90
  "normalization": True
36
91
  }
37
- label_col='Class'
92
+ label_col='QuoteConversion_Flag'
93
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
94
+
95
+ return train_dataset, test_dataset, transform
38
96
 
97
+ def TN_Weather_2020_2025(Paras):
98
+ df = DataSets.TamilNadu_weather_2020_2025()
99
+ transform = {
100
+ "train_size": 0.7,
101
+ "normalization": True
102
+ }
103
+ label_col='rain_tomorrow'
39
104
  train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
40
105
 
41
106
  return train_dataset, test_dataset, transform
@@ -101,6 +166,7 @@ def CIFAR100(Paras, model_name):
101
166
  download=True,
102
167
  transform=transform
103
168
  )
169
+
104
170
  if Paras["model_type"][model_name] == "binary":
105
171
  train_mask = (torch.tensor(train_dataset.targets) == 0) | (torch.tensor(train_dataset.targets) == 1)
106
172
  test_mask = (torch.tensor(test_dataset.targets) == 0) | (torch.tensor(test_dataset.targets) == 1)
@@ -111,4 +177,38 @@ def CIFAR100(Paras, model_name):
111
177
  train_dataset = torch.utils.data.Subset(train_dataset, train_indices.tolist())
112
178
  test_dataset = torch.utils.data.Subset(test_dataset, test_indices.tolist())
113
179
 
114
- return train_dataset, test_dataset, transform
180
+ return train_dataset, test_dataset, transform
181
+
182
+
183
+ def Caltech101_Resize_32(Paras, train_ratio=0.7, split=True):
184
+
185
+ transform = transforms.Compose([
186
+ # transforms.Lambda(convert_to_rgb),
187
+ transforms.Grayscale(num_output_channels=3),
188
+ transforms.Resize((32, 32)),
189
+ transforms.ToTensor(),
190
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
191
+ std=[0.229, 0.224, 0.225])
192
+ ])
193
+
194
+ full_dataset = torchvision.datasets.Caltech101(
195
+ root='./exp_data/Caltech101',
196
+ download=True,
197
+ transform=transform
198
+ )
199
+
200
+ if split:
201
+ total_size = len(full_dataset)
202
+ train_size = int(train_ratio * total_size)
203
+ test_size = total_size - train_size
204
+
205
+ train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
206
+
207
+ else:
208
+ train_dataset = full_dataset
209
+ # Empty test dataset, keep the structure consistent
210
+ test_dataset = Subset(full_dataset, [])
211
+
212
+ return train_dataset, test_dataset, transform
213
+
214
+ # <caltech101_Resize_32>
@@ -6,8 +6,13 @@
6
6
  """
7
7
 
8
8
  import pandas as pd
9
- import torch
9
+ import numpy as np
10
+ import torch, bz2
11
+ from typing import Optional
12
+ from torch.utils.data import random_split, Subset
13
+ from sklearn.datasets import load_svmlight_file
10
14
  from sklearn.preprocessing import StandardScaler
15
+ from junshan_kit import ParametersHub
11
16
 
12
17
 
13
18
  class CSV_TO_Pandas:
@@ -181,7 +186,7 @@ class CSV_TO_Pandas:
181
186
  if time_info is not None:
182
187
  df = self._trans_time_fea(df, time_info)
183
188
 
184
- # Step 3: Map target label (to -1 and +1)
189
+ # Step 3: Map target label (to 0 and +1)
185
190
  df[label_col] = df[label_col].map(label_map)
186
191
 
187
192
  # Step 4: Encode categorical features (exclude label column)
@@ -200,7 +205,7 @@ class CSV_TO_Pandas:
200
205
  # print info
201
206
  if print_info:
202
207
  pos_count = (df[label_col] == 1).sum()
203
- neg_count = (df[label_col] == -1).sum()
208
+ neg_count = (df[label_col] == 0).sum()
204
209
 
205
210
  # Step 6: Print dataset information
206
211
  print("\n" + "=" * 80)
@@ -214,8 +219,8 @@ class CSV_TO_Pandas:
214
219
  print(
215
220
  f"{'Dropping NaN & non-feature cols:':<40} {m_encoded} rows x {n_encoded} cols"
216
221
  )
217
- print(f"{'Positive samples (+1):':<40} {pos_count}")
218
- print(f"{'Negative samples (-1):':<40} {neg_count}")
222
+ print(f"{'Positive samples (1):':<40} {pos_count}")
223
+ print(f"{'Negative samples (0):':<40} {neg_count}")
219
224
  print(
220
225
  f"{'Size after one-hot encoding:':<40} {m_cleaned} rows x {n_cleaned} cols"
221
226
  )
@@ -252,7 +257,7 @@ from torch.utils.data import Dataset
252
257
  class Pandas_TO_Torch(Dataset):
253
258
 
254
259
  def __init__(self, df: pd.DataFrame,
255
- label_col: str,
260
+ label_col: str,
256
261
  ):
257
262
  self.df = df
258
263
  self.label_col = label_col
@@ -316,8 +321,6 @@ class bz2_To_Numpy:
316
321
 
317
322
 
318
323
 
319
-
320
-
321
324
  class StepByStep:
322
325
  def __init__(self):
323
326
  pass
@@ -332,3 +335,125 @@ class StepByStep:
332
335
  )
333
336
 
334
337
 
338
+ class LibSVMDataset_bz2(Dataset):
339
+ def __init__(self, path, data_name = None, Paras = None):
340
+ with bz2.open(path, 'rb') as f:
341
+ X, y = load_svmlight_file(f) # type: ignore
342
+
343
+ self.X, self.path = X, path
344
+
345
+ y = np.asanyarray(y)
346
+
347
+ if data_name is not None:
348
+ data_name = data_name.lower()
349
+
350
+ # Binary classification, with the label -1/1
351
+ if data_name in ["rcv1"]:
352
+ y = (y > 0).astype(int) # Convert to 0/1
353
+
354
+ # Multi-category, labels usually start with 1
355
+ elif data_name in [""]:
356
+ y = y - 1 # Start with 0
357
+
358
+ else:
359
+ # Default policy: Try to avoid CrossEntropyLoss errors
360
+ if np.min(y) < 0: # e.g. [-1, 1]
361
+ y = (y > 0).astype(int)
362
+ elif np.min(y) >= 1:
363
+ y = y - 1
364
+
365
+ self.y = y
366
+
367
+ def __len__(self):
368
+ return self.X.shape[0]
369
+
370
+ def __getitem__(self, idx):
371
+ xi = torch.tensor(self.X.getrow(idx).toarray(), dtype=torch.float32).squeeze(0)
372
+ yi = torch.tensor(self.y[idx], dtype=torch.float32)
373
+ return xi, yi
374
+
375
+ def __repr__(self):
376
+ num_samples = len(self.y)
377
+ num_features = self.X.shape[1]
378
+ num_classes = len(np.unique(self.y))
379
+ return (f"LibSVMDataset_bz2(\n"
380
+ f" num_samples = {num_samples},\n"
381
+ f" num_features = {num_features},\n"
382
+ f" num_classes = {num_classes}\n"
383
+ f" path = {self.path}\n"
384
+ f")")
385
+
386
+ def get_libsvm_bz2_data(train_path, test_path, data_name, Paras, split = True):
387
+
388
+ transform = "-1 → 0 for binary, y-1 for multi-class"
389
+ train_data = LibSVMDataset_bz2(train_path)
390
+
391
+ if data_name in ["Duke", "Ijcnn", "RCV1"]:
392
+ test_data = LibSVMDataset_bz2(test_path)
393
+ split = False
394
+ else:
395
+ test_data = Subset(train_data, [])
396
+
397
+
398
+ if split:
399
+ total_size = len(train_data)
400
+ train_size = int(Paras["train_ratio"] * total_size)
401
+ test_size = total_size - train_size
402
+
403
+ train_dataset, test_dataset = random_split(train_data, [train_size, test_size])
404
+
405
+ else:
406
+ train_dataset = train_data
407
+ # # Empty test dataset, keep the structure consistent
408
+ # test_dataset = Subset(train_data, [])
409
+ test_dataset = test_data
410
+
411
+ # print(test_dataset)
412
+ # assert False
413
+
414
+ return train_dataset, test_dataset, transform
415
+
416
+
417
+ def subset(dataset, ratio_or_num, seed=None) -> Subset:
418
+ """
419
+ Randomly sample a subset from a dataset.
420
+
421
+ Parameters
422
+ ----------
423
+ dataset : torch.utils.data.Dataset
424
+ The dataset to sample from.
425
+ ratio_or_num : float or int
426
+ If float in (0, 1], treated as sampling ratio.
427
+ Otherwise, treated as absolute number of samples.
428
+ seed : int, optional
429
+ Random seed for reproducibility.
430
+
431
+ Returns
432
+ -------
433
+ torch.utils.data.Subset
434
+ A randomly sampled subset of the dataset.
435
+ """
436
+
437
+ if ratio_or_num < 0:
438
+ raise ValueError(f"ratio_or_num must be non-negative, got {ratio_or_num}")
439
+
440
+ dataset_len = len(dataset)
441
+
442
+ # Determine number of samples
443
+ if isinstance(ratio_or_num, float) and 0 < ratio_or_num <= 1:
444
+ num = max(1, int(round(dataset_len * ratio_or_num)))
445
+ else:
446
+ num = int(ratio_or_num)
447
+
448
+ # Clamp to valid range
449
+ num = min(max(num, 1), dataset_len)
450
+
451
+ # Create and seed generator
452
+ generator = torch.Generator()
453
+ if seed is not None:
454
+ generator.manual_seed(seed)
455
+
456
+ # Random sampling
457
+ indices = torch.randperm(dataset_len, generator=generator)[:num].tolist()
458
+
459
+ return Subset(dataset, indices)
junshan_kit/DataSets.py CHANGED
@@ -147,12 +147,12 @@ def _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_
147
147
  ----------------------------------------------------------------------
148
148
  """
149
149
 
150
- def credit_card_fraud_detection(data_name = "Credit Card Fraud Detection", print_info = False, export_csv=False, drop_cols = []):
150
+ def credit_card_fraud_detection(data_name = "Credit_Card_Fraud_Detection", print_info = False, export_csv=False, drop_cols = []):
151
151
 
152
152
  data_type = "binary"
153
- csv_path = f'./exp_data/{data_name}/creditcard.csv'
153
+ csv_path = f'exp_data/{data_name}/creditcard.csv'
154
154
  label_col = 'Class'
155
- label_map = {0: -1, 1: 1}
155
+ label_map = {0: 0, 1: 1}
156
156
 
157
157
 
158
158
  df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
@@ -161,26 +161,26 @@ def credit_card_fraud_detection(data_name = "Credit Card Fraud Detection", print
161
161
  return df
162
162
 
163
163
 
164
- def diabetes_health_indicators(data_name = "Diabetes Health Indicators", print_info = False, export_csv = False, drop_cols = [], Standard = False):
164
+ def diabetes_health_indicators(data_name = "Diabetes_Health_Indicators", print_info = False, export_csv = False, drop_cols = [], Standard = False):
165
165
  data_type = "binary"
166
- csv_path = f'./exp_data/{data_name}/diabetes_dataset.csv'
166
+ csv_path = f'exp_data/{data_name}/diabetes_dataset.csv'
167
167
  label_col = 'diagnosed_diabetes'
168
- label_map = {0: -1, 1: 1}
168
+ label_map = {0: 0, 1: 1}
169
169
 
170
170
  df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
171
171
 
172
172
  return df
173
173
 
174
174
 
175
- def electric_vehicle_population(data_name = "Electric Vehicle Population", print_info = False, export_csv = False, drop_cols = ['VIN (1-10)', 'DOL Vehicle ID', 'Vehicle Location'], Standard = False):
175
+ def electric_vehicle_population(data_name = "Electric_Vehicle_Population", print_info = False, export_csv = False, drop_cols = ['VIN (1-10)', 'DOL Vehicle ID', 'Vehicle Location'], Standard = False):
176
176
 
177
177
  data_type = "binary"
178
- csv_path = f'./exp_data/{data_name}/Electric_Vehicle_Population_Data.csv'
178
+ csv_path = f'exp_data/{data_name}/Electric_Vehicle_Population_Data.csv'
179
179
  # drop_cols = ['VIN (1-10)', 'DOL Vehicle ID', 'Vehicle Location']
180
180
  label_col = 'Electric Vehicle Type'
181
181
  label_map = {
182
182
  'Battery Electric Vehicle (BEV)': 1,
183
- 'Plug-in Hybrid Electric Vehicle (PHEV)': -1
183
+ 'Plug-in Hybrid Electric Vehicle (PHEV)': 0
184
184
  }
185
185
 
186
186
 
@@ -188,12 +188,12 @@ def electric_vehicle_population(data_name = "Electric Vehicle Population", print
188
188
 
189
189
  return df
190
190
 
191
- def global_house_purchase(data_name = "Global House Purchase", print_info = False, export_csv = False, drop_cols = ['property_id'], Standard =False):
191
+ def global_house_purchase(data_name = "Global_House_Purchase", print_info = False, export_csv = False, drop_cols = ['property_id'], Standard =False):
192
192
 
193
193
  data_type = "binary"
194
- csv_path = f'./exp_data/{data_name}/global_house_purchase_dataset.csv'
194
+ csv_path = f'exp_data/{data_name}/global_house_purchase_dataset.csv'
195
195
  label_col = 'decision'
196
- label_map = {0: -1, 1: 1}
196
+ label_map = {0: 0, 1: 1}
197
197
 
198
198
 
199
199
  df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
@@ -201,13 +201,13 @@ def global_house_purchase(data_name = "Global House Purchase", print_info = Fals
201
201
  return df
202
202
 
203
203
 
204
- def health_lifestyle(data_name = "Health Lifestyle", print_info = False, export_csv = False, drop_cols = ['id'], Standard =False):
204
+ def health_lifestyle(data_name = "Health_Lifestyle", print_info = False, export_csv = False, drop_cols = ['id'], Standard =False):
205
205
 
206
206
  data_type = "binary"
207
- csv_path = f'./exp_data/{data_name}/health_lifestyle_dataset.csv'
207
+ csv_path = f'exp_data/{data_name}/health_lifestyle_dataset.csv'
208
208
 
209
209
  label_col = 'disease_risk'
210
- label_map = {0: -1, 1: 1}
210
+ label_map = {0: 0, 1: 1}
211
211
 
212
212
 
213
213
  df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
@@ -215,7 +215,7 @@ def health_lifestyle(data_name = "Health Lifestyle", print_info = False, export_
215
215
  return df
216
216
 
217
217
 
218
- def medical_insurance_cost_prediction(data_name = "Medical Insurance Cost Prediction", print_info = False, export_csv = False, drop_cols = ['alcohol_freq'], Standard = False):
218
+ def medical_insurance_cost_prediction(data_name = "Medical_Insurance_Cost Prediction", print_info = False, export_csv = False, drop_cols = ['alcohol_freq'], Standard = False):
219
219
  """
220
220
  1. The missing values in this dataset are handled by directly removing the corresponding column. Since the `alcohol_freq` column contains a large number of missing values, deleting the rows would result in significant data loss, so the entire column is dropped instead.
221
221
 
@@ -223,7 +223,7 @@ def medical_insurance_cost_prediction(data_name = "Medical Insurance Cost Predic
223
223
  """
224
224
 
225
225
  data_type = "binary"
226
- csv_path = f'./exp_data/{data_name}/medical_insurance.csv'
226
+ csv_path = f'exp_data/{data_name}/medical_insurance.csv'
227
227
 
228
228
  label_col = 'is_high_risk'
229
229
  label_map = {0: -1, 1: 1}
@@ -234,10 +234,10 @@ def medical_insurance_cost_prediction(data_name = "Medical Insurance Cost Predic
234
234
  return df
235
235
 
236
236
 
237
- def particle_physics_event_classification(data_name = "Particle Physics Event Classification", print_info = False, export_csv = False, drop_cols = [], Standard =False):
237
+ def particle_physics_event_classification(data_name = "Particle_Physics_Event_Classification", print_info = False, export_csv = False, drop_cols = [], Standard =False):
238
238
 
239
239
  data_type = "binary"
240
- csv_path = f'./exp_data/{data_name}/Particle Physics Event Classification.csv'
240
+ csv_path = f'exp_data/{data_name}/Particle Physics Event Classification.csv'
241
241
 
242
242
  label_col = 'Label'
243
243
  label_map = {'s': -1, 'b': 1}
@@ -249,13 +249,13 @@ def particle_physics_event_classification(data_name = "Particle Physics Event Cl
249
249
 
250
250
 
251
251
 
252
- def adult_income_prediction(data_name = "Adult Income Prediction", print_info = False, export_csv=False, drop_cols = [], Standard = False):
252
+ def adult_income_prediction(data_name = "Adult_Income_Prediction", print_info = False, export_csv=False, drop_cols = [], Standard = False):
253
253
 
254
254
  data_type = "binary"
255
255
  csv_path = f'./exp_data/{data_name}/adult.csv'
256
256
 
257
257
  label_col = 'income'
258
- label_map = {'<=50K': -1, '>50K': 1}
258
+ label_map = {'<=50K': 0, '>50K': 1}
259
259
 
260
260
 
261
261
  df = _run(csv_path, data_name, data_type, drop_cols, label_col, label_map, print_info, export_csv=export_csv)
@@ -263,13 +263,13 @@ def adult_income_prediction(data_name = "Adult Income Prediction", print_info =
263
263
  return df
264
264
 
265
265
 
266
- def TamilNadu_weather_2020_2025(data_name = "TN Weather 2020-2025", print_info = False, export_csv = False, drop_cols = ['Unnamed: 0'], Standard = False):
266
+ def TamilNadu_weather_2020_2025(data_name = "TN_Weather_2020_2025", print_info = False, export_csv = False, drop_cols = ['Unnamed: 0'], Standard = False):
267
267
 
268
268
  data_type = "binary"
269
269
  csv_path = f'./exp_data/{data_name}/TNweather_1.8M.csv'
270
270
 
271
271
  label_col = 'rain_tomorrow'
272
- label_map = {0: -1, 1: 1}
272
+ label_map = {0: 0, 1: 1}
273
273
 
274
274
  time_info = {
275
275
  'time_col_name': 'time',
@@ -281,7 +281,7 @@ def TamilNadu_weather_2020_2025(data_name = "TN Weather 2020-2025", print_info =
281
281
 
282
282
  return df
283
283
 
284
- def YouTube_Recommendation(data_name = "YouTube Recommendation", print_info = False, export_csv = False, drop_cols = ['user_id']):
284
+ def YouTube_Recommendation(data_name = "YouTube_Recommendation", print_info = False, export_csv = False, drop_cols = ['user_id']):
285
285
 
286
286
  data_type = "binary"
287
287
  csv_path = f'./exp_data/{data_name}/youtube recommendation dataset.csv'
@@ -303,13 +303,13 @@ def YouTube_Recommendation(data_name = "YouTube Recommendation", print_info = Fa
303
303
  return df
304
304
 
305
305
 
306
- def Santander_Customer_Satisfaction(data_name = "SantanderCustomerSatisfaction", print_info = False, export_csv = False):
306
+ def Santander_Customer_Satisfaction(data_name = "Santander_Customer_Satisfaction", print_info = False, export_csv = False):
307
307
  data_type = "binary"
308
308
  csv_path = None
309
309
 
310
310
  drop_cols = ['ID_code']
311
311
  label_col = 'target'
312
- label_map = {False: -1, True: 1}
312
+ label_map = {False: 0, True: 1}
313
313
 
314
314
  df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
315
315
 
@@ -324,7 +324,7 @@ def newsgroups_drift(data_name = "20_newsgroups.drift", print_info = False, expo
324
324
 
325
325
  drop_cols = ['ID_code']
326
326
  label_col = 'target'
327
- label_map = {False: -1, True: 1}
327
+ label_map = {False: 0, True: 1}
328
328
 
329
329
  df, y, categorical_indicator, attribute_names = junshan_kit.kit.download_openml_data(data_name)
330
330
 
@@ -340,7 +340,7 @@ def Homesite_Quote_Conversion(data_name = "Homesite_Quote_Conversion", print_inf
340
340
 
341
341
  drop_cols = ['QuoteNumber']
342
342
  label_col = 'QuoteConversion_Flag'
343
- label_map = {0: -1, 1: 1}
343
+ label_map = {0: 0, 1: 1}
344
344
 
345
345
  time_info = {
346
346
  'time_col_name': 'Original_Quote_Date',
@@ -353,7 +353,6 @@ def Homesite_Quote_Conversion(data_name = "Homesite_Quote_Conversion", print_inf
353
353
 
354
354
  return df
355
355
 
356
-
357
356
  def IEEE_CIS_Fraud_Detection(data_name = "IEEE-CIS_Fraud_Detection", print_info = False, export_csv = False, export_mat = False):
358
357
  data_type = "binary"
359
358
  csv_path = None
@@ -361,7 +360,7 @@ def IEEE_CIS_Fraud_Detection(data_name = "IEEE-CIS_Fraud_Detection", print_info
361
360
 
362
361
  drop_cols = ['TransactionID']
363
362
  label_col = 'isFraud'
364
- label_map = {0: -1, 1: 1}
363
+ label_map = {0: 0, 1: 1}
365
364
 
366
365
  Paras = {
367
366
  "export_mat": export_mat
@@ -2,7 +2,7 @@ import torch
2
2
  from torch.nn.utils import parameters_to_vector
3
3
  import torch.nn.functional as F
4
4
 
5
- def compute_epoch_loss(X, y, model, loss_fn, Paras):
5
+ def loss(X, y, model, loss_fn, Paras):
6
6
  pred = model(X)
7
7
  _, c = pred.shape
8
8
 
@@ -37,4 +37,77 @@ def compute_epoch_loss(X, y, model, loss_fn, Paras):
37
37
  )
38
38
  assert False
39
39
 
40
- return loss
40
+ return loss
41
+
42
+ def compute_loss_acc(X, y, model, loss_fn, Paras):
43
+ pred = model(X)
44
+ m, c = pred.shape
45
+
46
+ if c == 1:
47
+ # Logistic Regression (binary)
48
+ if isinstance(loss_fn, torch.nn.BCEWithLogitsLoss):
49
+ pred = pred.view(-1).float()
50
+ loss = loss_fn(pred, y).item()
51
+
52
+ if Paras["model_name"] == "LogRegressionBinaryL2":
53
+ x = parameters_to_vector(model.parameters())
54
+ lam = Paras["lambda"]
55
+ loss = (loss + 0.5 * lam * torch.norm(x, p=2) ** 2).item()
56
+
57
+ pred_label = (torch.sigmoid(pred) > 0.5).float()
58
+ correct = (pred_label == y).sum().item()
59
+
60
+ else:
61
+ assert False
62
+
63
+ else:
64
+
65
+ # Least Square (mutil)
66
+ if isinstance(loss_fn, torch.nn.MSELoss):
67
+ # loss
68
+ y_onehot = F.one_hot(y.long(), num_classes=c).float()
69
+ pred_label = pred.argmax(1).long()
70
+ pred_ont = F.one_hot(pred_label, num_classes=c).float()
71
+ loss = 0.5 * loss_fn(pred_ont, y_onehot).item() * c
72
+
73
+ # acc
74
+ correct = (pred_label == y).sum().item()
75
+
76
+ elif isinstance(loss_fn, torch.nn.CrossEntropyLoss):
77
+
78
+ # loss
79
+ loss = loss_fn(pred, y.long()).item()
80
+
81
+ # acc
82
+ # acc
83
+ pred_label = pred.argmax(1).long()
84
+ correct = (pred_label == y).sum().item()
85
+
86
+ else:
87
+ print(
88
+ f"\033[34m **** isinstance(loss_fn, torch.nn.MSELoss)? {isinstance(loss_fn, torch.nn.MSELoss)} **** \033[0m"
89
+ )
90
+ assert False
91
+
92
+ return loss, correct
93
+
94
+ def get_loss_acc(dataloader, model, loss_fn, Paras):
95
+ # model.eval()
96
+ size = len(dataloader.dataset)
97
+ num_batches = len(dataloader)
98
+ loss, correct = 0, 0
99
+
100
+ device = Paras["device"]
101
+
102
+ with torch.no_grad():
103
+ for X, y in dataloader:
104
+ X, y = X.to(device).float(), y.to(device).float()
105
+ per_loss, per_acc = compute_loss_acc(X, y, model, loss_fn, Paras)
106
+
107
+ loss += per_loss
108
+ correct += per_acc
109
+
110
+ loss /= num_batches
111
+ correct /= size
112
+
113
+ return loss, correct