nextrec 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -9
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/dataloader.py +168 -127
  6. nextrec/basic/features.py +24 -27
  7. nextrec/basic/layers.py +328 -159
  8. nextrec/basic/loggers.py +50 -37
  9. nextrec/basic/metrics.py +255 -147
  10. nextrec/basic/model.py +817 -462
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +16 -12
  13. nextrec/data/preprocessor.py +276 -252
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +30 -22
  16. nextrec/loss/match_losses.py +116 -83
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +70 -61
  19. nextrec/models/match/dssm_v2.py +61 -51
  20. nextrec/models/match/mind.py +89 -71
  21. nextrec/models/match/sdm.py +93 -81
  22. nextrec/models/match/youtube_dnn.py +62 -53
  23. nextrec/models/multi_task/esmm.py +49 -43
  24. nextrec/models/multi_task/mmoe.py +65 -56
  25. nextrec/models/multi_task/ple.py +92 -65
  26. nextrec/models/multi_task/share_bottom.py +48 -42
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +39 -30
  29. nextrec/models/ranking/autoint.py +70 -57
  30. nextrec/models/ranking/dcn.py +43 -35
  31. nextrec/models/ranking/deepfm.py +34 -28
  32. nextrec/models/ranking/dien.py +115 -79
  33. nextrec/models/ranking/din.py +84 -60
  34. nextrec/models/ranking/fibinet.py +51 -35
  35. nextrec/models/ranking/fm.py +28 -26
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +30 -31
  38. nextrec/models/ranking/widedeep.py +36 -31
  39. nextrec/models/ranking/xdeepfm.py +46 -39
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +23 -15
  43. nextrec/utils/optimizer.py +14 -10
  44. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/METADATA +6 -40
  45. nextrec-0.1.2.dist-info/RECORD +51 -0
  46. nextrec-0.1.1.dist-info/RECORD +0 -51
  47. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.1.dist-info → nextrec-0.1.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -23,53 +23,64 @@ from nextrec.basic.callback import EarlyStopper
23
23
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
24
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
25
 
26
+ from nextrec.loss import get_loss_fn
26
27
  from nextrec.data import get_column_data
27
28
  from nextrec.basic.loggers import setup_logger, colorize
28
29
  from nextrec.utils import get_optimizer_fn, get_scheduler_fn
29
- from nextrec.loss import get_loss_fn
30
+
30
31
 
31
32
 
32
33
  class BaseModel(nn.Module):
33
34
  @property
34
35
  def model_name(self) -> str:
35
36
  raise NotImplementedError
36
-
37
+
37
38
  @property
38
39
  def task_type(self) -> str:
39
40
  raise NotImplementedError
40
41
 
41
- def __init__(self,
42
- dense_features: list[DenseFeature] | None = None,
43
- sparse_features: list[SparseFeature] | None = None,
44
- sequence_features: list[SequenceFeature] | None = None,
45
- target: list[str] | str | None = None,
46
- task: str|list[str] = 'binary',
47
- device: str = 'cpu',
48
- embedding_l1_reg: float = 0.0,
49
- dense_l1_reg: float = 0.0,
50
- embedding_l2_reg: float = 0.0,
51
- dense_l2_reg: float = 0.0,
52
- early_stop_patience: int = 20,
53
- model_id: str = 'baseline'):
54
-
42
+ def __init__(
43
+ self,
44
+ dense_features: list[DenseFeature] | None = None,
45
+ sparse_features: list[SparseFeature] | None = None,
46
+ sequence_features: list[SequenceFeature] | None = None,
47
+ target: list[str] | str | None = None,
48
+ task: str | list[str] = "binary",
49
+ device: str = "cpu",
50
+ embedding_l1_reg: float = 0.0,
51
+ dense_l1_reg: float = 0.0,
52
+ embedding_l2_reg: float = 0.0,
53
+ dense_l2_reg: float = 0.0,
54
+ early_stop_patience: int = 20,
55
+ model_id: str = "baseline",
56
+ ):
57
+
55
58
  super(BaseModel, self).__init__()
56
59
 
57
60
  try:
58
61
  self.device = torch.device(device)
59
62
  except Exception as e:
60
- logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
61
- self.device = torch.device('cpu')
63
+ logging.warning(
64
+ colorize("Invalid device , defaulting to CPU.", color="yellow")
65
+ )
66
+ self.device = torch.device("cpu")
62
67
 
63
68
  self.dense_features = list(dense_features) if dense_features is not None else []
64
- self.sparse_features = list(sparse_features) if sparse_features is not None else []
65
- self.sequence_features = list(sequence_features) if sequence_features is not None else []
66
-
69
+ self.sparse_features = (
70
+ list(sparse_features) if sparse_features is not None else []
71
+ )
72
+ self.sequence_features = (
73
+ list(sequence_features) if sequence_features is not None else []
74
+ )
75
+
67
76
  if isinstance(target, str):
68
77
  self.target = [target]
69
78
  else:
70
79
  self.target = list(target) if target is not None else []
71
-
72
- self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
80
+
81
+ self.target_index = {
82
+ target_name: idx for idx, target_name in enumerate(self.target)
83
+ }
73
84
 
74
85
  self.task = task
75
86
  self.nums_task = len(task) if isinstance(task, list) else 1
@@ -79,33 +90,48 @@ class BaseModel(nn.Module):
79
90
  self._embedding_l2_reg = embedding_l2_reg
80
91
  self._dense_l2_reg = dense_l2_reg
81
92
 
82
- self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
83
- self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
93
+ self._regularization_weights = (
94
+ []
95
+ ) # list of dense weights for regularization, used to compute reg loss
96
+ self._embedding_params = (
97
+ []
98
+ ) # list of embedding weights for regularization, used to compute reg loss
84
99
 
85
100
  self.early_stop_patience = early_stop_patience
86
- self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
101
+ self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
87
102
 
88
103
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
89
104
  self.model_id = model_id
90
-
91
- checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
105
+
106
+ checkpoint_dir = os.path.abspath(
107
+ os.path.join(project_root, "..", "checkpoints")
108
+ )
92
109
  os.makedirs(checkpoint_dir, exist_ok=True)
93
- self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
94
- self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
110
+ self.checkpoint = os.path.join(
111
+ checkpoint_dir,
112
+ f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model",
113
+ )
114
+ self.best = os.path.join(
115
+ checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model"
116
+ )
95
117
 
96
118
  self._logger_initialized = False
97
119
  self._verbose = 1
98
120
 
99
- def _register_regularization_weights(self,
100
- embedding_attr: str = 'embedding',
101
- exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
102
- include_modules: list[str] | None = []):
121
+ def _register_regularization_weights(
122
+ self,
123
+ embedding_attr: str = "embedding",
124
+ exclude_modules: (
125
+ list[str] | None
126
+ ) = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
127
+ include_modules: list[str] | None = [],
128
+ ):
103
129
 
104
130
  exclude_modules = exclude_modules or []
105
-
131
+
106
132
  if hasattr(self, embedding_attr):
107
133
  embedding_layer = getattr(self, embedding_attr)
108
- if hasattr(embedding_layer, 'embed_dict'):
134
+ if hasattr(embedding_layer, "embed_dict"):
109
135
  for embed in embedding_layer.embed_dict.values():
110
136
  self._embedding_params.append(embed.weight)
111
137
 
@@ -113,40 +139,49 @@ class BaseModel(nn.Module):
113
139
  # Skip self module
114
140
  if module is self:
115
141
  continue
116
-
142
+
117
143
  # Skip embedding layers
118
144
  if embedding_attr in name:
119
145
  continue
120
-
146
+
121
147
  # Skip BatchNorm and Dropout by checking module type
122
- if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
123
- nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
148
+ if isinstance(
149
+ module,
150
+ (
151
+ nn.BatchNorm1d,
152
+ nn.BatchNorm2d,
153
+ nn.BatchNorm3d,
154
+ nn.Dropout,
155
+ nn.Dropout2d,
156
+ nn.Dropout3d,
157
+ ),
158
+ ):
124
159
  continue
125
-
160
+
126
161
  # White-list: only include modules whose names contain specific keywords
127
162
  if include_modules is not None:
128
163
  should_include = any(inc_name in name for inc_name in include_modules)
129
164
  if not should_include:
130
165
  continue
131
-
166
+
132
167
  # Black-list: exclude modules whose names contain specific keywords
133
168
  if any(exc_name in name for exc_name in exclude_modules):
134
169
  continue
135
-
170
+
136
171
  # Only add regularization for Linear layers
137
172
  if isinstance(module, nn.Linear):
138
173
  self._regularization_weights.append(module.weight)
139
174
 
140
175
  def add_reg_loss(self) -> torch.Tensor:
141
176
  reg_loss = torch.tensor(0.0, device=self.device)
142
-
177
+
143
178
  if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
144
179
  for param in self._embedding_params:
145
180
  reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
146
181
 
147
182
  if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
148
183
  for param in self._embedding_params:
149
- reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
184
+ reg_loss += self._embedding_l2_reg * torch.sum(param**2)
150
185
 
151
186
  if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
152
187
  for param in self._regularization_weights:
@@ -154,11 +189,16 @@ class BaseModel(nn.Module):
154
189
 
155
190
  if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
156
191
  for param in self._regularization_weights:
157
- reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
158
-
192
+ reg_loss += self._dense_l2_reg * torch.sum(param**2)
193
+
159
194
  return reg_loss
160
195
 
