nextrec 0.4.20__py3-none-any.whl → 0.4.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +9 -4
- nextrec/basic/callback.py +39 -87
- nextrec/basic/features.py +149 -28
- nextrec/basic/heads.py +4 -1
- nextrec/basic/layers.py +375 -94
- nextrec/basic/loggers.py +236 -39
- nextrec/basic/model.py +209 -316
- nextrec/basic/session.py +2 -2
- nextrec/basic/summary.py +323 -0
- nextrec/cli.py +3 -3
- nextrec/data/data_processing.py +45 -1
- nextrec/data/dataloader.py +2 -2
- nextrec/data/preprocessor.py +2 -2
- nextrec/loss/loss_utils.py +5 -30
- nextrec/models/multi_task/esmm.py +4 -6
- nextrec/models/multi_task/mmoe.py +4 -6
- nextrec/models/multi_task/ple.py +6 -8
- nextrec/models/multi_task/poso.py +5 -7
- nextrec/models/multi_task/share_bottom.py +6 -8
- nextrec/models/ranking/afm.py +4 -6
- nextrec/models/ranking/autoint.py +4 -6
- nextrec/models/ranking/dcn.py +8 -7
- nextrec/models/ranking/dcn_v2.py +4 -6
- nextrec/models/ranking/deepfm.py +5 -7
- nextrec/models/ranking/dien.py +8 -7
- nextrec/models/ranking/din.py +8 -7
- nextrec/models/ranking/eulernet.py +5 -7
- nextrec/models/ranking/ffm.py +5 -7
- nextrec/models/ranking/fibinet.py +4 -6
- nextrec/models/ranking/fm.py +4 -6
- nextrec/models/ranking/lr.py +4 -6
- nextrec/models/ranking/masknet.py +8 -9
- nextrec/models/ranking/pnn.py +4 -6
- nextrec/models/ranking/widedeep.py +5 -7
- nextrec/models/ranking/xdeepfm.py +8 -7
- nextrec/models/retrieval/dssm.py +4 -10
- nextrec/models/retrieval/dssm_v2.py +0 -6
- nextrec/models/retrieval/mind.py +4 -10
- nextrec/models/retrieval/sdm.py +4 -10
- nextrec/models/retrieval/youtube_dnn.py +4 -10
- nextrec/models/sequential/hstu.py +1 -3
- nextrec/utils/__init__.py +12 -14
- nextrec/utils/config.py +15 -5
- nextrec/utils/console.py +2 -2
- nextrec/utils/feature.py +2 -2
- nextrec/utils/torch_utils.py +57 -112
- nextrec/utils/types.py +59 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/METADATA +7 -5
- nextrec-0.4.21.dist-info/RECORD +81 -0
- nextrec-0.4.20.dist-info/RECORD +0 -79
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/WHEEL +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.20.dist-info → nextrec-0.4.21.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 28/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -12,7 +12,7 @@ import os
|
|
|
12
12
|
import pickle
|
|
13
13
|
import socket
|
|
14
14
|
from pathlib import Path
|
|
15
|
-
from typing import Any, Literal
|
|
15
|
+
from typing import Any, Literal
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import pandas as pd
|
|
@@ -26,7 +26,6 @@ from torch.utils.data.distributed import DistributedSampler
|
|
|
26
26
|
|
|
27
27
|
from nextrec import __version__
|
|
28
28
|
from nextrec.basic.callback import (
|
|
29
|
-
Callback,
|
|
30
29
|
CallbackList,
|
|
31
30
|
CheckpointSaver,
|
|
32
31
|
EarlyStopper,
|
|
@@ -41,9 +40,13 @@ from nextrec.basic.features import (
|
|
|
41
40
|
from nextrec.basic.heads import RetrievalHead
|
|
42
41
|
from nextrec.basic.loggers import TrainingLogger, colorize, format_kv, setup_logger
|
|
43
42
|
from nextrec.basic.metrics import check_user_id, configure_metrics, evaluate_metrics
|
|
44
|
-
from nextrec.basic.
|
|
43
|
+
from nextrec.basic.summary import SummarySet
|
|
44
|
+
from nextrec.basic.session import create_session, get_save_path
|
|
45
45
|
from nextrec.data.batch_utils import batch_to_dict, collate_fn
|
|
46
|
-
from nextrec.data.data_processing import
|
|
46
|
+
from nextrec.data.data_processing import (
|
|
47
|
+
get_column_data,
|
|
48
|
+
get_user_ids,
|
|
49
|
+
)
|
|
47
50
|
from nextrec.data.dataloader import (
|
|
48
51
|
RecDataLoader,
|
|
49
52
|
TensorDictDataset,
|
|
@@ -63,17 +66,19 @@ from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
|
63
66
|
from nextrec.utils.console import display_metrics_table, progress
|
|
64
67
|
from nextrec.utils.torch_utils import (
|
|
65
68
|
add_distributed_sampler,
|
|
66
|
-
|
|
69
|
+
get_device,
|
|
67
70
|
gather_numpy,
|
|
68
71
|
get_optimizer,
|
|
69
72
|
get_scheduler,
|
|
70
73
|
init_process_group,
|
|
71
74
|
to_tensor,
|
|
72
75
|
)
|
|
76
|
+
from nextrec.utils.config import safe_value
|
|
73
77
|
from nextrec.utils.model import compute_ranking_loss
|
|
78
|
+
from nextrec.utils.types import LossName, OptimizerName, SchedulerName
|
|
74
79
|
|
|
75
80
|
|
|
76
|
-
class BaseModel(FeatureSet, nn.Module):
|
|
81
|
+
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
77
82
|
@property
|
|
78
83
|
def model_name(self) -> str:
|
|
79
84
|
raise NotImplementedError
|
|
@@ -99,11 +104,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
99
104
|
embedding_l2_reg: float = 0.0,
|
|
100
105
|
dense_l2_reg: float = 0.0,
|
|
101
106
|
device: str = "cpu",
|
|
102
|
-
early_stop_patience: int = 20,
|
|
103
|
-
early_stop_monitor_task: str | None = None,
|
|
104
|
-
metrics_sample_limit: int | None = 200000,
|
|
105
107
|
session_id: str | None = None,
|
|
106
|
-
callbacks: list[Callback] | None = None,
|
|
107
108
|
distributed: bool = False,
|
|
108
109
|
rank: int | None = None,
|
|
109
110
|
world_size: int | None = None,
|
|
@@ -128,11 +129,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
128
129
|
dense_l2_reg: L2 regularization strength for dense params. e.g., 1e-4.
|
|
129
130
|
|
|
130
131
|
device: Torch device string or torch.device. e.g., 'cpu', 'cuda:0'.
|
|
131
|
-
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
132
|
-
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
133
|
-
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
134
132
|
session_id: Session id for logging. If None, a default id with timestamps will be created. e.g., 'session_tutorial'.
|
|
135
|
-
callbacks: List of callback instances. If None, default callbacks will be created. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
136
133
|
|
|
137
134
|
distributed: Enable DistributedDataParallel flow, set True to enable distributed training.
|
|
138
135
|
rank: Global rank (defaults to env RANK).
|
|
@@ -152,8 +149,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
152
149
|
self.local_rank = env_local_rank if local_rank is None else local_rank
|
|
153
150
|
self.is_main_process = self.rank == 0
|
|
154
151
|
self.ddp_find_unused_parameters = ddp_find_unused_parameters
|
|
155
|
-
self.ddp_model
|
|
156
|
-
self.device =
|
|
152
|
+
self.ddp_model = None
|
|
153
|
+
self.device = get_device(self.distributed, self.local_rank, device)
|
|
157
154
|
|
|
158
155
|
self.session_id = session_id
|
|
159
156
|
self.session = create_session(session_id)
|
|
@@ -174,21 +171,22 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
174
171
|
self.task = self.default_task if task is None else task
|
|
175
172
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
176
173
|
if isinstance(training_mode, list):
|
|
177
|
-
|
|
174
|
+
training_modes = list(training_mode)
|
|
175
|
+
if len(training_modes) != self.nums_task:
|
|
178
176
|
raise ValueError(
|
|
179
177
|
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
180
178
|
)
|
|
181
|
-
self.training_modes = list(training_mode)
|
|
182
179
|
else:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
180
|
+
training_modes = [training_mode] * self.nums_task
|
|
181
|
+
if any(
|
|
182
|
+
mode not in {"pointwise", "pairwise", "listwise"}
|
|
183
|
+
for mode in training_modes
|
|
184
|
+
):
|
|
185
|
+
raise ValueError(
|
|
186
|
+
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
187
|
+
)
|
|
188
|
+
self.training_modes = training_modes
|
|
189
|
+
self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
|
|
192
190
|
|
|
193
191
|
self.embedding_l1_reg = embedding_l1_reg
|
|
194
192
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -198,25 +196,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
198
196
|
self.embedding_params = []
|
|
199
197
|
self.loss_weight = None
|
|
200
198
|
|
|
201
|
-
self.early_stop_patience = early_stop_patience
|
|
202
|
-
self.early_stop_monitor_task = early_stop_monitor_task
|
|
203
|
-
# max samples to keep for training metrics, in case of large training set
|
|
204
|
-
self.metrics_sample_limit = (
|
|
205
|
-
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
206
|
-
)
|
|
207
199
|
self.max_gradient_norm = 1.0
|
|
208
200
|
self.logger_initialized = False
|
|
209
201
|
self.training_logger = None
|
|
210
|
-
self.callbacks = CallbackList(
|
|
211
|
-
|
|
212
|
-
self.
|
|
202
|
+
self.callbacks = CallbackList()
|
|
203
|
+
|
|
204
|
+
self.train_data_summary = None
|
|
205
|
+
self.valid_data_summary = None
|
|
213
206
|
|
|
214
207
|
def register_regularization_weights(
|
|
215
208
|
self,
|
|
216
209
|
embedding_attr: str = "embedding",
|
|
217
210
|
exclude_modules: list[str] | None = None,
|
|
218
211
|
include_modules: list[str] | None = None,
|
|
219
|
-
)
|
|
212
|
+
):
|
|
220
213
|
exclude_modules = exclude_modules or []
|
|
221
214
|
include_modules = include_modules or []
|
|
222
215
|
embedding_layer = getattr(self, embedding_attr, None)
|
|
@@ -264,24 +257,24 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
264
257
|
|
|
265
258
|
def add_reg_loss(self) -> torch.Tensor:
|
|
266
259
|
reg_loss = torch.tensor(0.0, device=self.device)
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
260
|
+
|
|
261
|
+
if self.embedding_l1_reg > 0:
|
|
262
|
+
reg_loss += self.embedding_l1_reg * sum(
|
|
263
|
+
param.abs().sum() for param in self.embedding_params
|
|
264
|
+
)
|
|
265
|
+
if self.embedding_l2_reg > 0:
|
|
266
|
+
reg_loss += self.embedding_l2_reg * sum(
|
|
267
|
+
(param**2).sum() for param in self.embedding_params
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
if self.dense_l1_reg > 0:
|
|
271
|
+
reg_loss += self.dense_l1_reg * sum(
|
|
272
|
+
param.abs().sum() for param in self.regularization_weights
|
|
273
|
+
)
|
|
274
|
+
if self.dense_l2_reg > 0:
|
|
275
|
+
reg_loss += self.dense_l2_reg * sum(
|
|
276
|
+
(param**2).sum() for param in self.regularization_weights
|
|
277
|
+
)
|
|
285
278
|
return reg_loss
|
|
286
279
|
|
|
287
280
|
def get_input(self, input_data: dict, require_labels: bool = True):
|
|
@@ -341,10 +334,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
341
334
|
)
|
|
342
335
|
return X_input, y
|
|
343
336
|
|
|
344
|
-
def
|
|
337
|
+
def handle_valid_split(
|
|
345
338
|
self,
|
|
346
339
|
train_data: dict | pd.DataFrame,
|
|
347
|
-
|
|
340
|
+
valid_split: float,
|
|
348
341
|
batch_size: int,
|
|
349
342
|
shuffle: bool,
|
|
350
343
|
num_workers: int = 0,
|
|
@@ -352,11 +345,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
352
345
|
"""
|
|
353
346
|
This function will split training data into training and validation sets when:
|
|
354
347
|
1. valid_data is None;
|
|
355
|
-
2.
|
|
348
|
+
2. valid_split is provided.
|
|
356
349
|
"""
|
|
357
|
-
if not (0 <
|
|
350
|
+
if not (0 < valid_split < 1):
|
|
358
351
|
raise ValueError(
|
|
359
|
-
f"[BaseModel-validation Error]
|
|
352
|
+
f"[BaseModel-validation Error] valid_split must be between 0 and 1, got {valid_split}"
|
|
360
353
|
)
|
|
361
354
|
if isinstance(train_data, pd.DataFrame):
|
|
362
355
|
total_length = len(train_data)
|
|
@@ -370,37 +363,40 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
370
363
|
)
|
|
371
364
|
else:
|
|
372
365
|
raise TypeError(
|
|
373
|
-
f"[BaseModel-validation Error] If you want to use
|
|
366
|
+
f"[BaseModel-validation Error] If you want to use valid_split, train_data must be a pandas DataFrame or a dict instead of {type(train_data)}"
|
|
374
367
|
)
|
|
375
368
|
rng = np.random.default_rng(42)
|
|
376
369
|
indices = rng.permutation(total_length)
|
|
377
|
-
split_idx = int(total_length * (1 -
|
|
370
|
+
split_idx = int(total_length * (1 - valid_split))
|
|
378
371
|
train_indices = indices[:split_idx]
|
|
379
372
|
valid_indices = indices[split_idx:]
|
|
380
373
|
if isinstance(train_data, pd.DataFrame):
|
|
381
|
-
|
|
382
|
-
|
|
374
|
+
train_split_data = train_data.iloc[train_indices].reset_index(drop=True)
|
|
375
|
+
valid_split_data = train_data.iloc[valid_indices].reset_index(drop=True)
|
|
383
376
|
else:
|
|
384
|
-
|
|
377
|
+
train_split_data = {
|
|
385
378
|
k: np.asarray(v)[train_indices] for k, v in train_data.items()
|
|
386
379
|
}
|
|
387
|
-
|
|
380
|
+
valid_split_data = {
|
|
388
381
|
k: np.asarray(v)[valid_indices] for k, v in train_data.items()
|
|
389
382
|
}
|
|
390
383
|
train_loader = self.prepare_data_loader(
|
|
391
|
-
|
|
384
|
+
train_split_data,
|
|
385
|
+
batch_size=batch_size,
|
|
386
|
+
shuffle=shuffle,
|
|
387
|
+
num_workers=num_workers,
|
|
392
388
|
)
|
|
393
389
|
logging.info(
|
|
394
390
|
f"Split data: {len(train_indices)} training samples, {len(valid_indices)} validation samples"
|
|
395
391
|
)
|
|
396
|
-
return train_loader,
|
|
392
|
+
return train_loader, valid_split_data
|
|
397
393
|
|
|
398
394
|
def compile(
|
|
399
395
|
self,
|
|
400
|
-
optimizer:
|
|
396
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
401
397
|
optimizer_params: dict | None = None,
|
|
402
398
|
scheduler: (
|
|
403
|
-
|
|
399
|
+
SchedulerName
|
|
404
400
|
| torch.optim.lr_scheduler._LRScheduler
|
|
405
401
|
| torch.optim.lr_scheduler.LRScheduler
|
|
406
402
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -408,10 +404,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
408
404
|
| None
|
|
409
405
|
) = None,
|
|
410
406
|
scheduler_params: dict | None = None,
|
|
411
|
-
loss:
|
|
407
|
+
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
412
408
|
loss_params: dict | list[dict] | None = None,
|
|
413
409
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
414
|
-
callbacks: list[Callback] | None = None,
|
|
415
410
|
):
|
|
416
411
|
"""
|
|
417
412
|
Configure the model for training.
|
|
@@ -424,7 +419,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
424
419
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
425
420
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
426
421
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
427
|
-
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
428
422
|
"""
|
|
429
423
|
default_losses = {
|
|
430
424
|
"pointwise": "bce",
|
|
@@ -453,10 +447,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
453
447
|
}:
|
|
454
448
|
if mode in {"pairwise", "listwise"}:
|
|
455
449
|
loss_list[idx] = default_losses[mode]
|
|
456
|
-
|
|
457
|
-
self.loss_params = {}
|
|
458
|
-
else:
|
|
459
|
-
self.loss_params = loss_params
|
|
450
|
+
self.loss_params = loss_params or {}
|
|
460
451
|
optimizer_params = optimizer_params or {}
|
|
461
452
|
self.optimizer_name = (
|
|
462
453
|
optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
@@ -483,7 +474,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
483
474
|
)
|
|
484
475
|
|
|
485
476
|
self.loss_config = loss_list if self.nums_task > 1 else loss_list[0]
|
|
486
|
-
self.loss_params = loss_params or {}
|
|
487
477
|
if isinstance(self.loss_params, dict):
|
|
488
478
|
loss_params_list = [self.loss_params] * self.nums_task
|
|
489
479
|
else:
|
|
@@ -545,11 +535,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
545
535
|
)
|
|
546
536
|
self.loss_weights = weights
|
|
547
537
|
|
|
548
|
-
# Add callbacks from compile if provided
|
|
549
|
-
if callbacks:
|
|
550
|
-
for callback in callbacks:
|
|
551
|
-
self.callbacks.append(callback)
|
|
552
|
-
|
|
553
538
|
def compute_loss(self, y_pred, y_true):
|
|
554
539
|
if y_true is None:
|
|
555
540
|
raise ValueError(
|
|
@@ -672,28 +657,49 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
672
657
|
shuffle: bool = True,
|
|
673
658
|
batch_size: int = 32,
|
|
674
659
|
user_id_column: str | None = None,
|
|
675
|
-
|
|
660
|
+
valid_split: float | None = None,
|
|
661
|
+
early_stop_patience: int = 20,
|
|
662
|
+
early_stop_monitor_task: str | None = None,
|
|
663
|
+
metrics_sample_limit: int | None = 200000,
|
|
676
664
|
num_workers: int = 0,
|
|
677
665
|
use_tensorboard: bool = True,
|
|
666
|
+
use_wandb: bool = False,
|
|
667
|
+
use_swanlab: bool = False,
|
|
668
|
+
wandb_kwargs: dict | None = None,
|
|
669
|
+
swanlab_kwargs: dict | None = None,
|
|
678
670
|
auto_ddp_sampler: bool = True,
|
|
679
671
|
log_interval: int = 1,
|
|
672
|
+
summary_sections: (
|
|
673
|
+
list[Literal["feature", "model", "train", "data"]] | None
|
|
674
|
+
) = None,
|
|
680
675
|
):
|
|
681
676
|
"""
|
|
682
677
|
Train the model.
|
|
683
678
|
|
|
684
679
|
Args:
|
|
685
680
|
train_data: Training data (dict/df/DataLoader). If distributed, each rank uses its own sampler/batches.
|
|
686
|
-
valid_data: Optional validation data; if None and
|
|
681
|
+
valid_data: Optional validation data; if None and valid_split is set, a split is created.
|
|
687
682
|
metrics: Metrics names or per-target dict. e.g. {'target1': ['auc', 'logloss'], 'target2': ['mse']} or ['auc', 'logloss'].
|
|
688
683
|
epochs: Training epochs.
|
|
689
684
|
shuffle: Whether to shuffle training data (ignored when a sampler enforces order).
|
|
690
685
|
batch_size: Batch size (per process when distributed).
|
|
691
686
|
user_id_column: Column name for GAUC-style metrics;.
|
|
692
|
-
|
|
687
|
+
valid_split: Ratio to split training data when valid_data is None. e.g., 0.1 for 10% validation.
|
|
688
|
+
|
|
689
|
+
early_stop_patience: Epochs for early stopping. 0 to disable. e.g., 20.
|
|
690
|
+
early_stop_monitor_task: Task name to monitor for early stopping in multi-task scenario. If None, uses first target. e.g., 'click'.
|
|
691
|
+
metrics_sample_limit: Max samples to keep for training metrics. None disables limit.
|
|
693
692
|
num_workers: DataLoader worker count.
|
|
693
|
+
|
|
694
694
|
use_tensorboard: Enable tensorboard logging.
|
|
695
|
+
use_wandb: Enable Weights & Biases logging.
|
|
696
|
+
use_swanlab: Enable SwanLab logging.
|
|
697
|
+
wandb_kwargs: Optional kwargs for wandb.init(...).
|
|
698
|
+
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
695
699
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
696
700
|
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
701
|
+
summary_sections: Optional summary sections to print. Choose from
|
|
702
|
+
["feature", "model", "train", "data"]. Defaults to all.
|
|
697
703
|
|
|
698
704
|
Notes:
|
|
699
705
|
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
@@ -733,20 +739,65 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
733
739
|
): # only main process initializes logger
|
|
734
740
|
setup_logger(session_id=self.session_id)
|
|
735
741
|
self.logger_initialized = True
|
|
736
|
-
self.training_logger = (
|
|
737
|
-
TrainingLogger(session=self.session, use_tensorboard=use_tensorboard)
|
|
738
|
-
if self.is_main_process
|
|
739
|
-
else None
|
|
740
|
-
)
|
|
741
|
-
|
|
742
742
|
self.metrics, self.task_specific_metrics, self.best_metrics_mode = (
|
|
743
743
|
configure_metrics(
|
|
744
744
|
task=self.task, metrics=metrics, target_names=self.target_columns
|
|
745
745
|
)
|
|
746
746
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
747
747
|
|
|
748
|
-
|
|
749
|
-
|
|
748
|
+
self.early_stop_patience = early_stop_patience
|
|
749
|
+
self.early_stop_monitor_task = early_stop_monitor_task
|
|
750
|
+
# max samples to keep for training metrics, in case of large training set
|
|
751
|
+
self.metrics_sample_limit = (
|
|
752
|
+
None if metrics_sample_limit is None else int(metrics_sample_limit)
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
training_config = {}
|
|
756
|
+
if self.is_main_process:
|
|
757
|
+
training_config = {
|
|
758
|
+
"model_name": getattr(self, "model_name", self.__class__.__name__),
|
|
759
|
+
"task": self.task,
|
|
760
|
+
"target_columns": self.target_columns,
|
|
761
|
+
"batch_size": batch_size,
|
|
762
|
+
"epochs": epochs,
|
|
763
|
+
"shuffle": shuffle,
|
|
764
|
+
"num_workers": num_workers,
|
|
765
|
+
"valid_split": valid_split,
|
|
766
|
+
"optimizer": getattr(self, "optimizer_name", None),
|
|
767
|
+
"optimizer_params": getattr(self, "optimizer_params", None),
|
|
768
|
+
"scheduler": getattr(self, "scheduler_name", None),
|
|
769
|
+
"scheduler_params": getattr(self, "scheduler_params", None),
|
|
770
|
+
"loss": getattr(self, "loss_config", None),
|
|
771
|
+
"loss_weights": getattr(self, "loss_weights", None),
|
|
772
|
+
"early_stop_patience": self.early_stop_patience,
|
|
773
|
+
"max_gradient_norm": self.max_gradient_norm,
|
|
774
|
+
"metrics_sample_limit": self.metrics_sample_limit,
|
|
775
|
+
"embedding_l1_reg": self.embedding_l1_reg,
|
|
776
|
+
"embedding_l2_reg": self.embedding_l2_reg,
|
|
777
|
+
"dense_l1_reg": self.dense_l1_reg,
|
|
778
|
+
"dense_l2_reg": self.dense_l2_reg,
|
|
779
|
+
"session_id": self.session_id,
|
|
780
|
+
"distributed": self.distributed,
|
|
781
|
+
"device": str(self.device),
|
|
782
|
+
"dense_feature_count": len(self.dense_features),
|
|
783
|
+
"sparse_feature_count": len(self.sparse_features),
|
|
784
|
+
"sequence_feature_count": len(self.sequence_features),
|
|
785
|
+
}
|
|
786
|
+
training_config: dict = safe_value(training_config) # type: ignore
|
|
787
|
+
|
|
788
|
+
self.training_logger = (
|
|
789
|
+
TrainingLogger(
|
|
790
|
+
session=self.session,
|
|
791
|
+
use_tensorboard=use_tensorboard,
|
|
792
|
+
use_wandb=use_wandb,
|
|
793
|
+
use_swanlab=use_swanlab,
|
|
794
|
+
config=training_config,
|
|
795
|
+
wandb_kwargs=wandb_kwargs,
|
|
796
|
+
swanlab_kwargs=swanlab_kwargs,
|
|
797
|
+
)
|
|
798
|
+
if self.is_main_process
|
|
799
|
+
else None
|
|
800
|
+
)
|
|
750
801
|
|
|
751
802
|
# Setup default callbacks if missing
|
|
752
803
|
if self.nums_task == 1:
|
|
@@ -830,9 +881,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
830
881
|
)
|
|
831
882
|
)
|
|
832
883
|
|
|
833
|
-
train_sampler
|
|
834
|
-
if
|
|
835
|
-
train_loader, valid_data = self.
|
|
884
|
+
train_sampler = None
|
|
885
|
+
if valid_split is not None and valid_data is None:
|
|
886
|
+
train_loader, valid_data = self.handle_valid_split(train_data=train_data, valid_split=valid_split, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) # type: ignore
|
|
836
887
|
if use_ddp_sampler:
|
|
837
888
|
base_dataset = getattr(train_loader, "dataset", None)
|
|
838
889
|
if base_dataset is not None and not isinstance(
|
|
@@ -867,7 +918,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
867
918
|
default_batch_size=batch_size,
|
|
868
919
|
is_main_process=self.is_main_process,
|
|
869
920
|
)
|
|
870
|
-
# train_loader, train_sampler = add_distributed_sampler(train_data, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
871
921
|
else:
|
|
872
922
|
train_loader = train_data
|
|
873
923
|
else:
|
|
@@ -911,8 +961,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
911
961
|
raise NotImplementedError(
|
|
912
962
|
"[BaseModel-fit Error] auto_ddp_sampler with pre-defined DataLoader is not supported yet."
|
|
913
963
|
)
|
|
914
|
-
# train_loader, train_sampler = add_distributed_sampler(train_loader, distributed=self.distributed, world_size=self.world_size, rank=self.rank, shuffle=shuffle, drop_last=True, default_batch_size=batch_size, is_main_process=self.is_main_process)
|
|
915
|
-
|
|
916
964
|
valid_loader, valid_user_ids = self.prepare_validation_data(
|
|
917
965
|
valid_data=valid_data,
|
|
918
966
|
batch_size=batch_size,
|
|
@@ -937,7 +985,17 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
937
985
|
)
|
|
938
986
|
|
|
939
987
|
if self.is_main_process:
|
|
940
|
-
self.
|
|
988
|
+
self.train_data_summary = (
|
|
989
|
+
None
|
|
990
|
+
if is_streaming
|
|
991
|
+
else self.build_train_data_summary(train_data, train_loader)
|
|
992
|
+
)
|
|
993
|
+
self.valid_data_summary = (
|
|
994
|
+
None
|
|
995
|
+
if valid_loader is None
|
|
996
|
+
else self.build_valid_data_summary(valid_data, valid_loader)
|
|
997
|
+
)
|
|
998
|
+
self.summary(summary_sections)
|
|
941
999
|
logging.info("")
|
|
942
1000
|
tb_dir = (
|
|
943
1001
|
self.training_logger.tensorboard_logdir
|
|
@@ -1017,11 +1075,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1017
1075
|
loss=train_loss,
|
|
1018
1076
|
metrics=train_metrics,
|
|
1019
1077
|
target_names=self.target_columns,
|
|
1020
|
-
base_metrics=(
|
|
1021
|
-
self.metrics
|
|
1022
|
-
if isinstance(getattr(self, "metrics", None), list)
|
|
1023
|
-
else None
|
|
1024
|
-
),
|
|
1078
|
+
base_metrics=(self.metrics if isinstance(self.metrics, list) else None),
|
|
1025
1079
|
is_main_process=self.is_main_process,
|
|
1026
1080
|
colorize=lambda s: colorize(s),
|
|
1027
1081
|
)
|
|
@@ -1048,9 +1102,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1048
1102
|
metrics=val_metrics,
|
|
1049
1103
|
target_names=self.target_columns,
|
|
1050
1104
|
base_metrics=(
|
|
1051
|
-
self.metrics
|
|
1052
|
-
if isinstance(getattr(self, "metrics", None), list)
|
|
1053
|
-
else None
|
|
1105
|
+
self.metrics if isinstance(self.metrics, list) else None
|
|
1054
1106
|
),
|
|
1055
1107
|
is_main_process=self.is_main_process,
|
|
1056
1108
|
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
@@ -1122,11 +1174,13 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1122
1174
|
self.training_logger.close()
|
|
1123
1175
|
return self
|
|
1124
1176
|
|
|
1125
|
-
def train_epoch(
|
|
1126
|
-
self, train_loader: DataLoader, is_streaming: bool = False
|
|
1127
|
-
) -> Union[float, np.ndarray, tuple[Union[float, np.ndarray], dict]]:
|
|
1177
|
+
def train_epoch(self, train_loader: DataLoader, is_streaming: bool = False):
|
|
1128
1178
|
# use ddp model for distributed training
|
|
1129
|
-
model =
|
|
1179
|
+
model = (
|
|
1180
|
+
self.ddp_model
|
|
1181
|
+
if hasattr(self, "ddp_model") and self.ddp_model is not None
|
|
1182
|
+
else self
|
|
1183
|
+
)
|
|
1130
1184
|
accumulated_loss = 0.0
|
|
1131
1185
|
model.train() # type: ignore
|
|
1132
1186
|
num_batches = 0
|
|
@@ -1263,7 +1317,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1263
1317
|
user_id_column: str | None = "user_id",
|
|
1264
1318
|
num_workers: int = 0,
|
|
1265
1319
|
auto_ddp_sampler: bool = True,
|
|
1266
|
-
)
|
|
1320
|
+
):
|
|
1267
1321
|
if valid_data is None:
|
|
1268
1322
|
return None, None
|
|
1269
1323
|
if isinstance(valid_data, DataLoader):
|
|
@@ -1607,7 +1661,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1607
1661
|
|
|
1608
1662
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1609
1663
|
|
|
1610
|
-
target_path =
|
|
1664
|
+
target_path = get_save_path(
|
|
1611
1665
|
path=save_path,
|
|
1612
1666
|
default_dir=self.session.predictions_dir,
|
|
1613
1667
|
default_name="predictions",
|
|
@@ -1655,7 +1709,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1655
1709
|
stream_chunk_size: int,
|
|
1656
1710
|
return_dataframe: bool,
|
|
1657
1711
|
id_columns: list[str] | None = None,
|
|
1658
|
-
)
|
|
1712
|
+
):
|
|
1659
1713
|
if isinstance(data, (str, os.PathLike)):
|
|
1660
1714
|
rec_loader = RecDataLoader(
|
|
1661
1715
|
dense_features=self.dense_features,
|
|
@@ -1702,7 +1756,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1702
1756
|
|
|
1703
1757
|
suffix = FILE_FORMAT_CONFIG[save_format]["extension"][0]
|
|
1704
1758
|
|
|
1705
|
-
target_path =
|
|
1759
|
+
target_path = get_save_path(
|
|
1706
1760
|
path=save_path,
|
|
1707
1761
|
default_dir=self.session.predictions_dir,
|
|
1708
1762
|
default_name="predictions",
|
|
@@ -1779,12 +1833,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1779
1833
|
# Non-streaming formats: collect all data
|
|
1780
1834
|
collected_frames.append(df_batch)
|
|
1781
1835
|
|
|
1782
|
-
if return_dataframe:
|
|
1783
|
-
|
|
1784
|
-
save_format in ["csv", "parquet"]
|
|
1785
|
-
and df_batch not in collected_frames
|
|
1786
|
-
):
|
|
1787
|
-
collected_frames.append(df_batch)
|
|
1836
|
+
if return_dataframe and save_format in ["csv", "parquet"]:
|
|
1837
|
+
collected_frames.append(df_batch)
|
|
1788
1838
|
|
|
1789
1839
|
# Close writers
|
|
1790
1840
|
if parquet_writer is not None:
|
|
@@ -1816,7 +1866,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1816
1866
|
verbose: bool = True,
|
|
1817
1867
|
):
|
|
1818
1868
|
add_timestamp = False if add_timestamp is None else add_timestamp
|
|
1819
|
-
target_path =
|
|
1869
|
+
target_path = get_save_path(
|
|
1820
1870
|
path=save_path,
|
|
1821
1871
|
default_dir=self.session_path,
|
|
1822
1872
|
default_name=self.model_name.upper(),
|
|
@@ -1825,7 +1875,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1825
1875
|
)
|
|
1826
1876
|
model_path = Path(target_path)
|
|
1827
1877
|
|
|
1828
|
-
ddp_model =
|
|
1878
|
+
ddp_model = self.ddp_model if hasattr(self, "ddp_model") else None
|
|
1829
1879
|
if ddp_model is not None:
|
|
1830
1880
|
model_to_save = ddp_model.module
|
|
1831
1881
|
else:
|
|
@@ -1967,150 +2017,6 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1967
2017
|
model.load_model(model_file, map_location=map_location, verbose=verbose)
|
|
1968
2018
|
return model
|
|
1969
2019
|
|
|
1970
|
-
def summary(self):
|
|
1971
|
-
logger = logging.getLogger()
|
|
1972
|
-
|
|
1973
|
-
logger.info("")
|
|
1974
|
-
logger.info(
|
|
1975
|
-
colorize(
|
|
1976
|
-
f"Model Summary: {self.model_name.upper()}",
|
|
1977
|
-
color="bright_blue",
|
|
1978
|
-
bold=True,
|
|
1979
|
-
)
|
|
1980
|
-
)
|
|
1981
|
-
logger.info("")
|
|
1982
|
-
|
|
1983
|
-
logger.info("")
|
|
1984
|
-
logger.info(colorize("[1] Feature Configuration", color="cyan", bold=True))
|
|
1985
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
1986
|
-
|
|
1987
|
-
if self.dense_features:
|
|
1988
|
-
logger.info(f"Dense Features ({len(self.dense_features)}):")
|
|
1989
|
-
for i, feat in enumerate(self.dense_features, 1):
|
|
1990
|
-
embed_dim = feat.embedding_dim if hasattr(feat, "embedding_dim") else 1
|
|
1991
|
-
logger.info(f" {i}. {feat.name:20s}")
|
|
1992
|
-
|
|
1993
|
-
if self.sparse_features:
|
|
1994
|
-
logger.info(f"\nSparse Features ({len(self.sparse_features)}):")
|
|
1995
|
-
|
|
1996
|
-
max_name_len = max(len(feat.name) for feat in self.sparse_features)
|
|
1997
|
-
max_embed_name_len = max(
|
|
1998
|
-
len(feat.embedding_name) for feat in self.sparse_features
|
|
1999
|
-
)
|
|
2000
|
-
name_width = max(max_name_len, 10) + 2
|
|
2001
|
-
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
2002
|
-
|
|
2003
|
-
logger.info(
|
|
2004
|
-
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10}"
|
|
2005
|
-
)
|
|
2006
|
-
logger.info(
|
|
2007
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10}"
|
|
2008
|
-
)
|
|
2009
|
-
for i, feat in enumerate(self.sparse_features, 1):
|
|
2010
|
-
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
2011
|
-
embed_dim = (
|
|
2012
|
-
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
2013
|
-
)
|
|
2014
|
-
logger.info(
|
|
2015
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10}"
|
|
2016
|
-
)
|
|
2017
|
-
|
|
2018
|
-
if self.sequence_features:
|
|
2019
|
-
logger.info(f"\nSequence Features ({len(self.sequence_features)}):")
|
|
2020
|
-
|
|
2021
|
-
max_name_len = max(len(feat.name) for feat in self.sequence_features)
|
|
2022
|
-
max_embed_name_len = max(
|
|
2023
|
-
len(feat.embedding_name) for feat in self.sequence_features
|
|
2024
|
-
)
|
|
2025
|
-
name_width = max(max_name_len, 10) + 2
|
|
2026
|
-
embed_name_width = max(max_embed_name_len, 15) + 2
|
|
2027
|
-
|
|
2028
|
-
logger.info(
|
|
2029
|
-
f" {'#':<4} {'Name':<{name_width}} {'Vocab Size':>12} {'Embed Name':>{embed_name_width}} {'Embed Dim':>10} {'Max Len':>10}"
|
|
2030
|
-
)
|
|
2031
|
-
logger.info(
|
|
2032
|
-
f" {'-'*4} {'-'*name_width} {'-'*12} {'-'*embed_name_width} {'-'*10} {'-'*10}"
|
|
2033
|
-
)
|
|
2034
|
-
for i, feat in enumerate(self.sequence_features, 1):
|
|
2035
|
-
vocab_size = feat.vocab_size if hasattr(feat, "vocab_size") else "N/A"
|
|
2036
|
-
embed_dim = (
|
|
2037
|
-
feat.embedding_dim if hasattr(feat, "embedding_dim") else "N/A"
|
|
2038
|
-
)
|
|
2039
|
-
max_len = feat.max_len if hasattr(feat, "max_len") else "N/A"
|
|
2040
|
-
logger.info(
|
|
2041
|
-
f" {i:<4} {feat.name:<{name_width}} {str(vocab_size):>12} {feat.embedding_name:>{embed_name_width}} {str(embed_dim):>10} {str(max_len):>10}"
|
|
2042
|
-
)
|
|
2043
|
-
|
|
2044
|
-
logger.info("")
|
|
2045
|
-
logger.info(colorize("[2] Model Parameters", color="cyan", bold=True))
|
|
2046
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
2047
|
-
|
|
2048
|
-
# Model Architecture
|
|
2049
|
-
logger.info("Model Architecture:")
|
|
2050
|
-
logger.info(str(self))
|
|
2051
|
-
logger.info("")
|
|
2052
|
-
|
|
2053
|
-
total_params = sum(p.numel() for p in self.parameters())
|
|
2054
|
-
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
2055
|
-
non_trainable_params = total_params - trainable_params
|
|
2056
|
-
|
|
2057
|
-
logger.info(f"Total Parameters: {total_params:,}")
|
|
2058
|
-
logger.info(f"Trainable Parameters: {trainable_params:,}")
|
|
2059
|
-
logger.info(f"Non-trainable Parameters: {non_trainable_params:,}")
|
|
2060
|
-
|
|
2061
|
-
logger.info("Layer-wise Parameters:")
|
|
2062
|
-
for name, module in self.named_children():
|
|
2063
|
-
layer_params = sum(p.numel() for p in module.parameters())
|
|
2064
|
-
if layer_params > 0:
|
|
2065
|
-
logger.info(f" {name:30s}: {layer_params:,}")
|
|
2066
|
-
|
|
2067
|
-
logger.info("")
|
|
2068
|
-
logger.info(colorize("[3] Training Configuration", color="cyan", bold=True))
|
|
2069
|
-
logger.info(colorize("-" * 80, color="cyan"))
|
|
2070
|
-
|
|
2071
|
-
logger.info(f"Task Type: {self.task}")
|
|
2072
|
-
logger.info(f"Number of Tasks: {self.nums_task}")
|
|
2073
|
-
logger.info(f"Metrics: {self.metrics}")
|
|
2074
|
-
logger.info(f"Target Columns: {self.target_columns}")
|
|
2075
|
-
logger.info(f"Device: {self.device}")
|
|
2076
|
-
|
|
2077
|
-
if hasattr(self, "optimizer_name"):
|
|
2078
|
-
logger.info(f"Optimizer: {self.optimizer_name}")
|
|
2079
|
-
if self.optimizer_params:
|
|
2080
|
-
for key, value in self.optimizer_params.items():
|
|
2081
|
-
logger.info(f" {key:25s}: {value}")
|
|
2082
|
-
|
|
2083
|
-
if hasattr(self, "scheduler_name") and self.scheduler_name:
|
|
2084
|
-
logger.info(f"Scheduler: {self.scheduler_name}")
|
|
2085
|
-
if self.scheduler_params:
|
|
2086
|
-
for key, value in self.scheduler_params.items():
|
|
2087
|
-
logger.info(f" {key:25s}: {value}")
|
|
2088
|
-
|
|
2089
|
-
if hasattr(self, "loss_config"):
|
|
2090
|
-
logger.info(f"Loss Function: {self.loss_config}")
|
|
2091
|
-
if hasattr(self, "loss_weights"):
|
|
2092
|
-
logger.info(f"Loss Weights: {self.loss_weights}")
|
|
2093
|
-
if hasattr(self, "grad_norm"):
|
|
2094
|
-
logger.info(f"GradNorm Enabled: {self.grad_norm is not None}")
|
|
2095
|
-
if self.grad_norm is not None:
|
|
2096
|
-
grad_lr = self.grad_norm.optimizer.param_groups[0].get("lr")
|
|
2097
|
-
logger.info(f" GradNorm alpha: {self.grad_norm.alpha}")
|
|
2098
|
-
logger.info(f" GradNorm lr: {grad_lr}")
|
|
2099
|
-
|
|
2100
|
-
logger.info("Regularization:")
|
|
2101
|
-
logger.info(f" Embedding L1: {self.embedding_l1_reg}")
|
|
2102
|
-
logger.info(f" Embedding L2: {self.embedding_l2_reg}")
|
|
2103
|
-
logger.info(f" Dense L1: {self.dense_l1_reg}")
|
|
2104
|
-
logger.info(f" Dense L2: {self.dense_l2_reg}")
|
|
2105
|
-
|
|
2106
|
-
logger.info("Other Settings:")
|
|
2107
|
-
logger.info(f" Early Stop Patience: {self.early_stop_patience}")
|
|
2108
|
-
logger.info(f" Max Gradient Norm: {self.max_gradient_norm}")
|
|
2109
|
-
logger.info(f" Max Metrics Samples: {self.metrics_sample_limit}")
|
|
2110
|
-
logger.info(f" Session ID: {self.session_id}")
|
|
2111
|
-
logger.info(f" Features Config Path: {self.features_config_path}")
|
|
2112
|
-
logger.info(f" Latest Checkpoint: {self.checkpoint_path}")
|
|
2113
|
-
|
|
2114
2020
|
|
|
2115
2021
|
class BaseMatchModel(BaseModel):
|
|
2116
2022
|
"""
|
|
@@ -2156,12 +2062,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2156
2062
|
dense_l1_reg: float = 0.0,
|
|
2157
2063
|
embedding_l2_reg: float = 0.0,
|
|
2158
2064
|
dense_l2_reg: float = 0.0,
|
|
2159
|
-
early_stop_patience: int = 20,
|
|
2160
2065
|
target: list[str] | str | None = "label",
|
|
2161
2066
|
id_columns: list[str] | str | None = None,
|
|
2162
2067
|
task: str | list[str] | None = None,
|
|
2163
2068
|
session_id: str | None = None,
|
|
2164
|
-
callbacks: list[Callback] | None = None,
|
|
2165
2069
|
distributed: bool = False,
|
|
2166
2070
|
rank: int | None = None,
|
|
2167
2071
|
world_size: int | None = None,
|
|
@@ -2170,22 +2074,16 @@ class BaseMatchModel(BaseModel):
|
|
|
2170
2074
|
**kwargs,
|
|
2171
2075
|
):
|
|
2172
2076
|
|
|
2173
|
-
|
|
2174
|
-
|
|
2175
|
-
|
|
2176
|
-
|
|
2177
|
-
|
|
2178
|
-
|
|
2179
|
-
|
|
2180
|
-
|
|
2181
|
-
|
|
2182
|
-
|
|
2183
|
-
if item_sparse_features:
|
|
2184
|
-
all_sparse_features.extend(item_sparse_features)
|
|
2185
|
-
if user_sequence_features:
|
|
2186
|
-
all_sequence_features.extend(user_sequence_features)
|
|
2187
|
-
if item_sequence_features:
|
|
2188
|
-
all_sequence_features.extend(item_sequence_features)
|
|
2077
|
+
user_dense_features = list(user_dense_features or [])
|
|
2078
|
+
user_sparse_features = list(user_sparse_features or [])
|
|
2079
|
+
user_sequence_features = list(user_sequence_features or [])
|
|
2080
|
+
item_dense_features = list(item_dense_features or [])
|
|
2081
|
+
item_sparse_features = list(item_sparse_features or [])
|
|
2082
|
+
item_sequence_features = list(item_sequence_features or [])
|
|
2083
|
+
|
|
2084
|
+
all_dense_features = user_dense_features + item_dense_features
|
|
2085
|
+
all_sparse_features = user_sparse_features + item_sparse_features
|
|
2086
|
+
all_sequence_features = user_sequence_features + item_sequence_features
|
|
2189
2087
|
|
|
2190
2088
|
super(BaseMatchModel, self).__init__(
|
|
2191
2089
|
dense_features=all_dense_features,
|
|
@@ -2199,9 +2097,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2199
2097
|
dense_l1_reg=dense_l1_reg,
|
|
2200
2098
|
embedding_l2_reg=embedding_l2_reg,
|
|
2201
2099
|
dense_l2_reg=dense_l2_reg,
|
|
2202
|
-
early_stop_patience=early_stop_patience,
|
|
2203
2100
|
session_id=session_id,
|
|
2204
|
-
callbacks=callbacks,
|
|
2205
2101
|
distributed=distributed,
|
|
2206
2102
|
rank=rank,
|
|
2207
2103
|
world_size=world_size,
|
|
@@ -2210,25 +2106,13 @@ class BaseMatchModel(BaseModel):
|
|
|
2210
2106
|
**kwargs,
|
|
2211
2107
|
)
|
|
2212
2108
|
|
|
2213
|
-
self.user_dense_features =
|
|
2214
|
-
|
|
2215
|
-
|
|
2216
|
-
self.user_sparse_features = (
|
|
2217
|
-
list(user_sparse_features) if user_sparse_features else []
|
|
2218
|
-
)
|
|
2219
|
-
self.user_sequence_features = (
|
|
2220
|
-
list(user_sequence_features) if user_sequence_features else []
|
|
2221
|
-
)
|
|
2109
|
+
self.user_dense_features = user_dense_features
|
|
2110
|
+
self.user_sparse_features = user_sparse_features
|
|
2111
|
+
self.user_sequence_features = user_sequence_features
|
|
2222
2112
|
|
|
2223
|
-
self.item_dense_features =
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
self.item_sparse_features = (
|
|
2227
|
-
list(item_sparse_features) if item_sparse_features else []
|
|
2228
|
-
)
|
|
2229
|
-
self.item_sequence_features = (
|
|
2230
|
-
list(item_sequence_features) if item_sequence_features else []
|
|
2231
|
-
)
|
|
2113
|
+
self.item_dense_features = item_dense_features
|
|
2114
|
+
self.item_sparse_features = item_sparse_features
|
|
2115
|
+
self.item_sequence_features = item_sequence_features
|
|
2232
2116
|
|
|
2233
2117
|
self.training_mode = training_mode
|
|
2234
2118
|
self.num_negative_samples = num_negative_samples
|
|
@@ -2255,10 +2139,10 @@ class BaseMatchModel(BaseModel):
|
|
|
2255
2139
|
|
|
2256
2140
|
def compile(
|
|
2257
2141
|
self,
|
|
2258
|
-
optimizer:
|
|
2142
|
+
optimizer: OptimizerName | torch.optim.Optimizer = "adam",
|
|
2259
2143
|
optimizer_params: dict | None = None,
|
|
2260
2144
|
scheduler: (
|
|
2261
|
-
|
|
2145
|
+
SchedulerName
|
|
2262
2146
|
| torch.optim.lr_scheduler._LRScheduler
|
|
2263
2147
|
| torch.optim.lr_scheduler.LRScheduler
|
|
2264
2148
|
| type[torch.optim.lr_scheduler._LRScheduler]
|
|
@@ -2266,26 +2150,34 @@ class BaseMatchModel(BaseModel):
|
|
|
2266
2150
|
| None
|
|
2267
2151
|
) = None,
|
|
2268
2152
|
scheduler_params: dict | None = None,
|
|
2269
|
-
loss:
|
|
2153
|
+
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
2270
2154
|
loss_params: dict | list[dict] | None = None,
|
|
2271
2155
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
2272
|
-
callbacks: list[Callback] | None = None,
|
|
2273
2156
|
):
|
|
2274
2157
|
"""
|
|
2275
2158
|
Configure the match model for training.
|
|
2159
|
+
|
|
2160
|
+
Args:
|
|
2161
|
+
optimizer: Optimizer to use (name or instance). e.g., 'adam', 'sgd'.
|
|
2162
|
+
optimizer_params: Parameters for the optimizer. e.g., {'lr': 0.001}.
|
|
2163
|
+
scheduler: Learning rate scheduler (name, instance, or class). e.g., 'step_lr'.
|
|
2164
|
+
scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
2165
|
+
loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
|
|
2166
|
+
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2167
|
+
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2276
2168
|
"""
|
|
2277
2169
|
if self.training_mode not in self.support_training_modes:
|
|
2278
2170
|
raise ValueError(
|
|
2279
2171
|
f"{self.model_name.upper()} does not support training_mode='{self.training_mode}'. Supported modes: {self.support_training_modes}"
|
|
2280
2172
|
)
|
|
2281
2173
|
|
|
2282
|
-
default_loss_by_mode
|
|
2174
|
+
default_loss_by_mode = {
|
|
2283
2175
|
"pointwise": "bce",
|
|
2284
2176
|
"pairwise": "bpr",
|
|
2285
2177
|
"listwise": "sampled_softmax",
|
|
2286
2178
|
}
|
|
2287
2179
|
|
|
2288
|
-
effective_loss
|
|
2180
|
+
effective_loss = loss
|
|
2289
2181
|
if effective_loss is None:
|
|
2290
2182
|
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2291
2183
|
elif isinstance(effective_loss, str):
|
|
@@ -2316,7 +2208,6 @@ class BaseMatchModel(BaseModel):
|
|
|
2316
2208
|
loss=effective_loss,
|
|
2317
2209
|
loss_params=loss_params,
|
|
2318
2210
|
loss_weights=loss_weights,
|
|
2319
|
-
callbacks=callbacks,
|
|
2320
2211
|
)
|
|
2321
2212
|
|
|
2322
2213
|
def inbatch_logits(
|
|
@@ -2406,7 +2297,9 @@ class BaseMatchModel(BaseModel):
|
|
|
2406
2297
|
batch_size, batch_size - 1
|
|
2407
2298
|
) # [B, B-1]
|
|
2408
2299
|
|
|
2409
|
-
loss_fn =
|
|
2300
|
+
loss_fn = (
|
|
2301
|
+
self.loss_fn[0] if hasattr(self, "loss_fn") and self.loss_fn else None
|
|
2302
|
+
)
|
|
2410
2303
|
if isinstance(loss_fn, SampledSoftmaxLoss):
|
|
2411
2304
|
loss = loss_fn(pos_logits, neg_logits)
|
|
2412
2305
|
elif isinstance(loss_fn, (BPRLoss, HingeLoss)):
|