nextrec 0.3.5__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +0 -30
- nextrec/__version__.py +1 -1
- nextrec/basic/layers.py +32 -15
- nextrec/basic/loggers.py +1 -1
- nextrec/basic/model.py +440 -189
- nextrec/basic/session.py +4 -2
- nextrec/data/__init__.py +0 -25
- nextrec/data/data_processing.py +31 -19
- nextrec/data/dataloader.py +51 -16
- nextrec/models/generative/__init__.py +0 -5
- nextrec/models/generative/hstu.py +3 -2
- nextrec/models/match/__init__.py +0 -13
- nextrec/models/match/dssm.py +0 -1
- nextrec/models/match/dssm_v2.py +0 -1
- nextrec/models/match/mind.py +0 -1
- nextrec/models/match/sdm.py +0 -1
- nextrec/models/match/youtube_dnn.py +0 -1
- nextrec/models/multi_task/__init__.py +0 -0
- nextrec/models/multi_task/esmm.py +5 -7
- nextrec/models/multi_task/mmoe.py +10 -6
- nextrec/models/multi_task/ple.py +10 -6
- nextrec/models/multi_task/poso.py +9 -6
- nextrec/models/multi_task/share_bottom.py +10 -7
- nextrec/models/ranking/__init__.py +0 -27
- nextrec/models/ranking/afm.py +113 -21
- nextrec/models/ranking/autoint.py +15 -9
- nextrec/models/ranking/dcn.py +8 -11
- nextrec/models/ranking/deepfm.py +5 -5
- nextrec/models/ranking/dien.py +4 -4
- nextrec/models/ranking/din.py +4 -4
- nextrec/models/ranking/fibinet.py +4 -4
- nextrec/models/ranking/fm.py +4 -4
- nextrec/models/ranking/masknet.py +4 -5
- nextrec/models/ranking/pnn.py +4 -4
- nextrec/models/ranking/widedeep.py +4 -4
- nextrec/models/ranking/xdeepfm.py +4 -4
- nextrec/utils/__init__.py +7 -3
- nextrec/utils/device.py +32 -1
- nextrec/utils/distributed.py +114 -0
- nextrec/utils/synthetic_data.py +413 -0
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/METADATA +15 -5
- nextrec-0.4.1.dist-info/RECORD +66 -0
- nextrec-0.3.5.dist-info/RECORD +0 -63
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/WHEEL +0 -0
- {nextrec-0.3.5.dist-info → nextrec-0.4.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,10 +2,9 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 05/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
8
|
import os
|
|
10
9
|
import tqdm
|
|
11
10
|
import pickle
|
|
@@ -17,10 +16,13 @@ import pandas as pd
|
|
|
17
16
|
import torch
|
|
18
17
|
import torch.nn as nn
|
|
19
18
|
import torch.nn.functional as F
|
|
19
|
+
import torch.distributed as dist
|
|
20
20
|
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Union, Literal, Any
|
|
23
23
|
from torch.utils.data import DataLoader
|
|
24
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
25
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
24
26
|
|
|
25
27
|
from nextrec.basic.callback import EarlyStopper
|
|
26
28
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureSet
|
|
@@ -31,22 +33,23 @@ from nextrec.basic.session import resolve_save_path, create_session
|
|
|
31
33
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
32
34
|
|
|
33
35
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
34
|
-
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
36
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
37
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
36
38
|
|
|
37
39
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
38
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
40
|
from nextrec.utils.tensor import to_tensor
|
|
40
|
-
|
|
41
|
+
from nextrec.utils.device import configure_device
|
|
42
|
+
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
43
|
+
from nextrec.utils.distributed import gather_numpy, init_process_group, add_distributed_sampler
|
|
41
44
|
from nextrec import __version__
|
|
42
45
|
|
|
43
46
|
class BaseModel(FeatureSet, nn.Module):
|
|
44
47
|
@property
|
|
45
48
|
def model_name(self) -> str:
|
|
46
49
|
raise NotImplementedError
|
|
47
|
-
|
|
50
|
+
|
|
48
51
|
@property
|
|
49
|
-
def
|
|
52
|
+
def default_task(self) -> str | list[str]:
|
|
50
53
|
raise NotImplementedError
|
|
51
54
|
|
|
52
55
|
def __init__(self,
|
|
@@ -55,21 +58,57 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
55
58
|
sequence_features: list[SequenceFeature] | None = None,
|
|
56
59
|
target: list[str] | str | None = None,
|
|
57
60
|
id_columns: list[str] | str | None = None,
|
|
58
|
-
task: str|list[str] =
|
|
61
|
+
task: str | list[str] | None = None,
|
|
59
62
|
device: str = 'cpu',
|
|
63
|
+
early_stop_patience: int = 20,
|
|
64
|
+
session_id: str | None = None,
|
|
60
65
|
embedding_l1_reg: float = 0.0,
|
|
61
66
|
dense_l1_reg: float = 0.0,
|
|
62
67
|
embedding_l2_reg: float = 0.0,
|
|
63
68
|
dense_l2_reg: float = 0.0,
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
69
|
+
|
|
70
|
+
distributed: bool = False,
|
|
71
|
+
rank: int | None = None,
|
|
72
|
+
world_size: int | None = None,
|
|
73
|
+
local_rank: int | None = None,
|
|
74
|
+
ddp_find_unused_parameters: bool = False,):
|
|
75
|
+
"""
|
|
76
|
+
Initialize a base model.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
dense_features: DenseFeature definitions.
|
|
80
|
+
sparse_features: SparseFeature definitions.
|
|
81
|
+
sequence_features: SequenceFeature definitions.
|
|
82
|
+
target: Target column name.
|
|
83
|
+
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
84
|
+
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
85
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
86
|
+
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
87
|
+
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
88
|
+
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
89
|
+
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
90
|
+
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
91
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
92
|
+
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
93
|
+
rank: Global rank (defaults to env RANK).
|
|
94
|
+
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
95
|
+
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
96
|
+
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
97
|
+
"""
|
|
67
98
|
super(BaseModel, self).__init__()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
99
|
+
|
|
100
|
+
# distributed training settings
|
|
101
|
+
env_rank = int(os.environ.get("RANK", "0"))
|
|
102
|
+
env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
103
|
+
env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
104
|
+
self.distributed = distributed or (env_world_size > 1)
|
|
105
|
+
self.rank = env_rank if rank is None else rank
|
|
106
|
+
self.world_size = env_world_size if world_size is None else world_size
|
|
107
|
+
self.local_rank = env_local_rank if local_rank is None else local_rank
|
|
108
|
+
self.is_main_process = self.rank == 0
|
|
109
|
+
self.ddp_find_unused_parameters = ddp_find_unused_parameters
|
|
110
|
+
self.ddp_model: DDP | None = None
|
|
111
|
+
self.device = configure_device(self.distributed, self.local_rank, device)
|
|
73
112
|
|
|
74
113
|
self.session_id = session_id
|
|
75
114
|
self.session = create_session(session_id)
|
|
@@ -79,8 +118,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
79
118
|
self.features_config_path = os.path.join(self.session_path, "features_config.pkl")
|
|
80
119
|
self.set_all_features(dense_features, sparse_features, sequence_features, target, id_columns)
|
|
81
120
|
|
|
82
|
-
self.task = task
|
|
83
|
-
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
121
|
+
self.task = self.default_task if task is None else task
|
|
122
|
+
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
84
123
|
|
|
85
124
|
self.embedding_l1_reg = embedding_l1_reg
|
|
86
125
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -89,10 +128,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
89
128
|
self.regularization_weights = []
|
|
90
129
|
self.embedding_params = []
|
|
91
130
|
self.loss_weight = None
|
|
131
|
+
|
|
92
132
|
self.early_stop_patience = early_stop_patience
|
|
93
133
|
self.max_gradient_norm = 1.0
|
|
94
134
|
self.logger_initialized = False
|
|
95
|
-
self.training_logger
|
|
135
|
+
self.training_logger = None
|
|
96
136
|
|
|
97
137
|
def register_regularization_weights(self, embedding_attr: str = "embedding", exclude_modules: list[str] | None = None, include_modules: list[str] | None = None) -> None:
|
|
98
138
|
exclude_modules = exclude_modules or []
|
|
@@ -145,18 +185,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
145
185
|
raise ValueError(f"[BaseModel-input Error] Target column '{target_name}' contains no data.")
|
|
146
186
|
continue
|
|
147
187
|
target_tensor = to_tensor(target_data, dtype=torch.float32, device=self.device)
|
|
148
|
-
target_tensor = target_tensor.view(target_tensor.size(0), -1)
|
|
188
|
+
target_tensor = target_tensor.view(target_tensor.size(0), -1) # always reshape to (batch_size, num_targets)
|
|
149
189
|
target_tensors.append(target_tensor)
|
|
150
190
|
if target_tensors:
|
|
151
191
|
y = torch.cat(target_tensors, dim=1)
|
|
152
|
-
if y.shape[1] == 1:
|
|
192
|
+
if y.shape[1] == 1: # no need to do that again
|
|
153
193
|
y = y.view(-1)
|
|
154
194
|
elif require_labels:
|
|
155
195
|
raise ValueError("[BaseModel-input Error] Labels are required but none were found in the input batch.")
|
|
156
196
|
return X_input, y
|
|
157
197
|
|
|
158
|
-
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool,
|
|
159
|
-
"""
|
|
198
|
+
def handle_validation_split(self, train_data: dict | pd.DataFrame, validation_split: float, batch_size: int, shuffle: bool, num_workers: int = 0,):
|
|
199
|
+
"""
|
|
200
|
+
This function will split training data into training and validation sets when:
|
|
201
|
+
1. valid_data is None;
|
|
202
|
+
2. validation_split is provided.
|
|
203
|
+
"""
|
|
160
204
|
if not (0 < validation_split < 1):
|
|
161
205
|
raise ValueError(f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}")
|
|
162
206
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
@@ -184,20 +228,35 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
184
228
|
arr = np.asarray(value)
|
|
185
229
|
train_split[key] = arr[train_indices]
|
|
186
230
|
valid_split[key] = arr[valid_indices]
|
|
187
|
-
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle)
|
|
231
|
+
train_loader = self.prepare_data_loader(train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
188
232
|
logging.info(f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples")
|
|
189
233
|
return train_loader, valid_split
|
|
190
234
|
|
|
191
235
|
def compile(
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
236
|
+
self,
|
|
237
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
238
|
+
optimizer_params: dict | None = None,
|
|
239
|
+
scheduler: str | torch.optim.lr_scheduler._LRScheduler | torch.optim.lr_scheduler.LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | type[torch.optim.lr_scheduler.LRScheduler] | None = None,
|
|
240
|
+
scheduler_params: dict | None = None,
|
|
241
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
242
|
+
loss_params: dict | list[dict] | None = None,
|
|
243
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
Configure the model for training.
|
|
247
|
+
Args:
|
|
248
|
+
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
249
|
+
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
250
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
|
|
251
|
+
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
252
|
+
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
253
|
+
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
254
|
+
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
255
|
+
"""
|
|
256
|
+
if loss_params is None:
|
|
257
|
+
self.loss_params = {}
|
|
258
|
+
else:
|
|
259
|
+
self.loss_params = loss_params
|
|
201
260
|
optimizer_params = optimizer_params or {}
|
|
202
261
|
self.optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
203
262
|
self.optimizer_params = optimizer_params
|
|
@@ -217,7 +276,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
217
276
|
self.loss_params = loss_params or {}
|
|
218
277
|
self.loss_fn = []
|
|
219
278
|
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
220
|
-
|
|
279
|
+
if len(loss) != self.nums_task:
|
|
280
|
+
raise ValueError(f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task}).")
|
|
281
|
+
loss_list = [loss[i] for i in range(self.nums_task)]
|
|
221
282
|
else: # for example: 'bce' -> ['bce', 'bce']
|
|
222
283
|
loss_list = [loss] * self.nums_task
|
|
223
284
|
|
|
@@ -231,12 +292,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
231
292
|
self.loss_weights = None
|
|
232
293
|
elif self.nums_task == 1:
|
|
233
294
|
if isinstance(loss_weights, (list, tuple)):
|
|
234
|
-
if len(loss_weights) != 1
|
|
295
|
+
if len(loss_weights) != 1:
|
|
235
296
|
raise ValueError("[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup.")
|
|
236
297
|
weight_value = loss_weights[0]
|
|
237
298
|
else:
|
|
238
299
|
weight_value = loss_weights
|
|
239
|
-
self.loss_weights = float(weight_value)
|
|
300
|
+
self.loss_weights = [float(weight_value)]
|
|
240
301
|
else:
|
|
241
302
|
if isinstance(loss_weights, (int, float)):
|
|
242
303
|
weights = [float(loss_weights)] * self.nums_task
|
|
@@ -250,29 +311,48 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
250
311
|
|
|
251
312
|
def compute_loss(self, y_pred, y_true):
|
|
252
313
|
if y_true is None:
|
|
253
|
-
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required
|
|
314
|
+
raise ValueError("[BaseModel-compute_loss Error] Ground truth labels (y_true) are required.")
|
|
254
315
|
if self.nums_task == 1:
|
|
255
|
-
|
|
316
|
+
if y_pred.dim() == 1:
|
|
317
|
+
y_pred = y_pred.view(-1, 1)
|
|
318
|
+
if y_true.dim() == 1:
|
|
319
|
+
y_true = y_true.view(-1, 1)
|
|
320
|
+
if y_pred.shape != y_true.shape:
|
|
321
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
322
|
+
task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
323
|
+
if task_dim == 1:
|
|
324
|
+
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
325
|
+
else:
|
|
326
|
+
loss = self.loss_fn[0](y_pred, y_true)
|
|
256
327
|
if self.loss_weights is not None:
|
|
257
|
-
loss
|
|
328
|
+
loss *= self.loss_weights[0]
|
|
258
329
|
return loss
|
|
330
|
+
# multi-task
|
|
331
|
+
if y_pred.shape != y_true.shape:
|
|
332
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
333
|
+
if hasattr(self, "prediction_layer"): # we need to use registered task_slices for multi-task and multi-class
|
|
334
|
+
slices = self.prediction_layer._task_slices # type: ignore
|
|
259
335
|
else:
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
336
|
+
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
337
|
+
task_losses = []
|
|
338
|
+
for i, (start, end) in enumerate(slices): # type: ignore
|
|
339
|
+
y_pred_i = y_pred[:, start:end]
|
|
340
|
+
y_true_i = y_true[:, start:end]
|
|
341
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
342
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
343
|
+
task_loss *= self.loss_weights[i]
|
|
344
|
+
task_losses.append(task_loss)
|
|
345
|
+
return torch.stack(task_losses).sum()
|
|
346
|
+
|
|
347
|
+
def prepare_data_loader(self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 32, shuffle: bool = True, num_workers: int = 0, sampler=None, return_dataset: bool = False) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
269
348
|
if isinstance(data, DataLoader):
|
|
270
|
-
return data
|
|
349
|
+
return (data, None) if return_dataset else data
|
|
271
350
|
tensors = build_tensors_from_data(data=data, raw_data=data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
|
|
272
351
|
if tensors is None:
|
|
273
352
|
raise ValueError("[BaseModel-prepare_data_loader Error] No data available to create DataLoader.")
|
|
274
353
|
dataset = TensorDictDataset(tensors)
|
|
275
|
-
|
|
354
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False if sampler is not None else shuffle, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
355
|
+
return (loader, dataset) if return_dataset else loader
|
|
276
356
|
|
|
277
357
|
def fit(self,
|
|
278
358
|
train_data: dict | pd.DataFrame | DataLoader,
|
|
@@ -281,27 +361,83 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
281
361
|
epochs:int=1, shuffle:bool=True, batch_size:int=32,
|
|
282
362
|
user_id_column: str | None = None,
|
|
283
363
|
validation_split: float | None = None,
|
|
284
|
-
|
|
364
|
+
num_workers: int = 0,
|
|
365
|
+
tensorboard: bool = True,
|
|
366
|
+
auto_distributed_sampler: bool = True,):
|
|
367
|
+
"""
|
|
368
|
+
Train the model.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
|
|
372
|
+
valid_data: Optional validation data; if None and validation_split is set, a split is created.
|
|
373
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
374
|
+
epochs: Training epochs.
|
|
375
|
+
shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
|
|
376
|
+
batch_size: Batch size (per process when distributed).
|
|
377
|
+
user_id_column: Column name for GAUC-style metrics;.
|
|
378
|
+
validation_split: Ratio to split training data when valid_data is None.
|
|
379
|
+
num_workers: DataLoader worker count.
|
|
380
|
+
tensorboard: Enable tensorboard logging.
|
|
381
|
+
auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
382
|
+
|
|
383
|
+
Notes:
|
|
384
|
+
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
385
|
+
- All ranks must call evaluate() together because it performs collective ops.
|
|
386
|
+
"""
|
|
387
|
+
device_id = self.local_rank if self.device.type == "cuda" else None
|
|
388
|
+
init_process_group(self.distributed, self.rank, self.world_size, device_id=device_id)
|
|
285
389
|
self.to(self.device)
|
|
286
|
-
|
|
390
|
+
|
|
391
|
+
if self.distributed and dist.is_available() and dist.is_initialized() and self.ddp_model is None:
|
|
392
|
+
device_ids = [self.local_rank] if self.device.type == "cuda" else None # device_ids means which device to use in ddp
|
|
393
|
+
output_device = self.local_rank if self.device.type == "cuda" else None # output_device means which device to place the output in ddp
|
|
394
|
+
object.__setattr__(self, "ddp_model", DDP(self, device_ids=device_ids, output_device=output_device, find_unused_parameters=self.ddp_find_unused_parameters))
|
|
395
|
+
|
|
396
|
+
if not self.logger_initialized and self.is_main_process: # only main process initializes logger
|
|
287
397
|
setup_logger(session_id=self.session_id)
|
|
288
398
|
self.logger_initialized = True
|
|
289
|
-
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
399
|
+
self.training_logger = TrainingLogger(session=self.session, enable_tensorboard=tensorboard) if self.is_main_process else None
|
|
290
400
|
|
|
291
401
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = configure_metrics(task=self.task, metrics=metrics, target_names=self.target_columns) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
292
402
|
self.early_stopper = EarlyStopper(patience=self.early_stop_patience, mode=self.best_metrics_mode)
|
|
403
|
+
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
404
|
+
|
|
293
405
|
self.needs_user_ids = check_user_id(self.metrics, self.task_specific_metrics) # check user_id needed for GAUC metrics
|
|
294
406
|
self.epoch_index = 0
|
|
295
407
|
self.stop_training = False
|
|
296
408
|
self.best_checkpoint_path = self.best_path
|
|
297
|
-
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
298
409
|
|
|
410
|
+
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
411
|
+
logging.info(colorize("[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.", color="yellow"))
|
|
412
|
+
|
|
413
|
+
train_sampler: DistributedSampler | None = None
|
|
299
414
|
if validation_split is not None and valid_data is None:
|
|
300
|
-
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle,) # type: ignore
|
|
415
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
416
|
+
if auto_distributed_sampler and self.distributed and dist.is_available() and dist.is_initialized():
|
|
417
|
+
base_dataset = getattr(train_loader, "dataset", None)
|
|
418
|
+
if base_dataset is not None and not isinstance(getattr(train_loader, "sampler", None), DistributedSampler):
|
|
419
|
+
train_sampler = DistributedSampler(base_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
|
|
420
|
+
train_loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
|
|
301
421
|
else:
|
|
302
|
-
|
|
422
|
+
if isinstance(train_data, DataLoader):
|
|
423
|
+
if auto_distributed_sampler and self.distributed:
|
|
424
|
+
train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
425
|
+
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
426
|
+
else:
|
|
427
|
+
train_loader = train_data
|
|
428
|
+
else:
|
|
429
|
+
loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
430
|
+
if auto_distributed_sampler and self.distributed and dataset is not None and dist.is_available() and dist.is_initialized():
|
|
431
|
+
train_sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True)
|
|
432
|
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, drop_last=True)
|
|
433
|
+
train_loader = loader
|
|
434
|
+
|
|
435
|
+
# If split-based loader was built without sampler, attach here when enabled
|
|
436
|
+
if self.distributed and auto_distributed_sampler and isinstance(train_loader, DataLoader) and train_sampler is None:
|
|
437
|
+
raise NotImplementedError("[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
|
|
438
|
+
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
303
439
|
|
|
304
|
-
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column)
|
|
440
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(valid_data=valid_data, batch_size=batch_size, needs_user_ids=self.needs_user_ids, user_id_column=user_id_column, num_workers=num_workers, auto_distributed_sampler=auto_distributed_sampler)
|
|
305
441
|
try:
|
|
306
442
|
self.steps_per_epoch = len(train_loader)
|
|
307
443
|
is_streaming = False
|
|
@@ -309,38 +445,41 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
309
445
|
self.steps_per_epoch = None
|
|
310
446
|
is_streaming = True
|
|
311
447
|
|
|
312
|
-
self.
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
448
|
+
if self.is_main_process:
|
|
449
|
+
self.summary()
|
|
450
|
+
logging.info("")
|
|
451
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
452
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
453
|
+
if tb_dir:
|
|
454
|
+
user = getpass.getuser()
|
|
455
|
+
host = socket.gethostname()
|
|
456
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
457
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
458
|
+
logging.info(colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan"))
|
|
459
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
460
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
461
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
462
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
463
|
+
|
|
464
|
+
logging.info("")
|
|
465
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
466
|
+
if is_streaming:
|
|
467
|
+
logging.info(colorize(f"Start streaming training", bold=True))
|
|
468
|
+
else:
|
|
469
|
+
logging.info(colorize(f"Start training", bold=True))
|
|
470
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
471
|
+
logging.info("")
|
|
472
|
+
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
336
473
|
|
|
337
474
|
for epoch in range(epochs):
|
|
338
475
|
self.epoch_index = epoch
|
|
339
|
-
if is_streaming:
|
|
476
|
+
if is_streaming and self.is_main_process:
|
|
340
477
|
logging.info("")
|
|
341
478
|
logging.info(colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)) # streaming mode, print epoch header before progress bar
|
|
342
479
|
|
|
343
480
|
# handle train result
|
|
481
|
+
if self.distributed and hasattr(train_loader, "sampler") and isinstance(train_loader.sampler, DistributedSampler):
|
|
482
|
+
train_loader.sampler.set_epoch(epoch)
|
|
344
483
|
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
345
484
|
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
346
485
|
train_loss, train_metrics = train_result
|
|
@@ -355,7 +494,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
355
494
|
if train_metrics:
|
|
356
495
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in train_metrics.items()])
|
|
357
496
|
log_str += f", {metrics_str}"
|
|
358
|
-
|
|
497
|
+
if self.is_main_process:
|
|
498
|
+
logging.info(colorize(log_str))
|
|
359
499
|
train_log_payload["loss"] = float(train_loss)
|
|
360
500
|
if train_metrics:
|
|
361
501
|
train_log_payload.update(train_metrics)
|
|
@@ -380,7 +520,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
380
520
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
381
521
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
382
522
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
383
|
-
|
|
523
|
+
if self.is_main_process:
|
|
524
|
+
logging.info(colorize(log_str))
|
|
384
525
|
train_log_payload["loss"] = float(total_loss_val)
|
|
385
526
|
if train_metrics:
|
|
386
527
|
train_log_payload.update(train_metrics)
|
|
@@ -388,10 +529,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
388
529
|
self.training_logger.log_metrics(train_log_payload, step=epoch + 1, split="train")
|
|
389
530
|
if valid_loader is not None:
|
|
390
531
|
# pass user_ids only if needed for GAUC metric
|
|
391
|
-
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
532
|
+
val_metrics = self.evaluate(valid_loader, user_ids=valid_user_ids if self.needs_user_ids else None, num_workers=num_workers) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
392
533
|
if self.nums_task == 1:
|
|
393
534
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in val_metrics.items()])
|
|
394
|
-
|
|
535
|
+
if self.is_main_process:
|
|
536
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}", color="cyan"))
|
|
395
537
|
else:
|
|
396
538
|
# multi task metrics
|
|
397
539
|
task_metrics = {}
|
|
@@ -408,20 +550,29 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
408
550
|
if target_name in task_metrics:
|
|
409
551
|
metrics_str = ", ".join([f"{k}={v:.4f}" for k, v in task_metrics[target_name].items()])
|
|
410
552
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
411
|
-
|
|
553
|
+
if self.is_main_process:
|
|
554
|
+
logging.info(colorize(f" Epoch {epoch + 1}/{epochs} - Valid: " + ", ".join(task_metric_strs), color="cyan"))
|
|
412
555
|
if val_metrics and self.training_logger:
|
|
413
556
|
self.training_logger.log_metrics(val_metrics, step=epoch + 1, split="valid")
|
|
414
557
|
# Handle empty validation metrics
|
|
415
558
|
if not val_metrics:
|
|
416
|
-
self.
|
|
417
|
-
|
|
418
|
-
|
|
559
|
+
if self.is_main_process:
|
|
560
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
561
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
562
|
+
logging.info(colorize(f"Warning: No validation metrics computed. Skipping validation for this epoch.", color="yellow"))
|
|
419
563
|
continue
|
|
420
564
|
if self.nums_task == 1:
|
|
421
565
|
primary_metric_key = self.metrics[0]
|
|
422
566
|
else:
|
|
423
567
|
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
424
568
|
primary_metric = val_metrics.get(primary_metric_key, val_metrics[list(val_metrics.keys())[0]]) # get primary metric value, default to first metric if not found
|
|
569
|
+
|
|
570
|
+
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
571
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
572
|
+
metric_tensor = torch.tensor([primary_metric], device=self.device, dtype=torch.float32)
|
|
573
|
+
dist.broadcast(metric_tensor, src=0)
|
|
574
|
+
primary_metric = float(metric_tensor.item())
|
|
575
|
+
|
|
425
576
|
improved = False
|
|
426
577
|
# early stopping check
|
|
427
578
|
if self.best_metrics_mode == 'max':
|
|
@@ -432,24 +583,40 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
432
583
|
if primary_metric < self.best_metric:
|
|
433
584
|
self.best_metric = primary_metric
|
|
434
585
|
improved = True
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
586
|
+
|
|
587
|
+
# save checkpoint and best model for main process
|
|
588
|
+
if self.is_main_process:
|
|
589
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
590
|
+
logging.info(" ")
|
|
591
|
+
if improved:
|
|
592
|
+
logging.info(colorize(f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"))
|
|
593
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
594
|
+
self.best_checkpoint_path = self.best_path
|
|
595
|
+
self.early_stopper.trial_counter = 0
|
|
596
|
+
else:
|
|
597
|
+
self.early_stopper.trial_counter += 1
|
|
598
|
+
logging.info(colorize(f"No improvement for {self.early_stopper.trial_counter} epoch(s)"))
|
|
599
|
+
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
600
|
+
self.stop_training = True
|
|
601
|
+
logging.info(colorize(f"Early stopping triggered after {epoch + 1} epochs", color="bright_red", bold=True))
|
|
442
602
|
else:
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
break
|
|
603
|
+
# Non-main processes also update trial_counter to keep in sync
|
|
604
|
+
if improved:
|
|
605
|
+
self.early_stopper.trial_counter = 0
|
|
606
|
+
else:
|
|
607
|
+
self.early_stopper.trial_counter += 1
|
|
449
608
|
else:
|
|
450
|
-
self.
|
|
451
|
-
|
|
452
|
-
|
|
609
|
+
if self.is_main_process:
|
|
610
|
+
self.save_model(self.checkpoint_path, add_timestamp=False, verbose=False)
|
|
611
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
612
|
+
self.best_checkpoint_path = self.best_path
|
|
613
|
+
|
|
614
|
+
# Broadcast stop_training flag to all processes (always, regardless of validation)
|
|
615
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
616
|
+
stop_tensor = torch.tensor([int(self.stop_training)], device=self.device)
|
|
617
|
+
dist.broadcast(stop_tensor, src=0)
|
|
618
|
+
self.stop_training = bool(stop_tensor.item())
|
|
619
|
+
|
|
453
620
|
if self.stop_training:
|
|
454
621
|
break
|
|
455
622
|
if self.scheduler_fn is not None:
|
|
@@ -458,41 +625,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
458
625
|
self.scheduler_fn.step(primary_metric)
|
|
459
626
|
else:
|
|
460
627
|
self.scheduler_fn.step()
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
628
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
629
|
+
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
630
|
+
if self.is_main_process:
|
|
631
|
+
logging.info(" ")
|
|
632
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
633
|
+
logging.info(" ")
|
|
464
634
|
if valid_loader is not None:
|
|
465
|
-
|
|
635
|
+
if self.is_main_process:
|
|
636
|
+
logging.info(colorize(f"Load best model from: {self.best_checkpoint_path}"))
|
|
466
637
|
self.load_model(self.best_checkpoint_path, map_location=self.device, verbose=False)
|
|
467
638
|
if self.training_logger:
|
|
468
639
|
self.training_logger.close()
|
|
469
640
|
return self
|
|
470
641
|
|
|
471
642
|
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
643
|
+
# use ddp model for distributed training
|
|
644
|
+
model = self.ddp_model if getattr(self, "ddp_model") is not None else self
|
|
472
645
|
accumulated_loss = 0.0
|
|
473
|
-
|
|
646
|
+
model.train() # type: ignore
|
|
474
647
|
num_batches = 0
|
|
475
648
|
y_true_list = []
|
|
476
649
|
y_pred_list = []
|
|
477
650
|
|
|
478
651
|
user_ids_list = [] if self.needs_user_ids else None
|
|
652
|
+
tqdm_disable = not self.is_main_process
|
|
479
653
|
if self.steps_per_epoch is not None:
|
|
480
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch))
|
|
654
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=f"Epoch {self.epoch_index + 1}", total=self.steps_per_epoch, disable=tqdm_disable))
|
|
481
655
|
else:
|
|
482
656
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
483
|
-
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc))
|
|
657
|
+
batch_iter = enumerate(tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable))
|
|
484
658
|
for batch_index, batch_data in batch_iter:
|
|
485
659
|
batch_dict = batch_to_dict(batch_data)
|
|
486
660
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
487
|
-
|
|
661
|
+
# call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
|
|
662
|
+
y_pred = model(X_input) # type: ignore
|
|
663
|
+
|
|
488
664
|
loss = self.compute_loss(y_pred, y_true)
|
|
489
665
|
reg_loss = self.add_reg_loss()
|
|
490
666
|
total_loss = loss + reg_loss
|
|
491
667
|
self.optimizer_fn.zero_grad()
|
|
492
668
|
total_loss.backward()
|
|
493
|
-
|
|
669
|
+
|
|
670
|
+
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
671
|
+
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
494
672
|
self.optimizer_fn.step()
|
|
495
673
|
accumulated_loss += loss.item()
|
|
674
|
+
|
|
496
675
|
if y_true is not None:
|
|
497
676
|
y_true_list.append(y_true.detach().cpu().numpy())
|
|
498
677
|
if self.needs_user_ids and user_ids_list is not None:
|
|
@@ -502,37 +681,78 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
502
681
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
503
682
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
504
683
|
num_batches += 1
|
|
684
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
685
|
+
loss_tensor = torch.tensor([accumulated_loss, num_batches], device=self.device, dtype=torch.float32)
|
|
686
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
687
|
+
accumulated_loss = loss_tensor[0].item()
|
|
688
|
+
num_batches = int(loss_tensor[1].item())
|
|
505
689
|
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
690
|
+
|
|
691
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
692
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
693
|
+
combined_user_ids_local = np.concatenate(user_ids_list, axis=0) if self.needs_user_ids and user_ids_list else None
|
|
694
|
+
|
|
695
|
+
# gather across ranks even when local is empty to avoid DDP hang
|
|
696
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
697
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
698
|
+
combined_user_ids = gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
699
|
+
|
|
700
|
+
if y_true_all is not None and y_pred_all is not None and len(y_true_all) > 0 and len(y_pred_all) > 0:
|
|
512
701
|
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=self.metrics, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=combined_user_ids)
|
|
513
702
|
return avg_loss, metrics_dict
|
|
514
703
|
return avg_loss
|
|
515
704
|
|
|
516
|
-
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id') -> tuple[DataLoader | None, np.ndarray | None]:
|
|
705
|
+
def prepare_validation_data(self, valid_data: dict | pd.DataFrame | DataLoader | None, batch_size: int, needs_user_ids: bool, user_id_column: str | None = 'user_id', num_workers: int = 0, auto_distributed_sampler: bool = True,) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
517
706
|
if valid_data is None:
|
|
518
707
|
return None, None
|
|
519
708
|
if isinstance(valid_data, DataLoader):
|
|
520
|
-
|
|
521
|
-
|
|
709
|
+
if auto_distributed_sampler and self.distributed:
|
|
710
|
+
raise NotImplementedError("[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet.")
|
|
711
|
+
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
712
|
+
else:
|
|
713
|
+
valid_loader = valid_data
|
|
714
|
+
return valid_loader, None
|
|
715
|
+
valid_sampler = None
|
|
716
|
+
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
717
|
+
if auto_distributed_sampler and self.distributed and valid_dataset is not None and dist.is_available() and dist.is_initialized():
|
|
718
|
+
valid_sampler = DistributedSampler(valid_dataset, num_replicas=self.world_size, rank=self.rank, shuffle=False, drop_last=False)
|
|
719
|
+
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, sampler=valid_sampler, collate_fn=collate_fn, num_workers=num_workers)
|
|
522
720
|
valid_user_ids = None
|
|
523
721
|
if needs_user_ids:
|
|
524
722
|
if user_id_column is None:
|
|
525
723
|
raise ValueError("[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics.")
|
|
526
|
-
|
|
724
|
+
# In distributed mode, user_ids will be collected during evaluation from each batch
|
|
725
|
+
# and gathered across all processes, so we don't pre-extract them here
|
|
726
|
+
if not self.distributed:
|
|
727
|
+
valid_user_ids = get_user_ids(data=valid_data, id_columns=user_id_column)
|
|
527
728
|
return valid_loader, valid_user_ids
|
|
528
729
|
|
|
529
|
-
def evaluate(
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
730
|
+
def evaluate(
|
|
731
|
+
self,
|
|
732
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
733
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
734
|
+
batch_size: int = 32,
|
|
735
|
+
user_ids: np.ndarray | None = None,
|
|
736
|
+
user_id_column: str = 'user_id',
|
|
737
|
+
num_workers: int = 0,) -> dict:
|
|
738
|
+
"""
|
|
739
|
+
**IMPORTANT for Distributed Training:**
|
|
740
|
+
in distributed mode, this method uses collective communication operations (all_gather).
|
|
741
|
+
all processes must call this method simultaneously, even if you only want results on rank 0.
|
|
742
|
+
failing to do so will cause the program to hang indefinitely.
|
|
743
|
+
|
|
744
|
+
Evaluate the model on the given data.
|
|
745
|
+
|
|
746
|
+
Args:
|
|
747
|
+
data: Evaluation data (dict/df/DataLoader).
|
|
748
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
749
|
+
batch_size: Batch size (per process when distributed).
|
|
750
|
+
user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
|
|
751
|
+
user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
|
|
752
|
+
num_workers: DataLoader worker count.
|
|
753
|
+
"""
|
|
754
|
+
model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
|
|
755
|
+
model.eval()
|
|
536
756
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
537
757
|
if eval_metrics is None:
|
|
538
758
|
raise ValueError("[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first.")
|
|
@@ -543,7 +763,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
543
763
|
else:
|
|
544
764
|
if user_ids is None and needs_user_ids:
|
|
545
765
|
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
546
|
-
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False)
|
|
766
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
547
767
|
y_true_list = []
|
|
548
768
|
y_pred_list = []
|
|
549
769
|
collected_user_ids = []
|
|
@@ -553,7 +773,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
553
773
|
batch_count += 1
|
|
554
774
|
batch_dict = batch_to_dict(batch_data)
|
|
555
775
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
556
|
-
y_pred =
|
|
776
|
+
y_pred = model(X_input)
|
|
557
777
|
if y_true is not None:
|
|
558
778
|
y_true_list.append(y_true.cpu().numpy())
|
|
559
779
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
@@ -562,20 +782,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
562
782
|
batch_user_id = get_user_ids(data=batch_dict, id_columns=self.id_columns)
|
|
563
783
|
if batch_user_id is not None:
|
|
564
784
|
collected_user_ids.append(batch_user_id)
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
else:
|
|
571
|
-
y_true_all = None
|
|
572
|
-
logging.info(colorize(f" Warning: No y_true collected from evaluation data", color="yellow"))
|
|
573
|
-
|
|
574
|
-
if len(y_pred_list) > 0:
|
|
575
|
-
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
576
|
-
else:
|
|
577
|
-
y_pred_all = None
|
|
578
|
-
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
785
|
+
if self.is_main_process:
|
|
786
|
+
logging.info(" ")
|
|
787
|
+
logging.info(colorize(f" Evaluation batches processed: {batch_count}", color="cyan"))
|
|
788
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
789
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
579
790
|
|
|
580
791
|
# Convert metrics to list if it's a dict
|
|
581
792
|
if isinstance(eval_metrics, dict):
|
|
@@ -588,50 +799,86 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
588
799
|
metrics_to_use = unique_metrics
|
|
589
800
|
else:
|
|
590
801
|
metrics_to_use = eval_metrics
|
|
591
|
-
|
|
592
|
-
if
|
|
593
|
-
|
|
802
|
+
final_user_ids_local = user_ids
|
|
803
|
+
if final_user_ids_local is None and collected_user_ids:
|
|
804
|
+
final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
|
|
805
|
+
|
|
806
|
+
# gather across ranks even when local arrays are empty to keep collectives aligned
|
|
807
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
808
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
809
|
+
final_user_ids = gather_numpy(self, final_user_ids_local) if needs_user_ids else None
|
|
810
|
+
if y_true_all is None or y_pred_all is None or len(y_true_all) == 0 or len(y_pred_all) == 0:
|
|
811
|
+
if self.is_main_process:
|
|
812
|
+
logging.info(colorize(" Warning: Not enough evaluation data to compute metrics after gathering", color="yellow"))
|
|
813
|
+
return {}
|
|
814
|
+
if self.is_main_process:
|
|
815
|
+
logging.info(colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan"))
|
|
594
816
|
metrics_dict = evaluate_metrics(y_true=y_true_all, y_pred=y_pred_all, metrics=metrics_to_use, task=self.task, target_names=self.target_columns, task_specific_metrics=self.task_specific_metrics, user_ids=final_user_ids,)
|
|
595
817
|
return metrics_dict
|
|
596
818
|
|
|
597
819
|
def predict(
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
820
|
+
self,
|
|
821
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
822
|
+
batch_size: int = 32,
|
|
823
|
+
save_path: str | os.PathLike | None = None,
|
|
824
|
+
save_format: Literal["csv", "parquet"] = "csv",
|
|
825
|
+
include_ids: bool | None = None,
|
|
826
|
+
id_columns: str | list[str] | None = None,
|
|
827
|
+
return_dataframe: bool = True,
|
|
828
|
+
streaming_chunk_size: int = 10000,
|
|
829
|
+
num_workers: int = 0,
|
|
830
|
+
) -> pd.DataFrame | np.ndarray:
|
|
831
|
+
"""
|
|
832
|
+
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
833
|
+
Make predictions on the given data.
|
|
834
|
+
|
|
835
|
+
Args:
|
|
836
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
837
|
+
batch_size: Batch size for prediction (per process when distributed).
|
|
838
|
+
save_path: Optional path to save predictions; if None, predictions are not saved to disk.
|
|
839
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
840
|
+
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
841
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
842
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
843
|
+
streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
844
|
+
num_workers: DataLoader worker count.
|
|
845
|
+
"""
|
|
607
846
|
self.eval()
|
|
847
|
+
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
848
|
+
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
849
|
+
if isinstance(predict_id_columns, str):
|
|
850
|
+
predict_id_columns = [predict_id_columns]
|
|
851
|
+
|
|
608
852
|
if include_ids is None:
|
|
609
|
-
include_ids = bool(
|
|
610
|
-
include_ids = include_ids and bool(
|
|
853
|
+
include_ids = bool(predict_id_columns)
|
|
854
|
+
include_ids = include_ids and bool(predict_id_columns)
|
|
611
855
|
|
|
856
|
+
# Use streaming mode for large file saves without loading all data into memory
|
|
612
857
|
if save_path is not None and not return_dataframe:
|
|
613
|
-
return self.
|
|
614
|
-
|
|
615
|
-
|
|
858
|
+
return self.predict_streaming(data=data, batch_size=batch_size, save_path=save_path, save_format=save_format, include_ids=include_ids, streaming_chunk_size=streaming_chunk_size, return_dataframe=return_dataframe, id_columns=predict_id_columns)
|
|
859
|
+
|
|
860
|
+
# Create DataLoader based on data type
|
|
861
|
+
if isinstance(data, DataLoader):
|
|
862
|
+
data_loader = data
|
|
863
|
+
elif isinstance(data, (str, os.PathLike)):
|
|
864
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=predict_id_columns,)
|
|
616
865
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
617
|
-
elif not isinstance(data, DataLoader):
|
|
618
|
-
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
619
866
|
else:
|
|
620
|
-
data_loader = data
|
|
867
|
+
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
|
621
868
|
|
|
622
|
-
y_pred_list
|
|
623
|
-
id_buffers
|
|
624
|
-
id_arrays
|
|
869
|
+
y_pred_list = []
|
|
870
|
+
id_buffers = {name: [] for name in (predict_id_columns or [])} if include_ids else {}
|
|
871
|
+
id_arrays = None
|
|
625
872
|
|
|
626
873
|
with torch.no_grad():
|
|
627
874
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
628
875
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
629
876
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
630
|
-
y_pred = self
|
|
877
|
+
y_pred = self(X_input)
|
|
631
878
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
632
879
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
633
|
-
if include_ids and
|
|
634
|
-
for id_name in
|
|
880
|
+
if include_ids and predict_id_columns and batch_dict.get("ids"):
|
|
881
|
+
for id_name in predict_id_columns:
|
|
635
882
|
if id_name not in batch_dict["ids"]:
|
|
636
883
|
continue
|
|
637
884
|
id_tensor = batch_dict["ids"][id_name]
|
|
@@ -654,7 +901,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
654
901
|
pred_columns.append(f"{name}_pred")
|
|
655
902
|
while len(pred_columns) < num_outputs:
|
|
656
903
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
657
|
-
if include_ids and
|
|
904
|
+
if include_ids and predict_id_columns:
|
|
658
905
|
id_arrays = {}
|
|
659
906
|
for id_name, pieces in id_buffers.items():
|
|
660
907
|
if pieces:
|
|
@@ -681,7 +928,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
681
928
|
df_to_save = output
|
|
682
929
|
else:
|
|
683
930
|
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
684
|
-
if include_ids and
|
|
931
|
+
if include_ids and predict_id_columns and id_arrays is not None:
|
|
685
932
|
id_df = pd.DataFrame(id_arrays)
|
|
686
933
|
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
687
934
|
raise ValueError(f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)}).")
|
|
@@ -693,7 +940,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
693
940
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
694
941
|
return output
|
|
695
942
|
|
|
696
|
-
def
|
|
943
|
+
def predict_streaming(
|
|
697
944
|
self,
|
|
698
945
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
699
946
|
batch_size: int,
|
|
@@ -702,9 +949,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
702
949
|
include_ids: bool,
|
|
703
950
|
streaming_chunk_size: int,
|
|
704
951
|
return_dataframe: bool,
|
|
952
|
+
id_columns: list[str] | None = None,
|
|
705
953
|
) -> pd.DataFrame:
|
|
706
954
|
if isinstance(data, (str, os.PathLike)):
|
|
707
|
-
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=
|
|
955
|
+
rec_loader = RecDataLoader(dense_features=self.dense_features, sparse_features=self.sparse_features, sequence_features=self.sequence_features, target=self.target_columns, id_columns=id_columns)
|
|
708
956
|
data_loader = rec_loader.create_dataloader(data=data, batch_size=batch_size, shuffle=False, load_full=False, chunk_size=streaming_chunk_size,)
|
|
709
957
|
elif not isinstance(data, DataLoader):
|
|
710
958
|
data_loader = self.prepare_data_loader(data, batch_size=batch_size, shuffle=False,)
|
|
@@ -717,8 +965,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
717
965
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
718
966
|
parquet_writer = None
|
|
719
967
|
|
|
720
|
-
pred_columns
|
|
721
|
-
collected_frames
|
|
968
|
+
pred_columns = None
|
|
969
|
+
collected_frames = [] # only used when return_dataframe is True
|
|
722
970
|
|
|
723
971
|
with torch.no_grad():
|
|
724
972
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
@@ -739,9 +987,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
739
987
|
while len(pred_columns) < num_outputs:
|
|
740
988
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
741
989
|
|
|
742
|
-
id_arrays_batch
|
|
743
|
-
if include_ids and
|
|
744
|
-
for id_name in
|
|
990
|
+
id_arrays_batch = {}
|
|
991
|
+
if include_ids and id_columns and batch_dict.get("ids"):
|
|
992
|
+
for id_name in id_columns:
|
|
745
993
|
if id_name not in batch_dict["ids"]:
|
|
746
994
|
continue
|
|
747
995
|
id_tensor = batch_dict["ids"][id_name]
|
|
@@ -781,7 +1029,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
781
1029
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
782
1030
|
target_path = resolve_save_path(path=save_path, default_dir=self.session_path, default_name=self.model_name, suffix=".model", add_timestamp=add_timestamp)
|
|
783
1031
|
model_path = Path(target_path)
|
|
784
|
-
|
|
1032
|
+
|
|
1033
|
+
model_to_save = (self.ddp_model.module if getattr(self, "ddp_model", None) is not None else self)
|
|
1034
|
+
torch.save(model_to_save.state_dict(), model_path)
|
|
1035
|
+
# torch.save(self.state_dict(), model_path)
|
|
785
1036
|
|
|
786
1037
|
config_path = self.features_config_path
|
|
787
1038
|
features_config = {
|
|
@@ -842,8 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
842
1093
|
**kwargs: Any,
|
|
843
1094
|
) -> "BaseModel":
|
|
844
1095
|
"""
|
|
845
|
-
|
|
846
|
-
|
|
1096
|
+
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
1097
|
+
a .model file and a features_config.pkl file.
|
|
847
1098
|
"""
|
|
848
1099
|
base_path = Path(checkpoint_path)
|
|
849
1100
|
verbose = kwargs.pop("verbose", True)
|
|
@@ -1003,10 +1254,10 @@ class BaseMatchModel(BaseModel):
|
|
|
1003
1254
|
@property
|
|
1004
1255
|
def model_name(self) -> str:
|
|
1005
1256
|
raise NotImplementedError
|
|
1006
|
-
|
|
1257
|
+
|
|
1007
1258
|
@property
|
|
1008
|
-
def
|
|
1009
|
-
|
|
1259
|
+
def default_task(self) -> str:
|
|
1260
|
+
return "binary"
|
|
1010
1261
|
|
|
1011
1262
|
@property
|
|
1012
1263
|
def support_training_modes(self) -> list[str]:
|