nextrec 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__init__.py +1 -1
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +10 -5
- nextrec/basic/callback.py +1 -0
- nextrec/basic/features.py +30 -22
- nextrec/basic/layers.py +250 -112
- nextrec/basic/loggers.py +63 -44
- nextrec/basic/metrics.py +270 -120
- nextrec/basic/model.py +1084 -402
- nextrec/basic/session.py +10 -3
- nextrec/cli.py +492 -0
- nextrec/data/__init__.py +19 -25
- nextrec/data/batch_utils.py +11 -3
- nextrec/data/data_processing.py +51 -45
- nextrec/data/data_utils.py +26 -15
- nextrec/data/dataloader.py +273 -96
- nextrec/data/preprocessor.py +320 -199
- nextrec/loss/listwise.py +17 -9
- nextrec/loss/loss_utils.py +7 -8
- nextrec/loss/pairwise.py +2 -0
- nextrec/loss/pointwise.py +30 -12
- nextrec/models/generative/hstu.py +103 -38
- nextrec/models/match/dssm.py +82 -68
- nextrec/models/match/dssm_v2.py +72 -57
- nextrec/models/match/mind.py +175 -107
- nextrec/models/match/sdm.py +104 -87
- nextrec/models/match/youtube_dnn.py +73 -59
- nextrec/models/multi_task/esmm.py +69 -46
- nextrec/models/multi_task/mmoe.py +91 -53
- nextrec/models/multi_task/ple.py +117 -58
- nextrec/models/multi_task/poso.py +163 -55
- nextrec/models/multi_task/share_bottom.py +63 -36
- nextrec/models/ranking/afm.py +80 -45
- nextrec/models/ranking/autoint.py +74 -57
- nextrec/models/ranking/dcn.py +110 -48
- nextrec/models/ranking/dcn_v2.py +265 -45
- nextrec/models/ranking/deepfm.py +39 -24
- nextrec/models/ranking/dien.py +335 -146
- nextrec/models/ranking/din.py +158 -92
- nextrec/models/ranking/fibinet.py +134 -52
- nextrec/models/ranking/fm.py +68 -26
- nextrec/models/ranking/masknet.py +95 -33
- nextrec/models/ranking/pnn.py +128 -58
- nextrec/models/ranking/widedeep.py +40 -28
- nextrec/models/ranking/xdeepfm.py +67 -40
- nextrec/utils/__init__.py +59 -34
- nextrec/utils/config.py +496 -0
- nextrec/utils/device.py +30 -20
- nextrec/utils/distributed.py +36 -9
- nextrec/utils/embedding.py +1 -0
- nextrec/utils/feature.py +1 -0
- nextrec/utils/file.py +33 -11
- nextrec/utils/initializer.py +61 -16
- nextrec/utils/model.py +22 -0
- nextrec/utils/optimizer.py +25 -9
- nextrec/utils/synthetic_data.py +283 -165
- nextrec/utils/tensor.py +24 -13
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/METADATA +53 -24
- nextrec-0.4.3.dist-info/RECORD +69 -0
- nextrec-0.4.3.dist-info/entry_points.txt +2 -0
- nextrec-0.4.1.dist-info/RECORD +0 -66
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/WHEEL +0 -0
- {nextrec-0.4.1.dist-info → nextrec-0.4.3.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -5,6 +5,7 @@ Date: create on 27/10/2025
|
|
|
5
5
|
Checkpoint: edit on 05/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
import os
|
|
9
10
|
import tqdm
|
|
10
11
|
import pickle
|
|
@@ -25,7 +26,12 @@ from torch.utils.data.distributed import DistributedSampler
|
|
|
25
26
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
26
27
|
|
|
27
28
|
from nextrec.basic.callback import EarlyStopper
|
|
28
|
-
from nextrec.basic.features import
|
|
29
|
+
from nextrec.basic.features import (
|
|
30
|
+
DenseFeature,
|
|
31
|
+
SparseFeature,
|
|
32
|
+
SequenceFeature,
|
|
33
|
+
FeatureSet,
|
|
34
|
+
)
|
|
29
35
|
from nextrec.data.dataloader import TensorDictDataset, RecDataLoader
|
|
30
36
|
|
|
31
37
|
from nextrec.basic.loggers import setup_logger, colorize, TrainingLogger
|
|
@@ -40,9 +46,14 @@ from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
|
40
46
|
from nextrec.utils.tensor import to_tensor
|
|
41
47
|
from nextrec.utils.device import configure_device
|
|
42
48
|
from nextrec.utils.optimizer import get_optimizer, get_scheduler
|
|
43
|
-
from nextrec.utils.distributed import
|
|
49
|
+
from nextrec.utils.distributed import (
|
|
50
|
+
gather_numpy,
|
|
51
|
+
init_process_group,
|
|
52
|
+
add_distributed_sampler,
|
|
53
|
+
)
|
|
44
54
|
from nextrec import __version__
|
|
45
55
|
|
|
56
|
+
|
|
46
57
|
class BaseModel(FeatureSet, nn.Module):
|
|
47
58
|
@property
|
|
48
59
|
def model_name(self) -> str:
|
|
@@ -52,26 +63,27 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
52
63
|
def default_task(self) -> str | list[str]:
|
|
53
64
|
raise NotImplementedError
|
|
54
65
|
|
|
55
|
-
def __init__(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
dense_features: list[DenseFeature] | None = None,
|
|
69
|
+
sparse_features: list[SparseFeature] | None = None,
|
|
70
|
+
sequence_features: list[SequenceFeature] | None = None,
|
|
71
|
+
target: list[str] | str | None = None,
|
|
72
|
+
id_columns: list[str] | str | None = None,
|
|
73
|
+
task: str | list[str] | None = None,
|
|
74
|
+
device: str = "cpu",
|
|
75
|
+
early_stop_patience: int = 20,
|
|
76
|
+
session_id: str | None = None,
|
|
77
|
+
embedding_l1_reg: float = 0.0,
|
|
78
|
+
dense_l1_reg: float = 0.0,
|
|
79
|
+
embedding_l2_reg: float = 0.0,
|
|
80
|
+
dense_l2_reg: float = 0.0,
|
|
81
|
+
distributed: bool = False,
|
|
82
|
+
rank: int | None = None,
|
|
83
|
+
world_size: int | None = None,
|
|
84
|
+
local_rank: int | None = None,
|
|
85
|
+
ddp_find_unused_parameters: bool = False,
|
|
86
|
+
):
|
|
75
87
|
"""
|
|
76
88
|
Initialize a base model.
|
|
77
89
|
|
|
@@ -112,11 +124,19 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
112
124
|
|
|
113
125
|
self.session_id = session_id
|
|
114
126
|
self.session = create_session(session_id)
|
|
115
|
-
self.session_path = self.session.root
|
|
116
|
-
self.checkpoint_path = os.path.join(
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
self.
|
|
127
|
+
self.session_path = self.session.root # pwd/session_id, path for this session
|
|
128
|
+
self.checkpoint_path = os.path.join(
|
|
129
|
+
self.session_path, self.model_name + "_checkpoint.model"
|
|
130
|
+
) # example: pwd/session_id/DeepFM_checkpoint.model
|
|
131
|
+
self.best_path = os.path.join(
|
|
132
|
+
self.session_path, self.model_name + "_best.model"
|
|
133
|
+
)
|
|
134
|
+
self.features_config_path = os.path.join(
|
|
135
|
+
self.session_path, "features_config.pkl"
|
|
136
|
+
)
|
|
137
|
+
self.set_all_features(
|
|
138
|
+
dense_features, sparse_features, sequence_features, target, id_columns
|
|
139
|
+
)
|
|
120
140
|
|
|
121
141
|
self.task = self.default_task if task is None else task
|
|
122
142
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
@@ -125,25 +145,43 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
125
145
|
self.dense_l1_reg = dense_l1_reg
|
|
126
146
|
self.embedding_l2_reg = embedding_l2_reg
|
|
127
147
|
self.dense_l2_reg = dense_l2_reg
|
|
128
|
-
self.regularization_weights = []
|
|
148
|
+
self.regularization_weights = []
|
|
129
149
|
self.embedding_params = []
|
|
130
150
|
self.loss_weight = None
|
|
131
151
|
|
|
132
152
|
self.early_stop_patience = early_stop_patience
|
|
133
|
-
self.max_gradient_norm = 1.0
|
|
153
|
+
self.max_gradient_norm = 1.0
|
|
134
154
|
self.logger_initialized = False
|
|
135
155
|
self.training_logger = None
|
|
136
156
|
|
|
137
|
-
def register_regularization_weights(
|
|
157
|
+
def register_regularization_weights(
|
|
158
|
+
self,
|
|
159
|
+
embedding_attr: str = "embedding",
|
|
160
|
+
exclude_modules: list[str] | None = None,
|
|
161
|
+
include_modules: list[str] | None = None,
|
|
162
|
+
) -> None:
|
|
138
163
|
exclude_modules = exclude_modules or []
|
|
139
164
|
include_modules = include_modules or []
|
|
140
165
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
141
166
|
embed_dict = getattr(embedding_layer, "embed_dict", None)
|
|
142
167
|
if embed_dict is not None:
|
|
143
168
|
self.embedding_params.extend(embed.weight for embed in embed_dict.values())
|
|
144
|
-
skip_types = (
|
|
169
|
+
skip_types = (
|
|
170
|
+
nn.BatchNorm1d,
|
|
171
|
+
nn.BatchNorm2d,
|
|
172
|
+
nn.BatchNorm3d,
|
|
173
|
+
nn.Dropout,
|
|
174
|
+
nn.Dropout2d,
|
|
175
|
+
nn.Dropout3d,
|
|
176
|
+
)
|
|
145
177
|
for name, module in self.named_modules():
|
|
146
|
-
if (
|
|
178
|
+
if (
|
|
179
|
+
module is self
|
|
180
|
+
or embedding_attr in name
|
|
181
|
+
or isinstance(module, skip_types)
|
|
182
|
+
or (include_modules and not any(inc in name for inc in include_modules))
|
|
183
|
+
or any(exc in name for exc in exclude_modules)
|
|
184
|
+
):
|
|
147
185
|
continue
|
|
148
186
|
if isinstance(module, nn.Linear):
|
|
149
187
|
self.regularization_weights.append(module.weight)
|
|
@@ -152,14 +190,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
152
190
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
153
191
|
if self.embedding_params:
|
|
154
192
|
if self.embedding_l1_reg > 0:
|
|
155
|
-
reg_loss += self.embedding_l1_reg * sum(
|
|
193
|
+
reg_loss += self.embedding_l1_reg * sum(
|
|
194
|
+
param.abs().sum() for param in self.embedding_params
|
|
195
|
+
)
|
|
156
196
|
if self.embedding_l2_reg > 0:
|
|
157
|
-
reg_loss += self.embedding_l2_reg * sum(
|
|
197
|
+
reg_loss += self.embedding_l2_reg * sum(
|
|
198
|
+
(param**2).sum() for param in self.embedding_params
|
|
199
|
+
)
|
|
158
200
|
if self.regularization_weights:
|
|
159
201
|
if self.dense_l1_reg > 0:
|
|
160
|
-
reg_loss += self.dense_l1_reg * sum(
|
|
202
|
+
reg_loss += self.dense_l1_reg * sum(
|
|
203
|
+
param.abs().sum() for param in self.regularization_weights
|
|
204
|
+
)
|
|
161
205
|
if self.dense_l2_reg > 0:
|
|
162
|
-
reg_loss += self.dense_l2_reg * sum(
|
|
206
|
+
reg_loss += self.dense_l2_reg * sum(
|
|
207
|
+
(param**2).sum() for param in self.regularization_weights
|
|
208
|
+
)
|
|
163
209
|
return reg_loss
|
|
164
210
|
|
|
165
211
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
@@ -168,51 +214,90 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
168
214
|
X_input = {}
|
|
169
215
|
for feature in self.all_features:
|
|
170
216
|
if feature.name not in feature_source:
|
|
171
|
-
raise KeyError(
|
|
217
|
+
raise KeyError(
|
|
218
|
+
f"[BaseModel-input Error] Feature '{feature.name}' not found in input data."
|
|
219
|
+
)
|
|
172
220
|
feature_data = get_column_data(feature_source, feature.name)
|
|
173
|
-
X_input[feature.name] = to_tensor(
|
|
221
|
+
X_input[feature.name] = to_tensor(
|
|
222
|
+
feature_data,
|
|
223
|
+
dtype=(
|
|
224
|
+
torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
225
|
+
),
|
|
226
|
+
device=self.device,
|
|
227
|
+
)
|
|
174
228
|
y = None
|
|
175
|
-
if
|
|
229
|
+
if len(self.target_columns) > 0 and (
|
|
230
|
+
require_labels
|
|
231
|
+
or (
|
|
232
|
+
label_source
|
|
233
|
+
and any(name in label_source for name in self.target_columns)
|
|
234
|
+
)
|
|
235
|
+
): # need labels: training or eval with labels
|
|
176
236
|
target_tensors = []
|
|
177
237
|
for target_name in self.target_columns:
|
|
178
238
|
if label_source is None or target_name not in label_source:
|
|
179
239
|
if require_labels:
|
|
180
|
-
raise KeyError(
|
|
240
|
+
raise KeyError(
|
|
241
|
+
f"[BaseModel-input Error] Target column '{target_name}' not found in input data."
|
|
242
|
+
)
|
|
181
243
|
continue
|
|
182
244
|
target_data = get_column_data(label_source, target_name)
|
|
183
245
|
if target_data is None:
|
|
184
246
|
if require_labels:
|
|
185
|
-
raise ValueError(
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"[BaseModel-input Error] Target column '{target_name}' contains no data."
|
|
249
|
+
)
|
|
186
250
|
continue
|
|
187
|
-
target_tensor = to_tensor(
|
|
188
|
-
|
|
251
|
+
target_tensor = to_tensor(
|
|
252
|
+
target_data, dtype=torch.float32, device=self.device
|
|
253
|
+
)
|
|
254
|
+
target_tensor = target_tensor.view(
|
|
255
|
+
target_tensor.size(0), -1
|
|
256
|
+
) # always reshape to (batch_size, num_targets)
|
|
189
257
|
target_tensors.append(target_tensor)
|
|
190
258
|
if target_tensors:
|
|
191
259
|
y = torch.cat(target_tensors, dim=1)
|
|
192
|
-
if y.shape[1] == 1:
|
|
260
|
+
if y.shape[1] == 1: # no need to do that again
|
|
193
261
|
y = y.view(-1)
|
|
194
262
|
elif require_labels:
|
|
195
|
-
raise ValueError(
|
|
263
|
+
raise ValueError(
|
|
264
|
+
"[BaseModel-input Error] Labels are required but none were found in the input batch."
|
|
265
|
+
)
|
|
196
266
|
return X_input, y
|
|
197
267
|
|
|
198
|
-
def handle_validation_split(
|
|
268
|
+
def handle_validation_split(
|
|
269
|
+
self,
|
|
270
|
+
train_data: dict | pd.DataFrame,
|
|
271
|
+
validation_split: float,
|
|
272
|
+
batch_size: int,
|
|
273
|
+
shuffle: bool,
|
|
274
|
+
num_workers: int = 0,
|
|
275
|
+
):
|
|
199
276
|
"""
|
|
200
|
-
This function will split training data into training and validation sets when:
|
|
201
|
-
1. valid_data is None;
|
|
277
|
+
This function will split training data into training and validation sets when:
|
|
278
|
+
1. valid_data is None;
|
|
202
279
|
2. validation_split is provided.
|
|
203
280
|
"""
|
|
204
281
|
if not (0 < validation_split < 1):
|
|
205
|
-
raise ValueError(
|
|
282
|
+
raise ValueError(
|
|
283
|
+
f"[BaseModel-validation Error] validation_split must be between 0 and 1, got {validation_split}"
|
|
284
|
+
)
|
|
206
285
|
if not isinstance(train_data, (pd.DataFrame, dict)):
|
|
207
|
-
raise TypeError(
|
|
286
|
+
raise TypeError(
|
|
287
|
+
f"[BaseModel-validation Error] train_data must be a pandas DataFrame or a dict, got {type(train_data)}"
|
|
288
|
+
)
|
|
208
289
|
if isinstance(train_data, pd.DataFrame):
|
|
209
290
|
total_length = len(train_data)
|
|
210
291
|
else:
|
|
211
|
-
sample_key = next(
|
|
212
|
-
|
|
292
|
+
sample_key = next(
|
|
293
|
+
iter(train_data)
|
|
294
|
+
) # pick the first key to check length, for example: 'user_id': [1,2,3,4,5]
|
|
295
|
+
total_length = len(train_data[sample_key]) # len(train_data['user_id'])
|
|
213
296
|
for k, v in train_data.items():
|
|
214
297
|
if len(v) != total_length:
|
|
215
|
-
raise ValueError(
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"[BaseModel-validation Error] Length of field '{k}' ({len(v)}) != length of field '{sample_key}' ({total_length})"
|
|
300
|
+
)
|
|
216
301
|
rng = np.random.default_rng(42)
|
|
217
302
|
indices = rng.permutation(total_length)
|
|
218
303
|
split_idx = int(total_length * (1 - validation_split))
|
|
@@ -225,23 +310,34 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
225
310
|
train_split = {}
|
|
226
311
|
valid_split = {}
|
|
227
312
|
for key, value in train_data.items():
|
|
228
|
-
arr = np.asarray(value)
|
|
313
|
+
arr = np.asarray(value)
|
|
229
314
|
train_split[key] = arr[train_indices]
|
|
230
315
|
valid_split[key] = arr[valid_indices]
|
|
231
|
-
train_loader = self.prepare_data_loader(
|
|
232
|
-
|
|
316
|
+
train_loader = self.prepare_data_loader(
|
|
317
|
+
train_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
|
|
318
|
+
)
|
|
319
|
+
logging.info(
|
|
320
|
+
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
|
|
321
|
+
)
|
|
233
322
|
return train_loader, valid_split
|
|
234
323
|
|
|
235
324
|
def compile(
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
325
|
+
self,
|
|
326
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
327
|
+
optimizer_params: dict | None = None,
|
|
328
|
+
scheduler: (
|
|
329
|
+
str
|
|
330
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
331
|
+
| torch.optim.lr_scheduler.LRScheduler
|
|
332
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
333
|
+
| type[torch.optim.lr_scheduler.LRScheduler]
|
|
334
|
+
| None
|
|
335
|
+
) = None,
|
|
336
|
+
scheduler_params: dict | None = None,
|
|
337
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
338
|
+
loss_params: dict | list[dict] | None = None,
|
|
339
|
+
loss_weights: int | float | list[int | float] | None = None,
|
|
340
|
+
):
|
|
245
341
|
"""
|
|
246
342
|
Configure the model for training.
|
|
247
343
|
Args:
|
|
@@ -258,42 +354,62 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
258
354
|
else:
|
|
259
355
|
self.loss_params = loss_params
|
|
260
356
|
optimizer_params = optimizer_params or {}
|
|
261
|
-
self.optimizer_name =
|
|
357
|
+
self.optimizer_name = (
|
|
358
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
359
|
+
)
|
|
262
360
|
self.optimizer_params = optimizer_params
|
|
263
|
-
self.optimizer_fn = get_optimizer(
|
|
361
|
+
self.optimizer_fn = get_optimizer(
|
|
362
|
+
optimizer=optimizer,
|
|
363
|
+
params=self.parameters(),
|
|
364
|
+
**optimizer_params,
|
|
365
|
+
)
|
|
264
366
|
|
|
265
367
|
scheduler_params = scheduler_params or {}
|
|
266
368
|
if isinstance(scheduler, str):
|
|
267
369
|
self.scheduler_name = scheduler
|
|
268
370
|
elif scheduler is None:
|
|
269
371
|
self.scheduler_name = None
|
|
270
|
-
else:
|
|
271
|
-
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__)
|
|
372
|
+
else: # for custom scheduler instance, need to provide class name for logging
|
|
373
|
+
self.scheduler_name = getattr(scheduler, "__name__", scheduler.__class__.__name__) # type: ignore
|
|
272
374
|
self.scheduler_params = scheduler_params
|
|
273
|
-
self.scheduler_fn = (
|
|
375
|
+
self.scheduler_fn = (
|
|
376
|
+
get_scheduler(scheduler, self.optimizer_fn, **scheduler_params)
|
|
377
|
+
if scheduler
|
|
378
|
+
else None
|
|
379
|
+
)
|
|
274
380
|
|
|
275
381
|
self.loss_config = loss
|
|
276
382
|
self.loss_params = loss_params or {}
|
|
277
383
|
self.loss_fn = []
|
|
278
|
-
if isinstance(loss, list):
|
|
384
|
+
if isinstance(loss, list): # for example: ['bce', 'mse'] -> ['bce', 'mse']
|
|
279
385
|
if len(loss) != self.nums_task:
|
|
280
|
-
raise ValueError(
|
|
386
|
+
raise ValueError(
|
|
387
|
+
f"[BaseModel-compile Error] Number of loss functions ({len(loss)}) must match number of tasks ({self.nums_task})."
|
|
388
|
+
)
|
|
281
389
|
loss_list = [loss[i] for i in range(self.nums_task)]
|
|
282
|
-
else:
|
|
390
|
+
else: # for example: 'bce' -> ['bce', 'bce']
|
|
283
391
|
loss_list = [loss] * self.nums_task
|
|
284
392
|
|
|
285
393
|
if isinstance(self.loss_params, dict):
|
|
286
394
|
params_list = [self.loss_params] * self.nums_task
|
|
287
395
|
else: # list[dict]
|
|
288
|
-
params_list = [
|
|
289
|
-
|
|
396
|
+
params_list = [
|
|
397
|
+
self.loss_params[i] if i < len(self.loss_params) else {}
|
|
398
|
+
for i in range(self.nums_task)
|
|
399
|
+
]
|
|
400
|
+
self.loss_fn = [
|
|
401
|
+
get_loss_fn(loss=loss_list[i], **params_list[i])
|
|
402
|
+
for i in range(self.nums_task)
|
|
403
|
+
]
|
|
290
404
|
|
|
291
405
|
if loss_weights is None:
|
|
292
406
|
self.loss_weights = None
|
|
293
407
|
elif self.nums_task == 1:
|
|
294
408
|
if isinstance(loss_weights, (list, tuple)):
|
|
295
409
|
if len(loss_weights) != 1:
|
|
296
|
-
raise ValueError(
|
|
410
|
+
raise ValueError(
|
|
411
|
+
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
412
|
+
)
|
|
297
413
|
weight_value = loss_weights[0]
|
|
298
414
|
else:
|
|
299
415
|
weight_value = loss_weights
|
|
@@ -304,14 +420,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
304
420
|
elif isinstance(loss_weights, (list, tuple)):
|
|
305
421
|
weights = [float(w) for w in loss_weights]
|
|
306
422
|
if len(weights) != self.nums_task:
|
|
307
|
-
raise ValueError(
|
|
423
|
+
raise ValueError(
|
|
424
|
+
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
425
|
+
)
|
|
308
426
|
else:
|
|
309
|
-
raise TypeError(
|
|
427
|
+
raise TypeError(
|
|
428
|
+
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
429
|
+
)
|
|
310
430
|
self.loss_weights = weights
|
|
311
431
|
|
|
312
432
|
def compute_loss(self, y_pred, y_true):
|
|
313
433
|
if y_true is None:
|
|
314
|
-
raise ValueError(
|
|
434
|
+
raise ValueError(
|
|
435
|
+
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
436
|
+
)
|
|
315
437
|
if self.nums_task == 1:
|
|
316
438
|
if y_pred.dim() == 1:
|
|
317
439
|
y_pred = y_pred.view(-1, 1)
|
|
@@ -319,7 +441,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
319
441
|
y_true = y_true.view(-1, 1)
|
|
320
442
|
if y_pred.shape != y_true.shape:
|
|
321
443
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
322
|
-
task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1]
|
|
444
|
+
task_dim = self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
323
445
|
if task_dim == 1:
|
|
324
446
|
loss = self.loss_fn[0](y_pred.view(-1), y_true.view(-1))
|
|
325
447
|
else:
|
|
@@ -330,12 +452,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
330
452
|
# multi-task
|
|
331
453
|
if y_pred.shape != y_true.shape:
|
|
332
454
|
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
333
|
-
if hasattr(
|
|
334
|
-
|
|
455
|
+
if hasattr(
|
|
456
|
+
self, "prediction_layer"
|
|
457
|
+
): # we need to use registered task_slices for multi-task and multi-class
|
|
458
|
+
slices = self.prediction_layer.task_slices # type: ignore
|
|
335
459
|
else:
|
|
336
460
|
slices = [(i, i + 1) for i in range(self.nums_task)]
|
|
337
461
|
task_losses = []
|
|
338
|
-
for i, (start, end) in enumerate(slices):
|
|
462
|
+
for i, (start, end) in enumerate(slices): # type: ignore
|
|
339
463
|
y_pred_i = y_pred[:, start:end]
|
|
340
464
|
y_true_i = y_true[:, start:end]
|
|
341
465
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
@@ -344,26 +468,55 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
344
468
|
task_losses.append(task_loss)
|
|
345
469
|
return torch.stack(task_losses).sum()
|
|
346
470
|
|
|
347
|
-
def prepare_data_loader(
|
|
471
|
+
def prepare_data_loader(
|
|
472
|
+
self,
|
|
473
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
474
|
+
batch_size: int = 32,
|
|
475
|
+
shuffle: bool = True,
|
|
476
|
+
num_workers: int = 0,
|
|
477
|
+
sampler=None,
|
|
478
|
+
return_dataset: bool = False,
|
|
479
|
+
) -> DataLoader | tuple[DataLoader, TensorDictDataset | None]:
|
|
348
480
|
if isinstance(data, DataLoader):
|
|
349
481
|
return (data, None) if return_dataset else data
|
|
350
|
-
tensors = build_tensors_from_data(
|
|
482
|
+
tensors = build_tensors_from_data(
|
|
483
|
+
data=data,
|
|
484
|
+
raw_data=data,
|
|
485
|
+
features=self.all_features,
|
|
486
|
+
target_columns=self.target_columns,
|
|
487
|
+
id_columns=self.id_columns,
|
|
488
|
+
)
|
|
351
489
|
if tensors is None:
|
|
352
|
-
raise ValueError(
|
|
490
|
+
raise ValueError(
|
|
491
|
+
"[BaseModel-prepare_data_loader Error] No data available to create DataLoader."
|
|
492
|
+
)
|
|
353
493
|
dataset = TensorDictDataset(tensors)
|
|
354
|
-
loader = DataLoader(
|
|
494
|
+
loader = DataLoader(
|
|
495
|
+
dataset,
|
|
496
|
+
batch_size=batch_size,
|
|
497
|
+
shuffle=False if sampler is not None else shuffle,
|
|
498
|
+
sampler=sampler,
|
|
499
|
+
collate_fn=collate_fn,
|
|
500
|
+
num_workers=num_workers,
|
|
501
|
+
)
|
|
355
502
|
return (loader, dataset) if return_dataset else loader
|
|
356
503
|
|
|
357
|
-
def fit(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
504
|
+
def fit(
|
|
505
|
+
self,
|
|
506
|
+
train_data: dict | pd.DataFrame | DataLoader,
|
|
507
|
+
valid_data: dict | pd.DataFrame | DataLoader | None = None,
|
|
508
|
+
metrics: (
|
|
509
|
+
list[str] | dict[str, list[str]] | None
|
|
510
|
+
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
511
|
+
epochs: int = 1,
|
|
512
|
+
shuffle: bool = True,
|
|
513
|
+
batch_size: int = 32,
|
|
514
|
+
user_id_column: str | None = None,
|
|
515
|
+
validation_split: float | None = None,
|
|
516
|
+
num_workers: int = 0,
|
|
517
|
+
tensorboard: bool = True,
|
|
518
|
+
auto_distributed_sampler: bool = True,
|
|
519
|
+
):
|
|
367
520
|
"""
|
|
368
521
|
Train the model.
|
|
369
522
|
|
|
@@ -385,63 +538,168 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
385
538
|
- All ranks must call evaluate() together because it performs collective ops.
|
|
386
539
|
"""
|
|
387
540
|
device_id = self.local_rank if self.device.type == "cuda" else None
|
|
388
|
-
init_process_group(
|
|
541
|
+
init_process_group(
|
|
542
|
+
self.distributed, self.rank, self.world_size, device_id=device_id
|
|
543
|
+
)
|
|
389
544
|
self.to(self.device)
|
|
390
545
|
|
|
391
|
-
if
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
546
|
+
if (
|
|
547
|
+
self.distributed
|
|
548
|
+
and dist.is_available()
|
|
549
|
+
and dist.is_initialized()
|
|
550
|
+
and self.ddp_model is None
|
|
551
|
+
):
|
|
552
|
+
device_ids = (
|
|
553
|
+
[self.local_rank] if self.device.type == "cuda" else None
|
|
554
|
+
) # device_ids means which device to use in ddp
|
|
555
|
+
output_device = (
|
|
556
|
+
self.local_rank if self.device.type == "cuda" else None
|
|
557
|
+
) # output_device means which device to place the output in ddp
|
|
558
|
+
object.__setattr__(
|
|
559
|
+
self,
|
|
560
|
+
"ddp_model",
|
|
561
|
+
DDP(
|
|
562
|
+
self,
|
|
563
|
+
device_ids=device_ids,
|
|
564
|
+
output_device=output_device,
|
|
565
|
+
find_unused_parameters=self.ddp_find_unused_parameters,
|
|
566
|
+
),
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
if (
|
|
570
|
+
not self.logger_initialized and self.is_main_process
|
|
571
|
+
): # only main process initializes logger
|
|
397
572
|
setup_logger(session_id=self.session_id)
|
|
398
573
|
self.logger_initialized = True
|
|
399
|
-
self.training_logger =
|
|
574
|
+
self.training_logger = (
|
|
575
|
+
TrainingLogger(session=self.session, enable_tensorboard=tensorboard)
|
|
576
|
+
if self.is_main_process
|
|
577
|
+
else None
|
|
578
|
+
)
|
|
400
579
|
|
|
401
|
-
self.metrics, self.task_specific_metrics, self.best_metrics_mode =
|
|
402
|
-
|
|
403
|
-
|
|
580
|
+
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
581
|
+
configure_metrics(
|
|
582
|
+
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
583
|
+
)
|
|
584
|
+
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
585
|
+
self.early_stopper = EarlyStopper(
|
|
586
|
+
patience=self.early_stop_patience, mode=self.best_metrics_mode
|
|
587
|
+
)
|
|
588
|
+
self.best_metric = (
|
|
589
|
+
float("-inf") if self.best_metrics_mode == "max" else float("inf")
|
|
590
|
+
)
|
|
404
591
|
|
|
405
|
-
self.needs_user_ids = check_user_id(
|
|
592
|
+
self.needs_user_ids = check_user_id(
|
|
593
|
+
self.metrics, self.task_specific_metrics
|
|
594
|
+
) # check user_id needed for GAUC metrics
|
|
406
595
|
self.epoch_index = 0
|
|
407
596
|
self.stop_training = False
|
|
408
597
|
self.best_checkpoint_path = self.best_path
|
|
409
598
|
|
|
410
599
|
if not auto_distributed_sampler and self.distributed and self.is_main_process:
|
|
411
|
-
logging.info(
|
|
600
|
+
logging.info(
|
|
601
|
+
colorize(
|
|
602
|
+
"[Distributed Info] auto_distributed_sampler=False; assuming data is already sharded per rank.",
|
|
603
|
+
color="yellow",
|
|
604
|
+
)
|
|
605
|
+
)
|
|
412
606
|
|
|
413
607
|
train_sampler: DistributedSampler | None = None
|
|
414
608
|
if validation_split is not None and valid_data is None:
|
|
415
|
-
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
|
|
416
|
-
if
|
|
609
|
+
train_loader, valid_data = self.handle_validation_split(train_data=train_data, validation_split=validation_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
610
|
+
if (
|
|
611
|
+
auto_distributed_sampler
|
|
612
|
+
and self.distributed
|
|
613
|
+
and dist.is_available()
|
|
614
|
+
and dist.is_initialized()
|
|
615
|
+
):
|
|
417
616
|
base_dataset = getattr(train_loader, "dataset", None)
|
|
418
|
-
if base_dataset is not None and not isinstance(
|
|
419
|
-
|
|
420
|
-
|
|
617
|
+
if base_dataset is not None and not isinstance(
|
|
618
|
+
getattr(train_loader, "sampler", None), DistributedSampler
|
|
619
|
+
):
|
|
620
|
+
train_sampler = DistributedSampler(
|
|
621
|
+
base_dataset,
|
|
622
|
+
num_replicas=self.world_size,
|
|
623
|
+
rank=self.rank,
|
|
624
|
+
shuffle=shuffle,
|
|
625
|
+
drop_last=True,
|
|
626
|
+
)
|
|
627
|
+
train_loader = DataLoader(
|
|
628
|
+
base_dataset,
|
|
629
|
+
batch_size=batch_size,
|
|
630
|
+
shuffle=False,
|
|
631
|
+
sampler=train_sampler,
|
|
632
|
+
collate_fn=collate_fn,
|
|
633
|
+
num_workers=num_workers,
|
|
634
|
+
drop_last=True,
|
|
635
|
+
)
|
|
421
636
|
else:
|
|
422
637
|
if isinstance(train_data, DataLoader):
|
|
423
638
|
if auto_distributed_sampler and self.distributed:
|
|
424
|
-
train_loader, train_sampler = add_distributed_sampler(
|
|
639
|
+
train_loader, train_sampler = add_distributed_sampler(
|
|
640
|
+
train_data,
|
|
641
|
+
distributed=self.distributed,
|
|
642
|
+
world_size=self.world_size,
|
|
643
|
+
rank=self.rank,
|
|
644
|
+
shuffle=shuffle,
|
|
645
|
+
drop_last=True,
|
|
646
|
+
default_batch_size=batch_size,
|
|
647
|
+
is_main_process=self.is_main_process,
|
|
648
|
+
)
|
|
425
649
|
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
426
650
|
else:
|
|
427
651
|
train_loader = train_data
|
|
428
652
|
else:
|
|
429
653
|
loader, dataset = self.prepare_data_loader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
430
|
-
if
|
|
431
|
-
|
|
432
|
-
|
|
654
|
+
if (
|
|
655
|
+
auto_distributed_sampler
|
|
656
|
+
and self.distributed
|
|
657
|
+
and dataset is not None
|
|
658
|
+
and dist.is_available()
|
|
659
|
+
and dist.is_initialized()
|
|
660
|
+
):
|
|
661
|
+
train_sampler = DistributedSampler(
|
|
662
|
+
dataset,
|
|
663
|
+
num_replicas=self.world_size,
|
|
664
|
+
rank=self.rank,
|
|
665
|
+
shuffle=shuffle,
|
|
666
|
+
drop_last=True,
|
|
667
|
+
)
|
|
668
|
+
loader = DataLoader(
|
|
669
|
+
dataset,
|
|
670
|
+
batch_size=batch_size,
|
|
671
|
+
shuffle=False,
|
|
672
|
+
sampler=train_sampler,
|
|
673
|
+
collate_fn=collate_fn,
|
|
674
|
+
num_workers=num_workers,
|
|
675
|
+
drop_last=True,
|
|
676
|
+
)
|
|
433
677
|
train_loader = loader
|
|
434
678
|
|
|
435
679
|
# If split-based loader was built without sampler, attach here when enabled
|
|
436
|
-
if
|
|
437
|
-
|
|
680
|
+
if (
|
|
681
|
+
self.distributed
|
|
682
|
+
and auto_distributed_sampler
|
|
683
|
+
and isinstance(train_loader, DataLoader)
|
|
684
|
+
and train_sampler is None
|
|
685
|
+
):
|
|
686
|
+
raise NotImplementedError(
|
|
687
|
+
"[BaseModel-fit Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
|
|
688
|
+
)
|
|
438
689
|
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
439
|
-
|
|
440
|
-
valid_loader, valid_user_ids = self.prepare_validation_data(
|
|
690
|
+
|
|
691
|
+
valid_loader, valid_user_ids = self.prepare_validation_data(
|
|
692
|
+
valid_data=valid_data,
|
|
693
|
+
batch_size=batch_size,
|
|
694
|
+
needs_user_ids=self.needs_user_ids,
|
|
695
|
+
user_id_column=user_id_column,
|
|
696
|
+
num_workers=num_workers,
|
|
697
|
+
auto_distributed_sampler=auto_distributed_sampler,
|
|
698
|
+
)
|
|
441
699
|
try:
|
|
442
700
|
self.steps_per_epoch = len(train_loader)
|
|
443
701
|
is_streaming = False
|
|
444
|
-
except TypeError:
|
|
702
|
+
except TypeError: # streaming data loader does not supported len()
|
|
445
703
|
self.steps_per_epoch = None
|
|
446
704
|
is_streaming = True
|
|
447
705
|
|
|
@@ -455,7 +713,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
455
713
|
host = socket.gethostname()
|
|
456
714
|
tb_cmd = f"tensorboard --logdir {tb_dir} --port 6006"
|
|
457
715
|
ssh_hint = f"ssh -L 6006:localhost:6006 {user}@{host}"
|
|
458
|
-
logging.info(
|
|
716
|
+
logging.info(
|
|
717
|
+
colorize(f"TensorBoard logs saved to: {tb_dir}", color="cyan")
|
|
718
|
+
)
|
|
459
719
|
logging.info(colorize("To view logs, run:", color="cyan"))
|
|
460
720
|
logging.info(colorize(f" {tb_cmd}", color="cyan"))
|
|
461
721
|
logging.info(colorize("Then SSH port forward:", color="cyan"))
|
|
@@ -464,9 +724,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
464
724
|
logging.info("")
|
|
465
725
|
logging.info(colorize("=" * 80, bold=True))
|
|
466
726
|
if is_streaming:
|
|
467
|
-
logging.info(colorize(
|
|
727
|
+
logging.info(colorize("Start streaming training", bold=True))
|
|
468
728
|
else:
|
|
469
|
-
logging.info(colorize(
|
|
729
|
+
logging.info(colorize("Start training", bold=True))
|
|
470
730
|
logging.info(colorize("=" * 80, bold=True))
|
|
471
731
|
logging.info("")
|
|
472
732
|
logging.info(colorize(f"Model device: {self.device}", bold=True))
|
|
@@ -475,13 +735,19 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
475
735
|
self.epoch_index = epoch
|
|
476
736
|
if is_streaming and self.is_main_process:
|
|
477
737
|
logging.info("")
|
|
478
|
-
logging.info(
|
|
738
|
+
logging.info(
|
|
739
|
+
colorize(f"Epoch {epoch + 1}/{epochs}", bold=True)
|
|
740
|
+
) # streaming mode, print epoch header before progress bar
|
|
479
741
|
|
|
480
742
|
# handle train result
|
|
481
|
-
if
|
|
743
|
+
if (
|
|
744
|
+
self.distributed
|
|
745
|
+
and hasattr(train_loader, "sampler")
|
|
746
|
+
and isinstance(train_loader.sampler, DistributedSampler)
|
|
747
|
+
):
|
|
482
748
|
train_loader.sampler.set_epoch(epoch)
|
|
483
|
-
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
484
|
-
if isinstance(train_result, tuple):
|
|
749
|
+
train_result = self.train_epoch(train_loader, is_streaming=is_streaming)
|
|
750
|
+
if isinstance(train_result, tuple): # [avg_loss, metrics_dict]
|
|
485
751
|
train_loss, train_metrics = train_result
|
|
486
752
|
else:
|
|
487
753
|
train_loss = train_result
|
|
@@ -492,7 +758,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
492
758
|
if self.nums_task == 1:
|
|
493
759
|
log_str = f"Epoch {epoch + 1}/{epochs} - Train: loss={train_loss:.4f}"
|
|
494
760
|
if train_metrics:
|
|
495
|
-
metrics_str = ", ".join(
|
|
761
|
+
metrics_str = ", ".join(
|
|
762
|
+
[f"{k}={v:.4f}" for k, v in train_metrics.items()]
|
|
763
|
+
)
|
|
496
764
|
log_str += f", {metrics_str}"
|
|
497
765
|
if self.is_main_process:
|
|
498
766
|
logging.info(colorize(log_str))
|
|
@@ -501,7 +769,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
501
769
|
train_log_payload.update(train_metrics)
|
|
502
770
|
else:
|
|
503
771
|
total_loss_val = np.sum(train_loss) if isinstance(train_loss, np.ndarray) else train_loss # type: ignore
|
|
504
|
-
log_str =
|
|
772
|
+
log_str = (
|
|
773
|
+
f"Epoch {epoch + 1}/{epochs} - Train: loss={total_loss_val:.4f}"
|
|
774
|
+
)
|
|
505
775
|
if train_metrics:
|
|
506
776
|
# group metrics by task
|
|
507
777
|
task_metrics = {}
|
|
@@ -517,7 +787,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
517
787
|
task_metric_strs = []
|
|
518
788
|
for target_name in self.target_columns:
|
|
519
789
|
if target_name in task_metrics:
|
|
520
|
-
metrics_str = ", ".join(
|
|
790
|
+
metrics_str = ", ".join(
|
|
791
|
+
[
|
|
792
|
+
f"{k}={v:.4f}"
|
|
793
|
+
for k, v in task_metrics[target_name].items()
|
|
794
|
+
]
|
|
795
|
+
)
|
|
521
796
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
522
797
|
log_str += ", " + ", ".join(task_metric_strs)
|
|
523
798
|
if self.is_main_process:
|
|
@@ -526,14 +801,27 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
526
801
|
if train_metrics:
|
|
527
802
|
train_log_payload.update(train_metrics)
|
|
528
803
|
if self.training_logger:
|
|
529
|
-
self.training_logger.log_metrics(
|
|
804
|
+
self.training_logger.log_metrics(
|
|
805
|
+
train_log_payload, step=epoch + 1, split="train"
|
|
806
|
+
)
|
|
530
807
|
if valid_loader is not None:
|
|
531
808
|
# pass user_ids only if needed for GAUC metric
|
|
532
|
-
val_metrics = self.evaluate(
|
|
809
|
+
val_metrics = self.evaluate(
|
|
810
|
+
valid_loader,
|
|
811
|
+
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
812
|
+
num_workers=num_workers,
|
|
813
|
+
) # {'auc': 0.75, 'logloss': 0.45} or {'auc_target1': 0.75, 'logloss_target1': 0.45, 'mse_target2': 3.2}
|
|
533
814
|
if self.nums_task == 1:
|
|
534
|
-
metrics_str = ", ".join(
|
|
815
|
+
metrics_str = ", ".join(
|
|
816
|
+
[f"{k}={v:.4f}" for k, v in val_metrics.items()]
|
|
817
|
+
)
|
|
535
818
|
if self.is_main_process:
|
|
536
|
-
logging.info(
|
|
819
|
+
logging.info(
|
|
820
|
+
colorize(
|
|
821
|
+
f" Epoch {epoch + 1}/{epochs} - Valid: {metrics_str}",
|
|
822
|
+
color="cyan",
|
|
823
|
+
)
|
|
824
|
+
)
|
|
537
825
|
else:
|
|
538
826
|
# multi task metrics
|
|
539
827
|
task_metrics = {}
|
|
@@ -548,34 +836,58 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
548
836
|
task_metric_strs = []
|
|
549
837
|
for target_name in self.target_columns:
|
|
550
838
|
if target_name in task_metrics:
|
|
551
|
-
metrics_str = ", ".join(
|
|
839
|
+
metrics_str = ", ".join(
|
|
840
|
+
[
|
|
841
|
+
f"{k}={v:.4f}"
|
|
842
|
+
for k, v in task_metrics[target_name].items()
|
|
843
|
+
]
|
|
844
|
+
)
|
|
552
845
|
task_metric_strs.append(f"{target_name}[{metrics_str}]")
|
|
553
846
|
if self.is_main_process:
|
|
554
|
-
logging.info(
|
|
847
|
+
logging.info(
|
|
848
|
+
colorize(
|
|
849
|
+
f" Epoch {epoch + 1}/{epochs} - Valid: "
|
|
850
|
+
+ ", ".join(task_metric_strs),
|
|
851
|
+
color="cyan",
|
|
852
|
+
)
|
|
853
|
+
)
|
|
555
854
|
if val_metrics and self.training_logger:
|
|
556
|
-
self.training_logger.log_metrics(
|
|
855
|
+
self.training_logger.log_metrics(
|
|
856
|
+
val_metrics, step=epoch + 1, split="valid"
|
|
857
|
+
)
|
|
557
858
|
# Handle empty validation metrics
|
|
558
859
|
if not val_metrics:
|
|
559
860
|
if self.is_main_process:
|
|
560
|
-
self.save_model(
|
|
861
|
+
self.save_model(
|
|
862
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
863
|
+
)
|
|
561
864
|
self.best_checkpoint_path = self.checkpoint_path
|
|
562
|
-
logging.info(
|
|
865
|
+
logging.info(
|
|
866
|
+
colorize(
|
|
867
|
+
"Warning: No validation metrics computed. Skipping validation for this epoch.",
|
|
868
|
+
color="yellow",
|
|
869
|
+
)
|
|
870
|
+
)
|
|
563
871
|
continue
|
|
564
872
|
if self.nums_task == 1:
|
|
565
873
|
primary_metric_key = self.metrics[0]
|
|
566
874
|
else:
|
|
567
875
|
primary_metric_key = f"{self.metrics[0]}_{self.target_columns[0]}"
|
|
568
|
-
primary_metric = val_metrics.get(
|
|
569
|
-
|
|
876
|
+
primary_metric = val_metrics.get(
|
|
877
|
+
primary_metric_key, val_metrics[list(val_metrics.keys())[0]]
|
|
878
|
+
) # get primary metric value, default to first metric if not found
|
|
879
|
+
|
|
570
880
|
# In distributed mode, broadcast primary_metric to ensure all processes use the same value
|
|
571
881
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
572
|
-
metric_tensor = torch.tensor(
|
|
882
|
+
metric_tensor = torch.tensor(
|
|
883
|
+
[primary_metric], device=self.device, dtype=torch.float32
|
|
884
|
+
)
|
|
573
885
|
dist.broadcast(metric_tensor, src=0)
|
|
574
886
|
primary_metric = float(metric_tensor.item())
|
|
575
|
-
|
|
887
|
+
|
|
576
888
|
improved = False
|
|
577
889
|
# early stopping check
|
|
578
|
-
if self.best_metrics_mode ==
|
|
890
|
+
if self.best_metrics_mode == "max":
|
|
579
891
|
if primary_metric > self.best_metric:
|
|
580
892
|
self.best_metric = primary_metric
|
|
581
893
|
improved = True
|
|
@@ -586,19 +898,37 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
586
898
|
|
|
587
899
|
# save checkpoint and best model for main process
|
|
588
900
|
if self.is_main_process:
|
|
589
|
-
self.save_model(
|
|
901
|
+
self.save_model(
|
|
902
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
903
|
+
)
|
|
590
904
|
logging.info(" ")
|
|
591
905
|
if improved:
|
|
592
|
-
logging.info(
|
|
593
|
-
|
|
906
|
+
logging.info(
|
|
907
|
+
colorize(
|
|
908
|
+
f"Validation {primary_metric_key} improved to {self.best_metric:.4f}"
|
|
909
|
+
)
|
|
910
|
+
)
|
|
911
|
+
self.save_model(
|
|
912
|
+
self.best_path, add_timestamp=False, verbose=False
|
|
913
|
+
)
|
|
594
914
|
self.best_checkpoint_path = self.best_path
|
|
595
915
|
self.early_stopper.trial_counter = 0
|
|
596
916
|
else:
|
|
597
917
|
self.early_stopper.trial_counter += 1
|
|
598
|
-
logging.info(
|
|
918
|
+
logging.info(
|
|
919
|
+
colorize(
|
|
920
|
+
f"No improvement for {self.early_stopper.trial_counter} epoch(s)"
|
|
921
|
+
)
|
|
922
|
+
)
|
|
599
923
|
if self.early_stopper.trial_counter >= self.early_stopper.patience:
|
|
600
924
|
self.stop_training = True
|
|
601
|
-
logging.info(
|
|
925
|
+
logging.info(
|
|
926
|
+
colorize(
|
|
927
|
+
f"Early stopping triggered after {epoch + 1} epochs",
|
|
928
|
+
color="bright_red",
|
|
929
|
+
bold=True,
|
|
930
|
+
)
|
|
931
|
+
)
|
|
602
932
|
else:
|
|
603
933
|
# Non-main processes also update trial_counter to keep in sync
|
|
604
934
|
if improved:
|
|
@@ -607,43 +937,55 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
607
937
|
self.early_stopper.trial_counter += 1
|
|
608
938
|
else:
|
|
609
939
|
if self.is_main_process:
|
|
610
|
-
self.save_model(
|
|
940
|
+
self.save_model(
|
|
941
|
+
self.checkpoint_path, add_timestamp=False, verbose=False
|
|
942
|
+
)
|
|
611
943
|
self.save_model(self.best_path, add_timestamp=False, verbose=False)
|
|
612
944
|
self.best_checkpoint_path = self.best_path
|
|
613
945
|
|
|
614
946
|
# Broadcast stop_training flag to all processes (always, regardless of validation)
|
|
615
947
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
616
|
-
stop_tensor = torch.tensor(
|
|
948
|
+
stop_tensor = torch.tensor(
|
|
949
|
+
[int(self.stop_training)], device=self.device
|
|
950
|
+
)
|
|
617
951
|
dist.broadcast(stop_tensor, src=0)
|
|
618
952
|
self.stop_training = bool(stop_tensor.item())
|
|
619
|
-
|
|
953
|
+
|
|
620
954
|
if self.stop_training:
|
|
621
955
|
break
|
|
622
956
|
if self.scheduler_fn is not None:
|
|
623
|
-
if isinstance(
|
|
957
|
+
if isinstance(
|
|
958
|
+
self.scheduler_fn, torch.optim.lr_scheduler.ReduceLROnPlateau
|
|
959
|
+
):
|
|
624
960
|
if valid_loader is not None:
|
|
625
961
|
self.scheduler_fn.step(primary_metric)
|
|
626
962
|
else:
|
|
627
|
-
self.scheduler_fn.step()
|
|
963
|
+
self.scheduler_fn.step()
|
|
628
964
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
629
|
-
dist.barrier()
|
|
965
|
+
dist.barrier() # dist.barrier() will wait for all processes, like async all_reduce()
|
|
630
966
|
if self.is_main_process:
|
|
631
967
|
logging.info(" ")
|
|
632
968
|
logging.info(colorize("Training finished.", bold=True))
|
|
633
969
|
logging.info(" ")
|
|
634
970
|
if valid_loader is not None:
|
|
635
971
|
if self.is_main_process:
|
|
636
|
-
logging.info(
|
|
637
|
-
|
|
972
|
+
logging.info(
|
|
973
|
+
colorize(f"Load best model from: {self.best_checkpoint_path}")
|
|
974
|
+
)
|
|
975
|
+
self.load_model(
|
|
976
|
+
self.best_checkpoint_path, map_location=self.device, verbose=False
|
|
977
|
+
)
|
|
638
978
|
if self.training_logger:
|
|
639
979
|
self.training_logger.close()
|
|
640
980
|
return self
|
|
641
981
|
|
|
642
|
-
def train_epoch(
|
|
982
|
+
def train_epoch(
|
|
983
|
+
self, train_loader: DataLoader, is_streaming: bool = False
|
|
984
|
+
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
643
985
|
# use ddp model for distributed training
|
|
644
986
|
model = self.ddp_model if getattr(self, "ddp_model") is not None else self
|
|
645
987
|
accumulated_loss = 0.0
|
|
646
|
-
model.train()
|
|
988
|
+
model.train() # type: ignore
|
|
647
989
|
num_batches = 0
|
|
648
990
|
y_true_list = []
|
|
649
991
|
y_pred_list = []
|
|
@@ -651,15 +993,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
651
993
|
user_ids_list = [] if self.needs_user_ids else None
|
|
652
994
|
tqdm_disable = not self.is_main_process
|
|
653
995
|
if self.steps_per_epoch is not None:
|
|
654
|
-
batch_iter = enumerate(
|
|
996
|
+
batch_iter = enumerate(
|
|
997
|
+
tqdm.tqdm(
|
|
998
|
+
train_loader,
|
|
999
|
+
desc=f"Epoch {self.epoch_index + 1}",
|
|
1000
|
+
total=self.steps_per_epoch,
|
|
1001
|
+
disable=tqdm_disable,
|
|
1002
|
+
)
|
|
1003
|
+
)
|
|
655
1004
|
else:
|
|
656
1005
|
desc = "Batches" if is_streaming else f"Epoch {self.epoch_index + 1}"
|
|
657
|
-
batch_iter = enumerate(
|
|
1006
|
+
batch_iter = enumerate(
|
|
1007
|
+
tqdm.tqdm(train_loader, desc=desc, disable=tqdm_disable)
|
|
1008
|
+
)
|
|
658
1009
|
for batch_index, batch_data in batch_iter:
|
|
659
1010
|
batch_dict = batch_to_dict(batch_data)
|
|
660
1011
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
661
1012
|
# call via __call__ so DDP hooks run (no grad sync if calling .forward directly)
|
|
662
|
-
y_pred = model(X_input)
|
|
1013
|
+
y_pred = model(X_input) # type: ignore
|
|
663
1014
|
|
|
664
1015
|
loss = self.compute_loss(y_pred, y_true)
|
|
665
1016
|
reg_loss = self.add_reg_loss()
|
|
@@ -667,7 +1018,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
667
1018
|
self.optimizer_fn.zero_grad()
|
|
668
1019
|
total_loss.backward()
|
|
669
1020
|
|
|
670
|
-
params = model.parameters() if self.ddp_model is not None else self.parameters()
|
|
1021
|
+
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
671
1022
|
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
672
1023
|
self.optimizer_fn.step()
|
|
673
1024
|
accumulated_loss += loss.item()
|
|
@@ -675,66 +1026,123 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
675
1026
|
if y_true is not None:
|
|
676
1027
|
y_true_list.append(y_true.detach().cpu().numpy())
|
|
677
1028
|
if self.needs_user_ids and user_ids_list is not None:
|
|
678
|
-
batch_user_id = get_user_ids(
|
|
1029
|
+
batch_user_id = get_user_ids(
|
|
1030
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1031
|
+
)
|
|
679
1032
|
if batch_user_id is not None:
|
|
680
1033
|
user_ids_list.append(batch_user_id)
|
|
681
1034
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
682
1035
|
y_pred_list.append(y_pred.detach().cpu().numpy())
|
|
683
1036
|
num_batches += 1
|
|
684
1037
|
if self.distributed and dist.is_available() and dist.is_initialized():
|
|
685
|
-
loss_tensor = torch.tensor(
|
|
1038
|
+
loss_tensor = torch.tensor(
|
|
1039
|
+
[accumulated_loss, num_batches], device=self.device, dtype=torch.float32
|
|
1040
|
+
)
|
|
686
1041
|
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
|
|
687
1042
|
accumulated_loss = loss_tensor[0].item()
|
|
688
1043
|
num_batches = int(loss_tensor[1].item())
|
|
689
1044
|
avg_loss = accumulated_loss / max(num_batches, 1)
|
|
690
|
-
|
|
1045
|
+
|
|
691
1046
|
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
692
1047
|
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
693
|
-
combined_user_ids_local =
|
|
1048
|
+
combined_user_ids_local = (
|
|
1049
|
+
np.concatenate(user_ids_list, axis=0)
|
|
1050
|
+
if self.needs_user_ids and user_ids_list
|
|
1051
|
+
else None
|
|
1052
|
+
)
|
|
694
1053
|
|
|
695
1054
|
# gather across ranks even when local is empty to avoid DDP hang
|
|
696
1055
|
y_true_all = gather_numpy(self, y_true_all_local)
|
|
697
1056
|
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
698
|
-
combined_user_ids =
|
|
1057
|
+
combined_user_ids = (
|
|
1058
|
+
gather_numpy(self, combined_user_ids_local) if self.needs_user_ids else None
|
|
1059
|
+
)
|
|
699
1060
|
|
|
700
|
-
if
|
|
701
|
-
|
|
1061
|
+
if (
|
|
1062
|
+
y_true_all is not None
|
|
1063
|
+
and y_pred_all is not None
|
|
1064
|
+
and len(y_true_all) > 0
|
|
1065
|
+
and len(y_pred_all) > 0
|
|
1066
|
+
):
|
|
1067
|
+
metrics_dict = evaluate_metrics(
|
|
1068
|
+
y_true=y_true_all,
|
|
1069
|
+
y_pred=y_pred_all,
|
|
1070
|
+
metrics=self.metrics,
|
|
1071
|
+
task=self.task,
|
|
1072
|
+
target_names=self.target_columns,
|
|
1073
|
+
task_specific_metrics=self.task_specific_metrics,
|
|
1074
|
+
user_ids=combined_user_ids,
|
|
1075
|
+
)
|
|
702
1076
|
return avg_loss, metrics_dict
|
|
703
1077
|
return avg_loss
|
|
704
1078
|
|
|
705
|
-
def prepare_validation_data(
|
|
1079
|
+
def prepare_validation_data(
|
|
1080
|
+
self,
|
|
1081
|
+
valid_data: dict | pd.DataFrame | DataLoader | None,
|
|
1082
|
+
batch_size: int,
|
|
1083
|
+
needs_user_ids: bool,
|
|
1084
|
+
user_id_column: str | None = "user_id",
|
|
1085
|
+
num_workers: int = 0,
|
|
1086
|
+
auto_distributed_sampler: bool = True,
|
|
1087
|
+
) -> tuple[DataLoader | None, np.ndarray | None]:
|
|
706
1088
|
if valid_data is None:
|
|
707
1089
|
return None, None
|
|
708
1090
|
if isinstance(valid_data, DataLoader):
|
|
709
1091
|
if auto_distributed_sampler and self.distributed:
|
|
710
|
-
raise NotImplementedError(
|
|
1092
|
+
raise NotImplementedError(
|
|
1093
|
+
"[BaseModel-prepare_validation_data Error] auto_distributed_sampler with pre-defined DataLoader is not supported yet."
|
|
1094
|
+
)
|
|
711
1095
|
# valid_loader, _ = add_distributed_sampler(valid_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=False, drop_last=False, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
712
1096
|
else:
|
|
713
1097
|
valid_loader = valid_data
|
|
714
1098
|
return valid_loader, None
|
|
715
1099
|
valid_sampler = None
|
|
716
1100
|
valid_loader, valid_dataset = self.prepare_data_loader(valid_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, return_dataset=True) # type: ignore
|
|
717
|
-
if
|
|
718
|
-
|
|
719
|
-
|
|
1101
|
+
if (
|
|
1102
|
+
auto_distributed_sampler
|
|
1103
|
+
and self.distributed
|
|
1104
|
+
and valid_dataset is not None
|
|
1105
|
+
and dist.is_available()
|
|
1106
|
+
and dist.is_initialized()
|
|
1107
|
+
):
|
|
1108
|
+
valid_sampler = DistributedSampler(
|
|
1109
|
+
valid_dataset,
|
|
1110
|
+
num_replicas=self.world_size,
|
|
1111
|
+
rank=self.rank,
|
|
1112
|
+
shuffle=False,
|
|
1113
|
+
drop_last=False,
|
|
1114
|
+
)
|
|
1115
|
+
valid_loader = DataLoader(
|
|
1116
|
+
valid_dataset,
|
|
1117
|
+
batch_size=batch_size,
|
|
1118
|
+
shuffle=False,
|
|
1119
|
+
sampler=valid_sampler,
|
|
1120
|
+
collate_fn=collate_fn,
|
|
1121
|
+
num_workers=num_workers,
|
|
1122
|
+
)
|
|
720
1123
|
valid_user_ids = None
|
|
721
1124
|
if needs_user_ids:
|
|
722
1125
|
if user_id_column is None:
|
|
723
|
-
raise ValueError(
|
|
1126
|
+
raise ValueError(
|
|
1127
|
+
"[BaseModel-validation Error] user_id_column must be specified when user IDs are needed for validation metrics."
|
|
1128
|
+
)
|
|
724
1129
|
# In distributed mode, user_ids will be collected during evaluation from each batch
|
|
725
1130
|
# and gathered across all processes, so we don't pre-extract them here
|
|
726
1131
|
if not self.distributed:
|
|
727
|
-
valid_user_ids = get_user_ids(
|
|
1132
|
+
valid_user_ids = get_user_ids(
|
|
1133
|
+
data=valid_data, id_columns=user_id_column
|
|
1134
|
+
)
|
|
728
1135
|
return valid_loader, valid_user_ids
|
|
729
1136
|
|
|
730
1137
|
def evaluate(
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
1138
|
+
self,
|
|
1139
|
+
data: dict | pd.DataFrame | DataLoader,
|
|
1140
|
+
metrics: list[str] | dict[str, list[str]] | None = None,
|
|
1141
|
+
batch_size: int = 32,
|
|
1142
|
+
user_ids: np.ndarray | None = None,
|
|
1143
|
+
user_id_column: str = "user_id",
|
|
1144
|
+
num_workers: int = 0,
|
|
1145
|
+
) -> dict:
|
|
738
1146
|
"""
|
|
739
1147
|
**IMPORTANT for Distributed Training:**
|
|
740
1148
|
in distributed mode, this method uses collective communication operations (all_gather).
|
|
@@ -755,15 +1163,19 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
755
1163
|
model.eval()
|
|
756
1164
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
757
1165
|
if eval_metrics is None:
|
|
758
|
-
raise ValueError(
|
|
1166
|
+
raise ValueError(
|
|
1167
|
+
"[BaseModel-evaluate Error] No metrics specified for evaluation. Please provide metrics parameter or call fit() first."
|
|
1168
|
+
)
|
|
759
1169
|
needs_user_ids = check_user_id(eval_metrics, self.task_specific_metrics)
|
|
760
|
-
|
|
1170
|
+
|
|
761
1171
|
if isinstance(data, DataLoader):
|
|
762
1172
|
data_loader = data
|
|
763
1173
|
else:
|
|
764
1174
|
if user_ids is None and needs_user_ids:
|
|
765
1175
|
user_ids = get_user_ids(data=data, id_columns=user_id_column)
|
|
766
|
-
data_loader = self.prepare_data_loader(
|
|
1176
|
+
data_loader = self.prepare_data_loader(
|
|
1177
|
+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
1178
|
+
)
|
|
767
1179
|
y_true_list = []
|
|
768
1180
|
y_pred_list = []
|
|
769
1181
|
collected_user_ids = []
|
|
@@ -779,15 +1191,19 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
779
1191
|
if y_pred is not None and isinstance(y_pred, torch.Tensor):
|
|
780
1192
|
y_pred_list.append(y_pred.cpu().numpy())
|
|
781
1193
|
if needs_user_ids and user_ids is None:
|
|
782
|
-
batch_user_id = get_user_ids(
|
|
1194
|
+
batch_user_id = get_user_ids(
|
|
1195
|
+
data=batch_dict, id_columns=self.id_columns
|
|
1196
|
+
)
|
|
783
1197
|
if batch_user_id is not None:
|
|
784
1198
|
collected_user_ids.append(batch_user_id)
|
|
785
1199
|
if self.is_main_process:
|
|
786
1200
|
logging.info(" ")
|
|
787
|
-
logging.info(
|
|
1201
|
+
logging.info(
|
|
1202
|
+
colorize(f" Evaluation batches processed: {batch_count}", color="cyan")
|
|
1203
|
+
)
|
|
788
1204
|
y_true_all_local = np.concatenate(y_true_list, axis=0) if y_true_list else None
|
|
789
1205
|
y_pred_all_local = np.concatenate(y_pred_list, axis=0) if y_pred_list else None
|
|
790
|
-
|
|
1206
|
+
|
|
791
1207
|
# Convert metrics to list if it's a dict
|
|
792
1208
|
if isinstance(eval_metrics, dict):
|
|
793
1209
|
# For dict metrics, we need to collect all unique metric names
|
|
@@ -798,7 +1214,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
798
1214
|
unique_metrics.append(m)
|
|
799
1215
|
metrics_to_use = unique_metrics
|
|
800
1216
|
else:
|
|
801
|
-
metrics_to_use = eval_metrics
|
|
1217
|
+
metrics_to_use = eval_metrics
|
|
802
1218
|
final_user_ids_local = user_ids
|
|
803
1219
|
if final_user_ids_local is None and collected_user_ids:
|
|
804
1220
|
final_user_ids_local = np.concatenate(collected_user_ids, axis=0)
|
|
@@ -806,28 +1222,50 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
806
1222
|
# gather across ranks even when local arrays are empty to keep collectives aligned
|
|
807
1223
|
y_true_all = gather_numpy(self, y_true_all_local)
|
|
808
1224
|
y_pred_all = gather_numpy(self, y_pred_all_local)
|
|
809
|
-
final_user_ids =
|
|
810
|
-
|
|
1225
|
+
final_user_ids = (
|
|
1226
|
+
gather_numpy(self, final_user_ids_local) if needs_user_ids else None
|
|
1227
|
+
)
|
|
1228
|
+
if (
|
|
1229
|
+
y_true_all is None
|
|
1230
|
+
or y_pred_all is None
|
|
1231
|
+
or len(y_true_all) == 0
|
|
1232
|
+
or len(y_pred_all) == 0
|
|
1233
|
+
):
|
|
811
1234
|
if self.is_main_process:
|
|
812
|
-
logging.info(
|
|
1235
|
+
logging.info(
|
|
1236
|
+
colorize(
|
|
1237
|
+
" Warning: Not enough evaluation data to compute metrics after gathering",
|
|
1238
|
+
color="yellow",
|
|
1239
|
+
)
|
|
1240
|
+
)
|
|
813
1241
|
return {}
|
|
814
1242
|
if self.is_main_process:
|
|
815
|
-
logging.info(
|
|
816
|
-
|
|
1243
|
+
logging.info(
|
|
1244
|
+
colorize(f" Evaluation samples: {y_true_all.shape[0]}", color="cyan")
|
|
1245
|
+
)
|
|
1246
|
+
metrics_dict = evaluate_metrics(
|
|
1247
|
+
y_true=y_true_all,
|
|
1248
|
+
y_pred=y_pred_all,
|
|
1249
|
+
metrics=metrics_to_use,
|
|
1250
|
+
task=self.task,
|
|
1251
|
+
target_names=self.target_columns,
|
|
1252
|
+
task_specific_metrics=self.task_specific_metrics,
|
|
1253
|
+
user_ids=final_user_ids,
|
|
1254
|
+
)
|
|
817
1255
|
return metrics_dict
|
|
818
1256
|
|
|
819
1257
|
def predict(
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
1258
|
+
self,
|
|
1259
|
+
data: str | dict | pd.DataFrame | DataLoader,
|
|
1260
|
+
batch_size: int = 32,
|
|
1261
|
+
save_path: str | os.PathLike | None = None,
|
|
1262
|
+
save_format: Literal["csv", "parquet"] = "csv",
|
|
1263
|
+
include_ids: bool | None = None,
|
|
1264
|
+
id_columns: str | list[str] | None = None,
|
|
1265
|
+
return_dataframe: bool = True,
|
|
1266
|
+
streaming_chunk_size: int = 10000,
|
|
1267
|
+
num_workers: int = 0,
|
|
1268
|
+
) -> pd.DataFrame | np.ndarray:
|
|
831
1269
|
"""
|
|
832
1270
|
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
833
1271
|
Make predictions on the given data.
|
|
@@ -848,28 +1286,53 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
848
1286
|
predict_id_columns = id_columns if id_columns is not None else self.id_columns
|
|
849
1287
|
if isinstance(predict_id_columns, str):
|
|
850
1288
|
predict_id_columns = [predict_id_columns]
|
|
851
|
-
|
|
1289
|
+
|
|
852
1290
|
if include_ids is None:
|
|
853
1291
|
include_ids = bool(predict_id_columns)
|
|
854
1292
|
include_ids = include_ids and bool(predict_id_columns)
|
|
855
1293
|
|
|
856
1294
|
# Use streaming mode for large file saves without loading all data into memory
|
|
857
1295
|
if save_path is not None and not return_dataframe:
|
|
858
|
-
return self.predict_streaming(
|
|
859
|
-
|
|
1296
|
+
return self.predict_streaming(
|
|
1297
|
+
data=data,
|
|
1298
|
+
batch_size=batch_size,
|
|
1299
|
+
save_path=save_path,
|
|
1300
|
+
save_format=save_format,
|
|
1301
|
+
include_ids=include_ids,
|
|
1302
|
+
streaming_chunk_size=streaming_chunk_size,
|
|
1303
|
+
return_dataframe=return_dataframe,
|
|
1304
|
+
id_columns=predict_id_columns,
|
|
1305
|
+
)
|
|
1306
|
+
|
|
860
1307
|
# Create DataLoader based on data type
|
|
861
1308
|
if isinstance(data, DataLoader):
|
|
862
1309
|
data_loader = data
|
|
863
1310
|
elif isinstance(data, (str, os.PathLike)):
|
|
864
|
-
rec_loader = RecDataLoader(
|
|
865
|
-
|
|
1311
|
+
rec_loader = RecDataLoader(
|
|
1312
|
+
dense_features=self.dense_features,
|
|
1313
|
+
sparse_features=self.sparse_features,
|
|
1314
|
+
sequence_features=self.sequence_features,
|
|
1315
|
+
target=self.target_columns,
|
|
1316
|
+
id_columns=predict_id_columns,
|
|
1317
|
+
)
|
|
1318
|
+
data_loader = rec_loader.create_dataloader(
|
|
1319
|
+
data=data,
|
|
1320
|
+
batch_size=batch_size,
|
|
1321
|
+
shuffle=False,
|
|
1322
|
+
load_full=False,
|
|
1323
|
+
chunk_size=streaming_chunk_size,
|
|
1324
|
+
)
|
|
866
1325
|
else:
|
|
867
|
-
data_loader = self.prepare_data_loader(
|
|
868
|
-
|
|
1326
|
+
data_loader = self.prepare_data_loader(
|
|
1327
|
+
data, batch_size=batch_size, shuffle=False, num_workers=num_workers
|
|
1328
|
+
)
|
|
1329
|
+
|
|
869
1330
|
y_pred_list = []
|
|
870
|
-
id_buffers =
|
|
1331
|
+
id_buffers = (
|
|
1332
|
+
{name: [] for name in (predict_id_columns or [])} if include_ids else {}
|
|
1333
|
+
)
|
|
871
1334
|
id_arrays = None
|
|
872
|
-
|
|
1335
|
+
|
|
873
1336
|
with torch.no_grad():
|
|
874
1337
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
875
1338
|
batch_dict = batch_to_dict(batch_data, include_ids=include_ids)
|
|
@@ -882,8 +1345,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
882
1345
|
if id_name not in batch_dict["ids"]:
|
|
883
1346
|
continue
|
|
884
1347
|
id_tensor = batch_dict["ids"][id_name]
|
|
885
|
-
id_np =
|
|
886
|
-
|
|
1348
|
+
id_np = (
|
|
1349
|
+
id_tensor.detach().cpu().numpy()
|
|
1350
|
+
if isinstance(id_tensor, torch.Tensor)
|
|
1351
|
+
else np.asarray(id_tensor)
|
|
1352
|
+
)
|
|
1353
|
+
id_buffers[id_name].append(
|
|
1354
|
+
id_np.reshape(id_np.shape[0], -1)
|
|
1355
|
+
if id_np.ndim == 1
|
|
1356
|
+
else id_np
|
|
1357
|
+
)
|
|
887
1358
|
if len(y_pred_list) > 0:
|
|
888
1359
|
y_pred_all = np.concatenate(y_pred_list, axis=0)
|
|
889
1360
|
else:
|
|
@@ -898,14 +1369,16 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
898
1369
|
pred_columns: list[str] = []
|
|
899
1370
|
if self.target_columns:
|
|
900
1371
|
for name in self.target_columns[:num_outputs]:
|
|
901
|
-
pred_columns.append(f"{name}
|
|
1372
|
+
pred_columns.append(f"{name}")
|
|
902
1373
|
while len(pred_columns) < num_outputs:
|
|
903
1374
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
904
1375
|
if include_ids and predict_id_columns:
|
|
905
1376
|
id_arrays = {}
|
|
906
1377
|
for id_name, pieces in id_buffers.items():
|
|
907
1378
|
if pieces:
|
|
908
|
-
concatenated = np.concatenate(
|
|
1379
|
+
concatenated = np.concatenate(
|
|
1380
|
+
[p.reshape(p.shape[0], -1) for p in pieces], axis=0
|
|
1381
|
+
)
|
|
909
1382
|
id_arrays[id_name] = concatenated.reshape(concatenated.shape[0])
|
|
910
1383
|
else:
|
|
911
1384
|
id_arrays[id_name] = np.array([], dtype=np.int64)
|
|
@@ -913,17 +1386,31 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
913
1386
|
id_df = pd.DataFrame(id_arrays)
|
|
914
1387
|
pred_df = pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
915
1388
|
if len(id_df) and len(pred_df) and len(id_df) != len(pred_df):
|
|
916
|
-
raise ValueError(
|
|
1389
|
+
raise ValueError(
|
|
1390
|
+
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(pred_df)})."
|
|
1391
|
+
)
|
|
917
1392
|
output = pd.concat([id_df, pred_df], axis=1)
|
|
918
1393
|
else:
|
|
919
1394
|
output = y_pred_all
|
|
920
1395
|
else:
|
|
921
|
-
output =
|
|
1396
|
+
output = (
|
|
1397
|
+
pd.DataFrame(y_pred_all, columns=pred_columns)
|
|
1398
|
+
if return_dataframe
|
|
1399
|
+
else y_pred_all
|
|
1400
|
+
)
|
|
922
1401
|
if save_path is not None:
|
|
923
1402
|
if save_format not in ("csv", "parquet"):
|
|
924
|
-
raise ValueError(
|
|
1403
|
+
raise ValueError(
|
|
1404
|
+
f"[BaseModel-predict Error] Unsupported save_format '{save_format}'. Choose from 'csv' or 'parquet'."
|
|
1405
|
+
)
|
|
925
1406
|
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
926
|
-
target_path = resolve_save_path(
|
|
1407
|
+
target_path = resolve_save_path(
|
|
1408
|
+
path=save_path,
|
|
1409
|
+
default_dir=self.session.predictions_dir,
|
|
1410
|
+
default_name="predictions",
|
|
1411
|
+
suffix=suffix,
|
|
1412
|
+
add_timestamp=True if save_path is None else False,
|
|
1413
|
+
)
|
|
927
1414
|
if isinstance(output, pd.DataFrame):
|
|
928
1415
|
df_to_save = output
|
|
929
1416
|
else:
|
|
@@ -931,13 +1418,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
931
1418
|
if include_ids and predict_id_columns and id_arrays is not None:
|
|
932
1419
|
id_df = pd.DataFrame(id_arrays)
|
|
933
1420
|
if len(id_df) and len(df_to_save) and len(id_df) != len(df_to_save):
|
|
934
|
-
raise ValueError(
|
|
1421
|
+
raise ValueError(
|
|
1422
|
+
f"[BaseModel-predict Error] Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_to_save)})."
|
|
1423
|
+
)
|
|
935
1424
|
df_to_save = pd.concat([id_df, df_to_save], axis=1)
|
|
936
1425
|
if save_format == "csv":
|
|
937
1426
|
df_to_save.to_csv(target_path, index=False)
|
|
938
1427
|
else:
|
|
939
1428
|
df_to_save.to_parquet(target_path, index=False)
|
|
940
|
-
logging.info(
|
|
1429
|
+
logging.info(
|
|
1430
|
+
colorize(f"Predictions saved to: {target_path}", color="green")
|
|
1431
|
+
)
|
|
941
1432
|
return output
|
|
942
1433
|
|
|
943
1434
|
def predict_streaming(
|
|
@@ -952,21 +1443,43 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
952
1443
|
id_columns: list[str] | None = None,
|
|
953
1444
|
) -> pd.DataFrame:
|
|
954
1445
|
if isinstance(data, (str, os.PathLike)):
|
|
955
|
-
rec_loader = RecDataLoader(
|
|
956
|
-
|
|
1446
|
+
rec_loader = RecDataLoader(
|
|
1447
|
+
dense_features=self.dense_features,
|
|
1448
|
+
sparse_features=self.sparse_features,
|
|
1449
|
+
sequence_features=self.sequence_features,
|
|
1450
|
+
target=self.target_columns,
|
|
1451
|
+
id_columns=id_columns,
|
|
1452
|
+
)
|
|
1453
|
+
data_loader = rec_loader.create_dataloader(
|
|
1454
|
+
data=data,
|
|
1455
|
+
batch_size=batch_size,
|
|
1456
|
+
shuffle=False,
|
|
1457
|
+
load_full=False,
|
|
1458
|
+
chunk_size=streaming_chunk_size,
|
|
1459
|
+
)
|
|
957
1460
|
elif not isinstance(data, DataLoader):
|
|
958
|
-
data_loader = self.prepare_data_loader(
|
|
1461
|
+
data_loader = self.prepare_data_loader(
|
|
1462
|
+
data,
|
|
1463
|
+
batch_size=batch_size,
|
|
1464
|
+
shuffle=False,
|
|
1465
|
+
)
|
|
959
1466
|
else:
|
|
960
1467
|
data_loader = data
|
|
961
1468
|
|
|
962
1469
|
suffix = ".csv" if save_format == "csv" else ".parquet"
|
|
963
|
-
target_path = resolve_save_path(
|
|
1470
|
+
target_path = resolve_save_path(
|
|
1471
|
+
path=save_path,
|
|
1472
|
+
default_dir=self.session.predictions_dir,
|
|
1473
|
+
default_name="predictions",
|
|
1474
|
+
suffix=suffix,
|
|
1475
|
+
add_timestamp=True if save_path is None else False,
|
|
1476
|
+
)
|
|
964
1477
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
965
1478
|
header_written = target_path.exists() and target_path.stat().st_size > 0
|
|
966
1479
|
parquet_writer = None
|
|
967
1480
|
|
|
968
1481
|
pred_columns = None
|
|
969
|
-
collected_frames = []
|
|
1482
|
+
collected_frames = [] # only used when return_dataframe is True
|
|
970
1483
|
|
|
971
1484
|
with torch.no_grad():
|
|
972
1485
|
for batch_data in tqdm.tqdm(data_loader, desc="Predicting"):
|
|
@@ -983,35 +1496,45 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
983
1496
|
pred_columns = []
|
|
984
1497
|
if self.target_columns:
|
|
985
1498
|
for name in self.target_columns[:num_outputs]:
|
|
986
|
-
pred_columns.append(f"{name}
|
|
1499
|
+
pred_columns.append(f"{name}")
|
|
987
1500
|
while len(pred_columns) < num_outputs:
|
|
988
1501
|
pred_columns.append(f"pred_{len(pred_columns)}")
|
|
989
|
-
|
|
1502
|
+
|
|
990
1503
|
id_arrays_batch = {}
|
|
991
1504
|
if include_ids and id_columns and batch_dict.get("ids"):
|
|
992
1505
|
for id_name in id_columns:
|
|
993
1506
|
if id_name not in batch_dict["ids"]:
|
|
994
1507
|
continue
|
|
995
1508
|
id_tensor = batch_dict["ids"][id_name]
|
|
996
|
-
id_np =
|
|
1509
|
+
id_np = (
|
|
1510
|
+
id_tensor.detach().cpu().numpy()
|
|
1511
|
+
if isinstance(id_tensor, torch.Tensor)
|
|
1512
|
+
else np.asarray(id_tensor)
|
|
1513
|
+
)
|
|
997
1514
|
id_arrays_batch[id_name] = id_np.reshape(id_np.shape[0])
|
|
998
1515
|
|
|
999
1516
|
df_batch = pd.DataFrame(y_pred_np, columns=pred_columns)
|
|
1000
1517
|
if id_arrays_batch:
|
|
1001
1518
|
id_df = pd.DataFrame(id_arrays_batch)
|
|
1002
1519
|
if len(id_df) and len(df_batch) and len(id_df) != len(df_batch):
|
|
1003
|
-
raise ValueError(
|
|
1520
|
+
raise ValueError(
|
|
1521
|
+
f"Mismatch between id rows ({len(id_df)}) and prediction rows ({len(df_batch)})."
|
|
1522
|
+
)
|
|
1004
1523
|
df_batch = pd.concat([id_df, df_batch], axis=1)
|
|
1005
1524
|
|
|
1006
1525
|
if save_format == "csv":
|
|
1007
|
-
df_batch.to_csv(
|
|
1526
|
+
df_batch.to_csv(
|
|
1527
|
+
target_path, mode="a", header=not header_written, index=False
|
|
1528
|
+
)
|
|
1008
1529
|
header_written = True
|
|
1009
1530
|
else:
|
|
1010
1531
|
try:
|
|
1011
1532
|
import pyarrow as pa
|
|
1012
1533
|
import pyarrow.parquet as pq
|
|
1013
1534
|
except ImportError as exc: # pragma: no cover
|
|
1014
|
-
raise ImportError(
|
|
1535
|
+
raise ImportError(
|
|
1536
|
+
"[BaseModel-predict-streaming Error] Parquet streaming save requires pyarrow to be installed."
|
|
1537
|
+
) from exc
|
|
1015
1538
|
table = pa.Table.from_pandas(df_batch, preserve_index=False)
|
|
1016
1539
|
if parquet_writer is None:
|
|
1017
1540
|
parquet_writer = pq.ParquetWriter(target_path, table.schema)
|
|
@@ -1022,15 +1545,34 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1022
1545
|
parquet_writer.close()
|
|
1023
1546
|
logging.info(colorize(f"Predictions saved to: {target_path}", color="green"))
|
|
1024
1547
|
if return_dataframe:
|
|
1025
|
-
return
|
|
1548
|
+
return (
|
|
1549
|
+
pd.concat(collected_frames, ignore_index=True)
|
|
1550
|
+
if collected_frames
|
|
1551
|
+
else pd.DataFrame(columns=pred_columns or [])
|
|
1552
|
+
)
|
|
1026
1553
|
return pd.DataFrame(columns=pred_columns or [])
|
|
1027
1554
|
|
|
1028
|
-
def save_model(
|
|
1555
|
+
def save_model(
|
|
1556
|
+
self,
|
|
1557
|
+
save_path: str | Path | None = None,
|
|
1558
|
+
add_timestamp: bool | None = None,
|
|
1559
|
+
verbose: bool = True,
|
|
1560
|
+
):
|
|
1029
1561
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1030
|
-
target_path = resolve_save_path(
|
|
1562
|
+
target_path = resolve_save_path(
|
|
1563
|
+
path=save_path,
|
|
1564
|
+
default_dir=self.session_path,
|
|
1565
|
+
default_name=self.model_name,
|
|
1566
|
+
suffix=".model",
|
|
1567
|
+
add_timestamp=add_timestamp,
|
|
1568
|
+
)
|
|
1031
1569
|
model_path = Path(target_path)
|
|
1032
1570
|
|
|
1033
|
-
model_to_save = (
|
|
1571
|
+
model_to_save = (
|
|
1572
|
+
self.ddp_model.module
|
|
1573
|
+
if getattr(self, "ddp_model", None) is not None
|
|
1574
|
+
else self
|
|
1575
|
+
)
|
|
1034
1576
|
torch.save(model_to_save.state_dict(), model_path)
|
|
1035
1577
|
# torch.save(self.state_dict(), model_path)
|
|
1036
1578
|
|
|
@@ -1045,29 +1587,47 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1045
1587
|
pickle.dump(features_config, f)
|
|
1046
1588
|
self.features_config_path = str(config_path)
|
|
1047
1589
|
if verbose:
|
|
1048
|
-
logging.info(
|
|
1049
|
-
|
|
1050
|
-
|
|
1590
|
+
logging.info(
|
|
1591
|
+
colorize(
|
|
1592
|
+
f"Model saved to: {model_path}, features config saved to: {config_path}, NextRec version: {__version__}",
|
|
1593
|
+
color="green",
|
|
1594
|
+
)
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
def load_model(
|
|
1598
|
+
self,
|
|
1599
|
+
save_path: str | Path,
|
|
1600
|
+
map_location: str | torch.device | None = "cpu",
|
|
1601
|
+
verbose: bool = True,
|
|
1602
|
+
):
|
|
1051
1603
|
self.to(self.device)
|
|
1052
1604
|
base_path = Path(save_path)
|
|
1053
1605
|
if base_path.is_dir():
|
|
1054
1606
|
model_files = sorted(base_path.glob("*.model"))
|
|
1055
1607
|
if not model_files:
|
|
1056
|
-
raise FileNotFoundError(
|
|
1608
|
+
raise FileNotFoundError(
|
|
1609
|
+
f"[BaseModel-load-model Error] No *.model file found in directory: {base_path}"
|
|
1610
|
+
)
|
|
1057
1611
|
model_path = model_files[-1]
|
|
1058
1612
|
config_dir = base_path
|
|
1059
1613
|
else:
|
|
1060
|
-
model_path =
|
|
1614
|
+
model_path = (
|
|
1615
|
+
base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
1616
|
+
)
|
|
1061
1617
|
config_dir = model_path.parent
|
|
1062
1618
|
if not model_path.exists():
|
|
1063
|
-
raise FileNotFoundError(
|
|
1619
|
+
raise FileNotFoundError(
|
|
1620
|
+
f"[BaseModel-load-model Error] Model file does not exist: {model_path}"
|
|
1621
|
+
)
|
|
1064
1622
|
|
|
1065
1623
|
state_dict = torch.load(model_path, map_location=map_location)
|
|
1066
1624
|
self.load_state_dict(state_dict)
|
|
1067
1625
|
|
|
1068
1626
|
features_config_path = config_dir / "features_config.pkl"
|
|
1069
1627
|
if not features_config_path.exists():
|
|
1070
|
-
raise FileNotFoundError(
|
|
1628
|
+
raise FileNotFoundError(
|
|
1629
|
+
f"[BaseModel-load-model Error] features_config.pkl not found in: {config_dir}"
|
|
1630
|
+
)
|
|
1071
1631
|
with open(features_config_path, "rb") as f:
|
|
1072
1632
|
features_config = pickle.load(f)
|
|
1073
1633
|
|
|
@@ -1077,11 +1637,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1077
1637
|
dense_features = [f for f in all_features if isinstance(f, DenseFeature)]
|
|
1078
1638
|
sparse_features = [f for f in all_features if isinstance(f, SparseFeature)]
|
|
1079
1639
|
sequence_features = [f for f in all_features if isinstance(f, SequenceFeature)]
|
|
1080
|
-
self.set_all_features(
|
|
1640
|
+
self.set_all_features(
|
|
1641
|
+
dense_features=dense_features,
|
|
1642
|
+
sparse_features=sparse_features,
|
|
1643
|
+
sequence_features=sequence_features,
|
|
1644
|
+
target=target,
|
|
1645
|
+
id_columns=id_columns,
|
|
1646
|
+
)
|
|
1081
1647
|
|
|
1082
1648
|
cfg_version = features_config.get("version")
|
|
1083
1649
|
if verbose:
|
|
1084
|
-
logging.info(
|
|
1650
|
+
logging.info(
|
|
1651
|
+
colorize(
|
|
1652
|
+
f"Model weights loaded from: {model_path}, features config loaded from: {features_config_path}, NextRec version: {cfg_version}",
|
|
1653
|
+
color="green",
|
|
1654
|
+
)
|
|
1655
|
+
)
|
|
1085
1656
|
|
|
1086
1657
|
@classmethod
|
|
1087
1658
|
def from_checkpoint(
|
|
@@ -1101,15 +1672,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1101
1672
|
if base_path.is_dir():
|
|
1102
1673
|
model_candidates = sorted(base_path.glob("*.model"))
|
|
1103
1674
|
if not model_candidates:
|
|
1104
|
-
raise FileNotFoundError(
|
|
1675
|
+
raise FileNotFoundError(
|
|
1676
|
+
f"[BaseModel-from-checkpoint Error] No *.model file found under: {base_path}"
|
|
1677
|
+
)
|
|
1105
1678
|
model_file = model_candidates[-1]
|
|
1106
1679
|
config_dir = base_path
|
|
1107
1680
|
else:
|
|
1108
|
-
model_file =
|
|
1681
|
+
model_file = (
|
|
1682
|
+
base_path.with_suffix(".model") if base_path.suffix == "" else base_path
|
|
1683
|
+
)
|
|
1109
1684
|
config_dir = model_file.parent
|
|
1110
1685
|
features_config_path = config_dir / "features_config.pkl"
|
|
1111
1686
|
if not features_config_path.exists():
|
|
1112
|
-
raise FileNotFoundError(
|
|
1687
|
+
raise FileNotFoundError(
|
|
1688
|
+
f"[BaseModel-from-checkpoint Error] features_config.pkl not found next to checkpoint: {features_config_path}"
|
|
1689
|
+
)
|
|
1113
1690
|
with open(features_config_path, "rb") as f:
|
|
1114
1691
|
features_config = pickle.load(f)
|
|
1115
1692
|
all_features = features_config.get("all_features", [])
|
|
@@ -1135,108 +1712,132 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1135
1712
|
|
|
1136
1713
|
def summary(self):
|
|
1137
1714
|
logger = logging.getLogger()
|
|
1138
|
-
|
|
1715
|
+
|
|
1139
1716
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1140
|
-
logger.info(
|
|
1717
|
+
logger.info(
|
|
1718
|
+
colorize(
|
|
1719
|
+
f"Model Summary: {self.model_name}", color="bright_blue", bold=True
|
|
1720
|
+
)
|
|
1721
|
+
)
|
|
1141
1722
|
logger.info(colorize("=" * 80, color="bright_blue", bold=True))
|
|
1142
|
-
|
|
1723
|
+
|
|
1143
1724
|
logger.info("")
|
|
1144
1725
|
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1145
1726
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1146
|
-
|
|
1727
|
+
|
|
1147
1728
|
if self.dense_features:
|
|
1148
1729
|
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1149
1730
|
for i, feat in enumerate(self.dense_features, 1):
|
|
1150
|
-
embed_dim = feat.embedding_dim if hasattr(feat,
|
|
1731
|
+
embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
|
|
1151
1732
|
logger.info(f" {i}. {feat.name:20s}")
|
|
1152
|
-
|
|
1733
|
+
|
|
1153
1734
|
if self.sparse_features:
|
|
1154
1735
|
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
1155
1736
|
|
|
1156
1737
|
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1157
|
-
max_embed_name_len = max(
|
|
1738
|
+
max_embed_name_len = max(
|
|
1739
|
+
len(feat.embedding_name) for feat in self.sparse_features
|
|
1740
|
+
)
|
|
1158
1741
|
name_width = max(max_name_len, 10) + 2
|
|
1159
1742
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1160
|
-
|
|
1161
|
-
logger.info(
|
|
1162
|
-
|
|
1743
|
+
|
|
1744
|
+
logger.info(
|
|
1745
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
|
|
1746
|
+
)
|
|
1747
|
+
logger.info(
|
|
1748
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
1749
|
+
)
|
|
1163
1750
|
for i, feat in enumerate(self.sparse_features, 1):
|
|
1164
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1165
|
-
embed_dim =
|
|
1166
|
-
|
|
1167
|
-
|
|
1751
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1752
|
+
embed_dim = (
|
|
1753
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1754
|
+
)
|
|
1755
|
+
logger.info(
|
|
1756
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1168
1759
|
if self.sequence_features:
|
|
1169
1760
|
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
1170
1761
|
|
|
1171
1762
|
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
1172
|
-
max_embed_name_len = max(
|
|
1763
|
+
max_embed_name_len = max(
|
|
1764
|
+
len(feat.embedding_name) for feat in self.sequence_features
|
|
1765
|
+
)
|
|
1173
1766
|
name_width = max(max_name_len, 10) + 2
|
|
1174
1767
|
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
1175
|
-
|
|
1176
|
-
logger.info(
|
|
1177
|
-
|
|
1768
|
+
|
|
1769
|
+
logger.info(
|
|
1770
|
+
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
|
|
1771
|
+
)
|
|
1772
|
+
logger.info(
|
|
1773
|
+
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
1774
|
+
)
|
|
1178
1775
|
for i, feat in enumerate(self.sequence_features, 1):
|
|
1179
|
-
vocab_size = feat.vocab_size if hasattr(feat,
|
|
1180
|
-
embed_dim =
|
|
1181
|
-
|
|
1182
|
-
|
|
1183
|
-
|
|
1776
|
+
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
1777
|
+
embed_dim = (
|
|
1778
|
+
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
1779
|
+
)
|
|
1780
|
+
max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
|
|
1781
|
+
logger.info(
|
|
1782
|
+
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
1783
|
+
)
|
|
1784
|
+
|
|
1184
1785
|
logger.info("")
|
|
1185
1786
|
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
1186
1787
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1187
|
-
|
|
1788
|
+
|
|
1188
1789
|
# Model Architecture
|
|
1189
1790
|
logger.info("Model Architecture:")
|
|
1190
1791
|
logger.info(str(self))
|
|
1191
1792
|
logger.info("")
|
|
1192
|
-
|
|
1793
|
+
|
|
1193
1794
|
total_params = sum(p.numel() for p in self.parameters())
|
|
1194
1795
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
1195
1796
|
non_trainable_params = total_params - trainable_params
|
|
1196
|
-
|
|
1797
|
+
|
|
1197
1798
|
logger.info(f"Total Parameters: {total_params:,}")
|
|
1198
1799
|
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
1199
1800
|
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
1200
|
-
|
|
1801
|
+
|
|
1201
1802
|
logger.info("Layer-wise Parameters:")
|
|
1202
1803
|
for name, module in self.named_children():
|
|
1203
1804
|
layer_params = sum(p.numel() for p in module.parameters())
|
|
1204
1805
|
if layer_params > 0:
|
|
1205
1806
|
logger.info(f" {name:30s}: {layer_params:,}")
|
|
1206
|
-
|
|
1807
|
+
|
|
1207
1808
|
logger.info("")
|
|
1208
1809
|
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
1209
1810
|
logger.info(colorize("-" * 80, color="cyan"))
|
|
1210
|
-
|
|
1811
|
+
|
|
1211
1812
|
logger.info(f"Task Type: {self.task}")
|
|
1212
1813
|
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
1213
1814
|
logger.info(f"Metrics: {self.metrics}")
|
|
1214
1815
|
logger.info(f"Target Columns: {self.target_columns}")
|
|
1215
1816
|
logger.info(f"Device: {self.device}")
|
|
1216
|
-
|
|
1217
|
-
if hasattr(self,
|
|
1817
|
+
|
|
1818
|
+
if hasattr(self, "optimizer_name"):
|
|
1218
1819
|
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
1219
1820
|
if self.optimizer_params:
|
|
1220
1821
|
for key, value in self.optimizer_params.items():
|
|
1221
1822
|
logger.info(f" {key:25s}: {value}")
|
|
1222
|
-
|
|
1223
|
-
if hasattr(self,
|
|
1823
|
+
|
|
1824
|
+
if hasattr(self, "scheduler_name") and self.scheduler_name:
|
|
1224
1825
|
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
1225
1826
|
if self.scheduler_params:
|
|
1226
1827
|
for key, value in self.scheduler_params.items():
|
|
1227
1828
|
logger.info(f" {key:25s}: {value}")
|
|
1228
|
-
|
|
1229
|
-
if hasattr(self,
|
|
1829
|
+
|
|
1830
|
+
if hasattr(self, "loss_config"):
|
|
1230
1831
|
logger.info(f"Loss Function: {self.loss_config}")
|
|
1231
|
-
if hasattr(self,
|
|
1832
|
+
if hasattr(self, "loss_weights"):
|
|
1232
1833
|
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
1233
|
-
|
|
1834
|
+
|
|
1234
1835
|
logger.info("Regularization:")
|
|
1235
1836
|
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
1236
1837
|
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
1237
1838
|
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
1238
1839
|
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
1239
|
-
|
|
1840
|
+
|
|
1240
1841
|
logger.info("Other Settings:")
|
|
1241
1842
|
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
1242
1843
|
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
@@ -1245,54 +1846,56 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1245
1846
|
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
1246
1847
|
|
|
1247
1848
|
|
|
1248
|
-
|
|
1249
1849
|
class BaseMatchModel(BaseModel):
|
|
1250
1850
|
"""
|
|
1251
1851
|
Base class for match (retrieval/recall) models
|
|
1252
1852
|
Supports pointwise, pairwise, and listwise training modes
|
|
1253
1853
|
"""
|
|
1854
|
+
|
|
1254
1855
|
@property
|
|
1255
1856
|
def model_name(self) -> str:
|
|
1256
1857
|
raise NotImplementedError
|
|
1257
|
-
|
|
1858
|
+
|
|
1258
1859
|
@property
|
|
1259
1860
|
def default_task(self) -> str:
|
|
1260
1861
|
return "binary"
|
|
1261
|
-
|
|
1862
|
+
|
|
1262
1863
|
@property
|
|
1263
1864
|
def support_training_modes(self) -> list[str]:
|
|
1264
1865
|
"""
|
|
1265
1866
|
Returns list of supported training modes for this model.
|
|
1266
1867
|
Override in subclasses to restrict training modes.
|
|
1267
|
-
|
|
1868
|
+
|
|
1268
1869
|
Returns:
|
|
1269
1870
|
List of supported modes: ['pointwise', 'pairwise', 'listwise']
|
|
1270
1871
|
"""
|
|
1271
|
-
return [
|
|
1272
|
-
|
|
1273
|
-
def __init__(
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
|
|
1288
|
-
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1872
|
+
return ["pointwise", "pairwise", "listwise"]
|
|
1873
|
+
|
|
1874
|
+
def __init__(
|
|
1875
|
+
self,
|
|
1876
|
+
user_dense_features: list[DenseFeature] | None = None,
|
|
1877
|
+
user_sparse_features: list[SparseFeature] | None = None,
|
|
1878
|
+
user_sequence_features: list[SequenceFeature] | None = None,
|
|
1879
|
+
item_dense_features: list[DenseFeature] | None = None,
|
|
1880
|
+
item_sparse_features: list[SparseFeature] | None = None,
|
|
1881
|
+
item_sequence_features: list[SequenceFeature] | None = None,
|
|
1882
|
+
training_mode: Literal["pointwise", "pairwise", "listwise"] = "pointwise",
|
|
1883
|
+
num_negative_samples: int = 4,
|
|
1884
|
+
temperature: float = 1.0,
|
|
1885
|
+
similarity_metric: Literal["dot", "cosine", "euclidean"] = "dot",
|
|
1886
|
+
device: str = "cpu",
|
|
1887
|
+
embedding_l1_reg: float = 0.0,
|
|
1888
|
+
dense_l1_reg: float = 0.0,
|
|
1889
|
+
embedding_l2_reg: float = 0.0,
|
|
1890
|
+
dense_l2_reg: float = 0.0,
|
|
1891
|
+
early_stop_patience: int = 20,
|
|
1892
|
+
**kwargs,
|
|
1893
|
+
):
|
|
1894
|
+
|
|
1292
1895
|
all_dense_features = []
|
|
1293
1896
|
all_sparse_features = []
|
|
1294
1897
|
all_sequence_features = []
|
|
1295
|
-
|
|
1898
|
+
|
|
1296
1899
|
if user_dense_features:
|
|
1297
1900
|
all_dense_features.extend(user_dense_features)
|
|
1298
1901
|
if item_dense_features:
|
|
@@ -1305,117 +1908,175 @@ class BaseMatchModel(BaseModel):
|
|
|
1305
1908
|
all_sequence_features.extend(user_sequence_features)
|
|
1306
1909
|
if item_sequence_features:
|
|
1307
1910
|
all_sequence_features.extend(item_sequence_features)
|
|
1308
|
-
|
|
1911
|
+
|
|
1309
1912
|
super(BaseMatchModel, self).__init__(
|
|
1310
1913
|
dense_features=all_dense_features,
|
|
1311
1914
|
sparse_features=all_sparse_features,
|
|
1312
1915
|
sequence_features=all_sequence_features,
|
|
1313
|
-
target=[
|
|
1314
|
-
task=
|
|
1916
|
+
target=["label"],
|
|
1917
|
+
task="binary",
|
|
1315
1918
|
device=device,
|
|
1316
1919
|
embedding_l1_reg=embedding_l1_reg,
|
|
1317
1920
|
dense_l1_reg=dense_l1_reg,
|
|
1318
1921
|
embedding_l2_reg=embedding_l2_reg,
|
|
1319
1922
|
dense_l2_reg=dense_l2_reg,
|
|
1320
1923
|
early_stop_patience=early_stop_patience,
|
|
1321
|
-
**kwargs
|
|
1924
|
+
**kwargs,
|
|
1925
|
+
)
|
|
1926
|
+
|
|
1927
|
+
self.user_dense_features = (
|
|
1928
|
+
list(user_dense_features) if user_dense_features else []
|
|
1322
1929
|
)
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
self.user_sequence_features =
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
self.
|
|
1331
|
-
|
|
1930
|
+
self.user_sparse_features = (
|
|
1931
|
+
list(user_sparse_features) if user_sparse_features else []
|
|
1932
|
+
)
|
|
1933
|
+
self.user_sequence_features = (
|
|
1934
|
+
list(user_sequence_features) if user_sequence_features else []
|
|
1935
|
+
)
|
|
1936
|
+
|
|
1937
|
+
self.item_dense_features = (
|
|
1938
|
+
list(item_dense_features) if item_dense_features else []
|
|
1939
|
+
)
|
|
1940
|
+
self.item_sparse_features = (
|
|
1941
|
+
list(item_sparse_features) if item_sparse_features else []
|
|
1942
|
+
)
|
|
1943
|
+
self.item_sequence_features = (
|
|
1944
|
+
list(item_sequence_features) if item_sequence_features else []
|
|
1945
|
+
)
|
|
1946
|
+
|
|
1332
1947
|
self.training_mode = training_mode
|
|
1333
1948
|
self.num_negative_samples = num_negative_samples
|
|
1334
1949
|
self.temperature = temperature
|
|
1335
1950
|
self.similarity_metric = similarity_metric
|
|
1336
1951
|
|
|
1337
|
-
self.user_feature_names = [
|
|
1338
|
-
|
|
1952
|
+
self.user_feature_names = [
|
|
1953
|
+
f.name
|
|
1954
|
+
for f in (
|
|
1955
|
+
self.user_dense_features
|
|
1956
|
+
+ self.user_sparse_features
|
|
1957
|
+
+ self.user_sequence_features
|
|
1958
|
+
)
|
|
1959
|
+
]
|
|
1960
|
+
self.item_feature_names = [
|
|
1961
|
+
f.name
|
|
1962
|
+
for f in (
|
|
1963
|
+
self.item_dense_features
|
|
1964
|
+
+ self.item_sparse_features
|
|
1965
|
+
+ self.item_sequence_features
|
|
1966
|
+
)
|
|
1967
|
+
]
|
|
1339
1968
|
|
|
1340
1969
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1341
1970
|
return {
|
|
1342
|
-
name: X_input[name]
|
|
1343
|
-
for name in self.user_feature_names
|
|
1344
|
-
if name in X_input
|
|
1971
|
+
name: X_input[name] for name in self.user_feature_names if name in X_input
|
|
1345
1972
|
}
|
|
1346
1973
|
|
|
1347
1974
|
def get_item_features(self, X_input: dict) -> dict:
|
|
1348
1975
|
return {
|
|
1349
|
-
name: X_input[name]
|
|
1350
|
-
for name in self.item_feature_names
|
|
1351
|
-
if name in X_input
|
|
1976
|
+
name: X_input[name] for name in self.item_feature_names if name in X_input
|
|
1352
1977
|
}
|
|
1353
|
-
|
|
1354
|
-
def compile(
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1978
|
+
|
|
1979
|
+
def compile(
|
|
1980
|
+
self,
|
|
1981
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1982
|
+
optimizer_params: dict | None = None,
|
|
1983
|
+
scheduler: (
|
|
1984
|
+
str
|
|
1985
|
+
| torch.optim.lr_scheduler._LRScheduler
|
|
1986
|
+
| torch.optim.lr_scheduler.LRScheduler
|
|
1987
|
+
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
1988
|
+
| type[torch.optim.lr_scheduler.LRScheduler]
|
|
1989
|
+
| None
|
|
1990
|
+
) = None,
|
|
1991
|
+
scheduler_params: dict | None = None,
|
|
1992
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1993
|
+
loss_params: dict | list[dict] | None = None,
|
|
1994
|
+
):
|
|
1361
1995
|
"""
|
|
1362
1996
|
Compile match model with optimizer, scheduler, and loss function.
|
|
1363
1997
|
Mirrors BaseModel.compile while adding training_mode validation for match tasks.
|
|
1364
1998
|
"""
|
|
1365
1999
|
if self.training_mode not in self.support_training_modes:
|
|
1366
|
-
raise ValueError(
|
|
2000
|
+
raise ValueError(
|
|
2001
|
+
f"{self.model_name} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2002
|
+
)
|
|
1367
2003
|
# Call parent compile with match-specific logic
|
|
1368
2004
|
optimizer_params = optimizer_params or {}
|
|
1369
|
-
|
|
1370
|
-
self.optimizer_name =
|
|
2005
|
+
|
|
2006
|
+
self.optimizer_name = (
|
|
2007
|
+
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
2008
|
+
)
|
|
1371
2009
|
self.optimizer_params = optimizer_params
|
|
1372
2010
|
if isinstance(scheduler, str):
|
|
1373
2011
|
self.scheduler_name = scheduler
|
|
1374
2012
|
elif scheduler is not None:
|
|
1375
2013
|
# Try to get __name__ first (for class types), then __class__.__name__ (for instances)
|
|
1376
|
-
self.scheduler_name = getattr(
|
|
2014
|
+
self.scheduler_name = getattr(
|
|
2015
|
+
scheduler,
|
|
2016
|
+
"__name__",
|
|
2017
|
+
getattr(scheduler.__class__, "__name__", str(scheduler)),
|
|
2018
|
+
)
|
|
1377
2019
|
else:
|
|
1378
2020
|
self.scheduler_name = None
|
|
1379
2021
|
self.scheduler_params = scheduler_params or {}
|
|
1380
2022
|
self.loss_config = loss
|
|
1381
2023
|
self.loss_params = loss_params or {}
|
|
1382
2024
|
|
|
1383
|
-
self.optimizer_fn = get_optimizer(
|
|
2025
|
+
self.optimizer_fn = get_optimizer(
|
|
2026
|
+
optimizer=optimizer, params=self.parameters(), **optimizer_params
|
|
2027
|
+
)
|
|
1384
2028
|
# Set loss function based on training mode
|
|
1385
2029
|
default_losses = {
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
2030
|
+
"pointwise": "bce",
|
|
2031
|
+
"pairwise": "bpr",
|
|
2032
|
+
"listwise": "sampled_softmax",
|
|
1389
2033
|
}
|
|
1390
2034
|
|
|
1391
2035
|
if loss is None:
|
|
1392
2036
|
loss_value = default_losses.get(self.training_mode, "bce")
|
|
1393
2037
|
elif isinstance(loss, list):
|
|
1394
|
-
loss_value =
|
|
2038
|
+
loss_value = (
|
|
2039
|
+
loss[0]
|
|
2040
|
+
if loss and loss[0] is not None
|
|
2041
|
+
else default_losses.get(self.training_mode, "bce")
|
|
2042
|
+
)
|
|
1395
2043
|
else:
|
|
1396
2044
|
loss_value = loss
|
|
1397
2045
|
|
|
1398
2046
|
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1399
|
-
if self.training_mode in {"pairwise", "listwise"} and loss_value in {
|
|
2047
|
+
if self.training_mode in {"pairwise", "listwise"} and loss_value in {
|
|
2048
|
+
"bce",
|
|
2049
|
+
"binary_crossentropy",
|
|
2050
|
+
}:
|
|
1400
2051
|
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1401
2052
|
loss_kwargs = get_loss_kwargs(self.loss_params, 0)
|
|
1402
2053
|
self.loss_fn = [get_loss_fn(loss=loss_value, **loss_kwargs)]
|
|
1403
2054
|
# set scheduler
|
|
1404
|
-
self.scheduler_fn =
|
|
2055
|
+
self.scheduler_fn = (
|
|
2056
|
+
get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {}))
|
|
2057
|
+
if scheduler
|
|
2058
|
+
else None
|
|
2059
|
+
)
|
|
1405
2060
|
|
|
1406
|
-
def compute_similarity(
|
|
1407
|
-
|
|
2061
|
+
def compute_similarity(
|
|
2062
|
+
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2063
|
+
) -> torch.Tensor:
|
|
2064
|
+
if self.similarity_metric == "dot":
|
|
1408
2065
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1409
2066
|
# [batch_size, num_items, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1410
|
-
similarity = torch.sum(
|
|
2067
|
+
similarity = torch.sum(
|
|
2068
|
+
user_emb * item_emb, dim=-1
|
|
2069
|
+
) # [batch_size, num_items]
|
|
1411
2070
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
1412
2071
|
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
1413
2072
|
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
1414
|
-
similarity = torch.sum(
|
|
2073
|
+
similarity = torch.sum(
|
|
2074
|
+
user_emb_expanded * item_emb, dim=-1
|
|
2075
|
+
) # [batch_size, num_items]
|
|
1415
2076
|
else:
|
|
1416
2077
|
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
1417
|
-
|
|
1418
|
-
elif self.similarity_metric ==
|
|
2078
|
+
|
|
2079
|
+
elif self.similarity_metric == "cosine":
|
|
1419
2080
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1420
2081
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1421
2082
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1423,8 +2084,8 @@ class BaseMatchModel(BaseModel):
|
|
|
1423
2084
|
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
1424
2085
|
else:
|
|
1425
2086
|
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
1426
|
-
|
|
1427
|
-
elif self.similarity_metric ==
|
|
2087
|
+
|
|
2088
|
+
elif self.similarity_metric == "euclidean":
|
|
1428
2089
|
if user_emb.dim() == 3 and item_emb.dim() == 3:
|
|
1429
2090
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1430
2091
|
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
@@ -1432,63 +2093,70 @@ class BaseMatchModel(BaseModel):
|
|
|
1432
2093
|
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
1433
2094
|
else:
|
|
1434
2095
|
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
1435
|
-
similarity = -distance
|
|
1436
|
-
|
|
2096
|
+
similarity = -distance
|
|
2097
|
+
|
|
1437
2098
|
else:
|
|
1438
2099
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
1439
2100
|
similarity = similarity / self.temperature
|
|
1440
2101
|
return similarity
|
|
1441
|
-
|
|
2102
|
+
|
|
1442
2103
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
1443
2104
|
raise NotImplementedError
|
|
1444
|
-
|
|
2105
|
+
|
|
1445
2106
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
1446
2107
|
raise NotImplementedError
|
|
1447
|
-
|
|
1448
|
-
def forward(
|
|
2108
|
+
|
|
2109
|
+
def forward(
|
|
2110
|
+
self, X_input: dict
|
|
2111
|
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
1449
2112
|
user_input = self.get_user_features(X_input)
|
|
1450
2113
|
item_input = self.get_item_features(X_input)
|
|
1451
|
-
|
|
1452
|
-
user_emb = self.user_tower(user_input)
|
|
1453
|
-
item_emb = self.item_tower(item_input)
|
|
1454
|
-
|
|
1455
|
-
if self.training and self.training_mode in [
|
|
2114
|
+
|
|
2115
|
+
user_emb = self.user_tower(user_input) # [B, D]
|
|
2116
|
+
item_emb = self.item_tower(item_input) # [B, D]
|
|
2117
|
+
|
|
2118
|
+
if self.training and self.training_mode in ["pairwise", "listwise"]:
|
|
1456
2119
|
return user_emb, item_emb
|
|
1457
2120
|
|
|
1458
2121
|
similarity = self.compute_similarity(user_emb, item_emb) # [B]
|
|
1459
|
-
|
|
1460
|
-
if self.training_mode ==
|
|
2122
|
+
|
|
2123
|
+
if self.training_mode == "pointwise":
|
|
1461
2124
|
return torch.sigmoid(similarity)
|
|
1462
2125
|
else:
|
|
1463
2126
|
return similarity
|
|
1464
|
-
|
|
2127
|
+
|
|
1465
2128
|
def compute_loss(self, y_pred, y_true):
|
|
1466
|
-
if self.training_mode ==
|
|
2129
|
+
if self.training_mode == "pointwise":
|
|
1467
2130
|
if y_true is None:
|
|
1468
2131
|
return torch.tensor(0.0, device=self.device)
|
|
1469
2132
|
return self.loss_fn[0](y_pred, y_true)
|
|
1470
|
-
|
|
2133
|
+
|
|
1471
2134
|
# pairwise / listwise using inbatch neg
|
|
1472
|
-
elif self.training_mode in [
|
|
2135
|
+
elif self.training_mode in ["pairwise", "listwise"]:
|
|
1473
2136
|
if not isinstance(y_pred, (tuple, list)) or len(y_pred) != 2:
|
|
1474
|
-
raise ValueError(
|
|
1475
|
-
|
|
2137
|
+
raise ValueError(
|
|
2138
|
+
"For pairwise/listwise training, forward should return (user_emb, item_emb). Please check BaseMatchModel.forward implementation."
|
|
2139
|
+
)
|
|
2140
|
+
user_emb, item_emb = y_pred # [B, D], [B, D]
|
|
1476
2141
|
logits = torch.matmul(user_emb, item_emb.t()) # [B, B]
|
|
1477
|
-
logits = logits / self.temperature
|
|
2142
|
+
logits = logits / self.temperature
|
|
1478
2143
|
batch_size = logits.size(0)
|
|
1479
|
-
targets = torch.arange(
|
|
2144
|
+
targets = torch.arange(
|
|
2145
|
+
batch_size, device=logits.device
|
|
2146
|
+
) # [0, 1, 2, ..., B-1]
|
|
1480
2147
|
# Cross-Entropy = InfoNCE
|
|
1481
2148
|
loss = F.cross_entropy(logits, targets)
|
|
1482
|
-
return loss
|
|
2149
|
+
return loss
|
|
1483
2150
|
else:
|
|
1484
2151
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
1485
2152
|
|
|
1486
|
-
|
|
1487
|
-
|
|
2153
|
+
def prepare_feature_data(
|
|
2154
|
+
self, data: dict | pd.DataFrame | DataLoader, features: list, batch_size: int
|
|
2155
|
+
) -> DataLoader:
|
|
1488
2156
|
"""Prepare data loader for specific features."""
|
|
1489
2157
|
if isinstance(data, DataLoader):
|
|
1490
2158
|
return data
|
|
1491
|
-
|
|
2159
|
+
|
|
1492
2160
|
feature_data = {}
|
|
1493
2161
|
for feature in features:
|
|
1494
2162
|
if isinstance(data, dict):
|
|
@@ -1497,13 +2165,21 @@ class BaseMatchModel(BaseModel):
|
|
|
1497
2165
|
elif isinstance(data, pd.DataFrame):
|
|
1498
2166
|
if feature.name in data.columns:
|
|
1499
2167
|
feature_data[feature.name] = data[feature.name].values
|
|
1500
|
-
return self.prepare_data_loader(
|
|
2168
|
+
return self.prepare_data_loader(
|
|
2169
|
+
feature_data, batch_size=batch_size, shuffle=False
|
|
2170
|
+
)
|
|
1501
2171
|
|
|
1502
|
-
def encode_user(
|
|
2172
|
+
def encode_user(
|
|
2173
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
2174
|
+
) -> np.ndarray:
|
|
1503
2175
|
self.eval()
|
|
1504
|
-
all_user_features =
|
|
2176
|
+
all_user_features = (
|
|
2177
|
+
self.user_dense_features
|
|
2178
|
+
+ self.user_sparse_features
|
|
2179
|
+
+ self.user_sequence_features
|
|
2180
|
+
)
|
|
1505
2181
|
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
1506
|
-
|
|
2182
|
+
|
|
1507
2183
|
embeddings_list = []
|
|
1508
2184
|
with torch.no_grad():
|
|
1509
2185
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding users"):
|
|
@@ -1512,12 +2188,18 @@ class BaseMatchModel(BaseModel):
|
|
|
1512
2188
|
user_emb = self.user_tower(user_input)
|
|
1513
2189
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
1514
2190
|
return np.concatenate(embeddings_list, axis=0)
|
|
1515
|
-
|
|
1516
|
-
def encode_item(
|
|
2191
|
+
|
|
2192
|
+
def encode_item(
|
|
2193
|
+
self, data: dict | pd.DataFrame | DataLoader, batch_size: int = 512
|
|
2194
|
+
) -> np.ndarray:
|
|
1517
2195
|
self.eval()
|
|
1518
|
-
all_item_features =
|
|
2196
|
+
all_item_features = (
|
|
2197
|
+
self.item_dense_features
|
|
2198
|
+
+ self.item_sparse_features
|
|
2199
|
+
+ self.item_sequence_features
|
|
2200
|
+
)
|
|
1519
2201
|
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
1520
|
-
|
|
2202
|
+
|
|
1521
2203
|
embeddings_list = []
|
|
1522
2204
|
with torch.no_grad():
|
|
1523
2205
|
for batch_data in tqdm.tqdm(data_loader, desc="Encoding items"):
|