nextrec 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/layers.py +32 -15
  3. nextrec/basic/model.py +435 -187
  4. nextrec/data/data_processing.py +31 -19
  5. nextrec/data/dataloader.py +40 -10
  6. nextrec/models/generative/hstu.py +3 -2
  7. nextrec/models/match/dssm.py +0 -1
  8. nextrec/models/match/dssm_v2.py +0 -1
  9. nextrec/models/match/mind.py +0 -1
  10. nextrec/models/match/sdm.py +0 -1
  11. nextrec/models/match/youtube_dnn.py +0 -1
  12. nextrec/models/multi_task/esmm.py +5 -7
  13. nextrec/models/multi_task/mmoe.py +10 -6
  14. nextrec/models/multi_task/ple.py +10 -6
  15. nextrec/models/multi_task/poso.py +9 -6
  16. nextrec/models/multi_task/share_bottom.py +10 -7
  17. nextrec/models/ranking/afm.py +113 -21
  18. nextrec/models/ranking/autoint.py +15 -9
  19. nextrec/models/ranking/dcn.py +8 -11
  20. nextrec/models/ranking/deepfm.py +5 -5
  21. nextrec/models/ranking/dien.py +4 -4
  22. nextrec/models/ranking/din.py +4 -4
  23. nextrec/models/ranking/fibinet.py +4 -4
  24. nextrec/models/ranking/fm.py +4 -4
  25. nextrec/models/ranking/masknet.py +4 -5
  26. nextrec/models/ranking/pnn.py +4 -4
  27. nextrec/models/ranking/widedeep.py +4 -4
  28. nextrec/models/ranking/xdeepfm.py +4 -4
  29. nextrec/utils/__init__.py +7 -3
  30. nextrec/utils/device.py +30 -0
  31. nextrec/utils/distributed.py +114 -0
  32. nextrec/utils/synthetic_data.py +413 -0
  33. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
  34. nextrec-0.4.1.dist-info/RECORD +66 -0
  35. nextrec-0.3.6.dist-info/RECORD +0 -64
  36. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
  37. {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py CHANGED
@@ -2,10 +2,9 @@
2
2
  Base Model & Base Match Model Class
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 02/12/2025
5
+ Checkpoint: edit on 05/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
-
9
8
  import os
10
9
  import tqdm
11
10
  import pickle
@@ -17,10 +16,13 @@ import pandas as pd
17
16
  import torch
18
17
  import torch.nn as nn
19
18
  import torch.nn.functional as F
19
+ import torch.distributed as dist
20
20
 
21
21
  from pathlib import Path
22
22
  from typing import Union, Literal, Any
23
23
  from torch.utils.data import DataLoader
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.nn.parallel import DistributedDataParallel as DDP
24
26
 
25
27
  from nextrec.basic.callback import EarlyStopper
26
28
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
@@ -31,22 +33,23 @@ from nextrec.basic.session import resolve_save_path, create_session
31
33
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
32
34
 
33
35
  from nextrec.data.dataloader import build_tensors_from_data
34
- from nextrec.data.data_processing import get_column_data, get_user_ids
35
36
  from nextrec.data.batch_utils import collate_fn, batch_to_dict
37
+ from nextrec.data.data_processing import get_column_data, get_user_ids
36
38
 
37
39
  from nextrec.loss import get_loss_fn, get_loss_kwargs
38
- from nextrec.utils import get_optimizer, get_scheduler
39
40
  from nextrec.utils.tensor import to_tensor
40
-
41
+ from nextrec.utils.device import configure_device
42
+ from nextrec.utils.optimizer import get_optimizer, get_scheduler
43
+ from nextrec.utils.distributed import gather_numpy, init_process_group, add_distributed_sampler
41
44
  from nextrec import __version__
42
45
 
43
46
  class BaseModel(FeatureSet, nn.Module):
44
47
  @property
45
48
  def model_name(self) -> str:
46
49
  raise NotImplementedError
47
-
50
+
48
51
  @property
49
- def task_type(self) -> str:
52
+ def default_task(self) -> str | list[str]:
50
53
  raise NotImplementedError
51
54
 
52
55
  def __init__(self,
@@ -55,21 +58,57 @@ class BaseModel(FeatureSet, nn.Module):
55
58
  sequence_features: list[SequenceFeature] | None = None,
56
59
  target: list[str] | str | None = None,
57
60
  id_columns: list[str] | str | None = None,
58
- task: str|list[str] = 'binary',
61
+ task: str | list[str] | None = None,
59
62
  device: str = 'cpu',
63
+ early_stop_patience: int = 20,
64
+ session_id: str | None = None,
60
65
  embedding_l1_reg: float = 0.0,
61
66
  dense_l1_reg: float = 0.0,
62
67
  embedding_l2_reg: float = 0.0,
63
68
  dense_l2_reg: float = 0.0,
64
- early_stop_patience: int = 20,
65
- session_id: str | None = None,):
66
-
69
+
70
+ distributed: bool = False,
71
+ rank: int | None = None,
72
+ world_size: int | None = None,
73
+ local_rank: int | None = None,
74
+ ddp_find_unused_parameters: bool = False,):
75
+ """
76
+ Initialize a base model.
77
+
78
+ Args:
79
+ dense_features: DenseFeature definitions.
80
+ sparse_features: SparseFeature definitions.
81
+ sequence_features: SequenceFeature definitions.
82
+ target: Target column name.
83
+ id_columns: Identifier column name, only need to specify if GAUC is required.
84
+ task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
85
+ device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
86
+ embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
87
+ dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
88
+ embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
89
+ dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
90
+ early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
91
+ session_id: Session id for logging. If None, a default id with timestamps will be created.
92
+ distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
93
+ rank: Global rank (defaults to env RANK).
94
+ world_size: Number of processes (defaults to env WORLD_SIZE).
95
+ local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
96
+ ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
97
+ """
67
98
  super(BaseModel, self).__init__()
68
- try:
69
- self.device = torch.device(device)
70
- except Exception as e:
71
- logging.warning("[BaseModel Warning] Invalid device , defaulting to CPU.")
72
- self.device = torch.device('cpu')
99
+
100
+ # distributed training settings
101
+ env_rank = int(os.environ.get("RANK", "0"))
102
+ env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
103
+ env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
104
+ self.distributed = distributed or (env_world_size > 1)
105
+ self.rank = env_rank if rank is None else rank
106
+ self.world_size = env_world_size if world_size is None else world_size
107
+ self.local_rank = env_local_rank if local_rank is None else local_rank
108
+ self.is_main_process = self.rank == 0
109
+ self.ddp_find_unused_parameters = ddp_find_unused_parameters
110
+ self.ddp_model: DDP | None = None
111
+ self.device = configure_device(self.distributed, self.local_rank, device)
73
112
 
74
113
  self.session_id = session_id
75
114
  self.session = create_session(session_id)
@@ -79,8 +118,8 @@ class BaseModel(FeatureSet, nn.Module):
79
118
  self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
80
119
  self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
81
120
 
82
- self.task = task
83
- self.nums_task = len(task) if isinstance(task, list) else 1
121
+ self.task = self.default_task if task is None else task
122
+ self.nums_task = len(self.task) if isinstance(self.task, list) else 1
84
123
 
85
124
  self.embedding_l1_reg = embedding_l1_reg
86
125
  self.dense_l1_reg = dense_l1_reg
@@ -89,10 +128,11 @@ class BaseModel(FeatureSet, nn.Module):
89
128
  self.regularization_weights = []
90
129
  self.embedding_params = []
91
130
  self.loss_weight = None
131
+
92
132
  self.early_stop_patience = early_stop_patience
93
133
  self.max_gradient_norm = 1.0
94
134
  self.logger_initialized = False
95
- self.training_logger: TrainingLogger | None = None
135
+ self.training_logger = None
96
136
 
97
137
  def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
98
138
  exclude_modules = exclude_modules or []
@@ -145,18 +185,22 @@ class BaseModel(FeatureSet, nn.Module):
145
185
  raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
146
186
  continue
147
187
  target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
148
- target_tensor = target_tensor.view(target_tensor.size(0), -1)
188
+ target_tensor = target_tensor.view(target_tensor.size(0), -1) # always reshape to (batch_size, num_targets)
149
189
  target_tensors.append(target_tensor)
150
190
  if target_tensors:
151
191
  y = torch.cat(target_tensors, dim=1)
152
- if y.shape[1] == 1:
192
+ if y.shape[1] == 1: # no need to do that again
153
193
  y = y.view(-1)
154
194
  elif require_labels:
155
195
  raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
156
196
  return X_input, y
157
197
 
158
- def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,) -> tuple[DataLoader, dict | pd.DataFrame]:
159
- """This function will split training data into training and validation sets when: 1. valid_data is None; 2. validation_split is provided."""
198
+ def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,):
199
+ """
200
+ This function will split training data into training and validation sets when:
201
+ 1. valid_data is None;
202
+ 2. validation_split is provided.
203
+ """
160
204
  if not (0 < validation_split < 1):
161
205
  raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
162
206
  if not isinstance(train_data, (pd.DataFrame, dict)):
@@ -189,15 +233,30 @@ class BaseModel(FeatureSet, nn.Module):
189
233
  return train_loader, valid_split
190
234
 
191
235
  def compile(
192
- self,
193
- optimizer: str | torch.optim.Optimizer = "adam",
194
- optimizer_params: dict | None = None,
195
- scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
196
- scheduler_params: dict | None = None,
197
- loss: str | nn.Module | list[str | nn.Module] | None = "bce",
198
- loss_params: dict | list[dict] | None = None,
199
- loss_weights: int | float | list[int | float] | None = None,
200
- ):
236
+ self,
237
+ optimizer: str | torch.optim.Optimizer = "adam",
238
+ optimizer_params: dict | None = None,
239
+ scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
240
+ scheduler_params: dict | None = None,
241
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
242
+ loss_params: dict | list[dict] | None = None,
243
+ loss_weights: int | float | list[int | float] | None = None,
244
+ ):
245
+ """
246
+ Configure the model for training.
247
+ Args:
248
+ optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
249
+ optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
250
+ scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
251
+ scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
252
+ loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
253
+ loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
254
+ loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
255
+ """
256
+ if loss_params is None:
257
+ self.loss_params = {}
258
+ else:
259
+ self.loss_params = loss_params
201
260
  optimizer_params = optimizer_params or {}
202
261
  self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
203
262
  self.optimizer_params = optimizer_params
@@ -217,7 +276,9 @@ class BaseModel(FeatureSet, nn.Module):
217
276
  self.loss_params = loss_params or {}
218
277
  self.loss_fn = []
219
278
  if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
220
- loss_list = [loss[i] if i < len(loss) else None for i in range(self.nums_task)]
279
+ if len(loss) != self.nums_task:
280
+ raise ValueError(f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task}).")
281
+ loss_list = [loss[i] for i in range(self.nums_task)]
221
282
  else: # for example: 'bce' -> ['bce', 'bce']
