junshan-kit 2.2.8__py2.py3-none-any.whl → 2.7.3__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+
2
+
3
+ def rosenbrock(x, a=1.0, b=100.0):
4
+ # Optimal value: (a, a^2)
5
+ return (a - x[0])**2 + b * (x[1] - x[0]**2)**2
6
+
7
+
@@ -0,0 +1,44 @@
1
+ """
2
+ ----------------------------------------------------------------------
3
+ >>> Author : Junshan Yin
4
+ >>> Last Updated : 2025-11-22
5
+ ----------------------------------------------------------------------
6
+ """
7
+
8
+ from junshan_kit import ModelsHub
9
+
10
+ def check_args(self, args, parser, allowed_models, allowed_optimizers, allowed_datasets):
11
+ # Parse and validate each train_group
12
+ for cfg in args.train:
13
+ try:
14
+ model, dataset, optimizer = cfg.split("-")
15
+
16
+ if model not in allowed_models:
17
+ parser.error(f"Invalid model '{model}'. Choose from {allowed_models}")
18
+ if optimizer not in allowed_optimizers:
19
+ parser.error(f"Invalid optimizer '{optimizer}'. Choose from {allowed_optimizers}")
20
+ if dataset not in allowed_datasets:
21
+ parser.error(f"Invalid dataset '{dataset}'. Choose from {allowed_datasets}")
22
+
23
+ except ValueError:
24
+ parser.error(f"Invalid format '{cfg}'. Use model-dataset-optimizer")
25
+
26
+ for cfg in args.train:
27
+ model_name, dataset_name, optimizer_name = cfg.split("-")
28
+ try:
29
+ f = getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}")
30
+
31
+ except:
32
+ print(getattr(ModelsHub, f"Build_{args.model_name_mapping[model_name]}_{args.data_name_mapping[dataset_name]}"))
33
+ assert False
34
+
35
+ def check_subset_info(self, args, parser):
36
+ total = sum(args.subset)
37
+ if args.subset[0]>1:
38
+ # CHECK
39
+ for i in args.subset:
40
+ if i < 1:
41
+ parser.error(f"Invalid --subset {args.subset}: The number of subdata must > 1")
42
+ else:
43
+ if abs(total - 1.0) != 0.0:
44
+ parser.error(f"Invalid --subset {args.subset}: the values must sum to 1.0 (current sum = {total:.6f}))")
junshan_kit/DataHub.py ADDED
@@ -0,0 +1,214 @@
1
+ """
2
+ ----------------------------------------------------------------------
3
+ >>> Author : Junshan Yin
4
+ >>> Last Updated : 2025-10-28
5
+ ----------------------------------------------------------------------
6
+ """
7
+
8
+ import torchvision, torch
9
+ import torchvision.transforms as transforms
10
+ import pandas as pd
11
+ from torch.utils.data import random_split, Subset
12
+
13
+ from junshan_kit import DataSets, DataProcessor, ParametersHub
14
+
15
+ def Adult_Income_Prediction(Paras):
16
+
17
+ df = DataSets.adult_income_prediction()
18
+ transform = {
19
+ "train_size": 0.7,
20
+ "normalization": True
21
+ }
22
+ label_col='income'
23
+
24
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
25
+
26
+ return train_dataset, test_dataset, transform
27
+
28
+
29
+ def Credit_Card_Fraud_Detection(Paras):
30
+ df = DataSets.credit_card_fraud_detection()
31
+ transform = {
32
+ "train_size": 0.7,
33
+ "normalization": True
34
+ }
35
+ label_col='Class'
36
+
37
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
38
+
39
+ return train_dataset, test_dataset, transform
40
+
41
+ def Diabetes_Health_Indicators(Paras):
42
+ df = DataSets.diabetes_health_indicators()
43
+ transform = {
44
+ "train_size": 0.7,
45
+ "normalization": True
46
+ }
47
+ label_col='diagnosed_diabetes'
48
+
49
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
50
+
51
+ return train_dataset, test_dataset, transform
52
+
53
+ def Electric_Vehicle_Population(Paras):
54
+ df = DataSets.electric_vehicle_population()
55
+ transform = {
56
+ "train_size": 0.7,
57
+ "normalization": True
58
+ }
59
+ label_col='Electric Vehicle Type'
60
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
61
+
62
+ return train_dataset, test_dataset, transform
63
+
64
+ def Global_House_Purchase(Paras):
65
+ df = DataSets.global_house_purchase()
66
+ transform = {
67
+ "train_size": 0.7,
68
+ "normalization": True
69
+ }
70
+ label_col='decision'
71
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
72
+
73
+ return train_dataset, test_dataset, transform
74
+
75
+ def Health_Lifestyle(Paras):
76
+ df = DataSets.health_lifestyle()
77
+ transform = {
78
+ "train_size": 0.7,
79
+ "normalization": True
80
+ }
81
+ label_col='disease_risk'
82
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
83
+
84
+ return train_dataset, test_dataset, transform
85
+
86
+ def Homesite_Quote_Conversion(Paras):
87
+ df = DataSets.Homesite_Quote_Conversion()
88
+ transform = {
89
+ "train_size": 0.7,
90
+ "normalization": True
91
+ }
92
+ label_col='QuoteConversion_Flag'
93
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
94
+
95
+ return train_dataset, test_dataset, transform
96
+
97
+ def TN_Weather_2020_2025(Paras):
98
+ df = DataSets.TamilNadu_weather_2020_2025()
99
+ transform = {
100
+ "train_size": 0.7,
101
+ "normalization": True
102
+ }
103
+ label_col='rain_tomorrow'
104
+ train_dataset, test_dataset, transform = DataProcessor.Pandas_TO_Torch(df, label_col).to_torch(transform, Paras)
105
+
106
+ return train_dataset, test_dataset, transform
107
+
108
+
109
+
110
+ def MNIST(Paras, model_name):
111
+ """
112
+ Load the MNIST dataset and return both the training and test sets,
113
+ along with the transformation applied (ToTensor).
114
+ """
115
+ transform = torchvision.transforms.ToTensor()
116
+
117
+ train_dataset = torchvision.datasets.MNIST(
118
+ root='./exp_data/MNIST',
119
+ train=True,
120
+ download=True,
121
+ transform=transform
122
+ )
123
+
124
+ test_dataset = torchvision.datasets.MNIST(
125
+ root='./exp_data/MNIST',
126
+ train=False,
127
+ download=True,
128
+ transform=transform
129
+ )
130
+
131
+ if Paras["model_type"][model_name] == "binary":
132
+
133
+ train_mask = (train_dataset.targets == 0) | (train_dataset.targets == 1)
134
+ test_mask = (test_dataset.targets == 0) | (test_dataset.targets == 1)
135
+
136
+ train_indices = torch.nonzero(train_mask, as_tuple=True)[0]
137
+ test_indices = torch.nonzero(test_mask, as_tuple=True)[0]
138
+
139
+ train_dataset = torch.utils.data.Subset(train_dataset, train_indices.tolist())
140
+ test_dataset = torch.utils.data.Subset(test_dataset, test_indices.tolist())
141
+
142
+ return train_dataset, test_dataset, transform
143
+
144
+
145
+ def CIFAR100(Paras, model_name):
146
+ """
147
+ Load the CIFAR-100 dataset with standard normalization and return both
148
+ the training and test sets, along with the transformation applied.
149
+ """
150
+ transform = transforms.Compose([
151
+ transforms.ToTensor(),
152
+ transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
153
+ std=[0.2675, 0.2565, 0.2761])
154
+ ])
155
+
156
+ train_dataset = torchvision.datasets.CIFAR100(
157
+ root='./exp_data/CIFAR100',
158
+ train=True,
159
+ download=True,
160
+ transform=transform
161
+ )
162
+
163
+ test_dataset = torchvision.datasets.CIFAR100(
164
+ root='./exp_data/CIFAR100',
165
+ train=False,
166
+ download=True,
167
+ transform=transform
168
+ )
169
+
170
+ if Paras["model_type"][model_name] == "binary":
171
+ train_mask = (torch.tensor(train_dataset.targets) == 0) | (torch.tensor(train_dataset.targets) == 1)
172
+ test_mask = (torch.tensor(test_dataset.targets) == 0) | (torch.tensor(test_dataset.targets) == 1)
173
+
174
+ train_indices = torch.nonzero(train_mask, as_tuple=True)[0]
175
+ test_indices = torch.nonzero(test_mask, as_tuple=True)[0]
176
+
177
+ train_dataset = torch.utils.data.Subset(train_dataset, train_indices.tolist())
178
+ test_dataset = torch.utils.data.Subset(test_dataset, test_indices.tolist())
179
+
180
+ return train_dataset, test_dataset, transform
181
+
182
+
183
+ def Caltech101_Resize_32(Paras, train_ratio=0.7, split=True):
184
+
185
+ transform = transforms.Compose([
186
+ # transforms.Lambda(convert_to_rgb),
187
+ transforms.Grayscale(num_output_channels=3),
188
+ transforms.Resize((32, 32)),
189
+ transforms.ToTensor(),
190
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
191
+ std=[0.229, 0.224, 0.225])
192
+ ])
193
+
194
+ full_dataset = torchvision.datasets.Caltech101(
195
+ root='./exp_data/Caltech101',
196
+ download=True,
197
+ transform=transform
198
+ )
199
+
200
+ if split:
201
+ total_size = len(full_dataset)
202
+ train_size = int(train_ratio * total_size)
203
+ test_size = total_size - train_size
204
+
205
+ train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
206
+
207
+ else:
208
+ train_dataset = full_dataset
209
+ # Empty test dataset, keep the structure consistent
210
+ test_dataset = Subset(full_dataset, [])
211
+
212
+ return train_dataset, test_dataset, transform
213
+
214
+ # <caltech101_Resize_32>
@@ -6,19 +6,112 @@
6
6
  """
7
7
 
8
8
  import pandas as pd
9
+ import numpy as np
10
+ import torch, bz2
11
+ from torch.utils.data import random_split, Subset
12
+ from sklearn.datasets import load_svmlight_file
13
+ from sklearn.preprocessing import StandardScaler
9
14
 
10
15
 
11
16
  class CSV_TO_Pandas:
12
17
  def __init__(self):
13
18
  pass
19
+
20
+ def _trans_time_fea(self, df, time_info: dict):
21
+ """
22
+ Transform and extract time-based features from a specified datetime column.
23
+
24
+ This function converts a given column to pandas datetime format and
25
+ extracts different time-related features based on the specified mode.
26
+ It supports two extraction modes:
27
+ - type = 0: Extracts basic components (year, month, day, hour)
28
+ - type = 1: Extracts hour, day of week, and weekend indicator
29
+
30
+ Parameters
31
+ ----------
32
+ df : pandas.DataFrame
33
+ Input DataFrame containing the datetime column.
34
+ time_info:
35
+ - time_col_name : str
36
+ Name of the column containing time or datetime values.
37
+ - trans_type : int, optional, default=1
38
+ - 0 : Extract ['year', 'month', 'day', 'hour']
39
+ - 1 : Extract ['hour', 'dayofweek', 'is_weekend']
40
+
41
+ Returns
42
+ -------
43
+ pandas.DataFrame
44
+ The DataFrame with newly added time-based feature columns.
45
+
46
+ Notes
47
+ -----
48
+ - Rows that cannot be parsed as valid datetime will be dropped automatically.
49
+ - 'dayofweek' ranges from 0 (Monday) to 6 (Sunday).
50
+ - 'is_weekend' equals 1 if the day is Saturday or Sunday, otherwise 0.
51
+
52
+ Examples
53
+ --------
54
+ >>> import pandas as pd
55
+ >>> data = pd.DataFrame({
56
+ ... 'timestamp': ['2023-08-01 12:30:00', '2023-08-05 08:15:00', 'invalid_time']
57
+ ... })
58
+ >>> df = handler._trans_time_fea(data, {"time_col_name": "timestamp", "trans_type": 1})
59
+ >>> print(df)
60
+ timestamp hour dayofweek is_weekend
61
+ 0 2023-08-01 12:30:00 12 1 0
62
+ 1 2023-08-05 08:15:00 8 5 1
63
+ """
64
+
65
+ time_col_name, trans_type = time_info['time_col_name'], time_info['trans_type']
66
+
67
+ df[time_col_name] = pd.to_datetime(df[time_col_name], errors="coerce")
68
+
69
+ # Drop rows where the datetime conversion failed, and make an explicit copy
70
+ df = df.dropna(subset=[time_col_name]).copy()
71
+
72
+ if trans_type == 0:
73
+ df.loc[:, "year"] = df[time_col_name].dt.year
74
+ df.loc[:, "month"] = df[time_col_name].dt.month
75
+ df.loc[:, "day"] = df[time_col_name].dt.day
76
+ df.loc[:, "hour"] = df[time_col_name].dt.hour
77
+
78
+ user_text_fea = ['year','month','day', 'hour']
79
+ df = pd.get_dummies(df, columns=user_text_fea, dtype=int)
80
+
81
+ elif trans_type == 1:
82
+ df.loc[:, "hour"] = df[time_col_name].dt.hour
83
+ df.loc[:, "dayofweek"] = df[time_col_name].dt.dayofweek
84
+ df.loc[:, "is_weekend"] = df["dayofweek"].isin([5, 6]).astype(int)
85
+
86
+ user_text_fea = ['hour','dayofweek','is_weekend']
87
+ df = pd.get_dummies(df, columns=user_text_fea, dtype=int)
88
+
89
+ elif trans_type == 2:
90
+ df.loc[:, "year"] = df[time_col_name].dt.year
91
+ df.loc[:, "month"] = df[time_col_name].dt.month
92
+ df.loc[:, "day"] = df[time_col_name].dt.day
93
+
94
+
95
+ user_text_fea = ['year','month','day']
96
+ df = pd.get_dummies(df, columns=user_text_fea, dtype=int)
97
+ else:
98
+ print("error!")
99
+
100
+ df = df.drop(columns=[time_col_name])
101
+
102
+ return df
14
103
 
15
104
  def preprocess_dataset(
16
105
  self,
17
- csv_path,
106
+ df,
18
107
  drop_cols: list,
19
108
  label_col: str,
20
109
  label_map: dict,
110
+ title_name: str,
111
+ user_one_hot_cols=[],
21
112
  print_info=False,
113
+ time_info: dict | None = None,
114
+ missing_strategy = 'drop', # [drop, mode]
22
115
  ):
23
116
  """
