nextrec 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +10 -23
  3. nextrec/basic/layers.py +18 -61
  4. nextrec/basic/loggers.py +1 -1
  5. nextrec/basic/metrics.py +55 -33
  6. nextrec/basic/model.py +258 -394
  7. nextrec/data/__init__.py +2 -2
  8. nextrec/data/data_utils.py +80 -4
  9. nextrec/data/dataloader.py +36 -57
  10. nextrec/data/preprocessor.py +5 -4
  11. nextrec/models/generative/__init__.py +5 -0
  12. nextrec/models/generative/hstu.py +399 -0
  13. nextrec/models/match/dssm.py +2 -2
  14. nextrec/models/match/dssm_v2.py +2 -2
  15. nextrec/models/match/mind.py +2 -2
  16. nextrec/models/match/sdm.py +2 -2
  17. nextrec/models/match/youtube_dnn.py +2 -2
  18. nextrec/models/multi_task/esmm.py +1 -1
  19. nextrec/models/multi_task/mmoe.py +1 -1
  20. nextrec/models/multi_task/ple.py +1 -1
  21. nextrec/models/multi_task/poso.py +1 -1
  22. nextrec/models/multi_task/share_bottom.py +1 -1
  23. nextrec/models/ranking/afm.py +1 -1
  24. nextrec/models/ranking/autoint.py +1 -1
  25. nextrec/models/ranking/dcn.py +1 -1
  26. nextrec/models/ranking/deepfm.py +1 -1
  27. nextrec/models/ranking/dien.py +1 -1
  28. nextrec/models/ranking/din.py +1 -1
  29. nextrec/models/ranking/fibinet.py +1 -1
  30. nextrec/models/ranking/fm.py +1 -1
  31. nextrec/models/ranking/masknet.py +2 -2
  32. nextrec/models/ranking/pnn.py +1 -1
  33. nextrec/models/ranking/widedeep.py +1 -1
  34. nextrec/models/ranking/xdeepfm.py +1 -1
  35. nextrec/utils/__init__.py +2 -1
  36. nextrec/utils/common.py +21 -2
  37. nextrec/utils/optimizer.py +7 -3
  38. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/METADATA +10 -4
  39. nextrec-0.3.3.dist-info/RECORD +57 -0
  40. nextrec-0.3.1.dist-info/RECORD +0 -56
  41. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
  42. {nextrec-0.3.1.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -21,21 +21,22 @@ from typing import Union, Literal, Any
21
21
  from torch.utils.data import DataLoader
22
22
 
23
23
  from nextrec.basic.callback import EarlyStopper
24
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
24
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
25
25
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
26
 
27
27
  from nextrec.basic.loggers import setup_logger, colorize
28
28
  from nextrec.basic.session import resolve_save_path, create_session
29
- from nextrec.basic.metrics import configure_metrics, evaluate_metrics
29
+ from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
30
30
 
31
- from nextrec.data import get_column_data, collate_fn
32
31
  from nextrec.data.dataloader import build_tensors_from_data
32
+ from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
33
33
 
34
34
  from nextrec.loss import get_loss_fn, get_loss_kwargs
35
- from nextrec.utils import get_optimizer, get_scheduler
35
+ from nextrec.utils import get_optimizer, get_scheduler, to_tensor
36
+
36
37
  from nextrec import __version__
37
38
 
38
- class BaseModel(FeatureSpecMixin, nn.Module):
39
+ class BaseModel(FeatureSet, nn.Module):
39
40
  @property
40
41
  def model_name(self) -> str:
41
42
  raise NotImplementedError
@@ -69,72 +70,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
69
70
  self.session_id = session_id
70
71
  self.session = create_session(session_id)
71
72
  self.session_path = self.session.root # pwd/session_id, path for this session
72
- self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint"+".model")
73
- self.best_path = os.path.join(self.session_path, self.model_name+ "_best.model")
73
+ self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
74
+ self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
74
75
  self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
75
- self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
76
- self.target = self.target_columns
77
- self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
76
+ self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
78
77
 
79
78
  self.task = task
80
79
  self.nums_task = len(task) if isinstance(task, list) else 1
81
80
 
82
- self._embedding_l1_reg = embedding_l1_reg
83
- self._dense_l1_reg = dense_l1_reg
84
- self._embedding_l2_reg = embedding_l2_reg
85
- self._dense_l2_reg = dense_l2_reg
86
- self._regularization_weights = []
87
- self._embedding_params = []
88
- self._loss_weights: float | list[float] | None = None
89
- self._early_stop_patience = early_stop_patience
90
- self._max_gradient_norm = 1.0
91
- self._logger_initialized = False
92
-
93
- def _register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
81
+ self.embedding_l1_reg = embedding_l1_reg
82
+ self.dense_l1_reg = dense_l1_reg
83
+ self.embedding_l2_reg = embedding_l2_reg
84
+ self.dense_l2_reg = dense_l2_reg
85
+ self.regularization_weights = []
86
+ self.embedding_params = []
87
+ self.loss_weight = None
88
+ self.early_stop_patience = early_stop_patience
89
+ self.max_gradient_norm = 1.0
90
+ self.logger_initialized = False
91
+
92
+ def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
94
93
  exclude_modules = exclude_modules or []
95
94
  include_modules = include_modules or []
96
- if hasattr(self, embedding_attr):
97
- embedding_layer = getattr(self, embedding_attr)
98
- if hasattr(embedding_layer, "embed_dict"):
99
- for embed in embedding_layer.embed_dict.values():
100
- self._embedding_params.append(embed.weight)
95
+ embedding_layer = getattr(self, embedding_attr, None)
96
+ embed_dict = getattr(embedding_layer, "embed_dict", None)
97
+ if embed_dict is not None:
98
+ self.embedding_params.extend(embed.weight for embed in embed_dict.values())
99
+ skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
101
100
  for name, module in self.named_modules():
102
- if module is self:
103
- continue
104
- if embedding_attr in name:
105
- continue
106
- if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
107
- continue
108
- if include_modules:
109
- if not any(inc_name in name for inc_name in include_modules):
110
- continue
111
- if any(exc_name in name for exc_name in exclude_modules):
101
+ if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
112
102
  continue
113
103
  if isinstance(module, nn.Linear):
114
- self._regularization_weights.append(module.weight)
104
+ self.regularization_weights.append(module.weight)
115
105
 
116
106
  def add_reg_loss(self) -> torch.Tensor:
117
107
  reg_loss = torch.tensor(0.0, device=self.device)
118
- if self._embedding_params:
119
- if self._embedding_l1_reg > 0:
120
- reg_loss += self._embedding_l1_reg * sum(param.abs().sum() for param in self._embedding_params)
121
- if self._embedding_l2_reg > 0:
122
- reg_loss += self._embedding_l2_reg * sum((param ** 2).sum() for param in self._embedding_params)
123
- if self._regularization_weights:
124
- if self._dense_l1_reg > 0:
125
- reg_loss += self._dense_l1_reg * sum(param.abs().sum() for param in self._regularization_weights)
126
- if self._dense_l2_reg > 0:
127
- reg_loss += self._dense_l2_reg * sum((param ** 2).sum() for param in self._regularization_weights)
108
+ if self.embedding_params:
109
+ if self.embedding_l1_reg > 0:
110
+ reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
111
+ if self.embedding_l2_reg > 0:
112
+ reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
113
+ if self.regularization_weights:
114
+ if self.dense_l1_reg > 0:
115
+ reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
116
+ if self.dense_l2_reg > 0:
117
+ reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
128
118
  return reg_loss
129
119
 
130
- def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
131
- tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
132
- if tensor.dtype != dtype:
133
- tensor = tensor.to(dtype=dtype)
134
- if tensor.device != self.device:
135
- tensor = tensor.to(self.device)
136
- return tensor
137
-
138
120
  def get_input(self, input_data: dict, require_labels: bool = True):
139
121
  feature_source = input_data.get("features", {})
140
122
  label_source = input_data.get("labels")
@@ -143,12 +125,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
143
125
  if feature.name not in feature_source:
144
126
  raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
145
127
  feature_data = get_column_data(feature_source, feature.name)
146
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
147
- X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
128
+ X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
148
129
  y = None
149
- if (len(self.target) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target)))): # need labels: training or eval with labels
130
+ if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
150
131
  target_tensors = []
151
- for target_name in self.target:
132
+ for target_name in self.target_columns:
152
133
  if label_source is None or target_name not in label_source:
153
134
  if require_labels:
154
135
  raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
@@ -158,7 +139,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
158
139
  if require_labels:
159
140
  raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
160
141
  continue
161
- target_tensor = self._to_tensor(target_data, dtype=torch.float32)
142
+ target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
162
143
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
163
144
  target_tensors.append(target_tensor)
164
145
  if target_tensors:
@@ -169,11 +150,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
169
150
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
170
151
  return X_input, y
171
152
 
172
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
173
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
174
- self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
175
-
176
- def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
153
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
154
+ """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
177
155
  if not (0 < validation_split < 1):
178
156
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
179
157
  if not isinstance(train_data, (pd.DataFrame, dict)):
@@ -181,8 +159,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
181
159
  if isinstance(train_data, pd.DataFrame):
182
160
  total_length = len(train_data)
183
161
  else:
184
- sample_key = next(iter(train_data))
185
- total_length = len(train_data[sample_key])
162
+ sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
163
+ total_length = len(train_data[sample_key]) # len(train_data['user_id'])
186
164
  for k, v in train_data.items():
187
165
  if len(v) != total_length:
188
166
  raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
@@ -198,67 +176,62 @@ class BaseModel(FeatureSpecMixin, nn.Module):
198
176
  train_split = {}
199
177
  valid_split = {}
200
178
  for key, value in train_data.items():
201
- if isinstance(value, np.ndarray):
202
- train_split[key] = value[train_indices]
203
- valid_split[key] = value[valid_indices]
204
- elif isinstance(value, (list, tuple)):
205
- arr = np.asarray(value)
206
- train_split[key] = arr[train_indices].tolist()
207
- valid_split[key] = arr[valid_indices].tolist()
208
- elif isinstance(value, pd.Series):
209
- train_split[key] = value.iloc[train_indices].values
210
- valid_split[key] = value.iloc[valid_indices].values
211
- else:
212
- train_split[key] = [value[i] for i in train_indices]
213
- valid_split[key] = [value[i] for i in valid_indices]
214
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
179
+ arr = np.asarray(value)
180
+ train_split[key] = arr[train_indices]
181
+ valid_split[key] = arr[valid_indices]
182
+ train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
215
183
  logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
216
184
  return train_loader, valid_split
217
185
 
218
186
  def compile(
219
- self, optimizer="adam", optimizer_params: dict | None = None,
220
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None, scheduler_params: dict | None = None,
221
- loss: str | nn.Module | list[str | nn.Module] | None = "bce", loss_params: dict | list[dict] | None = None,
222
- loss_weights: int | float | list[int | float] | None = None,):
187
+ self,
188
+ optimizer: str | torch.optim.Optimizer = "adam",
189
+ optimizer_params: dict | None = None,
190
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
191
+ scheduler_params: dict | None = None,
192
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
193
+ loss_params: dict | list[dict] | None = None,
194
+ loss_weights: int | float | list[int | float] | None = None,
195
+ ):
223
196
  optimizer_params = optimizer_params or {}
224
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
225
- self._optimizer_params = optimizer_params
197
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
198
+ self.optimizer_params = optimizer_params
226
199
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
227
200
 
228
201
  scheduler_params = scheduler_params or {}
229
202
  if isinstance(scheduler, str):
230
- self._scheduler_name = scheduler
203
+ self.scheduler_name = scheduler
231
204
  elif scheduler is None:
232
- self._scheduler_name = None
233
- else:
234
- self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
235
- self._scheduler_params = scheduler_params
205
+ self.scheduler_name = None
206
+ else: # for custom scheduler instance, need to provide class name for logging
207
+ self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
208
+ self.scheduler_params = scheduler_params
236
209
  self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
237
210
 
238
- self._loss_config = loss
239
- self._loss_params = loss_params or {}
211
+ self.loss_config = loss
212
+ self.loss_params = loss_params or {}
240
213
  self.loss_fn = []
241
- for i in range(self.nums_task):
242
- if isinstance(loss, list):
243
- loss_value = loss[i] if i < len(loss) else None
244
- else:
245
- loss_value = loss
246
- if self.nums_task == 1: # single task
247
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else self._loss_params[0]
248
- else:
249
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
250
- self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
251
- # Normalize loss weights for single-task and multi-task setups
214
+ if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
215
+ loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
216
+ else: # for example: 'bce' -> ['bce', 'bce']
217
+ loss_list = [loss] * self.nums_task
218
+
219
+ if isinstance(self.loss_params, dict):
220
+ params_list = [self.loss_params] * self.nums_task
221
+ else: # list[dict]
222
+ params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
223
+ self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
224
+
252
225
  if loss_weights is None:
253
- self._loss_weights = None
226
+ self.loss_weights = None
254
227
  elif self.nums_task == 1:
255
228
  if isinstance(loss_weights, (list, tuple)):
256
- if len(loss_weights) != 1:
229
+ if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
257
230
  raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
258
231
  weight_value = loss_weights[0]
259
232
  else:
260
233
  weight_value = loss_weights
261
- self._loss_weights = float(weight_value)
234
+ self.loss_weights = float(weight_value)
262
235
  else:
263
236
  if isinstance(loss_weights, (int, float)):
264
237
  weights = [float(loss_weights)] * self.nums_task
@@ -268,87 +241,68 @@ class BaseModel(FeatureSpecMixin, nn.Module):
268
241
  raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
269
242
  else:
270
243
  raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
271
- self._loss_weights = weights
244
+ self.loss_weights = weights
272
245
 
273
246
  def compute_loss(self, y_pred, y_true):
274
247
  if y_true is None:
275
248
  raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
276
249
  if self.nums_task == 1:
277
250
  loss = self.loss_fn[0](y_pred, y_true)
278
- if self._loss_weights is not None:
279
- loss = loss * self._loss_weights
251
+ if self.loss_weights is not None:
252
+ loss = loss * self.loss_weights
280
253
  return loss
281
254
  else:
282
255
  task_losses = []
283
256
  for i in range(self.nums_task):
284
257
  task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
285
- if isinstance(self._loss_weights, (list, tuple)):
286
- task_loss = task_loss * self._loss_weights[i]
258
+ if isinstance(self.loss_weights, (list, tuple)):
259
+ task_loss = task_loss * self.loss_weights[i]
287
260
  task_losses.append(task_loss)
288
261
  return torch.stack(task_losses).sum()
289
262
 
290
- def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
263
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
291
264
  if isinstance(data, DataLoader):
292
265
  return data
293
- tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
266
+ tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
294
267
  if tensors is None:
295
268
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
296
269
  dataset = TensorDictDataset(tensors)
297
270
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
298
271
 
299
- def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
300
- if not (isinstance(batch_data, dict) and "features" in batch_data):
301
- raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
302
- return {
303
- "features": batch_data.get("features", {}),
304
- "labels": batch_data.get("labels"),
305
- "ids": batch_data.get("ids") if include_ids else None,
306
- }
307
-
308
272
  def fit(self,
309
- train_data: dict|pd.DataFrame|DataLoader,
310
- valid_data: dict|pd.DataFrame|DataLoader|None=None,
311
- metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
273
+ train_data: dict | pd.DataFrame | DataLoader,
274
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
275
+ metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
312
276
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
313
- user_id_column: str = 'user_id',
277
+ user_id_column: str | None = None,
314
278
  validation_split: float | None = None):
315
279
  self.to(self.device)
316
- if not self._logger_initialized:
280
+ if not self.logger_initialized:
317
281
  setup_logger(session_id=self.session_id)
318
- self._logger_initialized = True
319
- self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
320
- self.summary()
321
- valid_loader = None
322
- valid_user_ids: np.ndarray | None = None
323
- needs_user_ids: bool = self._needs_user_ids_for_metrics()
282
+ self.logger_initialized = True
283
+
284
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
285
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
286
+ self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
287
+ self.epoch_index = 0
288
+ self.stop_training = False
289
+ self.best_checkpoint_path = self.best_path
290
+ self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
324
291
 
325
292
  if validation_split is not None and valid_data is None:
326
- train_loader, valid_data = self._handle_validation_split(
327
- train_data=train_data, # type: ignore
328
- validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
293
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
329
294
  else:
330
- train_loader = (train_data if isinstance(train_data, DataLoader) else self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
331
- if isinstance(valid_data, DataLoader):
332
- valid_loader = valid_data
333
- elif valid_data is not None:
334
- valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
335
- if needs_user_ids:
336
- if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
337
- valid_user_ids = np.asarray(valid_data[user_id_column].values)
338
- elif isinstance(valid_data, dict) and user_id_column in valid_data:
339
- valid_user_ids = np.asarray(valid_data[user_id_column])
295
+ train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
296
+
297
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
340
298
  try:
341
- self._steps_per_epoch = len(train_loader)
299
+ self.steps_per_epoch = len(train_loader)
342
300
  is_streaming = False
343
- except TypeError: # len() not supported, e.g., streaming data loader
344
- self._steps_per_epoch = None
301
+ except TypeError: # streaming data loader does not supported len()
302
+ self.steps_per_epoch = None
345
303
  is_streaming = True
346
304
 
347
- self._epoch_index = 0
348
- self._stop_training = False
349
- self._best_checkpoint_path = self.best_path
350
- self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
351
-
305
+ self.summary()
352
306
  logging.info("")
353
307
  logging.info(colorize("=" * 80, bold=True))
354
308
  if is_streaming:
@@ -360,36 +314,34 @@ class BaseModel(FeatureSpecMixin, nn.Module):
360
314
  logging.info(colorize(f"Model device: {self.device}", bold=True))
361
315
 
362
316
  for epoch in range(epochs):
363
- self._epoch_index = epoch
317
+ self.epoch_index = epoch
364
318
  if is_streaming:
365
319
  logging.info("")
366
320
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
367
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
368
- if isinstance(train_result, tuple):
321
+
322
+ # handle train result
323
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
324
+ if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
369
325
  train_loss, train_metrics = train_result
370
326
  else:
371
327
  train_loss = train_result
372
328
  train_metrics = None
329
+
330
+ # handle logging for single-task and multi-task
373
331
  if self.nums_task == 1:
374
332
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
375
333
  if train_metrics:
376
334
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
377
335
  log_str += f", {metrics_str}"
378
- logging.info(colorize(log_str, color="white"))
336
+ logging.info(colorize(log_str))
379
337
  else:
380
- task_labels = []
381
- for i in range(self.nums_task):
382
- if i < len(self.target):
383
- task_labels.append(self.target[i])
384
- else:
385
- task_labels.append(f"task_{i}")
386
338
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
387
339
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
388
340
  if train_metrics:
389
- # Group metrics by task
341
+ # group metrics by task
390
342
  task_metrics = {}
391
343
  for metric_key, metric_value in train_metrics.items():
392
- for target_name in self.target:
344
+ for target_name in self.target_columns:
393
345
  if metric_key.endswith(f"_{target_name}"):
394
346
  if target_name not in task_metrics:
395
347
  task_metrics[target_name] = {}
@@ -398,15 +350,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
398
350
  break
399
351
  if task_metrics:
400
352
  task_metric_strs = []
401
- for target_name in self.target:
353
+ for target_name in self.target_columns:
402
354
  if target_name in task_metrics:
403
355
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
404
356
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
405
357
  log_str += ", " + ", ".join(task_metric_strs)
406
- logging.info(colorize(log_str, color="white"))
358
+ logging.info(colorize(log_str))
407
359
  if valid_loader is not None:
408
- # Pass user_ids only if needed for GAUC metric
409
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
360
+ # pass user_ids only if needed for GAUC metric
361
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
410
362
  if self.nums_task == 1:
411
363
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
412
364
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
@@ -414,7 +366,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
414
366
  # multi task metrics
415
367
  task_metrics = {}
416
368
  for metric_key, metric_value in val_metrics.items():
417
- for target_name in self.target:
369
+ for target_name in self.target_columns:
418
370
  if metric_key.endswith(f"_{target_name}"):
419
371
  if target_name not in task_metrics:
420
372
  task_metrics[target_name] = {}
@@ -422,7 +374,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
422
374
  task_metrics[target_name][metric_name] = metric_value
423
375
  break
424
376
  task_metric_strs = []
425
- for target_name in self.target:
377
+ for target_name in self.target_columns:
426
378
  if target_name in task_metrics:
427
379
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
428
380
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
@@ -430,45 +382,42 @@ class BaseModel(FeatureSpecMixin, nn.Module):
430
382
  # Handle empty validation metrics
431
383
  if not val_metrics:
432
384
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
433
- self._best_checkpoint_path = self.checkpoint_path
385
+ self.best_checkpoint_path = self.checkpoint_path
434
386
  logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
435
387
  continue
436
388
  if self.nums_task == 1:
437
389
  primary_metric_key = self.metrics[0]
438
390
  else:
439
- primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
440
-
441
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
391
+ primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
392
+ primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
442
393
  improved = False
443
-
394
+ # early stopping check
444
395
  if self.best_metrics_mode == 'max':
445
- if primary_metric > self._best_metric:
446
- self._best_metric = primary_metric
447
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
396
+ if primary_metric > self.best_metric:
397
+ self.best_metric = primary_metric
448
398
  improved = True
449
399
  else:
450
- if primary_metric < self._best_metric:
451
- self._best_metric = primary_metric
400
+ if primary_metric < self.best_metric:
401
+ self.best_metric = primary_metric
452
402
  improved = True
453
- # Always keep the latest weights as a rolling checkpoint
454
403
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
455
404
  if improved:
456
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}"))
405
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
457
406
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
458
- self._best_checkpoint_path = self.best_path
407
+ self.best_checkpoint_path = self.best_path
459
408
  self.early_stopper.trial_counter = 0
460
409
  else:
461
410
  self.early_stopper.trial_counter += 1
462
411
  logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
463
412
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
464
- self._stop_training = True
413
+ self.stop_training = True
465
414
  logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
466
415
  break
467
416
  else:
468
417
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
469
418
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
470
- self._best_checkpoint_path = self.best_path
471
- if self._stop_training:
419
+ self.best_checkpoint_path = self.best_path
420
+ if self.stop_training:
472
421
  break
473
422
  if self.scheduler_fn is not None:
474
423
  if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
@@ -476,34 +425,29 @@ class BaseModel(FeatureSpecMixin, nn.Module):
476
425
  self.scheduler_fn.step(primary_metric)
477
426
  else:
478
427
  self.scheduler_fn.step()
479
- logging.info("\n")
480
- logging.info(colorize("Training finished.", color="bright_green", bold=True))
481
- logging.info("\n")
428
+ logging.info(" ")
429
+ logging.info(colorize("Training finished.", bold=True))
430
+ logging.info(" ")
482
431
  if valid_loader is not None:
483
- logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
484
- self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
432
+ logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
433
+ self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
485
434
  return self
486
435
 
487
436
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
488
- if self.nums_task == 1:
489
- accumulated_loss = 0.0
490
- else:
491
- accumulated_loss = 0.0
437
+ accumulated_loss = 0.0
492
438
  self.train()
493
439
  num_batches = 0
494
440
  y_true_list = []
495
441
  y_pred_list = []
496
- needs_user_ids = self._needs_user_ids_for_metrics()
497
- user_ids_list = [] if needs_user_ids else None
498
- if self._steps_per_epoch is not None:
499
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
442
+
443
+ user_ids_list = [] if self.needs_user_ids else None
444
+ if self.steps_per_epoch is not None:
445
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
500
446
  else:
501
- if is_streaming:
502
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
503
- else:
504
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
447
+ desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
448
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
505
449
  for batch_index, batch_data in batch_iter:
506
- batch_dict = self._batch_to_dict(batch_data)
450
+ batch_dict = batch_to_dict(batch_data)
507
451
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
508
452
  y_pred = self.forward(X_input)
509
453
  loss = self.compute_loss(y_pred, y_true)
@@ -511,66 +455,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
511
455
  total_loss = loss + reg_loss
512
456
  self.optimizer_fn.zero_grad()
513
457
  total_loss.backward()
514
- nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
458
+ nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
515
459
  self.optimizer_fn.step()
516
- if self.nums_task == 1:
517
- accumulated_loss += loss.item()
518
- else:
519
- accumulated_loss += loss.item()
460
+ accumulated_loss += loss.item()
520
461
  if y_true is not None:
521
- y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
522
- if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
523
- batch_user_id = None
524
- if self.id_columns:
525
- for id_name in self.id_columns:
526
- if id_name in batch_dict["ids"]:
527
- batch_user_id = batch_dict["ids"][id_name]
528
- break
529
- if batch_user_id is None and batch_dict["ids"]:
530
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
462
+ y_true_list.append(y_true.detach().cpu().numpy())
463
+ if self.needs_user_ids and user_ids_list is not None:
464
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
531
465
  if batch_user_id is not None:
532
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
533
- user_ids_list.append(ids_np.reshape(ids_np.shape[0]))
534
- if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
466
+ user_ids_list.append(batch_user_id)
467
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
535
468
  y_pred_list.append(y_pred.detach().cpu().numpy())
536
469
  num_batches += 1
537
- avg_loss = accumulated_loss / num_batches
470
+ avg_loss = accumulated_loss / max(num_batches, 1)
538
471
  if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
539
472
  y_true_all = np.concatenate(y_true_list, axis=0)
540
473
  y_pred_all = np.concatenate(y_pred_list, axis=0)
541
474
  combined_user_ids = None
542
- if needs_user_ids and user_ids_list:
475
+ if self.needs_user_ids and user_ids_list:
543
476
  combined_user_ids = np.concatenate(user_ids_list, axis=0)
544
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=combined_user_ids)
477
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
545
478
  return avg_loss, metrics_dict
546
479
  return avg_loss
547
480
 
548
- def _needs_user_ids_for_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None) -> bool:
549
- """Check if any configured metric requires user_ids (e.g., gauc, ranking @K)."""
550
- metric_names = set()
551
- sources = [metrics if metrics is not None else getattr(self, "metrics", None), getattr(self, "task_specific_metrics", None),]
552
- for src in sources:
553
- stack = [src]
554
- while stack:
555
- item = stack.pop()
556
- if not item:
557
- continue
558
- if isinstance(item, dict):
559
- stack.extend(item.values())
560
- elif isinstance(item, str):
561
- metric_names.add(item.lower())
562
- else:
563
- try:
564
- for m in item:
565
- metric_names.add(m.lower())
566
- except TypeError:
567
- continue
568
- for name in metric_names:
569
- if name == "gauc":
570
- return True
571
- if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
572
- return True
573
- return False
481
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
482
+ if valid_data is None:
483
+ return None, None
484
+ if isinstance(valid_data, DataLoader):
485
+ return valid_data, None
486
+ valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
487
+ valid_user_ids = None
488
+ if needs_user_ids:
489
+ if user_id_column is None:
490
+ raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
491
+ valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
492
+ return valid_loader, valid_user_ids
574
493
 
575
494
  def evaluate(self,
576
495
  data: dict | pd.DataFrame | DataLoader,
@@ -582,18 +501,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
582
501
  eval_metrics = metrics if metrics is not None else self.metrics
583
502
  if eval_metrics is None:
584
503
  raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
585
- needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
504
+ needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
586
505
 
587
506
  if isinstance(data, DataLoader):
588
507
  data_loader = data
589
508
  else:
590
- # Extract user_ids if needed and not provided
591
509
  if user_ids is None and needs_user_ids:
592
- if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
593
- user_ids = np.asarray(data[user_id_column].values)
594
- elif isinstance(data, dict) and user_id_column in data:
595
- user_ids = np.asarray(data[user_id_column])
596
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
510
+ user_ids = get_user_ids(data=data, id_columns=user_id_column)
511
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
597
512
  y_true_list = []
598
513
  y_pred_list = []
599
514
  collected_user_ids = []
@@ -601,26 +516,17 @@ class BaseModel(FeatureSpecMixin, nn.Module):
601
516
  with torch.no_grad():
602
517
  for batch_data in data_loader:
603
518
  batch_count += 1
604
- batch_dict = self._batch_to_dict(batch_data)
519
+ batch_dict = batch_to_dict(batch_data)
605
520
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
606
521
  y_pred = self.forward(X_input)
607
522
  if y_true is not None:
608
523
  y_true_list.append(y_true.cpu().numpy())
609
- # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
610
524
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
611
525
  y_pred_list.append(y_pred.cpu().numpy())
612
- if needs_user_ids and user_ids is None and batch_dict.get("ids"):
613
- batch_user_id = None
614
- if self.id_columns:
615
- for id_name in self.id_columns:
616
- if id_name in batch_dict["ids"]:
617
- batch_user_id = batch_dict["ids"][id_name]
618
- break
619
- if batch_user_id is None and batch_dict["ids"]:
620
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
526
+ if needs_user_ids and user_ids is None:
527
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
621
528
  if batch_user_id is not None:
622
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
623
- collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
529
+ collected_user_ids.append(batch_user_id)
624
530
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
625
531
  if len(y_true_list) > 0:
626
532
  y_true_all = np.concatenate(y_true_list, axis=0)
@@ -649,23 +555,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
649
555
  final_user_ids = user_ids
650
556
  if final_user_ids is None and collected_user_ids:
651
557
  final_user_ids = np.concatenate(collected_user_ids, axis=0)
652
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
558
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
653
559
  return metrics_dict
654
560
 
655
- def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
656
- """Evaluate metrics using the metrics module."""
657
- task_specific_metrics = getattr(self, 'task_specific_metrics', None)
658
-
659
- return evaluate_metrics(
660
- y_true=y_true,
661
- y_pred=y_pred,
662
- metrics=metrics,
663
- task=self.task,
664
- target_names=self.target,
665
- task_specific_metrics=task_specific_metrics,
666
- user_ids=user_ids
667
- )
668
-
669
561
  def predict(
670
562
  self,
671
563
  data: str | dict | pd.DataFrame | DataLoader,
@@ -676,28 +568,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
676
568
  return_dataframe: bool = True,
677
569
  streaming_chunk_size: int = 10000,
678
570
  ) -> pd.DataFrame | np.ndarray:
679
- """
680
- Run inference and optionally return ID-aligned predictions.
681
-
682
- When ``id_columns`` are configured and ``include_ids`` is True (default),
683
- the returned object will include those IDs to keep a one-to-one mapping
684
- between each prediction and its source row.
685
- If ``save_path`` is provided and ``return_dataframe`` is False, predictions
686
- stream to disk batch-by-batch to avoid holding all outputs in memory.
687
- """
688
571
  self.eval()
689
572
  if include_ids is None:
690
573
  include_ids = bool(self.id_columns)
691
574
  include_ids = include_ids and bool(self.id_columns)
692
575
 
693
- # if saving to disk without returning dataframe, use streaming prediction
694
576
  if save_path is not None and not return_dataframe:
695
577
  return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
696
578
  if isinstance(data, (str, os.PathLike)):
697
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
579
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
698
580
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
699
581
  elif not isinstance(data, DataLoader):
700
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
582
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
701
583
  else:
702
584
  data_loader = data
703
585
 
@@ -707,7 +589,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
707
589
 
708
590
  with torch.no_grad():
709
591
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
710
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
592
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
711
593
  X_input, _ = self.get_input(batch_dict, require_labels=False)
712
594
  y_pred = self.forward(X_input)
713
595
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
@@ -717,10 +599,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
717
599
  if id_name not in batch_dict["ids"]:
718
600
  continue
719
601
  id_tensor = batch_dict["ids"][id_name]
720
- if isinstance(id_tensor, torch.Tensor):
721
- id_np = id_tensor.detach().cpu().numpy()
722
- else:
723
- id_np = np.asarray(id_tensor)
602
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
724
603
  id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
725
604
  if len(y_pred_list) > 0:
726
605
  y_pred_all = np.concatenate(y_pred_list, axis=0)
@@ -730,12 +609,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
730
609
  if y_pred_all.ndim == 1:
731
610
  y_pred_all = y_pred_all.reshape(-1, 1)
732
611
  if y_pred_all.size == 0:
733
- num_outputs = len(self.target) if self.target else 1
612
+ num_outputs = len(self.target_columns) if self.target_columns else 1
734
613
  y_pred_all = y_pred_all.reshape(0, num_outputs)
735
614
  num_outputs = y_pred_all.shape[1]
736
615
  pred_columns: list[str] = []
737
- if self.target:
738
- for name in self.target[:num_outputs]:
616
+ if self.target_columns:
617
+ for name in self.target_columns[:num_outputs]:
739
618
  pred_columns.append(f"{name}_pred")
740
619
  while len(pred_columns) < num_outputs:
741
620
  pred_columns.append(f"pred_{len(pred_columns)}")
@@ -789,10 +668,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
789
668
  return_dataframe: bool,
790
669
  ) -> pd.DataFrame:
791
670
  if isinstance(data, (str, os.PathLike)):
792
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
671
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
793
672
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
794
673
  elif not isinstance(data, DataLoader):
795
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
674
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
796
675
  else:
797
676
  data_loader = data
798
677
 
@@ -807,35 +686,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
807
686
 
808
687
  with torch.no_grad():
809
688
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
810
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
689
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
811
690
  X_input, _ = self.get_input(batch_dict, require_labels=False)
812
691
  y_pred = self.forward(X_input)
813
692
  if y_pred is None or not isinstance(y_pred, torch.Tensor):
814
693
  continue
815
-
816
694
  y_pred_np = y_pred.detach().cpu().numpy()
817
695
  if y_pred_np.ndim == 1:
818
696
  y_pred_np = y_pred_np.reshape(-1, 1)
819
-
820
697
  if pred_columns is None:
821
698
  num_outputs = y_pred_np.shape[1]
822
699
  pred_columns = []
823
- if self.target:
824
- for name in self.target[:num_outputs]:
700
+ if self.target_columns:
701
+ for name in self.target_columns[:num_outputs]:
825
702
  pred_columns.append(f"{name}_pred")
826
703
  while len(pred_columns) < num_outputs:
827
704
  pred_columns.append(f"pred_{len(pred_columns)}")
828
-
705
+
829
706
  id_arrays_batch: dict[str, np.ndarray] = {}
830
707
  if include_ids and self.id_columns and batch_dict.get("ids"):
831
708
  for id_name in self.id_columns:
832
709
  if id_name not in batch_dict["ids"]:
833
710
  continue
834
711
  id_tensor = batch_dict["ids"][id_name]
835
- if isinstance(id_tensor, torch.Tensor):
836
- id_np = id_tensor.detach().cpu().numpy()
837
- else:
838
- id_np = np.asarray(id_tensor)
712
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
839
713
  id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
840
714
 
841
715
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
@@ -876,7 +750,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
876
750
  config_path = self.features_config_path
877
751
  features_config = {
878
752
  "all_features": self.all_features,
879
- "target": self.target,
753
+ "target": self.target_columns,
880
754
  "id_columns": self.id_columns,
881
755
  "version": __version__,
882
756
  }
@@ -916,9 +790,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
916
790
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
917
791
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
918
792
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
919
- self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
920
- self.target = self.target_columns
921
- self.target_index = {name: idx for idx, name in enumerate(self.target)}
793
+ self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
794
+
922
795
  cfg_version = features_config.get("version")
923
796
  if verbose:
924
797
  logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
@@ -1051,36 +924,37 @@ class BaseModel(FeatureSpecMixin, nn.Module):
1051
924
  logger.info(f"Task Type: {self.task}")
1052
925
  logger.info(f"Number of Tasks: {self.nums_task}")
1053
926
  logger.info(f"Metrics: {self.metrics}")
1054
- logger.info(f"Target Columns: {self.target}")
927
+ logger.info(f"Target Columns: {self.target_columns}")
1055
928
  logger.info(f"Device: {self.device}")
1056
929
 
1057
- if hasattr(self, '_optimizer_name'):
1058
- logger.info(f"Optimizer: {self._optimizer_name}")
1059
- if self._optimizer_params:
1060
- for key, value in self._optimizer_params.items():
930
+ if hasattr(self, 'optimizer_name'):
931
+ logger.info(f"Optimizer: {self.optimizer_name}")
932
+ if self.optimizer_params:
933
+ for key, value in self.optimizer_params.items():
1061
934
  logger.info(f" {key:25s}: {value}")
1062
935
 
1063
- if hasattr(self, '_scheduler_name') and self._scheduler_name:
1064
- logger.info(f"Scheduler: {self._scheduler_name}")
1065
- if self._scheduler_params:
1066
- for key, value in self._scheduler_params.items():
936
+ if hasattr(self, 'scheduler_name') and self.scheduler_name:
937
+ logger.info(f"Scheduler: {self.scheduler_name}")
938
+ if self.scheduler_params:
939
+ for key, value in self.scheduler_params.items():
1067
940
  logger.info(f" {key:25s}: {value}")
1068
941
 
1069
- if hasattr(self, '_loss_config'):
1070
- logger.info(f"Loss Function: {self._loss_config}")
1071
- if hasattr(self, '_loss_weights'):
1072
- logger.info(f"Loss Weights: {self._loss_weights}")
942
+ if hasattr(self, 'loss_config'):
943
+ logger.info(f"Loss Function: {self.loss_config}")
944
+ if hasattr(self, 'loss_weights'):
945
+ logger.info(f"Loss Weights: {self.loss_weights}")
1073
946
 
1074
947
  logger.info("Regularization:")
1075
- logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1076
- logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1077
- logger.info(f" Dense L1: {self._dense_l1_reg}")
1078
- logger.info(f" Dense L2: {self._dense_l2_reg}")
948
+ logger.info(f" Embedding L1: {self.embedding_l1_reg}")
949
+ logger.info(f" Embedding L2: {self.embedding_l2_reg}")
950
+ logger.info(f" Dense L1: {self.dense_l1_reg}")
951
+ logger.info(f" Dense L2: {self.dense_l2_reg}")
1079
952
 
1080
953
  logger.info("Other Settings:")
1081
- logger.info(f" Early Stop Patience: {self._early_stop_patience}")
1082
- logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
954
+ logger.info(f" Early Stop Patience: {self.early_stop_patience}")
955
+ logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1083
956
  logger.info(f" Session ID: {self.session_id}")
957
+ logger.info(f" Features Config Path: {self.features_config_path}")
1084
958
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
1085
959
 
1086
960
  logger.info("")
@@ -1195,7 +1069,7 @@ class BaseMatchModel(BaseModel):
1195
1069
  def compile(self,
1196
1070
  optimizer: str | torch.optim.Optimizer = "adam",
1197
1071
  optimizer_params: dict | None = None,
1198
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1072
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
1199
1073
  scheduler_params: dict | None = None,
1200
1074
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1201
1075
  loss_params: dict | list[dict] | None = None):
@@ -1208,18 +1082,18 @@ class BaseMatchModel(BaseModel):
1208
1082
  # Call parent compile with match-specific logic
1209
1083
  optimizer_params = optimizer_params or {}
1210
1084
 
1211
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1212
- self._optimizer_params = optimizer_params
1085
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1086
+ self.optimizer_params = optimizer_params
1213
1087
  if isinstance(scheduler, str):
1214
- self._scheduler_name = scheduler
1088
+ self.scheduler_name = scheduler
1215
1089
  elif scheduler is not None:
1216
1090
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1217
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1091
+ self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1218
1092
  else:
1219
- self._scheduler_name = None
1220
- self._scheduler_params = scheduler_params or {}
1221
- self._loss_config = loss
1222
- self._loss_params = loss_params or {}
1093
+ self.scheduler_name = None
1094
+ self.scheduler_params = scheduler_params or {}
1095
+ self.loss_config = loss
1096
+ self.loss_params = loss_params or {}
1223
1097
 
1224
1098
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
1225
1099
  # Set loss function based on training mode
@@ -1239,7 +1113,7 @@ class BaseMatchModel(BaseModel):
1239
1113
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1240
1114
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1241
1115
  loss_value = default_losses.get(self.training_mode, loss_value)
1242
- loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1116
+ loss_kwargs = get_loss_kwargs(self.loss_params, 0)
1243
1117
  self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1244
1118
  # set scheduler
1245
1119
  self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
@@ -1323,57 +1197,47 @@ class BaseMatchModel(BaseModel):
1323
1197
  return loss
1324
1198
  else:
1325
1199
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1200
+
1326
1201
 
1327
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
1328
- """Reuse BaseModel metric configuration (mode + early stopper)."""
1329
- super()._set_metrics(metrics)
1330
-
1202
+ def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
1203
+ """Prepare data loader for specific features."""
1204
+ if isinstance(data, DataLoader):
1205
+ return data
1206
+
1207
+ feature_data = {}
1208
+ for feature in features:
1209
+ if isinstance(data, dict):
1210
+ if feature.name in data:
1211
+ feature_data[feature.name] = data[feature.name]
1212
+ elif isinstance(data, pd.DataFrame):
1213
+ if feature.name in data.columns:
1214
+ feature_data[feature.name] = data[feature.name].values
1215
+ return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
1216
+
1331
1217
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1332
- self.eval()
1333
- if not isinstance(data, DataLoader):
1334
- user_data = {}
1335
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1336
- for feature in all_user_features:
1337
- if isinstance(data, dict):
1338
- if feature.name in data:
1339
- user_data[feature.name] = data[feature.name]
1340
- elif isinstance(data, pd.DataFrame):
1341
- if feature.name in data.columns:
1342
- user_data[feature.name] = data[feature.name].values
1343
- data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1344
- else:
1345
- data_loader = data
1218
+ self.eval()
1219
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1220
+ data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
1221
+
1346
1222
  embeddings_list = []
1347
1223
  with torch.no_grad():
1348
1224
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1349
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1225
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1350
1226
  user_input = self.get_user_features(batch_dict["features"])
1351
1227
  user_emb = self.user_tower(user_input)
1352
1228
  embeddings_list.append(user_emb.cpu().numpy())
1353
- embeddings = np.concatenate(embeddings_list, axis=0)
1354
- return embeddings
1229
+ return np.concatenate(embeddings_list, axis=0)
1355
1230
 
1356
1231
  def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1357
1232
  self.eval()
1358
- if not isinstance(data, DataLoader):
1359
- item_data = {}
1360
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1361
- for feature in all_item_features:
1362
- if isinstance(data, dict):
1363
- if feature.name in data:
1364
- item_data[feature.name] = data[feature.name]
1365
- elif isinstance(data, pd.DataFrame):
1366
- if feature.name in data.columns:
1367
- item_data[feature.name] = data[feature.name].values
1368
- data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1369
- else:
1370
- data_loader = data
1233
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1234
+ data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
1235
+
1371
1236
  embeddings_list = []
1372
1237
  with torch.no_grad():
1373
1238
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1374
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1239
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1375
1240
  item_input = self.get_item_features(batch_dict["features"])
1376
1241
  item_emb = self.item_tower(item_input)
1377
1242
  embeddings_list.append(item_emb.cpu().numpy())
1378
- embeddings = np.concatenate(embeddings_list, axis=0)
1379
- return embeddings
1243
+ return np.concatenate(embeddings_list, axis=0)