161
- def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
196
+ def _to_tensor(
197
+ self,
198
+ value,
199
+ dtype: torch.dtype | None = None,
200
+ device: str | torch.device | None = None,
201
+ ) -> torch.Tensor:
162
202
  if value is None:
163
203
  raise ValueError("Cannot convert None to tensor.")
164
204
  if isinstance(value, torch.Tensor):
@@ -170,11 +210,13 @@ class BaseModel(nn.Module):
170
210
  target_device = device if device is not None else self.device
171
211
  return tensor.to(target_device)
172
212
 
173
- def get_input(self, input_data: dict|pd.DataFrame):
213
+ def get_input(self, input_data: dict | pd.DataFrame):
174
214
  X_input = {}
175
-
176
- all_features = self.dense_features + self.sparse_features + self.sequence_features
177
-
215
+
216
+ all_features = (
217
+ self.dense_features + self.sparse_features + self.sequence_features
218
+ )
219
+
178
220
  for feature in all_features:
179
221
  if feature.name not in input_data:
180
222
  continue
@@ -198,10 +240,12 @@ class BaseModel(nn.Module):
198
240
  if target_data is None:
199
241
  continue
200
242
  target_tensor = self._to_tensor(target_data, dtype=torch.float32)
201
-
243
+
202
244
  if target_tensor.dim() > 1:
203
245
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
204
- target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
246
+ target_tensors.extend(
247
+ torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1)
248
+ )
205
249
  else:
206
250
  target_tensors.append(target_tensor.view(-1, 1))
207
251
 
@@ -216,14 +260,14 @@ class BaseModel(nn.Module):
216
260
 
217
261
  def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
218
262
  """Configure metrics for model evaluation using the metrics module."""
219
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
220
- task=self.task,
221
- metrics=metrics,
222
- target_names=self.target
223
- ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
224
-
225
- if not hasattr(self, 'early_stopper') or self.early_stopper is None:
226
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
263
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
264
+ configure_metrics(task=self.task, metrics=metrics, target_names=self.target)
265
+ ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
266
+
267
+ if not hasattr(self, "early_stopper") or self.early_stopper is None:
268
+ self.early_stopper = EarlyStopper(
269
+ patience=self.early_stop_patience, mode=self.best_metrics_mode
270
+ )
227
271
 
228
272
  def _validate_task_configuration(self):
229
273
  """Validate that task type, number of tasks, targets, and loss functions are consistent."""
@@ -232,35 +276,35 @@ class BaseModel(nn.Module):
232
276
  num_tasks_from_task = len(self.task)
233
277
  else:
234
278
  num_tasks_from_task = 1
235
-
279
+
236
280
  num_targets = len(self.target)
237
-
281
+
238
282
  if self.nums_task != num_tasks_from_task:
239
283
  raise ValueError(
240
284
  f"Number of tasks mismatch: nums_task={self.nums_task}, "
241
285
  f"but task list has {num_tasks_from_task} tasks."
242
286
  )
243
-
287
+
244
288
  if self.nums_task != num_targets:
245
289
  raise ValueError(
246
290
  f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
247
291
  f"Tasks: {self.task}, Targets: {self.target}"
248
292
  )
249
-
293
+
250
294
  # Check loss function consistency
251
- if hasattr(self, 'loss_fn'):
295
+ if hasattr(self, "loss_fn"):
252
296
  num_loss_fns = len(self.loss_fn)
253
297
  if num_loss_fns != self.nums_task:
254
298
  raise ValueError(
255
299
  f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
256
300
  )
257
-
301
+
258
302
  # Validate task types with metrics and loss functions
259
303
  from nextrec.loss import VALID_TASK_TYPES
260
304
  from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
261
-
305
+
262
306
  tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
263
-
307
+
264
308
  for i, task_type in enumerate(tasks_to_check):
265
309
  # Validate task type
266
310
  if task_type not in VALID_TASK_TYPES:
@@ -268,26 +312,26 @@ class BaseModel(nn.Module):
268
312
  f"Invalid task type '{task_type}' for task {i}. "
269
313
  f"Valid types: {VALID_TASK_TYPES}"
270
314
  )
271
-
315
+
272
316
  # Check metrics compatibility
273
- if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
317
+ if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
274
318
  target_name = self.target[i] if i < len(self.target) else f"task_{i}"
275
319
  task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
276
-
320
+
277
321
  for metric in task_metrics:
278
322
  metric_lower = metric.lower()
279
323
  # Skip gauc as it's valid for both classification and regression in some contexts
280
- if metric_lower == 'gauc':
324
+ if metric_lower == "gauc":
281
325
  continue
282
-
283
- if task_type in ['binary', 'multiclass']:
326
+
327
+ if task_type in ["binary", "multiclass"]:
284
328
  # Classification task
285
329
  if metric_lower in REGRESSION_METRICS:
286
330
  raise ValueError(
287
331
  f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
288
332
  f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
289
333
  )
290
- elif task_type in ['regression', 'multivariate_regression']:
334
+ elif task_type in ["regression", "multivariate_regression"]:
291
335
  # Regression task
292
336
  if metric_lower in CLASSIFICATION_METRICS:
293
337
  raise ValueError(
@@ -295,62 +339,72 @@ class BaseModel(nn.Module):
295
339
  f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
296
340
  )
297
341
 