24
117
  Preprocess a CSV dataset by performing data cleaning, label mapping, and feature encoding.
@@ -40,6 +133,9 @@ class CSV_TO_Pandas:
40
133
  print_info (bool, optional):
41
134
  Whether to print preprocessing information and dataset statistics.
42
135
  Defaults to False.
136
+ title_name (str):
137
+ Title used for the summary table or report that documents
138
+ the preprocessing steps and dataset statistics.
43
139
 
44
140
  Returns:
45
141
  pandas.DataFrame:
@@ -64,7 +160,8 @@ class CSV_TO_Pandas:
64
160
  ... )
65
161
  """
66
162
  # Step 0: Load the dataset
67
- df = pd.read_csv(csv_path)
163
+ # df = pd.read_csv(csv_path)
164
+ columns = df.columns
68
165
 
69
166
  # Save original size
70
167
  m_original, n_original = df.shape
@@ -73,10 +170,21 @@ class CSV_TO_Pandas:
73
170
  df = df.drop(columns=drop_cols)
74
171
 
75
172
  # Step 2: Remove rows with missing values
76
- df = df.dropna(axis=0, how="any")
173
+ if missing_strategy == 'drop':
174
+ df = df.dropna(axis=0, how="any")
175
+
176
+ elif missing_strategy == 'mode':
177
+ for col in df.columns:
178
+ if df[col].notna().any():
179
+ mode_val = df[col].mode()[0]
180
+ df[col] = df[col].fillna(mode_val)
181
+
77
182
  m_encoded, n_encoded = df.shape
78
183
 
79
- # Step 3: Map target label (to -1 and +1)
184
+ if time_info is not None:
185
+ df = self._trans_time_fea(df, time_info)
186
+
187
+ # Step 3: Map target label (to 0 and +1)
80
188
  df[label_col] = df[label_col].map(label_map)
81
189
 
82
190
  # Step 4: Encode categorical features (exclude label column)
@@ -87,36 +195,218 @@ class CSV_TO_Pandas:
87
195
  col for col in text_feature_cols if col != label_col
88
196
  ] # ✅ exclude label
89
197
 
90
- df = pd.get_dummies(df, columns=text_feature_cols, dtype=int)
198
+ df = pd.get_dummies(
199
+ df, columns=text_feature_cols + user_one_hot_cols, dtype=int
200
+ )
91
201
  m_cleaned, n_cleaned = df.shape
92
202
 
93
203
  # print info
94
204
  if print_info:
95
205
  pos_count = (df[label_col] == 1).sum()
96
- neg_count = (df[label_col] == -1).sum()
206
+ neg_count = (df[label_col] == 0).sum()
97
207
 
98
208
  # Step 6: Print dataset information
99
209
  print("\n" + "=" * 80)
100
- print(f"{'Dataset Info':^70}")
210
+ print(f"{f'{title_name} - Summary':^70}")
101
211
  print("=" * 80)
102
212
  print(f"{'Original size:':<40} {m_original} rows x {n_original} cols")
103
213
  print(
104
- f"{'Size after dropping NaN & non-feature cols:':<40} {m_cleaned} rows x {n_cleaned} cols"
214
+ f"{'Dropped non-feature columns:':<40} {', '.join(drop_cols) if drop_cols else 'None'}"
105
215
  )
106
- print(f"{'Positive samples (+1):':<40} {pos_count}")
107
- print(f"{'Negative samples (-1):':<40} {neg_count}")
216
+ print(f"{'missing_strategy:':<40} {missing_strategy}")
108
217
  print(
109
- f"{'Size after one-hot encoding:':<40} {m_encoded} rows x {n_encoded} cols"
218
+ f"{'Dropping NaN & non-feature cols:':<40} {m_encoded} rows x {n_encoded} cols"
110
219
  )
111
- print("-" * 80)
112
- print(f"Note:")
113
- print(f"{'Label column:':<40} {label_col}")
220
+ print(f"{'Positive samples (1):':<40} {pos_count}")
221
+ print(f"{'Negative samples (0):':<40} {neg_count}")
114
222
  print(
115
- f"{'Dropped non-feature columns:':<40} {', '.join(drop_cols) if drop_cols else 'None'}"
223
+ f"{'Size after one-hot encoding:':<40} {m_cleaned} rows x {n_cleaned} cols"
116
224
  )
225
+ print("-" * 80)
226
+ print(f"{'More details about preprocessing':^70}")
227
+ print("-" * 80)
228
+ print(f"{'Label column:':<40} {label_col}")
229
+ print(f"{'label_map:':<40} {label_map}")
230
+ print(f"{'time column:':<40} {time_info}")
231
+ if time_info is not None:
232
+ if time_info["trans_type"] == 0:
233
+ print("- 0 : Extract ['year', 'month', 'day', 'hour']")
234
+ elif time_info["trans_type"] == 1:
235
+ print("- 1 : Extract ['hour', 'dayofweek', 'is_weekend']")
236
+ elif time_info["trans_type"] == 2:
237
+ print("- 2 : Extract ['year', 'month', 'day']")
238
+ else:
239
+ assert False
117
240
  print(
118
241
  f"{'text fetaure columns:':<40} {', '.join(list(text_feature_cols)) if list(text_feature_cols) else 'None'}"
119
242
  )
120
- print("=" * 80 + "\n")
243
+ # print("-" * 80)
244
+ # print("all columns:")
245
+ # print(list(columns))
246
+ # print("=" * 80 + "\n")
121
247
 
122
248
  return df
249
+
250
+
251
+ from sklearn.model_selection import train_test_split
252
+ from sklearn.preprocessing import StandardScaler
253
+ from torch.utils.data import Dataset
254
+
255
+ class Pandas_TO_Torch(Dataset):
256
+
257
+ def __init__(self, df: pd.DataFrame,
258
+ label_col: str,
259
+ ):
260
+ self.df = df
261
+ self.label_col = label_col
262
+
263
+ # Identify feature columns automatically (all except the label)
264
+ self.label_col = label_col
265
+ self.feature_cols = [col for col in self.df.columns if col != label_col]
266
+
267
+ # Extract features and labels
268
+ self.features = self.df[self.feature_cols].values.astype("float32")
269
+ self.labels = self.df[self.label_col].values.astype("int64")
270
+
271
+
272
+ def __len__(self):
273
+ """Return the total number of samples."""
274
+ return len(self.features)
275
+
276
+ def __getitem__(self, idx):
277
+ x = torch.tensor(self.features[idx], dtype=torch.float32)
278
+ y = torch.tensor(self.labels[idx], dtype=torch.long)
279
+
280
+ return x, y
281
+
282
+ def __repr__(self):
283
+ info = (
284
+ f"Dataset CustomNumericDataset\n"
285
+ f" Number of datapoints: {len(self)}\n"
286
+ f" Features: {self.features.shape[1]}\n"
287
+ )
288
+ return info
289
+
290
+ def to_torch(self, transform, Paras):
291
+ fea_cols = [col for col in self.df.columns if col != self.label_col]
292
+
293
+ if transform["normalization"]:
294
+ scaler = StandardScaler()
295
+ self.df[fea_cols] = scaler.fit_transform(self.df[fea_cols])
296
+
297
+ # Train/test split
298
+
299
+ train_df, test_df = train_test_split(self.df, train_size=transform["train_size"], random_state=Paras["seed"], stratify=self.df[self.label_col])
300
+
301
+ # Create datasets
302
+ train_dataset = Pandas_TO_Torch(train_df, self.label_col)
303
+ test_dataset = Pandas_TO_Torch(test_df, self.label_col)
304
+
305
+ return train_dataset, test_dataset, transform
306
+
307
+
308
+ class TXT_TO_Numpy:
309
+ def __init__(self):
310
+ pass
311
+
312
+
313
+ class bz2_To_Numpy:
314
+ def __init__(self):
315
+ pass
316
+
317
+
318
+
319
+
320
+
321
+
322
+ class StepByStep:
323
+ def __init__(self):
324
+ pass
325
+
326
+ def print_text_fea(self, df, text_feature_cols):
327
+ for col in text_feature_cols:
328
+ print(f"\n{'-'*80}")
329
+ print(f'Feature: "{col}"')
330
+ print(f"{'-'*80}")
331
+ print(
332
+ f"Unique values ({len(df[col].unique())}): {df[col].unique().tolist()}"
333
+ )
334
+
335
+
336
+ class LibSVMDataset_bz2(Dataset):
337
+ def __init__(self, path, data_name = None, Paras = None):
338
+ with bz2.open(path, 'rb') as f:
339
+ X, y = load_svmlight_file(f) # type: ignore
340
+
341
+ self.X, self.path = X, path
342
+
343
+ y = np.asanyarray(y)
344
+
345
+ if data_name is not None:
346
+ data_name = data_name.lower()
347
+
348
+ # Binary classification, with the label -1/1
349
+ if data_name in ["rcv1"]:
350
+ y = (y > 0).astype(int) # Convert to 0/1
351
+
352
+ # Multi-category, labels usually start with 1
353
+ elif data_name in [""]:
354
+ y = y - 1 # Start with 0
355
+
356
+ else:
357
+ # Default policy: Try to avoid CrossEntropyLoss errors
358
+ if np.min(y) < 0: # e.g. [-1, 1]
359
+ y = (y > 0).astype(int)
360
+ elif np.min(y) >= 1:
361
+ y = y - 1
362
+
363
+ self.y = y
364
+
365
+ def __len__(self):
366
+ return self.X.shape[0]
367
+
368
+ def __getitem__(self, idx):
369
+ xi = torch.tensor(self.X.getrow(idx).toarray(), dtype=torch.float32).squeeze(0)
370
+ yi = torch.tensor(self.y[idx], dtype=torch.float32)
371
+ return xi, yi
372
+
373
+ def __repr__(self):
374
+ num_samples = len(self.y)
375
+ num_features = self.X.shape[1]
376
+ num_classes = len(np.unique(self.y))
377
+ return (f"LibSVMDataset_bz2(\n"
378
+ f" num_samples = {num_samples},\n"
379
+ f" num_features = {num_features},\n"
380
+ f" num_classes = {num_classes}\n"
381
+ f" path = {self.path}\n"
382
+ f")")
383
+
384
+ def get_libsvm_bz2_data(train_path, test_path, data_name, Paras, split = True):
385
+
386
+ transform = "-1 → 0 for binary, y-1 for multi-class"
387
+ train_data = LibSVMDataset_bz2(train_path)
388
+
389
+ if data_name in ["Duke", "Ijcnn"]:
390
+ test_data = LibSVMDataset_bz2(test_path)
391
+ split = False
392
+ else:
393
+ test_data = Subset(train_data, [])
394
+
395
+
396
+ if split:
397
+ total_size = len(train_data)
398
+ train_size = int(Paras["train_ratio"] * total_size)
399
+ test_size = total_size - train_size
400
+
401
+ train_dataset, test_dataset = random_split(train_data, [train_size, test_size])
402
+
403
+ else:
404
+ train_dataset = train_data
405
+ # # Empty test dataset, keep the structure consistent
406
+ # test_dataset = Subset(train_data, [])
407
+ test_dataset = test_data
408
+
409
+ # print(test_dataset)
410
+ # assert False
411
+
412
+ return train_dataset, test_dataset, transform