nextrec 0.3.6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +244 -113
- nextrec/basic/loggers.py +62 -43
- nextrec/basic/metrics.py +268 -119
- nextrec/basic/model.py +1373 -443
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +498 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +42 -24
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +303 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +106 -40
- nextrec/models/match/dssm.py +82 -69
- nextrec/models/match/dssm_v2.py +72 -58
- nextrec/models/match/mind.py +175 -108
- nextrec/models/match/sdm.py +104 -88
- nextrec/models/match/youtube_dnn.py +73 -60
- nextrec/models/multi_task/esmm.py +53 -39
- nextrec/models/multi_task/mmoe.py +70 -47
- nextrec/models/multi_task/ple.py +107 -50
- nextrec/models/multi_task/poso.py +121 -41
- nextrec/models/multi_task/share_bottom.py +54 -38
- nextrec/models/ranking/afm.py +172 -45
- nextrec/models/ranking/autoint.py +84 -61
- nextrec/models/ranking/dcn.py +59 -42
- nextrec/models/ranking/dcn_v2.py +64 -23
- nextrec/models/ranking/deepfm.py +36 -26
- nextrec/models/ranking/dien.py +158 -102
- nextrec/models/ranking/din.py +88 -60
- nextrec/models/ranking/fibinet.py +55 -35
- nextrec/models/ranking/fm.py +32 -26
- nextrec/models/ranking/masknet.py +95 -34
- nextrec/models/ranking/pnn.py +34 -31
- nextrec/models/ranking/widedeep.py +37 -29
- nextrec/models/ranking/xdeepfm.py +63 -41
- nextrec/utils/__init__.py +61 -32
- nextrec/utils/config.py +490 -0
- nextrec/utils/device.py +52 -12
- nextrec/utils/distributed.py +141 -0
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +32 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +531 -0
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/METADATA +15 -5
- nextrec-0.4.2.dist-info/RECORD +69 -0
- nextrec-0.4.2.dist-info/entry_points.txt +2 -0
- nextrec-0.3.6.dist-info/RECORD +0 -64
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/WHEEL +0 -0
- {nextrec-0.3.6.dist-info → nextrec-0.4.2.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 05/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -17,13 +17,21 @@ import pandas as pd
|
|
|
17
17
|
import torch
|
|
18
18
|
import torch.nn as nn
|
|
19
19
|
import torch.nn.functional as F
|
|
20
|
+
import torch.distributed as dist
|
|
20
21
|
|
|
21
22
|
from pathlib import Path
|
|
22
23
|
from typing import Union, Literal, Any
|
|
23
24
|
from torch.utils.data import DataLoader
|
|
25
|
+
from torch.utils.data.distributed import DistributedSampler
|
|
26
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
24
27
|
|
|
25
28
|
from nextrec.basic.callback import EarlyStopper
|
|
26
|
-
from nextrec.basic.features import
|
|
29
|
+
from nextrec.basic.features import (
|
|
30
|
+
DenseFeature,
|
|
31
|
+
SparseFeature,
|
|
32
|
+
SequenceFeature,
|
|
33
|
+
FeatureSet,
|
|
34
|
+
)
|
|
27
35
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
28
36
|
|
|
29
37
|
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
@@ -31,79 +39,149 @@ from nextrec.basic.session import resolve_save_path, create_session
|
|
|
31
39
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics, check_user_id
|
|
32
40
|
|
|
33
41
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
34
|
-
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
35
42
|
from nextrec.data.batch_utils import collate_fn, batch_to_dict
|
|
43
|
+
from nextrec.data.data_processing import get_column_data, get_user_ids
|
|
36
44
|
|
|
37
45
|
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
38
|
-
from nextrec.utils import get_optimizer, get_scheduler
|
|
39
46
|
from nextrec.utils.tensor import to_tensor
|
|
40
|
-
|
|
47
|
+
from nextrec.utils.device import configure_device
|
|
48
|
+
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
49
|
+
from nextrec.utils.distributed import (
|
|
50
|
+
gather_numpy,
|
|
51
|
+
init_process_group,
|
|
52
|
+
add_distributed_sampler,
|
|
53
|
+
)
|
|
41
54
|
from nextrec import __version__
|
|
42
55
|
|
|
56
|
+
|
|
43
57
|
class BaseModel(FeatureSet, nn.Module):
|
|
44
58
|
@property
|
|
45
59
|
def model_name(self) -> str:
|
|
46
60
|
raise NotImplementedError
|
|
47
|
-
|
|
61
|
+
|
|
48
62
|
@property
|
|
49
|
-
def
|
|
63
|
+
def default_task(self) -> str | list[str]:
|
|
50
64
|
raise NotImplementedError
|
|
51
65
|
|
|
52
|
-
def __init__(
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
dense_features: list[DenseFeature] | None = None,
|
|
69
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
70
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
71
|
+
target: list[str] | str | None = None,
|
|
72
|
+
id_columns: list[str] | str | None = None,
|
|
73
|
+
task: str | list[str] | None = None,
|
|
74
|
+
device: str = "cpu",
|
|
75
|
+
early_stop_patience: int = 20,
|
|
76
|
+
session_id: str | None = None,
|
|
77
|
+
embedding_l1_reg: float = 0.0,
|
|
78
|
+
dense_l1_reg: float = 0.0,
|
|
79
|
+
embedding_l2_reg: float = 0.0,
|
|
80
|
+
dense_l2_reg: float = 0.0,
|
|
81
|
+
distributed: bool = False,
|
|
82
|
+
rank: int | None = None,
|
|
83
|
+
world_size: int | None = None,
|
|
84
|
+
local_rank: int | None = None,
|
|
85
|
+
ddp_find_unused_parameters: bool = False,
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Initialize a base model.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
dense_features: DenseFeature definitions.
|
|
92
|
+
sparse_features: SparseFeature definitions.
|
|
93
|
+
sequence_features: SequenceFeature definitions.
|
|
94
|
+
target: Target column name.
|
|
95
|
+
id_columns: Identifier column name, only need to specify if GAUC is required.
|
|
96
|
+
task: Task types, e.g., 'binary', 'regression', or ['binary', 'regression']. If None, falls back to self.default_task.
|
|
97
|
+
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
98
|
+
embedding_l1_reg: L1 regularization strength for embedding params. e.g., 1e-6.
|
|
99
|
+
dense_l1_reg: L1 regularization strength for dense params. e.g., 1e-5.
|
|
100
|
+
embedding_l2_reg: L2 regularization strength for embedding params. e.g., 1e-5.
|
|
101
|
+
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
102
|
+
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
103
|
+
session_id: Session id for logging. If None, a default id with timestamps will be created.
|
|
104
|
+
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
105
|
+
rank: Global rank (defaults to env RANK).
|
|
106
|
+
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
107
|
+
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
108
|
+
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
109
|
+
"""
|
|
67
110
|
super(BaseModel, self).__init__()
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
111
|
+
|
|
112
|
+
# distributed training settings
|
|
113
|
+
env_rank = int(os.environ.get("RANK", "0"))
|
|
114
|
+
env_world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
|
115
|
+
env_local_rank = int(os.environ.get("LOCAL_RANK", "0"))
|
|
116
|
+
self.distributed = distributed or (env_world_size > 1)
|
|
117
|
+
self.rank = env_rank if rank is None else rank
|
|
118
|
+
self.world_size = env_world_size if world_size is None else world_size
|
|
119
|
+
self.local_rank = env_local_rank if local_rank is None else local_rank
|
|
120
|
+
self.is_main_process = self.rank == 0
|
|
121
|
+
self.ddp_find_unused_parameters = ddp_find_unused_parameters
|
|
122
|
+
self.ddp_model: DDP | None = None
|
|
123
|
+
self.device = configure_device(self.distributed, self.local_rank, device)
|
|
73
124
|
|
|
74
125
|
self.session_id = session_id
|
|
75
126
|
self.session = create_session(session_id)
|
|
76
|
-
self.session_path = self.session.root
|
|
77
|
-
self.checkpoint_path = os.path.join(
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
self.
|
|
127
|
+
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
128
|
+
self.checkpoint_path = os.path.join(
|
|
129
|
+
self.session_path, self.model_name + "_checkpoint.model"
|
|
130
|
+
) # example: pwd/session_id/DeepFM_checkpoint.model
|
|
131
|
+
self.best_path = os.path.join(
|
|
132
|
+
self.session_path, self.model_name + "_best.model"
|
|
133
|
+
)
|
|
134
|
+
self.features_config_path = os.path.join(
|
|
135
|
+
self.session_path, "features_config.pkl"
|
|
136
|
+
)
|
|
137
|
+
self.set_all_features(
|
|
138
|
+
dense_features, sparse_features, sequence_features, target, id_columns
|
|
139
|
+
)
|
|
81
140
|
|
|
82
|
-
self.task = task
|
|
83
|
-
self.nums_task = len(task) if isinstance(task, list) else 1
|
|
141
|
+
self.task = self.default_task if task is None else task
|
|
142
|
+
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
84
143
|
|
|
85
144
|
self.embedding_l1_reg = embedding_l1_reg
|
|
86
145
|
self.dense_l1_reg = dense_l1_reg
|
|
87
146
|
self.embedding_l2_reg = embedding_l2_reg
|
|
88
147
|
self.dense_l2_reg = dense_l2_reg
|
|
89
|
-
self.regularization_weights = []
|
|
148
|
+
self.regularization_weights = []
|
|
90
149
|
self.embedding_params = []
|
|
91
150
|
self.loss_weight = None
|
|
151
|
+
|
|
92
152
|
self.early_stop_patience = early_stop_patience
|
|
93
|
-
self.max_gradient_norm = 1.0
|
|
153
|
+
self.max_gradient_norm = 1.0
|
|
94
154
|
self.logger_initialized = False
|
|
95
|
-
self.training_logger
|
|
155
|
+
self.training_logger = None
|
|
96
156
|
|
|
97
|
-
def register_regularization_weights(
|
|
157
|
+
def register_regularization_weights(
|
|
158
|
+
self,
|
|
159
|
+
embedding_attr: str = "embedding",
|
|
160
|
+
exclude_modules: list[str] | None = None,
|
|
161
|
+
include_modules: list[str] | None = None,
|
|
162
|
+
) -> None:
|
|
98
163
|
exclude_modules = exclude_modules or []
|
|
99
164
|
include_modules = include_modules or []
|
|
100
165
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
101
166
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
102
167
|
if embed_dict is not None:
|
|
103
168
|
self.embedding_params.extend(embed.weight for embed in embed_dict.values())
|
|
104
|
-
skip_types = (
|
|
169
|
+
skip_types = (
|
|
170
|
+
nn.BatchNorm1d,
|
|
171
|
+
nn.BatchNorm2d,
|
|
172
|
+
nn.BatchNorm3d,
|
|
173
|
+
nn.Dropout,
|
|
174
|
+
nn.Dropout2d,
|
|
175
|
+
nn.Dropout3d,
|
|
176
|
+
)
|
|
105
177
|
for name, module in self.named_modules():
|
|
106
|
-
if (
|
|
178
|
+
if (
|
|
179
|
+
module is self
|
|
180
|
+
or embedding_attr in name
|
|
181
|
+
or isinstance(module, skip_types)
|
|
182
|
+
or (include_modules and not any(inc in name for inc in include_modules))
|
|
183
|
+
or any(exc in name for exc in exclude_modules)
|
|
184
|
+
):
|
|
107
185
|
continue
|
|
108
186
|
if isinstance(module, nn.Linear):
|
|
109
187
|
self.regularization_weights.append(module.weight)
|
|
@@ -112,14 +190,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
112
190
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
113
191
|
if self.embedding_params:
|
|
114
192
|
if self.embedding_l1_reg > 0:
|
|
115
|
-
reg_loss += self.embedding_l1_reg * sum(
|
|
193
|
+
reg_loss += self.embedding_l1_reg * sum(
|
|
194
|
+
param.abs().sum() for param in self.embedding_params
|
|
195
|
+
)
|
|
116
196
|
if self.embedding_l2_reg > 0:
|
|
117
|
-
reg_loss += self.embedding_l2_reg * sum(
|
|
197
|
+
reg_loss += self.embedding_l2_reg * sum(
|
|
198
|
+
(param**2).sum() for param in self.embedding_params
|
|
199
|
+
)
|
|
118
200
|
if self.regularization_weights:
|
|
119
201
|
if self.dense_l1_reg > 0:
|
|
120
|
-
reg_loss += self.dense_l1_reg * sum(
|
|
202
|
+
reg_loss += self.dense_l1_reg * sum(
|
|
203
|
+
param.abs().sum() for param in self.regularization_weights
|
|
204
|
+
)
|
|
121
205
|
if self.dense_l2_reg > 0:
|
|
122
|
-
reg_loss += self.dense_l2_reg * sum(
|
|
206
|
+
reg_loss += self.dense_l2_reg * sum(
|
|
207
|
+
(param**2).sum() for param in self.regularization_weights
|
|
208
|
+
)
|
|
123
209
|
return reg_loss
|
|
124
210
|
|
|
125
211
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
@@ -128,47 +214,90 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
128
214
|
X_input = {}
|
|
129
215
|
for feature in self.all_features:
|
|
130
216
|
if feature.name not in feature_source:
|
|
131
|
-
raise KeyError(
|
|
217
|
+
raise KeyError(
|
|
218
|
+
f"[BaseModel-input Error] Feature '{feature.name}' not found in input data."
|
|
219
|
+
)
|
|
132
220
|
feature_data = get_column_data(feature_source, feature.name)
|
|
133
|
-
X_input[feature.name] = to_tensor(
|
|
221
|
+
X_input[feature.name] = to_tensor(
|
|
222
|
+
feature_data,
|
|
223
|
+
dtype=(
|
|
224
|
+
torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
225
|
+
),
|
|
226
|
+
device=self.device,
|
|
227
|
+
)
|
|
134
228
|
y = None
|
|
135
|
-
if
|
|
229
|
+
if len(self.target_columns) > 0 and (
|
|
230
|
+
require_labels
|
|
231
|
+
or (
|
|
232
|
+
label_source
|
|
233
|
+
and any(name in label_source for name in self.target_columns)
|
|
234
|
+
)
|
|
235
|
+
): # need labels: training or eval with labels
|
|
136
236
|
target_tensors = []
|
|
137
237
|
for target_name in self.target_columns:
|
|
138
238
|
if label_source is None or target_name not in label_source:
|
|
139
239
|
if require_labels:
|
|
140
|
-
raise KeyError(
|
|
240
|
+
raise KeyError(
|
|
241
|
+
f"[BaseModel-input Error] Target column '{target_name}' not found in input data."
|
|
242
|
+
)
|
|
141
243
|
continue
|
|
142
244
|
target_data = get_column_data(label_source, target_name)
|
|
143
245
|
if target_data is None:
|
|
144
246
|
if require_labels:
|
|
145
|
-
raise ValueError(
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"[BaseModel-input Error] Target column '{target_name}' contains no data."
|
|
249
|
+
)
|
|
146
250
|
continue
|
|
147
|
-
target_tensor = to_tensor(
|
|
148
|
-
|
|
251
|
+
target_tensor = to_tensor(
|
|
252
|
+
target_data, dtype=torch.float32, device=self.device
|
|
253
|
+
)
|
|
254
|
+
target_tensor = target_tensor.view(
|
|
255
|
+
target_tensor.size(0), -1
|
|
256
|
+
) # always reshape to (batch_size, num_targets)
|
|
149
257
|
target_tensors.append(target_tensor)
|
|
150
258
|
if target_tensors:
|
|
151
259
|
y = torch.cat(target_tensors, dim=1)
|
|
152
|
-
if y.shape[1] == 1:
|
|
260
|
+
if y.shape[1] == 1: # no need to do that again
|
|
153
261
|
y = y.view(-1)
|
|
154
262
|
elif require_labels:
|
|
155
|
-
raise ValueError(
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"[BaseModel-input Error] Labels are required but none were found in the input batch."
|
|
265
|
+
)
|
|
156
266
|
return X_input, y
|
|
157
267
|
|
|
158
|
-
def handle_validation_split(
|
|
159
|
-
|
|
268
|
+
def handle_validation_split(
|
|
269
|
+
self,
|
|
270
|
+
train_data: dict | pd.DataFrame,
|
|
271
|
+
validation_split: float,
|
|
272
|
+
batch_size: int,
|
|
273
|
+
shuffle: bool,
|
|
274
|
+
num_workers: int = 0,
|
|
275
|
+
):
|
|
276
|
+
"""
|
|
277
|
+
This function will split training data into training and validation sets when:
|
|
278
|
+
1. valid_data is None;
|
|
279
|
+
2. validation_split is provided.
|
|
280
|
+
"""
|
|
160
281
|
if not (0 < validation_split < 1):
|
|
161
|
-
raise ValueError(
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
|
|
284
|
+
)
|
|
162
285
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
163
|
-
raise TypeError(
|
|
286
|
+
raise TypeError(
|
|
287
|
+
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
288
|
+
)
|
|
164
289
|
if isinstance(train_data, pd.DataFrame):
|
|
165
290
|
total_length = len(train_data)
|
|
166
291
|
else:
|
|
167
|
-
sample_key = next(
|
|
168
|
-
|
|
292
|
+
sample_key = next(
|
|
293
|
+
iter(train_data)
|
|
294
|
+
) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
295
|
+
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
169
296
|
for k, v in train_data.items():
|
|
170
297
|
if len(v) != total_length:
|
|
171
|
-
raise ValueError(
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
|
|
300
|
+
)
|
|
172
301
|
rng = np.random.default_rng(42)
|
|
173
302
|
indices = rng.permutation(total_length)
|
|
174
303
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -181,169 +310,444 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
181
310
|
train_split = {}
|
|
182
311
|
valid_split = {}
|
|
183
312
|
for key, value in train_data.items():
|
|
184
|
-
arr = np.asarray(value)
|
|
313
|
+
arr = np.asarray(value)
|
|
185
314
|
train_split[key] = arr[train_indices]
|
|
186
315
|
valid_split[key] = arr[valid_indices]
|
|
187
|
-
train_loader = self.prepare_data_loader(
|
|
188
|
-
|
|
316
|
+
train_loader = self.prepare_data_loader(
|
|
317
|
+
train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
|
|
318
|
+
)
|
|
319
|
+
logging.info(
|
|
320
|
+
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
|
|
321
|
+
)
|
|
189
322
|
return train_loader, valid_split
|
|
190
323
|
|
|
191
324
|
def compile(
|
|
192
325
|
self,
|
|
193
326
|
optimizer: str | torch.optim.Optimizer = "adam",
|
|
194
327
|
optimizer_params: dict | None = None,
|
|
195
|
-
scheduler:
|
|
328
|
+
scheduler: (
|
|
329
|
+
str
|
|
330
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
331
|
+
| torch.optim.lr_scheduler.LRScheduler
|
|
332
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
333
|
+
| type[torch.optim.lr_scheduler.LRScheduler]
|
|
334
|
+
| None
|
|
335
|
+
) = None,
|
|
196
336
|
scheduler_params: dict | None = None,
|
|
197
337
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
198
338
|
loss_params: dict | list[dict] | None = None,
|
|
199
339
|
loss_weights: int | float | list[int | float] | None = None,
|
|
200
340
|
):
|
|
341
|
+
"""
|
|
342
|
+
Configure the model for training.
|
|
343
|
+
Args:
|
|
344
|
+
optimizer: Optimizer name or instance. e.g., 'adam', 'sgd', or torch.optim.Adam().
|
|
345
|
+
optimizer_params: Optimizer parameters. e.g., {'lr': 1e-3, 'weight_decay': 1e-5}.
|
|
346
|
+
scheduler: Learning rate scheduler name or instance. e.g., 'step_lr', 'cosine_annealing', or torch.optim.lr_scheduler.StepLR().
|
|
347
|
+
scheduler_params: Scheduler parameters. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
348
|
+
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
349
|
+
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
350
|
+
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
351
|
+
"""
|
|
352
|
+
if loss_params is None:
|
|
353
|
+
self.loss_params = {}
|
|
354
|
+
else:
|
|
355
|
+
self.loss_params = loss_params
|
|
201
356
|
optimizer_params = optimizer_params or {}
|
|
202
|
-
self.optimizer_name =
|
|
357
|
+
self.optimizer_name = (
|
|
358
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
359
|
+
)
|
|
203
360
|
self.optimizer_params = optimizer_params
|
|
204
|
-
self.optimizer_fn = get_optimizer(
|
|
361
|
+
self.optimizer_fn = get_optimizer(
|
|
362
|
+
optimizer=optimizer,
|
|
363
|
+
params=self.parameters(),
|
|
364
|
+
**optimizer_params,
|
|
365
|
+
)
|
|
205
366
|
|
|
206
367
|
scheduler_params = scheduler_params or {}
|
|
207
368
|
if isinstance(scheduler, str):
|
|
208
369
|
self.scheduler_name = scheduler
|
|
209
370
|
elif scheduler is None:
|
|
210
371
|
self.scheduler_name = None
|
|
211
|
-
else:
|
|
212
|
-
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
|
|
372
|
+
else: # for custom scheduler instance, need to provide class name for logging
|
|
373
|
+
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
213
374
|
self.scheduler_params = scheduler_params
|
|
214
|
-
self.scheduler_fn = (
|
|
375
|
+
self.scheduler_fn = (
|
|
376
|
+
get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
|
|
377
|
+
if scheduler
|
|
378
|
+
else None
|
|
379
|
+
)
|
|
215
380
|
|
|
216
381
|
self.loss_config = loss
|
|
217
382
|
self.loss_params = loss_params or {}
|
|
218
383
|
self.loss_fn = []
|
|
219
|
-
if isinstance(loss, list):
|
|
220
|
-
|
|
221
|
-
|
|
384
|
+
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
385
|
+
if len(loss) != self.nums_task:
|
|
386
|
+
raise ValueError(
|
|
387
|
+
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
388
|
+
)
|
|
389
|
+
loss_list = [loss[i] for i in range(self.nums_task)]
|
|
390
|
+
else: # for example: 'bce' -> ['bce', 'bce']
|
|
222
391
|
loss_list = [loss] * self.nums_task
|
|
223
392
|
|
|
224
393
|
if isinstance(self.loss_params, dict):
|
|
225
394
|
params_list = [self.loss_params] * self.nums_task
|
|
226
395
|
else: # list[dict]
|
|
227
|
-
params_list = [
|
|
228
|
-
|
|
396
|
+
params_list = [
|
|
397
|
+
self.loss_params[i] if i < len(self.loss_params) else {}
|
|
398
|
+
for i in range(self.nums_task)
|
|
399
|
+
]
|
|
400
|
+
self.loss_fn = [
|
|
401
|
+
get_loss_fn(loss=loss_list[i], **params_list[i])
|
|
402
|
+
for i in range(self.nums_task)
|
|
403
|
+
]
|
|
229
404
|
|
|
230
405
|
if loss_weights is None:
|
|
231
406
|
self.loss_weights = None
|
|
232
407
|
elif self.nums_task == 1:
|
|
233
408
|
if isinstance(loss_weights, (list, tuple)):
|
|
234
|
-
if len(loss_weights) != 1
|
|
235
|
-
raise ValueError(
|
|
409
|
+
if len(loss_weights) != 1:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
412
|
+
)
|
|
236
413
|
weight_value = loss_weights[0]
|
|
237
414
|
else:
|
|
238
415
|
weight_value = loss_weights
|
|
239
|
-
self.loss_weights = float(weight_value)
|
|
416
|
+
self.loss_weights = [float(weight_value)]
|
|
240
417
|
else:
|
|
241
418
|
if isinstance(loss_weights, (int, float)):
|
|
242
419
|
weights = [float(loss_weights)] * self.nums_task
|
|
243
420
|
elif isinstance(loss_weights, (list, tuple)):
|
|
244
421
|
weights = [float(w) for w in loss_weights]
|
|
245
422
|
if len(weights) != self.nums_task:
|
|
246
|
-
raise ValueError(
|
|
423
|
+
raise ValueError(
|
|
424
|
+
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
425
|
+
)
|
|
247
426
|
else:
|
|
248
|
-
raise TypeError(
|
|
427
|
+
raise TypeError(
|
|
428
|
+
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
429
|
+
)
|
|
249
430
|
self.loss_weights = weights
|
|
250
431
|
|
|
251
432
|
def compute_loss(self, y_pred, y_true):
|
|
252
433
|
if y_true is None:
|
|
253
|
-
raise ValueError(
|
|
434
|
+
raise ValueError(
|
|
435
|
+
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
436
|
+
)
|
|
254
437
|
if self.nums_task == 1:
|
|
255
|
-
|
|
438
|
+
if y_pred.dim() == 1:
|
|
439
|
+
y_pred = y_pred.view(-1, 1)
|
|
440
|
+
if y_true.dim() == 1:
|
|
441
|
+
y_true = y_true.view(-1, 1)
|
|
442
|
+
if y_pred.shape != y_true.shape:
|
|
443
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
444
|
+
task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
445
|
+
if task_dim == 1:
|
|
446
|
+
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
447
|
+
else:
|
|
448
|
+
loss = self.loss_fn[0](y_pred, y_true)
|
|
256
449
|
if self.loss_weights is not None:
|
|
257
|
-
loss
|
|
450
|
+
loss *= self.loss_weights[0]
|
|
258
451
|
return loss
|
|
452
|
+
# multi-task
|
|
453
|
+
if y_pred.shape != y_true.shape:
|
|
454
|
+
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
455
|
+
if hasattr(
|
|
456
|
+
self, "prediction_layer"
|
|
457
|
+
): # we need to use registered task_slices for multi-task and multi-class
|
|
458
|
+
slices = self.prediction_layer._task_slices # type: ignore
|
|
259
459
|
else:
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
460
|
+
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
461
|
+
task_losses = []
|
|
462
|
+
for i, (start, end) in enumerate(slices): # type: ignore
|
|
463
|
+
y_pred_i = y_pred[:, start:end]
|
|
464
|
+
y_true_i = y_true[:, start:end]
|
|
465
|
+
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
466
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
467
|
+
task_loss *= self.loss_weights[i]
|
|
468
|
+
task_losses.append(task_loss)
|
|
469
|
+
return torch.stack(task_losses).sum()
|
|
470
|
+
|
|
471
|
+
def prepare_data_loader(
|
|
472
|
+
self,
|
|
473
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
474
|
+
batch_size: int = 32,
|
|
475
|
+
shuffle: bool = True,
|
|
476
|
+
num_workers: int = 0,
|
|
477
|
+
sampler=None,
|
|
478
|
+
return_dataset: bool = False,
|
|
479
|
+
) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
269
480
|
if isinstance(data, DataLoader):
|
|
270
|
-
return data
|
|
271
|
-
tensors = build_tensors_from_data(
|
|
481
|
+
return (data, None) if return_dataset else data
|
|
482
|
+
tensors = build_tensors_from_data(
|
|
483
|
+
data=data,
|
|
484
|
+
raw_data=data,
|
|
485
|
+
features=self.all_features,
|
|
486
|
+
target_columns=self.target_columns,
|
|
487
|
+
id_columns=self.id_columns,
|
|
488
|
+
)
|
|
272
489
|
if tensors is None:
|
|
273
|
-
raise ValueError(
|
|
490
|
+
raise ValueError(
|
|
491
|
+
"[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
|
|
492
|
+
)
|
|
274
493
|
dataset = TensorDictDataset(tensors)
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
494
|
+
loader = DataLoader(
|
|
495
|
+
dataset,
|
|
496
|
+
batch_size=batch_size,
|
|
497
|
+
shuffle=False if sampler is not None else shuffle,
|
|
498
|
+
sampler=sampler,
|
|
499
|
+
collate_fn=collate_fn,
|
|
500
|
+
num_workers=num_workers,
|
|
501
|
+
)
|
|
502
|
+
return (loader, dataset) if return_dataset else loader
|
|
503
|
+
|
|
504
|
+
def fit(
|
|
505
|
+
self,
|
|
506
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
507
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
508
|
+
metrics: (
|
|
509
|
+
list[str] | dict[str, list[str]] | None
|
|
510
|
+
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
511
|
+
epochs: int = 1,
|
|
512
|
+
shuffle: bool = True,
|
|
513
|
+
batch_size: int = 32,
|
|
514
|
+
user_id_column: str | None = None,
|
|
515
|
+
validation_split: float | None = None,
|
|
516
|
+
num_workers: int = 0,
|
|
517
|
+
tensorboard: bool = True,
|
|
518
|
+
auto_distributed_sampler: bool = True,
|
|
519
|
+
):
|
|
520
|
+
"""
|
|
521
|
+
Train the model.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
|
|
525
|
+
valid_data: Optional validation data; if None and validation_split is set, a split is created.
|
|
526
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
527
|
+
epochs: Training epochs.
|
|
528
|
+
shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
|
|
529
|
+
batch_size: Batch size (per process when distributed).
|
|
530
|
+
user_id_column: Column name for GAUC-style metrics;.
|
|
531
|
+
validation_split: Ratio to split training data when valid_data is None.
|
|
532
|
+
num_workers: DataLoader worker count.
|
|
533
|
+
tensorboard: Enable tensorboard logging.
|
|
534
|
+
auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
535
|
+
|
|
536
|
+
Notes:
|
|
537
|
+
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
538
|
+
- All ranks must call evaluate() together because it performs collective ops.
|
|
539
|
+
"""
|
|
540
|
+
device_id = self.local_rank if self.device.type == "cuda" else None
|
|
541
|
+
init_process_group(
|
|
542
|
+
self.distributed, self.rank, self.world_size, device_id=device_id
|
|
543
|
+
)
|
|
286
544
|
self.to(self.device)
|
|
287
|
-
|
|
545
|
+
|
|
546
|
+
if (
|
|
547
|
+
self.distributed
|
|
548
|
+
and dist.is_available()
|
|
549
|
+
and dist.is_initialized()
|
|
550
|
+
and self.ddp_model is None
|
|
551
|
+
):
|
|
552
|
+
device_ids = (
|
|
553
|
+
[self.local_rank] if self.device.type == "cuda" else None
|
|
554
|
+
) # device_ids means which device to use in ddp
|
|
555
|
+
output_device = (
|
|
556
|
+
self.local_rank if self.device.type == "cuda" else None
|
|
557
|
+
) # output_device means which device to place the output in ddp
|
|
558
|
+
object.__setattr__(
|
|
559
|
+
self,
|
|
560
|
+
"ddp_model",
|
|
561
|
+
DDP(
|
|
562
|
+
self,
|
|
563
|
+
device_ids=device_ids,
|
|
564
|
+
output_device=output_device,
|
|
565
|
+
find_unused_parameters=self.ddp_find_unused_parameters,
|
|
566
|
+
),
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if (
|
|
570
|
+
not self.logger_initialized and self.is_main_process
|
|
571
|
+
): # only main process initializes logger
|
|
288
572
|
setup_logger(session_id=self.session_id)
|
|
289
573
|
self.logger_initialized = True
|
|
290
|
-
self.training_logger =
|
|
574
|
+
self.training_logger = (
|
|
575
|
+
TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
576
|
+
if self.is_main_process
|
|
577
|
+
else None
|
|
578
|
+
)
|
|
291
579
|
|
|
292
|
-
self.metrics, self.task_specific_metrics, self.best_metrics_mode =
|
|
293
|
-
|
|
294
|
-
|
|
580
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
581
|
+
configure_metrics(
|
|
582
|
+
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
583
|
+
)
|
|
584
|
+
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
585
|
+
self.early_stopper = EarlyStopper(
|
|
586
|
+
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
587
|
+
)
|
|
588
|
+
self.best_metric = (
|
|
589
|
+
float("-inf") if self.best_metrics_mode == "max" else float("inf")
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
self.needs_user_ids = check_user_id(
|
|
593
|
+
self.metrics, self.task_specific_metrics
|
|
594
|
+
) # check user_id needed for GAUC metrics
|
|
295
595
|
self.epoch_index = 0
|
|
296
596
|
self.stop_training = False
|
|
297
597
|
self.best_checkpoint_path = self.best_path
|
|
298
|
-
self.best_metric = float('-inf') if self.best_metrics_mode == 'max' else float('inf')
|
|
299
598
|
|
|
599
|
+
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
600
|
+
logging.info(
|
|
601
|
+
colorize(
|
|
602
|
+
"[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
|
|
603
|
+
color="yellow",
|
|
604
|
+
)
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
train_sampler: DistributedSampler | None = None
|
|
300
608
|
if validation_split is not None and valid_data is None:
|
|
301
|
-
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
609
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
610
|
+
if (
|
|
611
|
+
auto_distributed_sampler
|
|
612
|
+
and self.distributed
|
|
613
|
+
and dist.is_available()
|
|
614
|
+
and dist.is_initialized()
|
|
615
|
+
):
|
|
616
|
+
base_dataset = getattr(train_loader, "dataset", None)
|
|
617
|
+
if base_dataset is not None and not isinstance(
|
|
618
|
+
getattr(train_loader, "sampler", None), DistributedSampler
|
|
619
|
+
):
|
|
620
|
+
train_sampler = DistributedSampler(
|
|
621
|
+
base_dataset,
|
|
622
|
+
num_replicas=self.world_size,
|
|
623
|
+
rank=self.rank,
|
|
624
|
+
shuffle=shuffle,
|
|
625
|
+
drop_last=True,
|
|
626
|
+
)
|
|
627
|
+
train_loader = DataLoader(
|
|
628
|
+
base_dataset,
|
|
629
|
+
batch_size=batch_size,
|
|
630
|
+
shuffle=False,
|
|
631
|
+
sampler=train_sampler,
|
|
632
|
+
collate_fn=collate_fn,
|
|
633
|
+
num_workers=num_workers,
|
|
634
|
+
drop_last=True,
|
|
635
|
+
)
|
|
302
636
|
else:
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
637
|
+
if isinstance(train_data, DataLoader):
|
|
638
|
+
if auto_distributed_sampler and self.distributed:
|
|
639
|
+
train_loader, train_sampler = add_distributed_sampler(
|
|
640
|
+
train_data,
|
|
641
|
+
distributed=self.distributed,
|
|
642
|
+
world_size=self.world_size,
|
|
643
|
+
rank=self.rank,
|
|
644
|
+
shuffle=shuffle,
|
|
645
|
+
drop_last=True,
|
|
646
|
+
default_batch_size=batch_size,
|
|
647
|
+
is_main_process=self.is_main_process,
|
|
648
|
+
)
|
|
649
|
+
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
650
|
+
else:
|
|
651
|
+
train_loader = train_data
|
|
652
|
+
else:
|
|
653
|
+
loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
654
|
+
if (
|
|
655
|
+
auto_distributed_sampler
|
|
656
|
+
and self.distributed
|
|
657
|
+
and dataset is not None
|
|
658
|
+
and dist.is_available()
|
|
659
|
+
and dist.is_initialized()
|
|
660
|
+
):
|
|
661
|
+
train_sampler = DistributedSampler(
|
|
662
|
+
dataset,
|
|
663
|
+
num_replicas=self.world_size,
|
|
664
|
+
rank=self.rank,
|
|
665
|
+
shuffle=shuffle,
|
|
666
|
+
drop_last=True,
|
|
667
|
+
)
|
|
668
|
+
loader = DataLoader(
|
|
669
|
+
dataset,
|
|
670
|
+
batch_size=batch_size,
|
|
671
|
+
shuffle=False,
|
|
672
|
+
sampler=train_sampler,
|
|
673
|
+
collate_fn=collate_fn,
|
|
674
|
+
num_workers=num_workers,
|
|
675
|
+
drop_last=True,
|
|
676
|
+
)
|
|
677
|
+
train_loader = loader
|
|
678
|
+
|
|
679
|
+
# If split-based loader was built without sampler, attach here when enabled
|
|
680
|
+
if (
|
|
681
|
+
self.distributed
|
|
682
|
+
and auto_distributed_sampler
|
|
683
|
+
and isinstance(train_loader, DataLoader)
|
|
684
|
+
and train_sampler is None
|
|
685
|
+
):
|
|
686
|
+
raise NotImplementedError(
|
|
687
|
+
"[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
|
|
688
|
+
)
|
|
689
|
+
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
690
|
+
|
|
691
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(
|
|
692
|
+
valid_data=valid_data,
|
|
693
|
+
batch_size=batch_size,
|
|
694
|
+
needs_user_ids=self.needs_user_ids,
|
|
695
|
+
user_id_column=user_id_column,
|
|
696
|
+
num_workers=num_workers,
|
|
697
|
+
auto_distributed_sampler=auto_distributed_sampler,
|
|
698
|
+
)
|
|
306
699
|
try:
|
|
307
700
|
self.steps_per_epoch = len(train_loader)
|
|
308
701
|
is_streaming = False
|
|
309
|
-
except TypeError:
|
|
702
|
+
except TypeError: # streaming data loader does not supported len()
|
|
310
703
|
self.steps_per_epoch = None
|
|
311
704
|
is_streaming = True
|
|
312
705
|
|
|
313
|
-
self.
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
logging.info(
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
706
|
+
if self.is_main_process:
|
|
707
|
+
self.summary()
|
|
708
|
+
logging.info("")
|
|
709
|
+
if self.training_logger and self.training_logger.enable_tensorboard:
|
|
710
|
+
tb_dir = self.training_logger.tensorboard_logdir
|
|
711
|
+
if tb_dir:
|
|
712
|
+
user = getpass.getuser()
|
|
713
|
+
host = socket.gethostname()
|
|
714
|
+
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
715
|
+
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
716
|
+
logging.info(
|
|
717
|
+
colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
|
|
718
|
+
)
|
|
719
|
+
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
720
|
+
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
721
|
+
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
722
|
+
logging.info(colorize(f" {ssh_hint}", color="cyan"))
|
|
723
|
+
|
|
724
|
+
logging.info("")
|
|
725
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
726
|
+
if is_streaming:
|
|
727
|
+
logging.info(colorize("Start streaming training", bold=True))
|
|
728
|
+
else:
|
|
729
|
+
logging.info(colorize("Start training", bold=True))
|
|
730
|
+
logging.info(colorize("=" * 80, bold=True))
|
|
731
|
+
logging.info("")
|
|
732
|
+
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
337
733
|
|
|
338
734
|
for epoch in range(epochs):
|
|
339
735
|
self.epoch_index = epoch
|
|
340
|
-
if is_streaming:
|
|
736
|
+
if is_streaming and self.is_main_process:
|
|
341
737
|
logging.info("")
|
|
342
|
-
logging.info(
|
|
738
|
+
logging.info(
|
|
739
|
+
colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)
|
|
740
|
+
) # streaming mode, print epoch header before progress bar
|
|
343
741
|
|
|
344
742
|
# handle train result
|
|
345
|
-
|
|
346
|
-
|
|
743
|
+
if (
|
|
744
|
+
self.distributed
|
|
745
|
+
and hasattr(train_loader, "sampler")
|
|
746
|
+
and isinstance(train_loader.sampler, DistributedSampler)
|
|
747
|
+
):
|
|
748
|
+
train_loader.sampler.set_epoch(epoch)
|
|
749
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
750
|
+
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
347
751
|
train_loss, train_metrics = train_result
|
|
348
752
|
else:
|
|
349
753
|
train_loss = train_result
|
|
@@ -354,15 +758,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
354
758
|
if self.nums_task == 1:
|
|
355
759
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
356
760
|
if train_metrics:
|
|
357
|
-
metrics_str = ", ".join(
|
|
761
|
+
metrics_str = ", ".join(
|
|
762
|
+
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
763
|
+
)
|
|
358
764
|
log_str += f", {metrics_str}"
|
|
359
|
-
|
|
765
|
+
if self.is_main_process:
|
|
766
|
+
logging.info(colorize(log_str))
|
|
360
767
|
train_log_payload["loss"] = float(train_loss)
|
|
361
768
|
if train_metrics:
|
|
362
769
|
train_log_payload.update(train_metrics)
|
|
363
770
|
else:
|
|
364
771
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
365
|
-
log_str =
|
|
772
|
+
log_str = (
|
|
773
|
+
f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
774
|
+
)
|
|
366
775
|
if train_metrics:
|
|
367
776
|
# group metrics by task
|
|
368
777
|
task_metrics = {}
|
|
@@ -378,21 +787,41 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
378
787
|
task_metric_strs = []
|
|
379
788
|
for target_name in self.target_columns:
|
|
380
789
|
if target_name in task_metrics:
|
|
381
|
-
metrics_str = ", ".join(
|
|
790
|
+
metrics_str = ", ".join(
|
|
791
|
+
[
|
|
792
|
+
f"{k}={v:.4f}"
|
|
793
|
+
for k, v in task_metrics[target_name].items()
|
|
794
|
+
]
|
|
795
|
+
)
|
|
382
796
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
383
797
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
384
|
-
|
|
798
|
+
if self.is_main_process:
|
|
799
|
+
logging.info(colorize(log_str))
|
|
385
800
|
train_log_payload["loss"] = float(total_loss_val)
|
|
386
801
|
if train_metrics:
|
|
387
802
|
train_log_payload.update(train_metrics)
|
|
388
803
|
if self.training_logger:
|
|
389
|
-
self.training_logger.log_metrics(
|
|
804
|
+
self.training_logger.log_metrics(
|
|
805
|
+
train_log_payload, step=epoch + 1, split="train"
|
|
806
|
+
)
|
|
390
807
|
if valid_loader is not None:
|
|
391
808
|
# pass user_ids only if needed for GAUC metric
|
|
392
|
-
val_metrics = self.evaluate(
|
|
809
|
+
val_metrics = self.evaluate(
|
|
810
|
+
valid_loader,
|
|
811
|
+
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
812
|
+
num_workers=num_workers,
|
|
813
|
+
) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
393
814
|
if self.nums_task == 1:
|
|
394
|
-
metrics_str = ", ".join(
|
|
395
|
-
|
|
815
|
+
metrics_str = ", ".join(
|
|
816
|
+
[f"{k}={v:.4f}" for k, v in val_metrics.items()]
|
|
817
|
+
)
|
|
818
|
+
if self.is_main_process:
|
|
819
|
+
logging.info(
|
|
820
|
+
colorize(
|
|
821
|
+
f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
|
|
822
|
+
color="cyan",
|
|
823
|
+
)
|
|
824
|
+
)
|
|
396
825
|
else:
|
|
397
826
|
# multi task metrics
|
|
398
827
|
task_metrics = {}
|
|
@@ -407,25 +836,58 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
407
836
|
task_metric_strs = []
|
|
408
837
|
for target_name in self.target_columns:
|
|
409
838
|
if target_name in task_metrics:
|
|
410
|
-
metrics_str = ", ".join(
|
|
839
|
+
metrics_str = ", ".join(
|
|
840
|
+
[
|
|
841
|
+
f"{k}={v:.4f}"
|
|
842
|
+
for k, v in task_metrics[target_name].items()
|
|
843
|
+
]
|
|
844
|
+
)
|
|
411
845
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
412
|
-
|
|
846
|
+
if self.is_main_process:
|
|
847
|
+
logging.info(
|
|
848
|
+
colorize(
|
|
849
|
+
f" Epoch {epoch + 1}/{epochs} - Valid: "
|
|
850
|
+
+ ", ".join(task_metric_strs),
|
|
851
|
+
color="cyan",
|
|
852
|
+
)
|
|
853
|
+
)
|
|
413
854
|
if val_metrics and self.training_logger:
|
|
414
|
-
self.training_logger.log_metrics(
|
|
855
|
+
self.training_logger.log_metrics(
|
|
856
|
+
val_metrics, step=epoch + 1, split="valid"
|
|
857
|
+
)
|
|
415
858
|
# Handle empty validation metrics
|
|
416
859
|
if not val_metrics:
|
|
417
|
-
self.
|
|
418
|
-
|
|
419
|
-
|
|
860
|
+
if self.is_main_process:
|
|
861
|
+
self.save_model(
|
|
862
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
863
|
+
)
|
|
864
|
+
self.best_checkpoint_path = self.checkpoint_path
|
|
865
|
+
logging.info(
|
|
866
|
+
colorize(
|
|
867
|
+
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
868
|
+
color="yellow",
|
|
869
|
+
)
|
|
870
|
+
)
|
|
420
871
|
continue
|
|
421
872
|
if self.nums_task == 1:
|
|
422
873
|
primary_metric_key = self.metrics[0]
|
|
423
874
|
else:
|
|
424
875
|
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
425
|
-
primary_metric = val_metrics.get(
|
|
876
|
+
primary_metric = val_metrics.get(
|
|
877
|
+
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
878
|
+
) # get primary metric value, default to first metric if not found
|
|
879
|
+
|
|
880
|
+
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
881
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
882
|
+
metric_tensor = torch.tensor(
|
|
883
|
+
[primary_metric], device=self.device, dtype=torch.float32
|
|
884
|
+
)
|
|
885
|
+
dist.broadcast(metric_tensor, src=0)
|
|
886
|
+
primary_metric = float(metric_tensor.item())
|
|
887
|
+
|
|
426
888
|
improved = False
|
|
427
889
|
# early stopping check
|
|
428
|
-
if self.best_metrics_mode ==
|
|
890
|
+
if self.best_metrics_mode == "max":
|
|
429
891
|
if primary_metric > self.best_metric:
|
|
430
892
|
self.best_metric = primary_metric
|
|
431
893
|
improved = True
|
|
@@ -433,119 +895,287 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
433
895
|
if primary_metric < self.best_metric:
|
|
434
896
|
self.best_metric = primary_metric
|
|
435
897
|
improved = True
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
if
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
898
|
+
|
|
899
|
+
# save checkpoint and best model for main process
|
|
900
|
+
if self.is_main_process:
|
|
901
|
+
self.save_model(
|
|
902
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
903
|
+
)
|
|
904
|
+
logging.info(" ")
|
|
905
|
+
if improved:
|
|
906
|
+
logging.info(
|
|
907
|
+
colorize(
|
|
908
|
+
f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
|
|
909
|
+
)
|
|
910
|
+
)
|
|
911
|
+
self.save_model(
|
|
912
|
+
self.best_path, add_timestamp=False, verbose=False
|
|
913
|
+
)
|
|
914
|
+
self.best_checkpoint_path = self.best_path
|
|
915
|
+
self.early_stopper.trial_counter = 0
|
|
916
|
+
else:
|
|
917
|
+
self.early_stopper.trial_counter += 1
|
|
918
|
+
logging.info(
|
|
919
|
+
colorize(
|
|
920
|
+
f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
|
|
921
|
+
)
|
|
922
|
+
)
|
|
923
|
+
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
924
|
+
self.stop_training = True
|
|
925
|
+
logging.info(
|
|
926
|
+
colorize(
|
|
927
|
+
f"Early stopping triggered after {epoch + 1} epochs",
|
|
928
|
+
color="bright_red",
|
|
929
|
+
bold=True,
|
|
930
|
+
)
|
|
931
|
+
)
|
|
443
932
|
else:
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
break
|
|
933
|
+
# Non-main processes also update trial_counter to keep in sync
|
|
934
|
+
if improved:
|
|
935
|
+
self.early_stopper.trial_counter = 0
|
|
936
|
+
else:
|
|
937
|
+
self.early_stopper.trial_counter += 1
|
|
450
938
|
else:
|
|
451
|
-
self.
|
|
452
|
-
|
|
453
|
-
|
|
939
|
+
if self.is_main_process:
|
|
940
|
+
self.save_model(
|
|
941
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
942
|
+
)
|
|
943
|
+
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
944
|
+
self.best_checkpoint_path = self.best_path
|
|
945
|
+
|
|
946
|
+
# Broadcast stop_training flag to all processes (always, regardless of validation)
|
|
947
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
948
|
+
stop_tensor = torch.tensor(
|
|
949
|
+
[int(self.stop_training)], device=self.device
|
|
950
|
+
)
|
|
951
|
+
dist.broadcast(stop_tensor, src=0)
|
|
952
|
+
self.stop_training = bool(stop_tensor.item())
|
|
953
|
+
|
|
454
954
|
if self.stop_training:
|
|
455
955
|
break
|
|
456
956
|
if self.scheduler_fn is not None:
|
|
457
|
-
if isinstance(
|
|
957
|
+
if isinstance(
|
|
958
|
+
self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
959
|
+
):
|
|
458
960
|
if valid_loader is not None:
|
|
459
961
|
self.scheduler_fn.step(primary_metric)
|
|
460
962
|
else:
|
|
461
|
-
self.scheduler_fn.step()
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
963
|
+
self.scheduler_fn.step()
|
|
964
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
965
|
+
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
966
|
+
if self.is_main_process:
|
|
967
|
+
logging.info(" ")
|
|
968
|
+
logging.info(colorize("Training finished.", bold=True))
|
|
969
|
+
logging.info(" ")
|
|
465
970
|
if valid_loader is not None:
|
|
466
|
-
|
|
467
|
-
|
|
971
|
+
if self.is_main_process:
|
|
972
|
+
logging.info(
|
|
973
|
+
colorize(f"Load best model from: {self.best_checkpoint_path}")
|
|
974
|
+
)
|
|
975
|
+
self.load_model(
|
|
976
|
+
self.best_checkpoint_path, map_location=self.device, verbose=False
|
|
977
|
+
)
|
|
468
978
|
if self.training_logger:
|
|
469
979
|
self.training_logger.close()
|
|
470
980
|
return self
|
|
471
981
|
|
|
472
|
-
def train_epoch(
|
|
982
|
+
def train_epoch(
|
|
983
|
+
self, train_loader: DataLoader, is_streaming: bool = False
|
|
984
|
+
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
985
|
+
# use ddp model for distributed training
|
|
986
|
+
model = self.ddp_model if getattr(self, "ddp_model") is not None else self
|
|
473
987
|
accumulated_loss = 0.0
|
|
474
|
-
|
|
988
|
+
model.train() # type: ignore
|
|
475
989
|
num_batches = 0
|
|
476
990
|
y_true_list = []
|
|
477
991
|
y_pred_list = []
|
|
478
992
|
|
|
479
993
|
user_ids_list = [] if self.needs_user_ids else None
|
|
994
|
+
tqdm_disable = not self.is_main_process
|
|
480
995
|
if self.steps_per_epoch is not None:
|
|
481
|
-
batch_iter = enumerate(
|
|
996
|
+
batch_iter = enumerate(
|
|
997
|
+
tqdm.tqdm(
|
|
998
|
+
train_loader,
|
|
999
|
+
desc=f"Epoch {self.epoch_index + 1}",
|
|
1000
|
+
total=self.steps_per_epoch,
|
|
1001
|
+
disable=tqdm_disable,
|
|
1002
|
+
)
|
|
1003
|
+
)
|
|
482
1004
|
else:
|
|
483
1005
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
484
|
-
batch_iter = enumerate(
|
|
1006
|
+
batch_iter = enumerate(
|
|
1007
|
+
tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
|
|
1008
|
+
)
|
|
485
1009
|
for batch_index, batch_data in batch_iter:
|
|
486
1010
|
batch_dict = batch_to_dict(batch_data)
|
|
487
1011
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
488
|
-
|
|
1012
|
+
# call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
|
|
1013
|
+
y_pred = model(X_input) # type: ignore
|
|
1014
|
+
|
|
489
1015
|
loss = self.compute_loss(y_pred, y_true)
|
|
490
1016
|
reg_loss = self.add_reg_loss()
|
|
491
1017
|
total_loss = loss + reg_loss
|
|
492
1018
|
self.optimizer_fn.zero_grad()
|
|
493
1019
|
total_loss.backward()
|
|
494
|
-
|
|
1020
|
+
|
|
1021
|
+
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
1022
|
+
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
495
1023
|
self.optimizer_fn.step()
|
|
496
1024
|
accumulated_loss += loss.item()
|
|
1025
|
+
|
|
497
1026
|
if y_true is not None:
|
|
498
1027
|
y_true_list.append(y_true.detach().cpu().numpy())
|
|
499
1028
|
if self.needs_user_ids and user_ids_list is not None:
|
|
500
|
-
batch_user_id = get_user_ids(
|
|
1029
|
+
batch_user_id = get_user_ids(
|
|
1030
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1031
|
+
)
|
|
501
1032
|
if batch_user_id is not None:
|
|
502
1033
|
user_ids_list.append(batch_user_id)
|
|
503
1034
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
504
1035
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
505
1036
|
num_batches += 1
|
|
1037
|
+
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
1038
|
+
loss_tensor = torch.tensor(
|
|
1039
|
+
[accumulated_loss, num_batches], device=self.device, dtype=torch.float32
|
|
1040
|
+
)
|
|
1041
|
+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
1042
|
+
accumulated_loss = loss_tensor[0].item()
|
|
1043
|
+
num_batches = int(loss_tensor[1].item())
|
|
506
1044
|
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
1045
|
+
|
|
1046
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
1047
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
1048
|
+
combined_user_ids_local = (
|
|
1049
|
+
np.concatenate(user_ids_list, axis=0)
|
|
1050
|
+
if self.needs_user_ids and user_ids_list
|
|
1051
|
+
else None
|
|
1052
|
+
)
|
|
1053
|
+
|
|
1054
|
+
# gather across ranks even when local is empty to avoid DDP hang
|
|
1055
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
1056
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
1057
|
+
combined_user_ids = (
|
|
1058
|
+
gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
if (
|
|
1062
|
+
y_true_all is not None
|
|
1063
|
+
and y_pred_all is not None
|
|
1064
|
+
and len(y_true_all) > 0
|
|
1065
|
+
and len(y_pred_all) > 0
|
|
1066
|
+
):
|
|
1067
|
+
metrics_dict = evaluate_metrics(
|
|
1068
|
+
y_true=y_true_all,
|
|
1069
|
+
y_pred=y_pred_all,
|
|
1070
|
+
metrics=self.metrics,
|
|
1071
|
+
task=self.task,
|
|
1072
|
+
target_names=self.target_columns,
|
|
1073
|
+
task_specific_metrics=self.task_specific_metrics,
|
|
1074
|
+
user_ids=combined_user_ids,
|
|
1075
|
+
)
|
|
514
1076
|
return avg_loss, metrics_dict
|
|
515
1077
|
return avg_loss
|
|
516
1078
|
|
|
517
|
-
def prepare_validation_data(
|
|
1079
|
+
def prepare_validation_data(
|
|
1080
|
+
self,
|
|
1081
|
+
valid_data: dict | pd.DataFrame | DataLoader | None,
|
|
1082
|
+
batch_size: int,
|
|
1083
|
+
needs_user_ids: bool,
|
|
1084
|
+
user_id_column: str | None = "user_id",
|
|
1085
|
+
num_workers: int = 0,
|
|
1086
|
+
auto_distributed_sampler: bool = True,
|
|
1087
|
+
) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
518
1088
|
if valid_data is None:
|
|
519
1089
|
return None, None
|
|
520
1090
|
if isinstance(valid_data, DataLoader):
|
|
521
|
-
|
|
522
|
-
|
|
1091
|
+
if auto_distributed_sampler and self.distributed:
|
|
1092
|
+
raise NotImplementedError(
|
|
1093
|
+
"[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
|
|
1094
|
+
)
|
|
1095
|
+
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
1096
|
+
else:
|
|
1097
|
+
valid_loader = valid_data
|
|
1098
|
+
return valid_loader, None
|
|
1099
|
+
valid_sampler = None
|
|
1100
|
+
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
1101
|
+
if (
|
|
1102
|
+
auto_distributed_sampler
|
|
1103
|
+
and self.distributed
|
|
1104
|
+
and valid_dataset is not None
|
|
1105
|
+
and dist.is_available()
|
|
1106
|
+
and dist.is_initialized()
|
|
1107
|
+
):
|
|
1108
|
+
valid_sampler = DistributedSampler(
|
|
1109
|
+
valid_dataset,
|
|
1110
|
+
num_replicas=self.world_size,
|
|
1111
|
+
rank=self.rank,
|
|
1112
|
+
shuffle=False,
|
|
1113
|
+
drop_last=False,
|
|
1114
|
+
)
|
|
1115
|
+
valid_loader = DataLoader(
|
|
1116
|
+
valid_dataset,
|
|
1117
|
+
batch_size=batch_size,
|
|
1118
|
+
shuffle=False,
|
|
1119
|
+
sampler=valid_sampler,
|
|
1120
|
+
collate_fn=collate_fn,
|
|
1121
|
+
num_workers=num_workers,
|
|
1122
|
+
)
|
|
523
1123
|
valid_user_ids = None
|
|
524
1124
|
if needs_user_ids:
|
|
525
1125
|
if user_id_column is None:
|
|
526
|
-
raise ValueError(
|
|
527
|
-
|
|
1126
|
+
raise ValueError(
|
|
1127
|
+
"[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics."
|
|
1128
|
+
)
|
|
1129
|
+
# In distributed mode, user_ids will be collected during evaluation from each batch
|
|
1130
|
+
# and gathered across all processes, so we don't pre-extract them here
|
|
1131
|
+
if not self.distributed:
|
|
1132
|
+
valid_user_ids = get_user_ids(
|
|
1133
|
+
data=valid_data, id_columns=user_id_column
|
|
1134
|
+
)
|
|
528
1135
|
return valid_loader, valid_user_ids
|
|
529
1136
|
|
|
530
|
-
def evaluate(
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
1137
|
+
def evaluate(
|
|
1138
|
+
self,
|
|
1139
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
1140
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
1141
|
+
batch_size: int = 32,
|
|
1142
|
+
user_ids: np.ndarray | None = None,
|
|
1143
|
+
user_id_column: str = "user_id",
|
|
1144
|
+
num_workers: int = 0,
|
|
1145
|
+
) -> dict:
|
|
1146
|
+
"""
|
|
1147
|
+
**IMPORTANT for Distributed Training:**
|
|
1148
|
+
in distributed mode, this method uses collective communication operations (all_gather).
|
|
1149
|
+
all processes must call this method simultaneously, even if you only want results on rank 0.
|
|
1150
|
+
failing to do so will cause the program to hang indefinitely.
|
|
1151
|
+
|
|
1152
|
+
Evaluate the model on the given data.
|
|
1153
|
+
|
|
1154
|
+
Args:
|
|
1155
|
+
data: Evaluation data (dict/df/DataLoader).
|
|
1156
|
+
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
1157
|
+
batch_size: Batch size (per process when distributed).
|
|
1158
|
+
user_ids: Optional array of user IDs for GAUC-style metrics; if None and needed, will be extracted from data using user_id_column. e.g. np.array([...])
|
|
1159
|
+
user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
|
|
1160
|
+
num_workers: DataLoader worker count.
|
|
1161
|
+
"""
|
|
1162
|
+
model = self.ddp_model if getattr(self, "ddp_model", None) is not None else self
|
|
1163
|
+
model.eval()
|
|
538
1164
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
539
1165
|
if eval_metrics is None:
|
|
540
|
-
raise ValueError(
|
|
1166
|
+
raise ValueError(
|
|
1167
|
+
"[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
|
|
1168
|
+
)
|
|
541
1169
|
needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
|
|
542
|
-
|
|
1170
|
+
|
|
543
1171
|
if isinstance(data, DataLoader):
|
|
544
1172
|
data_loader = data
|
|
545
1173
|
else:
|
|
546
1174
|
if user_ids is None and needs_user_ids:
|
|
547
1175
|
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
548
|
-
data_loader = self.prepare_data_loader(
|
|
1176
|
+
data_loader = self.prepare_data_loader(
|
|
1177
|
+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
1178
|
+
)
|
|
549
1179
|
y_true_list = []
|
|
550
1180
|
y_pred_list = []
|
|
551
1181
|
collected_user_ids = []
|
|
@@ -555,30 +1185,25 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
555
1185
|
batch_count += 1
|
|
556
1186
|
batch_dict = batch_to_dict(batch_data)
|
|
557
1187
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
558
|
-
y_pred =
|
|
1188
|
+
y_pred = model(X_input)
|
|
559
1189
|
if y_true is not None:
|
|
560
1190
|
y_true_list.append(y_true.cpu().numpy())
|
|
561
1191
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
562
1192
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
563
1193
|
if needs_user_ids and user_ids is None:
|
|
564
|
-
batch_user_id = get_user_ids(
|
|
1194
|
+
batch_user_id = get_user_ids(
|
|
1195
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1196
|
+
)
|
|
565
1197
|
if batch_user_id is not None:
|
|
566
1198
|
collected_user_ids.append(batch_user_id)
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
else
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
if len(y_pred_list) > 0:
|
|
577
|
-
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
578
|
-
else:
|
|
579
|
-
y_pred_all = None
|
|
580
|
-
logging.info(colorize(f" Warning: No y_pred collected from evaluation data", color="yellow"))
|
|
581
|
-
|
|
1199
|
+
if self.is_main_process:
|
|
1200
|
+
logging.info(" ")
|
|
1201
|
+
logging.info(
|
|
1202
|
+
colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
|
|
1203
|
+
)
|
|
1204
|
+
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
1205
|
+
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
1206
|
+
|
|
582
1207
|
# Convert metrics to list if it's a dict
|
|
583
1208
|
if isinstance(eval_metrics, dict):
|
|
584
1209
|
# For dict metrics, we need to collect all unique metric names
|
|
@@ -589,11 +1214,44 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
589
1214
|
unique_metrics.append(m)
|
|
590
1215
|
metrics_to_use = unique_metrics
|
|
591
1216
|
else:
|
|
592
|
-
metrics_to_use = eval_metrics
|
|
593
|
-
|
|
594
|
-
if
|
|
595
|
-
|
|
596
|
-
|
|
1217
|
+
metrics_to_use = eval_metrics
|
|
1218
|
+
final_user_ids_local = user_ids
|
|
1219
|
+
if final_user_ids_local is None and collected_user_ids:
|
|
1220
|
+
final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
|
|
1221
|
+
|
|
1222
|
+
# gather across ranks even when local arrays are empty to keep collectives aligned
|
|
1223
|
+
y_true_all = gather_numpy(self, y_true_all_local)
|
|
1224
|
+
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
1225
|
+
final_user_ids = (
|
|
1226
|
+
gather_numpy(self, final_user_ids_local) if needs_user_ids else None
|
|
1227
|
+
)
|
|
1228
|
+
if (
|
|
1229
|
+
y_true_all is None
|
|
1230
|
+
or y_pred_all is None
|
|
1231
|
+
or len(y_true_all) == 0
|
|
1232
|
+
or len(y_pred_all) == 0
|
|
1233
|
+
):
|
|
1234
|
+
if self.is_main_process:
|
|
1235
|
+
logging.info(
|
|
1236
|
+
colorize(
|
|
1237
|
+
" Warning: Not enough evaluation data to compute metrics after gathering",
|
|
1238
|
+
color="yellow",
|
|
1239
|
+
)
|
|
1240
|
+
)
|
|
1241
|
+
return {}
|
|
1242
|
+
if self.is_main_process:
|
|
1243
|
+
logging.info(
|
|
1244
|
+
colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
|
|
1245
|
+
)
|
|
1246
|
+
metrics_dict = evaluate_metrics(
|
|
1247
|
+
y_true=y_true_all,
|
|
1248
|
+
y_pred=y_pred_all,
|
|
1249
|
+
metrics=metrics_to_use,
|
|
1250
|
+
task=self.task,
|
|
1251
|
+
target_names=self.target_columns,
|
|
1252
|
+
task_specific_metrics=self.task_specific_metrics,
|
|
1253
|
+
user_ids=final_user_ids,
|
|
1254
|
+
)
|
|
597
1255
|
return metrics_dict
|
|
598
1256
|
|
|
599
1257
|
def predict(
|
|
@@ -603,43 +1261,100 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
603
1261
|
save_path: str | os.PathLike | None = None,
|
|
604
1262
|
save_format: Literal["csv", "parquet"] = "csv",
|
|
605
1263
|
include_ids: bool | None = None,
|
|
1264
|
+
id_columns: str | list[str] | None = None,
|
|
606
1265
|
return_dataframe: bool = True,
|
|
607
1266
|
streaming_chunk_size: int = 10000,
|
|
608
1267
|
num_workers: int = 0,
|
|
609
1268
|
) -> pd.DataFrame | np.ndarray:
|
|
1269
|
+
"""
|
|
1270
|
+
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1271
|
+
Make predictions on the given data.
|
|
1272
|
+
|
|
1273
|
+
Args:
|
|
1274
|
+
data: Input data for prediction (file path, dict, DataFrame, or DataLoader).
|
|
1275
|
+
batch_size: Batch size for prediction (per process when distributed).
|
|
1276
|
+
save_path: Optional path to save predictions; if None, predictions are not saved to disk.
|
|
1277
|
+
save_format: Format to save predictions ('csv' or 'parquet').
|
|
1278
|
+
include_ids: Whether to include ID columns in the output; if None, includes if id_columns are set.
|
|
1279
|
+
id_columns: Column name(s) to use as IDs; if None, uses model's id_columns.
|
|
1280
|
+
return_dataframe: Whether to return predictions as a pandas DataFrame; if False, returns a NumPy array.
|
|
1281
|
+
streaming_chunk_size: Number of rows per chunk when using streaming mode for large datasets.
|
|
1282
|
+
num_workers: DataLoader worker count.
|
|
1283
|
+
"""
|
|
610
1284
|
self.eval()
|
|
1285
|
+
# Use prediction-time id_columns if provided, otherwise fall back to model's id_columns
|
|
1286
|
+
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
1287
|
+
if isinstance(predict_id_columns, str):
|
|
1288
|
+
predict_id_columns = [predict_id_columns]
|
|
1289
|
+
|
|
611
1290
|
if include_ids is None:
|
|
612
|
-
include_ids = bool(
|
|
613
|
-
include_ids = include_ids and bool(
|
|
1291
|
+
include_ids = bool(predict_id_columns)
|
|
1292
|
+
include_ids = include_ids and bool(predict_id_columns)
|
|
614
1293
|
|
|
1294
|
+
# Use streaming mode for large file saves without loading all data into memory
|
|
615
1295
|
if save_path is not None and not return_dataframe:
|
|
616
|
-
return self.
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
1296
|
+
return self.predict_streaming(
|
|
1297
|
+
data=data,
|
|
1298
|
+
batch_size=batch_size,
|
|
1299
|
+
save_path=save_path,
|
|
1300
|
+
save_format=save_format,
|
|
1301
|
+
include_ids=include_ids,
|
|
1302
|
+
streaming_chunk_size=streaming_chunk_size,
|
|
1303
|
+
return_dataframe=return_dataframe,
|
|
1304
|
+
id_columns=predict_id_columns,
|
|
1305
|
+
)
|
|
1306
|
+
|
|
1307
|
+
# Create DataLoader based on data type
|
|
1308
|
+
if isinstance(data, DataLoader):
|
|
623
1309
|
data_loader = data
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
1310
|
+
elif isinstance(data, (str, os.PathLike)):
|
|
1311
|
+
rec_loader = RecDataLoader(
|
|
1312
|
+
dense_features=self.dense_features,
|
|
1313
|
+
sparse_features=self.sparse_features,
|
|
1314
|
+
sequence_features=self.sequence_features,
|
|
1315
|
+
target=self.target_columns,
|
|
1316
|
+
id_columns=predict_id_columns,
|
|
1317
|
+
)
|
|
1318
|
+
data_loader = rec_loader.create_dataloader(
|
|
1319
|
+
data=data,
|
|
1320
|
+
batch_size=batch_size,
|
|
1321
|
+
shuffle=False,
|
|
1322
|
+
load_full=False,
|
|
1323
|
+
chunk_size=streaming_chunk_size,
|
|
1324
|
+
)
|
|
1325
|
+
else:
|
|
1326
|
+
data_loader = self.prepare_data_loader(
|
|
1327
|
+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
y_pred_list = []
|
|
1331
|
+
id_buffers = (
|
|
1332
|
+
{name: [] for name in (predict_id_columns or [])} if include_ids else {}
|
|
1333
|
+
)
|
|
1334
|
+
id_arrays = None
|
|
1335
|
+
|
|
629
1336
|
with torch.no_grad():
|
|
630
1337
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
631
1338
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
632
1339
|
X_input, _ = self.get_input(batch_dict, require_labels=False)
|
|
633
|
-
y_pred = self
|
|
1340
|
+
y_pred = self(X_input)
|
|
634
1341
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
635
1342
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
636
|
-
if include_ids and
|
|
637
|
-
for id_name in
|
|
1343
|
+
if include_ids and predict_id_columns and batch_dict.get("ids"):
|
|
1344
|
+
for id_name in predict_id_columns:
|
|
638
1345
|
if id_name not in batch_dict["ids"]:
|
|
639
1346
|
continue
|
|
640
1347
|
id_tensor = batch_dict["ids"][id_name]
|
|
641
|
-
id_np =
|
|
642
|
-
|
|
1348
|
+
id_np = (
|
|
1349
|
+
id_tensor.detach().cpu().numpy()
|
|
1350
|
+
if isinstance(id_tensor, torch.Tensor)
|
|
1351
|
+
else np.asarray(id_tensor)
|
|
1352
|
+
)
|
|
1353
|
+
id_buffers[id_name].append(
|
|
1354
|
+
id_np.reshape(id_np.shape[0], -1)
|
|
1355
|
+
if id_np.ndim == 1
|
|
1356
|
+
else id_np
|
|
1357
|
+
)
|
|
643
1358
|
if len(y_pred_list) > 0:
|
|
644
1359
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
645
1360
|
else:
|
|
@@ -657,11 +1372,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
657
1372
|
pred_columns.append(f"{name}_pred")
|
|
658
1373
|
while len(pred_columns) < num_outputs:
|
|
659
1374
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
660
|
-
if include_ids and
|
|
1375
|
+
if include_ids and predict_id_columns:
|
|
661
1376
|
id_arrays = {}
|
|
662
1377
|
for id_name, pieces in id_buffers.items():
|
|
663
1378
|
if pieces:
|
|
664
|
-
concatenated = np.concatenate(
|
|
1379
|
+
concatenated = np.concatenate(
|
|
1380
|
+
[p.reshape(p.shape[0], -1) for p in pieces], axis=0
|
|
1381
|
+
)
|
|
665
1382
|
id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
|
|
666
1383
|
else:
|
|
667
1384
|
id_arrays[id_name] = np.array([], dtype=np.int64)
|
|
@@ -669,34 +1386,52 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
669
1386
|
id_df = pd.DataFrame(id_arrays)
|
|
670
1387
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
671
1388
|
if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
|
|
672
|
-
raise ValueError(
|
|
1389
|
+
raise ValueError(
|
|
1390
|
+
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)})."
|
|
1391
|
+
)
|
|
673
1392
|
output = pd.concat([id_df, pred_df], axis=1)
|
|
674
1393
|
else:
|
|
675
1394
|
output = y_pred_all
|
|
676
1395
|
else:
|
|
677
|
-
output =
|
|
1396
|
+
output = (
|
|
1397
|
+
pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
1398
|
+
if return_dataframe
|
|
1399
|
+
else y_pred_all
|
|
1400
|
+
)
|
|
678
1401
|
if save_path is not None:
|
|
679
1402
|
if save_format not in ("csv", "parquet"):
|
|
680
|
-
raise ValueError(
|
|
1403
|
+
raise ValueError(
|
|
1404
|
+
f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
|
|
1405
|
+
)
|
|
681
1406
|
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
682
|
-
target_path = resolve_save_path(
|
|
1407
|
+
target_path = resolve_save_path(
|
|
1408
|
+
path=save_path,
|
|
1409
|
+
default_dir=self.session.predictions_dir,
|
|
1410
|
+
default_name="predictions",
|
|
1411
|
+
suffix=suffix,
|
|
1412
|
+
add_timestamp=True if save_path is None else False,
|
|
1413
|
+
)
|
|
683
1414
|
if isinstance(output, pd.DataFrame):
|
|
684
1415
|
df_to_save = output
|
|
685
1416
|
else:
|
|
686
1417
|
df_to_save = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
687
|
-
if include_ids and
|
|
1418
|
+
if include_ids and predict_id_columns and id_arrays is not None:
|
|
688
1419
|
id_df = pd.DataFrame(id_arrays)
|
|
689
1420
|
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
690
|
-
raise ValueError(
|
|
1421
|
+
raise ValueError(
|
|
1422
|
+
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
|
|
1423
|
+
)
|
|
691
1424
|
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
692
1425
|
if save_format == "csv":
|
|
693
1426
|
df_to_save.to_csv(target_path, index=False)
|
|
694
1427
|
else:
|
|
695
1428
|
df_to_save.to_parquet(target_path, index=False)
|
|
696
|
-
logging.info(
|
|
1429
|
+
logging.info(
|
|
1430
|
+
colorize(f"Predictions saved to: {target_path}", color="green")
|
|
1431
|
+
)
|
|
697
1432
|
return output
|
|
698
1433
|
|
|
699
|
-
def
|
|
1434
|
+
def predict_streaming(
|
|
700
1435
|
self,
|
|
701
1436
|
data: str | dict | pd.DataFrame | DataLoader,
|
|
702
1437
|
batch_size: int,
|
|
@@ -705,23 +1440,46 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
705
1440
|
include_ids: bool,
|
|
706
1441
|
streaming_chunk_size: int,
|
|
707
1442
|
return_dataframe: bool,
|
|
1443
|
+
id_columns: list[str] | None = None,
|
|
708
1444
|
) -> pd.DataFrame:
|
|
709
1445
|
if isinstance(data, (str, os.PathLike)):
|
|
710
|
-
rec_loader = RecDataLoader(
|
|
711
|
-
|
|
1446
|
+
rec_loader = RecDataLoader(
|
|
1447
|
+
dense_features=self.dense_features,
|
|
1448
|
+
sparse_features=self.sparse_features,
|
|
1449
|
+
sequence_features=self.sequence_features,
|
|
1450
|
+
target=self.target_columns,
|
|
1451
|
+
id_columns=id_columns,
|
|
1452
|
+
)
|
|
1453
|
+
data_loader = rec_loader.create_dataloader(
|
|
1454
|
+
data=data,
|
|
1455
|
+
batch_size=batch_size,
|
|
1456
|
+
shuffle=False,
|
|
1457
|
+
load_full=False,
|
|
1458
|
+
chunk_size=streaming_chunk_size,
|
|
1459
|
+
)
|
|
712
1460
|
elif not isinstance(data, DataLoader):
|
|
713
|
-
data_loader = self.prepare_data_loader(
|
|
1461
|
+
data_loader = self.prepare_data_loader(
|
|
1462
|
+
data,
|
|
1463
|
+
batch_size=batch_size,
|
|
1464
|
+
shuffle=False,
|
|
1465
|
+
)
|
|
714
1466
|
else:
|
|
715
1467
|
data_loader = data
|
|
716
1468
|
|
|
717
1469
|
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
718
|
-
target_path = resolve_save_path(
|
|
1470
|
+
target_path = resolve_save_path(
|
|
1471
|
+
path=save_path,
|
|
1472
|
+
default_dir=self.session.predictions_dir,
|
|
1473
|
+
default_name="predictions",
|
|
1474
|
+
suffix=suffix,
|
|
1475
|
+
add_timestamp=True if save_path is None else False,
|
|
1476
|
+
)
|
|
719
1477
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
720
1478
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
721
1479
|
parquet_writer = None
|
|
722
1480
|
|
|
723
|
-
pred_columns
|
|
724
|
-
collected_frames
|
|
1481
|
+
pred_columns = None
|
|
1482
|
+
collected_frames = [] # only used when return_dataframe is True
|
|
725
1483
|
|
|
726
1484
|
with torch.no_grad():
|
|
727
1485
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
@@ -741,32 +1499,42 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
741
1499
|
pred_columns.append(f"{name}_pred")
|
|
742
1500
|
while len(pred_columns) < num_outputs:
|
|
743
1501
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
744
|
-
|
|
745
|
-
id_arrays_batch
|
|
746
|
-
if include_ids and
|
|
747
|
-
for id_name in
|
|
1502
|
+
|
|
1503
|
+
id_arrays_batch = {}
|
|
1504
|
+
if include_ids and id_columns and batch_dict.get("ids"):
|
|
1505
|
+
for id_name in id_columns:
|
|
748
1506
|
if id_name not in batch_dict["ids"]:
|
|
749
1507
|
continue
|
|
750
1508
|
id_tensor = batch_dict["ids"][id_name]
|
|
751
|
-
id_np =
|
|
1509
|
+
id_np = (
|
|
1510
|
+
id_tensor.detach().cpu().numpy()
|
|
1511
|
+
if isinstance(id_tensor, torch.Tensor)
|
|
1512
|
+
else np.asarray(id_tensor)
|
|
1513
|
+
)
|
|
752
1514
|
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
753
1515
|
|
|
754
1516
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
755
1517
|
if id_arrays_batch:
|
|
756
1518
|
id_df = pd.DataFrame(id_arrays_batch)
|
|
757
1519
|
if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
|
|
758
|
-
raise ValueError(
|
|
1520
|
+
raise ValueError(
|
|
1521
|
+
f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
|
|
1522
|
+
)
|
|
759
1523
|
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
760
1524
|
|
|
761
1525
|
if save_format == "csv":
|
|
762
|
-
df_batch.to_csv(
|
|
1526
|
+
df_batch.to_csv(
|
|
1527
|
+
target_path, mode="a", header=not header_written, index=False
|
|
1528
|
+
)
|
|
763
1529
|
header_written = True
|
|
764
1530
|
else:
|
|
765
1531
|
try:
|
|
766
1532
|
import pyarrow as pa
|
|
767
1533
|
import pyarrow.parquet as pq
|
|
768
1534
|
except ImportError as exc: # pragma: no cover
|
|
769
|
-
raise ImportError(
|
|
1535
|
+
raise ImportError(
|
|
1536
|
+
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
|
|
1537
|
+
) from exc
|
|
770
1538
|
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
771
1539
|
if parquet_writer is None:
|
|
772
1540
|
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
@@ -777,14 +1545,36 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
777
1545
|
parquet_writer.close()
|
|
778
1546
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
779
1547
|
if return_dataframe:
|
|
780
|
-
return
|
|
1548
|
+
return (
|
|
1549
|
+
pd.concat(collected_frames, ignore_index=True)
|
|
1550
|
+
if collected_frames
|
|
1551
|
+
else pd.DataFrame(columns=pred_columns or [])
|
|
1552
|
+
)
|
|
781
1553
|
return pd.DataFrame(columns=pred_columns or [])
|
|
782
1554
|
|
|
783
|
-
def save_model(
|
|
1555
|
+
def save_model(
|
|
1556
|
+
self,
|
|
1557
|
+
save_path: str | Path | None = None,
|
|
1558
|
+
add_timestamp: bool | None = None,
|
|
1559
|
+
verbose: bool = True,
|
|
1560
|
+
):
|
|
784
1561
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
785
|
-
target_path = resolve_save_path(
|
|
1562
|
+
target_path = resolve_save_path(
|
|
1563
|
+
path=save_path,
|
|
1564
|
+
default_dir=self.session_path,
|
|
1565
|
+
default_name=self.model_name,
|
|
1566
|
+
suffix=".model",
|
|
1567
|
+
add_timestamp=add_timestamp,
|
|
1568
|
+
)
|
|
786
1569
|
model_path = Path(target_path)
|
|
787
|
-
|
|
1570
|
+
|
|
1571
|
+
model_to_save = (
|
|
1572
|
+
self.ddp_model.module
|
|
1573
|
+
if getattr(self, "ddp_model", None) is not None
|
|
1574
|
+
else self
|
|
1575
|
+
)
|
|
1576
|
+
torch.save(model_to_save.state_dict(), model_path)
|
|
1577
|
+
# torch.save(self.state_dict(), model_path)
|
|
788
1578
|
|
|
789
1579
|
config_path = self.features_config_path
|
|
790
1580
|
features_config = {
|
|
@@ -797,29 +1587,47 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
797
1587
|
pickle.dump(features_config, f)
|
|
798
1588
|
self.features_config_path = str(config_path)
|
|
799
1589
|
if verbose:
|
|
800
|
-
logging.info(
|
|
801
|
-
|
|
802
|
-
|
|
1590
|
+
logging.info(
|
|
1591
|
+
colorize(
|
|
1592
|
+
f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",
|
|
1593
|
+
color="green",
|
|
1594
|
+
)
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
def load_model(
|
|
1598
|
+
self,
|
|
1599
|
+
save_path: str | Path,
|
|
1600
|
+
map_location: str | torch.device | None = "cpu",
|
|
1601
|
+
verbose: bool = True,
|
|
1602
|
+
):
|
|
803
1603
|
self.to(self.device)
|
|
804
1604
|
base_path = Path(save_path)
|
|
805
1605
|
if base_path.is_dir():
|
|
806
1606
|
model_files = sorted(base_path.glob("*.model"))
|
|
807
1607
|
if not model_files:
|
|
808
|
-
raise FileNotFoundError(
|
|
1608
|
+
raise FileNotFoundError(
|
|
1609
|
+
f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
|
|
1610
|
+
)
|
|
809
1611
|
model_path = model_files[-1]
|
|
810
1612
|
config_dir = base_path
|
|
811
1613
|
else:
|
|
812
|
-
model_path =
|
|
1614
|
+
model_path = (
|
|
1615
|
+
base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
1616
|
+
)
|
|
813
1617
|
config_dir = model_path.parent
|
|
814
1618
|
if not model_path.exists():
|
|
815
|
-
raise FileNotFoundError(
|
|
1619
|
+
raise FileNotFoundError(
|
|
1620
|
+
f"[BaseModel-load-model Error] Model file does not exist: {model_path}"
|
|
1621
|
+
)
|
|
816
1622
|
|
|
817
1623
|
state_dict = torch.load(model_path, map_location=map_location)
|
|
818
1624
|
self.load_state_dict(state_dict)
|
|
819
1625
|
|
|
820
1626
|
features_config_path = config_dir / "features_config.pkl"
|
|
821
1627
|
if not features_config_path.exists():
|
|
822
|
-
raise FileNotFoundError(
|
|
1628
|
+
raise FileNotFoundError(
|
|
1629
|
+
f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}"
|
|
1630
|
+
)
|
|
823
1631
|
with open(features_config_path, "rb") as f:
|
|
824
1632
|
features_config = pickle.load(f)
|
|
825
1633
|
|
|
@@ -829,11 +1637,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
829
1637
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
830
1638
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
831
1639
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
832
|
-
self.set_all_features(
|
|
1640
|
+
self.set_all_features(
|
|
1641
|
+
dense_features=dense_features,
|
|
1642
|
+
sparse_features=sparse_features,
|
|
1643
|
+
sequence_features=sequence_features,
|
|
1644
|
+
target=target,
|
|
1645
|
+
id_columns=id_columns,
|
|
1646
|
+
)
|
|
833
1647
|
|
|
834
1648
|
cfg_version = features_config.get("version")
|
|
835
1649
|
if verbose:
|
|
836
|
-
logging.info(
|
|
1650
|
+
logging.info(
|
|
1651
|
+
colorize(
|
|
1652
|
+
f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",
|
|
1653
|
+
color="green",
|
|
1654
|
+
)
|
|
1655
|
+
)
|
|
837
1656
|
|
|
838
1657
|
@classmethod
|
|
839
1658
|
def from_checkpoint(
|
|
@@ -845,23 +1664,29 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
845
1664
|
**kwargs: Any,
|
|
846
1665
|
) -> "BaseModel":
|
|
847
1666
|
"""
|
|
848
|
-
|
|
849
|
-
|
|
1667
|
+
Load a model from a checkpoint path. The checkpoint path should contain:
|
|
1668
|
+
a .model file and a features_config.pkl file.
|
|
850
1669
|
"""
|
|
851
1670
|
base_path = Path(checkpoint_path)
|
|
852
1671
|
verbose = kwargs.pop("verbose", True)
|
|
853
1672
|
if base_path.is_dir():
|
|
854
1673
|
model_candidates = sorted(base_path.glob("*.model"))
|
|
855
1674
|
if not model_candidates:
|
|
856
|
-
raise FileNotFoundError(
|
|
1675
|
+
raise FileNotFoundError(
|
|
1676
|
+
f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
|
|
1677
|
+
)
|
|
857
1678
|
model_file = model_candidates[-1]
|
|
858
1679
|
config_dir = base_path
|
|
859
1680
|
else:
|
|
860
|
-
model_file =
|
|
1681
|
+
model_file = (
|
|
1682
|
+
base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
1683
|
+
)
|
|
861
1684
|
config_dir = model_file.parent
|
|
862
1685
|
features_config_path = config_dir / "features_config.pkl"
|
|
863
1686
|
if not features_config_path.exists():
|
|
864
|
-
raise FileNotFoundError(
|
|
1687
|
+
raise FileNotFoundError(
|
|
1688
|
+
f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}"
|
|
1689
|
+
)
|
|
865
1690
|
with open(features_config_path, "rb") as f:
|
|
866
1691
|
features_config = pickle.load(f)
|
|
867
1692
|
all_features = features_config.get("all_features", [])
|
|
@@ -887,108 +1712,132 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
887
1712
|
|
|
888
1713
|
def summary(self):
|
|
889
1714
|
logger = logging.getLogger()
|
|
890
|
-
|
|
1715
|
+
|
|
891
1716
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
892
|
-
logger.info(
|
|
1717
|
+
logger.info(
|
|
1718
|
+
colorize(
|
|
1719
|
+
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1720
|
+
)
|
|
1721
|
+
)
|
|
893
1722
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
894
|
-
|
|
1723
|
+
|
|
895
1724
|
logger.info("")
|
|
896
1725
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
897
1726
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
898
|
-
|
|
1727
|
+
|
|
899
1728
|
if self.dense_features:
|
|
900
1729
|
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
901
1730
|
for i, feat in enumerate(self.dense_features, 1):
|
|
902
|
-
embed_dim = feat.embedding_dim if hasattr(feat,
|
|
1731
|
+
embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
|
|
903
1732
|
logger.info(f" {i}. {feat.name:20s}")
|
|
904
|
-
|
|
1733
|
+
|
|
905
1734
|
if self.sparse_features:
|
|
906
1735
|
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
907
1736
|
|
|
908
1737
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
909
|
-
max_embed_name_len = max(
|
|
1738
|
+
max_embed_name_len = max(
|
|
1739
|
+
len(feat.embedding_name) for feat in self.sparse_features
|
|
1740
|
+
)
|
|
910
1741
|
name_width = max(max_name_len, 10) + 2
|
|
911
1742
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
912
|
-
|
|
913
|
-
logger.info(
|
|
914
|
-
|
|
1743
|
+
|
|
1744
|
+
logger.info(
|
|
1745
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
|
|
1746
|
+
)
|
|
1747
|
+
logger.info(
|
|
1748
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
1749
|
+
)
|
|
915
1750
|
for i, feat in enumerate(self.sparse_features, 1):
|
|
916
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
917
|
-
embed_dim =
|
|
918
|
-
|
|
919
|
-
|
|
1751
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1752
|
+
embed_dim = (
|
|
1753
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1754
|
+
)
|
|
1755
|
+
logger.info(
|
|
1756
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
1757
|
+
)
|
|
1758
|
+
|
|
920
1759
|
if self.sequence_features:
|
|
921
1760
|
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
922
1761
|
|
|
923
1762
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
924
|
-
max_embed_name_len = max(
|
|
1763
|
+
max_embed_name_len = max(
|
|
1764
|
+
len(feat.embedding_name) for feat in self.sequence_features
|
|
1765
|
+
)
|
|
925
1766
|
name_width = max(max_name_len, 10) + 2
|
|
926
1767
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
927
|
-
|
|
928
|
-
logger.info(
|
|
929
|
-
|
|
1768
|
+
|
|
1769
|
+
logger.info(
|
|
1770
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
|
|
1771
|
+
)
|
|
1772
|
+
logger.info(
|
|
1773
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
1774
|
+
)
|
|
930
1775
|
for i, feat in enumerate(self.sequence_features, 1):
|
|
931
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
932
|
-
embed_dim =
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1776
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1777
|
+
embed_dim = (
|
|
1778
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1779
|
+
)
|
|
1780
|
+
max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
|
|
1781
|
+
logger.info(
|
|
1782
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
1783
|
+
)
|
|
1784
|
+
|
|
936
1785
|
logger.info("")
|
|
937
1786
|
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
938
1787
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
939
|
-
|
|
1788
|
+
|
|
940
1789
|
# Model Architecture
|
|
941
1790
|
logger.info("Model Architecture:")
|
|
942
1791
|
logger.info(str(self))
|
|
943
1792
|
logger.info("")
|
|
944
|
-
|
|
1793
|
+
|
|
945
1794
|
total_params = sum(p.numel() for p in self.parameters())
|
|
946
1795
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
947
1796
|
non_trainable_params = total_params - trainable_params
|
|
948
|
-
|
|
1797
|
+
|
|
949
1798
|
logger.info(f"Total Parameters: {total_params:,}")
|
|
950
1799
|
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
951
1800
|
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
952
|
-
|
|
1801
|
+
|
|
953
1802
|
logger.info("Layer-wise Parameters:")
|
|
954
1803
|
for name, module in self.named_children():
|
|
955
1804
|
layer_params = sum(p.numel() for p in module.parameters())
|
|
956
1805
|
if layer_params > 0:
|
|
957
1806
|
logger.info(f" {name:30s}: {layer_params:,}")
|
|
958
|
-
|
|
1807
|
+
|
|
959
1808
|
logger.info("")
|
|
960
1809
|
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
961
1810
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
962
|
-
|
|
1811
|
+
|
|
963
1812
|
logger.info(f"Task Type: {self.task}")
|
|
964
1813
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
965
1814
|
logger.info(f"Metrics: {self.metrics}")
|
|
966
1815
|
logger.info(f"Target Columns: {self.target_columns}")
|
|
967
1816
|
logger.info(f"Device: {self.device}")
|
|
968
|
-
|
|
969
|
-
if hasattr(self,
|
|
1817
|
+
|
|
1818
|
+
if hasattr(self, "optimizer_name"):
|
|
970
1819
|
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
971
1820
|
if self.optimizer_params:
|
|
972
1821
|
for key, value in self.optimizer_params.items():
|
|
973
1822
|
logger.info(f" {key:25s}: {value}")
|
|
974
|
-
|
|
975
|
-
if hasattr(self,
|
|
1823
|
+
|
|
1824
|
+
if hasattr(self, "scheduler_name") and self.scheduler_name:
|
|
976
1825
|
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
977
1826
|
if self.scheduler_params:
|
|
978
1827
|
for key, value in self.scheduler_params.items():
|
|
979
1828
|
logger.info(f" {key:25s}: {value}")
|
|
980
|
-
|
|
981
|
-
if hasattr(self,
|
|
1829
|
+
|
|
1830
|
+
if hasattr(self, "loss_config"):
|
|
982
1831
|
logger.info(f"Loss Function: {self.loss_config}")
|
|
983
|
-
if hasattr(self,
|
|
1832
|
+
if hasattr(self, "loss_weights"):
|
|
984
1833
|
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
985
|
-
|
|
1834
|
+
|
|
986
1835
|
logger.info("Regularization:")
|
|
987
1836
|
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
988
1837
|
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
989
1838
|
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
990
1839
|
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
991
|
-
|
|
1840
|
+
|
|
992
1841
|
logger.info("Other Settings:")
|
|
993
1842
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
994
1843
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
@@ -997,54 +1846,56 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
997
1846
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
998
1847
|
|
|
999
1848
|
|
|
1000
|
-
|
|
1001
1849
|
class BaseMatchModel(BaseModel):
|
|
1002
1850
|
"""
|
|
1003
1851
|
Base class for match (retrieval/recall) models
|
|
1004
1852
|
Supports pointwise, pairwise, and listwise training modes
|
|
1005
1853
|
"""
|
|
1854
|
+
|
|
1006
1855
|
@property
|
|
1007
1856
|
def model_name(self) -> str:
|
|
1008
1857
|
raise NotImplementedError
|
|
1009
1858
|
|
|
1010
1859
|
@property
|
|
1011
|
-
def
|
|
1012
|
-
|
|
1013
|
-
|
|
1860
|
+
def default_task(self) -> str:
|
|
1861
|
+
return "binary"
|
|
1862
|
+
|
|
1014
1863
|
@property
|
|
1015
1864
|
def support_training_modes(self) -> list[str]:
|
|
1016
1865
|
"""
|
|
1017
1866
|
Returns list of supported training modes for this model.
|
|
1018
1867
|
Override in subclasses to restrict training modes.
|
|
1019
|
-
|
|
1868
|
+
|
|
1020
1869
|
Returns:
|
|
1021
1870
|
List of supported modes: ['pointwise', 'pairwise', 'listwise']
|
|
1022
1871
|
"""
|
|
1023
|
-
return [
|
|
1024
|
-
|
|
1025
|
-
def __init__(
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1872
|
+
return ["pointwise", "pairwise", "listwise"]
|
|
1873
|
+
|
|
1874
|
+
def __init__(
|
|
1875
|
+
self,
|
|
1876
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
1877
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
1878
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
1879
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
1880
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
1881
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
1882
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
1883
|
+
num_negative_samples: int = 4,
|
|
1884
|
+
temperature: float = 1.0,
|
|
1885
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
1886
|
+
device: str = "cpu",
|
|
1887
|
+
embedding_l1_reg: float = 0.0,
|
|
1888
|
+
dense_l1_reg: float = 0.0,
|
|
1889
|
+
embedding_l2_reg: float = 0.0,
|
|
1890
|
+
dense_l2_reg: float = 0.0,
|
|
1891
|
+
early_stop_patience: int = 20,
|
|
1892
|
+
**kwargs,
|
|
1893
|
+
):
|
|
1894
|
+
|
|
1044
1895
|
all_dense_features = []
|
|
1045
1896
|
all_sparse_features = []
|
|
1046
1897
|
all_sequence_features = []
|
|
1047
|
-
|
|
1898
|
+
|
|
1048
1899
|
if user_dense_features:
|
|
1049
1900
|
all_dense_features.extend(user_dense_features)
|
|
1050
1901
|
if item_dense_features:
|
|
@@ -1057,117 +1908,175 @@ class BaseMatchModel(BaseModel):
|
|
|
1057
1908
|
all_sequence_features.extend(user_sequence_features)
|
|
1058
1909
|
if item_sequence_features:
|
|
1059
1910
|
all_sequence_features.extend(item_sequence_features)
|
|
1060
|
-
|
|
1911
|
+
|
|
1061
1912
|
super(BaseMatchModel, self).__init__(
|
|
1062
1913
|
dense_features=all_dense_features,
|
|
1063
1914
|
sparse_features=all_sparse_features,
|
|
1064
1915
|
sequence_features=all_sequence_features,
|
|
1065
|
-
target=[
|
|
1066
|
-
task=
|
|
1916
|
+
target=["label"],
|
|
1917
|
+
task="binary",
|
|
1067
1918
|
device=device,
|
|
1068
1919
|
embedding_l1_reg=embedding_l1_reg,
|
|
1069
1920
|
dense_l1_reg=dense_l1_reg,
|
|
1070
1921
|
embedding_l2_reg=embedding_l2_reg,
|
|
1071
1922
|
dense_l2_reg=dense_l2_reg,
|
|
1072
1923
|
early_stop_patience=early_stop_patience,
|
|
1073
|
-
**kwargs
|
|
1924
|
+
**kwargs,
|
|
1925
|
+
)
|
|
1926
|
+
|
|
1927
|
+
self.user_dense_features = (
|
|
1928
|
+
list(user_dense_features) if user_dense_features else []
|
|
1074
1929
|
)
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
self.user_sequence_features =
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
self.
|
|
1083
|
-
|
|
1930
|
+
self.user_sparse_features = (
|
|
1931
|
+
list(user_sparse_features) if user_sparse_features else []
|
|
1932
|
+
)
|
|
1933
|
+
self.user_sequence_features = (
|
|
1934
|
+
list(user_sequence_features) if user_sequence_features else []
|
|
1935
|
+
)
|
|
1936
|
+
|
|
1937
|
+
self.item_dense_features = (
|
|
1938
|
+
list(item_dense_features) if item_dense_features else []
|
|
1939
|
+
)
|
|
1940
|
+
self.item_sparse_features = (
|
|
1941
|
+
list(item_sparse_features) if item_sparse_features else []
|
|
1942
|
+
)
|
|
1943
|
+
self.item_sequence_features = (
|
|
1944
|
+
list(item_sequence_features) if item_sequence_features else []
|
|
1945
|
+
)
|
|
1946
|
+
|
|
1084
1947
|
self.training_mode = training_mode
|
|
1085
1948
|
self.num_negative_samples = num_negative_samples
|
|
1086
1949
|
self.temperature = temperature
|
|
1087
1950
|
self.similarity_metric = similarity_metric
|
|
1088
1951
|
|
|
1089
|
-
self.user_feature_names = [
|
|
1090
|
-
|
|
1952
|
+
self.user_feature_names = [
|
|
1953
|
+
f.name
|
|
1954
|
+
for f in (
|
|
1955
|
+
self.user_dense_features
|
|
1956
|
+
+ self.user_sparse_features
|
|
1957
|
+
+ self.user_sequence_features
|
|
1958
|
+
)
|
|
1959
|
+
]
|
|
1960
|
+
self.item_feature_names = [
|
|
1961
|
+
f.name
|
|
1962
|
+
for f in (
|
|
1963
|
+
self.item_dense_features
|
|
1964
|
+
+ self.item_sparse_features
|
|
1965
|
+
+ self.item_sequence_features
|
|
1966
|
+
)
|
|
1967
|
+
]
|
|
1091
1968
|
|
|
1092
1969
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1093
1970
|
return {
|
|
1094
|
-
name: X_input[name]
|
|
1095
|
-
for name in self.user_feature_names
|
|
1096
|
-
if name in X_input
|
|
1971
|
+
name: X_input[name] for name in self.user_feature_names if name in X_input
|
|
1097
1972
|
}
|
|
1098
1973
|
|
|
1099
1974
|
def get_item_features(self, X_input: dict) -> dict:
|
|
1100
1975
|
return {
|
|
1101
|
-
name: X_input[name]
|
|
1102
|
-
for name in self.item_feature_names
|
|
1103
|
-
if name in X_input
|
|
1976
|
+
name: X_input[name] for name in self.item_feature_names if name in X_input
|
|
1104
1977
|
}
|
|
1105
|
-
|
|
1106
|
-
def compile(
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1978
|
+
|
|
1979
|
+
def compile(
|
|
1980
|
+
self,
|
|
1981
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1982
|
+
optimizer_params: dict | None = None,
|
|
1983
|
+
scheduler: (
|
|
1984
|
+
str
|
|
1985
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
1986
|
+
| torch.optim.lr_scheduler.LRScheduler
|
|
1987
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
1988
|
+
| type[torch.optim.lr_scheduler.LRScheduler]
|
|
1989
|
+
| None
|
|
1990
|
+
) = None,
|
|
1991
|
+
scheduler_params: dict | None = None,
|
|
1992
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1993
|
+
loss_params: dict | list[dict] | None = None,
|
|
1994
|
+
):
|
|
1113
1995
|
"""
|
|
1114
1996
|
Compile match model with optimizer, scheduler, and loss function.
|
|
1115
1997
|
Mirrors BaseModel.compile while adding training_mode validation for match tasks.
|
|
1116
1998
|
"""
|
|
1117
1999
|
if self.training_mode not in self.support_training_modes:
|
|
1118
|
-
raise ValueError(
|
|
2000
|
+
raise ValueError(
|
|
2001
|
+
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2002
|
+
)
|
|
1119
2003
|
# Call parent compile with match-specific logic
|
|
1120
2004
|
optimizer_params = optimizer_params or {}
|
|
1121
|
-
|
|
1122
|
-
self.optimizer_name =
|
|
2005
|
+
|
|
2006
|
+
self.optimizer_name = (
|
|
2007
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
2008
|
+
)
|
|
1123
2009
|
self.optimizer_params = optimizer_params
|
|
1124
2010
|
if isinstance(scheduler, str):
|
|
1125
2011
|
self.scheduler_name = scheduler
|
|
1126
2012
|
elif scheduler is not None:
|
|
1127
2013
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1128
|
-
self.scheduler_name = getattr(
|
|
2014
|
+
self.scheduler_name = getattr(
|
|
2015
|
+
scheduler,
|
|
2016
|
+
"__name__",
|
|
2017
|
+
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
2018
|
+
)
|
|
1129
2019
|
else:
|
|
1130
2020
|
self.scheduler_name = None
|
|
1131
2021
|
self.scheduler_params = scheduler_params or {}
|
|
1132
2022
|
self.loss_config = loss
|
|
1133
2023
|
self.loss_params = loss_params or {}
|
|
1134
2024
|
|
|
1135
|
-
self.optimizer_fn = get_optimizer(
|
|
2025
|
+
self.optimizer_fn = get_optimizer(
|
|
2026
|
+
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
2027
|
+
)
|
|
1136
2028
|
# Set loss function based on training mode
|
|
1137
2029
|
default_losses = {
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
2030
|
+
"pointwise": "bce",
|
|
2031
|
+
"pairwise": "bpr",
|
|
2032
|
+
"listwise": "sampled_softmax",
|
|
1141
2033
|
}
|
|
1142
2034
|
|
|
1143
2035
|
if loss is None:
|
|
1144
2036
|
loss_value = default_losses.get(self.training_mode, "bce")
|
|
1145
2037
|
elif isinstance(loss, list):
|
|
1146
|
-
loss_value =
|
|
2038
|
+
loss_value = (
|
|
2039
|
+
loss[0]
|
|
2040
|
+
if loss and loss[0] is not None
|
|
2041
|
+
else default_losses.get(self.training_mode, "bce")
|
|
2042
|
+
)
|
|
1147
2043
|
else:
|
|
1148
2044
|
loss_value = loss
|
|
1149
2045
|
|
|
1150
2046
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1151
|
-
if self.training_mode in {"pairwise", "listwise"} and loss_value in {
|
|
2047
|
+
if self.training_mode in {"pairwise", "listwise"} and loss_value in {
|
|
2048
|
+
"bce",
|
|
2049
|
+
"binary_crossentropy",
|
|
2050
|
+
}:
|
|
1152
2051
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1153
2052
|
loss_kwargs = get_loss_kwargs(self.loss_params, 0)
|
|
1154
2053
|
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1155
2054
|
# set scheduler
|
|
1156
|
-
self.scheduler_fn =
|
|
2055
|
+
self.scheduler_fn = (
|
|
2056
|
+
get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
2057
|
+
if scheduler
|
|
2058
|
+
else None
|
|
2059
|
+
)
|
|
1157
2060
|
|
|
1158
|
-
def compute_similarity(
|
|
1159
|
-
|
|
2061
|
+
def compute_similarity(
|
|
2062
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2063
|
+
) -> torch.Tensor:
|
|
2064
|
+
if self.similarity_metric == "dot":
|
|
1160
2065
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1161
2066
|
# [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1162
|
-
similarity = torch.sum(
|
|
2067
|
+
similarity = torch.sum(
|
|
2068
|
+
user_emb * item_emb, dim=-1
|
|
2069
|
+
) # [batch_size, num_items]
|
|
1163
2070
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1164
2071
|
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1165
2072
|
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
1166
|
-
similarity = torch.sum(
|
|
2073
|
+
similarity = torch.sum(
|
|
2074
|
+
user_emb_expanded * item_emb, dim=-1
|
|
2075
|
+
) # [batch_size, num_items]
|
|
1167
2076
|
else:
|
|
1168
2077
|
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
1169
|
-
|
|
1170
|
-
elif self.similarity_metric ==
|
|
2078
|
+
|
|
2079
|
+
elif self.similarity_metric == "cosine":
|
|
1171
2080
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1172
2081
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1173
2082
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1175,8 +2084,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1175
2084
|
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
1176
2085
|
else:
|
|
1177
2086
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1178
|
-
|
|
1179
|
-
elif self.similarity_metric ==
|
|
2087
|
+
|
|
2088
|
+
elif self.similarity_metric == "euclidean":
|
|
1180
2089
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1181
2090
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1182
2091
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1184,63 +2093,70 @@ class BaseMatchModel(BaseModel):
|
|
|
1184
2093
|
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
1185
2094
|
else:
|
|
1186
2095
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1187
|
-
similarity = -distance
|
|
1188
|
-
|
|
2096
|
+
similarity = -distance
|
|
2097
|
+
|
|
1189
2098
|
else:
|
|
1190
2099
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1191
2100
|
similarity = similarity / self.temperature
|
|
1192
2101
|
return similarity
|
|
1193
|
-
|
|
2102
|
+
|
|
1194
2103
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
1195
2104
|
raise NotImplementedError
|
|
1196
|
-
|
|
2105
|
+
|
|
1197
2106
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
1198
2107
|
raise NotImplementedError
|
|
1199
|
-
|
|
1200
|
-
def forward(
|
|
2108
|
+
|
|
2109
|
+
def forward(
|
|
2110
|
+
self, X_input: dict
|
|
2111
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1201
2112
|
user_input = self.get_user_features(X_input)
|
|
1202
2113
|
item_input = self.get_item_features(X_input)
|
|
1203
|
-
|
|
1204
|
-
user_emb = self.user_tower(user_input)
|
|
1205
|
-
item_emb = self.item_tower(item_input)
|
|
1206
|
-
|
|
1207
|
-
if self.training and self.training_mode in [
|
|
2114
|
+
|
|
2115
|
+
user_emb = self.user_tower(user_input) # [B, D]
|
|
2116
|
+
item_emb = self.item_tower(item_input) # [B, D]
|
|
2117
|
+
|
|
2118
|
+
if self.training and self.training_mode in ["pairwise", "listwise"]:
|
|
1208
2119
|
return user_emb, item_emb
|
|
1209
2120
|
|
|
1210
2121
|
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
1211
|
-
|
|
1212
|
-
if self.training_mode ==
|
|
2122
|
+
|
|
2123
|
+
if self.training_mode == "pointwise":
|
|
1213
2124
|
return torch.sigmoid(similarity)
|
|
1214
2125
|
else:
|
|
1215
2126
|
return similarity
|
|
1216
|
-
|
|
2127
|
+
|
|
1217
2128
|
def compute_loss(self, y_pred, y_true):
|
|
1218
|
-
if self.training_mode ==
|
|
2129
|
+
if self.training_mode == "pointwise":
|
|
1219
2130
|
if y_true is None:
|
|
1220
2131
|
return torch.tensor(0.0, device=self.device)
|
|
1221
2132
|
return self.loss_fn[0](y_pred, y_true)
|
|
1222
|
-
|
|
2133
|
+
|
|
1223
2134
|
# pairwise / listwise using inbatch neg
|
|
1224
|
-
elif self.training_mode in [
|
|
2135
|
+
elif self.training_mode in ["pairwise", "listwise"]:
|
|
1225
2136
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1226
|
-
raise ValueError(
|
|
1227
|
-
|
|
2137
|
+
raise ValueError(
|
|
2138
|
+
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
2139
|
+
)
|
|
2140
|
+
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1228
2141
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1229
|
-
logits = logits / self.temperature
|
|
2142
|
+
logits = logits / self.temperature
|
|
1230
2143
|
batch_size = logits.size(0)
|
|
1231
|
-
targets = torch.arange(
|
|
2144
|
+
targets = torch.arange(
|
|
2145
|
+
batch_size, device=logits.device
|
|
2146
|
+
) # [0, 1, 2, ..., B-1]
|
|
1232
2147
|
# Cross-Entropy = InfoNCE
|
|
1233
2148
|
loss = F.cross_entropy(logits, targets)
|
|
1234
|
-
return loss
|
|
2149
|
+
return loss
|
|
1235
2150
|
else:
|
|
1236
2151
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1237
2152
|
|
|
1238
|
-
|
|
1239
|
-
|
|
2153
|
+
def prepare_feature_data(
|
|
2154
|
+
self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
|
|
2155
|
+
) -> DataLoader:
|
|
1240
2156
|
"""Prepare data loader for specific features."""
|
|
1241
2157
|
if isinstance(data, DataLoader):
|
|
1242
2158
|
return data
|
|
1243
|
-
|
|
2159
|
+
|
|
1244
2160
|
feature_data = {}
|
|
1245
2161
|
for feature in features:
|
|
1246
2162
|
if isinstance(data, dict):
|
|
@@ -1249,13 +2165,21 @@ class BaseMatchModel(BaseModel):
|
|
|
1249
2165
|
elif isinstance(data, pd.DataFrame):
|
|
1250
2166
|
if feature.name in data.columns:
|
|
1251
2167
|
feature_data[feature.name] = data[feature.name].values
|
|
1252
|
-
return self.prepare_data_loader(
|
|
2168
|
+
return self.prepare_data_loader(
|
|
2169
|
+
feature_data, batch_size=batch_size, shuffle=False
|
|
2170
|
+
)
|
|
1253
2171
|
|
|
1254
|
-
def encode_user(
|
|
2172
|
+
def encode_user(
|
|
2173
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
2174
|
+
) -> np.ndarray:
|
|
1255
2175
|
self.eval()
|
|
1256
|
-
all_user_features =
|
|
2176
|
+
all_user_features = (
|
|
2177
|
+
self.user_dense_features
|
|
2178
|
+
+ self.user_sparse_features
|
|
2179
|
+
+ self.user_sequence_features
|
|
2180
|
+
)
|
|
1257
2181
|
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
1258
|
-
|
|
2182
|
+
|
|
1259
2183
|
embeddings_list = []
|
|
1260
2184
|
with torch.no_grad():
|
|
1261
2185
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
@@ -1264,12 +2188,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1264
2188
|
user_emb = self.user_tower(user_input)
|
|
1265
2189
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1266
2190
|
return np.concatenate(embeddings_list, axis=0)
|
|
1267
|
-
|
|
1268
|
-
def encode_item(
|
|
2191
|
+
|
|
2192
|
+
def encode_item(
|
|
2193
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
2194
|
+
) -> np.ndarray:
|
|
1269
2195
|
self.eval()
|
|
1270
|
-
all_item_features =
|
|
2196
|
+
all_item_features = (
|
|
2197
|
+
self.item_dense_features
|
|
2198
|
+
+ self.item_sparse_features
|
|
2199
|
+
+ self.item_sequence_features
|
|
2200
|
+
)
|
|
1271
2201
|
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
1272
|
-
|
|
2202
|
+
|
|
1273
2203
|
embeddings_list = []
|
|
1274
2204
|
with torch.no_grad():
|
|
1275
2205
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|