nextrec 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. nextrec/__init__.py +0 -30
  2. nextrec/__version__.py +1 -1
  3. nextrec/basic/layers.py +32 -15
  4. nextrec/basic/loggers.py +1 -1
  5. nextrec/basic/model.py +440 -189
  6. nextrec/basic/session.py +4 -2
  7. nextrec/data/__init__.py +0 -25
  8. nextrec/data/data_processing.py +31 -19
  9. nextrec/data/dataloader.py +51 -16
  10. nextrec/models/generative/__init__.py +0 -5
  11. nextrec/models/generative/hstu.py +3 -2
  12. nextrec/models/match/__init__.py +0 -13
  13. nextrec/models/match/dssm.py +0 -1
  14. nextrec/models/match/dssm_v2.py +0 -1
  15. nextrec/models/match/mind.py +0 -1
  16. nextrec/models/match/sdm.py +0 -1
  17. nextrec/models/match/youtube_dnn.py +0 -1
  18. nextrec/models/multi_task/__init__.py +0 -0
  19. nextrec/models/multi_task/esmm.py +5 -7
  20. nextrec/models/multi_task/mmoe.py +10 -6
  21. nextrec/models/multi_task/ple.py +10 -6
  22. nextrec/models/multi_task/poso.py +9 -6
  23. nextrec/models/multi_task/share_bottom.py +10 -7
  24. nextrec/models/ranking/__init__.py +0 -27
  25. nextrec/models/ranking/afm.py +113 -21
  26. nextrec/models/ranking/autoint.py +15 -9
  27. nextrec/models/ranking/dcn.py +8 -11
  28. nextrec/models/ranking/deepfm.py +5 -5
  29. nextrec/models/ranking/dien.py +4 -4
  30. nextrec/models/ranking/din.py +4 -4
  31. nextrec/models/ranking/fibinet.py +4 -4
  32. nextrec/models/ranking/fm.py +4 -4
  33. nextrec/models/ranking/masknet.py +4 -5
  34. nextrec/models/ranking/pnn.py +4 -4
  35. nextrec/models/ranking/widedeep.py +4 -4
  36. nextrec/models/ranking/xdeepfm.py +4 -4
  37. nextrec/utils/__init__.py +7 -3
  38. nextrec/utils/device.py +32 -1
  39. nextrec/utils/distributed.py +114 -0
  40. nextrec/utils/synthetic_data.py +413 -0
  41. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
  42. nextrec-0.4.1.dist-info/RECORD +66 -0
  43. nextrec-0.3.5.dist-info/RECORD +0 -63
  44. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
  45. {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,10 +2,9 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 05/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
-
9
8
  import os
10
9
  import tqdm
11
10
  import pickle
@@ -17,10 +16,13 @@ import pandas as pd
17
16
  import torch
18
17
  import torch.nn as nn
19
18
  import torch.nn.functional as F
19
+ import torch.distributed as dist
20
20
 
21
21
  from pathlib import Path
22
22
  from typing import Union, Literal, Any
23
23
  from torch.utils.data import DataLoader
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
26
 
25
27
  from nextrec.basic.callback import EarlyStopper
26
28
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
@@ -31,22 +33,23 @@ from nextrec.basic.session import resolve_save_path, create_session
31
33
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
32
34
 
33
35
  from nextrec.data.dataloader import build_tensors_from_data
34
- from nextrec.data.data_processing import get_column_data, get_user_ids
35
36
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
37
+ from nextrec.data.data_processing import get_column_data, get_user_ids
36
38
 
37
39
  from nextrec.loss import get_loss_fn, get_loss_kwargs
38
- from nextrec.utils import get_optimizer, get_scheduler
39
40
  from nextrec.utils.tensor import to_tensor
40
-
41
+ from nextrec.utils.device import configure_device
42
+ from nextrec.utils.optimizer import get_optimizer, get_scheduler
43
+ from nextrec.utils.distributed import gather_numpy, init_process_group, add_distributed_sampler
41
44
  from nextrec import __version__
42
45
 
43
46
  class BaseModel(FeatureSet, nn.Module):
44
47
  @property
45
48
  def model_name(self) -> str:
46
49
  raise NotImplementedError
47
-
50
+
48
51
  @property
49
- def task_type(self) -> str:
52
+ def default_task(self) -> str | list[str]:
50
53
  raise NotImplementedError
51
54
 
52
55
  def __init__(self,
@@ -55,21 +58,57 @@ class BaseModel(FeatureSet, nn.Module):
55
58
  sequence_features: list[SequenceFeature] | None = None,
56
59
  target: list[str] | str | None = None,
57
60
  id_columns: list[str] | str | None = None,
58
- task: str|list[str] = 'binary',
61
+ task: str | list[str] | None = None,
59
62
  device: str = 'cpu',
63
+ early_stop_patience: int = 20,
64
+ session_id: str | None = None,
60
65
  embedding_l1_reg: float = 0.0,
61
66
  dense_l1_reg: float = 0.0,
62
67
  embedding_l2_reg: float = 0.0,
63
68
  dense_l2_reg: float = 0.0,
64
- early_stop_patience: int = 20,
65
- session_id: str | None = None,):
66
-
69
+
70
+ distributed: bool = False,
71
+ rank: int | None = None,
72
+ world_size: int | None = None,
73
+ local_rank: int | None = None,
74
+ ddp_find_unused_parameters: bool = False,):
75
+ """
76
+ Initialize a base model.
77
+
78
+ Args:
79
+ dense_features: DenseFeature definitions.
80
+ sparse_features: SparseFeature definitions.
81
+ sequence_features: SequenceFeature definitions.
82
+ target: Target column name.
83
+ id_columns: Identifier column name, only need to specify if GAUC is required.
84
+ task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
85
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
86
+ embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
87
+ dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
88
+ embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
89
+ dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
90
+ early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
91
+ session_id: Session id for logging. If None, a default id with timestamps will be created.
92
+ distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
93
+ rank: Global rank (defaults to env RANK).
94
+ world_size: Number of processes (defaults to env WORLD_SIZE).
95
+ local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
96
+ ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
97
+ """
67
98
  super(BaseModel, self).__init__()
68
- try:
69
- self.device = torch.device(device)
70
- except Exception as e:
71
- logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
72
- self.device = torch.device('cpu')
99
+
100
+ # distributed training settings
101
+ env_rank = int(os.environ.get("RANK", "0"))
102
+ env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
103
+ env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
104
+ self.distributed = distributed or (env_world_size > 1)
105
+ self.rank = env_rank if rank is None else rank
106
+ self.world_size = env_world_size if world_size is None else world_size
107
+ self.local_rank = env_local_rank if local_rank is None else local_rank
108
+ self.is_main_process = self.rank == 0
109
+ self.ddp_find_unused_parameters = ddp_find_unused_parameters
110
+ self.ddp_model: DDP | None = None
111
+ self.device = configure_device(self.distributed, self.local_rank, device)
73
112
 
74
113
  self.session_id = session_id
75
114
  self.session = create_session(session_id)
@@ -79,8 +118,8 @@ class BaseModel(FeatureSet, nn.Module):
79
118
  self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
80
119
  self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
81
120
 
82
- self.task = task
83
- self.nums_task = len(task) if isinstance(task, list) else 1
121
+ self.task = self.default_task if task is None else task
122
+ self.nums_task = len(self.task) if isinstance(self.task, list) else 1
84
123
 
85
124
  self.embedding_l1_reg = embedding_l1_reg
86
125
  self.dense_l1_reg = dense_l1_reg
@@ -89,10 +128,11 @@ class BaseModel(FeatureSet, nn.Module):
89
128
  self.regularization_weights = []
90
129
  self.embedding_params = []
91
130
  self.loss_weight = None
131
+
92
132
  self.early_stop_patience = early_stop_patience
93
133
  self.max_gradient_norm = 1.0
94
134
  self.logger_initialized = False
95
- self.training_logger: TrainingLogger | None = None
135
+ self.training_logger = None
96
136
 
97
137
  def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
98
138
  exclude_modules = exclude_modules or []
@@ -145,18 +185,22 @@ class BaseModel(FeatureSet, nn.Module):
145
185
  raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
146
186
  continue
147
187
  target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
148
- target_tensor = target_tensor.view(target_tensor.size(0), -1)
188
+ target_tensor = target_tensor.view(target_tensor.size(0), -1) # always reshape to (batch_size, num_targets)
149
189
  target_tensors.append(target_tensor)
150
190
  if target_tensors:
151
191
  y = torch.cat(target_tensors, dim=1)
152
- if y.shape[1] == 1:
192
+ if y.shape[1] == 1: # no need to do that again
153
193
  y = y.view(-1)
154
194
  elif require_labels:
155
195
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
156
196
  return X_input, y
157
197
 
158
- def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,) -> tuple[DataLoader, dict | pd.DataFrame]:
159
- """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
198
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,):
199
+ """
200
+ This function will split training data into training and validation sets when:
201
+ 1. valid_data is None;
202
+ 2. validation_split is provided.
203
+ """
160
204
  if not (0 < validation_split < 1):
161
205
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
162
206
  if not isinstance(train_data, (pd.DataFrame, dict)):
@@ -184,20 +228,35 @@ class BaseModel(FeatureSet, nn.Module):
184
228
  arr = np.asarray(value)
185
229
  train_split[key] = arr[train_indices]
186
230
  valid_split[key] = arr[valid_indices]
187
- train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
231
+ train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
188
232
  logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
189
233
  return train_loader, valid_split
190
234
 
191
235
  def compile(
192
- self,
193
- optimizer: str | torch.optim.Optimizer = "adam",
194
- optimizer_params: dict | None = None,
195
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
196
- scheduler_params: dict | None = None,
197
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
198
- loss_params: dict | list[dict] | None = None,
199
- loss_weights: int | float | list[int | float] | None = None,
200
- ):
236
+ self,
237
+ optimizer: str | torch.optim.Optimizer = "adam",
238
+ optimizer_params: dict | None = None,
239
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
240
+ scheduler_params: dict | None = None,
241
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
242
+ loss_params: dict | list[dict] | None = None,
243
+ loss_weights: int | float | list[int | float] | None = None,
244
+ ):
245
+ """
246
+ Configure the model for training.
247
+ Args:
248
+ optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
249
+ optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
250
+ scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
251
+ scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
252
+ loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
253
+ loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
254
+ loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
255
+ """
256
+ if loss_params is None:
257
+ self.loss_params = {}
258
+ else:
259
+ self.loss_params = loss_params
201
260
  optimizer_params = optimizer_params or {}
202
261
  self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
203
262
  self.optimizer_params = optimizer_params
@@ -217,7 +276,9 @@ class BaseModel(FeatureSet, nn.Module):
217
276
  self.loss_params = loss_params or {}
218
277
  self.loss_fn = []
219
278
  if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
220
- loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
279
+ if len(loss) != self.nums_task:
280
+ raise ValueError(f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task}).")
281
+ loss_list = [loss[i] for i in range(self.nums_task)]
221
282
  else: # for example: 'bce' -> ['bce', 'bce']
222
283
  loss_list = [loss] * self.nums_task
223
284
 
@@ -231,12 +292,12 @@ class BaseModel(FeatureSet, nn.Module):
231
292
  self.loss_weights = None
232
293
  elif self.nums_task == 1:
233
294
  if isinstance(loss_weights, (list, tuple)):
234
- if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
295
+ if len(loss_weights) != 1:
235
296
  raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
236
297
  weight_value = loss_weights[0]
237
298
  else:
238
299
  weight_value = loss_weights
239
- self.loss_weights = float(weight_value)
300
+ self.loss_weights = [float(weight_value)]
240
301
  else:
241
302
  if isinstance(loss_weights, (int, float)):
242
303
  weights = [float(loss_weights)] * self.nums_task
@@ -250,29 +311,48 @@ class BaseModel(FeatureSet, nn.Module):
250
311
 
251
312
  def compute_loss(self, y_pred, y_true):
252
313
  if y_true is None:
253
- raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
314
+ raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required.")
254
315
  if self.nums_task == 1:
255
- loss = self.loss_fn[0](y_pred, y_true)
316
+ if y_pred.dim() == 1:
317
+ y_pred = y_pred.view(-1, 1)
318
+ if y_true.dim() == 1:
319
+ y_true = y_true.view(-1, 1)
320
+ if y_pred.shape != y_true.shape:
321
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
322
+ task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
323
+ if task_dim == 1:
324
+ loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
325
+ else:
326
+ loss = self.loss_fn[0](y_pred, y_true)
256
327
  if self.loss_weights is not None:
257
- loss = loss * self.loss_weights
328
+ loss *= self.loss_weights[0]
258
329
  return loss
330
+ # multi-task
331
+ if y_pred.shape != y_true.shape:
332
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
333
+ if hasattr(self, "prediction_layer"): # we need to use registered task_slices for multi-task and multi-class
334
+ slices = self.prediction_layer._task_slices # type: ignore
259
335
  else:
260
- task_losses = []
261
- for i in range(self.nums_task):
262
- task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
263
- if isinstance(self.loss_weights, (list, tuple)):
264
- task_loss = task_loss * self.loss_weights[i]
265
- task_losses.append(task_loss)
266
- return torch.stack(task_losses).sum()
267
-
268
- def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True,):
336
+ slices = [(i, i + 1) for i in range(self.nums_task)]
337
+ task_losses = []
338
+ for i, (start, end) in enumerate(slices): # type: ignore
339
+ y_pred_i = y_pred[:, start:end]
340
+ y_true_i = y_true[:, start:end]
341
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
342
+ if isinstance(self.loss_weights, (list, tuple)):
343
+ task_loss *= self.loss_weights[i]
344
+ task_losses.append(task_loss)
345
+ return torch.stack(task_losses).sum()
346
+
347
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0, sampler=None, return_dataset: bool = False) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
269
348
  if isinstance(data, DataLoader):
270
- return data
349
+ return (data, None) if return_dataset else data
271
350
  tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
272
351
  if tensors is None:
273
352
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
274
353
  dataset = TensorDictDataset(tensors)
275
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
354
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
355
+ return (loader, dataset) if return_dataset else loader
276
356
 
277
357
  def fit(self,
278
358
  train_data: dict | pd.DataFrame | DataLoader,
@@ -281,27 +361,83 @@ class BaseModel(FeatureSet, nn.Module):
281
361
  epochs:int=1, shuffle:bool=True, batch_size:int=32,
282
362
  user_id_column: str | None = None,
283
363
  validation_split: float | None = None,
284
- tensorboard: bool = True,):
364
+ num_workers: int = 0,
365
+ tensorboard: bool = True,
366
+ auto_distributed_sampler: bool = True,):
367
+ """
368
+ Train the model.
369
+
370
+ Args:
371
+ train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
372
+ valid_data: Optional validation data; if None and validation_split is set, a split is created.
373
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
374
+ epochs: Training epochs.
375
+ shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
376
+ batch_size: Batch size (per process when distributed).
377
+ user_id_column: Column name for GAUC-style metrics;.
378
+ validation_split: Ratio to split training data when valid_data is None.
379
+ num_workers: DataLoader worker count.
380
+ tensorboard: Enable tensorboard logging.
381
+ auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
382
+
383
+ Notes:
384
+ - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
385
+ - All ranks must call evaluate() together because it performs collective ops.
386
+ """
387
+ device_id = self.local_rank if self.device.type == "cuda" else None
388
+ init_process_group(self.distributed, self.rank, self.world_size, device_id=device_id)
285
389
  self.to(self.device)
286
- if not self.logger_initialized:
390
+
391
+ if self.distributed and dist.is_available() and dist.is_initialized() and self.ddp_model is None:
392
+ device_ids = [self.local_rank] if self.device.type == "cuda" else None # device_ids means which device to use in ddp
393
+ output_device = self.local_rank if self.device.type == "cuda" else None # output_device means which device to place the output in ddp
394
+ object.__setattr__(self, "ddp_model", DDP(self, device_ids=device_ids, output_device=output_device, find_unused_parameters=self.ddp_find_unused_parameters))
395
+
396
+ if not self.logger_initialized and self.is_main_process: # only main process initializes logger
287
397
  setup_logger(session_id=self.session_id)
288
398
  self.logger_initialized = True
289
- self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
399
+ self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard) if self.is_main_process else None
290
400
 
291
401
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
292
402
  self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
403
+ self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
404
+
293
405
  self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
294
406
  self.epoch_index = 0
295
407
  self.stop_training = False
296
408
  self.best_checkpoint_path = self.best_path
297
- self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
298
409
 
410
+ if not auto_distributed_sampler and self.distributed and self.is_main_process:
411
+ logging.info(colorize("[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.", color="yellow"))
412
+
413
+ train_sampler: DistributedSampler | None = None
299
414
  if validation_split is not None and valid_data is None:
300
- train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
415
+ train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
416
+ if auto_distributed_sampler and self.distributed and dist.is_available() and dist.is_initialized():
417
+ base_dataset = getattr(train_loader, "dataset", None)
418
+ if base_dataset is not None and not isinstance(getattr(train_loader, "sampler", None), DistributedSampler):
419
+ train_sampler = DistributedSampler(base_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
420
+ train_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
301
421
  else:
302
- train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle))
422
+ if isinstance(train_data, DataLoader):
423
+ if auto_distributed_sampler and self.distributed:
424
+ train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
425
+ # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
426
+ else:
427
+ train_loader = train_data
428
+ else:
429
+ loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
430
+ if auto_distributed_sampler and self.distributed and dataset is not None and dist.is_available() and dist.is_initialized():
431
+ train_sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
432
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
433
+ train_loader = loader
434
+
435
+ # If split-based loader was built without sampler, attach here when enabled
436
+ if self.distributed and auto_distributed_sampler and isinstance(train_loader, DataLoader) and train_sampler is None:
437
+ raise NotImplementedError("[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
438
+ # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
303
439
 
304
- valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
440
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers, auto_distributed_sampler=auto_distributed_sampler)
305
441
  try:
306
442
  self.steps_per_epoch = len(train_loader)
307
443
  is_streaming = False
@@ -309,38 +445,41 @@ class BaseModel(FeatureSet, nn.Module):
309
445
  self.steps_per_epoch = None
310
446
  is_streaming = True
311
447
 
312
- self.summary()
313
- logging.info("")
314
- if self.training_logger and self.training_logger.enable_tensorboard:
315
- tb_dir = self.training_logger.tensorboard_logdir
316
- if tb_dir:
317
- user = getpass.getuser()
318
- host = socket.gethostname()
319
- tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
320
- ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
321
- logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
322
- logging.info(colorize("To view logs, run:", color="cyan"))
323
- logging.info(colorize(f" {tb_cmd}", color="cyan"))
324
- logging.info(colorize("Then SSH port forward:", color="cyan"))
325
- logging.info(colorize(f" {ssh_hint}", color="cyan"))
326
-
327
- logging.info("")
328
- logging.info(colorize("=" * 80, bold=True))
329
- if is_streaming:
330
- logging.info(colorize(f"Start streaming training", bold=True))
331
- else:
332
- logging.info(colorize(f"Start training", bold=True))
333
- logging.info(colorize("=" * 80, bold=True))
334
- logging.info("")
335
- logging.info(colorize(f"Model device: {self.device}", bold=True))
448
+ if self.is_main_process:
449
+ self.summary()
450
+ logging.info("")
451
+ if self.training_logger and self.training_logger.enable_tensorboard:
452
+ tb_dir = self.training_logger.tensorboard_logdir
453
+ if tb_dir:
454
+ user = getpass.getuser()
455
+ host = socket.gethostname()
456
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
457
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
458
+ logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
459
+ logging.info(colorize("To view logs, run:", color="cyan"))
460
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
461
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
462
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
463
+
464
+ logging.info("")
465
+ logging.info(colorize("=" * 80, bold=True))
466
+ if is_streaming:
467
+ logging.info(colorize(f"Start streaming training", bold=True))
468
+ else:
469
+ logging.info(colorize(f"Start training", bold=True))
470
+ logging.info(colorize("=" * 80, bold=True))
471
+ logging.info("")
472
+ logging.info(colorize(f"Model device: {self.device}", bold=True))
336
473
 
337
474
  for epoch in range(epochs):
338
475
  self.epoch_index = epoch
339
- if is_streaming:
476
+ if is_streaming and self.is_main_process:
340
477
  logging.info("")
341
478
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
342
479
 
343
480
  # handle train result
481
+ if self.distributed and hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
482
+ train_loader.sampler.set_epoch(epoch)
344
483
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
345
484
  if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
346
485
  train_loss, train_metrics = train_result
@@ -355,7 +494,8 @@ class BaseModel(FeatureSet, nn.Module):
355
494
  if train_metrics:
356
495
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
357
496
  log_str += f", {metrics_str}"
358
- logging.info(colorize(log_str))
497
+ if self.is_main_process:
498
+ logging.info(colorize(log_str))
359
499
  train_log_payload["loss"] = float(train_loss)
360
500
  if train_metrics:
361
501
  train_log_payload.update(train_metrics)
@@ -380,7 +520,8 @@ class BaseModel(FeatureSet, nn.Module):
380
520
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
381
521
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
382
522
  log_str += ", " + ", ".join(task_metric_strs)
383
- logging.info(colorize(log_str))
523
+ if self.is_main_process:
524
+ logging.info(colorize(log_str))
384
525
  train_log_payload["loss"] = float(total_loss_val)
385
526
  if train_metrics:
386
527
  train_log_payload.update(train_metrics)
@@ -388,10 +529,11 @@ class BaseModel(FeatureSet, nn.Module):
388
529
  self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
389
530
  if valid_loader is not None:
390
531
  # pass user_ids only if needed for GAUC metric
391
- val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
532
+ val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
392
533
  if self.nums_task == 1:
393
534
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
394
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
535
+ if self.is_main_process:
536
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
395
537
  else:
396
538
  # multi task metrics
397
539
  task_metrics = {}
@@ -408,20 +550,29 @@ class BaseModel(FeatureSet, nn.Module):
408
550
  if target_name in task_metrics:
409
551
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
410
552
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
411
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
553
+ if self.is_main_process:
554
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
412
555
  if val_metrics and self.training_logger:
413
556
  self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
414
557
  # Handle empty validation metrics
415
558
  if not val_metrics:
416
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
417
- self.best_checkpoint_path = self.checkpoint_path
418
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
559
+ if self.is_main_process:
560
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
561
+ self.best_checkpoint_path = self.checkpoint_path
562
+ logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
419
563
  continue
420
564
  if self.nums_task == 1:
421
565
  primary_metric_key = self.metrics[0]
422
566
  else:
423
567
  primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
424
568
  primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
569
+
570
+ # In distributed mode, broadcast primary_metric to ensure all processes use the same value
571
+ if self.distributed and dist.is_available() and dist.is_initialized():
572
+ metric_tensor = torch.tensor([primary_metric], device=self.device, dtype=torch.float32)
573
+ dist.broadcast(metric_tensor, src=0)
574
+ primary_metric = float(metric_tensor.item())
575
+
425
576
  improved = False
426
577
  # early stopping check
427
578
  if self.best_metrics_mode == 'max':
@@ -432,24 +583,40 @@ class BaseModel(FeatureSet, nn.Module):
432
583
  if primary_metric < self.best_metric:
433
584
  self.best_metric = primary_metric
434
585
  improved = True
435
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
436
- logging.info(" ")
437
- if improved:
438
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
439
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
440
- self.best_checkpoint_path = self.best_path
441
- self.early_stopper.trial_counter = 0
586
+
587
+ # save checkpoint and best model for main process
588
+ if self.is_main_process:
589
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
590
+ logging.info(" ")
591
+ if improved:
592
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
593
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
594
+ self.best_checkpoint_path = self.best_path
595
+ self.early_stopper.trial_counter = 0
596
+ else:
597
+ self.early_stopper.trial_counter += 1
598
+ logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
599
+ if self.early_stopper.trial_counter >= self.early_stopper.patience:
600
+ self.stop_training = True
601
+ logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
442
602
  else:
443
- self.early_stopper.trial_counter += 1
444
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
445
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
446
- self.stop_training = True
447
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
448
- break
603
+ # Non-main processes also update trial_counter to keep in sync
604
+ if improved:
605
+ self.early_stopper.trial_counter = 0
606
+ else:
607
+ self.early_stopper.trial_counter += 1
449
608
  else:
450
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
451
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
452
- self.best_checkpoint_path = self.best_path
609
+ if self.is_main_process:
610
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
611
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
612
+ self.best_checkpoint_path = self.best_path
613
+
614
+ # Broadcast stop_training flag to all processes (always, regardless of validation)
615
+ if self.distributed and dist.is_available() and dist.is_initialized():
616
+ stop_tensor = torch.tensor([int(self.stop_training)], device=self.device)
617
+ dist.broadcast(stop_tensor, src=0)
618
+ self.stop_training = bool(stop_tensor.item())
619
+
453
620
  if self.stop_training:
454
621
  break
455
622
  if self.scheduler_fn is not None:
@@ -458,41 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
458
625
  self.scheduler_fn.step(primary_metric)
459
626
  else:
460
627
  self.scheduler_fn.step()
461
- logging.info(" ")
462
- logging.info(colorize("Training finished.", bold=True))
463
- logging.info(" ")
628
+ if self.distributed and dist.is_available() and dist.is_initialized():
629
+ dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
630
+ if self.is_main_process:
631
+ logging.info(" ")
632
+ logging.info(colorize("Training finished.", bold=True))
633
+ logging.info(" ")
464
634
  if valid_loader is not None:
465
- logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
635
+ if self.is_main_process:
636
+ logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
466
637
  self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
467
638
  if self.training_logger:
468
639
  self.training_logger.close()
469
640
  return self
470
641
 
471
642
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
643
+ # use ddp model for distributed training
644
+ model = self.ddp_model if getattr(self, "ddp_model") is not None else self
472
645
  accumulated_loss = 0.0
473
- self.train()
646
+ model.train() # type: ignore
474
647
  num_batches = 0
475
648
  y_true_list = []
476
649
  y_pred_list = []
477
650
 
478
651
  user_ids_list = [] if self.needs_user_ids else None
652
+ tqdm_disable = not self.is_main_process
479
653
  if self.steps_per_epoch is not None:
480
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
654
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch, disable=tqdm_disable))
481
655
  else:
482
656
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
483
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
657
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable))
484
658
  for batch_index, batch_data in batch_iter:
485
659
  batch_dict = batch_to_dict(batch_data)
486
660
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
487
- y_pred = self.forward(X_input)
661
+ # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
662
+ y_pred = model(X_input) # type: ignore
663
+
488
664
  loss = self.compute_loss(y_pred, y_true)
489
665
  reg_loss = self.add_reg_loss()
490
666
  total_loss = loss + reg_loss
491
667
  self.optimizer_fn.zero_grad()
492
668
  total_loss.backward()
493
- nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
669
+
670
+ params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
671
+ nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
494
672
  self.optimizer_fn.step()
495
673
  accumulated_loss += loss.item()
674
+
496
675
  if y_true is not None:
497
676
  y_true_list.append(y_true.detach().cpu().numpy())
498
677
  if self.needs_user_ids and user_ids_list is not None:
@@ -502,37 +681,78 @@ class BaseModel(FeatureSet, nn.Module):
502
681
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
503
682
  y_pred_list.append(y_pred.detach().cpu().numpy())
504
683
  num_batches += 1
684
+ if self.distributed and dist.is_available() and dist.is_initialized():
685
+ loss_tensor = torch.tensor([accumulated_loss, num_batches], device=self.device, dtype=torch.float32)
686
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
687
+ accumulated_loss = loss_tensor[0].item()
688
+ num_batches = int(loss_tensor[1].item())
505
689
  avg_loss = accumulated_loss / max(num_batches, 1)
506
- if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
507
- y_true_all = np.concatenate(y_true_list, axis=0)
508
- y_pred_all = np.concatenate(y_pred_list, axis=0)
509
- combined_user_ids = None
510
- if self.needs_user_ids and user_ids_list:
511
- combined_user_ids = np.concatenate(user_ids_list, axis=0)
690
+
691
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
692
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
693
+ combined_user_ids_local = np.concatenate(user_ids_list, axis=0) if self.needs_user_ids and user_ids_list else None
694
+
695
+ # gather across ranks even when local is empty to avoid DDP hang
696
+ y_true_all = gather_numpy(self, y_true_all_local)
697
+ y_pred_all = gather_numpy(self, y_pred_all_local)
698
+ combined_user_ids = gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
699
+
700
+ if y_true_all is not None and y_pred_all is not None and len(y_true_all) > 0 and len(y_pred_all) > 0:
512
701
  metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
513
702
  return avg_loss, metrics_dict
514
703
  return avg_loss
515
704
 
516
- def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
705
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0, auto_distributed_sampler: bool = True,) -> tuple[DataLoader | None, np.ndarray | None]:
517
706
  if valid_data is None:
518
707
  return None, None
519
708
  if isinstance(valid_data, DataLoader):
520
- return valid_data, None
521
- valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False)
709
+ if auto_distributed_sampler and self.distributed:
710
+ raise NotImplementedError("[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
711
+ # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
712
+ else:
713
+ valid_loader = valid_data
714
+ return valid_loader, None
715
+ valid_sampler = None
716
+ valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
717
+ if auto_distributed_sampler and self.distributed and valid_dataset is not None and dist.is_available() and dist.is_initialized():
718
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, drop_last=False)
719
+ valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler, collate_fn=collate_fn, num_workers=num_workers)
522
720
  valid_user_ids = None
523
721
  if needs_user_ids:
524
722
  if user_id_column is None:
525
723
  raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
526
- valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
724
+ # In distributed mode, user_ids will be collected during evaluation from each batch
725
+ # and gathered across all processes, so we don't pre-extract them here
726
+ if not self.distributed:
727
+ valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
527
728
  return valid_loader, valid_user_ids
528
729
 
529
- def evaluate(self,
530
- data: dict | pd.DataFrame | DataLoader,
531
- metrics: list[str] | dict[str, list[str]] | None = None,
532
- batch_size: int = 32,
533
- user_ids: np.ndarray | None = None,
534
- user_id_column: str = 'user_id') -> dict:
535
- self.eval()
730
+ def evaluate(
731
+ self,
732
+ data: dict | pd.DataFrame | DataLoader,
733
+ metrics: list[str] | dict[str, list[str]] | None = None,
734
+ batch_size: int = 32,
735
+ user_ids: np.ndarray | None = None,
736
+ user_id_column: str = 'user_id',
737
+ num_workers: int = 0,) -> dict:
738
+ """
739
+ **IMPORTANT for Distributed Training:**
740
+ in distributed mode, this method uses collective communication operations (all_gather).
741
+ all processes must call this method simultaneously, even if you only want results on rank 0.
742
+ failing to do so will cause the program to hang indefinitely.
743
+
744
+ Evaluate the model on the given data.
745
+
746
+ Args:
747
+ data: Evaluation data (dict/df/DataLoader).
748
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
749
+ batch_size: Batch size (per process when distributed).
750
+ user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
751
+ user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
752
+ num_workers: DataLoader worker count.
753
+ """
754
+ model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
755
+ model.eval()
536
756
  eval_metrics = metrics if metrics is not None else self.metrics
537
757
  if eval_metrics is None:
538
758
  raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
@@ -543,7 +763,7 @@ class BaseModel(FeatureSet, nn.Module):
543
763
  else:
544
764
  if user_ids is None and needs_user_ids:
545
765
  user_ids = get_user_ids(data=data, id_columns=user_id_column)
546
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
766
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
547
767
  y_true_list = []
548
768
  y_pred_list = []
549
769
  collected_user_ids = []
@@ -553,7 +773,7 @@ class BaseModel(FeatureSet, nn.Module):
553
773
  batch_count += 1
554
774
  batch_dict = batch_to_dict(batch_data)
555
775
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
556
- y_pred = self.forward(X_input)
776
+ y_pred = model(X_input)
557
777
  if y_true is not None:
558
778
  y_true_list.append(y_true.cpu().numpy())
559
779
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
@@ -562,20 +782,11 @@ class BaseModel(FeatureSet, nn.Module):
562
782
  batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
563
783
  if batch_user_id is not None:
564
784
  collected_user_ids.append(batch_user_id)
565
- logging.info(" ")
566
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
567
- if len(y_true_list) > 0:
568
- y_true_all = np.concatenate(y_true_list, axis=0)
569
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
570
- else:
571
- y_true_all = None
572
- logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
573
-
574
- if len(y_pred_list) > 0:
575
- y_pred_all = np.concatenate(y_pred_list, axis=0)
576
- else:
577
- y_pred_all = None
578
- logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
785
+ if self.is_main_process:
786
+ logging.info(" ")
787
+ logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
788
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
789
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
579
790
 
580
791
  # Convert metrics to list if it's a dict
581
792
  if isinstance(eval_metrics, dict):
@@ -588,50 +799,86 @@ class BaseModel(FeatureSet, nn.Module):
588
799
  metrics_to_use = unique_metrics
589
800
  else:
590
801
  metrics_to_use = eval_metrics
591
- final_user_ids = user_ids
592
- if final_user_ids is None and collected_user_ids:
593
- final_user_ids = np.concatenate(collected_user_ids, axis=0)
802
+ final_user_ids_local = user_ids
803
+ if final_user_ids_local is None and collected_user_ids:
804
+ final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
805
+
806
+ # gather across ranks even when local arrays are empty to keep collectives aligned
807
+ y_true_all = gather_numpy(self, y_true_all_local)
808
+ y_pred_all = gather_numpy(self, y_pred_all_local)
809
+ final_user_ids = gather_numpy(self, final_user_ids_local) if needs_user_ids else None
810
+ if y_true_all is None or y_pred_all is None or len(y_true_all) == 0 or len(y_pred_all) == 0:
811
+ if self.is_main_process:
812
+ logging.info(colorize(" Warning: Not enough evaluation data to compute metrics after gathering", color="yellow"))
813
+ return {}
814
+ if self.is_main_process:
815
+ logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
594
816
  metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
595
817
  return metrics_dict
596
818
 
597
819
  def predict(
598
- self,
599
- data: str | dict | pd.DataFrame | DataLoader,
600
- batch_size: int = 32,
601
- save_path: str | os.PathLike | None = None,
602
- save_format: Literal["csv", "parquet"] = "csv",
603
- include_ids: bool | None = None,
604
- return_dataframe: bool = True,
605
- streaming_chunk_size: int = 10000,
606
- ) -> pd.DataFrame | np.ndarray:
820
+ self,
821
+ data: str | dict | pd.DataFrame | DataLoader,
822
+ batch_size: int = 32,
823
+ save_path: str | os.PathLike | None = None,
824
+ save_format: Literal["csv", "parquet"] = "csv",
825
+ include_ids: bool | None = None,
826
+ id_columns: str | list[str] | None = None,
827
+ return_dataframe: bool = True,
828
+ streaming_chunk_size: int = 10000,
829
+ num_workers: int = 0,
830
+ ) -> pd.DataFrame | np.ndarray:
831
+ """
832
+ Note: predict does not support distributed mode currently, consider it as a single-process operation.
833
+ Make predictions on the given data.
834
+
835
+ Args:
836
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
837
+ batch_size: Batch size for prediction (per process when distributed).
838
+ save_path: Optional path to save predictions; if None, predictions are not saved to disk.
839
+ save_format: Format to save predictions ('csv' or 'parquet').
840
+ include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
841
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
842
+ return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
843
+ streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
844
+ num_workers: DataLoader worker count.
845
+ """
607
846
  self.eval()
847
+ # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
848
+ predict_id_columns = id_columns if id_columns is not None else self.id_columns
849
+ if isinstance(predict_id_columns, str):
850
+ predict_id_columns = [predict_id_columns]
851
+
608
852
  if include_ids is None:
609
- include_ids = bool(self.id_columns)
610
- include_ids = include_ids and bool(self.id_columns)
853
+ include_ids = bool(predict_id_columns)
854
+ include_ids = include_ids and bool(predict_id_columns)
611
855
 
856
+ # Use streaming mode for large file saves without loading all data into memory
612
857
  if save_path is not None and not return_dataframe:
613
- return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
614
- if isinstance(data, (str, os.PathLike)):
615
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
858
+ return self.predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe, id_columns=predict_id_columns)
859
+
860
+ # Create DataLoader based on data type
861
+ if isinstance(data, DataLoader):
862
+ data_loader = data
863
+ elif isinstance(data, (str, os.PathLike)):
864
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=predict_id_columns,)
616
865
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
617
- elif not isinstance(data, DataLoader):
618
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
619
866
  else:
620
- data_loader = data
867
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
621
868
 
622
- y_pred_list: list[np.ndarray] = []
623
- id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
624
- id_arrays: dict[str, np.ndarray] | None = None
869
+ y_pred_list = []
870
+ id_buffers = {name: [] for name in (predict_id_columns or [])} if include_ids else {}
871
+ id_arrays = None
625
872
 
626
873
  with torch.no_grad():
627
874
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
628
875
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
629
876
  X_input, _ = self.get_input(batch_dict, require_labels=False)
630
- y_pred = self.forward(X_input)
877
+ y_pred = self(X_input)
631
878
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
632
879
  y_pred_list.append(y_pred.detach().cpu().numpy())
633
- if include_ids and self.id_columns and batch_dict.get("ids"):
634
- for id_name in self.id_columns:
880
+ if include_ids and predict_id_columns and batch_dict.get("ids"):
881
+ for id_name in predict_id_columns:
635
882
  if id_name not in batch_dict["ids"]:
636
883
  continue
637
884
  id_tensor = batch_dict["ids"][id_name]
@@ -654,7 +901,7 @@ class BaseModel(FeatureSet, nn.Module):
654
901
  pred_columns.append(f"{name}_pred")
655
902
  while len(pred_columns) < num_outputs:
656
903
  pred_columns.append(f"pred_{len(pred_columns)}")
657
- if include_ids and self.id_columns:
904
+ if include_ids and predict_id_columns:
658
905
  id_arrays = {}
659
906
  for id_name, pieces in id_buffers.items():
660
907
  if pieces:
@@ -681,7 +928,7 @@ class BaseModel(FeatureSet, nn.Module):
681
928
  df_to_save = output
682
929
  else:
683
930
  df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
684
- if include_ids and self.id_columns and id_arrays is not None:
931
+ if include_ids and predict_id_columns and id_arrays is not None:
685
932
  id_df = pd.DataFrame(id_arrays)
686
933
  if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
687
934
  raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
@@ -693,7 +940,7 @@ class BaseModel(FeatureSet, nn.Module):
693
940
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
694
941
  return output
695
942
 
696
- def _predict_streaming(
943
+ def predict_streaming(
697
944
  self,
698
945
  data: str | dict | pd.DataFrame | DataLoader,
699
946
  batch_size: int,
@@ -702,9 +949,10 @@ class BaseModel(FeatureSet, nn.Module):
702
949
  include_ids: bool,
703
950
  streaming_chunk_size: int,
704
951
  return_dataframe: bool,
952
+ id_columns: list[str] | None = None,
705
953
  ) -> pd.DataFrame:
706
954
  if isinstance(data, (str, os.PathLike)):
707
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
955
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=id_columns)
708
956
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
709
957
  elif not isinstance(data, DataLoader):
710
958
  data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
@@ -717,8 +965,8 @@ class BaseModel(FeatureSet, nn.Module):
717
965
  header_written = target_path.exists() and target_path.stat().st_size > 0
718
966
  parquet_writer = None
719
967
 
720
- pred_columns: list[str] | None = None
721
- collected_frames: list[pd.DataFrame] = []
968
+ pred_columns = None
969
+ collected_frames = [] # only used when return_dataframe is True
722
970
 
723
971
  with torch.no_grad():
724
972
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
@@ -739,9 +987,9 @@ class BaseModel(FeatureSet, nn.Module):
739
987
  while len(pred_columns) < num_outputs:
740
988
  pred_columns.append(f"pred_{len(pred_columns)}")
741
989
 
742
- id_arrays_batch: dict[str, np.ndarray] = {}
743
- if include_ids and self.id_columns and batch_dict.get("ids"):
744
- for id_name in self.id_columns:
990
+ id_arrays_batch = {}
991
+ if include_ids and id_columns and batch_dict.get("ids"):
992
+ for id_name in id_columns:
745
993
  if id_name not in batch_dict["ids"]:
746
994
  continue
747
995
  id_tensor = batch_dict["ids"][id_name]
@@ -781,7 +1029,10 @@ class BaseModel(FeatureSet, nn.Module):
781
1029
  add_timestamp = False if add_timestamp is None else add_timestamp
782
1030
  target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
783
1031
  model_path = Path(target_path)
784
- torch.save(self.state_dict(), model_path)
1032
+
1033
+ model_to_save = (self.ddp_model.module if getattr(self, "ddp_model", None) is not None else self)
1034
+ torch.save(model_to_save.state_dict(), model_path)
1035
+ # torch.save(self.state_dict(), model_path)
785
1036
 
786
1037
  config_path = self.features_config_path
787
1038
  features_config = {
@@ -842,8 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
842
1093
  **kwargs: Any,
843
1094
  ) -> "BaseModel":
844
1095
  """
845
- Factory that reconstructs a model instance (including feature specs)
846
- from a saved checkpoint directory or *.model file.
1096
+ Load a model from a checkpoint path. The checkpoint path should contain:
1097
+ a .model file and a features_config.pkl file.
847
1098
  """
848
1099
  base_path = Path(checkpoint_path)
849
1100
  verbose = kwargs.pop("verbose", True)
@@ -1003,10 +1254,10 @@ class BaseMatchModel(BaseModel):
1003
1254
  @property
1004
1255
  def model_name(self) -> str:
1005
1256
  raise NotImplementedError
1006
-
1257
+
1007
1258
  @property
1008
- def task_type(self) -> str:
1009
- raise NotImplementedError
1259
+ def default_task(self) -> str:
1260
+ return "binary"
1010
1261
 
1011
1262
  @property
1012
1263
  def support_training_modes(self) -> list[str]: