nextrec 0.3.2__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/features.py +10 -23
  3. nextrec/basic/layers.py +18 -61
  4. nextrec/basic/loggers.py +71 -8
  5. nextrec/basic/metrics.py +55 -33
  6. nextrec/basic/model.py +287 -397
  7. nextrec/data/__init__.py +2 -2
  8. nextrec/data/data_utils.py +80 -4
  9. nextrec/data/dataloader.py +38 -59
  10. nextrec/data/preprocessor.py +38 -73
  11. nextrec/models/generative/hstu.py +1 -1
  12. nextrec/models/match/dssm.py +2 -2
  13. nextrec/models/match/dssm_v2.py +2 -2
  14. nextrec/models/match/mind.py +2 -2
  15. nextrec/models/match/sdm.py +2 -2
  16. nextrec/models/match/youtube_dnn.py +2 -2
  17. nextrec/models/multi_task/esmm.py +1 -1
  18. nextrec/models/multi_task/mmoe.py +1 -1
  19. nextrec/models/multi_task/ple.py +1 -1
  20. nextrec/models/multi_task/poso.py +1 -1
  21. nextrec/models/multi_task/share_bottom.py +1 -1
  22. nextrec/models/ranking/afm.py +1 -1
  23. nextrec/models/ranking/autoint.py +1 -1
  24. nextrec/models/ranking/dcn.py +1 -1
  25. nextrec/models/ranking/deepfm.py +1 -1
  26. nextrec/models/ranking/dien.py +1 -1
  27. nextrec/models/ranking/din.py +1 -1
  28. nextrec/models/ranking/fibinet.py +1 -1
  29. nextrec/models/ranking/fm.py +1 -1
  30. nextrec/models/ranking/masknet.py +2 -2
  31. nextrec/models/ranking/pnn.py +1 -1
  32. nextrec/models/ranking/widedeep.py +1 -1
  33. nextrec/models/ranking/xdeepfm.py +1 -1
  34. nextrec/utils/__init__.py +2 -1
  35. nextrec/utils/common.py +21 -2
  36. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/METADATA +3 -3
  37. nextrec-0.3.4.dist-info/RECORD +57 -0
  38. nextrec-0.3.2.dist-info/RECORD +0 -57
  39. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/WHEEL +0 -0
  40. {nextrec-0.3.2.dist-info → nextrec-0.3.4.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -10,6 +10,8 @@ import os
10
10
  import tqdm
11
11
  import pickle
12
12
  import logging
13
+ import getpass
14
+ import socket
13
15
  import numpy as np
14
16
  import pandas as pd
15
17
  import torch
@@ -21,21 +23,22 @@ from typing import Union, Literal, Any
21
23
  from torch.utils.data import DataLoader
22
24
 
23
25
  from nextrec.basic.callback import EarlyStopper
24
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSpecMixin
26
+ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
25
27
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
26
28
 
27
- from nextrec.basic.loggers import setup_logger, colorize
29
+ from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
28
30
  from nextrec.basic.session import resolve_save_path, create_session
29
- from nextrec.basic.metrics import configure_metrics, evaluate_metrics
31
+ from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
30
32
 
31
- from nextrec.data import get_column_data, collate_fn
32
33
  from nextrec.data.dataloader import build_tensors_from_data
34
+ from nextrec.data.data_utils import get_column_data, collate_fn, batch_to_dict, get_user_ids
33
35
 
34
36
  from nextrec.loss import get_loss_fn, get_loss_kwargs
35
- from nextrec.utils import get_optimizer, get_scheduler
37
+ from nextrec.utils import get_optimizer, get_scheduler, to_tensor
38
+
36
39
  from nextrec import __version__
37
40
 
38
- class BaseModel(FeatureSpecMixin, nn.Module):
41
+ class BaseModel(FeatureSet, nn.Module):
39
42
  @property
40
43
  def model_name(self) -> str:
41
44
  raise NotImplementedError
@@ -69,72 +72,54 @@ class BaseModel(FeatureSpecMixin, nn.Module):
69
72
  self.session_id = session_id
70
73
  self.session = create_session(session_id)
71
74
  self.session_path = self.session.root # pwd/session_id, path for this session
72
- self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint"+".model")
73
- self.best_path = os.path.join(self.session_path, self.model_name+ "_best.model")
75
+ self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
76
+ self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
74
77
  self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
75
- self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
76
- self.target = self.target_columns
77
- self.target_index = {target_name: idx for idx, target_name in enumerate(self.target)}
78
+ self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
78
79
 
79
80
  self.task = task
80
81
  self.nums_task = len(task) if isinstance(task, list) else 1
81
82
 
82
- self._embedding_l1_reg = embedding_l1_reg
83
- self._dense_l1_reg = dense_l1_reg
84
- self._embedding_l2_reg = embedding_l2_reg
85
- self._dense_l2_reg = dense_l2_reg
86
- self._regularization_weights = []
87
- self._embedding_params = []
88
- self._loss_weights: float | list[float] | None = None
89
- self._early_stop_patience = early_stop_patience
90
- self._max_gradient_norm = 1.0
91
- self._logger_initialized = False
92
-
93
- def _register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
83
+ self.embedding_l1_reg = embedding_l1_reg
84
+ self.dense_l1_reg = dense_l1_reg
85
+ self.embedding_l2_reg = embedding_l2_reg
86
+ self.dense_l2_reg = dense_l2_reg
87
+ self.regularization_weights = []
88
+ self.embedding_params = []
89
+ self.loss_weight = None
90
+ self.early_stop_patience = early_stop_patience
91
+ self.max_gradient_norm = 1.0
92
+ self.logger_initialized = False
93
+ self.training_logger: TrainingLogger | None = None
94
+
95
+ def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
94
96
  exclude_modules = exclude_modules or []
95
97
  include_modules = include_modules or []
96
- if hasattr(self, embedding_attr):
97
- embedding_layer = getattr(self, embedding_attr)
98
- if hasattr(embedding_layer, "embed_dict"):
99
- for embed in embedding_layer.embed_dict.values():
100
- self._embedding_params.append(embed.weight)
98
+ embedding_layer = getattr(self, embedding_attr, None)
99
+ embed_dict = getattr(embedding_layer, "embed_dict", None)
100
+ if embed_dict is not None:
101
+ self.embedding_params.extend(embed.weight for embed in embed_dict.values())
102
+ skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
101
103
  for name, module in self.named_modules():
102
- if module is self:
103
- continue
104
- if embedding_attr in name:
105
- continue
106
- if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.Dropout, nn.Dropout2d, nn.Dropout3d),):
107
- continue
108
- if include_modules:
109
- if not any(inc_name in name for inc_name in include_modules):
110
- continue
111
- if any(exc_name in name for exc_name in exclude_modules):
104
+ if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
112
105
  continue
