nextrec 0.3.6__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -15
- nextrec/basic/model.py +435 -187
- nextrec/data/data_processing.py +31 -19
- nextrec/data/dataloader.py +40 -10
- nextrec/models/generative/hstu.py +3 -2
- nextrec/models/match/dssm.py +0 -1
- nextrec/models/match/dssm_v2.py +0 -1
- nextrec/models/match/mind.py +0 -1
- nextrec/models/match/sdm.py +0 -1
- nextrec/models/match/youtube_dnn.py +0 -1
- nextrec/models/multi_task/esmm.py +5 -7
- nextrec/models/multi_task/mmoe.py +10 -6
- nextrec/models/multi_task/ple.py +10 -6
- nextrec/models/multi_task/poso.py +9 -6
- nextrec/models/multi_task/share_bottom.py +10 -7
- nextrec/models/ranking/afm.py +113 -21
- nextrec/models/ranking/autoint.py +15 -9
- nextrec/models/ranking/dcn.py +8 -11
- nextrec/models/ranking/deepfm.py +5 -5
- nextrec/models/ranking/dien.py +4 -4
- nextrec/models/ranking/din.py +4 -4
- nextrec/models/ranking/fibinet.py +4 -4
- nextrec/models/ranking/fm.py +4 -4
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +4 -4
- nextrec/models/ranking/widedeep.py +4 -4
- nextrec/models/ranking/xdeepfm.py +4 -4
- nextrec/utils/__init__.py +7 -3
- nextrec/utils/device.py +30 -0
- nextrec/utils/distributed.py +114 -0
- nextrec/utils/synthetic_data.py +413 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
- nextrec-0.4.1.dist-info/RECORD +66 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 05/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
8
|
import os
|
|
10
9
|
import tqdm
|
|
11
10
|
import pickle
|
|
@@ -17,10 +16,13 @@ import pandas as pd
|
|
|
17
16
|
import torch
|
|
18
17
|
import torch.nn as nn
|
|
19
18
|
import torch.nn.functional as F
|
|
19
|
+
import torch.distributed as dist
|
|
20
20
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Union, Literal, Any
|
|
23
23
|
from torch.utils.data import DataLoader
|
|
24
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
25
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
24
26
|
|
|
25
27
|
from nextrec.basic.callback import EarlyStopper
|
|
26
28
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
@@ -31,22 +33,23 @@ from nextrec.basic.session import resolve_save_path, create_session
|
|
|
31
33
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
32
34
|
|
|
33
35
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
34
|
-
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
36
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
37
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
36
38
|
|
|
37
39
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
38
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
40
|
from nextrec.utils.tensor import to_tensor
|
|
40
|
-
|
|
41
|
+
from nextrec.utils.device import configure_device
|
|
42
|
+
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
43
|
+
from nextrec.utils.distributed import gather_numpy, init_process_group, add_distributed_sampler
|
|
41
44
|
from nextrec import __version__
|
|
42
45
|
|
|
43
46
|
class BaseModel(FeatureSet, nn.Module):
|
|
44
47
|
@property
|
|
45
48
|
def model_name(self) -> str:
|
|
46
49
|
raise NotImplementedError
|
|
47
|
-
|
|
50
|
+
|
|
48
51
|
@property
|
|
49
|
-
def
|
|
52
|
+
def default_task(self) -> str | list[str]:
|
|
50
53
|
raise NotImplementedError
|
|
51
54
|
|
|
52
55
|
def __init__(self,
|
|
@@ -55,21 +58,57 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
55
58
|
sequence_features: list[SequenceFeature] | None = None,
|
|
56
59
|
target: list[str] | str | None = None,
|
|
57
60
|
id_columns: list[str] | str | None = None,
|
|
58
|
-
task: str|list[str] =
|
|
61
|
+
task: str | list[str] | None = None,
|
|
59
62
|
device: str = 'cpu',
|
|
63
|
+
early_stop_patience: int = 20,
|
|
64
|
+
session_id: str | None = None,
|
|
60
65
|
embedding_l1_reg: float = 0.0,
|
|
61
66
|
dense_l1_reg: float = 0.0,
|
|
62
67
|
embedding_l2_reg: float = 0.0,
|
|
63
68
|
dense_l2_reg: float = 0.0,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
|
|
70
|
+
distributed: bool = False,
|
|
71
|
+
rank: int | None = None,
|
|
72
|
+
world_size: int | None = None,
|
|
73
|
+
local_rank: int | None = None,
|
|
74
|
+
ddp_find_unused_parameters: bool = False,):
|
|
75
|
+
"""
|
|
76
|
+
Initialize a base model.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
dense_features: DenseFeature definitions.
|
|
80
|
+
sparse_features: SparseFeature definitions.
|
|
81
|
+
sequence_features: SequenceFeature definitions.
|
|
82
|
+
target: Target column name.
|
|
83
|
+
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
84
|
+
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
85
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
86
|
+
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
87
|
+
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
88
|
+
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
89
|
+
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
90
|
+
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
91
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
92
|
+
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
93
|
+
rank: Global rank (defaults to env RANK).
|
|
94
|
+
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
95
|
+
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
96
|
+
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
97
|
+
"""
|
|
67
98
|
super(BaseModel, self).__init__()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
99
|
+
|
|
100
|
+
# distributed training settings
|
|
101
|
+
env_rank = int(os.environ.get("RANK", "0"))
|
|
102
|
+
env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
103
|
+
env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
104
|
+
self.distributed = distributed or (env_world_size > 1)
|
|
105
|
+
self.rank = env_rank if rank is None else rank
|
|
106
|
+
self.world_size = env_world_size if world_size is None else world_size
|
|
107
|
+
self.local_rank = env_local_rank if local_rank is None else local_rank
|
|
108
|
+
self.is_main_process = self.rank == 0
|
|
109
|
+
self.ddp_find_unused_parameters = ddp_find_unused_parameters
|
|
110
|
+
self.ddp_model: DDP | None = None
|
|
111
|
+
self.device = configure_device(self.distributed, self.local_rank, device)
|
|
73
112
|
|
|
74
113
|
self.session_id = session_id
|
|
75
114
|
self.session = create_session(session_id)
|
|
@@ -79,8 +118,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
79
118
|
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
80
119
|
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
81
120
|
|
|
82
|
-
self.task = task
|
|
83
|
-
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
121
|
+
self.task = self.default_task if task is None else task
|
|
122
|
+
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
84
123
|
|
|
85
124
|
self.embedding_l1_reg = embedding_l1_reg
|
|
86
125
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -89,10 +128,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
89
128
|
self.regularization_weights = []
|
|
90
129
|
self.embedding_params = []
|
|
91
130
|
self.loss_weight = None
|
|
131
|
+
|
|
92
132
|
self.early_stop_patience = early_stop_patience
|
|
93
133
|
self.max_gradient_norm = 1.0
|
|
94
134
|
self.logger_initialized = False
|
|
95
|
-
self.training_logger
|
|
135
|
+
self.training_logger = None
|
|
96
136
|
|
|
97
137
|
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
98
138
|
exclude_modules = exclude_modules or []
|
|
@@ -145,18 +185,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
145
185
|
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
146
186
|
continue
|
|
147
187
|
target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
|
|
148
|
-
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
188
|
+
target_tensor = target_tensor.view(target_tensor.size(0), -1) # always reshape to (batch_size, num_targets)
|
|
149
189
|
target_tensors.append(target_tensor)
|
|
150
190
|
if target_tensors:
|
|
151
191
|
y = torch.cat(target_tensors, dim=1)
|
|
152
|
-
if y.shape[1] == 1:
|
|
192
|
+
if y.shape[1] == 1: # no need to do that again
|
|
153
193
|
y = y.view(-1)
|
|
154
194
|
elif require_labels:
|
|
155
195
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
156
196
|
return X_input, y
|
|
157
197
|
|
|
158
|
-
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,)
|
|
159
|
-
"""
|
|
198
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,):
|
|
199
|
+
"""
|
|
200
|
+
This function will split training data into training and validation sets when:
|
|
201
|
+
1. valid_data is None;
|
|
202
|
+
2. validation_split is provided.
|
|
203
|
+
"""
|
|
160
204
|
if not (0 < validation_split < 1):
|
|
161
205
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
162
206
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
@@ -189,15 +233,30 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
189
233
|
return train_loader, valid_split
|
|
190
234
|
|
|
191
235
|
def compile(
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
236
|
+
self,
|
|
237
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
238
|
+
optimizer_params: dict | None = None,
|
|
239
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
240
|
+
scheduler_params: dict | None = None,
|
|
241
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
242
|
+
loss_params: dict | list[dict] | None = None,
|
|
243
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
Configure the model for training.
|
|
247
|
+
Args:
|
|
248
|
+
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
249
|
+
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
250
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
|
|
251
|
+
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
252
|
+
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
253
|
+
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
254
|
+
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
255
|
+
"""
|
|
256
|
+
if loss_params is None:
|
|
257
|
+
self.loss_params = {}
|
|
258
|
+
else:
|
|
259
|
+
self.loss_params = loss_params
|
|
201
260
|
optimizer_params = optimizer_params or {}
|
|
202
261
|
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
203
262
|
self.optimizer_params = optimizer_params
|
|
@@ -217,7 +276,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
217
276
|
self.loss_params = loss_params or {}
|
|
218
277
|
self.loss_fn = []
|
|
219
278
|
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
220
|
-
|
|
279
|
+
if len(loss) != self.nums_task:
|
|
280
|
+
raise ValueError(f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task}).")
|
|
281
|
+
loss_list = [loss[i] for i in range(self.nums_task)]
|
|
221
282
|
else: # for example: 'bce' -> ['bce', 'bce']
|
|
222
283
|
loss_list = [loss] * self.nums_task
|
|
223
284
|
|
|
@@ -231,12 +292,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
231
292
|
self.loss_weights = None
|
|
232
293
|
elif self.nums_task == 1:
|
|
233
294
|
if isinstance(loss_weights, (list, tuple)):
|
|
234
|
-
if len(loss_weights) != 1
|
|
295
|
+
if len(loss_weights) != 1:
|
|
235
296
|
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
236
297
|
weight_value = loss_weights[0]
|
|
237
298
|
else:
|
|
238
299
|
weight_value = loss_weights
|
|
239
|
-
self.loss_weights = float(weight_value)
|
|
300
|
+
self.loss_weights = [float(weight_value)]
|
|
240
301
|
else:
|
|
241
302
|
if isinstance(loss_weights, (int, float)):
|
|
242
303
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -250,29 +311,48 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
250
311
|
|
|
251
312
|
def compute_loss(self, y_pred, y_true):
|
|
252
313
|
if y_true is None:
|
|
253
|
-
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required
|
|
314
|
+
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required.")
|
|
254
315
|
if self.nums_task == 1:
|
|
255
|
-
|
|
316
|
+
if y_pred.dim() == 1:
|
|
317
|
+
y_pred = y_pred.view(-1, 1)
|
|
318
|
+
if y_true.dim() == 1:
|
|
319
|
+
y_true = y_true.view(-1, 1)
|
|
320
|
+
if y_pred.shape != y_true.shape:
|
|
321
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
322
|
+
task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
323
|
+
if task_dim == 1:
|
|
324
|
+
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
325
|
+
else:
|
|
326
|
+
loss = self.loss_fn[0](y_pred, y_true)
|
|
256
327
|
if self.loss_weights is not None:
|
|
257
|
-
loss
|
|
328
|
+
loss *= self.loss_weights[0]
|
|
258
329
|
return loss
|
|
330
|
+
# multi-task
|
|
331
|
+
if y_pred.shape != y_true.shape:
|
|
332
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
333
|
+
if hasattr(self, "prediction_layer"): # we need to use registered task_slices for multi-task and multi-class
|
|
334
|
+
slices = self.prediction_layer._task_slices # type: ignore
|
|
259
335
|
else:
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
336
|
+
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
337
|
+
task_losses = []
|
|
338
|
+
for i, (start, end) in enumerate(slices): # type: ignore
|
|
339
|
+
y_pred_i = y_pred[:, start:end]
|
|
340
|
+
y_true_i = y_true[:, start:end]
|
|
341
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
342
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
343
|
+
task_loss *= self.loss_weights[i]
|
|
344
|
+
task_losses.append(task_loss)
|
|
345
|
+
return torch.stack(task_losses).sum()
|
|
346
|
+
|
|
347
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0, sampler=None, return_dataset: bool = False) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
269
348
|
if isinstance(data, DataLoader):
|
|
270
|
-
return data
|
|
349
|
+
return (data, None) if return_dataset else data
|
|
271
350
|
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
272
351
|
if tensors is None:
|
|
273
352
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
274
353
|
dataset = TensorDictDataset(tensors)
|
|
275
|
-
|
|
354
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
355
|
+
return (loader, dataset) if return_dataset else loader
|
|
276
356
|
|
|
277
357
|
def fit(self,
|
|
278
358
|
train_data: dict | pd.DataFrame | DataLoader,
|
|
@@ -282,27 +362,82 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
282
362
|
user_id_column: str | None = None,
|
|
283
363
|
validation_split: float | None = None,
|
|
284
364
|
num_workers: int = 0,
|
|
285
|
-
tensorboard: bool = True,
|
|
365
|
+
tensorboard: bool = True,
|
|
366
|
+
auto_distributed_sampler: bool = True,):
|
|
367
|
+
"""
|
|
368
|
+
Train the model.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
|
|
372
|
+
valid_data: Optional validation data; if None and validation_split is set, a split is created.
|
|
373
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
374
|
+
epochs: Training epochs.
|
|
375
|
+
shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
|
|
376
|
+
batch_size: Batch size (per process when distributed).
|
|
377
|
+
user_id_column: Column name for GAUC-style metrics;.
|
|
378
|
+
validation_split: Ratio to split training data when valid_data is None.
|
|
379
|
+
num_workers: DataLoader worker count.
|
|
380
|
+
tensorboard: Enable tensorboard logging.
|
|
381
|
+
auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
382
|
+
|
|
383
|
+
Notes:
|
|
384
|
+
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
385
|
+
- All ranks must call evaluate() together because it performs collective ops.
|
|
386
|
+
"""
|
|
387
|
+
device_id = self.local_rank if self.device.type == "cuda" else None
|
|
388
|
+
init_process_group(self.distributed, self.rank, self.world_size, device_id=device_id)
|
|
286
389
|
self.to(self.device)
|
|
287
|
-
|
|
390
|
+
|
|
391
|
+
if self.distributed and dist.is_available() and dist.is_initialized() and self.ddp_model is None:
|
|
392
|
+
device_ids = [self.local_rank] if self.device.type == "cuda" else None # device_ids means which device to use in ddp
|
|
393
|
+
output_device = self.local_rank if self.device.type == "cuda" else None # output_device means which device to place the output in ddp
|
|
394
|
+
object.__setattr__(self, "ddp_model", DDP(self, device_ids=device_ids, output_device=output_device, find_unused_parameters=self.ddp_find_unused_parameters))
|
|
395
|
+
|
|
396
|
+
if not self.logger_initialized and self.is_main_process: # only main process initializes logger
|
|
288
397
|
setup_logger(session_id=self.session_id)
|
|
289
398
|
self.logger_initialized = True
|
|
290
|
-
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
399
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard) if self.is_main_process else None
|
|
291
400
|
|
|
292
401
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
293
402
|
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
403
|
+
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
404
|
+
|
|
294
405
|
self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
|
|
295
406
|
self.epoch_index = 0
|
|
296
407
|
self.stop_training = False
|
|
297
408
|
self.best_checkpoint_path = self.best_path
|
|
298
|
-
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
299
409
|
|
|
410
|
+
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
411
|
+
logging.info(colorize("[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.", color="yellow"))
|
|
412
|
+
|
|
413
|
+
train_sampler: DistributedSampler | None = None
|
|
300
414
|
if validation_split is not None and valid_data is None:
|
|
301
415
|
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
416
|
+
if auto_distributed_sampler and self.distributed and dist.is_available() and dist.is_initialized():
|
|
417
|
+
base_dataset = getattr(train_loader, "dataset", None)
|
|
418
|
+
if base_dataset is not None and not isinstance(getattr(train_loader, "sampler", None), DistributedSampler):
|
|
419
|
+
train_sampler = DistributedSampler(base_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
|
|
420
|
+
train_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
|
|
302
421
|
else:
|
|
303
|
-
|
|
422
|
+
if isinstance(train_data, DataLoader):
|
|
423
|
+
if auto_distributed_sampler and self.distributed:
|
|
424
|
+
train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
425
|
+
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
426
|
+
else:
|
|
427
|
+
train_loader = train_data
|
|
428
|
+
else:
|
|
429
|
+
loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
430
|
+
if auto_distributed_sampler and self.distributed and dataset is not None and dist.is_available() and dist.is_initialized():
|
|
431
|
+
train_sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
|
|
432
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
|
|
433
|
+
train_loader = loader
|
|
434
|
+
|
|
435
|
+
# If split-based loader was built without sampler, attach here when enabled
|
|
436
|
+
if self.distributed and auto_distributed_sampler and isinstance(train_loader, DataLoader) and train_sampler is None:
|
|
437
|
+
raise NotImplementedError("[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
|
|
438
|
+
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
304
439
|
|
|
305
|
-
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers)
|
|
440
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers, auto_distributed_sampler=auto_distributed_sampler)
|
|
306
441
|
try:
|
|
307
442
|
self.steps_per_epoch = len(train_loader)
|
|
308
443
|
is_streaming = False
|
|
@@ -310,38 +445,41 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
310
445
|
self.steps_per_epoch = None
|
|
311
446
|
is_streaming = True
|
|
312
447
|
|
|
313
|
-
self.
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
448
|
+
if self.is_main_process:
|
|
449
|
+
self.summary()
|
|
450
|
+
logging.info("")
|
|
451
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
452
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
453
|
+
if tb_dir:
|
|
454
|
+
user = getpass.getuser()
|
|
455
|
+
host = socket.gethostname()
|
|
456
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
457
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
458
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
459
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
460
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
461
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
462
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
463
|
+
|
|
464
|
+
logging.info("")
|
|
465
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
466
|
+
if is_streaming:
|
|
467
|
+
logging.info(colorize(f"Start streaming training", bold=True))
|
|
468
|
+
else:
|
|
469
|
+
logging.info(colorize(f"Start training", bold=True))
|
|
470
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
471
|
+
logging.info("")
|
|
472
|
+
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
337
473
|
|
|
338
474
|
for epoch in range(epochs):
|
|
339
475
|
self.epoch_index = epoch
|
|
340
|
-
if is_streaming:
|
|
476
|
+
if is_streaming and self.is_main_process:
|
|
341
477
|
logging.info("")
|
|
342
478
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
343
479
|
|
|
344
480
|
# handle train result
|
|
481
|
+
if self.distributed and hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
|
|
482
|
+
train_loader.sampler.set_epoch(epoch)
|
|
345
483
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
346
484
|
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
347
485
|
train_loss, train_metrics = train_result
|
|
@@ -356,7 +494,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
356
494
|
if train_metrics:
|
|
357
495
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
358
496
|
log_str += f", {metrics_str}"
|
|
359
|
-
|
|
497
|
+
if self.is_main_process:
|
|
498
|
+
logging.info(colorize(log_str))
|
|
360
499
|
train_log_payload["loss"] = float(train_loss)
|
|
361
500
|
if train_metrics:
|
|
362
501
|
train_log_payload.update(train_metrics)
|
|
@@ -381,7 +520,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
381
520
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
382
521
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
383
522
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
384
|
-
|
|
523
|
+
if self.is_main_process:
|
|
524
|
+
logging.info(colorize(log_str))
|
|
385
525
|
train_log_payload["loss"] = float(total_loss_val)
|
|
386
526
|
if train_metrics:
|
|
387
527
|
train_log_payload.update(train_metrics)
|
|
@@ -392,7 +532,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
392
532
|
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
393
533
|
if self.nums_task == 1:
|
|
394
534
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
395
|
-
|
|
535
|
+
if self.is_main_process:
|
|
536
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
396
537
|
else:
|
|
397
538
|
# multi task metrics
|
|
398
539
|
task_metrics = {}
|
|
@@ -409,20 +550,29 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
409
550
|
if target_name in task_metrics:
|
|
410
551
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
411
552
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
412
|
-
|
|
553
|
+
if self.is_main_process:
|
|
554
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
413
555
|
if val_metrics and self.training_logger:
|
|
414
556
|
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
415
557
|
# Handle empty validation metrics
|
|
416
558
|
if not val_metrics:
|
|
417
|
-
self.
|
|
418
|
-
|
|
419
|
-
|
|
559
|
+
if self.is_main_process:
|
|
560
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
561
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
562
|
+
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
420
563
|
continue
|
|
421
564
|
if self.nums_task == 1:
|
|
422
565
|
primary_metric_key = self.metrics[0]
|
|
423
566
|
else:
|
|
424
567
|
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
425
568
|
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
|
|
569
|
+
|
|
570
|
+
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
571
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
572
|
+
metric_tensor = torch.tensor([primary_metric], device=self.device, dtype=torch.float32)
|
|
573
|
+
dist.broadcast(metric_tensor, src=0)
|
|
574
|
+
primary_metric = float(metric_tensor.item())
|
|
575
|
+
|
|
426
576
|
improved = False
|
|
427
577
|
# early stopping check
|
|
428
578
|
if self.best_metrics_mode == 'max':
|
|
@@ -433,24 +583,40 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
433
583
|
if primary_metric < self.best_metric:
|
|
434
584
|
self.best_metric = primary_metric
|
|
435
585
|
improved = True
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
if
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
586
|
+
|
|
587
|
+
# save checkpoint and best model for main process
|
|
588
|
+
if self.is_main_process:
|
|
589
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
590
|
+
logging.info(" ")
|
|
591
|
+
if improved:
|
|
592
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
593
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
594
|
+
self.best_checkpoint_path = self.best_path
|
|
595
|
+
self.early_stopper.trial_counter = 0
|
|
596
|
+
else:
|
|
597
|
+
self.early_stopper.trial_counter += 1
|
|
598
|
+
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
599
|
+
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
600
|
+
self.stop_training = True
|
|
601
|
+
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
443
602
|
else:
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
break
|
|
603
|
+
# Non-main processes also update trial_counter to keep in sync
|
|
604
|
+
if improved:
|
|
605
|
+
self.early_stopper.trial_counter = 0
|
|
606
|
+
else:
|
|
607
|
+
self.early_stopper.trial_counter += 1
|
|
450
608
|
else:
|
|
451
|
-
self.
|
|
452
|
-
|
|
453
|
-
|
|
609
|
+
if self.is_main_process:
|
|
610
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
611
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
612
|
+
self.best_checkpoint_path = self.best_path
|
|
613
|
+
|
|
614
|
+
# Broadcast stop_training flag to all processes (always, regardless of validation)
|
|
615
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
616
|
+
stop_tensor = torch.tensor([int(self.stop_training)], device=self.device)
|
|
617
|
+
dist.broadcast(stop_tensor, src=0)
|
|
618
|
+
self.stop_training = bool(stop_tensor.item())
|
|
619
|
+
|
|
454
620
|
if self.stop_training:
|
|
455
621
|
break
|
|
456
622
|
if self.scheduler_fn is not None:
|
|
@@ -459,41 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
459
625
|
self.scheduler_fn.step(primary_metric)
|
|
460
626
|
else:
|
|
461
627
|
self.scheduler_fn.step()
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
628
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
629
|
+
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
630
|
+
if self.is_main_process:
|
|
631
|
+
logging.info(" ")
|
|
632
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
633
|
+
logging.info(" ")
|
|
465
634
|
if valid_loader is not None:
|
|
466
|
-
|
|
635
|
+
if self.is_main_process:
|
|
636
|
+
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
467
637
|
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
468
638
|
if self.training_logger:
|
|
469
639
|
self.training_logger.close()
|
|
470
640
|
return self
|
|
471
641
|
|
|
472
642
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
643
|
+
# use ddp model for distributed training
|
|
644
|
+
model = self.ddp_model if getattr(self, "ddp_model") is not None else self
|
|
473
645
|
accumulated_loss = 0.0
|
|
474
|
-
|
|
646
|
+
model.train() # type: ignore
|
|
475
647
|
num_batches = 0
|
|
476
648
|
y_true_list = []
|
|
477
649
|
y_pred_list = []
|
|
478
650
|
|
|
479
651
|
user_ids_list = [] if self.needs_user_ids else None
|
|
652
|
+
tqdm_disable = not self.is_main_process
|
|
480
653
|
if self.steps_per_epoch is not None:
|
|
481
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
|
|
654
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch, disable=tqdm_disable))
|
|
482
655
|
else:
|
|
483
656
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
484
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
|
|
657
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable))
|
|
485
658
|
for batch_index, batch_data in batch_iter:
|
|
486
659
|
batch_dict = batch_to_dict(batch_data)
|
|
487
660
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
488
|
-
|
|
661
|
+
# call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
|
|
662
|
+
y_pred = model(X_input) # type: ignore
|
|
663
|
+
|
|
489
664
|
loss = self.compute_loss(y_pred, y_true)
|
|
490
665
|
reg_loss = self.add_reg_loss()
|
|
491
666
|
total_loss = loss + reg_loss
|
|
492
667
|
self.optimizer_fn.zero_grad()
|
|
493
668
|
total_loss.backward()
|
|
494
|
-
|
|
669
|
+
|
|
670
|
+
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
671
|
+
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
495
672
|
self.optimizer_fn.step()
|
|
496
673
|
accumulated_loss += loss.item()
|
|
674
|
+
|
|
497
675
|
if y_true is not None:
|
|
498
676
|
y_true_list.append(y_true.detach().cpu().numpy())
|
|
499
677
|
if self.needs_user_ids and user_ids_list is not None:
|
|
@@ -503,38 +681,78 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
503
681
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
504
682
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
505
683
|
num_batches += 1
|
|
684
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
685
|
+
loss_tensor = torch.tensor([accumulated_loss, num_batches], device=self.device, dtype=torch.float32)
|
|
686
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
687
|
+
accumulated_loss = loss_tensor[0].item()
|
|
688
|
+
num_batches = int(loss_tensor[1].item())
|
|
506
689
|
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
690
|
+
|
|
691
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
692
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
693
|
+
combined_user_ids_local = np.concatenate(user_ids_list, axis=0) if self.needs_user_ids and user_ids_list else None
|
|
694
|
+
|
|
695
|
+
# gather across ranks even when local is empty to avoid DDP hang
|
|
696
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
697
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
698
|
+
combined_user_ids = gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
699
|
+
|
|
700
|
+
if y_true_all is not None and y_pred_all is not None and len(y_true_all) > 0 and len(y_pred_all) > 0:
|
|
513
701
|
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
|
|
514
702
|
return avg_loss, metrics_dict
|
|
515
703
|
return avg_loss
|
|
516
704
|
|
|
517
|
-
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0,) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
705
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0, auto_distributed_sampler: bool = True,) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
518
706
|
if valid_data is None:
|
|
519
707
|
return None, None
|
|
520
708
|
if isinstance(valid_data, DataLoader):
|
|
521
|
-
|
|
522
|
-
|
|
709
|
+
if auto_distributed_sampler and self.distributed:
|
|
710
|
+
raise NotImplementedError("[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
|
|
711
|
+
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
712
|
+
else:
|
|
713
|
+
valid_loader = valid_data
|
|
714
|
+
return valid_loader, None
|
|
715
|
+
valid_sampler = None
|
|
716
|
+
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
717
|
+
if auto_distributed_sampler and self.distributed and valid_dataset is not None and dist.is_available() and dist.is_initialized():
|
|
718
|
+
valid_sampler = DistributedSampler(valid_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, drop_last=False)
|
|
719
|
+
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
523
720
|
valid_user_ids = None
|
|
524
721
|
if needs_user_ids:
|
|
525
722
|
if user_id_column is None:
|
|
526
723
|
raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
|
|
527
|
-
|
|
724
|
+
# In distributed mode, user_ids will be collected during evaluation from each batch
|
|
725
|
+
# and gathered across all processes, so we don't pre-extract them here
|
|
726
|
+
if not self.distributed:
|
|
727
|
+
valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
|
|
528
728
|
return valid_loader, valid_user_ids
|
|
529
729
|
|
|
530
|
-
def evaluate(
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
730
|
+
def evaluate(
|
|
731
|
+
self,
|
|
732
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
733
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
734
|
+
batch_size: int = 32,
|
|
735
|
+
user_ids: np.ndarray | None = None,
|
|
736
|
+
user_id_column: str = 'user_id',
|
|
737
|
+
num_workers: int = 0,) -> dict:
|
|
738
|
+
"""
|
|
739
|
+
**IMPORTANT for Distributed Training:**
|
|
740
|
+
in distributed mode, this method uses collective communication operations (all_gather).
|
|
741
|
+
all processes must call this method simultaneously, even if you only want results on rank 0.
|
|
742
|
+
failing to do so will cause the program to hang indefinitely.
|
|
743
|
+
|
|
744
|
+
Evaluate the model on the given data.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
data: Evaluation data (dict/df/DataLoader).
|
|
748
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
749
|
+
batch_size: Batch size (per process when distributed).
|
|
750
|
+
user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
|
|
751
|
+
user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
|
|
752
|
+
num_workers: DataLoader worker count.
|
|
753
|
+
"""
|
|
754
|
+
model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
|
|
755
|
+
model.eval()
|
|
538
756
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
539
757
|
if eval_metrics is None:
|
|
540
758
|
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
@@ -555,7 +773,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
555
773
|
batch_count += 1
|
|
556
774
|
batch_dict = batch_to_dict(batch_data)
|
|
557
775
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
558
|
-
y_pred =
|
|
776
|
+
y_pred = model(X_input)
|
|
559
777
|
if y_true is not None:
|
|
560
778
|
y_true_list.append(y_true.cpu().numpy())
|
|
561
779
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
@@ -564,20 +782,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
564
782
|
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
565
783
|
if batch_user_id is not None:
|
|
566
784
|
collected_user_ids.append(batch_user_id)
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
else:
|
|
573
|
-
y_true_all = None
|
|
574
|
-
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
575
|
-
|
|
576
|
-
if len(y_pred_list) > 0:
|
|
577
|
-
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
578
|
-
else:
|
|
579
|
-
y_pred_all = None
|
|
580
|
-
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
785
|
+
if self.is_main_process:
|
|
786
|
+
logging.info(" ")
|
|
787
|
+
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
788
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
789
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
581
790
|
|
|
582
791
|
# Convert metrics to list if it's a dict
|
|
583
792
|
if isinstance(eval_metrics, dict):
|
|
@@ -590,51 +799,86 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
590
799
|
metrics_to_use = unique_metrics
|
|
591
800
|
else:
|
|
592
801
|
metrics_to_use = eval_metrics
|
|
593
|
-
|
|
594
|
-
if
|
|
595
|
-
|
|
802
|
+
final_user_ids_local = user_ids
|
|
803
|
+
if final_user_ids_local is None and collected_user_ids:
|
|
804
|
+
final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
|
|
805
|
+
|
|
806
|
+
# gather across ranks even when local arrays are empty to keep collectives aligned
|
|
807
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
808
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
809
|
+
final_user_ids = gather_numpy(self, final_user_ids_local) if needs_user_ids else None
|
|
810
|
+
if y_true_all is None or y_pred_all is None or len(y_true_all) == 0 or len(y_pred_all) == 0:
|
|
811
|
+
if self.is_main_process:
|
|
812
|
+
logging.info(colorize(" Warning: Not enough evaluation data to compute metrics after gathering", color="yellow"))
|
|
813
|
+
return {}
|
|
814
|
+
if self.is_main_process:
|
|
815
|
+
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
596
816
|
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
|
|
597
817
|
return metrics_dict
|
|
598
818
|
|
|
599
819
|
def predict(
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
820
|
+
self,
|
|
821
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
822
|
+
batch_size: int = 32,
|
|
823
|
+
save_path: str | os.PathLike | None = None,
|
|
824
|
+
save_format: Literal["csv", "parquet"] = "csv",
|
|
825
|
+
include_ids: bool | None = None,
|
|
826
|
+
id_columns: str | list[str] | None = None,
|
|
827
|
+
return_dataframe: bool = True,
|
|
828
|
+
streaming_chunk_size: int = 10000,
|
|
829
|
+
num_workers: int = 0,
|
|
830
|
+
) -> pd.DataFrame | np.ndarray:
|
|
831
|
+
"""
|
|
832
|
+
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
833
|
+
Make predictions on the given data.
|
|
834
|
+
|
|
835
|
+
Args:
|
|
836
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
837
|
+
batch_size: Batch size for prediction (per process when distributed).
|
|
838
|
+
save_path: Optional path to save predictions; if None, predictions are not saved to disk.
|
|
839
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
840
|
+
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
841
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
842
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
843
|
+
streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
844
|
+
num_workers: DataLoader worker count.
|
|
845
|
+
"""
|
|
610
846
|
self.eval()
|
|
847
|
+
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
848
|
+
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
849
|
+
if isinstance(predict_id_columns, str):
|
|
850
|
+
predict_id_columns = [predict_id_columns]
|
|
851
|
+
|
|
611
852
|
if include_ids is None:
|
|
612
|
-
include_ids = bool(
|
|
613
|
-
include_ids = include_ids and bool(
|
|
853
|
+
include_ids = bool(predict_id_columns)
|
|
854
|
+
include_ids = include_ids and bool(predict_id_columns)
|
|
614
855
|
|
|
856
|
+
# Use streaming mode for large file saves without loading all data into memory
|
|
615
857
|
if save_path is not None and not return_dataframe:
|
|
616
|
-
return self.
|
|
617
|
-
|
|
618
|
-
|
|
858
|
+
return self.predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe, id_columns=predict_id_columns)
|
|
859
|
+
|
|
860
|
+
# Create DataLoader based on data type
|
|
861
|
+
if isinstance(data, DataLoader):
|
|
862
|
+
data_loader = data
|
|
863
|
+
elif isinstance(data, (str, os.PathLike)):
|
|
864
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=predict_id_columns,)
|
|
619
865
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
620
|
-
elif not isinstance(data, DataLoader):
|
|
621
|
-
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
622
866
|
else:
|
|
623
|
-
data_loader = data
|
|
867
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
624
868
|
|
|
625
|
-
y_pred_list
|
|
626
|
-
id_buffers
|
|
627
|
-
id_arrays
|
|
869
|
+
y_pred_list = []
|
|
870
|
+
id_buffers = {name: [] for name in (predict_id_columns or [])} if include_ids else {}
|
|
871
|
+
id_arrays = None
|
|
628
872
|
|
|
629
873
|
with torch.no_grad():
|
|
630
874
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
631
875
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
632
876
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
633
|
-
y_pred = self
|
|
877
|
+
y_pred = self(X_input)
|
|
634
878
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
635
879
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
636
|
-
if include_ids and
|
|
637
|
-
for id_name in
|
|
880
|
+
if include_ids and predict_id_columns and batch_dict.get("ids"):
|
|
881
|
+
for id_name in predict_id_columns:
|
|
638
882
|
if id_name not in batch_dict["ids"]:
|
|
639
883
|
continue
|
|
640
884
|
id_tensor = batch_dict["ids"][id_name]
|
|
@@ -657,7 +901,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
657
901
|
pred_columns.append(f"{name}_pred")
|
|
658
902
|
while len(pred_columns) < num_outputs:
|
|
659
903
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
660
|
-
if include_ids and
|
|
904
|
+
if include_ids and predict_id_columns:
|
|
661
905
|
id_arrays = {}
|
|
662
906
|
for id_name, pieces in id_buffers.items():
|
|
663
907
|
if pieces:
|
|
@@ -684,7 +928,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
684
928
|
df_to_save = output
|
|
685
929
|
else:
|
|
686
930
|
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
687
|
-
if include_ids and
|
|
931
|
+
if include_ids and predict_id_columns and id_arrays is not None:
|
|
688
932
|
id_df = pd.DataFrame(id_arrays)
|
|
689
933
|
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
690
934
|
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
|
|
@@ -696,7 +940,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
696
940
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
697
941
|
return output
|
|
698
942
|
|
|
699
|
-
def
|
|
943
|
+
def predict_streaming(
|
|
700
944
|
self,
|
|
701
945
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
702
946
|
batch_size: int,
|
|
@@ -705,9 +949,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
705
949
|
include_ids: bool,
|
|
706
950
|
streaming_chunk_size: int,
|
|
707
951
|
return_dataframe: bool,
|
|
952
|
+
id_columns: list[str] | None = None,
|
|
708
953
|
) -> pd.DataFrame:
|
|
709
954
|
if isinstance(data, (str, os.PathLike)):
|
|
710
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=
|
|
955
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=id_columns)
|
|
711
956
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
712
957
|
elif not isinstance(data, DataLoader):
|
|
713
958
|
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
@@ -720,8 +965,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
720
965
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
721
966
|
parquet_writer = None
|
|
722
967
|
|
|
723
|
-
pred_columns
|
|
724
|
-
collected_frames
|
|
968
|
+
pred_columns = None
|
|
969
|
+
collected_frames = [] # only used when return_dataframe is True
|
|
725
970
|
|
|
726
971
|
with torch.no_grad():
|
|
727
972
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
@@ -742,9 +987,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
742
987
|
while len(pred_columns) < num_outputs:
|
|
743
988
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
744
989
|
|
|
745
|
-
id_arrays_batch
|
|
746
|
-
if include_ids and
|
|
747
|
-
for id_name in
|
|
990
|
+
id_arrays_batch = {}
|
|
991
|
+
if include_ids and id_columns and batch_dict.get("ids"):
|
|
992
|
+
for id_name in id_columns:
|
|
748
993
|
if id_name not in batch_dict["ids"]:
|
|
749
994
|
continue
|
|
750
995
|
id_tensor = batch_dict["ids"][id_name]
|
|
@@ -784,7 +1029,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
784
1029
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
785
1030
|
target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
|
|
786
1031
|
model_path = Path(target_path)
|
|
787
|
-
|
|
1032
|
+
|
|
1033
|
+
model_to_save = (self.ddp_model.module if getattr(self, "ddp_model", None) is not None else self)
|
|
1034
|
+
torch.save(model_to_save.state_dict(), model_path)
|
|
1035
|
+
# torch.save(self.state_dict(), model_path)
|
|
788
1036
|
|
|
789
1037
|
config_path = self.features_config_path
|
|
790
1038
|
features_config = {
|
|
@@ -845,8 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
845
1093
|
**kwargs: Any,
|
|
846
1094
|
) -> "BaseModel":
|
|
847
1095
|
"""
|
|
848
|
-
|
|
849
|
-
|
|
1096
|
+
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
1097
|
+
a .model file and a features_config.pkl file.
|
|
850
1098
|
"""
|
|
851
1099
|
base_path = Path(checkpoint_path)
|
|
852
1100
|
verbose = kwargs.pop("verbose", True)
|
|
@@ -1006,10 +1254,10 @@ class BaseMatchModel(BaseModel):
|
|
|
1006
1254
|
@property
|
|
1007
1255
|
def model_name(self) -> str:
|
|
1008
1256
|
raise NotImplementedError
|
|
1009
|
-
|
|
1257
|
+
|
|
1010
1258
|
@property
|
|
1011
|
-
def
|
|
1012
|
-
|
|
1259
|
+
def default_task(self) -> str:
|
|
1260
|
+
return "binary"
|
|
1013
1261
|
|
|
1014
1262
|
@property
|
|
1015
1263
|
def support_training_modes(self) -> list[str]:
|