nextrec 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. nextrec/__init__.py +4 -4
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +9 -10
  4. nextrec/basic/callback.py +0 -1
  5. nextrec/basic/dataloader.py +127 -168
  6. nextrec/basic/features.py +27 -24
  7. nextrec/basic/layers.py +159 -328
  8. nextrec/basic/loggers.py +37 -50
  9. nextrec/basic/metrics.py +147 -255
  10. nextrec/basic/model.py +462 -817
  11. nextrec/data/__init__.py +5 -5
  12. nextrec/data/data_utils.py +12 -16
  13. nextrec/data/preprocessor.py +252 -276
  14. nextrec/loss/__init__.py +12 -12
  15. nextrec/loss/loss_utils.py +22 -30
  16. nextrec/loss/match_losses.py +83 -116
  17. nextrec/models/match/__init__.py +5 -5
  18. nextrec/models/match/dssm.py +61 -70
  19. nextrec/models/match/dssm_v2.py +51 -61
  20. nextrec/models/match/mind.py +71 -89
  21. nextrec/models/match/sdm.py +81 -93
  22. nextrec/models/match/youtube_dnn.py +53 -62
  23. nextrec/models/multi_task/esmm.py +43 -49
  24. nextrec/models/multi_task/mmoe.py +56 -65
  25. nextrec/models/multi_task/ple.py +65 -92
  26. nextrec/models/multi_task/share_bottom.py +42 -48
  27. nextrec/models/ranking/__init__.py +7 -7
  28. nextrec/models/ranking/afm.py +30 -39
  29. nextrec/models/ranking/autoint.py +57 -70
  30. nextrec/models/ranking/dcn.py +35 -43
  31. nextrec/models/ranking/deepfm.py +28 -34
  32. nextrec/models/ranking/dien.py +79 -115
  33. nextrec/models/ranking/din.py +60 -84
  34. nextrec/models/ranking/fibinet.py +35 -51
  35. nextrec/models/ranking/fm.py +26 -28
  36. nextrec/models/ranking/masknet.py +31 -31
  37. nextrec/models/ranking/pnn.py +31 -30
  38. nextrec/models/ranking/widedeep.py +31 -36
  39. nextrec/models/ranking/xdeepfm.py +39 -46
  40. nextrec/utils/__init__.py +9 -9
  41. nextrec/utils/embedding.py +1 -1
  42. nextrec/utils/initializer.py +15 -23
  43. nextrec/utils/optimizer.py +10 -14
  44. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/METADATA +16 -7
  45. nextrec-0.1.8.dist-info/RECORD +51 -0
  46. nextrec-0.1.4.dist-info/RECORD +0 -51
  47. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/WHEEL +0 -0
  48. {nextrec-0.1.4.dist-info → nextrec-0.1.8.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -23,64 +23,53 @@ from nextrec.basic.callback import EarlyStopper
23
23
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
24
24
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
25
25
 
26
- from nextrec.loss import get_loss_fn
27
26
  from nextrec.data import get_column_data
28
27
  from nextrec.basic.loggers import setup_logger, colorize
29
28
  from nextrec.utils import get_optimizer_fn, get_scheduler_fn
30
-
29
+ from nextrec.loss import get_loss_fn
31
30
 
32
31
 
33
32
  class BaseModel(nn.Module):
34
33
  @property
35
34
  def model_name(self) -> str:
36
35
  raise NotImplementedError
37
-
36
+
38
37
  @property
39
38
  def task_type(self) -> str:
40
39
  raise NotImplementedError
41
40
 
42
- def __init__(
43
- self,
44
- dense_features: list[DenseFeature] | None = None,
45
- sparse_features: list[SparseFeature] | None = None,
46
- sequence_features: list[SequenceFeature] | None = None,
47
- target: list[str] | str | None = None,
48
- task: str | list[str] = "binary",
49
- device: str = "cpu",
50
- embedding_l1_reg: float = 0.0,
51
- dense_l1_reg: float = 0.0,
52
- embedding_l2_reg: float = 0.0,
53
- dense_l2_reg: float = 0.0,
54
- early_stop_patience: int = 20,
55
- model_id: str = "baseline",
56
- ):
57
-
41
+ def __init__(self,
42
+ dense_features: list[DenseFeature] | None = None,
43
+ sparse_features: list[SparseFeature] | None = None,
44
+ sequence_features: list[SequenceFeature] | None = None,
45
+ target: list[str] | str | None = None,
46
+ task: str|list[str] = 'binary',
47
+ device: str = 'cpu',
48
+ embedding_l1_reg: float = 0.0,
49
+ dense_l1_reg: float = 0.0,
50
+ embedding_l2_reg: float = 0.0,
51
+ dense_l2_reg: float = 0.0,
52
+ early_stop_patience: int = 20,
53
+ model_id: str = 'baseline'):
54
+
58
55
  super(BaseModel, self).__init__()
59
56
 
60
57
  try:
61
58
  self.device = torch.device(device)
62
59
  except Exception as e:
63
- logging.warning(
64
- colorize("Invalid device , defaulting to CPU.", color="yellow")
65
- )
66
- self.device = torch.device("cpu")
60
+ logging.warning(colorize("Invalid device , defaulting to CPU.", color='yellow'))
61
+ self.device = torch.device('cpu')
67
62
 
68
63
  self.dense_features = list(dense_features) if dense_features is not None else []
69
- self.sparse_features = (
70
- list(sparse_features) if sparse_features is not None else []
71
- )
72
- self.sequence_features = (
73
- list(sequence_features) if sequence_features is not None else []
74
- )
75
-
64
+ self.sparse_features = list(sparse_features) if sparse_features is not None else []
65
+ self.sequence_features = list(sequence_features) if sequence_features is not None else []
66
+
76
67
  if isinstance(target, str):
77
68
  self.target = [target]
78
69
  else:
79
70
  self.target = list(target) if target is not None else []
80
-
81
- self.target_index = {
82
- target_name: idx for idx, target_name in enumerate(self.target)
83
- }
71
+
72
+ self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
84
73
 
85
74
  self.task = task
86
75
  self.nums_task = len(task) if isinstance(task, list) else 1
@@ -90,48 +79,33 @@ class BaseModel(nn.Module):
90
79
  self._embedding_l2_reg = embedding_l2_reg
91
80
  self._dense_l2_reg = dense_l2_reg
92
81
 
93
- self._regularization_weights = (
94
- []
95
- ) # list of dense weights for regularization, used to compute reg loss
96
- self._embedding_params = (
97
- []
98
- ) # list of embedding weights for regularization, used to compute reg loss
82
+ self._regularization_weights = [] # list of dense weights for regularization, used to compute reg loss
83
+ self._embedding_params = [] # list of embedding weights for regularization, used to compute reg loss
99
84
 
100
85
  self.early_stop_patience = early_stop_patience
101
- self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
86
+ self._max_gradient_norm = 1.0 # Maximum gradient norm for gradient clipping
102
87
 
103
88
  project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
104
89
  self.model_id = model_id
105
-
106
- checkpoint_dir = os.path.abspath(
107
- os.path.join(project_root, "..", "checkpoints")
108
- )
90
+
91
+ checkpoint_dir = os.path.abspath(os.path.join(project_root, "..", "checkpoints"))
109
92
  os.makedirs(checkpoint_dir, exist_ok=True)
110
- self.checkpoint = os.path.join(
111
- checkpoint_dir,
112
- f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model",
113
- )
114
- self.best = os.path.join(
115
- checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model"
116
- )
93
+ self.checkpoint = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.model")
94
+ self.best = os.path.join(checkpoint_dir, f"{self.model_name}_{self.model_id}_best.model")
117
95
 
118
96
  self._logger_initialized = False
119
97
  self._verbose = 1
120
98
 
121
- def _register_regularization_weights(
122
- self,
123
- embedding_attr: str = "embedding",
124
- exclude_modules: (
125
- list[str] | None
126
- ) = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
127
- include_modules: list[str] | None = [],
128
- ):
99
+ def _register_regularization_weights(self,
100
+ embedding_attr: str = 'embedding',
101
+ exclude_modules: list[str] | None = [], # modules wont add regularization, example: ['fm', 'lr'] / ['fm.fc'] / etc.
102
+ include_modules: list[str] | None = []):
129
103
 
130
104
  exclude_modules = exclude_modules or []
131
-
105
+
132
106
  if hasattr(self, embedding_attr):
133
107
  embedding_layer = getattr(self, embedding_attr)
134
- if hasattr(embedding_layer, "embed_dict"):
108
+ if hasattr(embedding_layer, 'embed_dict'):
135
109
  for embed in embedding_layer.embed_dict.values():
136
110
  self._embedding_params.append(embed.weight)
137
111
 
@@ -139,49 +113,40 @@ class BaseModel(nn.Module):
139
113
  # Skip self module
140
114
  if module is self:
141
115
  continue
142
-
116
+
143
117
  # Skip embedding layers
144
118
  if embedding_attr in name:
145
119
  continue
146
-
120
+
147
121
  # Skip BatchNorm and Dropout by checking module type
148
- if isinstance(
149
- module,
150
- (
151
- nn.BatchNorm1d,
152
- nn.BatchNorm2d,
153
- nn.BatchNorm3d,
154
- nn.Dropout,
155
- nn.Dropout2d,
156
- nn.Dropout3d,
157
- ),
158
- ):
122
+ if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
123
+ nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
159
124
  continue
160
-
125
+
161
126
  # White-list: only include modules whose names contain specific keywords
162
127
  if include_modules is not None:
163
128
  should_include = any(inc_name in name for inc_name in include_modules)
164
129
  if not should_include:
165
130
  continue
166
-
131
+
167
132
  # Black-list: exclude modules whose names contain specific keywords
168
133
  if any(exc_name in name for exc_name in exclude_modules):
169
134
  continue
170
-
135
+
171
136
  # Only add regularization for Linear layers
172
137
  if isinstance(module, nn.Linear):
173
138
  self._regularization_weights.append(module.weight)
174
139
 
175
140
  def add_reg_loss(self) -> torch.Tensor:
176
141
  reg_loss = torch.tensor(0.0, device=self.device)
177
-
142
+
178
143
  if self._embedding_l1_reg > 0 and len(self._embedding_params) > 0:
179
144
  for param in self._embedding_params:
180
145
  reg_loss += self._embedding_l1_reg * torch.sum(torch.abs(param))
181
146
 
182
147
  if self._embedding_l2_reg > 0 and len(self._embedding_params) > 0:
183
148
  for param in self._embedding_params:
184
- reg_loss += self._embedding_l2_reg * torch.sum(param**2)
149
+ reg_loss += self._embedding_l2_reg * torch.sum(param ** 2)
185
150
 
186
151
  if self._dense_l1_reg > 0 and len(self._regularization_weights) > 0:
187
152
  for param in self._regularization_weights:
@@ -189,16 +154,11 @@ class BaseModel(nn.Module):
189
154
 
190
155
  if self._dense_l2_reg > 0 and len(self._regularization_weights) > 0:
191
156
  for param in self._regularization_weights:
192
- reg_loss += self._dense_l2_reg * torch.sum(param**2)
193
-
157
+ reg_loss += self._dense_l2_reg * torch.sum(param ** 2)
158
+
194
159
  return reg_loss
195
160
 
196
- def _to_tensor(
197
- self,
198
- value,
199
- dtype: torch.dtype | None = None,
200
- device: str | torch.device | None = None,
201
- ) -> torch.Tensor:
161
+ def _to_tensor(self, value, dtype: torch.dtype | None = None, device: str | torch.device | None = None) -> torch.Tensor:
202
162
  if value is None:
203
163
  raise ValueError("Cannot convert None to tensor.")
204
164
  if isinstance(value, torch.Tensor):
@@ -210,13 +170,11 @@ class BaseModel(nn.Module):
210
170
  target_device = device if device is not None else self.device
211
171
  return tensor.to(target_device)
212
172
 
213
- def get_input(self, input_data: dict | pd.DataFrame):
173
+ def get_input(self, input_data: dict|pd.DataFrame):
214
174
  X_input = {}
215
-
216
- all_features = (
217
- self.dense_features + self.sparse_features + self.sequence_features
218
- )
219
-
175
+
176
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
177
+
220
178
  for feature in all_features:
221
179
  if feature.name not in input_data:
222
180
  continue
@@ -240,12 +198,10 @@ class BaseModel(nn.Module):
240
198
  if target_data is None:
241
199
  continue
242
200
  target_tensor = self._to_tensor(target_data, dtype=torch.float32)
243
-
201
+
244
202
  if target_tensor.dim() > 1:
245
203
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
246
- target_tensors.extend(
247
- torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1)
248
- )
204
+ target_tensors.extend(torch.chunk(target_tensor, chunks=target_tensor.shape[1], dim=1))
249
205
  else:
250
206
  target_tensors.append(target_tensor.view(-1, 1))
251
207
 
@@ -260,14 +216,14 @@ class BaseModel(nn.Module):
260
216
 
261
217
  def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
262
218
  """Configure metrics for model evaluation using the metrics module."""
263
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
264
- configure_metrics(task=self.task, metrics=metrics, target_names=self.target)
265
- ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
266
-
267
- if not hasattr(self, "early_stopper") or self.early_stopper is None:
268
- self.early_stopper = EarlyStopper(
269
- patience=self.early_stop_patience, mode=self.best_metrics_mode
270
- )
219
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(
220
+ task=self.task,
221
+ metrics=metrics,
222
+ target_names=self.target
223
+ ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
224
+
225
+ if not hasattr(self, 'early_stopper') or self.early_stopper is None:
226
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
271
227
 
272
228
  def _validate_task_configuration(self):
273
229
  """Validate that task type, number of tasks, targets, and loss functions are consistent."""
@@ -276,35 +232,35 @@ class BaseModel(nn.Module):
276
232
  num_tasks_from_task = len(self.task)
277
233
  else:
278
234
  num_tasks_from_task = 1
279
-
235
+
280
236
  num_targets = len(self.target)
281
-
237
+
282
238
  if self.nums_task != num_tasks_from_task:
283
239
  raise ValueError(
284
240
  f"Number of tasks mismatch: nums_task={self.nums_task}, "
285
241
  f"but task list has {num_tasks_from_task} tasks."
286
242
  )
287
-
243
+
288
244
  if self.nums_task != num_targets:
289
245
  raise ValueError(
290
246
  f"Number of tasks ({self.nums_task}) does not match number of target columns ({num_targets}). "
291
247
  f"Tasks: {self.task}, Targets: {self.target}"
292
248
  )
293
-
249
+
294
250
  # Check loss function consistency
295
- if hasattr(self, "loss_fn"):
251
+ if hasattr(self, 'loss_fn'):
296
252
  num_loss_fns = len(self.loss_fn)
297
253
  if num_loss_fns != self.nums_task:
298
254
  raise ValueError(
299
255
  f"Number of loss functions ({num_loss_fns}) does not match number of tasks ({self.nums_task})."
300
256
  )
301
-
257
+
302
258
  # Validate task types with metrics and loss functions
303
259
  from nextrec.loss import VALID_TASK_TYPES
304
260
  from nextrec.basic.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS
305
-
261
+
306
262
  tasks_to_check = self.task if isinstance(self.task, list) else [self.task]
307
-
263
+
308
264
  for i, task_type in enumerate(tasks_to_check):
309
265
  # Validate task type
310
266
  if task_type not in VALID_TASK_TYPES:
@@ -312,26 +268,26 @@ class BaseModel(nn.Module):
312
268
  f"Invalid task type '{task_type}' for task {i}. "
313
269
  f"Valid types: {VALID_TASK_TYPES}"
314
270
  )
315
-
271
+
316
272
  # Check metrics compatibility
317
- if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
273
+ if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
318
274
  target_name = self.target[i] if i < len(self.target) else f"task_{i}"
319
275
  task_metrics = self.task_specific_metrics.get(target_name, self.metrics)
320
-
276
+
321
277
  for metric in task_metrics:
322
278
  metric_lower = metric.lower()
323
279
  # Skip gauc as it's valid for both classification and regression in some contexts
324
- if metric_lower == "gauc":
280
+ if metric_lower == 'gauc':
325
281
  continue
326
-
327
- if task_type in ["binary", "multiclass"]:
282
+
283
+ if task_type in ['binary', 'multiclass']:
328
284
  # Classification task
329
285
  if metric_lower in REGRESSION_METRICS:
330
286
  raise ValueError(
331
287
  f"Metric '{metric}' is not compatible with classification task type '{task_type}' "
332
288
  f"for target '{target_name}'. Classification metrics: {CLASSIFICATION_METRICS}"
333
289
  )
334
- elif task_type in ["regression", "multivariate_regression"]:
290
+ elif task_type in ['regression', 'multivariate_regression']:
335
291
  # Regression task
336
292
  if metric_lower in CLASSIFICATION_METRICS:
337
293
  raise ValueError(
@@ -339,72 +295,62 @@ class BaseModel(nn.Module):
339
295
  f"for target '{target_name}'. Regression metrics: {REGRESSION_METRICS}"
340
296
  )
341
297
 
342
- def _handle_validation_split(
343
- self,
344
- train_data: dict | pd.DataFrame | DataLoader,
345
- validation_split: float,
346
- batch_size: int,
347
- shuffle: bool,
348
- ) -> tuple[DataLoader, dict | pd.DataFrame]:
298
+ def _handle_validation_split(self,
299
+ train_data: dict | pd.DataFrame | DataLoader,
300
+ validation_split: float,
301
+ batch_size: int,
302
+ shuffle: bool) -> tuple[DataLoader, dict | pd.DataFrame]:
349
303
  """Handle validation split logic for training data.
350
-
304
+
351
305
  Args:
352
306
  train_data: Training data (dict, DataFrame, or DataLoader)
353
307
  validation_split: Fraction of data to use for validation (0 < validation_split < 1)
354
308
  batch_size: Batch size for DataLoader
355
309
  shuffle: Whether to shuffle training data
356
-
310
+
357
311
  Returns:
358
312
  tuple: (train_loader, valid_data)
359
313
  """
360
314
  if not (0 < validation_split < 1):
361
- raise ValueError(
362
- f"validation_split must be between 0 and 1, got {validation_split}"
363
- )
364
-
315
+ raise ValueError(f"validation_split must be between 0 and 1, got {validation_split}")
316
+
365
317
  if isinstance(train_data, DataLoader):
366
318
  raise ValueError(
367
319
  "validation_split cannot be used when train_data is a DataLoader. "
368
320
  "Please provide dict or pd.DataFrame for train_data."
369
321
  )
370
-
322
+
371
323
  if isinstance(train_data, pd.DataFrame):
372
324
  # Shuffle and split DataFrame
373
- shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(
374
- drop=True
375
- )
325
+ shuffled_df = train_data.sample(frac=1.0, random_state=42).reset_index(drop=True)
376
326
  split_idx = int(len(shuffled_df) * (1 - validation_split))
377
327
  train_split = shuffled_df.iloc[:split_idx]
378
328
  valid_split = shuffled_df.iloc[split_idx:]
379
-
380
- train_loader = self._prepare_data_loader(
381
- train_split, batch_size=batch_size, shuffle=shuffle
382
- )
383
-
329
+
330
+ train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
331
+
384
332
  if self._verbose:
385
- logging.info(
386
- colorize(
387
- f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
388
- color="cyan",
389
- )
390
- )
391
-
333
+ logging.info(colorize(
334
+ f"Split data: {len(train_split)} training samples, {len(valid_split)} validation samples",
335
+ color="cyan"
336
+ ))
337
+
392
338
  return train_loader, valid_split
393
-
339
+
394
340
  elif isinstance(train_data, dict):
395
341
  # Get total length from any feature
396
342
  sample_key = list(train_data.keys())[0]