298
- def _handle_validation_split(self,
299
- train_data: dict | pd.DataFrame | DataLoader,
300
- validation_split: float,
301
- batch_size: int,
302
- shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
342
+ def _handle_validation_split(
343
+ self,
344
+ train_data: dict | pd.DataFrame | DataLoader,
345
+ validation_split: float,
346
+ batch_size: int,
347
+ shuffle: bool,
348
+ ) -> tuple[DataLoader, dict | pd.DataFrame]:
303
349
  """Handle validation split logic for training data.
304
-
350
+
305
351
  Args:
306
352
  train_data: Training data (dict, DataFrame, or DataLoader)
307
353
  validation_split: Fraction of data to use for validation (0 < validation_split < 1)
308
354
  batch_size: Batch size for DataLoader
309
355
  shuffle: Whether to shuffle training data
310
-
356
+
311
357
  Returns:
312
358
  tuple: (train_loader, valid_data)
313
359
  """
314
360
  if not (0 < validation_split < 1):
315
- raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
316
-
361
+ raise ValueError(
362
+ f"validation_split must be between 0 and 1, got {validation_split}"
363
+ )
364
+
317
365
  if isinstance(train_data, DataLoader):
318
366
  raise ValueError(
319
367
  "validation_split cannot be used when train_data is a DataLoader. "
320
368
  "Please provide dict or pd.DataFrame for train_data."
321
369
  )
322
-
370
+
323
371
  if isinstance(train_data, pd.DataFrame):
324
372
  # Shuffle and split DataFrame
325
- shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
373
+ shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(
374
+ drop=True
375
+ )
326
376
  split_idx = int(len(shuffled_df) * (1 - validation_split))
327
377
  train_split = shuffled_df.iloc[:split_idx]
328
378
  valid_split = shuffled_df.iloc[split_idx:]
329
-
330
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
331
-
379
+
380
+ train_loader = self._prepare_data_loader(
381
+ train_split, batch_size=batch_size, shuffle=shuffle
382
+ )
383
+
332
384
  if self._verbose:
333
- logging.info(colorize(
334
- f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
335
- color="cyan"
336
- ))
337
-
385
+ logging.info(
386
+ colorize(
387
+ f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
388
+ color="cyan",
389
+ )
390
+ )
391
+
338
392
  return train_loader, valid_split
339
-
393
+
340
394
  elif isinstance(train_data, dict):
341
395
  # Get total length from any feature
342
396
  sample_key = list(train_data.keys())[0]
343
397
  total_length = len(train_data[sample_key])
344
-
398
+
345
399
  # Create indices and shuffle
346
400
  indices = np.arange(total_length)
347
401
  np.random.seed(42)
348
402
  np.random.shuffle(indices)
349
-
403
+
350
404
  split_idx = int(total_length * (1 - validation_split))
351
405
  train_indices = indices[:split_idx]
352
406
  valid_indices = indices[split_idx:]
353
-
407
+
354
408
  # Split dict
355
409
  train_split = {}
356
410
  valid_split = {}
@@ -368,82 +422,112 @@ class BaseModel(nn.Module):
368
422
  else:
369
423
  train_split[key] = [value[i] for i in train_indices]
370
424
  valid_split[key] = [value[i] for i in valid_indices]
371
-
372
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
373
-
425
+
426
+ train_loader = self._prepare_data_loader(
427
+ train_split, batch_size=batch_size, shuffle=shuffle
428
+ )
429
+
374
430
  if self._verbose:
375
- logging.info(colorize(
376
- f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
377
- color="cyan"
378
- ))
379
-
431
+ logging.info(
432
+ colorize(
433
+ f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
434
+ color="cyan",
435
+ )
436
+ )
437
+
380
438
  return train_loader, valid_split
381
-
439
+
382
440
  else:
383
441
  raise TypeError(f"Unsupported train_data type: {type(train_data)}")
384
442
 
385
-
386
- def compile(self,
387
- optimizer = "adam",
388
- optimizer_params: dict | None = None,
389
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
390
- scheduler_params: dict | None = None,
391
- loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
443
+ def compile(
444
+ self,
445
+ optimizer="adam",
446
+ optimizer_params: dict | None = None,
447
+ scheduler: (
448
+ str
449
+ | torch.optim.lr_scheduler._LRScheduler
450
+ | type[torch.optim.lr_scheduler._LRScheduler]
451
+ | None
452
+ ) = None,
453
+ scheduler_params: dict | None = None,
454
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
455
+ ):
392
456
  if optimizer_params is None:
393
457
  optimizer_params = {}
394
-
395
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
458
+
459
+ self._optimizer_name = (
460
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
461
+ )
396
462
  self._optimizer_params = optimizer_params
397
463
  if isinstance(scheduler, str):
398
464
  self._scheduler_name = scheduler
399
465
  elif scheduler is not None:
400
466
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
401
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
467
+ self._scheduler_name = getattr(
468
+ scheduler,
469
+ "__name__",
470
+ getattr(scheduler.__class__, "__name__", str(scheduler)),
471
+ )
402
472
  else:
403
473
  self._scheduler_name = None
404
474
  self._scheduler_params = scheduler_params or {}
405
475
  self._loss_config = loss
406
-
476
+
407
477
  # set optimizer
408
478
  self.optimizer_fn = get_optimizer_fn(
409
- optimizer=optimizer,
410
- params=self.parameters(),
411
- **optimizer_params
479
+ optimizer=optimizer, params=self.parameters(), **optimizer_params
412
480
  )
413
-
481
+
414
482
  # set loss functions
415
483
  if self.nums_task == 1:
416
484
  task_type = self.task if isinstance(self.task, str) else self.task[0]
417
485
  loss_value = loss[0] if isinstance(loss, list) else loss
418
486
  # For ranking and multitask, use pointwise training
419
- training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
487
+ training_mode = (
488
+ "pointwise" if self.task_type in ["ranking", "multitask"] else None
489
+ )
420
490
  # Use task_type directly, not self.task_type for single task
421
- self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
491
+ self.loss_fn = [
492
+ get_loss_fn(
493
+ task_type=task_type, training_mode=training_mode, loss=loss_value
494
+ )
495
+ ]
422
496
  else:
423
497
  self.loss_fn = []
424
498
  for i in range(self.nums_task):
425
499
  task_type = self.task[i] if isinstance(self.task, list) else self.task
426
-
500
+
427
501
  if isinstance(loss, list):
428
502
  loss_value = loss[i] if i < len(loss) else None
429
503
  else:
430
504
  loss_value = loss
431
-
505
+
432
506
  # Multitask always uses pointwise training
433
- training_mode = 'pointwise'
434
- self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
435
-
507
+ training_mode = "pointwise"
508
+ self.loss_fn.append(
509
+ get_loss_fn(
510
+ task_type=task_type,
511
+ training_mode=training_mode,
512
+ loss=loss_value,
513
+ )
514
+ )
515
+
436
516
  # set scheduler
437
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
517
+ self.scheduler_fn = (
518
+ get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
519
+ if scheduler
520
+ else None
521
+ )
438
522
 
439
523
  def compute_loss(self, y_pred, y_true):
440
524
  if y_true is None:
441
525
  return torch.tensor(0.0, device=self.device)
442
-
526
+
443
527
  if self.nums_task == 1:
444
528
  loss = self.loss_fn[0](y_pred, y_true)
445
529
  return loss
446
-
530
+
447
531
  else:
448
532
  task_losses = []
449
533
  for i in range(self.nums_task):
@@ -451,48 +535,69 @@ class BaseModel(nn.Module):
451
535
  task_losses.append(task_loss)
452
536
  return torch.stack(task_losses)
453
537
 
454
-
455
- def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
538
+ def _prepare_data_loader(
539
+ self,
540
+ data: dict | pd.DataFrame | DataLoader,
541
+ batch_size: int = 32,
542
+ shuffle: bool = True,
543
+ ):
456
544
  if isinstance(data, DataLoader):
457
545
  return data
458
546
  tensors = []
459
- all_features = self.dense_features + self.sparse_features + self.sequence_features
460
-
547
+ all_features = (
548
+ self.dense_features + self.sparse_features + self.sequence_features
549
+ )
550
+
461
551
  for feature in all_features:
462
552
  column = get_column_data(data, feature.name)
463
553
  if column is None:
464
554
  raise KeyError(f"Feature {feature.name} not found in provided data.")
465
-
555
+
466
556
  if isinstance(feature, SequenceFeature):
467
557
  if isinstance(column, pd.Series):
468
558
  column = column.values
469
559
  if isinstance(column, np.ndarray) and column.dtype == object:
470
- column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
471
- if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
560
+ column = np.array(
561
+ [
562
+ (
563
+ np.array(seq, dtype=np.int64)
564
+ if not isinstance(seq, np.ndarray)
565
+ else seq
566
+ )
567
+ for seq in column
568
+ ]
569
+ )
570
+ if (
571
+ isinstance(column, np.ndarray)
572
+ and column.ndim == 1
573
+ and column.dtype == object
574
+ ):
472
575
  column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
473
- tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
576
+ tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to("cpu")
474
577
  else:
475
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
476
- tensor = self._to_tensor(column, dtype=dtype, device='cpu')
477
-
578
+ dtype = (
579
+ torch.float32 if isinstance(feature, DenseFeature) else torch.long
580
+ )
581
+ tensor = self._to_tensor(column, dtype=dtype, device="cpu")
582
+
478
583
  tensors.append(tensor)
479
-
584
+
480
585
  label_tensors = []
481
586
  for target_name in self.target:
482
587
  column = get_column_data(data, target_name)
483
588
  if column is None:
484
589
  continue
485
- label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
486
-
590
+ label_tensor = self._to_tensor(column, dtype=torch.float32, device="cpu")
591
+
487
592
  if label_tensor.dim() == 1:
488
593
  # 1D tensor: (N,) -> (N, 1)
489
594
  label_tensor = label_tensor.view(-1, 1)
490
595
  elif label_tensor.dim() == 2:
491
596
  if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
492
597
  label_tensor = label_tensor.t()
493
-
598
+
494
599
  label_tensors.append(label_tensor)
495
-
600
+
496
601
  if label_tensors:
497
602
  if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
498
603
  y_tensor = label_tensors[0]
@@ -502,22 +607,23 @@ class BaseModel(nn.Module):
502
607
  if y_tensor.shape[1] == 1:
503
608
  y_tensor = y_tensor.squeeze(1)
504
609
  tensors.append(y_tensor)
505
-
610
+
506
611
  dataset = TensorDataset(*tensors)
507
612
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
508
613
 
509
-
510
614
  def _batch_to_dict(self, batch_data: tuple) -> dict:
511
615
  result = {}
512
- all_features = self.dense_features + self.sparse_features + self.sequence_features
513
-
616
+ all_features = (
617
+ self.dense_features + self.sparse_features + self.sequence_features
618
+ )
619
+
514
620
  for i, feature in enumerate(all_features):
515
621
  if i < len(batch_data):
516
622
  result[feature.name] = batch_data[i]
517
-
623
+
518
624
  if len(batch_data) > len(all_features):
519
625
  labels = batch_data[-1]
520
-
626
+
521
627
  if self.nums_task == 1:
522
628
  result[self.target[0]] = labels
523
629
  else:
@@ -534,28 +640,36 @@ class BaseModel(nn.Module):
534
640
  for i, target_name in enumerate(self.target):
535
641
  if i < labels.shape[-1]:
536
642
  result[target_name] = labels[..., i]
537
-
538
- return result
539
643
 
644
+ return result
540
645
 
541
- def fit(self,
542
- train_data: dict|pd.DataFrame|DataLoader,
543
- valid_data: dict|pd.DataFrame|DataLoader|None=None,
544
- metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
545
- epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
546
- user_id_column: str = 'user_id',
547
- validation_split: float | None = None):
646
+ def fit(
647
+ self,
648
+ train_data: dict | pd.DataFrame | DataLoader,
649
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
650
+ metrics: (
651
+ list[str] | dict[str, list[str]] | None
652
+ ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
653
+ epochs: int = 1,
654
+ verbose: int = 1,
655
+ shuffle: bool = True,
656
+ batch_size: int = 32,
657
+ user_id_column: str = "user_id",
658
+ validation_split: float | None = None,
659
+ ):
548
660
 
549
661
  self.to(self.device)
550
662
  if not self._logger_initialized:
551
663
  setup_logger()
552
664
  self._logger_initialized = True
553
665
  self._verbose = verbose
554
- self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
555
-
666
+ self._set_metrics(
667
+ metrics
668
+ ) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
669
+
556
670
  # Assert before training
557
671
  self._validate_task_configuration()
558
-
672
+
559
673
  if self._verbose:
560
674
  self.summary()
561
675
 
@@ -566,63 +680,85 @@ class BaseModel(nn.Module):
566
680
  train_data=train_data,
567
681
  validation_split=validation_split,
568
682
  batch_size=batch_size,
569
- shuffle=shuffle
683
+ shuffle=shuffle,
570
684
  )
571
685
  else:
572
686
  if not isinstance(train_data, DataLoader):
573
- train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
687
+ train_loader = self._prepare_data_loader(
688
+ train_data, batch_size=batch_size, shuffle=shuffle
689
+ )
574
690
  else:
575
691
  train_loader = train_data
576
-
577
692
 
578
693
  valid_user_ids: np.ndarray | None = None
579
694
  needs_user_ids = self._needs_user_ids_for_metrics()
580
695
 
581
696
  if valid_loader is None:
582
697
  if valid_data is not None and not isinstance(valid_data, DataLoader):
583
- valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
698
+ valid_loader = self._prepare_data_loader(
699
+ valid_data, batch_size=batch_size, shuffle=False
700
+ )
584
701
  # Extract user_ids only if needed for GAUC
585
702
  if needs_user_ids:
586
- if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
703
+ if (
704
+ isinstance(valid_data, pd.DataFrame)
705
+ and user_id_column in valid_data.columns
706
+ ):
587
707
  valid_user_ids = np.asarray(valid_data[user_id_column].values)
588
708
  elif isinstance(valid_data, dict) and user_id_column in valid_data:
589
709
  valid_user_ids = np.asarray(valid_data[user_id_column])
590
710
  elif valid_data is not None:
591
711
  valid_loader = valid_data
592
-
712
+
593
713
  try:
594
714
  self._steps_per_epoch = len(train_loader)
595
715
  is_streaming = False
596
716
  except TypeError:
597
717
  self._steps_per_epoch = None
598
718
  is_streaming = True
599
-
719
+
600
720
  self._epoch_index = 0
601
721
  self._stop_training = False
602
- self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
603
-
722
+ self._best_metric = (
723
+ float("-inf") if self.best_metrics_mode == "max" else float("inf")
724
+ )
725
+
604
726
  if self._verbose:
605
727
  logging.info("")
606
728
  logging.info(colorize("=" * 80, color="bright_green", bold=True))
607
729
  if is_streaming:
608
- logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
730
+ logging.info(
731
+ colorize(
732
+ f"Start training (Streaming Mode)",
733
+ color="bright_green",
734
+ bold=True,
735
+ )
736
+ )
609
737
  else:
610
- logging.info(colorize(f"Start training", color="bright_green", bold=True))
738
+ logging.info(
739
+ colorize(f"Start training", color="bright_green", bold=True)
740
+ )
611
741
  logging.info(colorize("=" * 80, color="bright_green", bold=True))
612
742
  logging.info("")
613
743
  logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
614
-
744
+
615
745
  for epoch in range(epochs):
616
746
  self._epoch_index = epoch
617
-
747
+
618
748
  # In streaming mode, print epoch header before progress bar
619
749
  if self._verbose and is_streaming:
620
750
  logging.info("")
621
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
751
+ logging.info(
752
+ colorize(
753
+ f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True
754
+ )
755
+ )
622
756
 
623
757
  # Train with metrics computation
624
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
625
-
758
+ train_result = self.train_epoch(
759
+ train_loader, is_streaming=is_streaming, compute_metrics=True
760
+ )
761
+
626
762
  # Unpack results
627
763
  if isinstance(train_result, tuple):
628
764
  train_loss, train_metrics = train_result
@@ -632,9 +768,13 @@ class BaseModel(nn.Module):
632
768
 
633
769
  if self._verbose:
634
770
  if self.nums_task == 1:
635
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
771
+ log_str = (
772
+ f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
773
+ )
636
774
  if train_metrics:
637
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
775
+ metrics_str = ", ".join(
776
+ [f"{k}={v:.4f}" for k, v in train_metrics.items()]
777
+ )
638
778
  log_str += f", {metrics_str}"
639
779
  logging.info(colorize(log_str, color="white"))
640
780
  else:
@@ -644,10 +784,12 @@ class BaseModel(nn.Module):
644
784
  task_labels.append(self.target[i])
645
785
  else:
646
786
  task_labels.append(f"task_{i}")
647
-
787
+
648
788
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
649
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
650
-
789
+ log_str = (
790
+ f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
791
+ )
792
+
651
793
  if train_metrics:
652
794
  # Group metrics by task
653
795
  task_metrics = {}
@@ -656,28 +798,50 @@ class BaseModel(nn.Module):
656
798
  if metric_key.endswith(f"_{target_name}"):
657
799
  if target_name not in task_metrics:
658
800
  task_metrics[target_name] = {}
659
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
660
- task_metrics[target_name][metric_name] = metric_value
801
+ metric_name = metric_key.rsplit(
802
+ f"_{target_name}", 1
803
+ )[0]
804
+ task_metrics[target_name][
805
+ metric_name
806
+ ] = metric_value
661
807
  break
662
-
808
+
663
809
  if task_metrics:
664
810
  task_metric_strs = []
665
811
  for target_name in self.target:
666
812
  if target_name in task_metrics:
667
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
668
- task_metric_strs.append(f"{target_name}[{metrics_str}]")
813
+ metrics_str = ", ".join(
814
+ [
815
+ f"{k}={v:.4f}"
816
+ for k, v in task_metrics[
817
+ target_name
818
+ ].items()
819
+ ]
820
+ )
821
+ task_metric_strs.append(
822
+ f"{target_name}[{metrics_str}]"
823
+ )
669
824
  log_str += ", " + ", ".join(task_metric_strs)
670
-
825
+
671
826
  logging.info(colorize(log_str, color="white"))
672
-
827
+
673
828
  if valid_loader is not None:
674
829
  # Pass user_ids only if needed for GAUC metric
675
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
676
-
830
+ val_metrics = self.evaluate(
831
+ valid_loader, user_ids=valid_user_ids if needs_user_ids else None
832
+ ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
833
+
677
834
  if self._verbose:
678
835
  if self.nums_task == 1:
679
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
680
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
836
+ metrics_str = ", ".join(
837
+ [f"{k}={v:.4f}" for k, v in val_metrics.items()]
838
+ )
839
+ logging.info(
840
+ colorize(
841
+ f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
842
+ color="cyan",
843
+ )
844
+ )
681
845
  else:
682
846
  # multi task metrics
683
847
  task_metrics = {}
@@ -686,33 +850,55 @@ class BaseModel(nn.Module):
686
850
  if metric_key.endswith(f"_{target_name}"):
687
851
  if target_name not in task_metrics:
688
852
  task_metrics[target_name] = {}
689
- metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
690
- task_metrics[target_name][metric_name] = metric_value
853
+ metric_name = metric_key.rsplit(
854
+ f"_{target_name}", 1
855
+ )[0]
856
+ task_metrics[target_name][
857
+ metric_name
858
+ ] = metric_value
691
859
  break
692
860
 
693
861
  task_metric_strs = []
694
862
  for target_name in self.target:
695
863
  if target_name in task_metrics:
696
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
864
+ metrics_str = ", ".join(
865
+ [
866
+ f"{k}={v:.4f}"
867
+ for k, v in task_metrics[target_name].items()
868
+ ]
869
+ )
697
870
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
698
-
699
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
700
-
871
+
872
+ logging.info(
873
+ colorize(
874
+ f"Epoch {epoch + 1}/{epochs} - Valid: "
875
+ + ", ".join(task_metric_strs),
876
+ color="cyan",
877
+ )
878
+ )
879
+
701
880
  # Handle empty validation metrics
702
881
  if not val_metrics:
703
882
  if self._verbose:
704
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
883
+ logging.info(
884
+ colorize(
885
+ f"Warning: No validation metrics computed. Skipping validation for this epoch.",
886
+ color="yellow",
887
+ )
888
+ )
705
889
  continue
706
-
890
+
707
891
  if self.nums_task == 1:
708
892
  primary_metric_key = self.metrics[0]
709
893
  else:
710
894
  primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
711
-
712
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
895
+
896
+ primary_metric = val_metrics.get(
897
+ primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
898
+ )
713
899
  improved = False
714
-
715
- if self.best_metrics_mode == 'max':
900
+
901
+ if self.best_metrics_mode == "max":
716
902
  if primary_metric > self._best_metric:
717
903
  self._best_metric = primary_metric
718
904
  self.save_weights(self.best)
@@ -721,56 +907,85 @@ class BaseModel(nn.Module):
721
907
  if primary_metric < self._best_metric:
722
908
  self._best_metric = primary_metric
723
909
  improved = True
724
-
910
+
725
911
  if improved:
726
912
  if self._verbose:
727
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
913
+ logging.info(
914
+ colorize(
915
+ f"Validation {primary_metric_key} improved to {self._best_metric:.4f}",
916
+ color="yellow",
917
+ )
918
+ )
728
919
  self.save_weights(self.checkpoint)
729
920
  self.early_stopper.trial_counter = 0
730
921
  else:
731
922
  self.early_stopper.trial_counter += 1
732
923
  if self._verbose:
733
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
734
-
924
+ logging.info(
925
+ colorize(
926
+ f"No improvement for {self.early_stopper.trial_counter} epoch(s)",
927
+ color="yellow",
928
+ )
929
+ )
930
+
735
931
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
736
932
  self._stop_training = True
737
933
  if self._verbose:
738
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
934
+ logging.info(
935
+ colorize(
936
+ f"Early stopping triggered after {epoch + 1} epochs",
937
+ color="bright_red",
938
+ bold=True,
939
+ )
940
+ )
739
941
  break
740
942
  else:
741
943
  self.save_weights(self.checkpoint)
742
-
944
+
743
945
  if self._stop_training:
744
946
  break
745
-
947
+
746
948
  if self.scheduler_fn is not None:
747
- if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
949
+ if isinstance(
950
+ self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
951
+ ):
748
952
  if valid_loader is not None:
749
953
  self.scheduler_fn.step(primary_metric)
750
954
  else:
751
955
  self.scheduler_fn.step()
752
-
956
+
753
957
  if self._verbose:
754
958
  logging.info("\n")
755
- logging.info(colorize("Training finished.", color="bright_green", bold=True))
959
+ logging.info(
960
+ colorize("Training finished.", color="bright_green", bold=True)
961
+ )
756
962
  logging.info("\n")
757
-
963
+
758
964
  if valid_loader is not None:
759
965
  if self._verbose:
760
- logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
966
+ logging.info(
967
+ colorize(
968
+ f"Load best model from: {self.checkpoint}", color="bright_blue"
969
+ )
970
+ )
761
971
  self.load_weights(self.checkpoint)
762
-
972
+
763
973
  return self
764
974
 
765
- def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
975
+ def train_epoch(
976
+ self,
977
+ train_loader: DataLoader,
978
+ is_streaming: bool = False,
979
+ compute_metrics: bool = False,
980
+ ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
766
981
  if self.nums_task == 1:
767
982
  accumulated_loss = 0.0
768
983
  else:
769
984
  accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
770
-
985
+
771
986
  self.train()
772
987
  num_batches = 0
773
-
988
+
774
989
  # Lists to store predictions and labels for metric computation
775
990
  y_true_list = []
776
991
  y_pred_list = []
@@ -778,19 +993,29 @@ class BaseModel(nn.Module):
778
993
  if self._verbose:
779
994
  # For streaming datasets without known length, set total=None to show progress without percentage
780
995
  if self._steps_per_epoch is not None:
781
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
996
+ batch_iter = enumerate(
997
+ tqdm.tqdm(
998
+ train_loader,
999
+ desc=f"Epoch {self._epoch_index + 1}",
1000
+ total=self._steps_per_epoch,
1001
+ )
1002
+ )
782
1003
  else:
783
1004
  # Streaming mode: show batch/file progress without epoch in desc
784
1005
  if is_streaming:
785
- batch_iter = enumerate(tqdm.tqdm(
786
- train_loader,
787
- desc="Batches",
788
- # position=1,
789
- # leave=False,
790
- # unit="batch"
791
- ))
1006
+ batch_iter = enumerate(
1007
+ tqdm.tqdm(
1008
+ train_loader,
1009
+ desc="Batches",
1010
+ # position=1,
1011
+ # leave=False,
1012
+ # unit="batch"
1013
+ )
1014
+ )
792
1015
  else:
793
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
1016
+ batch_iter = enumerate(
1017
+ tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}")
1018
+ )
794
1019
  else:
795
1020
  batch_iter = enumerate(train_loader)
796
1021
 
@@ -806,17 +1031,17 @@ class BaseModel(nn.Module):
806
1031
  total_loss = loss + reg_loss
807
1032
  else:
808
1033
  total_loss = loss.sum() + reg_loss
809
-
1034
+
810
1035
  self.optimizer_fn.zero_grad()
811
1036
  total_loss.backward()
812
1037
  nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
813
1038
  self.optimizer_fn.step()
814
-
1039
+
815
1040
  if self.nums_task == 1:
816
1041
  accumulated_loss += loss.item()
817
1042
  else:
818
1043
  accumulated_loss += loss.detach().cpu().numpy()
819
-
1044
+
820
1045
  # Collect predictions and labels for metrics if requested
821
1046
  if compute_metrics:
822
1047
  if y_true is not None:
@@ -824,49 +1049,52 @@ class BaseModel(nn.Module):
824
1049
  # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
825
1050
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
826
1051
  y_pred_list.append(y_pred.detach().cpu().numpy())
827
-
1052
+
828
1053
  num_batches += 1
829
1054
 
830
1055
  if self.nums_task == 1:
831
1056
  avg_loss = accumulated_loss / num_batches
832
1057
  else:
833
1058
  avg_loss = accumulated_loss / num_batches
834
-
1059
+
835
1060
  # Compute metrics if requested
836
1061
  if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
837
1062
  y_true_all = np.concatenate(y_true_list, axis=0)
838
1063
  y_pred_all = np.concatenate(y_pred_list, axis=0)
839
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
1064
+ metrics_dict = self.evaluate_metrics(
1065
+ y_true_all, y_pred_all, self.metrics, user_ids=None
1066
+ )
840
1067
  return avg_loss, metrics_dict
841
-
842
- return avg_loss
843
1068
 
1069
+ return avg_loss
844
1070
 
845
1071
  def _needs_user_ids_for_metrics(self) -> bool:
846
1072
  """Check if any configured metric requires user_ids (e.g., gauc)."""
847
1073
  all_metrics = set()
848
-
1074
+
849
1075
  # Collect all metrics from different sources
850
- if hasattr(self, 'metrics') and self.metrics:
1076
+ if hasattr(self, "metrics") and self.metrics:
851
1077
  all_metrics.update(m.lower() for m in self.metrics)
852
-
853
- if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
1078
+
1079
+ if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
854
1080
  for task_metrics in self.task_specific_metrics.values():
855
1081
  if isinstance(task_metrics, list):
856
1082
  all_metrics.update(m.lower() for m in task_metrics)
857
-
1083
+
858
1084
  # Check if gauc is in any of the metrics
859
- return 'gauc' in all_metrics
860
-
861
- def evaluate(self,
862
- data: dict | pd.DataFrame | DataLoader,
863
- metrics: list[str] | dict[str, list[str]] | None = None,
864
- batch_size: int = 32,
865
- user_ids: np.ndarray | None = None,
866
- user_id_column: str = 'user_id') -> dict:
1085
+ return "gauc" in all_metrics
1086
+
1087
+ def evaluate(
1088
+ self,
1089
+ data: dict | pd.DataFrame | DataLoader,
1090
+ metrics: list[str] | dict[str, list[str]] | None = None,
1091
+ batch_size: int = 32,
1092
+ user_ids: np.ndarray | None = None,
1093
+ user_id_column: str = "user_id",
1094
+ ) -> dict:
867
1095
  """
868
1096
  Evaluate the model on validation data.
869
-
1097
+
870
1098
  Args:
871
1099
  data: Evaluation data (dict, DataFrame, or DataLoader)
872
1100
  metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
@@ -874,17 +1102,19 @@ class BaseModel(nn.Module):
874
1102
  user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
875
1103
  will try to extract from data using user_id_column
876
1104
  user_id_column: Column name for user IDs (default: 'user_id')
877
-
1105
+
878
1106
  Returns:
879
1107
  Dictionary of metric values
880
1108
  """
881
1109
  self.eval()
882
-
1110
+
883
1111
  # Use provided metrics or fall back to configured metrics
884
1112
  eval_metrics = metrics if metrics is not None else self.metrics
885
1113
  if eval_metrics is None:
886
- raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
887
-
1114
+ raise ValueError(
1115
+ "No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
1116
+ )
1117
+
888
1118
  # Prepare DataLoader if needed
889
1119
  if isinstance(data, DataLoader):
890
1120
  data_loader = data
@@ -892,11 +1122,13 @@ class BaseModel(nn.Module):
892
1122
  if user_ids is None and self._needs_user_ids_for_metrics():
893
1123
  # Cannot extract user_ids from DataLoader, user must provide them
894
1124
  if self._verbose:
895
- logging.warning(colorize(
896
- "GAUC metric requires user_ids, but data is a DataLoader. "
897
- "Please provide user_ids parameter or use dict/DataFrame format.",
898
- color="yellow"
899
- ))
1125
+ logging.warning(
1126
+ colorize(
1127
+ "GAUC metric requires user_ids, but data is a DataLoader. "
1128
+ "Please provide user_ids parameter or use dict/DataFrame format.",
1129
+ color="yellow",
1130
+ )
1131
+ )
900
1132
  else:
901
1133
  # Extract user_ids if needed and not provided
902
1134
  if user_ids is None and self._needs_user_ids_for_metrics():
@@ -904,12 +1136,14 @@ class BaseModel(nn.Module):
904
1136
  user_ids = np.asarray(data[user_id_column].values)
905
1137
  elif isinstance(data, dict) and user_id_column in data:
906
1138
  user_ids = np.asarray(data[user_id_column])
907
-
908
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
909
-
1139
+
1140
+ data_loader = self._prepare_data_loader(
1141
+ data, batch_size=batch_size, shuffle=False
1142
+ )
1143
+
910
1144
  y_true_list = []
911
1145
  y_pred_list = []
912
-
1146
+
913
1147
  batch_count = 0
914
1148
  with torch.no_grad():
915
1149
  for batch_data in data_loader:
@@ -917,7 +1151,7 @@ class BaseModel(nn.Module):
917
1151
  batch_dict = self._batch_to_dict(batch_data)
918
1152
  X_input, y_true = self.get_input(batch_dict)
919
1153
  y_pred = self.forward(X_input)
920
-
1154
+
921
1155
  if y_true is not None:
922
1156
  y_true_list.append(y_true.cpu().numpy())
923
1157
  # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
@@ -925,24 +1159,40 @@ class BaseModel(nn.Module):
925
1159
  y_pred_list.append(y_pred.cpu().numpy())
926
1160
 
927
1161
  if self._verbose:
928
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
929
-
1162
+ logging.info(
1163
+ colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1164
+ )
1165
+
930
1166
  if len(y_true_list) > 0:
931
1167
  y_true_all = np.concatenate(y_true_list, axis=0)
932
1168
  if self._verbose:
933
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
1169
+ logging.info(
1170
+ colorize(
1171
+ f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"
1172
+ )
1173
+ )
934
1174
  else:
935
1175
  y_true_all = None
936
1176
  if self._verbose:
937
- logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
938
-
1177
+ logging.info(
1178
+ colorize(
1179
+ f" Warning: No y_true collected from evaluation data",
1180
+ color="yellow",
1181
+ )
1182
+ )
1183
+
939
1184
  if len(y_pred_list) > 0:
940
1185
  y_pred_all = np.concatenate(y_pred_list, axis=0)
941
1186
  else:
942
1187
  y_pred_all = None
943
1188
  if self._verbose:
944
- logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
945
-
1189
+ logging.info(
1190
+ colorize(
1191
+ f" Warning: No y_pred collected from evaluation data",
1192
+ color="yellow",
1193
+ )
1194
+ )
1195
+
946
1196
  # Convert metrics to list if it's a dict
947
1197
  if isinstance(eval_metrics, dict):
948
1198
  # For dict metrics, we need to collect all unique metric names
@@ -954,16 +1204,23 @@ class BaseModel(nn.Module):
954
1204
  metrics_to_use = unique_metrics
955
1205
  else:
956
1206
  metrics_to_use = eval_metrics
957
-
958
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
959
-
960
- return metrics_dict
961
1207
 
1208
+ metrics_dict = self.evaluate_metrics(
1209
+ y_true_all, y_pred_all, metrics_to_use, user_ids
1210
+ )
1211
+
1212
+ return metrics_dict
962
1213
 
963
- def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
1214
+ def evaluate_metrics(
1215
+ self,
1216
+ y_true: np.ndarray | None,
1217
+ y_pred: np.ndarray | None,
1218
+ metrics: list[str],
1219
+ user_ids: np.ndarray | None = None,
1220
+ ) -> dict:
964
1221
  """Evaluate metrics using the metrics module."""
965
- task_specific_metrics = getattr(self, 'task_specific_metrics', None)
966
-
1222
+ task_specific_metrics = getattr(self, "task_specific_metrics", None)
1223
+
967
1224
  return evaluate_metrics(
968
1225
  y_true=y_true,
969
1226
  y_pred=y_pred,
@@ -971,24 +1228,29 @@ class BaseModel(nn.Module):
971
1228
  task=self.task,
972
1229
  target_names=self.target,
973
1230
  task_specific_metrics=task_specific_metrics,
974
- user_ids=user_ids
1231
+ user_ids=user_ids,
975
1232
  )
976
1233
 
977
-
978
- def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
1234
+ def predict(
1235
+ self, data: str | dict | pd.DataFrame | DataLoader, batch_size: int = 32
1236
+ ) -> np.ndarray:
979
1237
  self.eval()
980
1238
  # todo: handle file path input later
981
1239
  if isinstance(data, (str, os.PathLike)):
982
1240
  pass
983
1241
  if not isinstance(data, DataLoader):
984
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
1242
+ data_loader = self._prepare_data_loader(
1243
+ data, batch_size=batch_size, shuffle=False
1244
+ )
985
1245
  else:
986
1246
  data_loader = data
987
-
1247
+
988
1248
  y_pred_list = []
989
-
1249
+
990
1250
  with torch.no_grad():
991
- for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
1251
+ for batch_data in tqdm.tqdm(
1252
+ data_loader, desc="Predicting", disable=self._verbose == 0
1253
+ ):
992
1254
  batch_dict = self._batch_to_dict(batch_data)
993
1255
  X_input, _ = self.get_input(batch_dict)
994
1256
  y_pred = self.forward(X_input)
@@ -1001,10 +1263,10 @@ class BaseModel(nn.Module):
1001
1263
  return y_pred_all
1002
1264
  else:
1003
1265
  return np.array([])
1004
-
1266
+
1005
1267
  def save_weights(self, model_path: str):
1006
1268
  torch.save(self.state_dict(), model_path)
1007
-
1269
+
1008
1270
  def load_weights(self, checkpoint):
1009
1271
  self.to(self.device)
1010
1272
  state_dict = torch.load(checkpoint, map_location="cpu")
@@ -1012,112 +1274,136 @@ class BaseModel(nn.Module):
1012
1274
 
1013
1275
  def summary(self):
1014
1276
  logger = logging.getLogger()
1015
-
1277
+
1016
1278
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1017
- logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
1279
+ logger.info(
1280
+ colorize(
1281
+ f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1282
+ )
1283
+ )
1018
1284
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1019
-
1285
+
1020
1286
  logger.info("")
1021
1287
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1022
1288
  logger.info(colorize("-" * 80, color="cyan"))
1023
-
1289
+
1024
1290
  if self.dense_features:
1025
1291
  logger.info(f"Dense Features ({len(self.dense_features)}):")
1026
1292
  for i, feat in enumerate(self.dense_features, 1):
1027
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
1293
+ embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
1028
1294
  logger.info(f" {i}. {feat.name:20s}")
1029
-
1295
+
1030
1296
  if self.sparse_features:
1031
1297
  logger.info(f"Sparse Features ({len(self.sparse_features)}):")
1032
1298
 
1033
1299
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
1034
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
1300
+ max_embed_name_len = max(
1301
+ len(feat.embedding_name) for feat in self.sparse_features
1302
+ )
1035
1303
  name_width = max(max_name_len, 10) + 2
1036
1304
  embed_name_width = max(max_embed_name_len, 15) + 2
1037
-
1038
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
1039
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
1305
+
1306
+ logger.info(
1307
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
1308
+ )
1309
+ logger.info(
1310
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
1311
+ )
1040
1312
  for i, feat in enumerate(self.sparse_features, 1):
1041
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1042
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1043
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
1044
-
1313
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1314
+ embed_dim = (
1315
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1316
+ )
1317
+ logger.info(
1318
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
1319
+ )
1320
+
1045
1321
  if self.sequence_features:
1046
1322
  logger.info(f"Sequence Features ({len(self.sequence_features)}):")
1047
1323
 
1048
1324
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
1049
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
1325
+ max_embed_name_len = max(
1326
+ len(feat.embedding_name) for feat in self.sequence_features
1327
+ )
1050
1328
  name_width = max(max_name_len, 10) + 2
1051
1329
  embed_name_width = max(max_embed_name_len, 15) + 2
1052
-
1053
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
1054
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
1330
+
1331
+ logger.info(
1332
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
1333
+ )
1334
+ logger.info(
1335
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
1336
+ )
1055
1337
  for i, feat in enumerate(self.sequence_features, 1):
1056
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1057
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1058
- max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
1059
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
1060
-
1338
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1339
+ embed_dim = (
1340
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1341
+ )
1342
+ max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
1343
+ logger.info(
1344
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
1345
+ )
1346
+
1061
1347
  logger.info("")
1062
1348
  logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
1063
1349
  logger.info(colorize("-" * 80, color="cyan"))
1064
-
1350
+
1065
1351
  # Model Architecture
1066
1352
  logger.info("Model Architecture:")
1067
1353
  logger.info(str(self))
1068
1354
  logger.info("")
1069
-
1355
+
1070
1356
  total_params = sum(p.numel() for p in self.parameters())
1071
1357
  trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
1072
1358
  non_trainable_params = total_params - trainable_params
1073
-
1359
+
1074
1360
  logger.info(f"Total Parameters: {total_params:,}")
1075
1361
  logger.info(f"Trainable Parameters: {trainable_params:,}")
1076
1362
  logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
1077
-
1363
+
1078
1364
  logger.info("Layer-wise Parameters:")
1079
1365
  for name, module in self.named_children():
1080
1366
  layer_params = sum(p.numel() for p in module.parameters())
1081
1367
  if layer_params > 0:
1082
1368
  logger.info(f" {name:30s}: {layer_params:,}")
1083
-
1369
+
1084
1370
  logger.info("")
1085
1371
  logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
1086
1372
  logger.info(colorize("-" * 80, color="cyan"))
1087
-
1373
+
1088
1374
  logger.info(f"Task Type: {self.task}")
1089
1375
  logger.info(f"Number of Tasks: {self.nums_task}")
1090
1376
  logger.info(f"Metrics: {self.metrics}")
1091
1377
  logger.info(f"Target Columns: {self.target}")
1092
1378
  logger.info(f"Device: {self.device}")
1093
-
1094
- if hasattr(self, '_optimizer_name'):
1379
+
1380
+ if hasattr(self, "_optimizer_name"):
1095
1381
  logger.info(f"Optimizer: {self._optimizer_name}")
1096
1382
  if self._optimizer_params:
1097
1383
  for key, value in self._optimizer_params.items():
1098
1384
  logger.info(f" {key:25s}: {value}")
1099
-
1100
- if hasattr(self, '_scheduler_name') and self._scheduler_name:
1385
+
1386
+ if hasattr(self, "_scheduler_name") and self._scheduler_name:
1101
1387
  logger.info(f"Scheduler: {self._scheduler_name}")
1102
1388
  if self._scheduler_params:
1103
1389
  for key, value in self._scheduler_params.items():
1104
1390
  logger.info(f" {key:25s}: {value}")
1105
-
1106
- if hasattr(self, '_loss_config'):
1391
+
1392
+ if hasattr(self, "_loss_config"):
1107
1393
  logger.info(f"Loss Function: {self._loss_config}")
1108
-
1394
+
1109
1395
  logger.info("Regularization:")
1110
1396
  logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1111
1397
  logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1112
1398
  logger.info(f" Dense L1: {self._dense_l1_reg}")
1113
1399
  logger.info(f" Dense L2: {self._dense_l2_reg}")
1114
-
1400
+
1115
1401
  logger.info("Other Settings:")
1116
1402
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1117
1403
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1118
1404
  logger.info(f" Model ID: {self.model_id}")
1119
1405
  logger.info(f" Checkpoint Path: {self.checkpoint}")
1120
-
1406
+
1121
1407
  logger.info("")
1122
1408
  logger.info("")
1123
1409
 
@@ -1127,45 +1413,47 @@ class BaseMatchModel(BaseModel):
1127
1413
  Base class for match (retrieval/recall) models
1128
1414
  Supports pointwise, pairwise, and listwise training modes
1129
1415
  """
1130
-
1416
+
1131
1417
  @property
1132
1418
  def task_type(self) -> str:
1133
- return 'match'
1134
-
1419
+ return "match"
1420
+
1135
1421
  @property
1136
1422
  def support_training_modes(self) -> list[str]:
1137
1423
  """
1138
1424
  Returns list of supported training modes for this model.
1139
1425
  Override in subclasses to restrict training modes.
1140
-
1426
+
1141
1427
  Returns:
1142
1428
  List of supported modes: ['pointwise', 'pairwise', 'listwise']
1143
1429
  """
1144
- return ['pointwise', 'pairwise', 'listwise']
1145
-
1146
- def __init__(self,
1147
- user_dense_features: list[DenseFeature] | None = None,
1148
- user_sparse_features: list[SparseFeature] | None = None,
1149
- user_sequence_features: list[SequenceFeature] | None = None,
1150
- item_dense_features: list[DenseFeature] | None = None,
1151
- item_sparse_features: list[SparseFeature] | None = None,
1152
- item_sequence_features: list[SequenceFeature] | None = None,
1153
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
1154
- num_negative_samples: int = 4,
1155
- temperature: float = 1.0,
1156
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
1157
- device: str = 'cpu',
1158
- embedding_l1_reg: float = 0.0,
1159
- dense_l1_reg: float = 0.0,
1160
- embedding_l2_reg: float = 0.0,
1161
- dense_l2_reg: float = 0.0,
1162
- early_stop_patience: int = 20,
1163
- model_id: str = 'baseline'):
1164
-
1430
+ return ["pointwise", "pairwise", "listwise"]
1431
+
1432
+ def __init__(
1433
+ self,
1434
+ user_dense_features: list[DenseFeature] | None = None,
1435
+ user_sparse_features: list[SparseFeature] | None = None,
1436
+ user_sequence_features: list[SequenceFeature] | None = None,
1437
+ item_dense_features: list[DenseFeature] | None = None,
1438
+ item_sparse_features: list[SparseFeature] | None = None,
1439
+ item_sequence_features: list[SequenceFeature] | None = None,
1440
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
1441
+ num_negative_samples: int = 4,
1442
+ temperature: float = 1.0,
1443
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
1444
+ device: str = "cpu",
1445
+ embedding_l1_reg: float = 0.0,
1446
+ dense_l1_reg: float = 0.0,
1447
+ embedding_l2_reg: float = 0.0,
1448
+ dense_l2_reg: float = 0.0,
1449
+ early_stop_patience: int = 20,
1450
+ model_id: str = "baseline",
1451
+ ):
1452
+
1165
1453
  all_dense_features = []
1166
1454
  all_sparse_features = []
1167
1455
  all_sequence_features = []
1168
-
1456
+
1169
1457
  if user_dense_features:
1170
1458
  all_dense_features.extend(user_dense_features)
1171
1459
  if item_dense_features:
@@ -1178,117 +1466,158 @@ class BaseMatchModel(BaseModel):
1178
1466
  all_sequence_features.extend(user_sequence_features)
1179
1467
  if item_sequence_features:
1180
1468
  all_sequence_features.extend(item_sequence_features)
1181
-
1469
+
1182
1470
  super(BaseMatchModel, self).__init__(
1183
1471
  dense_features=all_dense_features,
1184
1472
  sparse_features=all_sparse_features,
1185
1473
  sequence_features=all_sequence_features,
1186
- target=['label'],
1187
- task='binary',
1474
+ target=["label"],
1475
+ task="binary",
1188
1476
  device=device,
1189
1477
  embedding_l1_reg=embedding_l1_reg,
1190
1478
  dense_l1_reg=dense_l1_reg,
1191
1479
  embedding_l2_reg=embedding_l2_reg,
1192
1480
  dense_l2_reg=dense_l2_reg,
1193
1481
  early_stop_patience=early_stop_patience,
1194
- model_id=model_id
1482
+ model_id=model_id,
1483
+ )
1484
+
1485
+ self.user_dense_features = (
1486
+ list(user_dense_features) if user_dense_features else []
1487
+ )
1488
+ self.user_sparse_features = (
1489
+ list(user_sparse_features) if user_sparse_features else []
1490
+ )
1491
+ self.user_sequence_features = (
1492
+ list(user_sequence_features) if user_sequence_features else []
1493
+ )
1494
+
1495
+ self.item_dense_features = (
1496
+ list(item_dense_features) if item_dense_features else []
1195
1497
  )
1196
-
1197
- self.user_dense_features = list(user_dense_features) if user_dense_features else []
1198
- self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
1199
- self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
1200
-
1201
- self.item_dense_features = list(item_dense_features) if item_dense_features else []
1202
- self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
1203
- self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
1204
-
1498
+ self.item_sparse_features = (
1499
+ list(item_sparse_features) if item_sparse_features else []
1500
+ )
1501
+ self.item_sequence_features = (
1502
+ list(item_sequence_features) if item_sequence_features else []
1503
+ )
1504
+
1205
1505
  self.training_mode = training_mode
1206
1506
  self.num_negative_samples = num_negative_samples
1207
1507
  self.temperature = temperature
1208
1508
  self.similarity_metric = similarity_metric
1209
-
1509
+
1210
1510
  def get_user_features(self, X_input: dict) -> dict:
1211
1511
  user_input = {}
1212
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1512
+ all_user_features = (
1513
+ self.user_dense_features
1514
+ + self.user_sparse_features
1515
+ + self.user_sequence_features
1516
+ )
1213
1517
  for feature in all_user_features:
1214
1518
  if feature.name in X_input:
1215
1519
  user_input[feature.name] = X_input[feature.name]
1216
1520
  return user_input
1217
-
1521
+
1218
1522
  def get_item_features(self, X_input: dict) -> dict:
1219
1523
  item_input = {}
1220
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1524
+ all_item_features = (
1525
+ self.item_dense_features
1526
+ + self.item_sparse_features
1527
+ + self.item_sequence_features
1528
+ )
1221
1529
  for feature in all_item_features:
1222
1530
  if feature.name in X_input:
1223
1531
  item_input[feature.name] = X_input[feature.name]
1224
1532
  return item_input
1225
-
1226
- def compile(self,
1227
- optimizer = "adam",
1228
- optimizer_params: dict | None = None,
1229
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1230
- scheduler_params: dict | None = None,
1231
- loss: str | nn.Module | list[str | nn.Module] | None= None):
1533
+
1534
+ def compile(
1535
+ self,
1536
+ optimizer="adam",
1537
+ optimizer_params: dict | None = None,
1538
+ scheduler: (
1539
+ str
1540
+ | torch.optim.lr_scheduler._LRScheduler
1541
+ | type[torch.optim.lr_scheduler._LRScheduler]
1542
+ | None
1543
+ ) = None,
1544
+ scheduler_params: dict | None = None,
1545
+ loss: str | nn.Module | list[str | nn.Module] | None = None,
1546
+ ):
1232
1547
  """
1233
1548
  Compile match model with optimizer, scheduler, and loss function.
1234
1549
  Validates that training_mode is supported by the model.
1235
1550
  """
1236
1551
  from nextrec.loss import validate_training_mode
1237
-
1552
+
1238
1553
  # Validate training mode is supported
1239
1554
  validate_training_mode(
1240
1555
  training_mode=self.training_mode,
1241
1556
  support_training_modes=self.support_training_modes,
1242
- model_name=self.model_name
1557
+ model_name=self.model_name,
1243
1558
  )
1244
-
1559
+
1245
1560
  # Call parent compile with match-specific logic
1246
1561
  if optimizer_params is None:
1247
1562
  optimizer_params = {}
1248
-
1249
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1563
+
1564
+ self._optimizer_name = (
1565
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1566
+ )
1250
1567
  self._optimizer_params = optimizer_params
1251
1568
  if isinstance(scheduler, str):
1252
1569
  self._scheduler_name = scheduler
1253
1570
  elif scheduler is not None:
1254
1571
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1255
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1572
+ self._scheduler_name = getattr(
1573
+ scheduler,
1574
+ "__name__",
1575
+ getattr(scheduler.__class__, "__name__", str(scheduler)),
1576
+ )
1256
1577
  else:
1257
1578
  self._scheduler_name = None
1258
1579
  self._scheduler_params = scheduler_params or {}
1259
1580
  self._loss_config = loss
1260
-
1581
+
1261
1582
  # set optimizer
1262
1583
  self.optimizer_fn = get_optimizer_fn(
1263
- optimizer=optimizer,
1264
- params=self.parameters(),
1265
- **optimizer_params
1584
+ optimizer=optimizer, params=self.parameters(), **optimizer_params
1266
1585
  )
1267
-
1586
+
1268
1587
  # Set loss function based on training mode
1269
1588
  loss_value = loss[0] if isinstance(loss, list) else loss
1270
- self.loss_fn = [get_loss_fn(
1271
- task_type='match',
1272
- training_mode=self.training_mode,
1273
- loss=loss_value
1274
- )]
1275
-
1589
+ self.loss_fn = [
1590
+ get_loss_fn(
1591
+ task_type="match", training_mode=self.training_mode, loss=loss_value
1592
+ )
1593
+ ]
1594
+
1276
1595
  # set scheduler
1277
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1596
+ self.scheduler_fn = (
1597
+ get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
1598
+ if scheduler
1599
+ else None
1600
+ )
1278
1601
 
1279
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1280
- if self.similarity_metric == 'dot':
1602
+ def compute_similarity(
1603
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
1604
+ ) -> torch.Tensor:
1605
+ if self.similarity_metric == "dot":
1281
1606
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1282
1607
  # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
1283
- similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
1608
+ similarity = torch.sum(
1609
+ user_emb * item_emb, dim=-1
1610
+ ) # [batch_size, num_items]
1284
1611
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
1285
1612
  # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
1286
1613
  user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
1287
- similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
1614
+ similarity = torch.sum(
1615
+ user_emb_expanded * item_emb, dim=-1
1616
+ ) # [batch_size, num_items]
1288
1617
  else:
1289
1618
  similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
1290
-
1291
- elif self.similarity_metric == 'cosine':
1619
+
1620
+ elif self.similarity_metric == "cosine":
1292
1621
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1293
1622
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1294
1623
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1296,8 +1625,8 @@ class BaseMatchModel(BaseModel):
1296
1625
  similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
1297
1626
  else:
1298
1627
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1299
-
1300
- elif self.similarity_metric == 'euclidean':
1628
+
1629
+ elif self.similarity_metric == "euclidean":
1301
1630
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1302
1631
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1303
1632
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1305,84 +1634,96 @@ class BaseMatchModel(BaseModel):
1305
1634
  distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
1306
1635
  else:
1307
1636
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1308
- similarity = -distance
1309
-
1637
+ similarity = -distance
1638
+
1310
1639
  else:
1311
1640
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1312
-
1641
+
1313
1642
  similarity = similarity / self.temperature
1314
-
1643
+
1315
1644
  return similarity
1316
-
1645
+
1317
1646
  def user_tower(self, user_input: dict) -> torch.Tensor:
1318
1647
  raise NotImplementedError
1319
-
1648
+
1320
1649
  def item_tower(self, item_input: dict) -> torch.Tensor:
1321
1650
  raise NotImplementedError
1322
-
1323
- def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1651
+
1652
+ def forward(
1653
+ self, X_input: dict
1654
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1324
1655
  user_input = self.get_user_features(X_input)
1325
1656
  item_input = self.get_item_features(X_input)
1326
-
1327
- user_emb = self.user_tower(user_input) # [B, D]
1328
- item_emb = self.item_tower(item_input) # [B, D]
1329
-
1330
- if self.training and self.training_mode in ['pairwise', 'listwise']:
1657
+
1658
+ user_emb = self.user_tower(user_input) # [B, D]
1659
+ item_emb = self.item_tower(item_input) # [B, D]
1660
+
1661
+ if self.training and self.training_mode in ["pairwise", "listwise"]:
1331
1662
  return user_emb, item_emb
1332
1663
 
1333
1664
  similarity = self.compute_similarity(user_emb, item_emb) # [B]
1334
-
1335
- if self.training_mode == 'pointwise':
1665
+
1666
+ if self.training_mode == "pointwise":
1336
1667
  return torch.sigmoid(similarity)
1337
1668
  else:
1338
1669
  return similarity
1339
-
1670
+
1340
1671
  def compute_loss(self, y_pred, y_true):
1341
- if self.training_mode == 'pointwise':
1672
+ if self.training_mode == "pointwise":
1342
1673
  if y_true is None:
1343
1674
  return torch.tensor(0.0, device=self.device)
1344
1675
  return self.loss_fn[0](y_pred, y_true)
1345
-
1676
+
1346
1677
  # pairwise / listwise using inbatch neg
1347
- elif self.training_mode in ['pairwise', 'listwise']:
1678
+ elif self.training_mode in ["pairwise", "listwise"]:
1348
1679
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1349
1680
  raise ValueError(
1350
1681
  "For pairwise/listwise training, forward should return (user_emb, item_emb). "
1351
1682
  "Please check BaseMatchModel.forward implementation."
1352
1683
  )
1353
-
1684
+
1354
1685
  user_emb, item_emb = y_pred # [B, D], [B, D]
1355
-
1686
+
1356
1687
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1357
- logits = logits / self.temperature
1358
-
1688
+ logits = logits / self.temperature
1689
+
1359
1690
  batch_size = logits.size(0)
1360
- targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1361
-
1691
+ targets = torch.arange(
1692
+ batch_size, device=logits.device
1693
+ ) # [0, 1, 2, ..., B-1]
1694
+
1362
1695
  # Cross-Entropy = InfoNCE
1363
1696
  loss = F.cross_entropy(logits, targets)
1364
1697
  return loss
1365
-
1698
+
1366
1699
  else:
1367
1700
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1368
-
1701
+
1369
1702
  def _set_metrics(self, metrics: list[str] | None = None):
1370
1703
  if metrics is not None and len(metrics) > 0:
1371
1704
  self.metrics = [m.lower() for m in metrics]
1372
1705
  else:
1373
- self.metrics = ['auc', 'logloss']
1374
-
1375
- self.best_metrics_mode = 'max'
1376
-
1377
- if not hasattr(self, 'early_stopper') or self.early_stopper is None:
1378
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
1379
-
1380
- def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1706
+ self.metrics = ["auc", "logloss"]
1707
+
1708
+ self.best_metrics_mode = "max"
1709
+
1710
+ if not hasattr(self, "early_stopper") or self.early_stopper is None:
1711
+ self.early_stopper = EarlyStopper(
1712
+ patience=self.early_stop_patience, mode=self.best_metrics_mode
1713
+ )
1714
+
1715
+ def encode_user(
1716
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
1717
+ ) -> np.ndarray:
1381
1718
  self.eval()
1382
-
1719
+
1383
1720
  if not isinstance(data, DataLoader):
1384
1721
  user_data = {}
1385
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1722
+ all_user_features = (
1723
+ self.user_dense_features
1724
+ + self.user_sparse_features
1725
+ + self.user_sequence_features
1726
+ )
1386
1727
  for feature in all_user_features:
1387
1728
  if isinstance(data, dict):
1388
1729
  if feature.name in data:
@@ -1390,29 +1731,39 @@ class BaseMatchModel(BaseModel):
1390
1731
  elif isinstance(data, pd.DataFrame):
1391
1732
  if feature.name in data.columns:
1392
1733
  user_data[feature.name] = data[feature.name].values
1393
-
1394
- data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1734
+
1735
+ data_loader = self._prepare_data_loader(
1736
+ user_data, batch_size=batch_size, shuffle=False
1737
+ )
1395
1738
  else:
1396
1739
  data_loader = data
1397
-
1740
+
1398
1741
  embeddings_list = []
1399
-
1742
+
1400
1743
  with torch.no_grad():
1401
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
1744
+ for batch_data in tqdm.tqdm(
1745
+ data_loader, desc="Encoding users", disable=self._verbose == 0
1746
+ ):
1402
1747
  batch_dict = self._batch_to_dict(batch_data)
1403
1748
  user_input = self.get_user_features(batch_dict)
1404
1749
  user_emb = self.user_tower(user_input)
1405
1750
  embeddings_list.append(user_emb.cpu().numpy())
1406
-
1751
+
1407
1752
  embeddings = np.concatenate(embeddings_list, axis=0)
1408
1753
  return embeddings
1409
-
1410
- def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1754
+
1755
+ def encode_item(
1756
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
1757
+ ) -> np.ndarray:
1411
1758
  self.eval()
1412
-
1759
+
1413
1760
  if not isinstance(data, DataLoader):
1414
1761
  item_data = {}
1415
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1762
+ all_item_features = (
1763
+ self.item_dense_features
1764
+ + self.item_sparse_features
1765
+ + self.item_sequence_features
1766
+ )
1416
1767
  for feature in all_item_features:
1417
1768
  if isinstance(data, dict):
1418
1769
  if feature.name in data:
@@ -1421,18 +1772,22 @@ class BaseMatchModel(BaseModel):
1421
1772
  if feature.name in data.columns:
1422
1773
  item_data[feature.name] = data[feature.name].values
1423
1774
 
1424
- data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1775
+ data_loader = self._prepare_data_loader(
1776
+ item_data, batch_size=batch_size, shuffle=False
1777
+ )
1425
1778
  else:
1426
1779
  data_loader = data
1427
-
1780
+
1428
1781
  embeddings_list = []
1429
-
1782
+
1430
1783
  with torch.no_grad():
1431
- for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
1784
+ for batch_data in tqdm.tqdm(
1785
+ data_loader, desc="Encoding items", disable=self._verbose == 0
1786
+ ):
1432
1787
  batch_dict = self._batch_to_dict(batch_data)
1433
1788
  item_input = self.get_item_features(batch_dict)
1434
1789
  item_emb = self.item_tower(item_input)
1435
1790
  embeddings_list.append(item_emb.cpu().numpy())
1436
-
1791
+
1437
1792
  embeddings = np.concatenate(embeddings_list, axis=0)
1438
1793
  return embeddings