nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. nextrec/__init__.py +1 -1
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/activation.py +10 -5
  4. nextrec/basic/callback.py +1 -0
  5. nextrec/basic/features.py +30 -22
  6. nextrec/basic/layers.py +244 -113
  7. nextrec/basic/loggers.py +62 -43
  8. nextrec/basic/metrics.py +268 -119
  9. nextrec/basic/model.py +1373 -443
  10. nextrec/basic/session.py +10 -3
  11. nextrec/cli.py +498 -0
  12. nextrec/data/__init__.py +19 -25
  13. nextrec/data/batch_utils.py +11 -3
  14. nextrec/data/data_processing.py +42 -24
  15. nextrec/data/data_utils.py +26 -15
  16. nextrec/data/dataloader.py +303 -96
  17. nextrec/data/preprocessor.py +320 -199
  18. nextrec/loss/listwise.py +17 -9
  19. nextrec/loss/loss_utils.py +7 -8
  20. nextrec/loss/pairwise.py +2 -0
  21. nextrec/loss/pointwise.py +30 -12
  22. nextrec/models/generative/hstu.py +106 -40
  23. nextrec/models/match/dssm.py +82 -69
  24. nextrec/models/match/dssm_v2.py +72 -58
  25. nextrec/models/match/mind.py +175 -108
  26. nextrec/models/match/sdm.py +104 -88
  27. nextrec/models/match/youtube_dnn.py +73 -60
  28. nextrec/models/multi_task/esmm.py +53 -39
  29. nextrec/models/multi_task/mmoe.py +70 -47
  30. nextrec/models/multi_task/ple.py +107 -50
  31. nextrec/models/multi_task/poso.py +121 -41
  32. nextrec/models/multi_task/share_bottom.py +54 -38
  33. nextrec/models/ranking/afm.py +172 -45
  34. nextrec/models/ranking/autoint.py +84 -61
  35. nextrec/models/ranking/dcn.py +59 -42
  36. nextrec/models/ranking/dcn_v2.py +64 -23
  37. nextrec/models/ranking/deepfm.py +36 -26
  38. nextrec/models/ranking/dien.py +158 -102
  39. nextrec/models/ranking/din.py +88 -60
  40. nextrec/models/ranking/fibinet.py +55 -35
  41. nextrec/models/ranking/fm.py +32 -26
  42. nextrec/models/ranking/masknet.py +95 -34
  43. nextrec/models/ranking/pnn.py +34 -31
  44. nextrec/models/ranking/widedeep.py +37 -29
  45. nextrec/models/ranking/xdeepfm.py +63 -41
  46. nextrec/utils/__init__.py +61 -32
  47. nextrec/utils/config.py +490 -0
  48. nextrec/utils/device.py +52 -12
  49. nextrec/utils/distributed.py +141 -0
  50. nextrec/utils/embedding.py +1 -0
  51. nextrec/utils/feature.py +1 -0
  52. nextrec/utils/file.py +32 -11
  53. nextrec/utils/initializer.py +61 -16
  54. nextrec/utils/optimizer.py +25 -9
  55. nextrec/utils/synthetic_data.py +531 -0
  56. nextrec/utils/tensor.py +24 -13
  57. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
  58. nextrec-0.4.2.dist-info/RECORD +69 -0
  59. nextrec-0.4.2.dist-info/entry_points.txt +2 -0
  60. nextrec-0.3.6.dist-info/RECORD +0 -64
  61. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
  62. {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,7 +2,7 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 05/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
 
@@ -17,13 +17,21 @@ import pandas as pd
17
17
  import torch
18
18
  import torch.nn as nn
19
19
  import torch.nn.functional as F
20
+ import torch.distributed as dist
20
21
 
21
22
  from pathlib import Path
22
23
  from typing import Union, Literal, Any
23
24
  from torch.utils.data import DataLoader
25
+ from torch.utils.data.distributed import DistributedSampler
26
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
27
 
25
28
  from nextrec.basic.callback import EarlyStopper
26
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
29
+ from nextrec.basic.features import (
30
+ DenseFeature,
31
+ SparseFeature,
32
+ SequenceFeature,
33
+ FeatureSet,
34
+ )
27
35
  from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
28
36
 
29
37
  from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
@@ -31,79 +39,149 @@ from nextrec.basic.session import resolve_save_path, create_session
31
39
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
32
40
 
33
41
  from nextrec.data.dataloader import build_tensors_from_data
34
- from nextrec.data.data_processing import get_column_data, get_user_ids
35
42
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
43
+ from nextrec.data.data_processing import get_column_data, get_user_ids
36
44
 
37
45
  from nextrec.loss import get_loss_fn, get_loss_kwargs
38
- from nextrec.utils import get_optimizer, get_scheduler
39
46
  from nextrec.utils.tensor import to_tensor
40
-
47
+ from nextrec.utils.device import configure_device
48
+ from nextrec.utils.optimizer import get_optimizer, get_scheduler
49
+ from nextrec.utils.distributed import (
50
+ gather_numpy,
51
+ init_process_group,
52
+ add_distributed_sampler,
53
+ )
41
54
  from nextrec import __version__
42
55
 
56
+
43
57
  class BaseModel(FeatureSet, nn.Module):
44
58
  @property
45
59
  def model_name(self) -> str:
46
60
  raise NotImplementedError
47
-
61
+
48
62
  @property
49
- def task_type(self) -> str:
63
+ def default_task(self) -> str | list[str]:
50
64
  raise NotImplementedError
51
65
 
52
- def __init__(self,
53
- dense_features: list[DenseFeature] | None = None,
54
- sparse_features: list[SparseFeature] | None = None,
55
- sequence_features: list[SequenceFeature] | None = None,
56
- target: list[str] | str | None = None,
57
- id_columns: list[str] | str | None = None,
58
- task: str|list[str] = 'binary',
59
- device: str = 'cpu',
60
- embedding_l1_reg: float = 0.0,
61
- dense_l1_reg: float = 0.0,
62
- embedding_l2_reg: float = 0.0,
63
- dense_l2_reg: float = 0.0,
64
- early_stop_patience: int = 20,
65
- session_id: str | None = None,):
66
-
66
+ def __init__(
67
+ self,
68
+ dense_features: list[DenseFeature] | None = None,
69
+ sparse_features: list[SparseFeature] | None = None,
70
+ sequence_features: list[SequenceFeature] | None = None,
71
+ target: list[str] | str | None = None,
72
+ id_columns: list[str] | str | None = None,
73
+ task: str | list[str] | None = None,
74
+ device: str = "cpu",
75
+ early_stop_patience: int = 20,
76
+ session_id: str | None = None,
77
+ embedding_l1_reg: float = 0.0,
78
+ dense_l1_reg: float = 0.0,
79
+ embedding_l2_reg: float = 0.0,
80
+ dense_l2_reg: float = 0.0,
81
+ distributed: bool = False,
82
+ rank: int | None = None,
83
+ world_size: int | None = None,
84
+ local_rank: int | None = None,
85
+ ddp_find_unused_parameters: bool = False,
86
+ ):
87
+ """
88
+ Initialize a base model.
89
+
90
+ Args:
91
+ dense_features: DenseFeature definitions.
92
+ sparse_features: SparseFeature definitions.
93
+ sequence_features: SequenceFeature definitions.
94
+ target: Target column name.
95
+ id_columns: Identifier column name, only need to specify if GAUC is required.
96
+ task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
97
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
98
+ embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
99
+ dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
100
+ embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
101
+ dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
102
+ early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
103
+ session_id: Session id for logging. If None, a default id with timestamps will be created.
104
+ distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
105
+ rank: Global rank (defaults to env RANK).
106
+ world_size: Number of processes (defaults to env WORLD_SIZE).
107
+ local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
108
+ ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
109
+ """
67
110
  super(BaseModel, self).__init__()
68
- try:
69
- self.device = torch.device(device)
70
- except Exception as e:
71
- logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
72
- self.device = torch.device('cpu')
111
+
112
+ # distributed training settings
113
+ env_rank = int(os.environ.get("RANK", "0"))
114
+ env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
115
+ env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
116
+ self.distributed = distributed or (env_world_size > 1)
117
+ self.rank = env_rank if rank is None else rank
118
+ self.world_size = env_world_size if world_size is None else world_size
119
+ self.local_rank = env_local_rank if local_rank is None else local_rank
120
+ self.is_main_process = self.rank == 0
121
+ self.ddp_find_unused_parameters = ddp_find_unused_parameters
122
+ self.ddp_model: DDP | None = None
123
+ self.device = configure_device(self.distributed, self.local_rank, device)
73
124
 
74
125
  self.session_id = session_id
75
126
  self.session = create_session(session_id)
76
- self.session_path = self.session.root # pwd/session_id, path for this session
77
- self.checkpoint_path = os.path.join(self.session_path, self.model_name+"_checkpoint.model") # example: pwd/session_id/DeepFM_checkpoint.model
78
- self.best_path = os.path.join(self.session_path, self.model_name+"_best.model")
79
- self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
80
- self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
127
+ self.session_path = self.session.root # pwd/session_id, path for this session
128
+ self.checkpoint_path = os.path.join(
129
+ self.session_path, self.model_name + "_checkpoint.model"
130
+ ) # example: pwd/session_id/DeepFM_checkpoint.model
131
+ self.best_path = os.path.join(
132
+ self.session_path, self.model_name + "_best.model"
133
+ )
134
+ self.features_config_path = os.path.join(
135
+ self.session_path, "features_config.pkl"
136
+ )
137
+ self.set_all_features(
138
+ dense_features, sparse_features, sequence_features, target, id_columns
139
+ )
81
140
 
82
- self.task = task
83
- self.nums_task = len(task) if isinstance(task, list) else 1
141
+ self.task = self.default_task if task is None else task
142
+ self.nums_task = len(self.task) if isinstance(self.task, list) else 1
84
143
 
85
144
  self.embedding_l1_reg = embedding_l1_reg
86
145
  self.dense_l1_reg = dense_l1_reg
87
146
  self.embedding_l2_reg = embedding_l2_reg
88
147
  self.dense_l2_reg = dense_l2_reg
89
- self.regularization_weights = []
148
+ self.regularization_weights = []
90
149
  self.embedding_params = []
91
150
  self.loss_weight = None
151
+
92
152
  self.early_stop_patience = early_stop_patience
93
- self.max_gradient_norm = 1.0
153
+ self.max_gradient_norm = 1.0
94
154
  self.logger_initialized = False
95
- self.training_logger: TrainingLogger | None = None
155
+ self.training_logger = None
96
156
 
97
- def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
157
+ def register_regularization_weights(
158
+ self,
159
+ embedding_attr: str = "embedding",
160
+ exclude_modules: list[str] | None = None,
161
+ include_modules: list[str] | None = None,
162
+ ) -> None:
98
163
  exclude_modules = exclude_modules or []
99
164
  include_modules = include_modules or []
100
165
  embedding_layer = getattr(self, embedding_attr, None)
101
166
  embed_dict = getattr(embedding_layer, "embed_dict", None)
102
167
  if embed_dict is not None:
103
168
  self.embedding_params.extend(embed.weight for embed in embed_dict.values())
104
- skip_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.Dropout, nn.Dropout2d, nn.Dropout3d,)
169
+ skip_types = (
170
+ nn.BatchNorm1d,
171
+ nn.BatchNorm2d,
172
+ nn.BatchNorm3d,
173
+ nn.Dropout,
174
+ nn.Dropout2d,
175
+ nn.Dropout3d,
176
+ )
105
177
  for name, module in self.named_modules():
106
- if (module is self or embedding_attr in name or isinstance(module, skip_types) or (include_modules and not any(inc in name for inc in include_modules)) or any(exc in name for exc in exclude_modules)):
178
+ if (
179
+ module is self
180
+ or embedding_attr in name
181
+ or isinstance(module, skip_types)
182
+ or (include_modules and not any(inc in name for inc in include_modules))
183
+ or any(exc in name for exc in exclude_modules)
184
+ ):
107
185
  continue
108
186
  if isinstance(module, nn.Linear):
109
187
  self.regularization_weights.append(module.weight)
@@ -112,14 +190,22 @@ class BaseModel(FeatureSet, nn.Module):
112
190
  reg_loss = torch.tensor(0.0, device=self.device)
113
191
  if self.embedding_params:
114
192
  if self.embedding_l1_reg > 0:
115
- reg_loss += self.embedding_l1_reg * sum(param.abs().sum() for param in self.embedding_params)
193
+ reg_loss += self.embedding_l1_reg * sum(
194
+ param.abs().sum() for param in self.embedding_params
195
+ )
116
196
  if self.embedding_l2_reg > 0:
117
- reg_loss += self.embedding_l2_reg * sum((param ** 2).sum() for param in self.embedding_params)
197
+ reg_loss += self.embedding_l2_reg * sum(
198
+ (param**2).sum() for param in self.embedding_params
199
+ )
118
200
  if self.regularization_weights:
119
201
  if self.dense_l1_reg > 0:
120
- reg_loss += self.dense_l1_reg * sum(param.abs().sum() for param in self.regularization_weights)
202
+ reg_loss += self.dense_l1_reg * sum(
203
+ param.abs().sum() for param in self.regularization_weights
204
+ )
121
205
  if self.dense_l2_reg > 0:
122
- reg_loss += self.dense_l2_reg * sum((param ** 2).sum() for param in self.regularization_weights)
206
+ reg_loss += self.dense_l2_reg * sum(
207
+ (param**2).sum() for param in self.regularization_weights
208
+ )
123
209
  return reg_loss
124
210
 
125
211
  def get_input(self, input_data: dict, require_labels: bool = True):
@@ -128,47 +214,90 @@ class BaseModel(FeatureSet, nn.Module):
128
214
  X_input = {}
129
215
  for feature in self.all_features:
130
216
  if feature.name not in feature_source:
131
- raise KeyError(f"[BaseModel-input Error] Feature '{feature.name}' not found in input data.")
217
+ raise KeyError(
218
+ f"[BaseModel-input Error] Feature '{feature.name}' not found in input data."
219
+ )
132
220
  feature_data = get_column_data(feature_source, feature.name)
133
- X_input[feature.name] = to_tensor(feature_data, dtype=torch.float32 if isinstance(feature, DenseFeature) else torch.long, device=self.device)
221
+ X_input[feature.name] = to_tensor(
222
+ feature_data,
223
+ dtype=(
224
+ torch.float32 if isinstance(feature, DenseFeature) else torch.long
225
+ ),
226
+ device=self.device,
227
+ )
134
228
  y = None
135
- if (len(self.target_columns) > 0 and (require_labels or (label_source and any(name in label_source for name in self.target_columns)))): # need labels: training or eval with labels
229
+ if len(self.target_columns) > 0 and (
230
+ require_labels
231
+ or (
232
+ label_source
233
+ and any(name in label_source for name in self.target_columns)
234
+ )
235
+ ): # need labels: training or eval with labels
136
236
  target_tensors = []
137
237
  for target_name in self.target_columns:
138
238
  if label_source is None or target_name not in label_source:
139
239
  if require_labels:
140
- raise KeyError(f"[BaseModel-input Error] Target column '{target_name}' not found in input data.")
240
+ raise KeyError(
241
+ f"[BaseModel-input Error] Target column '{target_name}' not found in input data."
242
+ )
141
243
  continue
142
244
  target_data = get_column_data(label_source, target_name)
143
245
  if target_data is None:
144
246
  if require_labels:
145
- raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
247
+ raise ValueError(
248
+ f"[BaseModel-input Error] Target column '{target_name}' contains no data."
249
+ )
146
250
  continue
147
- target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
148
- target_tensor = target_tensor.view(target_tensor.size(0), -1)
251
+ target_tensor = to_tensor(
252
+ target_data, dtype=torch.float32, device=self.device
253
+ )
254
+ target_tensor = target_tensor.view(
255
+ target_tensor.size(0), -1
256
+ ) # always reshape to (batch_size, num_targets)
149
257
  target_tensors.append(target_tensor)
150
258
  if target_tensors:
151
259
  y = torch.cat(target_tensors, dim=1)
152
- if y.shape[1] == 1:
260
+ if y.shape[1] == 1: # no need to do that again
153
261
  y = y.view(-1)
154
262
  elif require_labels:
155
- raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
263
+ raise ValueError(
264
+ "[BaseModel-input Error] Labels are required but none were found in the input batch."
265
+ )
156
266
  return X_input, y
157
267
 
158
- def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,) -> tuple[DataLoader, dict | pd.DataFrame]:
159
- """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
268
+ def handle_validation_split(
269
+ self,
270
+ train_data: dict | pd.DataFrame,
271
+ validation_split: float,
272
+ batch_size: int,
273
+ shuffle: bool,
274
+ num_workers: int = 0,
275
+ ):
276
+ """
277
+ This function will split training data into training and validation sets when:
278
+ 1. valid_data is None;
279
+ 2. validation_split is provided.
280
+ """
160
281
  if not (0 < validation_split < 1):
161
- raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
282
+ raise ValueError(
283
+ f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
284
+ )
162
285
  if not isinstance(train_data, (pd.DataFrame, dict)):
163
- raise TypeError(f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}")
286
+ raise TypeError(
287
+ f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
288
+ )
164
289
  if isinstance(train_data, pd.DataFrame):
165
290
  total_length = len(train_data)
166
291
  else:
167
- sample_key = next(iter(train_data)) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
168
- total_length = len(train_data[sample_key]) # len(train_data['user_id'])
292
+ sample_key = next(
293
+ iter(train_data)
294
+ ) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
295
+ total_length = len(train_data[sample_key]) # len(train_data['user_id'])
169
296
  for k, v in train_data.items():
170
297
  if len(v) != total_length:
171
- raise ValueError(f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})")
298
+ raise ValueError(
299
+ f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
300
+ )
172
301
  rng = np.random.default_rng(42)
173
302
  indices = rng.permutation(total_length)
174
303
  split_idx = int(total_length * (1 - validation_split))
@@ -181,169 +310,444 @@ class BaseModel(FeatureSet, nn.Module):
181
310
  train_split = {}
182
311
  valid_split = {}
183
312
  for key, value in train_data.items():
184
- arr = np.asarray(value)
313
+ arr = np.asarray(value)
185
314
  train_split[key] = arr[train_indices]
186
315
  valid_split[key] = arr[valid_indices]
187
- train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
188
- logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
316
+ train_loader = self.prepare_data_loader(
317
+ train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
318
+ )
319
+ logging.info(
320
+ f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
321
+ )
189
322
  return train_loader, valid_split
190
323
 
191
324
  def compile(
192
325
  self,
193
326
  optimizer: str | torch.optim.Optimizer = "adam",
194
327
  optimizer_params: dict | None = None,
195
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
328
+ scheduler: (
329
+ str
330
+ | torch.optim.lr_scheduler._LRScheduler
331
+ | torch.optim.lr_scheduler.LRScheduler
332
+ | type[torch.optim.lr_scheduler._LRScheduler]
333
+ | type[torch.optim.lr_scheduler.LRScheduler]
334
+ | None
335
+ ) = None,
196
336
  scheduler_params: dict | None = None,
197
337
  loss: str | nn.Module | list[str | nn.Module] | None = "bce",
198
338
  loss_params: dict | list[dict] | None = None,
199
339
  loss_weights: int | float | list[int | float] | None = None,
200
340
  ):
341
+ """
342
+ Configure the model for training.
343
+ Args:
344
+ optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
345
+ optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
346
+ scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
347
+ scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
348
+ loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
349
+ loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
350
+ loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
351
+ """
352
+ if loss_params is None:
353
+ self.loss_params = {}
354
+ else:
355
+ self.loss_params = loss_params
201
356
  optimizer_params = optimizer_params or {}
202
- self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
357
+ self.optimizer_name = (
358
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
359
+ )
203
360
  self.optimizer_params = optimizer_params
204
- self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params,)
361
+ self.optimizer_fn = get_optimizer(
362
+ optimizer=optimizer,
363
+ params=self.parameters(),
364
+ **optimizer_params,
365
+ )
205
366
 
206
367
  scheduler_params = scheduler_params or {}
207
368
  if isinstance(scheduler, str):
208
369
  self.scheduler_name = scheduler
209
370
  elif scheduler is None:
210
371
  self.scheduler_name = None
211
- else: # for custom scheduler instance, need to provide class name for logging
212
- self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
372
+ else: # for custom scheduler instance, need to provide class name for logging
373
+ self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
213
374
  self.scheduler_params = scheduler_params
214
- self.scheduler_fn = (get_scheduler(scheduler, self.optimizer_fn, **scheduler_params) if scheduler else None)
375
+ self.scheduler_fn = (
376
+ get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
377
+ if scheduler
378
+ else None
379
+ )
215
380
 
216
381
  self.loss_config = loss
217
382
  self.loss_params = loss_params or {}
218
383
  self.loss_fn = []
219
- if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
220
- loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
221
- else: # for example: 'bce' -> ['bce', 'bce']
384
+ if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
385
+ if len(loss) != self.nums_task:
386
+ raise ValueError(
387
+ f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
388
+ )
389
+ loss_list = [loss[i] for i in range(self.nums_task)]
390
+ else: # for example: 'bce' -> ['bce', 'bce']
222
391
  loss_list = [loss] * self.nums_task
223
392
 
224
393
  if isinstance(self.loss_params, dict):
225
394
  params_list = [self.loss_params] * self.nums_task
226
395
  else: # list[dict]
227
- params_list = [self.loss_params[i] if i < len(self.loss_params) else {} for i in range(self.nums_task)]
228
- self.loss_fn = [get_loss_fn(loss=loss_list[i], **params_list[i]) for i in range(self.nums_task)]
396
+ params_list = [
397
+ self.loss_params[i] if i < len(self.loss_params) else {}
398
+ for i in range(self.nums_task)
399
+ ]
400
+ self.loss_fn = [
401
+ get_loss_fn(loss=loss_list[i], **params_list[i])
402
+ for i in range(self.nums_task)
403
+ ]
229
404
 
230
405
  if loss_weights is None:
231
406
  self.loss_weights = None
232
407
  elif self.nums_task == 1:
233
408
  if isinstance(loss_weights, (list, tuple)):
234
- if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
235
- raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
409
+ if len(loss_weights) != 1:
410
+ raise ValueError(
411
+ "[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
412
+ )
236
413
  weight_value = loss_weights[0]
237
414
  else:
238
415
  weight_value = loss_weights
239
- self.loss_weights = float(weight_value)
416
+ self.loss_weights = [float(weight_value)]
240
417
  else:
241
418
  if isinstance(loss_weights, (int, float)):
242
419
  weights = [float(loss_weights)] * self.nums_task
243
420
  elif isinstance(loss_weights, (list, tuple)):
244
421
  weights = [float(w) for w in loss_weights]
245
422
  if len(weights) != self.nums_task:
246
- raise ValueError(f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task}).")
423
+ raise ValueError(
424
+ f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
425
+ )
247
426
  else:
248
- raise TypeError(f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}")
427
+ raise TypeError(
428
+ f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
429
+ )
249
430
  self.loss_weights = weights
250
431
 
251
432
  def compute_loss(self, y_pred, y_true):
252
433
  if y_true is None:
253
- raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
434
+ raise ValueError(
435
+ "[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
436
+ )
254
437
  if self.nums_task == 1:
255
- loss = self.loss_fn[0](y_pred, y_true)
438
+ if y_pred.dim() == 1:
439
+ y_pred = y_pred.view(-1, 1)
440
+ if y_true.dim() == 1:
441
+ y_true = y_true.view(-1, 1)
442
+ if y_pred.shape != y_true.shape:
443
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
444
+ task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
445
+ if task_dim == 1:
446
+ loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
447
+ else:
448
+ loss = self.loss_fn[0](y_pred, y_true)
256
449
  if self.loss_weights is not None:
257
- loss = loss * self.loss_weights
450
+ loss *= self.loss_weights[0]
258
451
  return loss
452
+ # multi-task
453
+ if y_pred.shape != y_true.shape:
454
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
455
+ if hasattr(
456
+ self, "prediction_layer"
457
+ ): # we need to use registered task_slices for multi-task and multi-class
458
+ slices = self.prediction_layer._task_slices # type: ignore
259
459
  else:
260
- task_losses = []
261
- for i in range(self.nums_task):
262
- task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
263
- if isinstance(self.loss_weights, (list, tuple)):
264
- task_loss = task_loss * self.loss_weights[i]
265
- task_losses.append(task_loss)
266
- return torch.stack(task_losses).sum()
267
-
268
- def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0,) -> DataLoader:
460
+ slices = [(i, i + 1) for i in range(self.nums_task)]
461
+ task_losses = []
462
+ for i, (start, end) in enumerate(slices): # type: ignore
463
+ y_pred_i = y_pred[:, start:end]
464
+ y_true_i = y_true[:, start:end]
465
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
466
+ if isinstance(self.loss_weights, (list, tuple)):
467
+ task_loss *= self.loss_weights[i]
468
+ task_losses.append(task_loss)
469
+ return torch.stack(task_losses).sum()
470
+
471
+ def prepare_data_loader(
472
+ self,
473
+ data: dict | pd.DataFrame | DataLoader,
474
+ batch_size: int = 32,
475
+ shuffle: bool = True,
476
+ num_workers: int = 0,
477
+ sampler=None,
478
+ return_dataset: bool = False,
479
+ ) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
269
480
  if isinstance(data, DataLoader):
270
- return data
271
- tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
481
+ return (data, None) if return_dataset else data
482
+ tensors = build_tensors_from_data(
483
+ data=data,
484
+ raw_data=data,
485
+ features=self.all_features,
486
+ target_columns=self.target_columns,
487
+ id_columns=self.id_columns,
488
+ )
272
489
  if tensors is None:
273
- raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
490
+ raise ValueError(
491
+ "[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
492
+ )
274
493
  dataset = TensorDictDataset(tensors)
275
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
276
-
277
- def fit(self,
278
- train_data: dict | pd.DataFrame | DataLoader,
279
- valid_data: dict | pd.DataFrame | DataLoader | None = None,
280
- metrics: list[str] | dict[str, list[str]] | None = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
281
- epochs:int=1, shuffle:bool=True, batch_size:int=32,
282
- user_id_column: str | None = None,
283
- validation_split: float | None = None,
284
- num_workers: int = 0,
285
- tensorboard: bool = True,):
494
+ loader = DataLoader(
495
+ dataset,
496
+ batch_size=batch_size,
497
+ shuffle=False if sampler is not None else shuffle,
498
+ sampler=sampler,
499
+ collate_fn=collate_fn,
500
+ num_workers=num_workers,
501
+ )
502
+ return (loader, dataset) if return_dataset else loader
503
+
504
+ def fit(
505
+ self,
506
+ train_data: dict | pd.DataFrame | DataLoader,
507
+ valid_data: dict | pd.DataFrame | DataLoader | None = None,
508
+ metrics: (
509
+ list[str] | dict[str, list[str]] | None
510
+ ) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
511
+ epochs: int = 1,
512
+ shuffle: bool = True,
513
+ batch_size: int = 32,
514
+ user_id_column: str | None = None,
515
+ validation_split: float | None = None,
516
+ num_workers: int = 0,
517
+ tensorboard: bool = True,
518
+ auto_distributed_sampler: bool = True,
519
+ ):
520
+ """
521
+ Train the model.
522
+
523
+ Args:
524
+ train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
525
+ valid_data: Optional validation data; if None and validation_split is set, a split is created.
526
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
527
+ epochs: Training epochs.
528
+ shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
529
+ batch_size: Batch size (per process when distributed).
530
+ user_id_column: Column name for GAUC-style metrics;.
531
+ validation_split: Ratio to split training data when valid_data is None.
532
+ num_workers: DataLoader worker count.
533
+ tensorboard: Enable tensorboard logging.
534
+ auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
535
+
536
+ Notes:
537
+ - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
538
+ - All ranks must call evaluate() together because it performs collective ops.
539
+ """
540
+ device_id = self.local_rank if self.device.type == "cuda" else None
541
+ init_process_group(
542
+ self.distributed, self.rank, self.world_size, device_id=device_id
543
+ )
286
544
  self.to(self.device)
287
- if not self.logger_initialized:
545
+
546
+ if (
547
+ self.distributed
548
+ and dist.is_available()
549
+ and dist.is_initialized()
550
+ and self.ddp_model is None
551
+ ):
552
+ device_ids = (
553
+ [self.local_rank] if self.device.type == "cuda" else None
554
+ ) # device_ids means which device to use in ddp
555
+ output_device = (
556
+ self.local_rank if self.device.type == "cuda" else None
557
+ ) # output_device means which device to place the output in ddp
558
+ object.__setattr__(
559
+ self,
560
+ "ddp_model",
561
+ DDP(
562
+ self,
563
+ device_ids=device_ids,
564
+ output_device=output_device,
565
+ find_unused_parameters=self.ddp_find_unused_parameters,
566
+ ),
567
+ )
568
+
569
+ if (
570
+ not self.logger_initialized and self.is_main_process
571
+ ): # only main process initializes logger
288
572
  setup_logger(session_id=self.session_id)
289
573
  self.logger_initialized = True
290
- self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
574
+ self.training_logger = (
575
+ TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
576
+ if self.is_main_process
577
+ else None
578
+ )
291
579
 
292
- self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
293
- self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
294
- self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
580
+ self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
581
+ configure_metrics(
582
+ task=self.task, metrics=metrics, target_names=self.target_columns
583
+ )
584
+ ) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
585
+ self.early_stopper = EarlyStopper(
586
+ patience=self.early_stop_patience, mode=self.best_metrics_mode
587
+ )
588
+ self.best_metric = (
589
+ float("-inf") if self.best_metrics_mode == "max" else float("inf")
590
+ )
591
+
592
+ self.needs_user_ids = check_user_id(
593
+ self.metrics, self.task_specific_metrics
594
+ ) # check user_id needed for GAUC metrics
295
595
  self.epoch_index = 0
296
596
  self.stop_training = False
297
597
  self.best_checkpoint_path = self.best_path
298
- self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
299
598
 
599
+ if not auto_distributed_sampler and self.distributed and self.is_main_process:
600
+ logging.info(
601
+ colorize(
602
+ "[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
603
+ color="yellow",
604
+ )
605
+ )
606
+
607
+ train_sampler: DistributedSampler | None = None
300
608
  if validation_split is not None and valid_data is None:
301
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
609
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
610
+ if (
611
+ auto_distributed_sampler
612
+ and self.distributed
613
+ and dist.is_available()
614
+ and dist.is_initialized()
615
+ ):
616
+ base_dataset = getattr(train_loader, "dataset", None)
617
+ if base_dataset is not None and not isinstance(
618
+ getattr(train_loader, "sampler", None), DistributedSampler
619
+ ):
620
+ train_sampler = DistributedSampler(
621
+ base_dataset,
622
+ num_replicas=self.world_size,
623
+ rank=self.rank,
624
+ shuffle=shuffle,
625
+ drop_last=True,
626
+ )
627
+ train_loader = DataLoader(
628
+ base_dataset,
629
+ batch_size=batch_size,
630
+ shuffle=False,
631
+ sampler=train_sampler,
632
+ collate_fn=collate_fn,
633
+ num_workers=num_workers,
634
+ drop_last=True,
635
+ )
302
636
  else:
303
- train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers))
304
-
305
- valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers)
637
+ if isinstance(train_data, DataLoader):
638
+ if auto_distributed_sampler and self.distributed:
639
+ train_loader, train_sampler = add_distributed_sampler(
640
+ train_data,
641
+ distributed=self.distributed,
642
+ world_size=self.world_size,
643
+ rank=self.rank,
644
+ shuffle=shuffle,
645
+ drop_last=True,
646
+ default_batch_size=batch_size,
647
+ is_main_process=self.is_main_process,
648
+ )
649
+ # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
650
+ else:
651
+ train_loader = train_data
652
+ else:
653
+ loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
654
+ if (
655
+ auto_distributed_sampler
656
+ and self.distributed
657
+ and dataset is not None
658
+ and dist.is_available()
659
+ and dist.is_initialized()
660
+ ):
661
+ train_sampler = DistributedSampler(
662
+ dataset,
663
+ num_replicas=self.world_size,
664
+ rank=self.rank,
665
+ shuffle=shuffle,
666
+ drop_last=True,
667
+ )
668
+ loader = DataLoader(
669
+ dataset,
670
+ batch_size=batch_size,
671
+ shuffle=False,
672
+ sampler=train_sampler,
673
+ collate_fn=collate_fn,
674
+ num_workers=num_workers,
675
+ drop_last=True,
676
+ )
677
+ train_loader = loader
678
+
679
+ # If split-based loader was built without sampler, attach here when enabled
680
+ if (
681
+ self.distributed
682
+ and auto_distributed_sampler
683
+ and isinstance(train_loader, DataLoader)
684
+ and train_sampler is None
685
+ ):
686
+ raise NotImplementedError(
687
+ "[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
688
+ )
689
+ # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
690
+
691
+ valid_loader, valid_user_ids = self.prepare_validation_data(
692
+ valid_data=valid_data,
693
+ batch_size=batch_size,
694
+ needs_user_ids=self.needs_user_ids,
695
+ user_id_column=user_id_column,
696
+ num_workers=num_workers,
697
+ auto_distributed_sampler=auto_distributed_sampler,
698
+ )
306
699
  try:
307
700
  self.steps_per_epoch = len(train_loader)
308
701
  is_streaming = False
309
- except TypeError: # streaming data loader does not supported len()
702
+ except TypeError: # streaming data loader does not supported len()
310
703
  self.steps_per_epoch = None
311
704
  is_streaming = True
312
705
 
313
- self.summary()
314
- logging.info("")
315
- if self.training_logger and self.training_logger.enable_tensorboard:
316
- tb_dir = self.training_logger.tensorboard_logdir
317
- if tb_dir:
318
- user = getpass.getuser()
319
- host = socket.gethostname()
320
- tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
321
- ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
322
- logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
323
- logging.info(colorize("To view logs, run:", color="cyan"))
324
- logging.info(colorize(f" {tb_cmd}", color="cyan"))
325
- logging.info(colorize("Then SSH port forward:", color="cyan"))
326
- logging.info(colorize(f" {ssh_hint}", color="cyan"))
327
-
328
- logging.info("")
329
- logging.info(colorize("=" * 80, bold=True))
330
- if is_streaming:
331
- logging.info(colorize(f"Start streaming training", bold=True))
332
- else:
333
- logging.info(colorize(f"Start training", bold=True))
334
- logging.info(colorize("=" * 80, bold=True))
335
- logging.info("")
336
- logging.info(colorize(f"Model device: {self.device}", bold=True))
706
+ if self.is_main_process:
707
+ self.summary()
708
+ logging.info("")
709
+ if self.training_logger and self.training_logger.enable_tensorboard:
710
+ tb_dir = self.training_logger.tensorboard_logdir
711
+ if tb_dir:
712
+ user = getpass.getuser()
713
+ host = socket.gethostname()
714
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
715
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
716
+ logging.info(
717
+ colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
718
+ )
719
+ logging.info(colorize("To view logs, run:", color="cyan"))
720
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
721
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
722
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
723
+
724
+ logging.info("")
725
+ logging.info(colorize("=" * 80, bold=True))
726
+ if is_streaming:
727
+ logging.info(colorize("Start streaming training", bold=True))
728
+ else:
729
+ logging.info(colorize("Start training", bold=True))
730
+ logging.info(colorize("=" * 80, bold=True))
731
+ logging.info("")
732
+ logging.info(colorize(f"Model device: {self.device}", bold=True))
337
733
 
338
734
  for epoch in range(epochs):
339
735
  self.epoch_index = epoch
340
- if is_streaming:
736
+ if is_streaming and self.is_main_process:
341
737
  logging.info("")
342
- logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
738
+ logging.info(
739
+ colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)
740
+ ) # streaming mode, print epoch header before progress bar
343
741
 
344
742
  # handle train result
345
- train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
346
- if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
743
+ if (
744
+ self.distributed
745
+ and hasattr(train_loader, "sampler")
746
+ and isinstance(train_loader.sampler, DistributedSampler)
747
+ ):
748
+ train_loader.sampler.set_epoch(epoch)
749
+ train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
750
+ if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
347
751
  train_loss, train_metrics = train_result
348
752
  else:
349
753
  train_loss = train_result
@@ -354,15 +758,20 @@ class BaseModel(FeatureSet, nn.Module):
354
758
  if self.nums_task == 1:
355
759
  log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
356
760
  if train_metrics:
357
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
761
+ metrics_str = ", ".join(
762
+ [f"{k}={v:.4f}" for k, v in train_metrics.items()]
763
+ )
358
764
  log_str += f", {metrics_str}"
359
- logging.info(colorize(log_str))
765
+ if self.is_main_process:
766
+ logging.info(colorize(log_str))
360
767
  train_log_payload["loss"] = float(train_loss)
361
768
  if train_metrics:
362
769
  train_log_payload.update(train_metrics)
363
770
  else:
364
771
  total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
365
- log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
772
+ log_str = (
773
+ f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
774
+ )
366
775
  if train_metrics:
367
776
  # group metrics by task
368
777
  task_metrics = {}
@@ -378,21 +787,41 @@ class BaseModel(FeatureSet, nn.Module):
378
787
  task_metric_strs = []
379
788
  for target_name in self.target_columns:
380
789
  if target_name in task_metrics:
381
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
790
+ metrics_str = ", ".join(
791
+ [
792
+ f"{k}={v:.4f}"
793
+ for k, v in task_metrics[target_name].items()
794
+ ]
795
+ )
382
796
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
383
797
  log_str += ", " + ", ".join(task_metric_strs)
384
- logging.info(colorize(log_str))
798
+ if self.is_main_process:
799
+ logging.info(colorize(log_str))
385
800
  train_log_payload["loss"] = float(total_loss_val)
386
801
  if train_metrics:
387
802
  train_log_payload.update(train_metrics)
388
803
  if self.training_logger:
389
- self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
804
+ self.training_logger.log_metrics(
805
+ train_log_payload, step=epoch + 1, split="train"
806
+ )
390
807
  if valid_loader is not None:
391
808
  # pass user_ids only if needed for GAUC metric
392
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
809
+ val_metrics = self.evaluate(
810
+ valid_loader,
811
+ user_ids=valid_user_ids if self.needs_user_ids else None,
812
+ num_workers=num_workers,
813
+ ) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
393
814
  if self.nums_task == 1:
394
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
395
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
815
+ metrics_str = ", ".join(
816
+ [f"{k}={v:.4f}" for k, v in val_metrics.items()]
817
+ )
818
+ if self.is_main_process:
819
+ logging.info(
820
+ colorize(
821
+ f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
822
+ color="cyan",
823
+ )
824
+ )
396
825
  else:
397
826
  # multi task metrics
398
827
  task_metrics = {}
@@ -407,25 +836,58 @@ class BaseModel(FeatureSet, nn.Module):
407
836
  task_metric_strs = []
408
837
  for target_name in self.target_columns:
409
838
  if target_name in task_metrics:
410
- metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
839
+ metrics_str = ", ".join(
840
+ [
841
+ f"{k}={v:.4f}"
842
+ for k, v in task_metrics[target_name].items()
843
+ ]
844
+ )
411
845
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
412
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
846
+ if self.is_main_process:
847
+ logging.info(
848
+ colorize(
849
+ f" Epoch {epoch + 1}/{epochs} - Valid: "
850
+ + ", ".join(task_metric_strs),
851
+ color="cyan",
852
+ )
853
+ )
413
854
  if val_metrics and self.training_logger:
414
- self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
855
+ self.training_logger.log_metrics(
856
+ val_metrics, step=epoch + 1, split="valid"
857
+ )
415
858
  # Handle empty validation metrics
416
859
  if not val_metrics:
417
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
418
- self.best_checkpoint_path = self.checkpoint_path
419
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
860
+ if self.is_main_process:
861
+ self.save_model(
862
+ self.checkpoint_path, add_timestamp=False, verbose=False
863
+ )
864
+ self.best_checkpoint_path = self.checkpoint_path
865
+ logging.info(
866
+ colorize(
867
+ "Warning: No validation metrics computed. Skipping validation for this epoch.",
868
+ color="yellow",
869
+ )
870
+ )
420
871
  continue
421
872
  if self.nums_task == 1:
422
873
  primary_metric_key = self.metrics[0]
423
874
  else:
424
875
  primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
425
- primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
876
+ primary_metric = val_metrics.get(
877
+ primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
878
+ ) # get primary metric value, default to first metric if not found
879
+
880
+ # In distributed mode, broadcast primary_metric to ensure all processes use the same value
881
+ if self.distributed and dist.is_available() and dist.is_initialized():
882
+ metric_tensor = torch.tensor(
883
+ [primary_metric], device=self.device, dtype=torch.float32
884
+ )
885
+ dist.broadcast(metric_tensor, src=0)
886
+ primary_metric = float(metric_tensor.item())
887
+
426
888
  improved = False
427
889
  # early stopping check
428
- if self.best_metrics_mode == 'max':
890
+ if self.best_metrics_mode == "max":
429
891
  if primary_metric > self.best_metric:
430
892
  self.best_metric = primary_metric
431
893
  improved = True
@@ -433,119 +895,287 @@ class BaseModel(FeatureSet, nn.Module):
433
895
  if primary_metric < self.best_metric:
434
896
  self.best_metric = primary_metric
435
897
  improved = True
436
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
437
- logging.info(" ")
438
- if improved:
439
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
440
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
441
- self.best_checkpoint_path = self.best_path
442
- self.early_stopper.trial_counter = 0
898
+
899
+ # save checkpoint and best model for main process
900
+ if self.is_main_process:
901
+ self.save_model(
902
+ self.checkpoint_path, add_timestamp=False, verbose=False
903
+ )
904
+ logging.info(" ")
905
+ if improved:
906
+ logging.info(
907
+ colorize(
908
+ f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
909
+ )
910
+ )
911
+ self.save_model(
912
+ self.best_path, add_timestamp=False, verbose=False
913
+ )
914
+ self.best_checkpoint_path = self.best_path
915
+ self.early_stopper.trial_counter = 0
916
+ else:
917
+ self.early_stopper.trial_counter += 1
918
+ logging.info(
919
+ colorize(
920
+ f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
921
+ )
922
+ )
923
+ if self.early_stopper.trial_counter >= self.early_stopper.patience:
924
+ self.stop_training = True
925
+ logging.info(
926
+ colorize(
927
+ f"Early stopping triggered after {epoch + 1} epochs",
928
+ color="bright_red",
929
+ bold=True,
930
+ )
931
+ )
443
932
  else:
444
- self.early_stopper.trial_counter += 1
445
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
446
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
447
- self.stop_training = True
448
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
449
- break
933
+ # Non-main processes also update trial_counter to keep in sync
934
+ if improved:
935
+ self.early_stopper.trial_counter = 0
936
+ else:
937
+ self.early_stopper.trial_counter += 1
450
938
  else:
451
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
452
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
453
- self.best_checkpoint_path = self.best_path
939
+ if self.is_main_process:
940
+ self.save_model(
941
+ self.checkpoint_path, add_timestamp=False, verbose=False
942
+ )
943
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
944
+ self.best_checkpoint_path = self.best_path
945
+
946
+ # Broadcast stop_training flag to all processes (always, regardless of validation)
947
+ if self.distributed and dist.is_available() and dist.is_initialized():
948
+ stop_tensor = torch.tensor(
949
+ [int(self.stop_training)], device=self.device
950
+ )
951
+ dist.broadcast(stop_tensor, src=0)
952
+ self.stop_training = bool(stop_tensor.item())
953
+
454
954
  if self.stop_training:
455
955
  break
456
956
  if self.scheduler_fn is not None:
457
- if isinstance(self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau):
957
+ if isinstance(
958
+ self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
959
+ ):
458
960
  if valid_loader is not None:
459
961
  self.scheduler_fn.step(primary_metric)
460
962
  else:
461
- self.scheduler_fn.step()
462
- logging.info(" ")
463
- logging.info(colorize("Training finished.", bold=True))
464
- logging.info(" ")
963
+ self.scheduler_fn.step()
964
+ if self.distributed and dist.is_available() and dist.is_initialized():
965
+ dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
966
+ if self.is_main_process:
967
+ logging.info(" ")
968
+ logging.info(colorize("Training finished.", bold=True))
969
+ logging.info(" ")
465
970
  if valid_loader is not None:
466
- logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
467
- self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
971
+ if self.is_main_process:
972
+ logging.info(
973
+ colorize(f"Load best model from: {self.best_checkpoint_path}")
974
+ )
975
+ self.load_model(
976
+ self.best_checkpoint_path, map_location=self.device, verbose=False
977
+ )
468
978
  if self.training_logger:
469
979
  self.training_logger.close()
470
980
  return self
471
981
 
472
- def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
982
+ def train_epoch(
983
+ self, train_loader: DataLoader, is_streaming: bool = False
984
+ ) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
985
+ # use ddp model for distributed training
986
+ model = self.ddp_model if getattr(self, "ddp_model") is not None else self
473
987
  accumulated_loss = 0.0
474
- self.train()
988
+ model.train() # type: ignore
475
989
  num_batches = 0
476
990
  y_true_list = []
477
991
  y_pred_list = []
478
992
 
479
993
  user_ids_list = [] if self.needs_user_ids else None
994
+ tqdm_disable = not self.is_main_process
480
995
  if self.steps_per_epoch is not None:
481
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
996
+ batch_iter = enumerate(
997
+ tqdm.tqdm(
998
+ train_loader,
999
+ desc=f"Epoch {self.epoch_index + 1}",
1000
+ total=self.steps_per_epoch,
1001
+ disable=tqdm_disable,
1002
+ )
1003
+ )
482
1004
  else:
483
1005
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
484
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
1006
+ batch_iter = enumerate(
1007
+ tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
1008
+ )
485
1009
  for batch_index, batch_data in batch_iter:
486
1010
  batch_dict = batch_to_dict(batch_data)
487
1011
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
488
- y_pred = self.forward(X_input)
1012
+ # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
1013
+ y_pred = model(X_input) # type: ignore
1014
+
489
1015
  loss = self.compute_loss(y_pred, y_true)
490
1016
  reg_loss = self.add_reg_loss()
491
1017
  total_loss = loss + reg_loss
492
1018
  self.optimizer_fn.zero_grad()
493
1019
  total_loss.backward()
494
- nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
1020
+
1021
+ params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
1022
+ nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
495
1023
  self.optimizer_fn.step()
496
1024
  accumulated_loss += loss.item()
1025
+
497
1026
  if y_true is not None:
498
1027
  y_true_list.append(y_true.detach().cpu().numpy())
499
1028
  if self.needs_user_ids and user_ids_list is not None:
500
- batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
1029
+ batch_user_id = get_user_ids(
1030
+ data=batch_dict, id_columns=self.id_columns
1031
+ )
501
1032
  if batch_user_id is not None:
502
1033
  user_ids_list.append(batch_user_id)
503
1034
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
504
1035
  y_pred_list.append(y_pred.detach().cpu().numpy())
505
1036
  num_batches += 1
1037
+ if self.distributed and dist.is_available() and dist.is_initialized():
1038
+ loss_tensor = torch.tensor(
1039
+ [accumulated_loss, num_batches], device=self.device, dtype=torch.float32
1040
+ )
1041
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
1042
+ accumulated_loss = loss_tensor[0].item()
1043
+ num_batches = int(loss_tensor[1].item())
506
1044
  avg_loss = accumulated_loss / max(num_batches, 1)
507
- if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
508
- y_true_all = np.concatenate(y_true_list, axis=0)
509
- y_pred_all = np.concatenate(y_pred_list, axis=0)
510
- combined_user_ids = None
511
- if self.needs_user_ids and user_ids_list:
512
- combined_user_ids = np.concatenate(user_ids_list, axis=0)
513
- metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
1045
+
1046
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
1047
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
1048
+ combined_user_ids_local = (
1049
+ np.concatenate(user_ids_list, axis=0)
1050
+ if self.needs_user_ids and user_ids_list
1051
+ else None
1052
+ )
1053
+
1054
+ # gather across ranks even when local is empty to avoid DDP hang
1055
+ y_true_all = gather_numpy(self, y_true_all_local)
1056
+ y_pred_all = gather_numpy(self, y_pred_all_local)
1057
+ combined_user_ids = (
1058
+ gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
1059
+ )
1060
+
1061
+ if (
1062
+ y_true_all is not None
1063
+ and y_pred_all is not None
1064
+ and len(y_true_all) > 0
1065
+ and len(y_pred_all) > 0
1066
+ ):
1067
+ metrics_dict = evaluate_metrics(
1068
+ y_true=y_true_all,
1069
+ y_pred=y_pred_all,
1070
+ metrics=self.metrics,
1071
+ task=self.task,
1072
+ target_names=self.target_columns,
1073
+ task_specific_metrics=self.task_specific_metrics,
1074
+ user_ids=combined_user_ids,
1075
+ )
514
1076
  return avg_loss, metrics_dict
515
1077
  return avg_loss
516
1078
 
517
- def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0,) -> tuple[DataLoader | None, np.ndarray | None]:
1079
+ def prepare_validation_data(
1080
+ self,
1081
+ valid_data: dict | pd.DataFrame | DataLoader | None,
1082
+ batch_size: int,
1083
+ needs_user_ids: bool,
1084
+ user_id_column: str | None = "user_id",
1085
+ num_workers: int = 0,
1086
+ auto_distributed_sampler: bool = True,
1087
+ ) -> tuple[DataLoader | None, np.ndarray | None]:
518
1088
  if valid_data is None:
519
1089
  return None, None
520
1090
  if isinstance(valid_data, DataLoader):
521
- return valid_data, None
522
- valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
1091
+ if auto_distributed_sampler and self.distributed:
1092
+ raise NotImplementedError(
1093
+ "[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
1094
+ )
1095
+ # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
1096
+ else:
1097
+ valid_loader = valid_data
1098
+ return valid_loader, None
1099
+ valid_sampler = None
1100
+ valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
1101
+ if (
1102
+ auto_distributed_sampler
1103
+ and self.distributed
1104
+ and valid_dataset is not None
1105
+ and dist.is_available()
1106
+ and dist.is_initialized()
1107
+ ):
1108
+ valid_sampler = DistributedSampler(
1109
+ valid_dataset,
1110
+ num_replicas=self.world_size,
1111
+ rank=self.rank,
1112
+ shuffle=False,
1113
+ drop_last=False,
1114
+ )
1115
+ valid_loader = DataLoader(
1116
+ valid_dataset,
1117
+ batch_size=batch_size,
1118
+ shuffle=False,
1119
+ sampler=valid_sampler,
1120
+ collate_fn=collate_fn,
1121
+ num_workers=num_workers,
1122
+ )
523
1123
  valid_user_ids = None
524
1124
  if needs_user_ids:
525
1125
  if user_id_column is None:
526
- raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
527
- valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
1126
+ raise ValueError(
1127
+ "[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics."
1128
+ )
1129
+ # In distributed mode, user_ids will be collected during evaluation from each batch
1130
+ # and gathered across all processes, so we don't pre-extract them here
1131
+ if not self.distributed:
1132
+ valid_user_ids = get_user_ids(
1133
+ data=valid_data, id_columns=user_id_column
1134
+ )
528
1135
  return valid_loader, valid_user_ids
529
1136
 
530
- def evaluate(self,
531
- data: dict | pd.DataFrame | DataLoader,
532
- metrics: list[str] | dict[str, list[str]] | None = None,
533
- batch_size: int = 32,
534
- user_ids: np.ndarray | None = None,
535
- user_id_column: str = 'user_id',
536
- num_workers: int = 0,) -> dict:
537
- self.eval()
1137
+ def evaluate(
1138
+ self,
1139
+ data: dict | pd.DataFrame | DataLoader,
1140
+ metrics: list[str] | dict[str, list[str]] | None = None,
1141
+ batch_size: int = 32,
1142
+ user_ids: np.ndarray | None = None,
1143
+ user_id_column: str = "user_id",
1144
+ num_workers: int = 0,
1145
+ ) -> dict:
1146
+ """
1147
+ **IMPORTANT for Distributed Training:**
1148
+ in distributed mode, this method uses collective communication operations (all_gather).
1149
+ all processes must call this method simultaneously, even if you only want results on rank 0.
1150
+ failing to do so will cause the program to hang indefinitely.
1151
+
1152
+ Evaluate the model on the given data.
1153
+
1154
+ Args:
1155
+ data: Evaluation data (dict/df/DataLoader).
1156
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
1157
+ batch_size: Batch size (per process when distributed).
1158
+ user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
1159
+ user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
1160
+ num_workers: DataLoader worker count.
1161
+ """
1162
+ model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
1163
+ model.eval()
538
1164
  eval_metrics = metrics if metrics is not None else self.metrics
539
1165
  if eval_metrics is None:
540
- raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
1166
+ raise ValueError(
1167
+ "[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
1168
+ )
541
1169
  needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
542
-
1170
+
543
1171
  if isinstance(data, DataLoader):
544
1172
  data_loader = data
545
1173
  else:
546
1174
  if user_ids is None and needs_user_ids:
547
1175
  user_ids = get_user_ids(data=data, id_columns=user_id_column)
548
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
1176
+ data_loader = self.prepare_data_loader(
1177
+ data, batch_size=batch_size, shuffle=False, num_workers=num_workers
1178
+ )
549
1179
  y_true_list = []
550
1180
  y_pred_list = []
551
1181
  collected_user_ids = []
@@ -555,30 +1185,25 @@ class BaseModel(FeatureSet, nn.Module):
555
1185
  batch_count += 1
556
1186
  batch_dict = batch_to_dict(batch_data)
557
1187
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
558
- y_pred = self.forward(X_input)
1188
+ y_pred = model(X_input)
559
1189
  if y_true is not None:
560
1190
  y_true_list.append(y_true.cpu().numpy())
561
1191
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
562
1192
  y_pred_list.append(y_pred.cpu().numpy())
563
1193
  if needs_user_ids and user_ids is None:
564
- batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
1194
+ batch_user_id = get_user_ids(
1195
+ data=batch_dict, id_columns=self.id_columns
1196
+ )
565
1197
  if batch_user_id is not None:
566
1198
  collected_user_ids.append(batch_user_id)
567
- logging.info(" ")
568
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
569
- if len(y_true_list) > 0:
570
- y_true_all = np.concatenate(y_true_list, axis=0)
571
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
572
- else:
573
- y_true_all = None
574
- logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
575
-
576
- if len(y_pred_list) > 0:
577
- y_pred_all = np.concatenate(y_pred_list, axis=0)
578
- else:
579
- y_pred_all = None
580
- logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
581
-
1199
+ if self.is_main_process:
1200
+ logging.info(" ")
1201
+ logging.info(
1202
+ colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
1203
+ )
1204
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
1205
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
1206
+
582
1207
  # Convert metrics to list if it's a dict
583
1208
  if isinstance(eval_metrics, dict):
584
1209
  # For dict metrics, we need to collect all unique metric names
@@ -589,11 +1214,44 @@ class BaseModel(FeatureSet, nn.Module):
589
1214
  unique_metrics.append(m)
590
1215
  metrics_to_use = unique_metrics
591
1216
  else:
592
- metrics_to_use = eval_metrics
593
- final_user_ids = user_ids
594
- if final_user_ids is None and collected_user_ids:
595
- final_user_ids = np.concatenate(collected_user_ids, axis=0)
596
- metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
1217
+ metrics_to_use = eval_metrics
1218
+ final_user_ids_local = user_ids
1219
+ if final_user_ids_local is None and collected_user_ids:
1220
+ final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
1221
+
1222
+ # gather across ranks even when local arrays are empty to keep collectives aligned
1223
+ y_true_all = gather_numpy(self, y_true_all_local)
1224
+ y_pred_all = gather_numpy(self, y_pred_all_local)
1225
+ final_user_ids = (
1226
+ gather_numpy(self, final_user_ids_local) if needs_user_ids else None
1227
+ )
1228
+ if (
1229
+ y_true_all is None
1230
+ or y_pred_all is None
1231
+ or len(y_true_all) == 0
1232
+ or len(y_pred_all) == 0
1233
+ ):
1234
+ if self.is_main_process:
1235
+ logging.info(
1236
+ colorize(
1237
+ " Warning: Not enough evaluation data to compute metrics after gathering",
1238
+ color="yellow",
1239
+ )
1240
+ )
1241
+ return {}
1242
+ if self.is_main_process:
1243
+ logging.info(
1244
+ colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
1245
+ )
1246
+ metrics_dict = evaluate_metrics(
1247
+ y_true=y_true_all,
1248
+ y_pred=y_pred_all,
1249
+ metrics=metrics_to_use,
1250
+ task=self.task,
1251
+ target_names=self.target_columns,
1252
+ task_specific_metrics=self.task_specific_metrics,
1253
+ user_ids=final_user_ids,
1254
+ )
597
1255
  return metrics_dict
598
1256
 
599
1257
  def predict(
@@ -603,43 +1261,100 @@ class BaseModel(FeatureSet, nn.Module):
603
1261
  save_path: str | os.PathLike | None = None,
604
1262
  save_format: Literal["csv", "parquet"] = "csv",
605
1263
  include_ids: bool | None = None,
1264
+ id_columns: str | list[str] | None = None,
606
1265
  return_dataframe: bool = True,
607
1266
  streaming_chunk_size: int = 10000,
608
1267
  num_workers: int = 0,
609
1268
  ) -> pd.DataFrame | np.ndarray:
1269
+ """
1270
+ Note: predict does not support distributed mode currently, consider it as a single-process operation.
1271
+ Make predictions on the given data.
1272
+
1273
+ Args:
1274
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
1275
+ batch_size: Batch size for prediction (per process when distributed).
1276
+ save_path: Optional path to save predictions; if None, predictions are not saved to disk.
1277
+ save_format: Format to save predictions ('csv' or 'parquet').
1278
+ include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
1279
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
1280
+ return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
1281
+ streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
1282
+ num_workers: DataLoader worker count.
1283
+ """
610
1284
  self.eval()
1285
+ # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
1286
+ predict_id_columns = id_columns if id_columns is not None else self.id_columns
1287
+ if isinstance(predict_id_columns, str):
1288
+ predict_id_columns = [predict_id_columns]
1289
+
611
1290
  if include_ids is None:
612
- include_ids = bool(self.id_columns)
613
- include_ids = include_ids and bool(self.id_columns)
1291
+ include_ids = bool(predict_id_columns)
1292
+ include_ids = include_ids and bool(predict_id_columns)
614
1293
 
1294
+ # Use streaming mode for large file saves without loading all data into memory
615
1295
  if save_path is not None and not return_dataframe:
616
- return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
617
- if isinstance(data, (str, os.PathLike)):
618
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
619
- data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
620
- elif not isinstance(data, DataLoader):
621
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
622
- else:
1296
+ return self.predict_streaming(
1297
+ data=data,
1298
+ batch_size=batch_size,
1299
+ save_path=save_path,
1300
+ save_format=save_format,
1301
+ include_ids=include_ids,
1302
+ streaming_chunk_size=streaming_chunk_size,
1303
+ return_dataframe=return_dataframe,
1304
+ id_columns=predict_id_columns,
1305
+ )
1306
+
1307
+ # Create DataLoader based on data type
1308
+ if isinstance(data, DataLoader):
623
1309
  data_loader = data
624
-
625
- y_pred_list: list[np.ndarray] = []
626
- id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
627
- id_arrays: dict[str, np.ndarray] | None = None
628
-
1310
+ elif isinstance(data, (str, os.PathLike)):
1311
+ rec_loader = RecDataLoader(
1312
+ dense_features=self.dense_features,
1313
+ sparse_features=self.sparse_features,
1314
+ sequence_features=self.sequence_features,
1315
+ target=self.target_columns,
1316
+ id_columns=predict_id_columns,
1317
+ )
1318
+ data_loader = rec_loader.create_dataloader(
1319
+ data=data,
1320
+ batch_size=batch_size,
1321
+ shuffle=False,
1322
+ load_full=False,
1323
+ chunk_size=streaming_chunk_size,
1324
+ )
1325
+ else:
1326
+ data_loader = self.prepare_data_loader(
1327
+ data, batch_size=batch_size, shuffle=False, num_workers=num_workers
1328
+ )
1329
+
1330
+ y_pred_list = []
1331
+ id_buffers = (
1332
+ {name: [] for name in (predict_id_columns or [])} if include_ids else {}
1333
+ )
1334
+ id_arrays = None
1335
+
629
1336
  with torch.no_grad():
630
1337
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
631
1338
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
632
1339
  X_input, _ = self.get_input(batch_dict, require_labels=False)
633
- y_pred = self.forward(X_input)
1340
+ y_pred = self(X_input)
634
1341
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
635
1342
  y_pred_list.append(y_pred.detach().cpu().numpy())
636
- if include_ids and self.id_columns and batch_dict.get("ids"):
637
- for id_name in self.id_columns:
1343
+ if include_ids and predict_id_columns and batch_dict.get("ids"):
1344
+ for id_name in predict_id_columns:
638
1345
  if id_name not in batch_dict["ids"]:
639
1346
  continue
640
1347
  id_tensor = batch_dict["ids"][id_name]
641
- id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
642
- id_buffers[id_name].append(id_np.reshape(id_np.shape[0], -1) if id_np.ndim == 1 else id_np)
1348
+ id_np = (
1349
+ id_tensor.detach().cpu().numpy()
1350
+ if isinstance(id_tensor, torch.Tensor)
1351
+ else np.asarray(id_tensor)
1352
+ )
1353
+ id_buffers[id_name].append(
1354
+ id_np.reshape(id_np.shape[0], -1)
1355
+ if id_np.ndim == 1
1356
+ else id_np
1357
+ )
643
1358
  if len(y_pred_list) > 0:
644
1359
  y_pred_all = np.concatenate(y_pred_list, axis=0)
645
1360
  else:
@@ -657,11 +1372,13 @@ class BaseModel(FeatureSet, nn.Module):
657
1372
  pred_columns.append(f"{name}_pred")
658
1373
  while len(pred_columns) < num_outputs:
659
1374
  pred_columns.append(f"pred_{len(pred_columns)}")
660
- if include_ids and self.id_columns:
1375
+ if include_ids and predict_id_columns:
661
1376
  id_arrays = {}
662
1377
  for id_name, pieces in id_buffers.items():
663
1378
  if pieces:
664
- concatenated = np.concatenate([p.reshape(p.shape[0], -1) for p in pieces], axis=0)
1379
+ concatenated = np.concatenate(
1380
+ [p.reshape(p.shape[0], -1) for p in pieces], axis=0
1381
+ )
665
1382
  id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
666
1383
  else:
667
1384
  id_arrays[id_name] = np.array([], dtype=np.int64)
@@ -669,34 +1386,52 @@ class BaseModel(FeatureSet, nn.Module):
669
1386
  id_df = pd.DataFrame(id_arrays)
670
1387
  pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
671
1388
  if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
672
- raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)}).")
1389
+ raise ValueError(
1390
+ f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)})."
1391
+ )
673
1392
  output = pd.concat([id_df, pred_df], axis=1)
674
1393
  else:
675
1394
  output = y_pred_all
676
1395
  else:
677
- output = pd.DataFrame(y_pred_all, columns=pred_columns) if return_dataframe else y_pred_all
1396
+ output = (
1397
+ pd.DataFrame(y_pred_all, columns=pred_columns)
1398
+ if return_dataframe
1399
+ else y_pred_all
1400
+ )
678
1401
  if save_path is not None:
679
1402
  if save_format not in ("csv", "parquet"):
680
- raise ValueError(f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'.")
1403
+ raise ValueError(
1404
+ f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
1405
+ )
681
1406
  suffix = ".csv" if save_format == "csv" else ".parquet"
682
- target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False)
1407
+ target_path = resolve_save_path(
1408
+ path=save_path,
1409
+ default_dir=self.session.predictions_dir,
1410
+ default_name="predictions",
1411
+ suffix=suffix,
1412
+ add_timestamp=True if save_path is None else False,
1413
+ )
683
1414
  if isinstance(output, pd.DataFrame):
684
1415
  df_to_save = output
685
1416
  else:
686
1417
  df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
687
- if include_ids and self.id_columns and id_arrays is not None:
1418
+ if include_ids and predict_id_columns and id_arrays is not None:
688
1419
  id_df = pd.DataFrame(id_arrays)
689
1420
  if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
690
- raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
1421
+ raise ValueError(
1422
+ f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
1423
+ )
691
1424
  df_to_save = pd.concat([id_df, df_to_save], axis=1)
692
1425
  if save_format == "csv":
693
1426
  df_to_save.to_csv(target_path, index=False)
694
1427
  else:
695
1428
  df_to_save.to_parquet(target_path, index=False)
696
- logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
1429
+ logging.info(
1430
+ colorize(f"Predictions saved to: {target_path}", color="green")
1431
+ )
697
1432
  return output
698
1433
 
699
- def _predict_streaming(
1434
+ def predict_streaming(
700
1435
  self,
701
1436
  data: str | dict | pd.DataFrame | DataLoader,
702
1437
  batch_size: int,
@@ -705,23 +1440,46 @@ class BaseModel(FeatureSet, nn.Module):
705
1440
  include_ids: bool,
706
1441
  streaming_chunk_size: int,
707
1442
  return_dataframe: bool,
1443
+ id_columns: list[str] | None = None,
708
1444
  ) -> pd.DataFrame:
709
1445
  if isinstance(data, (str, os.PathLike)):
710
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
711
- data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
1446
+ rec_loader = RecDataLoader(
1447
+ dense_features=self.dense_features,
1448
+ sparse_features=self.sparse_features,
1449
+ sequence_features=self.sequence_features,
1450
+ target=self.target_columns,
1451
+ id_columns=id_columns,
1452
+ )
1453
+ data_loader = rec_loader.create_dataloader(
1454
+ data=data,
1455
+ batch_size=batch_size,
1456
+ shuffle=False,
1457
+ load_full=False,
1458
+ chunk_size=streaming_chunk_size,
1459
+ )
712
1460
  elif not isinstance(data, DataLoader):
713
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
1461
+ data_loader = self.prepare_data_loader(
1462
+ data,
1463
+ batch_size=batch_size,
1464
+ shuffle=False,
1465
+ )
714
1466
  else:
715
1467
  data_loader = data
716
1468
 
717
1469
  suffix = ".csv" if save_format == "csv" else ".parquet"
718
- target_path = resolve_save_path(path=save_path, default_dir=self.session.predictions_dir, default_name="predictions", suffix=suffix, add_timestamp=True if save_path is None else False,)
1470
+ target_path = resolve_save_path(
1471
+ path=save_path,
1472
+ default_dir=self.session.predictions_dir,
1473
+ default_name="predictions",
1474
+ suffix=suffix,
1475
+ add_timestamp=True if save_path is None else False,
1476
+ )
719
1477
  target_path.parent.mkdir(parents=True, exist_ok=True)
720
1478
  header_written = target_path.exists() and target_path.stat().st_size > 0
721
1479
  parquet_writer = None
722
1480
 
723
- pred_columns: list[str] | None = None
724
- collected_frames: list[pd.DataFrame] = []
1481
+ pred_columns = None
1482
+ collected_frames = [] # only used when return_dataframe is True
725
1483
 
726
1484
  with torch.no_grad():
727
1485
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
@@ -741,32 +1499,42 @@ class BaseModel(FeatureSet, nn.Module):
741
1499
  pred_columns.append(f"{name}_pred")
742
1500
  while len(pred_columns) < num_outputs:
743
1501
  pred_columns.append(f"pred_{len(pred_columns)}")
744
-
745
- id_arrays_batch: dict[str, np.ndarray] = {}
746
- if include_ids and self.id_columns and batch_dict.get("ids"):
747
- for id_name in self.id_columns:
1502
+
1503
+ id_arrays_batch = {}
1504
+ if include_ids and id_columns and batch_dict.get("ids"):
1505
+ for id_name in id_columns:
748
1506
  if id_name not in batch_dict["ids"]:
749
1507
  continue
750
1508
  id_tensor = batch_dict["ids"][id_name]
751
- id_np = id_tensor.detach().cpu().numpy() if isinstance(id_tensor, torch.Tensor) else np.asarray(id_tensor)
1509
+ id_np = (
1510
+ id_tensor.detach().cpu().numpy()
1511
+ if isinstance(id_tensor, torch.Tensor)
1512
+ else np.asarray(id_tensor)
1513
+ )
752
1514
  id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
753
1515
 
754
1516
  df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
755
1517
  if id_arrays_batch:
756
1518
  id_df = pd.DataFrame(id_arrays_batch)
757
1519
  if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
758
- raise ValueError(f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)}).")
1520
+ raise ValueError(
1521
+ f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
1522
+ )
759
1523
  df_batch = pd.concat([id_df, df_batch], axis=1)
760
1524
 
761
1525
  if save_format == "csv":
762
- df_batch.to_csv(target_path, mode="a", header=not header_written, index=False)
1526
+ df_batch.to_csv(
1527
+ target_path, mode="a", header=not header_written, index=False
1528
+ )
763
1529
  header_written = True
764
1530
  else:
765
1531
  try:
766
1532
  import pyarrow as pa
767
1533
  import pyarrow.parquet as pq
768
1534
  except ImportError as exc: # pragma: no cover
769
- raise ImportError("[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed.") from exc
1535
+ raise ImportError(
1536
+ "[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
1537
+ ) from exc
770
1538
  table = pa.Table.from_pandas(df_batch, preserve_index=False)
771
1539
  if parquet_writer is None:
772
1540
  parquet_writer = pq.ParquetWriter(target_path, table.schema)
@@ -777,14 +1545,36 @@ class BaseModel(FeatureSet, nn.Module):
777
1545
  parquet_writer.close()
778
1546
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
779
1547
  if return_dataframe:
780
- return pd.concat(collected_frames, ignore_index=True) if collected_frames else pd.DataFrame(columns=pred_columns or [])
1548
+ return (
1549
+ pd.concat(collected_frames, ignore_index=True)
1550
+ if collected_frames
1551
+ else pd.DataFrame(columns=pred_columns or [])
1552
+ )
781
1553
  return pd.DataFrame(columns=pred_columns or [])
782
1554
 
783
- def save_model(self, save_path: str | Path | None = None, add_timestamp: bool | None = None, verbose: bool = True):
1555
+ def save_model(
1556
+ self,
1557
+ save_path: str | Path | None = None,
1558
+ add_timestamp: bool | None = None,
1559
+ verbose: bool = True,
1560
+ ):
784
1561
  add_timestamp = False if add_timestamp is None else add_timestamp
785
- target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
1562
+ target_path = resolve_save_path(
1563
+ path=save_path,
1564
+ default_dir=self.session_path,
1565
+ default_name=self.model_name,
1566
+ suffix=".model",
1567
+ add_timestamp=add_timestamp,
1568
+ )
786
1569
  model_path = Path(target_path)
787
- torch.save(self.state_dict(), model_path)
1570
+
1571
+ model_to_save = (
1572
+ self.ddp_model.module
1573
+ if getattr(self, "ddp_model", None) is not None
1574
+ else self
1575
+ )
1576
+ torch.save(model_to_save.state_dict(), model_path)
1577
+ # torch.save(self.state_dict(), model_path)
788
1578
 
789
1579
  config_path = self.features_config_path
790
1580
  features_config = {
@@ -797,29 +1587,47 @@ class BaseModel(FeatureSet, nn.Module):
797
1587
  pickle.dump(features_config, f)
798
1588
  self.features_config_path = str(config_path)
799
1589
  if verbose:
800
- logging.info(colorize(f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",color="green",))
801
-
802
- def load_model(self, save_path: str | Path, map_location: str | torch.device | None = "cpu", verbose: bool = True):
1590
+ logging.info(
1591
+ colorize(
1592
+ f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",
1593
+ color="green",
1594
+ )
1595
+ )
1596
+
1597
+ def load_model(
1598
+ self,
1599
+ save_path: str | Path,
1600
+ map_location: str | torch.device | None = "cpu",
1601
+ verbose: bool = True,
1602
+ ):
803
1603
  self.to(self.device)
804
1604
  base_path = Path(save_path)
805
1605
  if base_path.is_dir():
806
1606
  model_files = sorted(base_path.glob("*.model"))
807
1607
  if not model_files:
808
- raise FileNotFoundError(f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}")
1608
+ raise FileNotFoundError(
1609
+ f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
1610
+ )
809
1611
  model_path = model_files[-1]
810
1612
  config_dir = base_path
811
1613
  else:
812
- model_path = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1614
+ model_path = (
1615
+ base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1616
+ )
813
1617
  config_dir = model_path.parent
814
1618
  if not model_path.exists():
815
- raise FileNotFoundError(f"[BaseModel-load-model Error] Model file does not exist: {model_path}")
1619
+ raise FileNotFoundError(
1620
+ f"[BaseModel-load-model Error] Model file does not exist: {model_path}"
1621
+ )
816
1622
 
817
1623
  state_dict = torch.load(model_path, map_location=map_location)
818
1624
  self.load_state_dict(state_dict)
819
1625
 
820
1626
  features_config_path = config_dir / "features_config.pkl"
821
1627
  if not features_config_path.exists():
822
- raise FileNotFoundError(f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}")
1628
+ raise FileNotFoundError(
1629
+ f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}"
1630
+ )
823
1631
  with open(features_config_path, "rb") as f:
824
1632
  features_config = pickle.load(f)
825
1633
 
@@ -829,11 +1637,22 @@ class BaseModel(FeatureSet, nn.Module):
829
1637
  dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
830
1638
  sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
831
1639
  sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
832
- self.set_all_features(dense_features=dense_features, sparse_features=sparse_features, sequence_features=sequence_features, target=target, id_columns=id_columns)
1640
+ self.set_all_features(
1641
+ dense_features=dense_features,
1642
+ sparse_features=sparse_features,
1643
+ sequence_features=sequence_features,
1644
+ target=target,
1645
+ id_columns=id_columns,
1646
+ )
833
1647
 
834
1648
  cfg_version = features_config.get("version")
835
1649
  if verbose:
836
- logging.info(colorize(f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",color="green",))
1650
+ logging.info(
1651
+ colorize(
1652
+ f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",
1653
+ color="green",
1654
+ )
1655
+ )
837
1656
 
838
1657
  @classmethod
839
1658
  def from_checkpoint(
@@ -845,23 +1664,29 @@ class BaseModel(FeatureSet, nn.Module):
845
1664
  **kwargs: Any,
846
1665
  ) -> "BaseModel":
847
1666
  """
848
- Factory that reconstructs a model instance (including feature specs)
849
- from a saved checkpoint directory or *.model file.
1667
+ Load a model from a checkpoint path. The checkpoint path should contain:
1668
+ a .model file and a features_config.pkl file.
850
1669
  """
851
1670
  base_path = Path(checkpoint_path)
852
1671
  verbose = kwargs.pop("verbose", True)
853
1672
  if base_path.is_dir():
854
1673
  model_candidates = sorted(base_path.glob("*.model"))
855
1674
  if not model_candidates:
856
- raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}")
1675
+ raise FileNotFoundError(
1676
+ f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
1677
+ )
857
1678
  model_file = model_candidates[-1]
858
1679
  config_dir = base_path
859
1680
  else:
860
- model_file = base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1681
+ model_file = (
1682
+ base_path.with_suffix(".model") if base_path.suffix == "" else base_path
1683
+ )
861
1684
  config_dir = model_file.parent
862
1685
  features_config_path = config_dir / "features_config.pkl"
863
1686
  if not features_config_path.exists():
864
- raise FileNotFoundError(f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}")
1687
+ raise FileNotFoundError(
1688
+ f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}"
1689
+ )
865
1690
  with open(features_config_path, "rb") as f:
866
1691
  features_config = pickle.load(f)
867
1692
  all_features = features_config.get("all_features", [])
@@ -887,108 +1712,132 @@ class BaseModel(FeatureSet, nn.Module):
887
1712
 
888
1713
  def summary(self):
889
1714
  logger = logging.getLogger()
890
-
1715
+
891
1716
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
892
- logger.info(colorize(f"Model Summary: {self.model_name}", color="bright_blue", bold=True))
1717
+ logger.info(
1718
+ colorize(
1719
+ f"Model Summary: {self.model_name}", color="bright_blue", bold=True
1720
+ )
1721
+ )
893
1722
  logger.info(colorize("=" * 80, color="bright_blue", bold=True))
894
-
1723
+
895
1724
  logger.info("")
896
1725
  logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
897
1726
  logger.info(colorize("-" * 80, color="cyan"))
898
-
1727
+
899
1728
  if self.dense_features:
900
1729
  logger.info(f"Dense Features ({len(self.dense_features)}):")
901
1730
  for i, feat in enumerate(self.dense_features, 1):
902
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 1
1731
+ embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
903
1732
  logger.info(f" {i}. {feat.name:20s}")
904
-
1733
+
905
1734
  if self.sparse_features:
906
1735
  logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
907
1736
 
908
1737
  max_name_len = max(len(feat.name) for feat in self.sparse_features)
909
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sparse_features)
1738
+ max_embed_name_len = max(
1739
+ len(feat.embedding_name) for feat in self.sparse_features
1740
+ )
910
1741
  name_width = max(max_name_len, 10) + 2
911
1742
  embed_name_width = max(max_embed_name_len, 15) + 2
912
-
913
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}")
914
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}")
1743
+
1744
+ logger.info(
1745
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
1746
+ )
1747
+ logger.info(
1748
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
1749
+ )
915
1750
  for i, feat in enumerate(self.sparse_features, 1):
916
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
917
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
918
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}")
919
-
1751
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1752
+ embed_dim = (
1753
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1754
+ )
1755
+ logger.info(
1756
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
1757
+ )
1758
+
920
1759
  if self.sequence_features:
921
1760
  logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
922
1761
 
923
1762
  max_name_len = max(len(feat.name) for feat in self.sequence_features)
924
- max_embed_name_len = max(len(feat.embedding_name) for feat in self.sequence_features)
1763
+ max_embed_name_len = max(
1764
+ len(feat.embedding_name) for feat in self.sequence_features
1765
+ )
925
1766
  name_width = max(max_name_len, 10) + 2
926
1767
  embed_name_width = max(max_embed_name_len, 15) + 2
927
-
928
- logger.info(f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}")
929
- logger.info(f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}")
1768
+
1769
+ logger.info(
1770
+ f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
1771
+ )
1772
+ logger.info(
1773
+ f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
1774
+ )
930
1775
  for i, feat in enumerate(self.sequence_features, 1):
931
- vocab_size = feat.vocab_size if hasattr(feat, 'vocab_size') else 'N/A'
932
- embed_dim = feat.embedding_dim if hasattr(feat, 'embedding_dim') else 'N/A'
933
- max_len = feat.max_len if hasattr(feat, 'max_len') else 'N/A'
934
- logger.info(f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}")
935
-
1776
+ vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
1777
+ embed_dim = (
1778
+ feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
1779
+ )
1780
+ max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
1781
+ logger.info(
1782
+ f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
1783
+ )
1784
+
936
1785
  logger.info("")
937
1786
  logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
938
1787
  logger.info(colorize("-" * 80, color="cyan"))
939
-
1788
+
940
1789
  # Model Architecture
941
1790
  logger.info("Model Architecture:")
942
1791
  logger.info(str(self))
943
1792
  logger.info("")
944
-
1793
+
945
1794
  total_params = sum(p.numel() for p in self.parameters())
946
1795
  trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
947
1796
  non_trainable_params = total_params - trainable_params
948
-
1797
+
949
1798
  logger.info(f"Total Parameters: {total_params:,}")
950
1799
  logger.info(f"Trainable Parameters: {trainable_params:,}")
951
1800
  logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
952
-
1801
+
953
1802
  logger.info("Layer-wise Parameters:")
954
1803
  for name, module in self.named_children():
955
1804
  layer_params = sum(p.numel() for p in module.parameters())
956
1805
  if layer_params > 0:
957
1806
  logger.info(f" {name:30s}: {layer_params:,}")
958
-
1807
+
959
1808
  logger.info("")
960
1809
  logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
961
1810
  logger.info(colorize("-" * 80, color="cyan"))
962
-
1811
+
963
1812
  logger.info(f"Task Type: {self.task}")
964
1813
  logger.info(f"Number of Tasks: {self.nums_task}")
965
1814
  logger.info(f"Metrics: {self.metrics}")
966
1815
  logger.info(f"Target Columns: {self.target_columns}")
967
1816
  logger.info(f"Device: {self.device}")
968
-
969
- if hasattr(self, 'optimizer_name'):
1817
+
1818
+ if hasattr(self, "optimizer_name"):
970
1819
  logger.info(f"Optimizer: {self.optimizer_name}")
971
1820
  if self.optimizer_params:
972
1821
  for key, value in self.optimizer_params.items():
973
1822
  logger.info(f" {key:25s}: {value}")
974
-
975
- if hasattr(self, 'scheduler_name') and self.scheduler_name:
1823
+
1824
+ if hasattr(self, "scheduler_name") and self.scheduler_name:
976
1825
  logger.info(f"Scheduler: {self.scheduler_name}")
977
1826
  if self.scheduler_params:
978
1827
  for key, value in self.scheduler_params.items():
979
1828
  logger.info(f" {key:25s}: {value}")
980
-
981
- if hasattr(self, 'loss_config'):
1829
+
1830
+ if hasattr(self, "loss_config"):
982
1831
  logger.info(f"Loss Function: {self.loss_config}")
983
- if hasattr(self, 'loss_weights'):
1832
+ if hasattr(self, "loss_weights"):
984
1833
  logger.info(f"Loss Weights: {self.loss_weights}")
985
-
1834
+
986
1835
  logger.info("Regularization:")
987
1836
  logger.info(f" Embedding L1: {self.embedding_l1_reg}")
988
1837
  logger.info(f" Embedding L2: {self.embedding_l2_reg}")
989
1838
  logger.info(f" Dense L1: {self.dense_l1_reg}")
990
1839
  logger.info(f" Dense L2: {self.dense_l2_reg}")
991
-
1840
+
992
1841
  logger.info("Other Settings:")
993
1842
  logger.info(f" Early Stop Patience: {self.early_stop_patience}")
994
1843
  logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
@@ -997,54 +1846,56 @@ class BaseModel(FeatureSet, nn.Module):
997
1846
  logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
998
1847
 
999
1848
 
1000
-
1001
1849
  class BaseMatchModel(BaseModel):
1002
1850
  """
1003
1851
  Base class for match (retrieval/recall) models
1004
1852
  Supports pointwise, pairwise, and listwise training modes
1005
1853
  """
1854
+
1006
1855
  @property
1007
1856
  def model_name(self) -> str:
1008
1857
  raise NotImplementedError
1009
1858
 
1010
1859
  @property
1011
- def task_type(self) -> str:
1012
- raise NotImplementedError
1013
-
1860
+ def default_task(self) -> str:
1861
+ return "binary"
1862
+
1014
1863
  @property
1015
1864
  def support_training_modes(self) -> list[str]:
1016
1865
  """
1017
1866
  Returns list of supported training modes for this model.
1018
1867
  Override in subclasses to restrict training modes.
1019
-
1868
+
1020
1869
  Returns:
1021
1870
  List of supported modes: ['pointwise', 'pairwise', 'listwise']
1022
1871
  """
1023
- return ['pointwise', 'pairwise', 'listwise']
1024
-
1025
- def __init__(self,
1026
- user_dense_features: list[DenseFeature] | None = None,
1027
- user_sparse_features: list[SparseFeature] | None = None,
1028
- user_sequence_features: list[SequenceFeature] | None = None,
1029
- item_dense_features: list[DenseFeature] | None = None,
1030
- item_sparse_features: list[SparseFeature] | None = None,
1031
- item_sequence_features: list[SequenceFeature] | None = None,
1032
- training_mode: Literal['pointwise', 'pairwise', 'listwise'] = 'pointwise',
1033
- num_negative_samples: int = 4,
1034
- temperature: float = 1.0,
1035
- similarity_metric: Literal['dot', 'cosine', 'euclidean'] = 'dot',
1036
- device: str = 'cpu',
1037
- embedding_l1_reg: float = 0.0,
1038
- dense_l1_reg: float = 0.0,
1039
- embedding_l2_reg: float = 0.0,
1040
- dense_l2_reg: float = 0.0,
1041
- early_stop_patience: int = 20,
1042
- **kwargs):
1043
-
1872
+ return ["pointwise", "pairwise", "listwise"]
1873
+
1874
+ def __init__(
1875
+ self,
1876
+ user_dense_features: list[DenseFeature] | None = None,
1877
+ user_sparse_features: list[SparseFeature] | None = None,
1878
+ user_sequence_features: list[SequenceFeature] | None = None,
1879
+ item_dense_features: list[DenseFeature] | None = None,
1880
+ item_sparse_features: list[SparseFeature] | None = None,
1881
+ item_sequence_features: list[SequenceFeature] | None = None,
1882
+ training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
1883
+ num_negative_samples: int = 4,
1884
+ temperature: float = 1.0,
1885
+ similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
1886
+ device: str = "cpu",
1887
+ embedding_l1_reg: float = 0.0,
1888
+ dense_l1_reg: float = 0.0,
1889
+ embedding_l2_reg: float = 0.0,
1890
+ dense_l2_reg: float = 0.0,
1891
+ early_stop_patience: int = 20,
1892
+ **kwargs,
1893
+ ):
1894
+
1044
1895
  all_dense_features = []
1045
1896
  all_sparse_features = []
1046
1897
  all_sequence_features = []
1047
-
1898
+
1048
1899
  if user_dense_features:
1049
1900
  all_dense_features.extend(user_dense_features)
1050
1901
  if item_dense_features:
@@ -1057,117 +1908,175 @@ class BaseMatchModel(BaseModel):
1057
1908
  all_sequence_features.extend(user_sequence_features)
1058
1909
  if item_sequence_features:
1059
1910
  all_sequence_features.extend(item_sequence_features)
1060
-
1911
+
1061
1912
  super(BaseMatchModel, self).__init__(
1062
1913
  dense_features=all_dense_features,
1063
1914
  sparse_features=all_sparse_features,
1064
1915
  sequence_features=all_sequence_features,
1065
- target=['label'],
1066
- task='binary',
1916
+ target=["label"],
1917
+ task="binary",
1067
1918
  device=device,
1068
1919
  embedding_l1_reg=embedding_l1_reg,
1069
1920
  dense_l1_reg=dense_l1_reg,
1070
1921
  embedding_l2_reg=embedding_l2_reg,
1071
1922
  dense_l2_reg=dense_l2_reg,
1072
1923
  early_stop_patience=early_stop_patience,
1073
- **kwargs
1924
+ **kwargs,
1925
+ )
1926
+
1927
+ self.user_dense_features = (
1928
+ list(user_dense_features) if user_dense_features else []
1074
1929
  )
1075
-
1076
- self.user_dense_features = list(user_dense_features) if user_dense_features else []
1077
- self.user_sparse_features = list(user_sparse_features) if user_sparse_features else []
1078
- self.user_sequence_features = list(user_sequence_features) if user_sequence_features else []
1079
-
1080
- self.item_dense_features = list(item_dense_features) if item_dense_features else []
1081
- self.item_sparse_features = list(item_sparse_features) if item_sparse_features else []
1082
- self.item_sequence_features = list(item_sequence_features) if item_sequence_features else []
1083
-
1930
+ self.user_sparse_features = (
1931
+ list(user_sparse_features) if user_sparse_features else []
1932
+ )
1933
+ self.user_sequence_features = (
1934
+ list(user_sequence_features) if user_sequence_features else []
1935
+ )
1936
+
1937
+ self.item_dense_features = (
1938
+ list(item_dense_features) if item_dense_features else []
1939
+ )
1940
+ self.item_sparse_features = (
1941
+ list(item_sparse_features) if item_sparse_features else []
1942
+ )
1943
+ self.item_sequence_features = (
1944
+ list(item_sequence_features) if item_sequence_features else []
1945
+ )
1946
+
1084
1947
  self.training_mode = training_mode
1085
1948
  self.num_negative_samples = num_negative_samples
1086
1949
  self.temperature = temperature
1087
1950
  self.similarity_metric = similarity_metric
1088
1951
 
1089
- self.user_feature_names = [f.name for f in (self.user_dense_features + self.user_sparse_features + self.user_sequence_features)]
1090
- self.item_feature_names = [f.name for f in (self.item_dense_features + self.item_sparse_features + self.item_sequence_features)]
1952
+ self.user_feature_names = [
1953
+ f.name
1954
+ for f in (
1955
+ self.user_dense_features
1956
+ + self.user_sparse_features
1957
+ + self.user_sequence_features
1958
+ )
1959
+ ]
1960
+ self.item_feature_names = [
1961
+ f.name
1962
+ for f in (
1963
+ self.item_dense_features
1964
+ + self.item_sparse_features
1965
+ + self.item_sequence_features
1966
+ )
1967
+ ]
1091
1968
 
1092
1969
  def get_user_features(self, X_input: dict) -> dict:
1093
1970
  return {
1094
- name: X_input[name]
1095
- for name in self.user_feature_names
1096
- if name in X_input
1971
+ name: X_input[name] for name in self.user_feature_names if name in X_input
1097
1972
  }
1098
1973
 
1099
1974
  def get_item_features(self, X_input: dict) -> dict:
1100
1975
  return {
1101
- name: X_input[name]
1102
- for name in self.item_feature_names
1103
- if name in X_input
1976
+ name: X_input[name] for name in self.item_feature_names if name in X_input
1104
1977
  }
1105
-
1106
- def compile(self,
1107
- optimizer: str | torch.optim.Optimizer = "adam",
1108
- optimizer_params: dict | None = None,
1109
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
1110
- scheduler_params: dict | None = None,
1111
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1112
- loss_params: dict | list[dict] | None = None):
1978
+
1979
+ def compile(
1980
+ self,
1981
+ optimizer: str | torch.optim.Optimizer = "adam",
1982
+ optimizer_params: dict | None = None,
1983
+ scheduler: (
1984
+ str
1985
+ | torch.optim.lr_scheduler._LRScheduler
1986
+ | torch.optim.lr_scheduler.LRScheduler
1987
+ | type[torch.optim.lr_scheduler._LRScheduler]
1988
+ | type[torch.optim.lr_scheduler.LRScheduler]
1989
+ | None
1990
+ ) = None,
1991
+ scheduler_params: dict | None = None,
1992
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1993
+ loss_params: dict | list[dict] | None = None,
1994
+ ):
1113
1995
  """
1114
1996
  Compile match model with optimizer, scheduler, and loss function.
1115
1997
  Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1116
1998
  """
1117
1999
  if self.training_mode not in self.support_training_modes:
1118
- raise ValueError(f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}")
2000
+ raise ValueError(
2001
+ f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
2002
+ )
1119
2003
  # Call parent compile with match-specific logic
1120
2004
  optimizer_params = optimizer_params or {}
1121
-
1122
- self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2005
+
2006
+ self.optimizer_name = (
2007
+ optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
2008
+ )
1123
2009
  self.optimizer_params = optimizer_params
1124
2010
  if isinstance(scheduler, str):
1125
2011
  self.scheduler_name = scheduler
1126
2012
  elif scheduler is not None:
1127
2013
  # Try to get __name__ first (for class types), then __class__.__name__ (for instances)
1128
- self.scheduler_name = getattr(scheduler, '__name__', getattr(scheduler.__class__, '__name__', str(scheduler)))
2014
+ self.scheduler_name = getattr(
2015
+ scheduler,
2016
+ "__name__",
2017
+ getattr(scheduler.__class__, "__name__", str(scheduler)),
2018
+ )
1129
2019
  else:
1130
2020
  self.scheduler_name = None
1131
2021
  self.scheduler_params = scheduler_params or {}
1132
2022
  self.loss_config = loss
1133
2023
  self.loss_params = loss_params or {}
1134
2024
 
1135
- self.optimizer_fn = get_optimizer(optimizer=optimizer, params=self.parameters(), **optimizer_params)
2025
+ self.optimizer_fn = get_optimizer(
2026
+ optimizer=optimizer, params=self.parameters(), **optimizer_params
2027
+ )
1136
2028
  # Set loss function based on training mode
1137
2029
  default_losses = {
1138
- 'pointwise': 'bce',
1139
- 'pairwise': 'bpr',
1140
- 'listwise': 'sampled_softmax',
2030
+ "pointwise": "bce",
2031
+ "pairwise": "bpr",
2032
+ "listwise": "sampled_softmax",
1141
2033
  }
1142
2034
 
1143
2035
  if loss is None:
1144
2036
  loss_value = default_losses.get(self.training_mode, "bce")
1145
2037
  elif isinstance(loss, list):
1146
- loss_value = loss[0] if loss and loss[0] is not None else default_losses.get(self.training_mode, "bce")
2038
+ loss_value = (
2039
+ loss[0]
2040
+ if loss and loss[0] is not None
2041
+ else default_losses.get(self.training_mode, "bce")
2042
+ )
1147
2043
  else:
1148
2044
  loss_value = loss
1149
2045
 
1150
2046
  # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1151
- if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
2047
+ if self.training_mode in {"pairwise", "listwise"} and loss_value in {
2048
+ "bce",
2049
+ "binary_crossentropy",
2050
+ }:
1152
2051
  loss_value = default_losses.get(self.training_mode, loss_value)
1153
2052
  loss_kwargs = get_loss_kwargs(self.loss_params, 0)
1154
2053
  self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
1155
2054
  # set scheduler
1156
- self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
2055
+ self.scheduler_fn = (
2056
+ get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
2057
+ if scheduler
2058
+ else None
2059
+ )
1157
2060
 
1158
- def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1159
- if self.similarity_metric == 'dot':
2061
+ def compute_similarity(
2062
+ self, user_emb: torch.Tensor, item_emb: torch.Tensor
2063
+ ) -> torch.Tensor:
2064
+ if self.similarity_metric == "dot":
1160
2065
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1161
2066
  # [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
1162
- similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size, num_items]
2067
+ similarity = torch.sum(
2068
+ user_emb * item_emb, dim=-1
2069
+ ) # [batch_size, num_items]
1163
2070
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
1164
2071
  # [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
1165
2072
  user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
1166
- similarity = torch.sum(user_emb_expanded * item_emb, dim=-1) # [batch_size, num_items]
2073
+ similarity = torch.sum(
2074
+ user_emb_expanded * item_emb, dim=-1
2075
+ ) # [batch_size, num_items]
1167
2076
  else:
1168
2077
  similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
1169
-
1170
- elif self.similarity_metric == 'cosine':
2078
+
2079
+ elif self.similarity_metric == "cosine":
1171
2080
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1172
2081
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1173
2082
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1175,8 +2084,8 @@ class BaseMatchModel(BaseModel):
1175
2084
  similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
1176
2085
  else:
1177
2086
  similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
1178
-
1179
- elif self.similarity_metric == 'euclidean':
2087
+
2088
+ elif self.similarity_metric == "euclidean":
1180
2089
  if user_emb.dim() == 3 and item_emb.dim() == 3:
1181
2090
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1182
2091
  elif user_emb.dim() == 2 and item_emb.dim() == 3:
@@ -1184,63 +2093,70 @@ class BaseMatchModel(BaseModel):
1184
2093
  distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
1185
2094
  else:
1186
2095
  distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
1187
- similarity = -distance
1188
-
2096
+ similarity = -distance
2097
+
1189
2098
  else:
1190
2099
  raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
1191
2100
  similarity = similarity / self.temperature
1192
2101
  return similarity
1193
-
2102
+
1194
2103
  def user_tower(self, user_input: dict) -> torch.Tensor:
1195
2104
  raise NotImplementedError
1196
-
2105
+
1197
2106
  def item_tower(self, item_input: dict) -> torch.Tensor:
1198
2107
  raise NotImplementedError
1199
-
1200
- def forward(self, X_input: dict) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
2108
+
2109
+ def forward(
2110
+ self, X_input: dict
2111
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
1201
2112
  user_input = self.get_user_features(X_input)
1202
2113
  item_input = self.get_item_features(X_input)
1203
-
1204
- user_emb = self.user_tower(user_input) # [B, D]
1205
- item_emb = self.item_tower(item_input) # [B, D]
1206
-
1207
- if self.training and self.training_mode in ['pairwise', 'listwise']:
2114
+
2115
+ user_emb = self.user_tower(user_input) # [B, D]
2116
+ item_emb = self.item_tower(item_input) # [B, D]
2117
+
2118
+ if self.training and self.training_mode in ["pairwise", "listwise"]:
1208
2119
  return user_emb, item_emb
1209
2120
 
1210
2121
  similarity = self.compute_similarity(user_emb, item_emb) # [B]
1211
-
1212
- if self.training_mode == 'pointwise':
2122
+
2123
+ if self.training_mode == "pointwise":
1213
2124
  return torch.sigmoid(similarity)
1214
2125
  else:
1215
2126
  return similarity
1216
-
2127
+
1217
2128
  def compute_loss(self, y_pred, y_true):
1218
- if self.training_mode == 'pointwise':
2129
+ if self.training_mode == "pointwise":
1219
2130
  if y_true is None:
1220
2131
  return torch.tensor(0.0, device=self.device)
1221
2132
  return self.loss_fn[0](y_pred, y_true)
1222
-
2133
+
1223
2134
  # pairwise / listwise using inbatch neg
1224
- elif self.training_mode in ['pairwise', 'listwise']:
2135
+ elif self.training_mode in ["pairwise", "listwise"]:
1225
2136
  if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
1226
- raise ValueError("For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation.")
1227
- user_emb, item_emb = y_pred # [B, D], [B, D]
2137
+ raise ValueError(
2138
+ "For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
2139
+ )
2140
+ user_emb, item_emb = y_pred # [B, D], [B, D]
1228
2141
  logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
1229
- logits = logits / self.temperature
2142
+ logits = logits / self.temperature
1230
2143
  batch_size = logits.size(0)
1231
- targets = torch.arange(batch_size, device=logits.device) # [0, 1, 2, ..., B-1]
2144
+ targets = torch.arange(
2145
+ batch_size, device=logits.device
2146
+ ) # [0, 1, 2, ..., B-1]
1232
2147
  # Cross-Entropy = InfoNCE
1233
2148
  loss = F.cross_entropy(logits, targets)
1234
- return loss
2149
+ return loss
1235
2150
  else:
1236
2151
  raise ValueError(f"Unknown training mode: {self.training_mode}")
1237
2152
 
1238
-
1239
- def prepare_feature_data(self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int) -> DataLoader:
2153
+ def prepare_feature_data(
2154
+ self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
2155
+ ) -> DataLoader:
1240
2156
  """Prepare data loader for specific features."""
1241
2157
  if isinstance(data, DataLoader):
1242
2158
  return data
1243
-
2159
+
1244
2160
  feature_data = {}
1245
2161
  for feature in features:
1246
2162
  if isinstance(data, dict):
@@ -1249,13 +2165,21 @@ class BaseMatchModel(BaseModel):
1249
2165
  elif isinstance(data, pd.DataFrame):
1250
2166
  if feature.name in data.columns:
1251
2167
  feature_data[feature.name] = data[feature.name].values
1252
- return self.prepare_data_loader(feature_data, batch_size=batch_size, shuffle=False)
2168
+ return self.prepare_data_loader(
2169
+ feature_data, batch_size=batch_size, shuffle=False
2170
+ )
1253
2171
 
1254
- def encode_user(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
2172
+ def encode_user(
2173
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2174
+ ) -> np.ndarray:
1255
2175
  self.eval()
1256
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
2176
+ all_user_features = (
2177
+ self.user_dense_features
2178
+ + self.user_sparse_features
2179
+ + self.user_sequence_features
2180
+ )
1257
2181
  data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
1258
-
2182
+
1259
2183
  embeddings_list = []
1260
2184
  with torch.no_grad():
1261
2185
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
@@ -1264,12 +2188,18 @@ class BaseMatchModel(BaseModel):
1264
2188
  user_emb = self.user_tower(user_input)
1265
2189
  embeddings_list.append(user_emb.cpu().numpy())
1266
2190
  return np.concatenate(embeddings_list, axis=0)
1267
-
1268
- def encode_item(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512) -> np.ndarray:
2191
+
2192
+ def encode_item(
2193
+ self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
2194
+ ) -> np.ndarray:
1269
2195
  self.eval()
1270
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
2196
+ all_item_features = (
2197
+ self.item_dense_features
2198
+ + self.item_sparse_features
2199
+ + self.item_sequence_features
2200
+ )
1271
2201
  data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
1272
-
2202
+
1273
2203
  embeddings_list = []
1274
2204
  with torch.no_grad():
1275
2205
  for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):