397
343
  total_length = len(train_data[sample_key])
398
-
344
+
399
345
  # Create indices and shuffle
400
346
  indices = np.arange(total_length)
401
347
  np.random.seed(42)
402
348
  np.random.shuffle(indices)
403
-
349
+
404
350
  split_idx = int(total_length * (1 - validation_split))
405
351
  train_indices = indices[:split_idx]
406
352
  valid_indices = indices[split_idx:]
407
-
353
+
408
354
  # Split dict
409
355
  train_split = {}
410
356
  valid_split = {}
@@ -422,112 +368,82 @@ class BaseModel(nn.Module):
422
368
  else:
423
369
  train_split[key] = [value[i] for i in train_indices]
424
370
  valid_split[key] = [value[i] for i in valid_indices]
425
-
426
- train_loader = self._prepare_data_loader(
427
- train_split, batch_size=batch_size, shuffle=shuffle
428
- )
429
-
371
+
372
+ train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
373
+
430
374
  if self._verbose:
431
- logging.info(
432
- colorize(
433
- f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
434
- color="cyan",
435
- )
436
- )
437
-
375
+ logging.info(colorize(
376
+ f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples",
377
+ color="cyan"
378
+ ))
379
+
438
380
  return train_loader, valid_split
439
-
381
+
440
382
  else:
441
383
  raise TypeError(f"Unsupported train_data type: {type(train_data)}")
442
384
 
443
- def compile(
444
- self,
445
- optimizer="adam",
446
- optimizer_params: dict | None = None,
447
- scheduler: (
448
- str
449
- | torch.optim.lr_scheduler._LRScheduler
450
- | type[torch.optim.lr_scheduler._LRScheduler]
451
- | None
452
- ) = None,
453
- scheduler_params: dict | None = None,
454
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
455
- ):
385
+
386
+ def compile(self,
387
+ optimizer = "adam",
388
+ optimizer_params: dict | None = None,
389
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
390
+ scheduler_params: dict | None = None,
391
+ loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
456
392
  if optimizer_params is None:
457
393
  optimizer_params = {}
458
-
459
- self._optimizer_name = (
460
- optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
461
- )
394
+
395
+ self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
462
396
  self._optimizer_params = optimizer_params
463
397
  if isinstance(scheduler, str):
464
398
  self._scheduler_name = scheduler
465
399
  elif scheduler is not None:
466
400
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
467
- self._scheduler_name = getattr(
468
- scheduler,
469
- "__name__",
470
- getattr(scheduler.__class__, "__name__", str(scheduler)),
471
- )
401
+ self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
472
402
  else:
473
403
  self._scheduler_name = None
474
404
  self._scheduler_params = scheduler_params or {}
475
405
  self._loss_config = loss
476
-
406
+
477
407
  # set optimizer
478
408
  self.optimizer_fn = get_optimizer_fn(
479
- optimizer=optimizer, params=self.parameters(), **optimizer_params
409
+ optimizer=optimizer,
410
+ params=self.parameters(),
411
+ **optimizer_params
480
412
  )
481
-
413
+
482
414
  # set loss functions
483
415
  if self.nums_task == 1:
484
416
  task_type = self.task if isinstance(self.task, str) else self.task[0]
485
417
  loss_value = loss[0] if isinstance(loss, list) else loss
486
418
  # For ranking and multitask, use pointwise training
487
- training_mode = (
488
- "pointwise" if self.task_type in ["ranking", "multitask"] else None
489
- )
419
+ training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
490
420
  # Use task_type directly, not self.task_type for single task
491
- self.loss_fn = [
492
- get_loss_fn(
493
- task_type=task_type, training_mode=training_mode, loss=loss_value
494
- )
495
- ]
421
+ self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
496
422
  else:
497
423
  self.loss_fn = []
498
424
  for i in range(self.nums_task):
499
425
  task_type = self.task[i] if isinstance(self.task, list) else self.task
500
-
426
+
501
427
  if isinstance(loss, list):
502
428
  loss_value = loss[i] if i < len(loss) else None
503
429
  else:
504
430
  loss_value = loss
505
-
431
+
506
432
  # Multitask always uses pointwise training
507
- training_mode = "pointwise"
508
- self.loss_fn.append(
509
- get_loss_fn(
510
- task_type=task_type,
511
- training_mode=training_mode,
512
- loss=loss_value,
513
- )
514
- )
515
-
433
+ training_mode = 'pointwise'
434
+ self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
435
+
516
436
  # set scheduler
517
- self.scheduler_fn = (
518
- get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
519
- if scheduler
520
- else None
521
- )
437
+ self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
522
438
 
523
439
  def compute_loss(self, y_pred, y_true):
524
440
  if y_true is None:
525
441
  return torch.tensor(0.0, device=self.device)
526
-
442
+
527
443
  if self.nums_task == 1:
528
444
  loss = self.loss_fn[0](y_pred, y_true)
529
445
  return loss
530
-
446
+
531
447
  else:
532
448
  task_losses = []
533
449
  for i in range(self.nums_task):
@@ -535,69 +451,48 @@ class BaseModel(nn.Module):
535
451
  task_losses.append(task_loss)
536
452
  return torch.stack(task_losses)
537
453
 
538
- def _prepare_data_loader(
539
- self,
540
- data: dict | pd.DataFrame | DataLoader,
541
- batch_size: int = 32,
542
- shuffle: bool = True,
543
- ):
454
+
455
+ def _prepare_data_loader(self, data: dict|pd.DataFrame|DataLoader, batch_size: int = 32, shuffle: bool = True):
544
456
  if isinstance(data, DataLoader):
545
457
  return data
546
458
  tensors = []
547
- all_features = (
548
- self.dense_features + self.sparse_features + self.sequence_features
549
- )
550
-
459
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
460
+
551
461
  for feature in all_features:
552
462
  column = get_column_data(data, feature.name)
553
463
  if column is None:
554
464
  raise KeyError(f"Feature {feature.name} not found in provided data.")
555
-
465
+
556
466
  if isinstance(feature, SequenceFeature):
557
467
  if isinstance(column, pd.Series):
558
468
  column = column.values
559
469
  if isinstance(column, np.ndarray) and column.dtype == object:
560
- column = np.array(
561
- [
562
- (
563
- np.array(seq, dtype=np.int64)
564
- if not isinstance(seq, np.ndarray)
565
- else seq
566
- )
567
- for seq in column
568
- ]
569
- )
570
- if (
571
- isinstance(column, np.ndarray)
572
- and column.ndim == 1
573
- and column.dtype == object
574
- ):
470
+ column = np.array([np.array(seq, dtype=np.int64) if not isinstance(seq, np.ndarray) else seq for seq in column])
471
+ if isinstance(column, np.ndarray) and column.ndim == 1 and column.dtype == object:
575
472
  column = np.vstack([c if isinstance(c, np.ndarray) else np.array(c) for c in column]) # type: ignore
576
- tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to("cpu")
473
+ tensor = torch.from_numpy(np.asarray(column, dtype=np.int64)).to('cpu')
577
474
  else:
578
- dtype = (
579
- torch.float32 if isinstance(feature, DenseFeature) else torch.long
580
- )
581
- tensor = self._to_tensor(column, dtype=dtype, device="cpu")
582
-
475
+ dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
476
+ tensor = self._to_tensor(column, dtype=dtype, device='cpu')
477
+
583
478
  tensors.append(tensor)
584
-
479
+
585
480
  label_tensors = []
586
481
  for target_name in self.target:
587
482
  column = get_column_data(data, target_name)
588
483
  if column is None:
589
484
  continue
590
- label_tensor = self._to_tensor(column, dtype=torch.float32, device="cpu")
591
-
485
+ label_tensor = self._to_tensor(column, dtype=torch.float32, device='cpu')
486
+
592
487
  if label_tensor.dim() == 1:
593
488
  # 1D tensor: (N,) -> (N, 1)
594
489
  label_tensor = label_tensor.view(-1, 1)
595
490
  elif label_tensor.dim() == 2:
596
491
  if label_tensor.shape[0] == 1 and label_tensor.shape[1] > 1:
597
492
  label_tensor = label_tensor.t()
598
-
493
+
599
494
  label_tensors.append(label_tensor)
600
-
495
+
601
496
  if label_tensors:
602
497
  if len(label_tensors) == 1 and label_tensors[0].shape[1] > 1:
603
498
  y_tensor = label_tensors[0]
@@ -607,23 +502,22 @@ class BaseModel(nn.Module):
607
502
  if y_tensor.shape[1] == 1:
608
503
  y_tensor = y_tensor.squeeze(1)
609
504
  tensors.append(y_tensor)
610
-
505
+
611
506
  dataset = TensorDataset(*tensors)
612
507
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
613
508
 
509
+
614
510
  def _batch_to_dict(self, batch_data: tuple) -> dict:
615
511
  result = {}
616
- all_features = (
617
- self.dense_features + self.sparse_features + self.sequence_features
618
- )
619
-
512
+ all_features = self.dense_features + self.sparse_features + self.sequence_features
513
+
620
514
  for i, feature in enumerate(all_features):
621
515
  if i < len(batch_data):
622
516
  result[feature.name] = batch_data[i]
623
-
517
+
624
518
  if len(batch_data) > len(all_features):
625
519
  labels = batch_data[-1]
626
-
520
+
627
521
  if self.nums_task == 1:
628
522
  result[self.target[0]] = labels
629
523
  else:
@@ -640,36 +534,28 @@ class BaseModel(nn.Module):
640
534
  for i, target_name in enumerate(self.target):
641
535
  if i < labels.shape[-1]:
642
536
  result[target_name] = labels[..., i]
643
-
537
+
644
538
  return result
645
539
 
646
- def fit(
647
- self,
648
- train_data: dict | pd.DataFrame | DataLoader,
649
- valid_data: dict | pd.DataFrame | DataLoader | None = None,
650
- metrics: (
651
- list[str] | dict[str, list[str]] | None
652
- ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
653
- epochs: int = 1,
654
- verbose: int = 1,
655
- shuffle: bool = True,
656
- batch_size: int = 32,
657
- user_id_column: str = "user_id",
658
- validation_split: float | None = None,
659
- ):
540
+
541
+ def fit(self,
542
+ train_data: dict|pd.DataFrame|DataLoader,
543
+ valid_data: dict|pd.DataFrame|DataLoader|None=None,
544
+ metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
545
+ epochs:int=1, verbose:int=1, shuffle:bool=True, batch_size:int=32,
546
+ user_id_column: str = 'user_id',
547
+ validation_split: float | None = None):
660
548
 
661
549
  self.to(self.device)
662
550
  if not self._logger_initialized:
663
551
  setup_logger()
664
552
  self._logger_initialized = True
665
553
  self._verbose = verbose
666
- self._set_metrics(
667
- metrics
668
- ) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
669
-
554
+ self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
555
+
670
556
  # Assert before training
671
557
  self._validate_task_configuration()
672
-
558
+
673
559
  if self._verbose:
674
560
  self.summary()
675
561
 
@@ -680,85 +566,63 @@ class BaseModel(nn.Module):
680
566
  train_data=train_data,
681
567
  validation_split=validation_split,
682
568
  batch_size=batch_size,
683
- shuffle=shuffle,
569
+ shuffle=shuffle
684
570
  )
685
571
  else:
686
572
  if not isinstance(train_data, DataLoader):
687
- train_loader = self._prepare_data_loader(
688
- train_data, batch_size=batch_size, shuffle=shuffle
689
- )
573
+ train_loader = self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle)
690
574
  else:
691
575
  train_loader = train_data
576
+
692
577
 
693
578
  valid_user_ids: np.ndarray | None = None
694
579
  needs_user_ids = self._needs_user_ids_for_metrics()
695
580
 
696
581
  if valid_loader is None:
697
582
  if valid_data is not None and not isinstance(valid_data, DataLoader):
698
- valid_loader = self._prepare_data_loader(
699
- valid_data, batch_size=batch_size, shuffle=False
700
- )
583
+ valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
701
584
  # Extract user_ids only if needed for GAUC
702
585
  if needs_user_ids:
703
- if (
704
- isinstance(valid_data, pd.DataFrame)
705
- and user_id_column in valid_data.columns
706
- ):
586
+ if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
707
587
  valid_user_ids = np.asarray(valid_data[user_id_column].values)
708
588
  elif isinstance(valid_data, dict) and user_id_column in valid_data:
709
589
  valid_user_ids = np.asarray(valid_data[user_id_column])
710
590
  elif valid_data is not None:
711
591
  valid_loader = valid_data
712
-
592
+
713
593
  try:
714
594
  self._steps_per_epoch = len(train_loader)
715
595
  is_streaming = False
716
596
  except TypeError:
717
597
  self._steps_per_epoch = None
718
598
  is_streaming = True
719
-
599
+
720
600
  self._epoch_index = 0
721
601
  self._stop_training = False
722
- self._best_metric = (
723
- float("-inf") if self.best_metrics_mode == "max" else float("inf")
724
- )
725
-
602
+ self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
603
+
726
604
  if self._verbose:
727
605
  logging.info("")
728
606
  logging.info(colorize("=" * 80, color="bright_green", bold=True))
729
607
  if is_streaming:
730
- logging.info(
731
- colorize(
732
- f"Start training (Streaming Mode)",
733
- color="bright_green",
734
- bold=True,
735
- )
736
- )
608
+ logging.info(colorize(f"Start training (Streaming Mode)", color="bright_green", bold=True))
737
609
  else:
738
- logging.info(
739
- colorize(f"Start training", color="bright_green", bold=True)
740
- )
610
+ logging.info(colorize(f"Start training", color="bright_green", bold=True))
741
611
  logging.info(colorize("=" * 80, color="bright_green", bold=True))
742
612
  logging.info("")
743
613
  logging.info(colorize(f"Model device: {self.device}", color="bright_green"))
744
-
614
+
745
615
  for epoch in range(epochs):
746
616
  self._epoch_index = epoch
747
-
617
+
748
618
  # In streaming mode, print epoch header before progress bar
749
619
  if self._verbose and is_streaming:
750
620
  logging.info("")
751
- logging.info(
752
- colorize(
753
- f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True
754
- )
755
- )
621
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", color="bright_green", bold=True))
756
622
 
757
623
  # Train with metrics computation
758
- train_result = self.train_epoch(
759
- train_loader, is_streaming=is_streaming, compute_metrics=True
760
- )
761
-
624
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming, compute_metrics=True)
625
+
762
626
  # Unpack results
763
627
  if isinstance(train_result, tuple):
764
628
  train_loss, train_metrics = train_result
@@ -768,13 +632,9 @@ class BaseModel(nn.Module):
768
632
 
769
633
  if self._verbose:
770
634
  if self.nums_task == 1:
771
- log_str = (
772
- f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
773
- )
635
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
774
636
  if train_metrics:
775
- metrics_str = ", ".join(
776
- [f"{k}={v:.4f}" for k, v in train_metrics.items()]
777
- )
637
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
778
638
  log_str += f", {metrics_str}"
779
639
  logging.info(colorize(log_str, color="white"))
780
640
  else:
@@ -784,12 +644,10 @@ class BaseModel(nn.Module):
784
644
  task_labels.append(self.target[i])
785
645
  else:
786
646
  task_labels.append(f"task_{i}")
787
-
647
+
788
648
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
789
- log_str = (
790
- f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
791
- )
792
-
649
+ log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
650
+
793
651
  if train_metrics:
794
652
  # Group metrics by task
795
653
  task_metrics = {}
@@ -798,50 +656,28 @@ class BaseModel(nn.Module):
798
656
  if metric_key.endswith(f"_{target_name}"):
799
657
  if target_name not in task_metrics:
800
658
  task_metrics[target_name] = {}
801
- metric_name = metric_key.rsplit(
802
- f"_{target_name}", 1
803
- )[0]
804
- task_metrics[target_name][
805
- metric_name
806
- ] = metric_value
659
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
660
+ task_metrics[target_name][metric_name] = metric_value
807
661
  break
808
-
662
+
809
663
  if task_metrics:
810
664
  task_metric_strs = []
811
665
  for target_name in self.target:
812
666
  if target_name in task_metrics:
813
- metrics_str = ", ".join(
814
- [
815
- f"{k}={v:.4f}"
816
- for k, v in task_metrics[
817
- target_name
818
- ].items()
819
- ]
820
- )
821
- task_metric_strs.append(
822
- f"{target_name}[{metrics_str}]"
823
- )
667
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
668
+ task_metric_strs.append(f"{target_name}[{metrics_str}]")
824
669
  log_str += ", " + ", ".join(task_metric_strs)
825
-
670
+
826
671
  logging.info(colorize(log_str, color="white"))
827
-
672
+
828
673
  if valid_loader is not None:
829
674
  # Pass user_ids only if needed for GAUC metric
830
- val_metrics = self.evaluate(
831
- valid_loader, user_ids=valid_user_ids if needs_user_ids else None
832
- ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
833
-
675
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
676
+
834
677
  if self._verbose:
835
678
  if self.nums_task == 1:
836
- metrics_str = ", ".join(
837
- [f"{k}={v:.4f}" for k, v in val_metrics.items()]
838
- )
839
- logging.info(
840
- colorize(
841
- f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
842
- color="cyan",
843
- )
844
- )
679
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
680
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
845
681
  else:
846
682
  # multi task metrics
847
683
  task_metrics = {}
@@ -850,55 +686,33 @@ class BaseModel(nn.Module):
850
686
  if metric_key.endswith(f"_{target_name}"):
851
687
  if target_name not in task_metrics:
852
688
  task_metrics[target_name] = {}
853
- metric_name = metric_key.rsplit(
854
- f"_{target_name}", 1
855
- )[0]
856
- task_metrics[target_name][
857
- metric_name
858
- ] = metric_value
689
+ metric_name = metric_key.rsplit(f"_{target_name}", 1)[0]
690
+ task_metrics[target_name][metric_name] = metric_value
859
691
  break
860
692
 
861
693
  task_metric_strs = []
862
694
  for target_name in self.target:
863
695
  if target_name in task_metrics:
864
- metrics_str = ", ".join(
865
- [
866
- f"{k}={v:.4f}"
867
- for k, v in task_metrics[target_name].items()
868
- ]
869
- )
696
+ metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
870
697
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
871
-
872
- logging.info(
873
- colorize(
874
- f"Epoch {epoch + 1}/{epochs} - Valid: "
875
- + ", ".join(task_metric_strs),
876
- color="cyan",
877
- )
878
- )
879
-
698
+
699
+ logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
700
+
880
701
  # Handle empty validation metrics
881
702
  if not val_metrics:
882
703
  if self._verbose:
883
- logging.info(
884
- colorize(
885
- f"Warning: No validation metrics computed. Skipping validation for this epoch.",
886
- color="yellow",
887
- )
888
- )
704
+ logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
889
705
  continue
890
-
706
+
891
707
  if self.nums_task == 1:
892
708
  primary_metric_key = self.metrics[0]
893
709
  else:
894
710
  primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
895
-
896
- primary_metric = val_metrics.get(
897
- primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
898
- )
711
+
712
+ primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
899
713
  improved = False
900
-
901
- if self.best_metrics_mode == "max":
714
+
715
+ if self.best_metrics_mode == 'max':
902
716
  if primary_metric > self._best_metric:
903
717
  self._best_metric = primary_metric
904
718
  self.save_weights(self.best)
@@ -907,85 +721,56 @@ class BaseModel(nn.Module):
907
721
  if primary_metric < self._best_metric:
908
722
  self._best_metric = primary_metric
909
723
  improved = True
910
-
724
+
911
725
  if improved:
912
726
  if self._verbose:
913
- logging.info(
914
- colorize(
915
- f"Validation {primary_metric_key} improved to {self._best_metric:.4f}",
916
- color="yellow",
917
- )
918
- )
727
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}", color="yellow"))
919
728
  self.save_weights(self.checkpoint)
920
729
  self.early_stopper.trial_counter = 0
921
730
  else:
922
731
  self.early_stopper.trial_counter += 1
923
732
  if self._verbose:
924
- logging.info(
925
- colorize(
926
- f"No improvement for {self.early_stopper.trial_counter} epoch(s)",
927
- color="yellow",
928
- )
929
- )
930
-
733
+ logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)", color="yellow"))
734
+
931
735
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
932
736
  self._stop_training = True
933
737
  if self._verbose:
934
- logging.info(
935
- colorize(
936
- f"Early stopping triggered after {epoch + 1} epochs",
937
- color="bright_red",
938
- bold=True,
939
- )
940
- )
738
+ logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
941
739
  break
942
740
  else:
943
741
  self.save_weights(self.checkpoint)
944
-
742
+
945
743
  if self._stop_training:
946
744
  break
947
-
745
+
948
746
  if self.scheduler_fn is not None:
949
- if isinstance(
950
- self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
951
- ):
747
+ if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
952
748
  if valid_loader is not None:
953
749
  self.scheduler_fn.step(primary_metric)
954
750
  else:
955
751
  self.scheduler_fn.step()
956
-
752
+
957
753
  if self._verbose:
958
754
  logging.info("\n")
959
- logging.info(
960
- colorize("Training finished.", color="bright_green", bold=True)
961
- )
755
+ logging.info(colorize("Training finished.", color="bright_green", bold=True))
962
756
  logging.info("\n")
963
-
757
+
964
758
  if valid_loader is not None:
965
759
  if self._verbose:
966
- logging.info(
967
- colorize(
968
- f"Load best model from: {self.checkpoint}", color="bright_blue"
969
- )
970
- )
760
+ logging.info(colorize(f"Load best model from: {self.checkpoint}", color="bright_blue"))
971
761
  self.load_weights(self.checkpoint)
972
-
762
+
973
763
  return self
974
764
 
975
- def train_epoch(
976
- self,
977
- train_loader: DataLoader,
978
- is_streaming: bool = False,
979
- compute_metrics: bool = False,
980
- ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
765
+ def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False, compute_metrics: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
981
766
  if self.nums_task == 1:
982
767
  accumulated_loss = 0.0
983
768
  else:
984
769
  accumulated_loss = np.zeros(self.nums_task, dtype=np.float64)
985
-
770
+
986
771
  self.train()
987
772
  num_batches = 0
988
-
773
+
989
774
  # Lists to store predictions and labels for metric computation
990
775
  y_true_list = []
991
776
  y_pred_list = []
@@ -993,29 +778,19 @@ class BaseModel(nn.Module):
993
778
  if self._verbose:
994
779
  # For streaming datasets without known length, set total=None to show progress without percentage
995
780
  if self._steps_per_epoch is not None:
996
- batch_iter = enumerate(
997
- tqdm.tqdm(
998
- train_loader,
999
- desc=f"Epoch {self._epoch_index + 1}",
1000
- total=self._steps_per_epoch,
1001
- )
1002
- )
781
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
1003
782
  else:
1004
783
  # Streaming mode: show batch/file progress without epoch in desc
1005
784
  if is_streaming:
1006
- batch_iter = enumerate(
1007
- tqdm.tqdm(
1008
- train_loader,
1009
- desc="Batches",
1010
- # position=1,
1011
- # leave=False,
1012
- # unit="batch"
1013
- )
1014
- )
785
+ batch_iter = enumerate(tqdm.tqdm(
786
+ train_loader,
787
+ desc="Batches",
788
+ # position=1,
789
+ # leave=False,
790
+ # unit="batch"
791
+ ))
1015
792
  else:
1016
- batch_iter = enumerate(
1017
- tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}")
1018
- )
793
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
1019
794
  else:
1020
795
  batch_iter = enumerate(train_loader)
1021
796
 
@@ -1031,17 +806,17 @@ class BaseModel(nn.Module):
1031
806
  total_loss = loss + reg_loss
1032
807
  else:
1033
808
  total_loss = loss.sum() + reg_loss
1034
-
809
+
1035
810
  self.optimizer_fn.zero_grad()
1036
811
  total_loss.backward()
1037
812
  nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
1038
813
  self.optimizer_fn.step()
1039
-
814
+
1040
815
  if self.nums_task == 1:
1041
816
  accumulated_loss += loss.item()
1042
817
  else:
1043
818
  accumulated_loss += loss.detach().cpu().numpy()
1044
-
819
+
1045
820
  # Collect predictions and labels for metrics if requested
1046
821
  if compute_metrics:
1047
822
  if y_true is not None:
@@ -1049,52 +824,49 @@ class BaseModel(nn.Module):
1049
824
  # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
1050
825
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
1051
826
  y_pred_list.append(y_pred.detach().cpu().numpy())
1052
-
827
+
1053
828
  num_batches += 1
1054
829
 
1055
830
  if self.nums_task == 1:
1056
831
  avg_loss = accumulated_loss / num_batches
1057
832
  else:
1058
833
  avg_loss = accumulated_loss / num_batches
1059
-
834
+
1060
835
  # Compute metrics if requested
1061
836
  if compute_metrics and len(y_true_list) > 0 and len(y_pred_list) > 0:
1062
837
  y_true_all = np.concatenate(y_true_list, axis=0)
1063
838
  y_pred_all = np.concatenate(y_pred_list, axis=0)
1064
- metrics_dict = self.evaluate_metrics(
1065
- y_true_all, y_pred_all, self.metrics, user_ids=None
1066
- )
839
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=None)
1067
840
  return avg_loss, metrics_dict
1068
-
841
+
1069
842
  return avg_loss
1070
843
 
844
+
1071
845
  def _needs_user_ids_for_metrics(self) -> bool:
1072
846
  """Check if any configured metric requires user_ids (e.g., gauc)."""
1073
847
  all_metrics = set()
1074
-
848
+
1075
849
  # Collect all metrics from different sources
1076
- if hasattr(self, "metrics") and self.metrics:
850
+ if hasattr(self, 'metrics') and self.metrics:
1077
851
  all_metrics.update(m.lower() for m in self.metrics)
1078
-
1079
- if hasattr(self, "task_specific_metrics") and self.task_specific_metrics:
852
+
853
+ if hasattr(self, 'task_specific_metrics') and self.task_specific_metrics:
1080
854
  for task_metrics in self.task_specific_metrics.values():
1081
855
  if isinstance(task_metrics, list):
1082
856
  all_metrics.update(m.lower() for m in task_metrics)
1083
-
857
+
1084
858
  # Check if gauc is in any of the metrics
1085
- return "gauc" in all_metrics
1086
-
1087
- def evaluate(
1088
- self,
1089
- data: dict | pd.DataFrame | DataLoader,
1090
- metrics: list[str] | dict[str, list[str]] | None = None,
1091
- batch_size: int = 32,
1092
- user_ids: np.ndarray | None = None,
1093
- user_id_column: str = "user_id",
1094
- ) -> dict:
859
+ return 'gauc' in all_metrics
860
+
861
+ def evaluate(self,
862
+ data: dict | pd.DataFrame | DataLoader,
863
+ metrics: list[str] | dict[str, list[str]] | None = None,
864
+ batch_size: int = 32,
865
+ user_ids: np.ndarray | None = None,
866
+ user_id_column: str = 'user_id') -> dict:
1095
867
  """
1096
868
  Evaluate the model on validation data.
1097
-
869
+
1098
870
  Args:
1099
871
  data: Evaluation data (dict, DataFrame, or DataLoader)
1100
872
  metrics: Optional metrics to use for evaluation. If None, uses metrics from fit()
@@ -1102,19 +874,17 @@ class BaseModel(nn.Module):
1102
874
  user_ids: Optional user IDs for computing GAUC metric. If None and gauc is needed,
1103
875
  will try to extract from data using user_id_column
1104
876
  user_id_column: Column name for user IDs (default: 'user_id')
1105
-
877
+
1106
878
  Returns:
1107
879
  Dictionary of metric values
1108
880
  """
1109
881
  self.eval()
1110
-
882
+
1111
883
  # Use provided metrics or fall back to configured metrics
1112
884
  eval_metrics = metrics if metrics is not None else self.metrics
1113
885
  if eval_metrics is None:
1114
- raise ValueError(
1115
- "No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
1116
- )
1117
-
886
+ raise ValueError("No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
887
+
1118
888
  # Prepare DataLoader if needed
1119
889
  if isinstance(data, DataLoader):
1120
890
  data_loader = data
@@ -1122,13 +892,11 @@ class BaseModel(nn.Module):
1122
892
  if user_ids is None and self._needs_user_ids_for_metrics():
1123
893
  # Cannot extract user_ids from DataLoader, user must provide them
1124
894
  if self._verbose:
1125
- logging.warning(
1126
- colorize(
1127
- "GAUC metric requires user_ids, but data is a DataLoader. "
1128
- "Please provide user_ids parameter or use dict/DataFrame format.",
1129
- color="yellow",
1130
- )
1131
- )
895
+ logging.warning(colorize(
896
+ "GAUC metric requires user_ids, but data is a DataLoader. "
897
+ "Please provide user_ids parameter or use dict/DataFrame format.",
898
+ color="yellow"
899
+ ))
1132
900
  else:
1133
901
  # Extract user_ids if needed and not provided
1134
902
  if user_ids is None and self._needs_user_ids_for_metrics():
@@ -1136,14 +904,12 @@ class BaseModel(nn.Module):
1136
904
  user_ids = np.asarray(data[user_id_column].values)
1137
905
  elif isinstance(data, dict) and user_id_column in data:
1138
906
  user_ids = np.asarray(data[user_id_column])
1139
-
1140
- data_loader = self._prepare_data_loader(
1141
- data, batch_size=batch_size, shuffle=False
1142
- )
1143
-
907
+
908
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
909
+
1144
910
  y_true_list = []
1145
911
  y_pred_list = []
1146
-
912
+
1147
913
  batch_count = 0
1148
914
  with torch.no_grad():
1149
915
  for batch_data in data_loader:
@@ -1151,7 +917,7 @@ class BaseModel(nn.Module):
1151
917
  batch_dict = self._batch_to_dict(batch_data)
1152
918
  X_input, y_true = self.get_input(batch_dict)
1153
919
  y_pred = self.forward(X_input)
1154
-
920
+
1155
921
  if y_true is not None:
1156
922
  y_true_list.append(y_true.cpu().numpy())
1157
923
  # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
@@ -1159,40 +925,24 @@ class BaseModel(nn.Module):
1159
925
  y_pred_list.append(y_pred.cpu().numpy())
1160
926
 
1161
927
  if self._verbose:
1162
- logging.info(
1163
- colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1164
- )
1165
-
928
+ logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
929
+
1166
930
  if len(y_true_list) > 0:
1167
931
  y_true_all = np.concatenate(y_true_list, axis=0)
1168
932
  if self._verbose:
1169
- logging.info(
1170
- colorize(
1171
- f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"
1172
- )
1173
- )
933
+ logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
1174
934
  else:
1175
935
  y_true_all = None
1176
936
  if self._verbose:
1177
- logging.info(
1178
- colorize(
1179
- f" Warning: No y_true collected from evaluation data",
1180
- color="yellow",
1181
- )
1182
- )
1183
-
937
+ logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
938
+
1184
939
  if len(y_pred_list) > 0:
1185
940
  y_pred_all = np.concatenate(y_pred_list, axis=0)
1186
941
  else:
1187
942
  y_pred_all = None
1188
943
  if self._verbose:
1189
- logging.info(
1190
- colorize(
1191
- f" Warning: No y_pred collected from evaluation data",
1192
- color="yellow",
1193
- )
1194
- )
1195
-
944
+ logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
945
+
1196
946
  # Convert metrics to list if it's a dict
1197
947
  if isinstance(eval_metrics, dict):
1198
948
  # For dict metrics, we need to collect all unique metric names
@@ -1204,23 +954,16 @@ class BaseModel(nn.Module):
1204
954
  metrics_to_use = unique_metrics
1205
955
  else:
1206
956
  metrics_to_use = eval_metrics
1207
-
1208
- metrics_dict = self.evaluate_metrics(
1209
- y_true_all, y_pred_all, metrics_to_use, user_ids
1210
- )
1211
-
957
+
958
+ metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, user_ids)
959
+
1212
960
  return metrics_dict
1213
961
 
1214
- def evaluate_metrics(
1215
- self,
1216
- y_true: np.ndarray | None,
1217
- y_pred: np.ndarray | None,
1218
- metrics: list[str],
1219
- user_ids: np.ndarray | None = None,
1220
- ) -> dict:
1221
- """Evaluate metrics using the metrics module."""
1222
- task_specific_metrics = getattr(self, "task_specific_metrics", None)
1223
962
 
963
+ def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
964
+ """Evaluate metrics using the metrics module."""
965
+ task_specific_metrics = getattr(self, 'task_specific_metrics', None)
966
+
1224
967
  return evaluate_metrics(
1225
968
  y_true=y_true,
1226
969
  y_pred=y_pred,
@@ -1228,29 +971,24 @@ class BaseModel(nn.Module):
1228
971
  task=self.task,
1229
972
  target_names=self.target,
1230
973
  task_specific_metrics=task_specific_metrics,
1231
- user_ids=user_ids,
974
+ user_ids=user_ids
1232
975
  )
1233
976
 
1234
- def predict(
1235
- self, data: str | dict | pd.DataFrame | DataLoader, batch_size: int = 32
1236
- ) -> np.ndarray:
977
+
978
+ def predict(self, data: str|dict|pd.DataFrame|DataLoader, batch_size: int = 32) -> np.ndarray:
1237
979
  self.eval()
1238
980
  # todo: handle file path input later
1239
981
  if isinstance(data, (str, os.PathLike)):
1240
982
  pass
1241
983
  if not isinstance(data, DataLoader):
1242
- data_loader = self._prepare_data_loader(
1243
- data, batch_size=batch_size, shuffle=False
1244
- )
984
+ data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
1245
985
  else:
1246
986
  data_loader = data
1247
-
987
+
1248
988
  y_pred_list = []
1249
-
989
+
1250
990
  with torch.no_grad():
1251
- for batch_data in tqdm.tqdm(
1252
- data_loader, desc="Predicting", disable=self._verbose == 0
1253
- ):
991
+ for batch_data in tqdm.tqdm(data_loader, desc="Predicting", disable=self._verbose == 0):
1254
992
  batch_dict = self._batch_to_dict(batch_data)
1255
993
  X_input, _ = self.get_input(batch_dict)
1256
994
  y_pred = self.forward(X_input)
@@ -1263,10 +1001,10 @@ class BaseModel(nn.Module):
1263
1001
  return y_pred_all
1264
1002
  else:
1265
1003
  return np.array([])
1266
-
1004
+
1267
1005
  def save_weights(self, model_path: str):
1268
1006
  torch.save(self.state_dict(), model_path)
1269
-
1007
+
1270
1008
  def load_weights(self, checkpoint):
1271
1009
  self.to(self.device)
1272
1010
  state_dict = torch.load(checkpoint, map_location="cpu")
@@ -1274,136 +1012,112 @@ class BaseModel(nn.Module):
1274
1012
 
1275
1013
  def summary(self):
1276
1014
  logger = logging.getLogger()
1277
-
1015
+
1278
1016
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1279
- logger.info(
1280
- colorize(
1281
- f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1282
- )
1283
- )
1017
+ logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
1284
1018
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
1285
-
1019
+
1286
1020
  logger.info("")
1287
1021
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
1288
1022
  logger.info(colorize("-" * 80, color="cyan"))
1289
-
1023
+
1290
1024
  if self.dense_features:
1291
1025
  logger.info(f"Dense Features ({len(self.dense_features)}):")
1292
1026
  for i, feat in enumerate(self.dense_features, 1):
1293
- embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
1027
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
1294
1028
  logger.info(f" {i}. {feat.name:20s}")
1295
-
1029
+
1296
1030
  if self.sparse_features:
1297
1031
  logger.info(f"Sparse Features ({len(self.sparse_features)}):")
1298
1032
 
1299
1033
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
1300
- max_embed_name_len = max(
1301
- len(feat.embedding_name) for feat in self.sparse_features
1302
- )
1034
+ max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
1303
1035
  name_width = max(max_name_len, 10) + 2
1304
1036
  embed_name_width = max(max_embed_name_len, 15) + 2
1305
-
1306
- logger.info(
1307
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
1308
- )
1309
- logger.info(
1310
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
1311
- )
1037
+
1038
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
1039
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
1312
1040
  for i, feat in enumerate(self.sparse_features, 1):
1313
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1314
- embed_dim = (
1315
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1316
- )
1317
- logger.info(
1318
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
1319
- )
1320
-
1041
+ vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1042
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1043
+ logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
1044
+
1321
1045
  if self.sequence_features:
1322
1046
  logger.info(f"Sequence Features ({len(self.sequence_features)}):")
1323
1047
 
1324
1048
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
1325
- max_embed_name_len = max(
1326
- len(feat.embedding_name) for feat in self.sequence_features
1327
- )
1049
+ max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
1328
1050
  name_width = max(max_name_len, 10) + 2
1329
1051
  embed_name_width = max(max_embed_name_len, 15) + 2
1330
-
1331
- logger.info(
1332
- f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
1333
- )
1334
- logger.info(
1335
- f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
1336
- )
1052
+
1053
+ logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
1054
+ logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
1337
1055
  for i, feat in enumerate(self.sequence_features, 1):
1338
- vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1339
- embed_dim = (
1340
- feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1341
- )
1342
- max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
1343
- logger.info(
1344
- f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
1345
- )
1346
-
1056
+ vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
1057
+ embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
1058
+ max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
1059
+ logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
1060
+
1347
1061
  logger.info("")
1348
1062
  logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
1349
1063
  logger.info(colorize("-" * 80, color="cyan"))
1350
-
1064
+
1351
1065
  # Model Architecture
1352
1066
  logger.info("Model Architecture:")
1353
1067
  logger.info(str(self))
1354
1068
  logger.info("")
1355
-
1069
+
1356
1070
  total_params = sum(p.numel() for p in self.parameters())
1357
1071
  trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
1358
1072
  non_trainable_params = total_params - trainable_params
1359
-
1073
+
1360
1074
  logger.info(f"Total Parameters: {total_params:,}")
1361
1075
  logger.info(f"Trainable Parameters: {trainable_params:,}")
1362
1076
  logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
1363
-
1077
+
1364
1078
  logger.info("Layer-wise Parameters:")
1365
1079
  for name, module in self.named_children():
1366
1080
  layer_params = sum(p.numel() for p in module.parameters())
1367
1081
  if layer_params > 0:
1368
1082
  logger.info(f" {name:30s}: {layer_params:,}")
1369
-
1083
+
1370
1084
  logger.info("")
1371
1085
  logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
1372
1086
  logger.info(colorize("-" * 80, color="cyan"))
1373
-
1087
+
1374
1088
  logger.info(f"Task Type: {self.task}")
1375
1089
  logger.info(f"Number of Tasks: {self.nums_task}")
1376
1090
  logger.info(f"Metrics: {self.metrics}")
1377
1091
  logger.info(f"Target Columns: {self.target}")
1378
1092
  logger.info(f"Device: {self.device}")
1379
-
1380
- if hasattr(self, "_optimizer_name"):
1093
+
1094
+ if hasattr(self, '_optimizer_name'):
1381
1095
  logger.info(f"Optimizer: {self._optimizer_name}")
1382
1096
  if self._optimizer_params:
1383
1097
  for key, value in self._optimizer_params.items():
1384
1098
  logger.info(f" {key:25s}: {value}")
1385
-
1386
- if hasattr(self, "_scheduler_name") and self._scheduler_name:
1099
+
1100
+ if hasattr(self, '_scheduler_name') and self._scheduler_name:
1387
1101
  logger.info(f"Scheduler: {self._scheduler_name}")
1388
1102
  if self._scheduler_params:
1389
1103
  for key, value in self._scheduler_params.items():
1390
1104
  logger.info(f" {key:25s}: {value}")
1391
-
1392
- if hasattr(self, "_loss_config"):
1105
+
1106
+ if hasattr(self, '_loss_config'):
1393
1107
  logger.info(f"Loss Function: {self._loss_config}")
1394
-
1108
+
1395
1109
  logger.info("Regularization:")
1396
1110
  logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1397
1111
  logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1398
1112
  logger.info(f" Dense L1: {self._dense_l1_reg}")
1399
1113
  logger.info(f" Dense L2: {self._dense_l2_reg}")
1400
-
1114
+
1401
1115
  logger.info("Other Settings:")
1402
1116
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
1403
1117
  logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
1404
1118
  logger.info(f" Model ID: {self.model_id}")
1405
1119
  logger.info(f" Checkpoint Path: {self.checkpoint}")
1406
-
1120
+
1407
1121
  logger.info("")
1408
1122
  logger.info("")
1409
1123
 
@@ -1413,47 +1127,45 @@ class BaseMatchModel(BaseModel):
1413
1127
  Base class for match (retrieval/recall) models
1414
1128
  Supports pointwise, pairwise, and listwise training modes
1415
1129
  """
1416
-
1130
+
1417
1131
  @property
1418
1132
  def task_type(self) -> str:
1419
- return "match"
1420
-
1133
+ return 'match'
1134
+
1421
1135
  @property
1422
1136
  def support_training_modes(self) -> list[str]:
1423
1137
  """
1424
1138
  Returns list of supported training modes for this model.
1425
1139
  Override in subclasses to restrict training modes.
1426
-
1140
+
1427
1141
  Returns:
1428
1142
  List of supported modes: ['pointwise', 'pairwise', 'listwise']
1429
1143
  """
1430
- return ["pointwise", "pairwise", "listwise"]
1431
-
1432
- def __init__(
1433
- self,
1434
- user_dense_features: list[DenseFeature] | None = None,
1435
- user_sparse_features: list[SparseFeature] | None = None,
1436
- user_sequence_features: list[SequenceFeature] | None = None,
1437
- item_dense_features: list[DenseFeature] | None = None,
1438
- item_sparse_features: list[SparseFeature] | None = None,
1439
- item_sequence_features: list[SequenceFeature] | None = None,
1440
- training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
1441
- num_negative_samples: int = 4,
1442
- temperature: float = 1.0,
1443
- similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
1444
- device: str = "cpu",
1445
- embedding_l1_reg: float = 0.0,
1446
- dense_l1_reg: float = 0.0,
1447
- embedding_l2_reg: float = 0.0,
1448
- dense_l2_reg: float = 0.0,
1449
- early_stop_patience: int = 20,
1450
- model_id: str = "baseline",
1451
- ):
1452
-
1144
+ return ['pointwise', 'pairwise', 'listwise']
1145
+
1146
+ def __init__(self,
1147
+ user_dense_features: list[DenseFeature] | None = None,
1148
+ user_sparse_features: list[SparseFeature] | None = None,
1149
+ user_sequence_features: list[SequenceFeature] | None = None,
1150
+ item_dense_features: list[DenseFeature] | None = None,
1151
+ item_sparse_features: list[SparseFeature] | None = None,
1152
+ item_sequence_features: list[SequenceFeature] | None = None,
1153
+ training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
1154
+ num_negative_samples: int = 4,
1155
+ temperature: float = 1.0,
1156
+ similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
1157
+ device: str = 'cpu',
1158
+ embedding_l1_reg: float = 0.0,
1159
+ dense_l1_reg: float = 0.0,
1160
+ embedding_l2_reg: float = 0.0,
1161
+ dense_l2_reg: float = 0.0,
1162
+ early_stop_patience: int = 20,
1163
+ model_id: str = 'baseline'):
1164
+
1453
1165
  all_dense_features = []
1454
1166
  all_sparse_features = []
1455
1167
  all_sequence_features = []
1456
-
1168
+
1457
1169
  if user_dense_features:
1458
1170
  all_dense_features.extend(user_dense_features)
1459
1171
  if item_dense_features:
@@ -1466,158 +1178,117 @@ class BaseMatchModel(BaseModel):
1466
1178
  all_sequence_features.extend(user_sequence_features)
1467
1179
  if item_sequence_features:
1468
1180
  all_sequence_features.extend(item_sequence_features)
1469
-
1181
+
1470
1182
  super(BaseMatchModel, self).__init__(
1471
1183
  dense_features=all_dense_features,
1472
1184
  sparse_features=all_sparse_features,
1473
1185
  sequence_features=all_sequence_features,
1474
- target=["label"],
1475
- task="binary",
1186
+ target=['label'],
1187
+ task='binary',
1476
1188
  device=device,
1477
1189
  embedding_l1_reg=embedding_l1_reg,
1478
1190
  dense_l1_reg=dense_l1_reg,
1479
1191
  embedding_l2_reg=embedding_l2_reg,
1480
1192
  dense_l2_reg=dense_l2_reg,
1481
1193
  early_stop_patience=early_stop_patience,
1482
- model_id=model_id,
1483
- )
1484
-
1485
- self.user_dense_features = (
1486
- list(user_dense_features) if user_dense_features else []
1487
- )
1488
- self.user_sparse_features = (
1489
- list(user_sparse_features) if user_sparse_features else []
1490
- )
1491
- self.user_sequence_features = (
1492
- list(user_sequence_features) if user_sequence_features else []
1493
- )
1494
-
1495
- self.item_dense_features = (
1496
- list(item_dense_features) if item_dense_features else []
1194
+ model_id=model_id
1497
1195
  )
1498
- self.item_sparse_features = (
1499
- list(item_sparse_features) if item_sparse_features else []
1500
- )
1501
- self.item_sequence_features = (
1502
- list(item_sequence_features) if item_sequence_features else []
1503
- )
1504
-
1196
+
1197
+ self.user_dense_features = list(user_dense_features) if user_dense_features else []
1198
+ self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
1199
+ self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
1200
+
1201
+ self.item_dense_features = list(item_dense_features) if item_dense_features else []
1202
+ self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
1203
+ self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
1204
+
1505
1205
  self.training_mode = training_mode
1506
1206
  self.num_negative_samples = num_negative_samples
1507
1207
  self.temperature = temperature
1508
1208
  self.similarity_metric = similarity_metric
1509
-
1209
+
1510
1210
  def get_user_features(self, X_input: dict) -> dict:
1511
1211
  user_input = {}
1512
- all_user_features = (
1513
- self.user_dense_features
1514
- + self.user_sparse_features
1515
- + self.user_sequence_features
1516
- )
1212
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1517
1213
  for feature in all_user_features:
1518
1214
  if feature.name in X_input:
1519
1215
  user_input[feature.name] = X_input[feature.name]
1520
1216
  return user_input
1521
-
1217
+
1522
1218
  def get_item_features(self, X_input: dict) -> dict:
1523
1219
  item_input = {}
1524
- all_item_features = (
1525
- self.item_dense_features
1526
- + self.item_sparse_features
1527
- + self.item_sequence_features
1528
- )
1220
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1529
1221
  for feature in all_item_features:
1530
1222
  if feature.name in X_input:
1531
1223
  item_input[feature.name] = X_input[feature.name]
1532
1224
  return item_input
1533
-
1534
- def compile(
1535
- self,
1536
- optimizer="adam",
1537
- optimizer_params: dict | None = None,
1538
- scheduler: (
1539
- str
1540
- | torch.optim.lr_scheduler._LRScheduler
1541
- | type[torch.optim.lr_scheduler._LRScheduler]
1542
- | None
1543
- ) = None,
1544
- scheduler_params: dict | None = None,
1545
- loss: str | nn.Module | list[str | nn.Module] | None = None,
1546
- ):
1225
+
1226
+ def compile(self,
1227
+ optimizer = "adam",
1228
+ optimizer_params: dict | None = None,
1229
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1230
+ scheduler_params: dict | None = None,
1231
+ loss: str | nn.Module | list[str | nn.Module] | None= None):
1547
1232
  """
1548
1233
  Compile match model with optimizer, scheduler, and loss function.
1549
1234
  Validates that training_mode is supported by the model.