222
283
  loss_list = [loss] * self.nums_task
223
284
 
@@ -231,12 +292,12 @@ class BaseModel(FeatureSet, nn.Module):
231
292
  self.loss_weights = None
232
293
  elif self.nums_task == 1:
233
294
  if isinstance(loss_weights, (list, tuple)):
234
- if len(loss_weights) != 1 and isinstance(loss_weights, (list, tuple)):
295
+ if len(loss_weights) != 1:
235
296
  raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
236
297
  weight_value = loss_weights[0]
237
298
  else:
238
299
  weight_value = loss_weights
239
- self.loss_weights = float(weight_value)
300
+ self.loss_weights = [float(weight_value)]
240
301
  else:
241
302
  if isinstance(loss_weights, (int, float)):
242
303
  weights = [float(loss_weights)] * self.nums_task
@@ -250,29 +311,48 @@ class BaseModel(FeatureSet, nn.Module):
250
311
 
251
312
  def compute_loss(self, y_pred, y_true):
252
313
  if y_true is None:
253
- raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required to compute loss.")
314
+ raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required.")
254
315
  if self.nums_task == 1:
255
- loss = self.loss_fn[0](y_pred, y_true)
316
+ if y_pred.dim() == 1:
317
+ y_pred = y_pred.view(-1, 1)
318
+ if y_true.dim() == 1:
319
+ y_true = y_true.view(-1, 1)
320
+ if y_pred.shape != y_true.shape:
321
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
322
+ task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
323
+ if task_dim == 1:
324
+ loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
325
+ else:
326
+ loss = self.loss_fn[0](y_pred, y_true)
256
327
  if self.loss_weights is not None:
257
- loss = loss * self.loss_weights
328
+ loss *= self.loss_weights[0]
258
329
  return loss
330
+ # multi-task
331
+ if y_pred.shape != y_true.shape:
332
+ raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
333
+ if hasattr(self, "prediction_layer"): # we need to use registered task_slices for multi-task and multi-class
334
+ slices = self.prediction_layer._task_slices # type: ignore
259
335
  else:
260
- task_losses = []
261
- for i in range(self.nums_task):
262
- task_loss = self.loss_fn[i](y_pred[:, i], y_true[:, i])
263
- if isinstance(self.loss_weights, (list, tuple)):
264
- task_loss = task_loss * self.loss_weights[i]
265
- task_losses.append(task_loss)
266
- return torch.stack(task_losses).sum()
267
-
268
- def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0,) -> DataLoader:
336
+ slices = [(i, i + 1) for i in range(self.nums_task)]
337
+ task_losses = []
338
+ for i, (start, end) in enumerate(slices): # type: ignore
339
+ y_pred_i = y_pred[:, start:end]
340
+ y_true_i = y_true[:, start:end]
341
+ task_loss = self.loss_fn[i](y_pred_i, y_true_i)
342
+ if isinstance(self.loss_weights, (list, tuple)):
343
+ task_loss *= self.loss_weights[i]
344
+ task_losses.append(task_loss)
345
+ return torch.stack(task_losses).sum()
346
+
347
+ def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0, sampler=None, return_dataset: bool = False) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
269
348
  if isinstance(data, DataLoader):
270
- return data
349
+ return (data, None) if return_dataset else data
271
350
  tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
272
351
  if tensors is None:
273
352
  raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
274
353
  dataset = TensorDictDataset(tensors)
275
- return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn, num_workers=num_workers)
354
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
355
+ return (loader, dataset) if return_dataset else loader
276
356
 
277
357
  def fit(self,
278
358
  train_data: dict | pd.DataFrame | DataLoader,
@@ -282,27 +362,82 @@ class BaseModel(FeatureSet, nn.Module):
282
362
  user_id_column: str | None = None,
283
363
  validation_split: float | None = None,
284
364
  num_workers: int = 0,
285
- tensorboard: bool = True,):
365
+ tensorboard: bool = True,
366
+ auto_distributed_sampler: bool = True,):
367
+ """
368
+ Train the model.
369
+
370
+ Args:
371
+ train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
372
+ valid_data: Optional validation data; if None and validation_split is set, a split is created.
373
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
374
+ epochs: Training epochs.
375
+ shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
376
+ batch_size: Batch size (per process when distributed).
377
+ user_id_column: Column name for GAUC-style metrics;.
378
+ validation_split: Ratio to split training data when valid_data is None.
379
+ num_workers: DataLoader worker count.
380
+ tensorboard: Enable tensorboard logging.
381
+ auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
382
+
383
+ Notes:
384
+ - Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
385
+ - All ranks must call evaluate() together because it performs collective ops.
386
+ """
387
+ device_id = self.local_rank if self.device.type == "cuda" else None
388
+ init_process_group(self.distributed, self.rank, self.world_size, device_id=device_id)
286
389
  self.to(self.device)
287
- if not self.logger_initialized:
390
+
391
+ if self.distributed and dist.is_available() and dist.is_initialized() and self.ddp_model is None:
392
+ device_ids = [self.local_rank] if self.device.type == "cuda" else None # device_ids means which device to use in ddp
393
+ output_device = self.local_rank if self.device.type == "cuda" else None # output_device means which device to place the output in ddp
394
+ object.__setattr__(self, "ddp_model", DDP(self, device_ids=device_ids, output_device=output_device, find_unused_parameters=self.ddp_find_unused_parameters))
395
+
396
+ if not self.logger_initialized and self.is_main_process: # only main process initializes logger
288
397
  setup_logger(session_id=self.session_id)