113
106
  if isinstance(module, nn.Linear):
114
- self._regularization_weights.append(module.weight)
107
+ self.regularization_weights.append(module.weight)
115
108
 
116
109
  def add_reg_loss(self) -> torch.Tensor:
117
110
  reg_loss = torch.tensor(0.0, device=self.device)
118
- if self._embedding_params:
119
- if self._embedding_l1_reg > 0:
120
- reg_loss += self._embedding_l1_reg * sum(param.abs().sum() for param in self._embedding_params)
121
- if self._embedding_l2_reg > 0:
122
- reg_loss += self._embedding_l2_reg * sum((param ** 2).sum() for param in self._embedding_params)
123
- if self._regularization_weights:
124
- if self._dense_l1_reg > 0:
125
- reg_loss += self._dense_l1_reg * sum(param.abs().sum() for param in self._regularization_weights)
126
- if self._dense_l2_reg > 0:
127
- reg_loss += self._dense_l2_reg * sum((param ** 2).sum() for param in self._regularization_weights)
111
+ if self.embedding_params:
112
+ if self.embedding_l1_reg > 0:
113
+ reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
114
+ if self.embedding_l2_reg > 0:
115
+ reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
116
+ if self.regularization_weights:
117
+ if self.dense_l1_reg > 0:
118
+ reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
119
+ if self.dense_l2_reg > 0:
120
+ reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
128
121
  return reg_loss
129
122
 
130
- def _to_tensor(self, value, dtype: torch.dtype) -> torch.Tensor:
131
- tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
132
- if tensor.dtype != dtype:
133
- tensor = tensor.to(dtype=dtype)
134
- if tensor.device != self.device:
135
- tensor = tensor.to(self.device)
136
- return tensor
137
-
138
123
  def get_input(self, input_data: dict, require_labels: bool = True):
139
124
  feature_source = input_data.get("features", {})
140
125
  label_source = input_data.get("labels")
@@ -143,12 +128,11 @@ class BaseModel(FeatureSpecMixin, nn.Module):
143
128
  if feature.name not in feature_source:
144
129
  raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
145
130
  feature_data = get_column_data(feature_source, feature.name)
146
- dtype = torch.float32 if isinstance(feature, DenseFeature) else torch.long
147
- X_input[feature.name] = self._to_tensor(feature_data, dtype=dtype)
131
+ X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
148
132
  y = None
149
- if (len(self.target) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target)))): # need labels: training or eval with labels
133
+ if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
150
134
  target_tensors = []
151
- for target_name in self.target:
135
+ for target_name in self.target_columns:
152
136
  if label_source is None or target_name not in label_source:
153
137
  if require_labels:
154
138
  raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
@@ -158,7 +142,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
158
142
  if require_labels:
159
143
  raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
160
144
  continue
161
- target_tensor = self._to_tensor(target_data, dtype=torch.float32)
145
+ target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
162
146
  target_tensor = target_tensor.view(target_tensor.size(0), -1)
163
147
  target_tensors.append(target_tensor)
164
148
  if target_tensors:
@@ -169,11 +153,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
169
153
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
170
154
  return X_input, y
171
155
 
172
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
173
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
174
- self.early_stopper = EarlyStopper(patience=self._early_stop_patience, mode=self.best_metrics_mode)
175
-
176
- def _handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
156
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
157
+ """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
177
158
  if not (0 < validation_split < 1):
178
159
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
179
160
  if not isinstance(train_data, (pd.DataFrame, dict)):
@@ -181,8 +162,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
181
162
  if isinstance(train_data, pd.DataFrame):
182
163
  total_length = len(train_data)
183
164
  else:
184
- sample_key = next(iter(train_data))
185
- total_length = len(train_data[sample_key])
165
+ sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
166
+ total_length = len(train_data[sample_key]) # len(train_data['user_id'])
186
167
  for k, v in train_data.items():
187
168
  if len(v) != total_length:
188
169
  raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
@@ -198,20 +179,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
198
179
  train_split = {}
199
180
  valid_split = {}
200
181
  for key, value in train_data.items():
201
- if isinstance(value, np.ndarray):
202
- train_split[key] = value[train_indices]
203
- valid_split[key] = value[valid_indices]
204
- elif isinstance(value, (list, tuple)):
205
- arr = np.asarray(value)
206
- train_split[key] = arr[train_indices].tolist()
207
- valid_split[key] = arr[valid_indices].tolist()
208
- elif isinstance(value, pd.Series):
209
- train_split[key] = value.iloc[train_indices].values
210
- valid_split[key] = value.iloc[valid_indices].values
211
- else:
212
- train_split[key] = [value[i] for i in train_indices]
213
- valid_split[key] = [value[i] for i in valid_indices]
214
- train_loader = self._prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
182
+ arr = np.asarray(value)
183
+ train_split[key] = arr[train_indices]
184
+ valid_split[key] = arr[valid_indices]
185
+ train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
215
186
  logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
216
187
  return train_loader, valid_split
217
188
 
@@ -226,44 +197,44 @@ class BaseModel(FeatureSpecMixin, nn.Module):
226
197
  loss_weights: int | float | list[int | float] | None = None,
227
198
  ):
228
199
  optimizer_params = optimizer_params or {}
229
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
230
- self._optimizer_params = optimizer_params
200
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
201
+ self.optimizer_params = optimizer_params
231
202
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
232
203
 
233
204
  scheduler_params = scheduler_params or {}
234
205
  if isinstance(scheduler, str):
235
- self._scheduler_name = scheduler
206
+ self.scheduler_name = scheduler
236
207
  elif scheduler is None:
237
- self._scheduler_name = None
238
- else:
239
- self._scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
240
- self._scheduler_params = scheduler_params
208
+ self.scheduler_name = None
209
+ else: # for custom scheduler instance, need to provide class name for logging
210
+ self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
211
+ self.scheduler_params = scheduler_params
241
212
  self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
242
213
 
243
- self._loss_config = loss
244
- self._loss_params = loss_params or {}
214
+ self.loss_config = loss
215
+ self.loss_params = loss_params or {}
245
216
  self.loss_fn = []
246
- for i in range(self.nums_task):
247
- if isinstance(loss, list):
248
- loss_value = loss[i] if i < len(loss) else None
249
- else:
250
- loss_value = loss
251
- if self.nums_task == 1: # single task
252
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else self._loss_params[0]
253
- else:
254
- loss_kwargs = self._loss_params if isinstance(self._loss_params, dict) else (self._loss_params[i] if i < len(self._loss_params) else {})
255
- self.loss_fn.append(get_loss_fn(loss=loss_value, **loss_kwargs,))
256
- # Normalize loss weights for single-task and multi-task setups
217
+ if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
218
+ loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
219
+ else: # for example: 'bce' -> ['bce', 'bce']
220
+ loss_list = [loss] * self.nums_task
221
+
222
+ if isinstance(self.loss_params, dict):
223
+ params_list = [self.loss_params] * self.nums_task
224
+ else: # list[dict]
225
+ params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
226
+ self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
227
+
257
228
  if loss_weights is None:
258
- self._loss_weights = None
229
+ self.loss_weights = None
259
230
  elif self.nums_task == 1:
260
231
  if isinstance(loss_weights, (list, tuple)):
261
- if len(loss_weights) != 1:
232
+ if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
262
233
  raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
263
234
  weight_value = loss_weights[0]
264
235
  else:
265
236
  weight_value = loss_weights
266
- self._loss_weights = float(weight_value)
237
+ self.loss_weights = float(weight_value)
267
238
  else:
268
239
  if isinstance(loss_weights, (int, float)):
269
240
  weights = [float(loss_weights)] * self.nums_task
@@ -273,87 +244,84 @@ class BaseModel(FeatureSpecMixin, nn.Module):
273
244
  raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
274
245
  else:
275
246
  raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
276
- self._loss_weights = weights
247
+ self.loss_weights = weights
277
248
 
278
249
  def compute_loss(self, y_pred, y_true):
279
250
  if y_true is None:
280
251
  raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
281
252
  if self.nums_task == 1:
282
253
  loss = self.loss_fn[0](y_pred, y_true)
283
- if self._loss_weights is not None:
284
- loss = loss * self._loss_weights
254
+ if self.loss_weights is not None:
255
+ loss = loss * self.loss_weights
285
256
  return loss
286
257
  else:
287
258
  task_losses = []
288
259
  for i in range(self.nums_task):
289
260
  task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
290
- if isinstance(self._loss_weights, (list, tuple)):
291
- task_loss = task_loss * self._loss_weights[i]
261
+ if isinstance(self.loss_weights, (list, tuple)):
262
+ task_loss = task_loss * self.loss_weights[i]
292
263
  task_losses.append(task_loss)
293
264
  return torch.stack(task_losses).sum()
294
265
 
295
- def _prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
266
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
296
267
  if isinstance(data, DataLoader):
297
268
  return data
298
- tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target, id_columns=self.id_columns,)
269
+ tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
299
270
  if tensors is None:
300
271
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
301
272
  dataset = TensorDictDataset(tensors)
302
273
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
303
274
 
304
- def _batch_to_dict(self, batch_data: Any, include_ids: bool = True) -> dict:
305
- if not (isinstance(batch_data, dict) and "features" in batch_data):
306
- raise TypeError("[BaseModel-batch_to_dict Error] Batch data must be a dict with 'features' produced by the current DataLoader.")
307
- return {
308
- "features": batch_data.get("features", {}),
309
- "labels": batch_data.get("labels"),
310
- "ids": batch_data.get("ids") if include_ids else None,
311
- }
312
-
313
275
  def fit(self,
314
- train_data: dict|pd.DataFrame|DataLoader,
315
- valid_data: dict|pd.DataFrame|DataLoader|None=None,
316
- metrics: list[str]|dict[str, list[str]]|None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
276
+ train_data: dict | pd.DataFrame | DataLoader,
277
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
278
+ metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
317
279
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
318
- user_id_column: str = 'user_id',
319
- validation_split: float | None = None):
280
+ user_id_column: str | None = None,
281
+ validation_split: float | None = None,
282
+ tensorboard: bool = True,):
320
283
  self.to(self.device)
321
- if not self._logger_initialized:
284
+ if not self.logger_initialized:
322
285
  setup_logger(session_id=self.session_id)
323
- self._logger_initialized = True
324
- self._set_metrics(metrics) # add self.metrics, self.task_specific_metrics, self.best_metrics_mode, self.early_stopper
325
- self.summary()
326
- valid_loader = None
327
- valid_user_ids: np.ndarray | None = None
328
- needs_user_ids: bool = self._needs_user_ids_for_metrics()
286
+ self.logger_initialized = True
287
+ self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
288
+
289
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
290
+ self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
291
+ self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
292
+ self.epoch_index = 0
293
+ self.stop_training = False
294
+ self.best_checkpoint_path = self.best_path
295
+ self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
329
296
 
330
297
  if validation_split is not None and valid_data is None:
331
- train_loader, valid_data = self._handle_validation_split(
332
- train_data=train_data, # type: ignore
333
- validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,)
298
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
334
299
  else:
335
- train_loader = (train_data if isinstance(train_data, DataLoader) else self._prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
336
- if isinstance(valid_data, DataLoader):
337
- valid_loader = valid_data
338
- elif valid_data is not None:
339
- valid_loader = self._prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
340
- if needs_user_ids:
341
- if isinstance(valid_data, pd.DataFrame) and user_id_column in valid_data.columns:
342
- valid_user_ids = np.asarray(valid_data[user_id_column].values)
343
- elif isinstance(valid_data, dict) and user_id_column in valid_data:
344
- valid_user_ids = np.asarray(valid_data[user_id_column])
300
+ train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
301
+
302
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
345
303
  try:
346
- self._steps_per_epoch = len(train_loader)
304
+ self.steps_per_epoch = len(train_loader)
347
305
  is_streaming = False
348
- except TypeError: # len() not supported, e.g., streaming data loader
349
- self._steps_per_epoch = None
306
+ except TypeError: # streaming data loader does not supported len()
307
+ self.steps_per_epoch = None
350
308
  is_streaming = True
351
309
 
352
- self._epoch_index = 0
353
- self._stop_training = False
354
- self._best_checkpoint_path = self.best_path
355
- self._best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
356
-
310
+ self.summary()
311
+ logging.info("")
312
+ if self.training_logger and self.training_logger.enable_tensorboard:
313
+ tb_dir = self.training_logger.tensorboard_logdir
314
+ if tb_dir:
315
+ user = getpass.getuser()
316
+ host = socket.gethostname()
317
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
318
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
319
+ logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
320
+ logging.info(colorize("To view logs, run:", color="cyan"))
321
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
322
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
323
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
324
+
357
325
  logging.info("")
358
326
  logging.info(colorize("=" * 80, bold=True))
359
327
  if is_streaming:
@@ -363,38 +331,40 @@ class BaseModel(FeatureSpecMixin, nn.Module):
363
331
  logging.info(colorize("=" * 80, bold=True))
364
332
  logging.info("")
365
333
  logging.info(colorize(f"Model device: {self.device}", bold=True))
366
-
334
+
367
335
  for epoch in range(epochs):
368
- self._epoch_index = epoch
336
+ self.epoch_index = epoch
369
337
  if is_streaming:
370
338
  logging.info("")
371
339
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
372
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
373
- if isinstance(train_result, tuple):
340
+
341
+ # handle train result
342
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
343
+ if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
374
344
  train_loss, train_metrics = train_result
375
345
  else:
376
346
  train_loss = train_result
377
347
  train_metrics = None
348
+
349
+ train_log_payload: dict[str, float] = {}
350
+ # handle logging for single-task and multi-task
378
351
  if self.nums_task == 1:
379
352
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
380
353
  if train_metrics:
381
354
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
382
355
  log_str += f", {metrics_str}"
383
- logging.info(colorize(log_str, color="white"))
356
+ logging.info(colorize(log_str))
357
+ train_log_payload["loss"] = float(train_loss)
358
+ if train_metrics:
359
+ train_log_payload.update(train_metrics)
384
360
  else:
385
- task_labels = []
386
- for i in range(self.nums_task):
387
- if i < len(self.target):
388
- task_labels.append(self.target[i])
389
- else:
390
- task_labels.append(f"task_{i}")
391
361
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
392
362
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
393
363
  if train_metrics:
394
- # Group metrics by task
364
+ # group metrics by task
395
365
  task_metrics = {}
396
366
  for metric_key, metric_value in train_metrics.items():
397
- for target_name in self.target:
367
+ for target_name in self.target_columns:
398
368
  if metric_key.endswith(f"_{target_name}"):
399
369
  if target_name not in task_metrics:
400
370
  task_metrics[target_name] = {}
@@ -403,23 +373,28 @@ class BaseModel(FeatureSpecMixin, nn.Module):
403
373
  break
404
374
  if task_metrics:
405
375
  task_metric_strs = []
406
- for target_name in self.target:
376
+ for target_name in self.target_columns:
407
377
  if target_name in task_metrics:
408
378
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
409
379
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
410
380
  log_str += ", " + ", ".join(task_metric_strs)
411
- logging.info(colorize(log_str, color="white"))
381
+ logging.info(colorize(log_str))
382
+ train_log_payload["loss"] = float(total_loss_val)
383
+ if train_metrics:
384
+ train_log_payload.update(train_metrics)
385
+ if self.training_logger:
386
+ self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
412
387
  if valid_loader is not None:
413
- # Pass user_ids only if needed for GAUC metric
414
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
388
+ # pass user_ids only if needed for GAUC metric
389
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
415
390
  if self.nums_task == 1:
416
391
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
417
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
392
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
418
393
  else:
419
394
  # multi task metrics
420
395
  task_metrics = {}
421
396
  for metric_key, metric_value in val_metrics.items():
422
- for target_name in self.target:
397
+ for target_name in self.target_columns:
423
398
  if metric_key.endswith(f"_{target_name}"):
424
399
  if target_name not in task_metrics:
425
400
  task_metrics[target_name] = {}
@@ -427,53 +402,53 @@ class BaseModel(FeatureSpecMixin, nn.Module):
427
402
  task_metrics[target_name][metric_name] = metric_value
428
403
  break
429
404
  task_metric_strs = []
430
- for target_name in self.target:
405
+ for target_name in self.target_columns:
431
406
  if target_name in task_metrics:
432
407
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
433
408
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
434
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
409
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
410
+ if val_metrics and self.training_logger:
411
+ self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
435
412
  # Handle empty validation metrics
436
413
  if not val_metrics:
437
414
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
438
- self._best_checkpoint_path = self.checkpoint_path
415
+ self.best_checkpoint_path = self.checkpoint_path
439
416
  logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
440
417
  continue
441
418
  if self.nums_task == 1:
442
419
  primary_metric_key = self.metrics[0]
443
420
  else:
444
- primary_metric_key = f"{self.metrics[0]}_{self.target[0]}"
445
-
446
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]])
421
+ primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
422
+ primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
447
423
  improved = False
448
-
424
+ # early stopping check
449
425
  if self.best_metrics_mode == 'max':
450
- if primary_metric > self._best_metric:
451
- self._best_metric = primary_metric
452
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
426
+ if primary_metric > self.best_metric:
427
+ self.best_metric = primary_metric
453
428
  improved = True
454
429
  else:
455
- if primary_metric < self._best_metric:
456
- self._best_metric = primary_metric
430
+ if primary_metric < self.best_metric:
431
+ self.best_metric = primary_metric
457
432
  improved = True
458
- # Always keep the latest weights as a rolling checkpoint
459
433
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
434
+ logging.info(" ")
460
435
  if improved:
461
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self._best_metric:.4f}"))
436
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
462
437
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
463
- self._best_checkpoint_path = self.best_path
438
+ self.best_checkpoint_path = self.best_path
464
439
  self.early_stopper.trial_counter = 0
465
440
  else:
466
441
  self.early_stopper.trial_counter += 1
467
442
  logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
468
443
  if self.early_stopper.trial_counter >= self.early_stopper.patience:
469
- self._stop_training = True
444
+ self.stop_training = True
470
445
  logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
471
446
  break
472
447
  else:
473
448
  self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
474
449
  self.save_model(self.best_path, add_timestamp=False, verbose=False)
475
- self._best_checkpoint_path = self.best_path
476
- if self._stop_training:
450
+ self.best_checkpoint_path = self.best_path
451
+ if self.stop_training:
477
452
  break
478
453
  if self.scheduler_fn is not None:
479
454
  if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
@@ -481,34 +456,31 @@ class BaseModel(FeatureSpecMixin, nn.Module):
481
456
  self.scheduler_fn.step(primary_metric)
482
457
  else:
483
458
  self.scheduler_fn.step()
484
- logging.info("\n")
485
- logging.info(colorize("Training finished.", color="bright_green", bold=True))
486
- logging.info("\n")
459
+ logging.info(" ")
460
+ logging.info(colorize("Training finished.", bold=True))
461
+ logging.info(" ")
487
462
  if valid_loader is not None:
488
- logging.info(colorize(f"Load best model from: {self._best_checkpoint_path}", color="bright_blue"))
489
- self.load_model(self._best_checkpoint_path, map_location=self.device, verbose=False)
463
+ logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
464
+ self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
465
+ if self.training_logger:
466
+ self.training_logger.close()
490
467
  return self
491
468
 
492
469
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
493
- if self.nums_task == 1:
494
- accumulated_loss = 0.0
495
- else:
496
- accumulated_loss = 0.0
470
+ accumulated_loss = 0.0
497
471
  self.train()
498
472
  num_batches = 0
499
473
  y_true_list = []
500
474
  y_pred_list = []
501
- needs_user_ids = self._needs_user_ids_for_metrics()
502
- user_ids_list = [] if needs_user_ids else None
503
- if self._steps_per_epoch is not None:
504
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}", total=self._steps_per_epoch))
475
+
476
+ user_ids_list = [] if self.needs_user_ids else None
477
+ if self.steps_per_epoch is not None:
478
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
505
479
  else:
506
- if is_streaming:
507
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc="Batches")) # Streaming mode: show batch/file progress without epoch in desc
508
- else:
509
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self._epoch_index + 1}"))
480
+ desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
481
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
510
482
  for batch_index, batch_data in batch_iter:
511
- batch_dict = self._batch_to_dict(batch_data)
483
+ batch_dict = batch_to_dict(batch_data)
512
484
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
513
485
  y_pred = self.forward(X_input)
514
486
  loss = self.compute_loss(y_pred, y_true)
@@ -516,66 +488,41 @@ class BaseModel(FeatureSpecMixin, nn.Module):
516
488
  total_loss = loss + reg_loss
517
489
  self.optimizer_fn.zero_grad()
518
490
  total_loss.backward()
519
- nn.utils.clip_grad_norm_(self.parameters(), self._max_gradient_norm)
491
+ nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
520
492
  self.optimizer_fn.step()
521
- if self.nums_task == 1:
522
- accumulated_loss += loss.item()
523
- else:
524
- accumulated_loss += loss.item()
493
+ accumulated_loss += loss.item()
525
494
  if y_true is not None:
526
- y_true_list.append(y_true.detach().cpu().numpy()) # Collect predictions and labels for metrics if requested
527
- if needs_user_ids and user_ids_list is not None and batch_dict.get("ids"):
528
- batch_user_id = None
529
- if self.id_columns:
530
- for id_name in self.id_columns:
531
- if id_name in batch_dict["ids"]:
532
- batch_user_id = batch_dict["ids"][id_name]
533
- break
534
- if batch_user_id is None and batch_dict["ids"]:
535
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
495
+ y_true_list.append(y_true.detach().cpu().numpy())
496
+ if self.needs_user_ids and user_ids_list is not None:
497
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
536
498
  if batch_user_id is not None:
537
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
538
- user_ids_list.append(ids_np.reshape(ids_np.shape[0]))
539
- if y_pred is not None and isinstance(y_pred, torch.Tensor): # For pairwise/listwise mode, y_pred is a tuple of embeddings, skip metric collection during training
499
+ user_ids_list.append(batch_user_id)
500
+ if y_pred is not None and isinstance(y_pred, torch.Tensor):
540
501
  y_pred_list.append(y_pred.detach().cpu().numpy())
541
502
  num_batches += 1
542
- avg_loss = accumulated_loss / num_batches
503
+ avg_loss = accumulated_loss / max(num_batches, 1)
543
504
  if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
544
505
  y_true_all = np.concatenate(y_true_list, axis=0)
545
506
  y_pred_all = np.concatenate(y_pred_list, axis=0)
546
507
  combined_user_ids = None
547
- if needs_user_ids and user_ids_list:
508
+ if self.needs_user_ids and user_ids_list:
548
509
  combined_user_ids = np.concatenate(user_ids_list, axis=0)
549
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, self.metrics, user_ids=combined_user_ids)
510
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
550
511
  return avg_loss, metrics_dict
551
512
  return avg_loss
552
513
 
553
- def _needs_user_ids_for_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None) -> bool:
554
- """Check if any configured metric requires user_ids (e.g., gauc, ranking @K)."""
555
- metric_names = set()
556
- sources = [metrics if metrics is not None else getattr(self, "metrics", None), getattr(self, "task_specific_metrics", None),]
557
- for src in sources:
558
- stack = [src]
559
- while stack:
560
- item = stack.pop()
561
- if not item:
562
- continue
563
- if isinstance(item, dict):
564
- stack.extend(item.values())
565
- elif isinstance(item, str):
566
- metric_names.add(item.lower())
567
- else:
568
- try:
569
- for m in item:
570
- metric_names.add(m.lower())
571
- except TypeError:
572
- continue
573
- for name in metric_names:
574
- if name == "gauc":
575
- return True
576
- if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
577
- return True
578
- return False
514
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
515
+ if valid_data is None:
516
+ return None, None
517
+ if isinstance(valid_data, DataLoader):
518
+ return valid_data, None
519
+ valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
520
+ valid_user_ids = None
521
+ if needs_user_ids:
522
+ if user_id_column is None:
523
+ raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
524
+ valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
525
+ return valid_loader, valid_user_ids
579
526
 
580
527
  def evaluate(self,
581
528
  data: dict | pd.DataFrame | DataLoader,
@@ -587,18 +534,14 @@ class BaseModel(FeatureSpecMixin, nn.Module):
587
534
  eval_metrics = metrics if metrics is not None else self.metrics
588
535
  if eval_metrics is None:
589
536
  raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
590
- needs_user_ids = self._needs_user_ids_for_metrics(eval_metrics)
537
+ needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
591
538
 
592
539
  if isinstance(data, DataLoader):
593
540
  data_loader = data
594
541
  else:
595
- # Extract user_ids if needed and not provided
596
542
  if user_ids is None and needs_user_ids:
597
- if isinstance(data, pd.DataFrame) and user_id_column in data.columns:
598
- user_ids = np.asarray(data[user_id_column].values)
599
- elif isinstance(data, dict) and user_id_column in data:
600
- user_ids = np.asarray(data[user_id_column])
601
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False)
543
+ user_ids = get_user_ids(data=data, id_columns=user_id_column)
544
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
602
545
  y_true_list = []
603
546
  y_pred_list = []
604
547
  collected_user_ids = []
@@ -606,26 +549,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
606
549
  with torch.no_grad():
607
550
  for batch_data in data_loader:
608
551
  batch_count += 1
609
- batch_dict = self._batch_to_dict(batch_data)
552
+ batch_dict = batch_to_dict(batch_data)
610
553
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
611
554
  y_pred = self.forward(X_input)
612
555
  if y_true is not None:
613
556
  y_true_list.append(y_true.cpu().numpy())
614
- # Skip if y_pred is not a tensor (e.g., tuple in pairwise mode, though this shouldn't happen in eval mode)
615
557
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
616
558
  y_pred_list.append(y_pred.cpu().numpy())
617
- if needs_user_ids and user_ids is None and batch_dict.get("ids"):
618
- batch_user_id = None
619
- if self.id_columns:
620
- for id_name in self.id_columns:
621
- if id_name in batch_dict["ids"]:
622
- batch_user_id = batch_dict["ids"][id_name]
623
- break
624
- if batch_user_id is None and batch_dict["ids"]:
625
- batch_user_id = next(iter(batch_dict["ids"].values()), None)
559
+ if needs_user_ids and user_ids is None:
560
+ batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
626
561
  if batch_user_id is not None:
627
- ids_np = batch_user_id.detach().cpu().numpy() if isinstance(batch_user_id, torch.Tensor) else np.asarray(batch_user_id)
628
- collected_user_ids.append(ids_np.reshape(ids_np.shape[0]))
562
+ collected_user_ids.append(batch_user_id)
563
+ logging.info(" ")
629
564
  logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
630
565
  if len(y_true_list) > 0:
631
566
  y_true_all = np.concatenate(y_true_list, axis=0)
@@ -654,23 +589,9 @@ class BaseModel(FeatureSpecMixin, nn.Module):
654
589
  final_user_ids = user_ids
655
590
  if final_user_ids is None and collected_user_ids:
656
591
  final_user_ids = np.concatenate(collected_user_ids, axis=0)
657
- metrics_dict = self.evaluate_metrics(y_true_all, y_pred_all, metrics_to_use, final_user_ids)
592
+ metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
658
593
  return metrics_dict
659
594
 
660
- def evaluate_metrics(self, y_true: np.ndarray|None, y_pred: np.ndarray|None, metrics: list[str], user_ids: np.ndarray|None = None) -> dict:
661
- """Evaluate metrics using the metrics module."""
662
- task_specific_metrics = getattr(self, 'task_specific_metrics', None)
663
-
664
- return evaluate_metrics(
665
- y_true=y_true,
666
- y_pred=y_pred,
667
- metrics=metrics,
668
- task=self.task,
669
- target_names=self.target,
670
- task_specific_metrics=task_specific_metrics,
671
- user_ids=user_ids
672
- )
673
-
674
595
  def predict(
675
596
  self,
676
597
  data: str | dict | pd.DataFrame | DataLoader,
@@ -681,28 +602,18 @@ class BaseModel(FeatureSpecMixin, nn.Module):
681
602
  return_dataframe: bool = True,
682
603
  streaming_chunk_size: int = 10000,
683
604
  ) -> pd.DataFrame | np.ndarray:
684
- """
685
- Run inference and optionally return ID-aligned predictions.
686
-
687
- When ``id_columns`` are configured and ``include_ids`` is True (default),
688
- the returned object will include those IDs to keep a one-to-one mapping
689
- between each prediction and its source row.
690
- If ``save_path`` is provided and ``return_dataframe`` is False, predictions
691
- stream to disk batch-by-batch to avoid holding all outputs in memory.
692
- """
693
605
  self.eval()
694
606
  if include_ids is None:
695
607
  include_ids = bool(self.id_columns)
696
608
  include_ids = include_ids and bool(self.id_columns)
697
609
 
698
- # if saving to disk without returning dataframe, use streaming prediction
699
610
  if save_path is not None and not return_dataframe:
700
611
  return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
701
612
  if isinstance(data, (str, os.PathLike)):
702
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns,)
613
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
703
614
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
704
615
  elif not isinstance(data, DataLoader):
705
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
616
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
706
617
  else:
707
618
  data_loader = data
708
619
 
@@ -712,7 +623,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
712
623
 
713
624
  with torch.no_grad():
714
625
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
715
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
626
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
716
627
  X_input, _ = self.get_input(batch_dict, require_labels=False)
717
628
  y_pred = self.forward(X_input)
718
629
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
@@ -722,10 +633,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
722
633
  if id_name not in batch_dict["ids"]:
723
634
  continue
724
635
  id_tensor = batch_dict["ids"][id_name]
725
- if isinstance(id_tensor, torch.Tensor):
726
- id_np = id_tensor.detach().cpu().numpy()
727
- else:
728
- id_np = np.asarray(id_tensor)
636
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
729
637
  id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
730
638
  if len(y_pred_list) > 0:
731
639
  y_pred_all = np.concatenate(y_pred_list, axis=0)
@@ -735,12 +643,12 @@ class BaseModel(FeatureSpecMixin, nn.Module):
735
643
  if y_pred_all.ndim == 1:
736
644
  y_pred_all = y_pred_all.reshape(-1, 1)
737
645
  if y_pred_all.size == 0:
738
- num_outputs = len(self.target) if self.target else 1
646
+ num_outputs = len(self.target_columns) if self.target_columns else 1
739
647
  y_pred_all = y_pred_all.reshape(0, num_outputs)
740
648
  num_outputs = y_pred_all.shape[1]
741
649
  pred_columns: list[str] = []
742
- if self.target:
743
- for name in self.target[:num_outputs]:
650
+ if self.target_columns:
651
+ for name in self.target_columns[:num_outputs]:
744
652
  pred_columns.append(f"{name}_pred")
745
653
  while len(pred_columns) < num_outputs:
746
654
  pred_columns.append(f"pred_{len(pred_columns)}")
@@ -794,10 +702,10 @@ class BaseModel(FeatureSpecMixin, nn.Module):
794
702
  return_dataframe: bool,
795
703
  ) -> pd.DataFrame:
796
704
  if isinstance(data, (str, os.PathLike)):
797
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target, id_columns=self.id_columns)
705
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
798
706
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
799
707
  elif not isinstance(data, DataLoader):
800
- data_loader = self._prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
708
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
801
709
  else:
802
710
  data_loader = data
803
711
 
@@ -812,35 +720,30 @@ class BaseModel(FeatureSpecMixin, nn.Module):
812
720
 
813
721
  with torch.no_grad():
814
722
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
815
- batch_dict = self._batch_to_dict(batch_data, include_ids=include_ids)
723
+ batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
816
724
  X_input, _ = self.get_input(batch_dict, require_labels=False)
817
725
  y_pred = self.forward(X_input)
818
726
  if y_pred is None or not isinstance(y_pred, torch.Tensor):
819
727
  continue
820
-
821
728
  y_pred_np = y_pred.detach().cpu().numpy()
822
729
  if y_pred_np.ndim == 1:
823
730
  y_pred_np = y_pred_np.reshape(-1, 1)
824
-
825
731
  if pred_columns is None:
826
732
  num_outputs = y_pred_np.shape[1]
827
733
  pred_columns = []
828
- if self.target:
829
- for name in self.target[:num_outputs]:
734
+ if self.target_columns:
735
+ for name in self.target_columns[:num_outputs]:
830
736
  pred_columns.append(f"{name}_pred")
831
737
  while len(pred_columns) < num_outputs:
832
738
  pred_columns.append(f"pred_{len(pred_columns)}")
833
-
739
+
834
740
  id_arrays_batch: dict[str, np.ndarray] = {}
835
741
  if include_ids and self.id_columns and batch_dict.get("ids"):
836
742
  for id_name in self.id_columns:
837
743
  if id_name not in batch_dict["ids"]:
838
744
  continue
839
745
  id_tensor = batch_dict["ids"][id_name]
840
- if isinstance(id_tensor, torch.Tensor):
841
- id_np = id_tensor.detach().cpu().numpy()
842
- else:
843
- id_np = np.asarray(id_tensor)
746
+ id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
844
747
  id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
845
748
 
846
749
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
@@ -881,7 +784,7 @@ class BaseModel(FeatureSpecMixin, nn.Module):
881
784
  config_path = self.features_config_path
882
785
  features_config = {
883
786
  "all_features": self.all_features,
884
- "target": self.target,
787
+ "target": self.target_columns,
885
788
  "id_columns": self.id_columns,
886
789
  "version": __version__,
887
790
  }
@@ -921,9 +824,8 @@ class BaseModel(FeatureSpecMixin, nn.Module):
921
824
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
922
825
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
923
826
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
924
- self._set_feature_config(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
925
- self.target = self.target_columns
926
- self.target_index = {name: idx for idx, name in enumerate(self.target)}
827
+ self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
828
+
927
829
  cfg_version = features_config.get("version")
928
830
  if verbose:
929
831
  logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
@@ -1056,41 +958,39 @@ class BaseModel(FeatureSpecMixin, nn.Module):
1056
958
  logger.info(f"Task Type: {self.task}")
1057
959
  logger.info(f"Number of Tasks: {self.nums_task}")
1058
960
  logger.info(f"Metrics: {self.metrics}")
1059
- logger.info(f"Target Columns: {self.target}")
961
+ logger.info(f"Target Columns: {self.target_columns}")
1060
962
  logger.info(f"Device: {self.device}")
1061
963
 
1062
- if hasattr(self, '_optimizer_name'):
1063
- logger.info(f"Optimizer: {self._optimizer_name}")
1064
- if self._optimizer_params:
1065
- for key, value in self._optimizer_params.items():
964
+ if hasattr(self, 'optimizer_name'):
965
+ logger.info(f"Optimizer: {self.optimizer_name}")
966
+ if self.optimizer_params:
967
+ for key, value in self.optimizer_params.items():
1066
968
  logger.info(f" {key:25s}: {value}")
1067
969
 
1068
- if hasattr(self, '_scheduler_name') and self._scheduler_name:
1069
- logger.info(f"Scheduler: {self._scheduler_name}")
1070
- if self._scheduler_params:
1071
- for key, value in self._scheduler_params.items():
970
+ if hasattr(self, 'scheduler_name') and self.scheduler_name:
971
+ logger.info(f"Scheduler: {self.scheduler_name}")
972
+ if self.scheduler_params:
973
+ for key, value in self.scheduler_params.items():
1072
974
  logger.info(f" {key:25s}: {value}")
1073
975
 
1074
- if hasattr(self, '_loss_config'):
1075
- logger.info(f"Loss Function: {self._loss_config}")
1076
- if hasattr(self, '_loss_weights'):
1077
- logger.info(f"Loss Weights: {self._loss_weights}")
976
+ if hasattr(self, 'loss_config'):
977
+ logger.info(f"Loss Function: {self.loss_config}")
978
+ if hasattr(self, 'loss_weights'):
979
+ logger.info(f"Loss Weights: {self.loss_weights}")
1078
980
 
1079
981
  logger.info("Regularization:")
1080
- logger.info(f" Embedding L1: {self._embedding_l1_reg}")
1081
- logger.info(f" Embedding L2: {self._embedding_l2_reg}")
1082
- logger.info(f" Dense L1: {self._dense_l1_reg}")
1083
- logger.info(f" Dense L2: {self._dense_l2_reg}")
982
+ logger.info(f" Embedding L1: {self.embedding_l1_reg}")
983
+ logger.info(f" Embedding L2: {self.embedding_l2_reg}")
984
+ logger.info(f" Dense L1: {self.dense_l1_reg}")
985
+ logger.info(f" Dense L2: {self.dense_l2_reg}")
1084
986
 
1085
987
  logger.info("Other Settings:")
1086
- logger.info(f" Early Stop Patience: {self._early_stop_patience}")
1087
- logger.info(f" Max Gradient Norm: {self._max_gradient_norm}")
988
+ logger.info(f" Early Stop Patience: {self.early_stop_patience}")
989
+ logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
1088
990
  logger.info(f" Session ID: {self.session_id}")
1089
991
  logger.info(f" Features Config Path: {self.features_config_path}")
1090
992
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
1091
-
1092
- logger.info("")
1093
- logger.info("")
993
+
1094
994
 
1095
995
 
1096
996
  class BaseMatchModel(BaseModel):
@@ -1214,18 +1114,18 @@ class BaseMatchModel(BaseModel):
1214
1114
  # Call parent compile with match-specific logic
1215
1115
  optimizer_params = optimizer_params or {}
1216
1116
 
1217
- self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1218
- self._optimizer_params = optimizer_params
1117
+ self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1118
+ self.optimizer_params = optimizer_params
1219
1119
  if isinstance(scheduler, str):
1220
- self._scheduler_name = scheduler
1120
+ self.scheduler_name = scheduler
1221
1121
  elif scheduler is not None:
1222
1122
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1223
- self._scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1123
+ self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
1224
1124
  else:
1225
- self._scheduler_name = None
1226
- self._scheduler_params = scheduler_params or {}
1227
- self._loss_config = loss
1228
- self._loss_params = loss_params or {}
1125
+ self.scheduler_name = None
1126
+ self.scheduler_params = scheduler_params or {}
1127
+ self.loss_config = loss
1128
+ self.loss_params = loss_params or {}
1229
1129
 
1230
1130
  self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
1231
1131
  # Set loss function based on training mode
@@ -1245,7 +1145,7 @@ class BaseMatchModel(BaseModel):
1245
1145
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1246
1146
  if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1247
1147
  loss_value = default_losses.get(self.training_mode, loss_value)
1248
- loss_kwargs = get_loss_kwargs(self._loss_params, 0)
1148
+ loss_kwargs = get_loss_kwargs(self.loss_params, 0)
1249
1149
  self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1250
1150
  # set scheduler
1251
1151
  self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
@@ -1329,57 +1229,47 @@ class BaseMatchModel(BaseModel):
1329
1229
  return loss
1330
1230
  else:
1331
1231
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1232
+
1332
1233
 
1333
- def _set_metrics(self, metrics: list[str] | dict[str, list[str]] | None = None):
1334
- """Reuse BaseModel metric configuration (mode + early stopper)."""
1335
- super()._set_metrics(metrics)
1336
-
1234
+ def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
1235
+ """Prepare data loader for specific features."""
1236
+ if isinstance(data, DataLoader):
1237
+ return data
1238
+
1239
+ feature_data = {}
1240
+ for feature in features:
1241
+ if isinstance(data, dict):
1242
+ if feature.name in data:
1243
+ feature_data[feature.name] = data[feature.name]
1244
+ elif isinstance(data, pd.DataFrame):
1245
+ if feature.name in data.columns:
1246
+ feature_data[feature.name] = data[feature.name].values
1247
+ return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
1248
+
1337
1249
  def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1338
- self.eval()
1339
- if not isinstance(data, DataLoader):
1340
- user_data = {}
1341
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1342
- for feature in all_user_features:
1343
- if isinstance(data, dict):
1344
- if feature.name in data:
1345
- user_data[feature.name] = data[feature.name]
1346
- elif isinstance(data, pd.DataFrame):
1347
- if feature.name in data.columns:
1348
- user_data[feature.name] = data[feature.name].values
1349
- data_loader = self._prepare_data_loader(user_data, batch_size=batch_size, shuffle=False)
1350
- else:
1351
- data_loader = data
1250
+ self.eval()
1251
+ all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1252
+ data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
1253
+
1352
1254
  embeddings_list = []
1353
1255
  with torch.no_grad():
1354
1256
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
1355
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1257
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1356
1258
  user_input = self.get_user_features(batch_dict["features"])
1357
1259
  user_emb = self.user_tower(user_input)
1358
1260
  embeddings_list.append(user_emb.cpu().numpy())
1359
- embeddings = np.concatenate(embeddings_list, axis=0)
1360
- return embeddings
1261
+ return np.concatenate(embeddings_list, axis=0)
1361
1262
 
1362
1263
  def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
1363
1264
  self.eval()
1364
- if not isinstance(data, DataLoader):
1365
- item_data = {}
1366
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1367
- for feature in all_item_features:
1368
- if isinstance(data, dict):
1369
- if feature.name in data:
1370
- item_data[feature.name] = data[feature.name]
1371
- elif isinstance(data, pd.DataFrame):
1372
- if feature.name in data.columns:
1373
- item_data[feature.name] = data[feature.name].values
1374
- data_loader = self._prepare_data_loader(item_data, batch_size=batch_size, shuffle=False)
1375
- else:
1376
- data_loader = data
1265
+ all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1266
+ data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
1267
+
1377
1268
  embeddings_list = []
1378
1269
  with torch.no_grad():
1379
1270
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
1380
- batch_dict = self._batch_to_dict(batch_data, include_ids=False)
1271
+ batch_dict = batch_to_dict(batch_data, include_ids=False)
1381
1272
  item_input = self.get_item_features(batch_dict["features"])
1382
1273
  item_emb = self.item_tower(item_input)
1383
1274
  embeddings_list.append(item_emb.cpu().numpy())
1384
- embeddings = np.concatenate(embeddings_list, axis=0)
1385
- return embeddings
1275
+ return np.concatenate(embeddings_list, axis=0)