nextrec 0.4.20__py3-none-any.whl → 0.4.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +3 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +259 -326
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/__init__.py +0 -4
- nextrec/loss/grad_norm.py +3 -3
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +17 -15
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/{loss/loss_utils.py → utils/loss.py} +21 -36
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +63 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/METADATA +8 -6
- nextrec-0.4.22.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.22.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 28/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -12,7 +12,7 @@ import os
|
|
|
12
12
|
import pickle
|
|
13
13
|
import socket
|
|
14
14
|
from pathlib import Path
|
|
15
|
-
from typing import Any, Literal
|
|
15
|
+
from typing import Any, Literal
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pandas as pd
|
|
@@ -26,7 +26,6 @@ from torch.utils.data.distributed import DistributedSampler
|
|
|
26
26
|
|
|
27
27
|
from nextrec import __version__
|
|
28
28
|
from nextrec.basic.callback import (
|
|
29
|
-
Callback,
|
|
30
29
|
CallbackList,
|
|
31
30
|
CheckpointSaver,
|
|
32
31
|
EarlyStopper,
|
|
@@ -41,9 +40,13 @@ from nextrec.basic.features import (
|
|
|
41
40
|
from nextrec.basic.heads import RetrievalHead
|
|
42
41
|
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
43
42
|
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
44
|
-
from nextrec.basic.
|
|
43
|
+
from nextrec.basic.summary import SummarySet
|
|
44
|
+
from nextrec.basic.session import create_session, get_save_path
|
|
45
45
|
from nextrec.data.batch_utils import batch_to_dict, collate_fn
|
|
46
|
-
from nextrec.data.data_processing import
|
|
46
|
+
from nextrec.data.data_processing import (
|
|
47
|
+
get_column_data,
|
|
48
|
+
get_user_ids,
|
|
49
|
+
)
|
|
47
50
|
from nextrec.data.dataloader import (
|
|
48
51
|
RecDataLoader,
|
|
49
52
|
TensorDictDataset,
|
|
@@ -57,23 +60,31 @@ from nextrec.loss import (
|
|
|
57
60
|
InfoNCELoss,
|
|
58
61
|
SampledSoftmaxLoss,
|
|
59
62
|
TripletLoss,
|
|
60
|
-
get_loss_fn,
|
|
61
63
|
)
|
|
64
|
+
from nextrec.utils.loss import get_loss_fn
|
|
62
65
|
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
63
66
|
from nextrec.utils.console import display_metrics_table, progress
|
|
64
67
|
from nextrec.utils.torch_utils import (
|
|
65
68
|
add_distributed_sampler,
|
|
66
|
-
|
|
69
|
+
get_device,
|
|
67
70
|
gather_numpy,
|
|
68
71
|
get_optimizer,
|
|
69
72
|
get_scheduler,
|
|
70
73
|
init_process_group,
|
|
71
74
|
to_tensor,
|
|
72
75
|
)
|
|
76
|
+
from nextrec.utils.config import safe_value
|
|
73
77
|
from nextrec.utils.model import compute_ranking_loss
|
|
78
|
+
from nextrec.utils.types import (
|
|
79
|
+
LossName,
|
|
80
|
+
OptimizerName,
|
|
81
|
+
SchedulerName,
|
|
82
|
+
TrainingModeName,
|
|
83
|
+
TaskTypeName,
|
|
84
|
+
)
|
|
74
85
|
|
|
75
86
|
|
|
76
|
-
class BaseModel(FeatureSet, nn.Module):
|
|
87
|
+
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
77
88
|
@property
|
|
78
89
|
def model_name(self) -> str:
|
|
79
90
|
raise NotImplementedError
|
|
@@ -89,21 +100,14 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
89
100
|
sequence_features: list[SequenceFeature] | None = None,
|
|
90
101
|
target: list[str] | str | None = None,
|
|
91
102
|
id_columns: list[str] | str | None = None,
|
|
92
|
-
task:
|
|
93
|
-
training_mode:
|
|
94
|
-
Literal["pointwise", "pairwise", "listwise"]
|
|
95
|
-
| list[Literal["pointwise", "pairwise", "listwise"]]
|
|
96
|
-
) = "pointwise",
|
|
103
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
104
|
+
training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
|
|
97
105
|
embedding_l1_reg: float = 0.0,
|
|
98
106
|
dense_l1_reg: float = 0.0,
|
|
99
107
|
embedding_l2_reg: float = 0.0,
|
|
100
108
|
dense_l2_reg: float = 0.0,
|
|
101
109
|
device: str = "cpu",
|
|
102
|
-
early_stop_patience: int = 20,
|
|
103
|
-
early_stop_monitor_task: str | None = None,
|
|
104
|
-
metrics_sample_limit: int | None = 200000,
|
|
105
110
|
session_id: str | None = None,
|
|
106
|
-
callbacks: list[Callback] | None = None,
|
|
107
111
|
distributed: bool = False,
|
|
108
112
|
rank: int | None = None,
|
|
109
113
|
world_size: int | None = None,
|
|
@@ -128,11 +132,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
128
132
|
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
129
133
|
|
|
130
134
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
131
|
-
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
132
|
-
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
133
|
-
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
134
135
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
135
|
-
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
136
136
|
|
|
137
137
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
138
138
|
rank: Global rank (defaults to env RANK).
|
|
@@ -152,8 +152,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
152
152
|
self.local_rank = env_local_rank if local_rank is None else local_rank
|
|
153
153
|
self.is_main_process = self.rank == 0
|
|
154
154
|
self.ddp_find_unused_parameters = ddp_find_unused_parameters
|
|
155
|
-
self.ddp_model
|
|
156
|
-
self.device =
|
|
155
|
+
self.ddp_model = None
|
|
156
|
+
self.device = get_device(self.distributed, self.local_rank, device)
|
|
157
157
|
|
|
158
158
|
self.session_id = session_id
|
|
159
159
|
self.session = create_session(session_id)
|
|
@@ -174,21 +174,21 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
174
174
|
self.task = self.default_task if task is None else task
|
|
175
175
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
176
176
|
if isinstance(training_mode, list):
|
|
177
|
-
|
|
177
|
+
training_modes = list(training_mode)
|
|
178
|
+
if len(training_modes) != self.nums_task:
|
|
178
179
|
raise ValueError(
|
|
179
180
|
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
180
181
|
)
|
|
181
|
-
self.training_modes = list(training_mode)
|
|
182
182
|
else:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
183
|
+
training_modes = [training_mode] * self.nums_task
|
|
184
|
+
if any(
|
|
185
|
+
mode not in {"pointwise", "pairwise", "listwise"} for mode in training_modes
|
|
186
|
+
):
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
189
|
+
)
|
|
190
|
+
self.training_modes = training_modes
|
|
191
|
+
self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
|
|
192
192
|
|
|
193
193
|
self.embedding_l1_reg = embedding_l1_reg
|
|
194
194
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -197,26 +197,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
197
197
|
self.regularization_weights = []
|
|
198
198
|
self.embedding_params = []
|
|
199
199
|
self.loss_weight = None
|
|
200
|
+
self.ignore_label = None
|
|
200
201
|
|
|
201
|
-
self.early_stop_patience = early_stop_patience
|
|
202
|
-
self.early_stop_monitor_task = early_stop_monitor_task
|
|
203
|
-
# max samples to keep for training metrics, in case of large training set
|
|
204
|
-
self.metrics_sample_limit = (
|
|
205
|
-
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
206
|
-
)
|
|
207
202
|
self.max_gradient_norm = 1.0
|
|
208
203
|
self.logger_initialized = False
|
|
209
204
|
self.training_logger = None
|
|
210
|
-
self.callbacks = CallbackList(
|
|
211
|
-
|
|
212
|
-
self.
|
|
205
|
+
self.callbacks = CallbackList()
|
|
206
|
+
|
|
207
|
+
self.train_data_summary = None
|
|
208
|
+
self.valid_data_summary = None
|
|
213
209
|
|
|
214
210
|
def register_regularization_weights(
|
|
215
211
|
self,
|
|
216
212
|
embedding_attr: str = "embedding",
|
|
217
213
|
exclude_modules: list[str] | None = None,
|
|
218
214
|
include_modules: list[str] | None = None,
|
|
219
|
-
)
|
|
215
|
+
):
|
|
220
216
|
exclude_modules = exclude_modules or []
|
|
221
217
|
include_modules = include_modules or []
|
|
222
218
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
@@ -264,24 +260,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
264
260
|
|
|
265
261
|
def add_reg_loss(self) -> torch.Tensor:
|
|
266
262
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
263
|
+
|
|
264
|
+
if self.embedding_l1_reg > 0:
|
|
265
|
+
reg_loss += self.embedding_l1_reg * sum(
|
|
266
|
+
param.abs().sum() for param in self.embedding_params
|
|
267
|
+
)
|
|
268
|
+
if self.embedding_l2_reg > 0:
|
|
269
|
+
reg_loss += self.embedding_l2_reg * sum(
|
|
270
|
+
(param**2).sum() for param in self.embedding_params
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
if self.dense_l1_reg > 0:
|
|
274
|
+
reg_loss += self.dense_l1_reg * sum(
|
|
275
|
+
param.abs().sum() for param in self.regularization_weights
|
|
276
|
+
)
|
|
277
|
+
if self.dense_l2_reg > 0:
|
|
278
|
+
reg_loss += self.dense_l2_reg * sum(
|
|
279
|
+
(param**2).sum() for param in self.regularization_weights
|
|
280
|
+
)
|
|
285
281
|
return reg_loss
|
|
286
282
|
|
|
287
283
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
@@ -341,10 +337,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
341
337
|
)
|
|
342
338
|
return X_input, y
|
|
343
339
|
|
|
344
|
-
def
|
|
340
|
+
def handle_valid_split(
|
|
345
341
|
self,
|
|
346
342
|
train_data: dict | pd.DataFrame,
|
|
347
|
-
|
|
343
|
+
valid_split: float,
|
|
348
344
|
batch_size: int,
|
|
349
345
|
shuffle: bool,
|
|
350
346
|
num_workers: int = 0,
|
|
@@ -352,11 +348,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
352
348
|
"""
|
|
353
349
|
This function will split training data into training and validation sets when:
|
|
354
350
|
1. valid_data is None;
|
|
355
|
-
2.
|
|
351
|
+
2. valid_split is provided.
|
|
356
352
|
"""
|
|
357
|
-
if not (0 <
|
|
353
|
+
if not (0 < valid_split < 1):
|
|
358
354
|
raise ValueError(
|
|
359
|
-
f"[BaseModel-validation Error]
|
|
355
|
+
f"[BaseModel-validation Error] valid_split must be between 0 and 1, got {valid_split}"
|
|
360
356
|
)
|
|
361
357
|
if isinstance(train_data, pd.DataFrame):
|
|
362
358
|
total_length = len(train_data)
|
|
@@ -370,37 +366,40 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
370
366
|
)
|
|
371
367
|
else:
|
|
372
368
|
raise TypeError(
|
|
373
|
-
f"[BaseModel-validation Error] If you want to use
|
|
369
|
+
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
|
|
374
370
|
)
|
|
375
371
|
rng = np.random.default_rng(42)
|
|
376
372
|
indices = rng.permutation(total_length)
|
|
377
|
-
split_idx = int(total_length * (1 -
|
|
373
|
+
split_idx = int(total_length * (1 - valid_split))
|
|
378
374
|
train_indices = indices[:split_idx]
|
|
379
375
|
valid_indices = indices[split_idx:]
|
|
380
376
|
if isinstance(train_data, pd.DataFrame):
|
|
381
|
-
|
|
382
|
-
|
|
377
|
+
train_split_data = train_data.iloc[train_indices].reset_index(drop=True)
|
|
378
|
+
valid_split_data = train_data.iloc[valid_indices].reset_index(drop=True)
|
|
383
379
|
else:
|
|
384
|
-
|
|
380
|
+
train_split_data = {
|
|
385
381
|
k: np.asarray(v)[train_indices] for k, v in train_data.items()
|
|
386
382
|
}
|
|
387
|
-
|
|
383
|
+
valid_split_data = {
|
|
388
384
|
k: np.asarray(v)[valid_indices] for k, v in train_data.items()
|
|
389
385
|
}
|
|
390
386
|
train_loader = self.prepare_data_loader(
|
|
391
|
-
|
|
387
|
+
train_split_data,
|
|
388
|
+
batch_size=batch_size,
|
|
389
|
+
shuffle=shuffle,
|
|
390
|
+
num_workers=num_workers,
|
|
392
391
|
)
|
|
393
392
|
logging.info(
|
|
394
393
|
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
|
|
395
394
|
)
|
|
396
|
-
return train_loader,
|
|
395
|
+
return train_loader, valid_split_data
|
|
397
396
|
|
|
398
397
|
def compile(
|
|
399
398
|
self,
|
|
400
|
-
optimizer:
|
|
399
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
401
400
|
optimizer_params: dict | None = None,
|
|
402
401
|
scheduler: (
|
|
403
|
-
|
|
402
|
+
SchedulerName
|
|
404
403
|
| torch.optim.lr_scheduler._LRScheduler
|
|
405
404
|
| torch.optim.lr_scheduler.LRScheduler
|
|
406
405
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -408,10 +407,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
408
407
|
| None
|
|
409
408
|
) = None,
|
|
410
409
|
scheduler_params: dict | None = None,
|
|
411
|
-
loss:
|
|
410
|
+
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
412
411
|
loss_params: dict | list[dict] | None = None,
|
|
413
412
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
414
|
-
|
|
413
|
+
ignore_label: int | float | None = -1,
|
|
415
414
|
):
|
|
416
415
|
"""
|
|
417
416
|
Configure the model for training.
|
|
@@ -424,8 +423,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
424
423
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
425
424
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
426
425
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
427
|
-
|
|
426
|
+
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
428
427
|
"""
|
|
428
|
+
self.ignore_label = ignore_label
|
|
429
429
|
default_losses = {
|
|
430
430
|
"pointwise": "bce",
|
|
431
431
|
"pairwise": "bpr",
|
|
@@ -453,10 +453,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
453
453
|
}:
|
|
454
454
|
if mode in {"pairwise", "listwise"}:
|
|
455
455
|
loss_list[idx] = default_losses[mode]
|
|
456
|
-
|
|
457
|
-
self.loss_params = {}
|
|
458
|
-
else:
|
|
459
|
-
self.loss_params = loss_params
|
|
456
|
+
self.loss_params = loss_params or {}
|
|
460
457
|
optimizer_params = optimizer_params or {}
|
|
461
458
|
self.optimizer_name = (
|
|
462
459
|
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
@@ -483,7 +480,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
483
480
|
)
|
|
484
481
|
|
|
485
482
|
self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
|
|
486
|
-
self.loss_params = loss_params or {}
|
|
487
483
|
if isinstance(self.loss_params, dict):
|
|
488
484
|
loss_params_list = [self.loss_params] * self.nums_task
|
|
489
485
|
else:
|
|
@@ -545,16 +541,12 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
545
541
|
)
|
|
546
542
|
self.loss_weights = weights
|
|
547
543
|
|
|
548
|
-
# Add callbacks from compile if provided
|
|
549
|
-
if callbacks:
|
|
550
|
-
for callback in callbacks:
|
|
551
|
-
self.callbacks.append(callback)
|
|
552
|
-
|
|
553
544
|
def compute_loss(self, y_pred, y_true):
|
|
554
545
|
if y_true is None:
|
|
555
546
|
raise ValueError(
|
|
556
547
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
557
548
|
)
|
|
549
|
+
|
|
558
550
|
# single-task
|
|
559
551
|
if self.nums_task == 1:
|
|
560
552
|
if y_pred.dim() == 1:
|
|
@@ -562,13 +554,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
562
554
|
if y_true.dim() == 1:
|
|
563
555
|
y_true = y_true.view(-1, 1)
|
|
564
556
|
if y_pred.shape != y_true.shape:
|
|
565
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
566
|
-
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
567
|
-
if loss_fn is None:
|
|
568
557
|
raise ValueError(
|
|
569
|
-
"[BaseModel-compute_loss Error]
|
|
558
|
+
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
570
559
|
)
|
|
560
|
+
|
|
561
|
+
loss_fn = self.loss_fn[0]
|
|
562
|
+
|
|
563
|
+
if self.ignore_label is not None:
|
|
564
|
+
valid_mask = y_true != self.ignore_label
|
|
565
|
+
if valid_mask.dim() > 1:
|
|
566
|
+
valid_mask = valid_mask.all(dim=1)
|
|
567
|
+
if not torch.any(valid_mask): # if no valid labels, return zero loss
|
|
568
|
+
return y_pred.sum() * 0.0
|
|
569
|
+
|
|
570
|
+
y_pred = y_pred[valid_mask]
|
|
571
|
+
y_true = y_true[valid_mask]
|
|
572
|
+
|
|
571
573
|
mode = self.training_modes[0]
|
|
574
|
+
|
|
572
575
|
task_dim = (
|
|
573
576
|
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
574
577
|
)
|
|
@@ -599,7 +602,25 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
599
602
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
600
603
|
y_pred_i = y_pred[:, start:end]
|
|
601
604
|
y_true_i = y_true[:, start:end]
|
|
605
|
+
total_count = y_true_i.shape[0]
|
|
606
|
+
# valid_count = None
|
|
607
|
+
|
|
608
|
+
# mask ignored labels
|
|
609
|
+
if self.ignore_label is not None:
|
|
610
|
+
valid_mask = y_true_i != self.ignore_label
|
|
611
|
+
if valid_mask.dim() > 1:
|
|
612
|
+
valid_mask = valid_mask.all(dim=1)
|
|
613
|
+
if not torch.any(valid_mask):
|
|
614
|
+
task_losses.append(y_pred_i.sum() * 0.0)
|
|
615
|
+
continue
|
|
616
|
+
# valid_count = valid_mask.sum().to(dtype=y_true_i.dtype)
|
|
617
|
+
y_pred_i = y_pred_i[valid_mask]
|
|
618
|
+
y_true_i = y_true_i[valid_mask]
|
|
619
|
+
# else:
|
|
620
|
+
# valid_count = y_true_i.new_tensor(float(total_count))
|
|
621
|
+
|
|
602
622
|
mode = self.training_modes[i]
|
|
623
|
+
|
|
603
624
|
if mode in {"pairwise", "listwise"}:
|
|
604
625
|
task_loss = compute_ranking_loss(
|
|
605
626
|
training_mode=mode,
|
|
@@ -609,7 +630,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
609
630
|
)
|
|
610
631
|
else:
|
|
611
632
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
633
|
+
# task_loss = normalize_task_loss(
|
|
634
|
+
# task_loss, valid_count, total_count
|
|
635
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
612
636
|
task_losses.append(task_loss)
|
|
637
|
+
|
|
613
638
|
if self.grad_norm is not None:
|
|
614
639
|
if self.grad_norm_shared_params is None:
|
|
615
640
|
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
@@ -672,28 +697,49 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
672
697
|
shuffle: bool = True,
|
|
673
698
|
batch_size: int = 32,
|
|
674
699
|
user_id_column: str | None = None,
|
|
675
|
-
|
|
700
|
+
valid_split: float | None = None,
|
|
701
|
+
early_stop_patience: int = 20,
|
|
702
|
+
early_stop_monitor_task: str | None = None,
|
|
703
|
+
metrics_sample_limit: int | None = 200000,
|
|
676
704
|
num_workers: int = 0,
|
|
677
705
|
use_tensorboard: bool = True,
|
|
706
|
+
use_wandb: bool = False,
|
|
707
|
+
use_swanlab: bool = False,
|
|
708
|
+
wandb_kwargs: dict | None = None,
|
|
709
|
+
swanlab_kwargs: dict | None = None,
|
|
678
710
|
auto_ddp_sampler: bool = True,
|
|
679
711
|
log_interval: int = 1,
|
|
712
|
+
summary_sections: (
|
|
713
|
+
list[Literal["feature", "model", "train", "data"]] | None
|
|
714
|
+
) = None,
|
|
680
715
|
):
|
|
681
716
|
"""
|
|
682
717
|
Train the model.
|
|
683
718
|
|
|
684
719
|
Args:
|
|
685
720
|
train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
|
|
686
|
-
valid_data: Optional validation data; if None and
|
|
721
|
+
valid_data: Optional validation data; if None and valid_split is set, a split is created.
|
|
687
722
|
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
688
723
|
epochs: Training epochs.
|
|
689
724
|
shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
|
|
690
725
|
batch_size: Batch size (per process when distributed).
|
|
691
726
|
user_id_column: Column name for GAUC-style metrics;.
|
|
692
|
-
|
|
727
|
+
valid_split: Ratio to split training data when valid_data is None. e.g., 0.1 for 10% validation.
|
|
728
|
+
|
|
729
|
+
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
730
|
+
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
731
|
+
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
693
732
|
num_workers: DataLoader worker count.
|
|
733
|
+
|
|
694
734
|
use_tensorboard: Enable tensorboard logging.
|
|
735
|
+
use_wandb: Enable Weights & Biases logging.
|
|
736
|
+
use_swanlab: Enable SwanLab logging.
|
|
737
|
+
wandb_kwargs: Optional kwargs for wandb.init(...).
|
|
738
|
+
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
695
739
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
696
740
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
741
|
+
summary_sections: Optional summary sections to print. Choose from
|
|
742
|
+
["feature", "model", "train", "data"]. Defaults to all.
|
|
697
743
|
|
|
698
744
|
Notes:
|
|
699
745
|
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
@@ -733,20 +779,65 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
733
779
|
): # only main process initializes logger
|
|
734
780
|
setup_logger(session_id=self.session_id)
|
|
735
781
|
self.logger_initialized = True
|
|
736
|
-
self.training_logger = (
|
|
737
|
-
TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
|
|
738
|
-
if self.is_main_process
|
|
739
|
-
else None
|
|
740
|
-
)
|
|
741
|
-
|
|
742
782
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
743
783
|
configure_metrics(
|
|
744
784
|
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
745
785
|
)
|
|
746
786
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
747
787
|
|
|
748
|
-
|
|
749
|
-
|
|
788
|
+
self.early_stop_patience = early_stop_patience
|
|
789
|
+
self.early_stop_monitor_task = early_stop_monitor_task
|
|
790
|
+
# max samples to keep for training metrics, in case of large training set
|
|
791
|
+
self.metrics_sample_limit = (
|
|
792
|
+
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
training_config = {}
|
|
796
|
+
if self.is_main_process:
|
|
797
|
+
training_config = {
|
|
798
|
+
"model_name": getattr(self, "model_name", self.__class__.__name__),
|
|
799
|
+
"task": self.task,
|
|
800
|
+
"target_columns": self.target_columns,
|
|
801
|
+
"batch_size": batch_size,
|
|
802
|
+
"epochs": epochs,
|
|
803
|
+
"shuffle": shuffle,
|
|
804
|
+
"num_workers": num_workers,
|
|
805
|
+
"valid_split": valid_split,
|
|
806
|
+
"optimizer": getattr(self, "optimizer_name", None),
|
|
807
|
+
"optimizer_params": getattr(self, "optimizer_params", None),
|
|
808
|
+
"scheduler": getattr(self, "scheduler_name", None),
|
|
809
|
+
"scheduler_params": getattr(self, "scheduler_params", None),
|
|
810
|
+
"loss": getattr(self, "loss_config", None),
|
|
811
|
+
"loss_weights": getattr(self, "loss_weights", None),
|
|
812
|
+
"early_stop_patience": self.early_stop_patience,
|
|
813
|
+
"max_gradient_norm": self.max_gradient_norm,
|
|
814
|
+
"metrics_sample_limit": self.metrics_sample_limit,
|
|
815
|
+
"embedding_l1_reg": self.embedding_l1_reg,
|
|
816
|
+
"embedding_l2_reg": self.embedding_l2_reg,
|
|
817
|
+
"dense_l1_reg": self.dense_l1_reg,
|
|
818
|
+
"dense_l2_reg": self.dense_l2_reg,
|
|
819
|
+
"session_id": self.session_id,
|
|
820
|
+
"distributed": self.distributed,
|
|
821
|
+
"device": str(self.device),
|
|
822
|
+
"dense_feature_count": len(self.dense_features),
|
|
823
|
+
"sparse_feature_count": len(self.sparse_features),
|
|
824
|
+
"sequence_feature_count": len(self.sequence_features),
|
|
825
|
+
}
|
|
826
|
+
training_config: dict = safe_value(training_config) # type: ignore
|
|
827
|
+
|
|
828
|
+
self.training_logger = (
|
|
829
|
+
TrainingLogger(
|
|
830
|
+
session=self.session,
|
|
831
|
+
use_tensorboard=use_tensorboard,
|
|
832
|
+
use_wandb=use_wandb,
|
|
833
|
+
use_swanlab=use_swanlab,
|
|
834
|
+
config=training_config,
|
|
835
|
+
wandb_kwargs=wandb_kwargs,
|
|
836
|
+
swanlab_kwargs=swanlab_kwargs,
|
|
837
|
+
)
|
|
838
|
+
if self.is_main_process
|
|
839
|
+
else None
|
|
840
|
+
)
|
|
750
841
|
|
|
751
842
|
# Setup default callbacks if missing
|
|
752
843
|
if self.nums_task == 1:
|
|
@@ -830,9 +921,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
830
921
|
)
|
|
831
922
|
)
|
|
832
923
|
|
|
833
|
-
train_sampler
|
|
834
|
-
if
|
|
835
|
-
train_loader, valid_data = self.
|
|
924
|
+
train_sampler = None
|
|
925
|
+
if valid_split is not None and valid_data is None:
|
|
926
|
+
train_loader, valid_data = self.handle_valid_split(train_data=train_data, valid_split=valid_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
836
927
|
if use_ddp_sampler:
|
|
837
928
|
base_dataset = getattr(train_loader, "dataset", None)
|
|
838
929
|
if base_dataset is not None and not isinstance(
|
|
@@ -867,7 +958,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
867
958
|
default_batch_size=batch_size,
|
|
868
959
|
is_main_process=self.is_main_process,
|
|
869
960
|
)
|
|
870
|
-
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
871
961
|
else:
|
|
872
962
|
train_loader = train_data
|
|
873
963
|
else:
|
|
@@ -911,8 +1001,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
911
1001
|
raise NotImplementedError(
|
|
912
1002
|
"[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
913
1003
|
)
|
|
914
|
-
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
915
|
-
|
|
916
1004
|
valid_loader, valid_user_ids = self.prepare_validation_data(
|
|
917
1005
|
valid_data=valid_data,
|
|
918
1006
|
batch_size=batch_size,
|
|
@@ -937,7 +1025,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
937
1025
|
)
|
|
938
1026
|
|
|
939
1027
|
if self.is_main_process:
|
|
940
|
-
self.
|
|
1028
|
+
self.train_data_summary = (
|
|
1029
|
+
None
|
|
1030
|
+
if is_streaming
|
|
1031
|
+
else self.build_train_data_summary(train_data, train_loader)
|
|
1032
|
+
)
|
|
1033
|
+
self.valid_data_summary = (
|
|
1034
|
+
None
|
|
1035
|
+
if valid_loader is None
|
|
1036
|
+
else self.build_valid_data_summary(valid_data, valid_loader)
|
|
1037
|
+
)
|
|
1038
|
+
self.summary(summary_sections)
|
|
941
1039
|
logging.info("")
|
|
942
1040
|
tb_dir = (
|
|
943
1041
|
self.training_logger.tensorboard_logdir
|
|
@@ -1017,11 +1115,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1017
1115
|
loss=train_loss,
|
|
1018
1116
|
metrics=train_metrics,
|
|
1019
1117
|
target_names=self.target_columns,
|
|
1020
|
-
base_metrics=(
|
|
1021
|
-
self.metrics
|
|
1022
|
-
if isinstance(getattr(self, "metrics", None), list)
|
|
1023
|
-
else None
|
|
1024
|
-
),
|
|
1118
|
+
base_metrics=(self.metrics if isinstance(self.metrics, list) else None),
|
|
1025
1119
|
is_main_process=self.is_main_process,
|
|
1026
1120
|
colorize=lambda s: colorize(s),
|
|
1027
1121
|
)
|
|
@@ -1048,9 +1142,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1048
1142
|
metrics=val_metrics,
|
|
1049
1143
|
target_names=self.target_columns,
|
|
1050
1144
|
base_metrics=(
|
|
1051
|
-
self.metrics
|
|
1052
|
-
if isinstance(getattr(self, "metrics", None), list)
|
|
1053
|
-
else None
|
|
1145
|
+
self.metrics if isinstance(self.metrics, list) else None
|
|
1054
1146
|
),
|
|
1055
1147
|
is_main_process=self.is_main_process,
|
|
1056
1148
|
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
@@ -1122,11 +1214,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1122
1214
|
self.training_logger.close()
|
|
1123
1215
|
return self
|
|
1124
1216
|
|
|
1125
|
-
def train_epoch(
|
|
1126
|
-
self, train_loader: DataLoader, is_streaming: bool = False
|
|
1127
|
-
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
1217
|
+
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False):
|
|
1128
1218
|
# use ddp model for distributed training
|
|
1129
|
-
model =
|
|
1219
|
+
model = (
|
|
1220
|
+
self.ddp_model
|
|
1221
|
+
if hasattr(self, "ddp_model") and self.ddp_model is not None
|
|
1222
|
+
else self
|
|
1223
|
+
)
|
|
1130
1224
|
accumulated_loss = 0.0
|
|
1131
1225
|
model.train() # type: ignore
|
|
1132
1226
|
num_batches = 0
|
|
@@ -1263,7 +1357,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1263
1357
|
user_id_column: str | None = "user_id",
|
|
1264
1358
|
num_workers: int = 0,
|
|
1265
1359
|
auto_ddp_sampler: bool = True,
|
|
1266
|
-
)
|
|
1360
|
+
):
|
|
1267
1361
|
if valid_data is None:
|
|
1268
1362
|
return None, None
|
|
1269
1363
|
if isinstance(valid_data, DataLoader):
|
|
@@ -1607,7 +1701,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1607
1701
|
|
|
1608
1702
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1609
1703
|
|
|
1610
|
-
target_path =
|
|
1704
|
+
target_path = get_save_path(
|
|
1611
1705
|
path=save_path,
|
|
1612
1706
|
default_dir=self.session.predictions_dir,
|
|
1613
1707
|
default_name="predictions",
|
|
@@ -1655,7 +1749,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1655
1749
|
stream_chunk_size: int,
|
|
1656
1750
|
return_dataframe: bool,
|
|
1657
1751
|
id_columns: list[str] | None = None,
|
|
1658
|
-
)
|
|
1752
|
+
):
|
|
1659
1753
|
if isinstance(data, (str, os.PathLike)):
|
|
1660
1754
|
rec_loader = RecDataLoader(
|
|
1661
1755
|
dense_features=self.dense_features,
|
|
@@ -1702,7 +1796,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1702
1796
|
|
|
1703
1797
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1704
1798
|
|
|
1705
|
-
target_path =
|
|
1799
|
+
target_path = get_save_path(
|
|
1706
1800
|
path=save_path,
|
|
1707
1801
|
default_dir=self.session.predictions_dir,
|
|
1708
1802
|
default_name="predictions",
|
|
@@ -1779,12 +1873,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1779
1873
|
# Non-streaming formats: collect all data
|
|
1780
1874
|
collected_frames.append(df_batch)
|
|
1781
1875
|
|
|
1782
|
-
if return_dataframe:
|
|
1783
|
-
|
|
1784
|
-
save_format in ["csv", "parquet"]
|
|
1785
|
-
and df_batch not in collected_frames
|
|
1786
|
-
):
|
|
1787
|
-
collected_frames.append(df_batch)
|
|
1876
|
+
if return_dataframe and save_format in ["csv", "parquet"]:
|
|
1877
|
+
collected_frames.append(df_batch)
|
|
1788
1878
|
|
|
1789
1879
|
# Close writers
|
|
1790
1880
|
if parquet_writer is not None:
|
|
@@ -1816,7 +1906,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1816
1906
|
verbose: bool = True,
|
|
1817
1907
|
):
|
|
1818
1908
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1819
|
-
target_path =
|
|
1909
|
+
target_path = get_save_path(
|
|
1820
1910
|
path=save_path,
|
|
1821
1911
|
default_dir=self.session_path,
|
|
1822
1912
|
default_name=self.model_name.upper(),
|
|
@@ -1825,7 +1915,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1825
1915
|
)
|
|
1826
1916
|
model_path = Path(target_path)
|
|
1827
1917
|
|
|
1828
|
-
ddp_model =
|
|
1918
|
+
ddp_model = self.ddp_model if hasattr(self, "ddp_model") else None
|
|
1829
1919
|
if ddp_model is not None:
|
|
1830
1920
|
model_to_save = ddp_model.module
|
|
1831
1921
|
else:
|
|
@@ -1967,150 +2057,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1967
2057
|
model.load_model(model_file, map_location=map_location, verbose=verbose)
|
|
1968
2058
|
return model
|
|
1969
2059
|
|
|
1970
|
-
def summary(self):
|
|
1971
|
-
logger = logging.getLogger()
|
|
1972
|
-
|
|
1973
|
-
logger.info("")
|
|
1974
|
-
logger.info(
|
|
1975
|
-
colorize(
|
|
1976
|
-
f"Model Summary: {self.model_name.upper()}",
|
|
1977
|
-
color="bright_blue",
|
|
1978
|
-
bold=True,
|
|
1979
|
-
)
|
|
1980
|
-
)
|
|
1981
|
-
logger.info("")
|
|
1982
|
-
|
|
1983
|
-
logger.info("")
|
|
1984
|
-
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1985
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
1986
|
-
|
|
1987
|
-
if self.dense_features:
|
|
1988
|
-
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1989
|
-
for i, feat in enumerate(self.dense_features, 1):
|
|
1990
|
-
embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
|
|
1991
|
-
logger.info(f" {i}. {feat.name:20s}")
|
|
1992
|
-
|
|
1993
|
-
if self.sparse_features:
|
|
1994
|
-
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
1995
|
-
|
|
1996
|
-
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1997
|
-
max_embed_name_len = max(
|
|
1998
|
-
len(feat.embedding_name) for feat in self.sparse_features
|
|
1999
|
-
)
|
|
2000
|
-
name_width = max(max_name_len, 10) + 2
|
|
2001
|
-
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
2002
|
-
|
|
2003
|
-
logger.info(
|
|
2004
|
-
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
|
|
2005
|
-
)
|
|
2006
|
-
logger.info(
|
|
2007
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
2008
|
-
)
|
|
2009
|
-
for i, feat in enumerate(self.sparse_features, 1):
|
|
2010
|
-
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
2011
|
-
embed_dim = (
|
|
2012
|
-
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
2013
|
-
)
|
|
2014
|
-
logger.info(
|
|
2015
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
2016
|
-
)
|
|
2017
|
-
|
|
2018
|
-
if self.sequence_features:
|
|
2019
|
-
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
2020
|
-
|
|
2021
|
-
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
2022
|
-
max_embed_name_len = max(
|
|
2023
|
-
len(feat.embedding_name) for feat in self.sequence_features
|
|
2024
|
-
)
|
|
2025
|
-
name_width = max(max_name_len, 10) + 2
|
|
2026
|
-
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
2027
|
-
|
|
2028
|
-
logger.info(
|
|
2029
|
-
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
|
|
2030
|
-
)
|
|
2031
|
-
logger.info(
|
|
2032
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
2033
|
-
)
|
|
2034
|
-
for i, feat in enumerate(self.sequence_features, 1):
|
|
2035
|
-
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
2036
|
-
embed_dim = (
|
|
2037
|
-
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
2038
|
-
)
|
|
2039
|
-
max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
|
|
2040
|
-
logger.info(
|
|
2041
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
2042
|
-
)
|
|
2043
|
-
|
|
2044
|
-
logger.info("")
|
|
2045
|
-
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
2046
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
2047
|
-
|
|
2048
|
-
# Model Architecture
|
|
2049
|
-
logger.info("Model Architecture:")
|
|
2050
|
-
logger.info(str(self))
|
|
2051
|
-
logger.info("")
|
|
2052
|
-
|
|
2053
|
-
total_params = sum(p.numel() for p in self.parameters())
|
|
2054
|
-
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
2055
|
-
non_trainable_params = total_params - trainable_params
|
|
2056
|
-
|
|
2057
|
-
logger.info(f"Total Parameters: {total_params:,}")
|
|
2058
|
-
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
2059
|
-
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
2060
|
-
|
|
2061
|
-
logger.info("Layer-wise Parameters:")
|
|
2062
|
-
for name, module in self.named_children():
|
|
2063
|
-
layer_params = sum(p.numel() for p in module.parameters())
|
|
2064
|
-
if layer_params > 0:
|
|
2065
|
-
logger.info(f" {name:30s}: {layer_params:,}")
|
|
2066
|
-
|
|
2067
|
-
logger.info("")
|
|
2068
|
-
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
2069
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
2070
|
-
|
|
2071
|
-
logger.info(f"Task Type: {self.task}")
|
|
2072
|
-
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
2073
|
-
logger.info(f"Metrics: {self.metrics}")
|
|
2074
|
-
logger.info(f"Target Columns: {self.target_columns}")
|
|
2075
|
-
logger.info(f"Device: {self.device}")
|
|
2076
|
-
|
|
2077
|
-
if hasattr(self, "optimizer_name"):
|
|
2078
|
-
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
2079
|
-
if self.optimizer_params:
|
|
2080
|
-
for key, value in self.optimizer_params.items():
|
|
2081
|
-
logger.info(f" {key:25s}: {value}")
|
|
2082
|
-
|
|
2083
|
-
if hasattr(self, "scheduler_name") and self.scheduler_name:
|
|
2084
|
-
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
2085
|
-
if self.scheduler_params:
|
|
2086
|
-
for key, value in self.scheduler_params.items():
|
|
2087
|
-
logger.info(f" {key:25s}: {value}")
|
|
2088
|
-
|
|
2089
|
-
if hasattr(self, "loss_config"):
|
|
2090
|
-
logger.info(f"Loss Function: {self.loss_config}")
|
|
2091
|
-
if hasattr(self, "loss_weights"):
|
|
2092
|
-
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
2093
|
-
if hasattr(self, "grad_norm"):
|
|
2094
|
-
logger.info(f"GradNorm Enabled: {self.grad_norm is not None}")
|
|
2095
|
-
if self.grad_norm is not None:
|
|
2096
|
-
grad_lr = self.grad_norm.optimizer.param_groups[0].get("lr")
|
|
2097
|
-
logger.info(f" GradNorm alpha: {self.grad_norm.alpha}")
|
|
2098
|
-
logger.info(f" GradNorm lr: {grad_lr}")
|
|
2099
|
-
|
|
2100
|
-
logger.info("Regularization:")
|
|
2101
|
-
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
2102
|
-
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
2103
|
-
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
2104
|
-
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
2105
|
-
|
|
2106
|
-
logger.info("Other Settings:")
|
|
2107
|
-
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
2108
|
-
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
2109
|
-
logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
|
|
2110
|
-
logger.info(f" Session ID: {self.session_id}")
|
|
2111
|
-
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
2112
|
-
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
2113
|
-
|
|
2114
2060
|
|
|
2115
2061
|
class BaseMatchModel(BaseModel):
|
|
2116
2062
|
"""
|
|
@@ -2156,12 +2102,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2156
2102
|
dense_l1_reg: float = 0.0,
|
|
2157
2103
|
embedding_l2_reg: float = 0.0,
|
|
2158
2104
|
dense_l2_reg: float = 0.0,
|
|
2159
|
-
early_stop_patience: int = 20,
|
|
2160
2105
|
target: list[str] | str | None = "label",
|
|
2161
2106
|
id_columns: list[str] | str | None = None,
|
|
2162
2107
|
task: str | list[str] | None = None,
|
|
2163
2108
|
session_id: str | None = None,
|
|
2164
|
-
callbacks: list[Callback] | None = None,
|
|
2165
2109
|
distributed: bool = False,
|
|
2166
2110
|
rank: int | None = None,
|
|
2167
2111
|
world_size: int | None = None,
|
|
@@ -2170,22 +2114,16 @@ class BaseMatchModel(BaseModel):
|
|
|
2170
2114
|
**kwargs,
|
|
2171
2115
|
):
|
|
2172
2116
|
|
|
2173
|
-
|
|
2174
|
-
|
|
2175
|
-
|
|
2176
|
-
|
|
2177
|
-
|
|
2178
|
-
|
|
2179
|
-
|
|
2180
|
-
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
if item_sparse_features:
|
|
2184
|
-
all_sparse_features.extend(item_sparse_features)
|
|
2185
|
-
if user_sequence_features:
|
|
2186
|
-
all_sequence_features.extend(user_sequence_features)
|
|
2187
|
-
if item_sequence_features:
|
|
2188
|
-
all_sequence_features.extend(item_sequence_features)
|
|
2117
|
+
user_dense_features = list(user_dense_features or [])
|
|
2118
|
+
user_sparse_features = list(user_sparse_features or [])
|
|
2119
|
+
user_sequence_features = list(user_sequence_features or [])
|
|
2120
|
+
item_dense_features = list(item_dense_features or [])
|
|
2121
|
+
item_sparse_features = list(item_sparse_features or [])
|
|
2122
|
+
item_sequence_features = list(item_sequence_features or [])
|
|
2123
|
+
|
|
2124
|
+
all_dense_features = user_dense_features + item_dense_features
|
|
2125
|
+
all_sparse_features = user_sparse_features + item_sparse_features
|
|
2126
|
+
all_sequence_features = user_sequence_features + item_sequence_features
|
|
2189
2127
|
|
|
2190
2128
|
super(BaseMatchModel, self).__init__(
|
|
2191
2129
|
dense_features=all_dense_features,
|
|
@@ -2199,9 +2137,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2199
2137
|
dense_l1_reg=dense_l1_reg,
|
|
2200
2138
|
embedding_l2_reg=embedding_l2_reg,
|
|
2201
2139
|
dense_l2_reg=dense_l2_reg,
|
|
2202
|
-
early_stop_patience=early_stop_patience,
|
|
2203
2140
|
session_id=session_id,
|
|
2204
|
-
callbacks=callbacks,
|
|
2205
2141
|
distributed=distributed,
|
|
2206
2142
|
rank=rank,
|
|
2207
2143
|
world_size=world_size,
|
|
@@ -2210,25 +2146,13 @@ class BaseMatchModel(BaseModel):
|
|
|
2210
2146
|
**kwargs,
|
|
2211
2147
|
)
|
|
2212
2148
|
|
|
2213
|
-
self.user_dense_features =
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
self.user_sparse_features = (
|
|
2217
|
-
list(user_sparse_features) if user_sparse_features else []
|
|
2218
|
-
)
|
|
2219
|
-
self.user_sequence_features = (
|
|
2220
|
-
list(user_sequence_features) if user_sequence_features else []
|
|
2221
|
-
)
|
|
2149
|
+
self.user_dense_features = user_dense_features
|
|
2150
|
+
self.user_sparse_features = user_sparse_features
|
|
2151
|
+
self.user_sequence_features = user_sequence_features
|
|
2222
2152
|
|
|
2223
|
-
self.item_dense_features =
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
self.item_sparse_features = (
|
|
2227
|
-
list(item_sparse_features) if item_sparse_features else []
|
|
2228
|
-
)
|
|
2229
|
-
self.item_sequence_features = (
|
|
2230
|
-
list(item_sequence_features) if item_sequence_features else []
|
|
2231
|
-
)
|
|
2153
|
+
self.item_dense_features = item_dense_features
|
|
2154
|
+
self.item_sparse_features = item_sparse_features
|
|
2155
|
+
self.item_sequence_features = item_sequence_features
|
|
2232
2156
|
|
|
2233
2157
|
self.training_mode = training_mode
|
|
2234
2158
|
self.num_negative_samples = num_negative_samples
|
|
@@ -2255,10 +2179,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2255
2179
|
|
|
2256
2180
|
def compile(
|
|
2257
2181
|
self,
|
|
2258
|
-
optimizer:
|
|
2182
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
2259
2183
|
optimizer_params: dict | None = None,
|
|
2260
2184
|
scheduler: (
|
|
2261
|
-
|
|
2185
|
+
SchedulerName
|
|
2262
2186
|
| torch.optim.lr_scheduler._LRScheduler
|
|
2263
2187
|
| torch.optim.lr_scheduler.LRScheduler
|
|
2264
2188
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -2266,26 +2190,34 @@ class BaseMatchModel(BaseModel):
|
|
|
2266
2190
|
| None
|
|
2267
2191
|
) = None,
|
|
2268
2192
|
scheduler_params: dict | None = None,
|
|
2269
|
-
loss:
|
|
2193
|
+
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
2270
2194
|
loss_params: dict | list[dict] | None = None,
|
|
2271
2195
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
2272
|
-
callbacks: list[Callback] | None = None,
|
|
2273
2196
|
):
|
|
2274
2197
|
"""
|
|
2275
2198
|
Configure the match model for training.
|
|
2199
|
+
|
|
2200
|
+
Args:
|
|
2201
|
+
optimizer: Optimizer to use (name or instance). e.g., 'adam', 'sgd'.
|
|
2202
|
+
optimizer_params: Parameters for the optimizer. e.g., {'lr': 0.001}.
|
|
2203
|
+
scheduler: Learning rate scheduler (name, instance, or class). e.g., 'step_lr'.
|
|
2204
|
+
scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
2205
|
+
loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
|
|
2206
|
+
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2207
|
+
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2276
2208
|
"""
|
|
2277
2209
|
if self.training_mode not in self.support_training_modes:
|
|
2278
2210
|
raise ValueError(
|
|
2279
2211
|
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2280
2212
|
)
|
|
2281
2213
|
|
|
2282
|
-
default_loss_by_mode
|
|
2214
|
+
default_loss_by_mode = {
|
|
2283
2215
|
"pointwise": "bce",
|
|
2284
2216
|
"pairwise": "bpr",
|
|
2285
2217
|
"listwise": "sampled_softmax",
|
|
2286
2218
|
}
|
|
2287
2219
|
|
|
2288
|
-
effective_loss
|
|
2220
|
+
effective_loss = loss
|
|
2289
2221
|
if effective_loss is None:
|
|
2290
2222
|
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2291
2223
|
elif isinstance(effective_loss, str):
|
|
@@ -2316,7 +2248,6 @@ class BaseMatchModel(BaseModel):
|
|
|
2316
2248
|
loss=effective_loss,
|
|
2317
2249
|
loss_params=loss_params,
|
|
2318
2250
|
loss_weights=loss_weights,
|
|
2319
|
-
callbacks=callbacks,
|
|
2320
2251
|
)
|
|
2321
2252
|
|
|
2322
2253
|
def inbatch_logits(
|
|
@@ -2406,7 +2337,9 @@ class BaseMatchModel(BaseModel):
|
|
|
2406
2337
|
batch_size, batch_size - 1
|
|
2407
2338
|
) # [B, B-1]
|
|
2408
2339
|
|
|
2409
|
-
loss_fn =
|
|
2340
|
+
loss_fn = (
|
|
2341
|
+
self.loss_fn[0] if hasattr(self, "loss_fn") and self.loss_fn else None
|
|
2342
|
+
)
|
|
2410
2343
|
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
2411
2344
|
loss = loss_fn(pos_logits, neg_logits)
|
|
2412
2345
|
elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
|