289
398
  self.logger_initialized = True
290
- self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
399
+ self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard) if self.is_main_process else None
291
400
 
292
401
  self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
293
402
  self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
403
+ self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
404
+
294
405
  self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
295
406
  self.epoch_index = 0
296
407
  self.stop_training = False
297
408
  self.best_checkpoint_path = self.best_path
298
- self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
299
409
 
410
+ if not auto_distributed_sampler and self.distributed and self.is_main_process:
411
+ logging.info(colorize("[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.", color="yellow"))
412
+
413
+ train_sampler: DistributedSampler | None = None
300
414
  if validation_split is not None and valid_data is None:
301
415
  train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
416
+ if auto_distributed_sampler and self.distributed and dist.is_available() and dist.is_initialized():
417
+ base_dataset = getattr(train_loader, "dataset", None)
418
+ if base_dataset is not None and not isinstance(getattr(train_loader, "sampler", None), DistributedSampler):
419
+ train_sampler = DistributedSampler(base_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
420
+ train_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
302
421
  else:
303
- train_loader = (train_data if isinstance(train_data, DataLoader) else self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers))
422
+ if isinstance(train_data, DataLoader):
423
+ if auto_distributed_sampler and self.distributed:
424
+ train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
425
+ # train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
426
+ else:
427
+ train_loader = train_data
428
+ else:
429
+ loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
430
+ if auto_distributed_sampler and self.distributed and dataset is not None and dist.is_available() and dist.is_initialized():
431
+ train_sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
432
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
433
+ train_loader = loader
434
+
435
+ # If split-based loader was built without sampler, attach here when enabled
436
+ if self.distributed and auto_distributed_sampler and isinstance(train_loader, DataLoader) and train_sampler is None:
437
+ raise NotImplementedError("[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
438
+ # train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
304
439
 
305
- valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers)
440
+ valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers, auto_distributed_sampler=auto_distributed_sampler)
306
441
  try:
307
442
  self.steps_per_epoch = len(train_loader)
308
443
  is_streaming = False
@@ -310,38 +445,41 @@ class BaseModel(FeatureSet, nn.Module):
310
445
  self.steps_per_epoch = None
311
446
  is_streaming = True
312
447
 
313
- self.summary()
314
- logging.info("")
315
- if self.training_logger and self.training_logger.enable_tensorboard:
316
- tb_dir = self.training_logger.tensorboard_logdir
317
- if tb_dir:
318
- user = getpass.getuser()
319
- host = socket.gethostname()
320
- tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
321
- ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
322
- logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
323
- logging.info(colorize("To view logs, run:", color="cyan"))
324
- logging.info(colorize(f" {tb_cmd}", color="cyan"))
325
- logging.info(colorize("Then SSH port forward:", color="cyan"))
326
- logging.info(colorize(f" {ssh_hint}", color="cyan"))
327
-
328
- logging.info("")
329
- logging.info(colorize("=" * 80, bold=True))
330
- if is_streaming:
331
- logging.info(colorize(f"Start streaming training", bold=True))
332
- else:
333
- logging.info(colorize(f"Start training", bold=True))
334
- logging.info(colorize("=" * 80, bold=True))
335
- logging.info("")
336
- logging.info(colorize(f"Model device: {self.device}", bold=True))
448
+ if self.is_main_process:
449
+ self.summary()
450
+ logging.info("")
451
+ if self.training_logger and self.training_logger.enable_tensorboard:
452
+ tb_dir = self.training_logger.tensorboard_logdir
453
+ if tb_dir:
454
+ user = getpass.getuser()
455
+ host = socket.gethostname()
456
+ tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
457
+ ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
458
+ logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
459
+ logging.info(colorize("To view logs, run:", color="cyan"))
460
+ logging.info(colorize(f" {tb_cmd}", color="cyan"))
461
+ logging.info(colorize("Then SSH port forward:", color="cyan"))
462
+ logging.info(colorize(f" {ssh_hint}", color="cyan"))
463
+
464
+ logging.info("")
465
+ logging.info(colorize("=" * 80, bold=True))
466
+ if is_streaming:
467
+ logging.info(colorize(f"Start streaming training", bold=True))
468
+ else:
469
+ logging.info(colorize(f"Start training", bold=True))
470
+ logging.info(colorize("=" * 80, bold=True))
471
+ logging.info("")
472
+ logging.info(colorize(f"Model device: {self.device}", bold=True))
337
473
 
338
474
  for epoch in range(epochs):
339
475
  self.epoch_index = epoch
340
- if is_streaming:
476
+ if is_streaming and self.is_main_process:
341
477
  logging.info("")
342
478
  logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
343
479
 
344
480
  # handle train result
481
+ if self.distributed and hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
482
+ train_loader.sampler.set_epoch(epoch)
345
483
  train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
346
484
  if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
347
485
  train_loss, train_metrics = train_result
@@ -356,7 +494,8 @@ class BaseModel(FeatureSet, nn.Module):
356
494
  if train_metrics:
357
495
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
358
496
  log_str += f", {metrics_str}"
359
- logging.info(colorize(log_str))
497
+ if self.is_main_process:
498
+ logging.info(colorize(log_str))
360
499
  train_log_payload["loss"] = float(train_loss)
361
500
  if train_metrics:
362
501
  train_log_payload.update(train_metrics)
@@ -381,7 +520,8 @@ class BaseModel(FeatureSet, nn.Module):
381
520
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
382
521
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
383
522
  log_str += ", " + ", ".join(task_metric_strs)
384
- logging.info(colorize(log_str))
523
+ if self.is_main_process:
524
+ logging.info(colorize(log_str))
385
525
  train_log_payload["loss"] = float(total_loss_val)
386
526
  if train_metrics:
387
527
  train_log_payload.update(train_metrics)
@@ -392,7 +532,8 @@ class BaseModel(FeatureSet, nn.Module):
392
532
  val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
393
533
  if self.nums_task == 1:
394
534
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
395
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
535
+ if self.is_main_process:
536
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
396
537
  else:
397
538
  # multi task metrics
398
539
  task_metrics = {}
@@ -409,20 +550,29 @@ class BaseModel(FeatureSet, nn.Module):
409
550
  if target_name in task_metrics:
410
551
  metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
411
552
  task_metric_strs.append(f"{target_name}[{metrics_str}]")
412
- logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
553
+ if self.is_main_process:
554
+ logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
413
555
  if val_metrics and self.training_logger:
414
556
  self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
415
557
  # Handle empty validation metrics
416
558
  if not val_metrics:
417
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
418
- self.best_checkpoint_path = self.checkpoint_path
419
- logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
559
+ if self.is_main_process:
560
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
561
+ self.best_checkpoint_path = self.checkpoint_path
562
+ logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
420
563
  continue
421
564
  if self.nums_task == 1:
422
565
  primary_metric_key = self.metrics[0]
423
566
  else:
424
567
  primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
425
568
  primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
569
+
570
+ # In distributed mode, broadcast primary_metric to ensure all processes use the same value
571
+ if self.distributed and dist.is_available() and dist.is_initialized():
572
+ metric_tensor = torch.tensor([primary_metric], device=self.device, dtype=torch.float32)
573
+ dist.broadcast(metric_tensor, src=0)
574
+ primary_metric = float(metric_tensor.item())
575
+
426
576
  improved = False
427
577
  # early stopping check
428
578
  if self.best_metrics_mode == 'max':
@@ -433,24 +583,40 @@ class BaseModel(FeatureSet, nn.Module):
433
583
  if primary_metric < self.best_metric:
434
584
  self.best_metric = primary_metric
435
585
  improved = True
436
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
437
- logging.info(" ")
438
- if improved:
439
- logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
440
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
441
- self.best_checkpoint_path = self.best_path
442
- self.early_stopper.trial_counter = 0
586
+
587
+ # save checkpoint and best model for main process
588
+ if self.is_main_process:
589
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
590
+ logging.info(" ")
591
+ if improved:
592
+ logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
593
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
594
+ self.best_checkpoint_path = self.best_path
595
+ self.early_stopper.trial_counter = 0
596
+ else:
597
+ self.early_stopper.trial_counter += 1
598
+ logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
599
+ if self.early_stopper.trial_counter >= self.early_stopper.patience:
600
+ self.stop_training = True
601
+ logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
443
602
  else:
444
- self.early_stopper.trial_counter += 1
445
- logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
446
- if self.early_stopper.trial_counter >= self.early_stopper.patience:
447
- self.stop_training = True
448
- logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
449
- break
603
+ # Non-main processes also update trial_counter to keep in sync
604
+ if improved:
605
+ self.early_stopper.trial_counter = 0
606
+ else:
607
+ self.early_stopper.trial_counter += 1
450
608
  else:
451
- self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
452
- self.save_model(self.best_path, add_timestamp=False, verbose=False)
453
- self.best_checkpoint_path = self.best_path
609
+ if self.is_main_process:
610
+ self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
611
+ self.save_model(self.best_path, add_timestamp=False, verbose=False)
612
+ self.best_checkpoint_path = self.best_path
613
+
614
+ # Broadcast stop_training flag to all processes (always, regardless of validation)
615
+ if self.distributed and dist.is_available() and dist.is_initialized():
616
+ stop_tensor = torch.tensor([int(self.stop_training)], device=self.device)
617
+ dist.broadcast(stop_tensor, src=0)
618
+ self.stop_training = bool(stop_tensor.item())
619
+
454
620
  if self.stop_training:
455
621
  break
456
622
  if self.scheduler_fn is not None:
@@ -459,41 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
459
625
  self.scheduler_fn.step(primary_metric)
460
626
  else:
461
627
  self.scheduler_fn.step()
462
- logging.info(" ")
463
- logging.info(colorize("Training finished.", bold=True))
464
- logging.info(" ")
628
+ if self.distributed and dist.is_available() and dist.is_initialized():
629
+ dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
630
+ if self.is_main_process:
631
+ logging.info(" ")
632
+ logging.info(colorize("Training finished.", bold=True))
633
+ logging.info(" ")
465
634
  if valid_loader is not None:
466
- logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
635
+ if self.is_main_process:
636
+ logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
467
637
  self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
468
638
  if self.training_logger:
469
639
  self.training_logger.close()
470
640
  return self
471
641
 
472
642
  def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
643
+ # use ddp model for distributed training
644
+ model = self.ddp_model if getattr(self, "ddp_model") is not None else self
473
645
  accumulated_loss = 0.0
474
- self.train()
646
+ model.train() # type: ignore
475
647
  num_batches = 0
476
648
  y_true_list = []
477
649
  y_pred_list = []
478
650
 
479
651
  user_ids_list = [] if self.needs_user_ids else None
652
+ tqdm_disable = not self.is_main_process
480
653
  if self.steps_per_epoch is not None:
481
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
654
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch, disable=tqdm_disable))
482
655
  else:
483
656
  desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
484
- batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
657
+ batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable))
485
658
  for batch_index, batch_data in batch_iter:
486
659
  batch_dict = batch_to_dict(batch_data)
487
660
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
488
- y_pred = self.forward(X_input)
661
+ # call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
662
+ y_pred = model(X_input) # type: ignore
663
+
489
664
  loss = self.compute_loss(y_pred, y_true)
490
665
  reg_loss = self.add_reg_loss()
491
666
  total_loss = loss + reg_loss
492
667
  self.optimizer_fn.zero_grad()
493
668
  total_loss.backward()
494
- nn.utils.clip_grad_norm_(self.parameters(), self.max_gradient_norm)
669
+
670
+ params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
671
+ nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
495
672
  self.optimizer_fn.step()
496
673
  accumulated_loss += loss.item()
674
+
497
675
  if y_true is not None:
498
676
  y_true_list.append(y_true.detach().cpu().numpy())
499
677
  if self.needs_user_ids and user_ids_list is not None:
@@ -503,38 +681,78 @@ class BaseModel(FeatureSet, nn.Module):
503
681
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
504
682
  y_pred_list.append(y_pred.detach().cpu().numpy())
505
683
  num_batches += 1
684
+ if self.distributed and dist.is_available() and dist.is_initialized():
685
+ loss_tensor = torch.tensor([accumulated_loss, num_batches], device=self.device, dtype=torch.float32)
686
+ dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
687
+ accumulated_loss = loss_tensor[0].item()
688
+ num_batches = int(loss_tensor[1].item())
506
689
  avg_loss = accumulated_loss / max(num_batches, 1)
507
- if len(y_true_list) > 0 and len(y_pred_list) > 0: # Compute metrics if requested
508
- y_true_all = np.concatenate(y_true_list, axis=0)
509
- y_pred_all = np.concatenate(y_pred_list, axis=0)
510
- combined_user_ids = None
511
- if self.needs_user_ids and user_ids_list:
512
- combined_user_ids = np.concatenate(user_ids_list, axis=0)
690
+
691
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
692
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
693
+ combined_user_ids_local = np.concatenate(user_ids_list, axis=0) if self.needs_user_ids and user_ids_list else None
694
+
695
+ # gather across ranks even when local is empty to avoid DDP hang
696
+ y_true_all = gather_numpy(self, y_true_all_local)
697
+ y_pred_all = gather_numpy(self, y_pred_all_local)
698
+ combined_user_ids = gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
699
+
700
+ if y_true_all is not None and y_pred_all is not None and len(y_true_all) > 0 and len(y_pred_all) > 0:
513
701
  metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
514
702
  return avg_loss, metrics_dict
515
703
  return avg_loss
516
704
 
517
- def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0,) -> tuple[DataLoader | None, np.ndarray | None]:
705
+ def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0, auto_distributed_sampler: bool = True,) -> tuple[DataLoader | None, np.ndarray | None]:
518
706
  if valid_data is None:
519
707
  return None, None
520
708
  if isinstance(valid_data, DataLoader):
521
- return valid_data, None
522
- valid_loader = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
709
+ if auto_distributed_sampler and self.distributed:
710
+ raise NotImplementedError("[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
711
+ # valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
712
+ else:
713
+ valid_loader = valid_data
714
+ return valid_loader, None
715
+ valid_sampler = None
716
+ valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
717
+ if auto_distributed_sampler and self.distributed and valid_dataset is not None and dist.is_available() and dist.is_initialized():
718
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, drop_last=False)
719
+ valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler, collate_fn=collate_fn, num_workers=num_workers)
523
720
  valid_user_ids = None
524
721
  if needs_user_ids:
525
722
  if user_id_column is None:
526
723
  raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
527
- valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
724
+ # In distributed mode, user_ids will be collected during evaluation from each batch
725
+ # and gathered across all processes, so we don't pre-extract them here
726
+ if not self.distributed:
727
+ valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
528
728
  return valid_loader, valid_user_ids
529
729
 
530
- def evaluate(self,
531
- data: dict | pd.DataFrame | DataLoader,
532
- metrics: list[str] | dict[str, list[str]] | None = None,
533
- batch_size: int = 32,
534
- user_ids: np.ndarray | None = None,
535
- user_id_column: str = 'user_id',
536
- num_workers: int = 0,) -> dict:
537
- self.eval()
730
+ def evaluate(
731
+ self,
732
+ data: dict | pd.DataFrame | DataLoader,
733
+ metrics: list[str] | dict[str, list[str]] | None = None,
734
+ batch_size: int = 32,
735
+ user_ids: np.ndarray | None = None,
736
+ user_id_column: str = 'user_id',
737
+ num_workers: int = 0,) -> dict:
738
+ """
739
+ **IMPORTANT for Distributed Training:**
740
+ in distributed mode, this method uses collective communication operations (all_gather).
741
+ all processes must call this method simultaneously, even if you only want results on rank 0.
742
+ failing to do so will cause the program to hang indefinitely.
743
+
744
+ Evaluate the model on the given data.
745
+
746
+ Args:
747
+ data: Evaluation data (dict/df/DataLoader).
748
+ metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
749
+ batch_size: Batch size (per process when distributed).
750
+ user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
751
+ user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
752
+ num_workers: DataLoader worker count.
753
+ """
754
+ model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
755
+ model.eval()
538
756
  eval_metrics = metrics if metrics is not None else self.metrics
539
757
  if eval_metrics is None:
540
758
  raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
@@ -555,7 +773,7 @@ class BaseModel(FeatureSet, nn.Module):
555
773
  batch_count += 1