1550
1235
  """
1551
1236
  from nextrec.loss import validate_training_mode
1552
-
1237
+
1553
1238
  # Validate training mode is supported
1554
1239
  validate_training_mode(
1555
1240
  training_mode=self.training_mode,
1556
1241
  support_training_modes=self.support_training_modes,
1557
- model_name=self.model_name,
1242
+ model_name=self.model_name
1558
1243
  )
1559
-
1244
+
1560
1245
  # Call parent compile with match-specific logic
1561
1246
  if optimizer_params is None:
1562
1247
  optimizer_params = {}
1563
-
1564
- self._optimizer_name = (
1565
- optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1566
- )
1248
+
1249
+ self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1567
1250
  self._optimizer_params = optimizer_params
1568
1251
  if isinstance(scheduler, str):
1569
1252
  self._scheduler_name = scheduler
1570
1253
  elif scheduler is not None:
1571
1254
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1572
- self._scheduler_name = getattr(
1573
- scheduler,
1574
- "__name__",
1575
- getattr(scheduler.__class__, "__name__", str(scheduler)),
1576
- )
1255
+ self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1577
1256
  else:
1578
1257
  self._scheduler_name = None
1579
1258
  self._scheduler_params = scheduler_params or {}
1580
1259
  self._loss_config = loss
1581
-
1260
+
1582
1261
  # set optimizer
1583
1262
  self.optimizer_fn = get_optimizer_fn(
1584
- optimizer=optimizer, params=self.parameters(), **optimizer_params
1263
+ optimizer=optimizer,
1264
+ params=self.parameters(),
1265
+ **optimizer_params
1585
1266
  )
1586
-
1267
+
1587
1268
  # Set loss function based on training mode
1588
1269
  loss_value = loss[0] if isinstance(loss, list) else loss
1589
- self.loss_fn = [
1590
- get_loss_fn(
1591
- task_type="match", training_mode=self.training_mode, loss=loss_value
1592
- )
1593
- ]
1594
-
1270
+ self.loss_fn = [get_loss_fn(
1271
+ task_type='match',
1272
+ training_mode=self.training_mode,
1273
+ loss=loss_value
1274
+ )]
1275
+
1595
1276
  # set scheduler
1596
- self.scheduler_fn = (
1597
- get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {}))
1598
- if scheduler
1599
- else None
1600
- )
1277
+ self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1601
1278
 
1602
- def compute_similarity(
1603
- self, user_emb: torch.Tensor, item_emb: torch.Tensor
1604
- ) -> torch.Tensor:
1605
- if self.similarity_metric == "dot":
1279
+ def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1280
+ if self.similarity_metric == 'dot':
1606
1281
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1607
1282
  # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
1608
- similarity = torch.sum(
1609
- user_emb * item_emb, dim=-1
1610
- ) # [batch_size, num_items]
1283
+ similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
1611
1284
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
1612
1285
  # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
1613
1286
  user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
1614
- similarity = torch.sum(
1615
- user_emb_expanded * item_emb, dim=-1
1616
- ) # [batch_size, num_items]
1287
+ similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
1617
1288
  else:
1618
1289
  similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
1619
-
1620
- elif self.similarity_metric == "cosine":
1290
+
1291
+ elif self.similarity_metric == 'cosine':
1621
1292
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1622
1293
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1623
1294
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1625,8 +1296,8 @@ class BaseMatchModel(BaseModel):
1625
1296
  similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
1626
1297
  else:
1627
1298
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1628
-
1629
- elif self.similarity_metric == "euclidean":
1299
+
1300
+ elif self.similarity_metric == 'euclidean':
1630
1301
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1631
1302
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1632
1303
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1634,96 +1305,84 @@ class BaseMatchModel(BaseModel):
1634
1305
  distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
1635
1306
  else:
1636
1307
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1637
- similarity = -distance
1638
-
1308
+ similarity = -distance
1309
+
1639
1310
  else:
1640
1311
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1641
-
1312
+
1642
1313
  similarity = similarity / self.temperature
1643
-
1314
+
1644
1315
  return similarity
1645
-
1316
+
1646
1317
  def user_tower(self, user_input: dict) -> torch.Tensor:
1647
1318
  raise NotImplementedError
1648
-
1319
+
1649
1320
  def item_tower(self, item_input: dict) -> torch.Tensor:
1650
1321
  raise NotImplementedError
1651
-
1652
- def forward(
1653
- self, X_input: dict
1654
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1322
+
1323
+ def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1655
1324
  user_input = self.get_user_features(X_input)
1656
1325
  item_input = self.get_item_features(X_input)
1657
-
1658
- user_emb = self.user_tower(user_input) # [B, D]
1659
- item_emb = self.item_tower(item_input) # [B, D]
1660
-
1661
- if self.training and self.training_mode in ["pairwise", "listwise"]:
1326
+
1327
+ user_emb = self.user_tower(user_input) # [B, D]
1328
+ item_emb = self.item_tower(item_input) # [B, D]
1329
+
1330
+ if self.training and self.training_mode in ['pairwise', 'listwise']:
1662
1331
  return user_emb, item_emb
1663
1332
 
1664
1333
  similarity = self.compute_similarity(user_emb, item_emb) # [B]
1665
-
1666
- if self.training_mode == "pointwise":
1334
+
1335
+ if self.training_mode == 'pointwise':
1667
1336
  return torch.sigmoid(similarity)
1668
1337
  else:
1669
1338
  return similarity
1670
-
1339
+
1671
1340
  def compute_loss(self, y_pred, y_true):
1672
- if self.training_mode == "pointwise":
1341
+ if self.training_mode == 'pointwise':
1673
1342
  if y_true is None:
1674
1343
  return torch.tensor(0.0, device=self.device)
1675
1344
  return self.loss_fn[0](y_pred, y_true)
1676
-
1345
+
1677
1346
  # pairwise / listwise using inbatch neg
1678
- elif self.training_mode in ["pairwise", "listwise"]:
1347
+ elif self.training_mode in ['pairwise', 'listwise']:
1679
1348
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1680
1349
  raise ValueError(
1681
1350
  "For pairwise/listwise training, forward should return (user_emb, item_emb). "
1682
1351
  "Please check BaseMatchModel.forward implementation."
1683
1352
  )
1684
-
1353
+
1685
1354
  user_emb, item_emb = y_pred # [B, D], [B, D]
1686
-
1355
+
1687
1356
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1688
- logits = logits / self.temperature
1689
-
1357
+ logits = logits / self.temperature
1358
+
1690
1359
  batch_size = logits.size(0)
1691
- targets = torch.arange(
1692
- batch_size, device=logits.device
1693
- ) # [0, 1, 2, ..., B-1]
1694
-
1360
+ targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
1361
+
1695
1362
  # Cross-Entropy = InfoNCE
1696
1363
  loss = F.cross_entropy(logits, targets)
1697
1364
  return loss
1698
-
1365
+
1699
1366
  else:
1700
1367
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1701
-
1368
+
1702
1369
  def _set_metrics(self, metrics: list[str] | None = None):
1703
1370
  if metrics is not None and len(metrics) > 0:
1704
1371
  self.metrics = [m.lower() for m in metrics]
1705
1372
  else:
1706
- self.metrics = ["auc", "logloss"]
1707
-
1708
- self.best_metrics_mode = "max"
1709
-
1710
- if not hasattr(self, "early_stopper") or self.early_stopper is None:
1711
- self.early_stopper = EarlyStopper(
1712
- patience=self.early_stop_patience, mode=self.best_metrics_mode
1713
- )
1714
-
1715
- def encode_user(
1716
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
1717
- ) -> np.ndarray:
1373
+ self.metrics = ['auc', 'logloss']
1374
+
1375
+ self.best_metrics_mode = 'max'
1376
+
1377
+ if not hasattr(self, 'early_stopper') or self.early_stopper is None:
1378
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
1379
+
1380
+ def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1718
1381
  self.eval()
1719
-
1382
+
1720
1383
  if not isinstance(data, DataLoader):
1721
1384
  user_data = {}
1722
- all_user_features = (
1723
- self.user_dense_features
1724
- + self.user_sparse_features
1725
- + self.user_sequence_features
1726
- )
1385
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1727
1386
  for feature in all_user_features:
1728
1387
  if isinstance(data, dict):
1729
1388
  if feature.name in data:
@@ -1731,39 +1390,29 @@ class BaseMatchModel(BaseModel):
1731
1390
  elif isinstance(data, pd.DataFrame):
1732
1391
  if feature.name in data.columns:
1733
1392
  user_data[feature.name] = data[feature.name].values
1734
-
1735
- data_loader = self._prepare_data_loader(
1736
- user_data, batch_size=batch_size, shuffle=False
1737
- )
1393
+
1394
+ data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1738
1395
  else:
1739
1396
  data_loader = data
1740
-
1397
+
1741
1398
  embeddings_list = []
1742
-
1399
+
1743
1400
  with torch.no_grad():
1744
- for batch_data in tqdm.tqdm(
1745
- data_loader, desc="Encoding users", disable=self._verbose == 0
1746
- ):
1401
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding users", disable=self._verbose == 0):
1747
1402
  batch_dict = self._batch_to_dict(batch_data)
1748
1403
  user_input = self.get_user_features(batch_dict)
1749
1404
  user_emb = self.user_tower(user_input)
1750
1405
  embeddings_list.append(user_emb.cpu().numpy())
1751
-
1406
+
1752
1407
  embeddings = np.concatenate(embeddings_list, axis=0)
1753
1408
  return embeddings
1754
-
1755
- def encode_item(
1756
- self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
1757
- ) -> np.ndarray:
1409
+
1410
+ def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1758
1411
  self.eval()
1759
-
1412
+
1760
1413
  if not isinstance(data, DataLoader):
1761
1414
  item_data = {}
1762
- all_item_features = (
1763
- self.item_dense_features
1764
- + self.item_sparse_features
1765
- + self.item_sequence_features
1766
- )
1415
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1767
1416
  for feature in all_item_features:
1768
1417
  if isinstance(data, dict):
1769
1418
  if feature.name in data:
@@ -1772,22 +1421,18 @@ class BaseMatchModel(BaseModel):
1772
1421
  if feature.name in data.columns:
1773
1422
  item_data[feature.name] = data[feature.name].values
1774
1423
 
1775
- data_loader = self._prepare_data_loader(
1776
- item_data, batch_size=batch_size, shuffle=False
1777
- )
1424
+ data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1778
1425
  else:
1779
1426
  data_loader = data
1780
-
1427
+
1781
1428
  embeddings_list = []
1782
-
1429
+
1783
1430
  with torch.no_grad():
1784
- for batch_data in tqdm.tqdm(
1785
- data_loader, desc="Encoding items", disable=self._verbose == 0
1786
- ):
1431
+ for batch_data in tqdm.tqdm(data_loader, desc="Encoding items", disable=self._verbose == 0):
1787
1432
  batch_dict = self._batch_to_dict(batch_data)
1788
1433
  item_input = self.get_item_features(batch_dict)
1789
1434
  item_emb = self.item_tower(item_input)
1790
1435
  embeddings_list.append(item_emb.cpu().numpy())
1791
-
1436
+
1792
1437
  embeddings = np.concatenate(embeddings_list, axis=0)
1793
1438
  return embeddings