nextrec 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +10 -23
  3. nextrec/basic/layers.py +18 -61
  4. nextrec/basic/metrics.py +55 -33
  5. nextrec/basic/model.py +247 -389
  6. nextrec/data/__init__.py +2 -2
  7. nextrec/data/data_utils.py +80 -4
  8. nextrec/data/dataloader.py +36 -57
  9. nextrec/data/preprocessor.py +5 -4
  10. nextrec/models/generative/hstu.py +1 -1
  11. nextrec/models/match/dssm.py +2 -2
  12. nextrec/models/match/dssm_v2.py +2 -2
  13. nextrec/models/match/mind.py +2 -2
  14. nextrec/models/match/sdm.py +2 -2
  15. nextrec/models/match/youtube_dnn.py +2 -2
  16. nextrec/models/multi_task/esmm.py +1 -1
  17. nextrec/models/multi_task/mmoe.py +1 -1
  18. nextrec/models/multi_task/ple.py +1 -1
  19. nextrec/models/multi_task/poso.py +1 -1
  20. nextrec/models/multi_task/share_bottom.py +1 -1
  21. nextrec/models/ranking/afm.py +1 -1
  22. nextrec/models/ranking/autoint.py +1 -1
  23. nextrec/models/ranking/dcn.py +1 -1
  24. nextrec/models/ranking/deepfm.py +1 -1
  25. nextrec/models/ranking/dien.py +1 -1
  26. nextrec/models/ranking/din.py +1 -1
  27. nextrec/models/ranking/fibinet.py +1 -1
  28. nextrec/models/ranking/fm.py +1 -1
  29. nextrec/models/ranking/masknet.py +2 -2
  30. nextrec/models/ranking/pnn.py +1 -1
  31. nextrec/models/ranking/widedeep.py +1 -1
  32. nextrec/models/ranking/xdeepfm.py +1 -1
  33. nextrec/utils/__init__.py +2 -1
  34. nextrec/utils/common.py +21 -2
  35. {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/METADATA +3 -3
  36. nextrec-0.3.3.dist-info/RECORD +57 -0
  37. nextrec-0.3.2.dist-info/RECORD +0 -57
  38. {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/WHEEL +0 -0
  39. {nextrec-0.3.2.dist-info → nextrec-0.3.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -21,21 +21,22 @@ from typing import Union, Literal, Any
21
21
  from torch.utils.data import DataLoader
22
22
 
23
23
  from nextrec.basic.callback import EarlyStopper
24
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
24
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
25
25
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
26
 
27
27
  from nextrec.basic.loggers import setup_logger, colorize
28
28
  from nextrec.basic.session import resolve_save_path, create_session
29
- from nextrec.basic.metrics import configure_metrics, evaluate_metrics
29
+ from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
30
30
 
31
- from nextrec.data import get_column_data, collate_fn
32
31
  from nextrec.data.dataloader import build_tensors_from_data
32
+ from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
33
33
 
34
34
  from nextrec.loss import get_loss_fn, get_loss_kwargs
35
- from nextrec.utils import get_optimizer, get_scheduler
35
+ from nextrec.utils import get_optimizer, get_scheduler, to_tensor
36
+
36
37
  from nextrec import __version__
37
38
 
38
- class BaseModel(FeatureSpecMixin, nn.Module):
39
+ class BaseModel(FeatureSet, nn.Module):
39
40
  @property
40
41
  def model_name(self) -> str:
41
42
  raise NotImplementedError
@@ -69,72 +70,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
69
70
  self.session_id = session_id
70
71
  self.session = create_session(session_id)
71
72
  self.session_path = self.session.root # pwd/session_id, path for this session
72
- self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint"+".model")
73
- self.best_path = os.path.join(self.session_path, self.model_name+ "_best.model")
73
+ self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
74
+ self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
74
75
  self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
75
- self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
76
- self.target = self.target_columns
77
- self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
76
+ self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
78
77
 
79
78
  self.task = task
80
79
  self.nums_task = len(task) if isinstance(task, list) else 1
81
80
 
82
- self._embedding_l1_reg = embedding_l1_reg
83
- self._dense_l1_reg = dense_l1_reg
84
- self._embedding_l2_reg = embedding_l2_reg
85
- self._dense_l2_reg = dense_l2_reg
86
- self._regularization_weights = []
87
- self._embedding_params = []
88
- self._loss_weights: float | list[float] | None = None
89
- self._early_stop_patience = early_stop_patience
90
- self._max_gradient_norm = 1.0
91
- self._logger_initialized = False
92
-
93
- def _register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
81
+ self.embedding_l1_reg = embedding_l1_reg
82
+ self.dense_l1_reg = dense_l1_reg
83
+ self.embedding_l2_reg = embedding_l2_reg
84
+ self.dense_l2_reg = dense_l2_reg
85
+ self.regularization_weights = []
86
+ self.embedding_params = []
87
+ self.loss_weight = None
88
+ self.early_stop_patience = early_stop_patience
89
+ self.max_gradient_norm = 1.0
90
+ self.logger_initialized = False
91
+
92
+ def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
94
93
  exclude_modules = exclude_modules or []
95
94
  include_modules = include_modules or []
96
- if hasattr(self, embedding_attr):
97
- embedding_layer = getattr(self, embedding_attr)
98
- if hasattr(embedding_layer, "embed_dict"):
99
- for embed in embedding_layer.embed_dict.values():
100
- self._embedding_params.append(embed.weight)
95
+ embedding_layer = getattr(self, embedding_attr, None)
96
+ embed_dict = getattr(embedding_layer, "embed_dict", None)
97
+ if embed_dict is not None:
98
+ self.embedding_params.extend(embed.weight for embed in embed_dict.values())
99
+ skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
101
100
  for name, module in self.named_modules():
102
- if module is self:
103
- continue
104
- if embedding_attr in name:
105
- continue
106
- if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
107
- continue
108
- if include_modules:
109
- if not any(inc_name in name for inc_name in include_modules):
110
- continue
111
- if any(exc_name in name for exc_name in exclude_modules):
101
+ if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
112
102
  continue
113
103
  if isinstance(module, nn.Linear):
114
- self._regularization_weights.append(module.weight)
104
+ self.regularization_weights.append(module.weight)
115
105
 
116
106
  def add_reg_loss(self) -> torch.Tensor:
117
107
  reg_loss = torch.tensor(0.0, device=self.device)
118
- if self._embedding_params:
119
- if self._embedding_l1_reg > 0:
120
- reg_loss += self._embedding_l1_reg * sum(param.abs().sum() for param in self._embedding_params)
121
- if self._embedding_l2_reg > 0:
122
- reg_loss += self._embedding_l2_reg * sum((param ** 2).sum() for param in self._embedding_params)
123
- if self._regularization_weights:
124
- if self._dense_l1_reg > 0:
125
- reg_loss += self._dense_l1_reg * sum(param.abs().sum() for param in self._regularization_weights)
126
- if self._dense_l2_reg > 0:
127
- reg_loss += self._dense_l2_reg * sum((param ** 2).sum() for param in self._regularization_weights)
108
+ if self.embedding_params:
109
+ if self.embedding_l1_reg > 0:
110
+ reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
111
+ if self.embedding_l2_reg > 0:
112
+ reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
113
+ if self.regularization_weights:
114
+ if self.dense_l1_reg > 0:
115
+ reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
116
+ if self.dense_l2_reg > 0:
117
+ reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
128
118
  return reg_loss
129
119
 
130
- def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
131
- tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
132
- if tensor.dtype != dtype:
133
- tensor = tensor.to(dtype=dtype)
134
- if tensor.device != self.device:
135
- tensor = tensor.to(self.device)
136
- return tensor
137
-
138
120
  def get_input(self, input_data: dict, require_labels: bool = True):
139
121
  feature_source = input_data.get("features", {})
140
122
  label_source = input_data.get("labels")
@@ -143,12 +125,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
143
125
  if feature.name not in feature_source:
144
126
  raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
145
127
  feature_data = get_column_data(feature_source, feature.name)
146
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
147
- X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
128
+ X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
148
129
  y = None
149
- if (len(self.target) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target)))): # need labels: training or eval with labels
130
+ if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
150
131
  target_tensors = []
151
- for target_name in self.target:
132
+ for target_name in self.target_columns:
152
133
  if label_source is None or target_name not in label_source:
153
134
  if require_labels:
154
135
  raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
@@ -158,7 +139,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
158
139
  if require_labels:
159
140
  raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
160
141
  continue
161
- target_tensor = self._to_tensor(target_data, dtype=torch.float32)
142
+ target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
162
143
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
163
144
  target_tensors.append(target_tensor)
164
145
  if target_tensors:
@@ -169,11 +150,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
169
150
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
170
151
  return X_input, y
171
152
 
172
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
173
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
174
- self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
175
-
176
- def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
153
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
154
+ """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
177
155
  if not (0 < validation_split < 1):
178
156
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
179
157
  if not isinstance(train_data, (pd.DataFrame, dict)):
@@ -181,8 +159,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
181
159
  if isinstance(train_data, pd.DataFrame):
182
160
  total_length = len(train_data)
183
161
  else:
184
- sample_key = next(iter(train_data))
185
- total_length = len(train_data[sample_key])
162
+ sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
163
+ total_length = len(train_data[sample_key]) # len(train_data['user_id'])
186
164
  for k, v in train_data.items():
187
165
  if len(v) != total_length:
188
166
  raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
@@ -198,20 +176,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
198
176
  train_split = {}
199
177
  valid_split = {}
200
178
  for key, value in train_data.items():
201
- if isinstance(value, np.ndarray):
202
- train_split[key] = value[train_indices]
203
- valid_split[key] = value[valid_indices]
204
- elif isinstance(value, (list, tuple)):
205
- arr = np.asarray(value)
206
- train_split[key] = arr[train_indices].tolist()
207
- valid_split[key] = arr[valid_indices].tolist()
208
- elif isinstance(value, pd.Series):
209
- train_split[key] = value.iloc[train_indices].values
210
- valid_split[key] = value.iloc[valid_indices].values
211
- else:
212
- train_split[key] = [value[i] for i in train_indices]
213
- valid_split[key] = [value[i] for i in valid_indices]
214
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
179
+ arr = np.asarray(value)
180
+ train_split[key] = arr[train_indices]
181
+ valid_split[key] = arr[valid_indices]
182
+ train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
215
183
  logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
216
184
  return train_loader, valid_split
217
185
 
@@ -226,44 +194,44 @@ class BaseModel(FeatureSpecMixin, nn.Module):
226
194
  loss_weights: int | float | list[int | float] | None = None,
227
195
  ):
228
196
  optimizer_params = optimizer_params or {}
229
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
230
- self._optimizer_params = optimizer_params
197
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
198
+ self.optimizer_params = optimizer_params
231
199
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
232
200
 
233
201
  scheduler_params = scheduler_params or {}
234
202
  if isinstance(scheduler, str):
235
- self._scheduler_name = scheduler
203
+ self.scheduler_name = scheduler
236
204
  elif scheduler is None:
237
- self._scheduler_name = None
238
- else:
239
- self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
240
- self._scheduler_params = scheduler_params
205
+ self.scheduler_name = None
206
+ else: # for custom scheduler instance, need to provide class name for logging
207
+ self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
208
+ self.scheduler_params = scheduler_params
241
209
  self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
242
210
 
243
- self._loss_config = loss
244
- self._loss_params = loss_params or {}
211
+ self.loss_config = loss
212
+ self.loss_params = loss_params or {}
245
213
  self.loss_fn = []
246
- for i in range(self.nums_task):
247
- if isinstance(loss, list):
248
- loss_value = loss[i] if i < len(loss) else None
249
- else:
250
- loss_value = loss
251
- if self.nums_task == 1: # single task
252
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else self._loss_params[0]
253
- else:
254
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
255
- self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
256
- # Normalize loss weights for single-task and multi-task setups
214
+ if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
215
+ loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
216
+ else: # for example: 'bce' -> ['bce', 'bce']
217
+ loss_list = [loss] * self.nums_task
218
+
219
+ if isinstance(self.loss_params, dict):
220
+ params_list = [self.loss_params] * self.nums_task
221
+ else: # list[dict]
222
+ params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
223
+ self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
224
+
257
225
  if loss_weights is None:
258
- self._loss_weights = None
226
+ self.loss_weights = None
259
227
  elif self.nums_task == 1:
260
228
  if isinstance(loss_weights, (list, tuple)):
261
- if len(loss_weights) != 1:
229
+ if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
262
230
  raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
263
231
  weight_value = loss_weights[0]
264
232
  else:
265
233
  weight_value = loss_weights
266
- self._loss_weights = float(weight_value)
234
+ self.loss_weights = float(weight_value)
267
235
  else:
268
236
  if isinstance(loss_weights, (int, float)):
269
237
  weights = [float(loss_weights)] * self.nums_task
@@ -273,87 +241,68 @@ class BaseModel(FeatureSpecMixin, nn.Module):
273
241
  raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
274
242
  else:
275
243
  raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
276
- self._loss_weights = weights
244
+ self.loss_weights = weights
277
245
 
278
246
  def compute_loss(self, y_pred, y_true):
279
247
  if y_true is None:
280
248
  raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
281
249
  if self.nums_task == 1:
282
250
  loss = self.loss_fn[0](y_pred, y_true)
283
- if self._loss_weights is not None:
284
- loss = loss * self._loss_weights
251
+ if self.loss_weights is not None:
252
+ loss = loss * self.loss_weights
285
253
  return loss
286
254
  else:
287
255
  task_losses = []
288
256
  for i in range(self.nums_task):
289
257
  task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
290
- if isinstance(self._loss_weights, (list, tuple)):
291
- task_loss = task_loss * self._loss_weights[i]
258
+ if isinstance(self.loss_weights, (list, tuple)):
259
+ task_loss = task_loss * self.loss_weights[i]
292
260
  task_losses.append(task_loss)
293
261
  return torch.stack(task_losses).sum()
294
262
 
295
- def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
263
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
296
264
  if isinstance(data, DataLoader):
297
265
  return data
298
- tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
266
+ tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
299
267
  if tensors is None:
300
268
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
301
269
  dataset = TensorDictDataset(tensors)
302
270
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
303
271
 
304
- def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
305
- if not (isinstance(batch_data, dict) and "features" in batch_data):
306
- raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
307
- return {
308
- "features": batch_data.get("features", {}),
309
- "labels": batch_data.get("labels"),
310
- "ids": batch_data.get("ids") if include_ids else None,
311
- }
312
-
313
272
  def fit(self,
314
- train_data: dict|pd.DataFrame|DataLoader,
315
- valid_data: dict|pd.DataFrame|DataLoader|None=None,
316
- metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
273
+ train_data: dict | pd.DataFrame | DataLoader,
274
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
275
+ metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
317
276
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
318
- user_id_column: str = 'user_id',
277
+ user_id_column: str | None = None,
319
278
  validation_split: float | None = None):
320
279
  self.to(self.device)
321
- if not self._logger_initialized:
280
+ if not self.logger_initialized:
322
281
  setup_logger(session_id=self.session_id)
323
- self._logger_initialized = True
324
- self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
325
- self.summary()
326
- valid_loader = None
327
- valid_user_ids: np.ndarray | None = None
328
- needs_user_ids: bool = self._needs_user_ids_for_metrics()
282
+ self.logger_initialized = True
283
+
284
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
285
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
286
+ self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
287
+ self.epoch_index = 0
288
+ self.stop_training = False
289
+ self.best_checkpoint_path = self.best_path
290
+ self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
329
291
 
330
292
  if validation_split is not None and valid_data is None:
331
- train_loader, valid_data = self._handle_validation_split(
332
- train_data=train_data, # type: ignore
333
- validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
293
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
334
294
  else:
335
- train_loader = (train_data if isinstance(train_data, DataLoader) else self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
336
- if isinstance(valid_data, DataLoader):
337
- valid_loader = valid_data
338
- elif valid_data is not None:
339
- valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
340
- if needs_user_ids:
341
- if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
342
- valid_user_ids = np.asarray(valid_data[user_id_column].values)
343
- elif isinstance(valid_data, dict) and user_id_column in valid_data:
344
- valid_user_ids = np.asarray(valid_data[user_id_column])
295
+ train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
296
+
297
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
345
298
  try:
346
- self._steps_per_epoch = len(train_loader)
299
+ self.steps_per_epoch = len(train_loader)
347
300
  is_streaming = False
348
- except TypeError: # len() not supported, e.g., streaming data loader
349
- self._steps_per_epoch = None
301
+ except TypeError: # streaming data loader does not supported len()
302
+ self.steps_per_epoch = None
350
303
  is_streaming = True
351
304
 
352
- self._epoch_index = 0
353
- self._stop_training = False
354
- self._best_checkpoint_path = self.best_path
355
- self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
356
-
305
+ self.summary()
357
306
  logging.info("")
358
307
  logging.info(colorize("=" * 80, bold=True))
359
308
  if is_streaming:
@@ -365,36 +314,34 @@ class BaseModel(FeatureSpecMixin, nn.Module):
365
314
  logging.info(colorize(f"Model device: {self.device}", bold=True))
366
315
 
367
316
  for epoch in range(epochs):
368
- self._epoch_index = epoch
317
+ self.epoch_index = epoch
369
318
  if is_streaming:
370
319
  logging.info("")
371
320
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
372
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
373
- if isinstance(train_result, tuple):
321
+
322
+ # handle train result
323
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
324
+ if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
374
325
  train_loss, train_metrics = train_result
375
326
  else:
376
327
  train_loss = train_result
377
328
  train_metrics = None
329
+
330
+ # handle logging for single-task and multi-task
378
331
  if self.nums_task == 1:
379
332
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
380
333
  if train_metrics:
381
334
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
382
335
  log_str += f", {metrics_str}"
383
- logging.info(colorize(log_str, color="white"))
336
+ logging.info(colorize(log_str))
384
337
  else:
385
- task_labels = []
386
- for i in range(self.nums_task):
387
- if i < len(self.target):
388
- task_labels.append(self.target[i])
389
- else:
390
- task_labels.append(f"task_{i}")
391
338
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
392
339
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
393
340
  if train_metrics:
394
- # Group metrics by task
341
+ # group metrics by task
395
342
  task_metrics = {}
396
343
  for metric_key, metric_value in train_metrics.items():
397
- for target_name in self.target:
344
+ for target_name in self.target_columns:
398
345
  if metric_key.endswith(f"_{target_name}"):
399
346
  if target_name not in task_metrics:
400
347
  task_metrics[target_name] = {}
@@ -403,15 +350,15 @@ class BaseModel(FeatureSpecMixin, nn.Module):
403
350
  break
404
351
  if task_metrics:
405
352
  task_metric_strs = []
406
- for target_name in self.target:
353
+ for target_name in self.target_columns:
407
354
  if target_name in task_metrics:
408
355
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
409
356
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
410
357
  log_str += ", " + ", ".join(task_metric_strs)
411
- logging.info(colorize(log_str, color="white"))
358
+ logging.info(colorize(log_str))
412
359
  if valid_loader is not None:
413
- # Pass user_ids only if needed for GAUC metric
414
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
360
+ # pass user_ids only if needed for GAUC metric
361
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
415
362
  if self.nums_task == 1:
416
363
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
417
364
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
@@ -419,7 +366,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
419
366
  # multi task metrics
420
367
  task_metrics = {}
421
368
  for metric_key, metric_value in val_metrics.items():
422
- for target_name in self.target:
369
+ for target_name in self.target_columns:
423
370
  if metric_key.endswith(f"_{target_name}"):
424
371
  if target_name not in task_metrics:
425
372
  task_metrics[target_name] = {}
@@ -427,7 +374,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
427
374
  task_metrics[target_name][metric_name] = metric_value
428
375
  break
429
376
  task_metric_strs = []
430
- for target_name in self.target:
377
+ for target_name in self.target_columns:
431
378
  if target_name in task_metrics:
432
379
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
433
380
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
@@ -435,45 +382,42 @@ class BaseModel(FeatureSpecMixin, nn.Module):
435
382
  # Handle empty validation metrics
436
383
  if not val_metrics:
437
384
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
438
- self._best_checkpoint_path = self.checkpoint_path
385
+ self.best_checkpoint_path = self.checkpoint_path
439
386
  logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
440
387
  continue
441
388
  if self.nums_task == 1:
442
389
  primary_metric_key = self.metrics[0]
443
390
  else:
444
- primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
445
-
446
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
391
+ primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
392
+ primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
447
393
  improved = False
448
-
394
+ # early stopping check
449
395
  if self.best_metrics_mode == 'max':
450
- if primary_metric > self._best_metric:
451
- self._best_metric = primary_metric
452
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
396
+ if primary_metric > self.best_metric:
397
+ self.best_metric = primary_metric
453
398
  improved = True
454
399
  else:
455
- if primary_metric < self._best_metric:
456
- self._best_metric = primary_metric
400
+ if primary_metric < self.best_metric:
401
+ self.best_metric = primary_metric
457
402
  improved = True
458
- # Always keep the latest weights as a rolling checkpoint
459
403
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
460
404
  if improved:
461
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}"))
405
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
462
406
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
463
- self._best_checkpoint_path = self.best_path
407
+ self.best_checkpoint_path = self.best_path
464
408
  self.early_stopper.trial_counter = 0
465
409
  else:
466
410
  self.early_stopper.trial_counter += 1
467
411
  logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
468
412
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
469
- self._stop_training = True
413
+ self.stop_training = True
470
414
  logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
471
415
  break
472
416
  else:
473
417
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
474
418
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
475
- self._best_checkpoint_path = self.best_path
476
- if self._stop_training:
419
+ self.best_checkpoint_path = self.best_path
420
+ if self.stop_training:
477
421
  break
478
422
  if self.scheduler_fn is not None:
479
423
  if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
@@ -481,34 +425,29 @@ class BaseModel(FeatureSpecMixin, nn.Module):
481
425
  self.scheduler_fn.step(primary_metric)
482
426
  else:
483
427
  self.scheduler_fn.step()
484
- logging.info("\n")
485
- logging.info(colorize("Training finished.", color="bright_green", bold=True))
486
- logging.info("\n")
428
+ logging.info(" ")
429
+ logging.info(colorize("Training finished.", bold=True))
430
+ logging.info(" ")
487
431
  if valid_loader is not None:
488
- logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
489
- self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
432
+ logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
433
+ self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
490
434
  return self
491
435
 
492
436
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
493
- if self.nums_task == 1:
494
- accumulated_loss = 0.0
495
- else:
496
- accumulated_loss = 0.0
437
+ accumulated_loss = 0.0
497
438
  self.train()
498
439
  num_batches = 0
499
440
  y_true_list = []
500
441
  y_pred_list = []
501
- needs_user_ids = self._needs_user_ids_for_metrics()
502
- user_ids_list = [] if needs_user_ids else None
503
- if self._steps_per_epoch is not None:
504
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
442
+
443
+ user_ids_list = [] if self.needs_user_ids else None
444
+ if self.steps_per_epoch is not None:
445
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
505
446
  else:
506
- if is_streaming:
507
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
508
- else:
509
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
447
+ desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
448
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
510
449
  for batch_index, batch_data in batch_iter:
511
- batch_dict = self._batch_to_dict(batch_data)
450
+ batch_dict = batch_to_dict(batch_data)
512
451
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
513
452
  y_pred = self.forward(X_input)
514
453
  loss = self.compute_loss(y_pred, y_true)
@@ -516,66 +455,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
516
455
  total_loss = loss + reg_loss
517
456
  self.optimizer_fn.zero_grad()
518
457
  total_loss.backward()
519
- nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
458
+ nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
520
459
  self.optimizer_fn.step()
521
- if self.nums_task == 1:
522
- accumulated_loss += loss.item()
523
- else:
524
- accumulated_loss += loss.item()
460
+ accumulated_loss += loss.item()
525
461
  if y_true is not None:
526
- y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
527
- if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
528
- batch_user_id = None
529
- if self.id_columns:
530
- for id_name in self.id_columns:
531
- if id_name in batch_dict["ids"]:
532
- batch_user_id = batch_dict["ids"][id_name]
533
- break
534
- if batch_user_id is None and batch_dict["ids"]:
535
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
462
+ y_true_list.append(y_true.detach().cpu().numpy())
463
+ if self.needs_user_ids and user_ids_list is not None:
464
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
536
465
  if batch_user_id is not None:
537
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
538
- user_ids_list.append(ids_np.reshape(ids_np.shape[0]))
539
- if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
466
+ user_ids_list.append(batch_user_id)
467
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
540
468
  y_pred_list.append(y_pred.detach().cpu().numpy())
541
469
  num_batches += 1
542
- avg_loss = accumulated_loss / num_batches
470
+ avg_loss = accumulated_loss / max(num_batches, 1)
543
471
  if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
544
472
  y_true_all = np.concatenate(y_true_list, axis=0)
545
473
  y_pred_all = np.concatenate(y_pred_list, axis=0)
546
474
  combined_user_ids = None
547
- if needs_user_ids and user_ids_list:
475
+ if self.needs_user_ids and user_ids_list:
548
476
  combined_user_ids = np.concatenate(user_ids_list, axis=0)
549
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=combined_user_ids)
477
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
550
478
  return avg_loss, metrics_dict
551
479
  return avg_loss
552
480
 
553
- def _needs_user_ids_for_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None) -> bool:
554
- """Check if any configured metric requires user_ids (e.g., gauc, ranking @K)."""
555
- metric_names = set()
556
- sources = [metrics if metrics is not None else getattr(self, "metrics", None), getattr(self, "task_specific_metrics", None),]
557
- for src in sources:
558
- stack = [src]
559
- while stack:
560
- item = stack.pop()
561
- if not item:
562
- continue
563
- if isinstance(item, dict):
564
- stack.extend(item.values())
565
- elif isinstance(item, str):
566
- metric_names.add(item.lower())
567
- else:
568
- try:
569
- for m in item:
570
- metric_names.add(m.lower())
571
- except TypeError:
572
- continue
573
- for name in metric_names:
574
- if name == "gauc":
575
- return True
576
- if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
577
- return True
578
- return False
481
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
482
+ if valid_data is None:
483
+ return None, None
484
+ if isinstance(valid_data, DataLoader):
485
+ return valid_data, None
486
+ valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
487
+ valid_user_ids = None
488
+ if needs_user_ids:
489
+ if user_id_column is None:
490
+ raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
491
+ valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
492
+ return valid_loader, valid_user_ids
579
493
 
580
494
  def evaluate(self,
581
495
  data: dict | pd.DataFrame | DataLoader,
@@ -587,18 +501,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
587
501
  eval_metrics = metrics if metrics is not None else self.metrics
588
502
  if eval_metrics is None:
589
503
  raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
590
- needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
504
+ needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
591
505
 
592
506
  if isinstance(data, DataLoader):
593
507
  data_loader = data
594
508
  else:
595
- # Extract user_ids if needed and not provided
596
509
  if user_ids is None and needs_user_ids:
597
- if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
598
- user_ids = np.asarray(data[user_id_column].values)
599
- elif isinstance(data, dict) and user_id_column in data:
600
- user_ids = np.asarray(data[user_id_column])
601
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
510
+ user_ids = get_user_ids(data=data, id_columns=user_id_column)
511
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
602
512
  y_true_list = []
603
513
  y_pred_list = []
604
514
  collected_user_ids = []
@@ -606,26 +516,17 @@ class BaseModel(FeatureSpecMixin, nn.Module):
606
516
  with torch.no_grad():
607
517
  for batch_data in data_loader:
608
518
  batch_count += 1
609
- batch_dict = self._batch_to_dict(batch_data)
519
+ batch_dict = batch_to_dict(batch_data)
610
520
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
611
521
  y_pred = self.forward(X_input)
612
522
  if y_true is not None:
613
523
  y_true_list.append(y_true.cpu().numpy())
614
- # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
615
524
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
616
525
  y_pred_list.append(y_pred.cpu().numpy())
617
- if needs_user_ids and user_ids is None and batch_dict.get("ids"):
618
- batch_user_id = None
619
- if self.id_columns:
620
- for id_name in self.id_columns:
621
- if id_name in batch_dict["ids"]:
622
- batch_user_id = batch_dict["ids"][id_name]
623
- break
624
- if batch_user_id is None and batch_dict["ids"]:
625
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
526
+ if needs_user_ids and user_ids is None:
527
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
626
528
  if batch_user_id is not None:
627
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
628
- collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
529
+ collected_user_ids.append(batch_user_id)
629
530
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
630
531
  if len(y_true_list) > 0:
631
532
  y_true_all = np.concatenate(y_true_list, axis=0)
@@ -654,23 +555,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
654
555
  final_user_ids = user_ids
655
556
  if final_user_ids is None and collected_user_ids:
656
557
  final_user_ids = np.concatenate(collected_user_ids, axis=0)
657
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
558
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
658
559
  return metrics_dict
659
560
 
660
- def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
661
- """Evaluate metrics using the metrics module."""
662
- task_specific_metrics = getattr(self, 'task_specific_metrics', None)
663
-
664
- return evaluate_metrics(
665
- y_true=y_true,
666
- y_pred=y_pred,
667
- metrics=metrics,
668
- task=self.task,
669
- target_names=self.target,
670
- task_specific_metrics=task_specific_metrics,
671
- user_ids=user_ids
672
- )
673
-
674
561
  def predict(
675
562
  self,
676
563
  data: str | dict | pd.DataFrame | DataLoader,
@@ -681,28 +568,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
681
568
  return_dataframe: bool = True,
682
569
  streaming_chunk_size: int = 10000,
683
570
  ) -> pd.DataFrame | np.ndarray:
684
- """
685
- Run inference and optionally return ID-aligned predictions.
686
-
687
- When ``id_columns`` are configured and ``include_ids`` is True (default),
688
- the returned object will include those IDs to keep a one-to-one mapping
689
- between each prediction and its source row.
690
- If ``save_path`` is provided and ``return_dataframe`` is False, predictions
691
- stream to disk batch-by-batch to avoid holding all outputs in memory.
692
- """
693
571
  self.eval()
694
572
  if include_ids is None:
695
573
  include_ids = bool(self.id_columns)
696
574
  include_ids = include_ids and bool(self.id_columns)
697
575
 
698
- # if saving to disk without returning dataframe, use streaming prediction
699
576
  if save_path is not None and not return_dataframe:
700
577
  return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
701
578
  if isinstance(data, (str, os.PathLike)):
702
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
579
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
703
580
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
704
581
  elif not isinstance(data, DataLoader):
705
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
582
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
706
583
  else:
707
584
  data_loader = data
708
585
 
@@ -712,7 +589,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
712
589
 
713
590
  with torch.no_grad():
714
591
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
715
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
592
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
716
593
  X_input, _ = self.get_input(batch_dict, require_labels=False)
717
594
  y_pred = self.forward(X_input)
718
595
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
@@ -722,10 +599,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
722
599
  if id_name not in batch_dict["ids"]:
723
600
  continue
724
601
  id_tensor = batch_dict["ids"][id_name]
725
- if isinstance(id_tensor, torch.Tensor):
726
- id_np = id_tensor.detach().cpu().numpy()
727
- else:
728
- id_np = np.asarray(id_tensor)
602
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
729
603
  id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
730
604
  if len(y_pred_list) > 0:
731
605
  y_pred_all = np.concatenate(y_pred_list, axis=0)
@@ -735,12 +609,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
735
609
  if y_pred_all.ndim == 1:
736
610
  y_pred_all = y_pred_all.reshape(-1, 1)
737
611
  if y_pred_all.size == 0:
738
- num_outputs = len(self.target) if self.target else 1
612
+ num_outputs = len(self.target_columns) if self.target_columns else 1
739
613
  y_pred_all = y_pred_all.reshape(0, num_outputs)
740
614
  num_outputs = y_pred_all.shape[1]
741
615
  pred_columns: list[str] = []
742
- if self.target:
743
- for name in self.target[:num_outputs]:
616
+ if self.target_columns:
617
+ for name in self.target_columns[:num_outputs]:
744
618
  pred_columns.append(f"{name}_pred")
745
619
  while len(pred_columns) < num_outputs:
746
620
  pred_columns.append(f"pred_{len(pred_columns)}")
@@ -794,10 +668,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
794
668
  return_dataframe: bool,
795
669
  ) -> pd.DataFrame:
796
670
  if isinstance(data, (str, os.PathLike)):
797
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
671
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
798
672
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
799
673
  elif not isinstance(data, DataLoader):
800
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
674
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
801
675
  else:
802
676
  data_loader = data
803
677
 
@@ -812,35 +686,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
812
686
 
813
687
  with torch.no_grad():
814
688
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
815
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
689
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
816
690
  X_input, _ = self.get_input(batch_dict, require_labels=False)
817
691
  y_pred = self.forward(X_input)
818
692
  if y_pred is None or not isinstance(y_pred, torch.Tensor):
819
693
  continue
820
-
821
694
  y_pred_np = y_pred.detach().cpu().numpy()
822
695
  if y_pred_np.ndim == 1:
823
696
  y_pred_np = y_pred_np.reshape(-1, 1)
824
-
825
697
  if pred_columns is None:
826
698
  num_outputs = y_pred_np.shape[1]
827
699
  pred_columns = []
828
- if self.target:
829
- for name in self.target[:num_outputs]:
700
+ if self.target_columns:
701
+ for name in self.target_columns[:num_outputs]:
830
702
  pred_columns.append(f"{name}_pred")
831
703
  while len(pred_columns) < num_outputs:
832
704
  pred_columns.append(f"pred_{len(pred_columns)}")
833
-
705
+
834
706
  id_arrays_batch: dict[str, np.ndarray] = {}
835
707
  if include_ids and self.id_columns and batch_dict.get("ids"):
836
708
  for id_name in self.id_columns:
837
709
  if id_name not in batch_dict["ids"]:
838
710
  continue
839
711
  id_tensor = batch_dict["ids"][id_name]
840
- if isinstance(id_tensor, torch.Tensor):
841
- id_np = id_tensor.detach().cpu().numpy()
842
- else:
843
- id_np = np.asarray(id_tensor)
712
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
844
713
  id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
845
714
 
846
715
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
@@ -881,7 +750,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
881
750
  config_path = self.features_config_path
882
751
  features_config = {
883
752
  "all_features": self.all_features,
884
- "target": self.target,
753
+ "target": self.target_columns,
885
754
  "id_columns": self.id_columns,
886
755
  "version": __version__,
887
756
  }
@@ -921,9 +790,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
921
790
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
922
791
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
923
792
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
924
- self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
925
- self.target = self.target_columns
926
- self.target_index = {name: idx for idx, name in enumerate(self.target)}
793
+ self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
794
+
927
795
  cfg_version = features_config.get("version")
928
796
  if verbose:
929
797
  logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
@@ -1056,35 +924,35 @@ class BaseModel(FeatureSpecMixin, nn.Module):
1056
924
  logger.info(f"Task Type: {self.task}")
1057
925
  logger.info(f"Number of Tasks: {self.nums_task}")
1058
926
  logger.info(f"Metrics: {self.metrics}")
1059
- logger.info(f"Target Columns: {self.target}")
927
+ logger.info(f"Target Columns: {self.target_columns}")
1060
928
  logger.info(f"Device: {self.device}")
1061
929
 
1062
- if hasattr(self, '_optimizer_name'):
1063
- logger.info(f"Optimizer: {self._optimizer_name}")
1064
- if self._optimizer_params:
1065
- for key, value in self._optimizer_params.items():
930
+ if hasattr(self, 'optimizer_name'):
931
+ logger.info(f"Optimizer: {self.optimizer_name}")
932
+ if self.optimizer_params:
933
+ for key, value in self.optimizer_params.items():
1066
934
  logger.info(f" {key:25s}: {value}")
1067
935
 
1068
- if hasattr(self, '_scheduler_name') and self._scheduler_name:
1069
- logger.info(f"Scheduler: {self._scheduler_name}")
1070
- if self._scheduler_params:
1071
- for key, value in self._scheduler_params.items():
936
+ if hasattr(self, 'scheduler_name') and self.scheduler_name:
937
+ logger.info(f"Scheduler: {self.scheduler_name}")
938
+ if self.scheduler_params:
939
+ for key, value in self.scheduler_params.items():
1072
940
  logger.info(f" {key:25s}: {value}")
1073
941
 
1074
- if hasattr(self, '_loss_config'):
1075
- logger.info(f"Loss Function: {self._loss_config}")
1076
- if hasattr(self, '_loss_weights'):
1077
- logger.info(f"Loss Weights: {self._loss_weights}")
942
+ if hasattr(self, 'loss_config'):
943
+ logger.info(f"Loss Function: {self.loss_config}")
944
+ if hasattr(self, 'loss_weights'):
945
+ logger.info(f"Loss Weights: {self.loss_weights}")
1078
946
 
1079
947
  logger.info("Regularization:")
1080
- logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1081
- logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1082
- logger.info(f" Dense L1: {self._dense_l1_reg}")
1083
- logger.info(f" Dense L2: {self._dense_l2_reg}")
948
+ logger.info(f" Embedding L1: {self.embedding_l1_reg}")
949
+ logger.info(f" Embedding L2: {self.embedding_l2_reg}")
950
+ logger.info(f" Dense L1: {self.dense_l1_reg}")
951
+ logger.info(f" Dense L2: {self.dense_l2_reg}")
1084
952
 
1085
953
  logger.info("Other Settings:")
1086
- logger.info(f" Early Stop Patience: {self._early_stop_patience}")
1087
- logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
954
+ logger.info(f" Early Stop Patience: {self.early_stop_patience}")
955
+ logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1088
956
  logger.info(f" Session ID: {self.session_id}")
1089
957
  logger.info(f" Features Config Path: {self.features_config_path}")
1090
958
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
@@ -1214,18 +1082,18 @@ class BaseMatchModel(BaseModel):
1214
1082
  # Call parent compile with match-specific logic
1215
1083
  optimizer_params = optimizer_params or {}
1216
1084
 
1217
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1218
- self._optimizer_params = optimizer_params
1085
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1086
+ self.optimizer_params = optimizer_params
1219
1087
  if isinstance(scheduler, str):
1220
- self._scheduler_name = scheduler
1088
+ self.scheduler_name = scheduler
1221
1089
  elif scheduler is not None:
1222
1090
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1223
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1091
+ self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1224
1092
  else:
1225
- self._scheduler_name = None
1226
- self._scheduler_params = scheduler_params or {}
1227
- self._loss_config = loss
1228
- self._loss_params = loss_params or {}
1093
+ self.scheduler_name = None
1094
+ self.scheduler_params = scheduler_params or {}
1095
+ self.loss_config = loss
1096
+ self.loss_params = loss_params or {}
1229
1097
 
1230
1098
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
1231
1099
  # Set loss function based on training mode
@@ -1245,7 +1113,7 @@ class BaseMatchModel(BaseModel):
1245
1113
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1246
1114
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1247
1115
  loss_value = default_losses.get(self.training_mode, loss_value)
1248
- loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1116
+ loss_kwargs = get_loss_kwargs(self.loss_params, 0)
1249
1117
  self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1250
1118
  # set scheduler
1251
1119
  self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
@@ -1329,57 +1197,47 @@ class BaseMatchModel(BaseModel):
1329
1197
  return loss
1330
1198
  else:
1331
1199
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1200
+
1332
1201
 
1333
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
1334
- """Reuse BaseModel metric configuration (mode + early stopper)."""
1335
- super()._set_metrics(metrics)
1336
-
1202
+ def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
1203
+ """Prepare data loader for specific features."""
1204
+ if isinstance(data, DataLoader):
1205
+ return data
1206
+
1207
+ feature_data = {}
1208
+ for feature in features:
1209
+ if isinstance(data, dict):
1210
+ if feature.name in data:
1211
+ feature_data[feature.name] = data[feature.name]
1212
+ elif isinstance(data, pd.DataFrame):
1213
+ if feature.name in data.columns:
1214
+ feature_data[feature.name] = data[feature.name].values
1215
+ return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
1216
+
1337
1217
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1338
- self.eval()
1339
- if not isinstance(data, DataLoader):
1340
- user_data = {}
1341
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1342
- for feature in all_user_features:
1343
- if isinstance(data, dict):
1344
- if feature.name in data:
1345
- user_data[feature.name] = data[feature.name]
1346
- elif isinstance(data, pd.DataFrame):
1347
- if feature.name in data.columns:
1348
- user_data[feature.name] = data[feature.name].values
1349
- data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1350
- else:
1351
- data_loader = data
1218
+ self.eval()
1219
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1220
+ data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
1221
+
1352
1222
  embeddings_list = []
1353
1223
  with torch.no_grad():
1354
1224
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1355
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1225
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1356
1226
  user_input = self.get_user_features(batch_dict["features"])
1357
1227
  user_emb = self.user_tower(user_input)
1358
1228
  embeddings_list.append(user_emb.cpu().numpy())
1359
- embeddings = np.concatenate(embeddings_list, axis=0)
1360
- return embeddings
1229
+ return np.concatenate(embeddings_list, axis=0)
1361
1230
 
1362
1231
  def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1363
1232
  self.eval()
1364
- if not isinstance(data, DataLoader):
1365
- item_data = {}
1366
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1367
- for feature in all_item_features:
1368
- if isinstance(data, dict):
1369
- if feature.name in data:
1370
- item_data[feature.name] = data[feature.name]
1371
- elif isinstance(data, pd.DataFrame):
1372
- if feature.name in data.columns:
1373
- item_data[feature.name] = data[feature.name].values
1374
- data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1375
- else:
1376
- data_loader = data
1233
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1234
+ data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
1235
+
1377
1236
  embeddings_list = []
1378
1237
  with torch.no_grad():
1379
1238
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1380
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1239
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1381
1240
  item_input = self.get_item_features(batch_dict["features"])
1382
1241
  item_emb = self.item_tower(item_input)
1383
1242
  embeddings_list.append(item_emb.cpu().numpy())
1384
- embeddings = np.concatenate(embeddings_list, axis=0)
1385
- return embeddings
1243
+ return np.concatenate(embeddings_list, axis=0)