556
774
  batch_dict = batch_to_dict(batch_data)
557
775
  X_input, y_true = self.get_input(batch_dict, require_labels=True)
558
- y_pred = self.forward(X_input)
776
+ y_pred = model(X_input)
559
777
  if y_true is not None:
560
778
  y_true_list.append(y_true.cpu().numpy())
561
779
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
@@ -564,20 +782,11 @@ class BaseModel(FeatureSet, nn.Module):
564
782
  batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
565
783
  if batch_user_id is not None:
566
784
  collected_user_ids.append(batch_user_id)
567
- logging.info(" ")
568
- logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
569
- if len(y_true_list) > 0:
570
- y_true_all = np.concatenate(y_true_list, axis=0)
571
- logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
572
- else:
573
- y_true_all = None
574
- logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
575
-
576
- if len(y_pred_list) > 0:
577
- y_pred_all = np.concatenate(y_pred_list, axis=0)
578
- else:
579
- y_pred_all = None
580
- logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
785
+ if self.is_main_process:
786
+ logging.info(" ")
787
+ logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
788
+ y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
789
+ y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
581
790
 
582
791
  # Convert metrics to list if it's a dict
583
792
  if isinstance(eval_metrics, dict):
@@ -590,51 +799,86 @@ class BaseModel(FeatureSet, nn.Module):
590
799
  metrics_to_use = unique_metrics
591
800
  else:
592
801
  metrics_to_use = eval_metrics
593
- final_user_ids = user_ids
594
- if final_user_ids is None and collected_user_ids:
595
- final_user_ids = np.concatenate(collected_user_ids, axis=0)
802
+ final_user_ids_local = user_ids
803
+ if final_user_ids_local is None and collected_user_ids:
804
+ final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
805
+
806
+ # gather across ranks even when local arrays are empty to keep collectives aligned
807
+ y_true_all = gather_numpy(self, y_true_all_local)
808
+ y_pred_all = gather_numpy(self, y_pred_all_local)
809
+ final_user_ids = gather_numpy(self, final_user_ids_local) if needs_user_ids else None
810
+ if y_true_all is None or y_pred_all is None or len(y_true_all) == 0 or len(y_pred_all) == 0:
811
+ if self.is_main_process:
812
+ logging.info(colorize(" Warning: Not enough evaluation data to compute metrics after gathering", color="yellow"))
813
+ return {}
814
+ if self.is_main_process:
815
+ logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
596
816
  metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
597
817
  return metrics_dict
598
818
 
599
819
  def predict(
600
- self,
601
- data: str | dict | pd.DataFrame | DataLoader,
602
- batch_size: int = 32,
603
- save_path: str | os.PathLike | None = None,
604
- save_format: Literal["csv", "parquet"] = "csv",
605
- include_ids: bool | None = None,
606
- return_dataframe: bool = True,
607
- streaming_chunk_size: int = 10000,
608
- num_workers: int = 0,
609
- ) -> pd.DataFrame | np.ndarray:
820
+ self,
821
+ data: str | dict | pd.DataFrame | DataLoader,
822
+ batch_size: int = 32,
823
+ save_path: str | os.PathLike | None = None,
824
+ save_format: Literal["csv", "parquet"] = "csv",
825
+ include_ids: bool | None = None,
826
+ id_columns: str | list[str] | None = None,
827
+ return_dataframe: bool = True,
828
+ streaming_chunk_size: int = 10000,
829
+ num_workers: int = 0,
830
+ ) -> pd.DataFrame | np.ndarray:
831
+ """
832
+ Note: predict does not support distributed mode currently, consider it as a single-process operation.
833
+ Make predictions on the given data.
834
+
835
+ Args:
836
+ data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
837
+ batch_size: Batch size for prediction (per process when distributed).
838
+ save_path: Optional path to save predictions; if None, predictions are not saved to disk.
839
+ save_format: Format to save predictions ('csv' or 'parquet').
840
+ include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
841
+ id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
842
+ return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
843
+ streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
844
+ num_workers: DataLoader worker count.
845
+ """
610
846
  self.eval()
847
+ # Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
848
+ predict_id_columns = id_columns if id_columns is not None else self.id_columns
849
+ if isinstance(predict_id_columns, str):
850
+ predict_id_columns = [predict_id_columns]
851
+
611
852
  if include_ids is None:
612
- include_ids = bool(self.id_columns)
613
- include_ids = include_ids and bool(self.id_columns)
853
+ include_ids = bool(predict_id_columns)
854
+ include_ids = include_ids and bool(predict_id_columns)
614
855
 
856
+ # Use streaming mode for large file saves without loading all data into memory
615
857
  if save_path is not None and not return_dataframe:
616
- return self._predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe)
617
- if isinstance(data, (str, os.PathLike)):
618
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns,)
858
+ return self.predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe, id_columns=predict_id_columns)
859
+
860
+ # Create DataLoader based on data type
861
+ if isinstance(data, DataLoader):
862
+ data_loader = data
863
+ elif isinstance(data, (str, os.PathLike)):
864
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=predict_id_columns,)
619
865
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
620
- elif not isinstance(data, DataLoader):
621
- data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
622
866
  else:
623
- data_loader = data
867
+ data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
624
868
 
625
- y_pred_list: list[np.ndarray] = []
626
- id_buffers: dict[str, list[np.ndarray]] = {name: [] for name in (self.id_columns or [])} if include_ids else {}
627
- id_arrays: dict[str, np.ndarray] | None = None
869
+ y_pred_list = []
870
+ id_buffers = {name: [] for name in (predict_id_columns or [])} if include_ids else {}
871
+ id_arrays = None
628
872
 
629
873
  with torch.no_grad():
630
874
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
631
875
  batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
632
876
  X_input, _ = self.get_input(batch_dict, require_labels=False)
633
- y_pred = self.forward(X_input)
877
+ y_pred = self(X_input)
634
878
  if y_pred is not None and isinstance(y_pred, torch.Tensor):
635
879
  y_pred_list.append(y_pred.detach().cpu().numpy())
636
- if include_ids and self.id_columns and batch_dict.get("ids"):
637
- for id_name in self.id_columns:
880
+ if include_ids and predict_id_columns and batch_dict.get("ids"):
881
+ for id_name in predict_id_columns:
638
882
  if id_name not in batch_dict["ids"]:
639
883
  continue
640
884
  id_tensor = batch_dict["ids"][id_name]
@@ -657,7 +901,7 @@ class BaseModel(FeatureSet, nn.Module):
657
901
  pred_columns.append(f"{name}_pred")
658
902
  while len(pred_columns) < num_outputs:
659
903
  pred_columns.append(f"pred_{len(pred_columns)}")
660
- if include_ids and self.id_columns:
904
+ if include_ids and predict_id_columns:
661
905
  id_arrays = {}
662
906
  for id_name, pieces in id_buffers.items():
663
907
  if pieces:
@@ -684,7 +928,7 @@ class BaseModel(FeatureSet, nn.Module):
684
928
  df_to_save = output
685
929
  else:
686
930
  df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
687
- if include_ids and self.id_columns and id_arrays is not None:
931
+ if include_ids and predict_id_columns and id_arrays is not None:
688
932
  id_df = pd.DataFrame(id_arrays)
689
933
  if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
690
934
  raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
@@ -696,7 +940,7 @@ class BaseModel(FeatureSet, nn.Module):
696
940
  logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
697
941
  return output
698
942
 
699
- def _predict_streaming(
943
+ def predict_streaming(
700
944
  self,
701
945
  data: str | dict | pd.DataFrame | DataLoader,
702
946
  batch_size: int,
@@ -705,9 +949,10 @@ class BaseModel(FeatureSet, nn.Module):
705
949
  include_ids: bool,
706
950
  streaming_chunk_size: int,
707
951
  return_dataframe: bool,
952
+ id_columns: list[str] | None = None,
708
953
  ) -> pd.DataFrame:
709
954
  if isinstance(data, (str, os.PathLike)):
710
- rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=self.id_columns)
955
+ rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=id_columns)
711
956
  data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
712
957
  elif not isinstance(data, DataLoader):
713
958
  data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
@@ -720,8 +965,8 @@ class BaseModel(FeatureSet, nn.Module):
720
965
  header_written = target_path.exists() and target_path.stat().st_size > 0
721
966
  parquet_writer = None
722
967
 
723
- pred_columns: list[str] | None = None
724
- collected_frames: list[pd.DataFrame] = []
968
+ pred_columns = None
969
+ collected_frames = [] # only used when return_dataframe is True
725
970
 
726
971
  with torch.no_grad():
727
972
  for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
@@ -742,9 +987,9 @@ class BaseModel(FeatureSet, nn.Module):
742
987
  while len(pred_columns) < num_outputs:
743
988
  pred_columns.append(f"pred_{len(pred_columns)}")
744
989
 
745
- id_arrays_batch: dict[str, np.ndarray] = {}
746
- if include_ids and self.id_columns and batch_dict.get("ids"):
747
- for id_name in self.id_columns:
990
+ id_arrays_batch = {}
991
+ if include_ids and id_columns and batch_dict.get("ids"):
992
+ for id_name in id_columns:
748
993
  if id_name not in batch_dict["ids"]:
749
994
  continue
750
995
  id_tensor = batch_dict["ids"][id_name]
@@ -784,7 +1029,10 @@ class BaseModel(FeatureSet, nn.Module):
784
1029
  add_timestamp = False if add_timestamp is None else add_timestamp
785
1030
  target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
786
1031
  model_path = Path(target_path)
787
- torch.save(self.state_dict(), model_path)
1032
+
1033
+ model_to_save = (self.ddp_model.module if getattr(self, "ddp_model", None) is not None else self)
1034
+ torch.save(model_to_save.state_dict(), model_path)
1035
+ # torch.save(self.state_dict(), model_path)
788
1036
 
789
1037
  config_path = self.features_config_path
790
1038
  features_config = {
@@ -845,8 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
845
1093
  **kwargs: Any,
846
1094
  ) -> "BaseModel":
847
1095
  """
848
- Factory that reconstructs a model instance (including feature specs)
849
- from a saved checkpoint directory or *.model file.
1096
+ Load a model from a checkpoint path. The checkpoint path should contain:
1097
+ a .model file and a features_config.pkl file.
850
1098
  """
851
1099
  base_path = Path(checkpoint_path)
852
1100
  verbose = kwargs.pop("verbose", True)
@@ -1006,10 +1254,10 @@ class BaseMatchModel(BaseModel):
1006
1254
  @property
1007
1255
  def model_name(self) -> str:
1008
1256
  raise NotImplementedError
1009
-
1257
+
1010
1258
  @property
1011
- def task_type(self) -> str:
1012
- raise NotImplementedError
1259
+ def default_task(self) -> str:
1260
+ return "binary"
1013
1261
 
1014
1262
  @property
1015
1263
  def support_training_modes(self) -> list[str]: