nextrec 0.4.21__py3-none-any.whl → 0.4.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/activation.py +1 -1
- nextrec/basic/heads.py +2 -3
- nextrec/basic/metrics.py +1 -2
- nextrec/basic/model.py +115 -80
- nextrec/basic/summary.py +36 -2
- nextrec/data/preprocessor.py +137 -5
- nextrec/loss/__init__.py +0 -4
- nextrec/loss/grad_norm.py +3 -3
- nextrec/loss/listwise.py +19 -6
- nextrec/loss/pairwise.py +6 -4
- nextrec/loss/pointwise.py +8 -6
- nextrec/models/multi_task/esmm.py +3 -26
- nextrec/models/multi_task/mmoe.py +2 -24
- nextrec/models/multi_task/ple.py +13 -35
- nextrec/models/multi_task/poso.py +4 -28
- nextrec/models/multi_task/share_bottom.py +1 -24
- nextrec/models/ranking/afm.py +3 -27
- nextrec/models/ranking/autoint.py +5 -38
- nextrec/models/ranking/dcn.py +1 -26
- nextrec/models/ranking/dcn_v2.py +5 -33
- nextrec/models/ranking/deepfm.py +2 -29
- nextrec/models/ranking/dien.py +2 -28
- nextrec/models/ranking/din.py +2 -27
- nextrec/models/ranking/eulernet.py +3 -30
- nextrec/models/ranking/ffm.py +0 -26
- nextrec/models/ranking/fibinet.py +8 -32
- nextrec/models/ranking/fm.py +0 -29
- nextrec/models/ranking/lr.py +0 -30
- nextrec/models/ranking/masknet.py +4 -30
- nextrec/models/ranking/pnn.py +4 -28
- nextrec/models/ranking/widedeep.py +0 -32
- nextrec/models/ranking/xdeepfm.py +0 -30
- nextrec/models/retrieval/dssm.py +0 -24
- nextrec/models/retrieval/dssm_v2.py +0 -24
- nextrec/models/retrieval/mind.py +0 -20
- nextrec/models/retrieval/sdm.py +0 -20
- nextrec/models/retrieval/youtube_dnn.py +0 -21
- nextrec/models/sequential/hstu.py +0 -18
- nextrec/utils/__init__.py +5 -1
- nextrec/{loss/loss_utils.py → utils/loss.py} +17 -7
- nextrec/utils/model.py +79 -1
- nextrec/utils/types.py +62 -23
- {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/METADATA +8 -6
- nextrec-0.4.23.dist-info/RECORD +81 -0
- nextrec-0.4.21.dist-info/RECORD +0 -81
- {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/WHEEL +0 -0
- {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.21.dist-info → nextrec-0.4.23.dist-info}/licenses/LICENSE +0 -0
nextrec/__version__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.4.
|
|
1
|
+
__version__ = "0.4.23"
|
nextrec/basic/activation.py
CHANGED
nextrec/basic/heads.py
CHANGED
|
@@ -15,6 +15,7 @@ import torch.nn as nn
|
|
|
15
15
|
import torch.nn.functional as F
|
|
16
16
|
|
|
17
17
|
from nextrec.basic.layers import PredictionLayer
|
|
18
|
+
from nextrec.utils.types import TaskTypeName
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
class TaskHead(nn.Module):
|
|
@@ -27,9 +28,7 @@ class TaskHead(nn.Module):
|
|
|
27
28
|
|
|
28
29
|
def __init__(
|
|
29
30
|
self,
|
|
30
|
-
task_type:
|
|
31
|
-
Literal["binary", "regression"] | list[Literal["binary", "regression"]]
|
|
32
|
-
) = "binary",
|
|
31
|
+
task_type: TaskTypeName | list[TaskTypeName] = "binary",
|
|
33
32
|
task_dims: int | list[int] | None = None,
|
|
34
33
|
use_bias: bool = True,
|
|
35
34
|
return_logits: bool = False,
|
nextrec/basic/metrics.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Metrics computation and configuration for model evaluation.
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -39,7 +39,6 @@ REGRESSION_METRICS = {"mse", "mae", "rmse", "r2", "mape", "msle"}
|
|
|
39
39
|
TASK_DEFAULT_METRICS = {
|
|
40
40
|
"binary": ["auc", "gauc", "ks", "logloss", "accuracy", "precision", "recall", "f1"],
|
|
41
41
|
"regression": ["mse", "mae", "rmse", "r2", "mape"],
|
|
42
|
-
"multilabel": ["auc", "hamming_loss", "subset_accuracy", "micro_f1", "macro_f1"],
|
|
43
42
|
"matching": ["auc", "gauc", "precision@10", "hitrate@10", "map@10", "cosine"]
|
|
44
43
|
+ [f"recall@{k}" for k in (5, 10, 20)]
|
|
45
44
|
+ [f"ndcg@{k}" for k in (5, 10, 20)]
|
nextrec/basic/model.py
CHANGED
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
import getpass
|
|
10
10
|
import logging
|
|
11
11
|
import os
|
|
12
|
+
import sys
|
|
12
13
|
import pickle
|
|
13
14
|
import socket
|
|
14
15
|
from pathlib import Path
|
|
@@ -16,6 +17,16 @@ from typing import Any, Literal
|
|
|
16
17
|
|
|
17
18
|
import numpy as np
|
|
18
19
|
import pandas as pd
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import swanlab # type: ignore
|
|
23
|
+
except ModuleNotFoundError:
|
|
24
|
+
swanlab = None
|
|
25
|
+
try:
|
|
26
|
+
import wandb # type: ignore
|
|
27
|
+
except ModuleNotFoundError:
|
|
28
|
+
wandb = None
|
|
29
|
+
|
|
19
30
|
import torch
|
|
20
31
|
import torch.distributed as dist
|
|
21
32
|
import torch.nn as nn
|
|
@@ -60,8 +71,8 @@ from nextrec.loss import (
|
|
|
60
71
|
InfoNCELoss,
|
|
61
72
|
SampledSoftmaxLoss,
|
|
62
73
|
TripletLoss,
|
|
63
|
-
get_loss_fn,
|
|
64
74
|
)
|
|
75
|
+
from nextrec.utils.loss import get_loss_fn
|
|
65
76
|
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
66
77
|
from nextrec.utils.console import display_metrics_table, progress
|
|
67
78
|
from nextrec.utils.torch_utils import (
|
|
@@ -74,8 +85,20 @@ from nextrec.utils.torch_utils import (
|
|
|
74
85
|
to_tensor,
|
|
75
86
|
)
|
|
76
87
|
from nextrec.utils.config import safe_value
|
|
77
|
-
from nextrec.utils.model import
|
|
78
|
-
|
|
88
|
+
from nextrec.utils.model import (
|
|
89
|
+
compute_ranking_loss,
|
|
90
|
+
get_loss_list,
|
|
91
|
+
resolve_loss_weights,
|
|
92
|
+
get_training_modes,
|
|
93
|
+
)
|
|
94
|
+
from nextrec.utils.types import (
|
|
95
|
+
LossName,
|
|
96
|
+
OptimizerName,
|
|
97
|
+
SchedulerName,
|
|
98
|
+
TrainingModeName,
|
|
99
|
+
TaskTypeName,
|
|
100
|
+
MetricsName,
|
|
101
|
+
)
|
|
79
102
|
|
|
80
103
|
|
|
81
104
|
class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
@@ -84,7 +107,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
84
107
|
raise NotImplementedError
|
|
85
108
|
|
|
86
109
|
@property
|
|
87
|
-
def default_task(self) ->
|
|
110
|
+
def default_task(self) -> TaskTypeName | list[TaskTypeName]:
|
|
88
111
|
raise NotImplementedError
|
|
89
112
|
|
|
90
113
|
def __init__(
|
|
@@ -94,11 +117,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
94
117
|
sequence_features: list[SequenceFeature] | None = None,
|
|
95
118
|
target: list[str] | str | None = None,
|
|
96
119
|
id_columns: list[str] | str | None = None,
|
|
97
|
-
task:
|
|
98
|
-
training_mode:
|
|
99
|
-
Literal["pointwise", "pairwise", "listwise"]
|
|
100
|
-
| list[Literal["pointwise", "pairwise", "listwise"]]
|
|
101
|
-
) = "pointwise",
|
|
120
|
+
task: TaskTypeName | list[TaskTypeName] | None = None,
|
|
121
|
+
training_mode: TrainingModeName | list[TrainingModeName] = "pointwise",
|
|
102
122
|
embedding_l1_reg: float = 0.0,
|
|
103
123
|
dense_l1_reg: float = 0.0,
|
|
104
124
|
embedding_l2_reg: float = 0.0,
|
|
@@ -136,6 +156,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
136
156
|
world_size: Number of processes (defaults to env WORLD_SIZE).
|
|
137
157
|
local_rank: Local rank for selecting CUDA device (defaults to env LOCAL_RANK).
|
|
138
158
|
ddp_find_unused_parameters: Default False, set it True only when exist unused parameters in ddp model, in most cases should be False.
|
|
159
|
+
|
|
160
|
+
Note:
|
|
161
|
+
Optimizer, scheduler, and loss are configured via compile().
|
|
139
162
|
"""
|
|
140
163
|
super(BaseModel, self).__init__()
|
|
141
164
|
|
|
@@ -168,25 +191,12 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
168
191
|
dense_features, sparse_features, sequence_features, target, id_columns
|
|
169
192
|
)
|
|
170
193
|
|
|
171
|
-
self.task = self.default_task
|
|
194
|
+
self.task = task or self.default_task
|
|
172
195
|
self.nums_task = len(self.task) if isinstance(self.task, list) else 1
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
if
|
|
176
|
-
|
|
177
|
-
"[BaseModel-init Error] training_mode list length must match number of tasks."
|
|
178
|
-
)
|
|
179
|
-
else:
|
|
180
|
-
training_modes = [training_mode] * self.nums_task
|
|
181
|
-
if any(
|
|
182
|
-
mode not in {"pointwise", "pairwise", "listwise"}
|
|
183
|
-
for mode in training_modes
|
|
184
|
-
):
|
|
185
|
-
raise ValueError(
|
|
186
|
-
"[BaseModel-init Error] training_mode must be one of {'pointwise', 'pairwise', 'listwise'}."
|
|
187
|
-
)
|
|
188
|
-
self.training_modes = training_modes
|
|
189
|
-
self.training_mode = training_modes if self.nums_task > 1 else training_modes[0]
|
|
196
|
+
self.training_modes = get_training_modes(training_mode, self.nums_task)
|
|
197
|
+
self.training_mode = (
|
|
198
|
+
self.training_modes if self.nums_task > 1 else self.training_modes[0]
|
|
199
|
+
)
|
|
190
200
|
|
|
191
201
|
self.embedding_l1_reg = embedding_l1_reg
|
|
192
202
|
self.dense_l1_reg = dense_l1_reg
|
|
@@ -194,7 +204,9 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
194
204
|
self.dense_l2_reg = dense_l2_reg
|
|
195
205
|
self.regularization_weights = []
|
|
196
206
|
self.embedding_params = []
|
|
197
|
-
|
|
207
|
+
|
|
208
|
+
self.ignore_label = None
|
|
209
|
+
self.compiled = False
|
|
198
210
|
|
|
199
211
|
self.max_gradient_norm = 1.0
|
|
200
212
|
self.logger_initialized = False
|
|
@@ -407,6 +419,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
407
419
|
loss: LossName | nn.Module | list[LossName | nn.Module] | None = "bce",
|
|
408
420
|
loss_params: dict | list[dict] | None = None,
|
|
409
421
|
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
422
|
+
ignore_label: int | float | None = -1,
|
|
410
423
|
):
|
|
411
424
|
"""
|
|
412
425
|
Configure the model for training.
|
|
@@ -419,34 +432,17 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
419
432
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
420
433
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
421
434
|
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
435
|
+
ignore_label: Label value to ignore when computing loss. Use this to skip gradients for unknown labels.
|
|
422
436
|
"""
|
|
437
|
+
self.ignore_label = ignore_label
|
|
423
438
|
default_losses = {
|
|
424
439
|
"pointwise": "bce",
|
|
425
440
|
"pairwise": "bpr",
|
|
426
441
|
"listwise": "listnet",
|
|
427
442
|
}
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
elif isinstance(effective_loss, list):
|
|
432
|
-
if not effective_loss:
|
|
433
|
-
loss_list = [default_losses[mode] for mode in self.training_modes]
|
|
434
|
-
else:
|
|
435
|
-
if len(effective_loss) != self.nums_task:
|
|
436
|
-
raise ValueError(
|
|
437
|
-
f"[BaseModel-compile Error] Number of loss functions ({len(effective_loss)}) must match number of tasks ({self.nums_task})."
|
|
438
|
-
)
|
|
439
|
-
loss_list = list(effective_loss)
|
|
440
|
-
else:
|
|
441
|
-
loss_list = [effective_loss] * self.nums_task
|
|
442
|
-
|
|
443
|
-
for idx, mode in enumerate(self.training_modes):
|
|
444
|
-
if isinstance(loss_list[idx], str) and loss_list[idx] in {
|
|
445
|
-
"bce",
|
|
446
|
-
"binary_crossentropy",
|
|
447
|
-
}:
|
|
448
|
-
if mode in {"pairwise", "listwise"}:
|
|
449
|
-
loss_list[idx] = default_losses[mode]
|
|
443
|
+
loss_list = get_loss_list(
|
|
444
|
+
loss, self.training_modes, self.nums_task, default_losses
|
|
445
|
+
)
|
|
450
446
|
self.loss_params = loss_params or {}
|
|
451
447
|
optimizer_params = optimizer_params or {}
|
|
452
448
|
self.optimizer_name = (
|
|
@@ -510,36 +506,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
510
506
|
nums_task=self.nums_task, device=self.device, **grad_norm_params
|
|
511
507
|
)
|
|
512
508
|
self.loss_weights = None
|
|
513
|
-
elif loss_weights is None:
|
|
514
|
-
self.loss_weights = None
|
|
515
|
-
elif self.nums_task == 1:
|
|
516
|
-
if isinstance(loss_weights, (list, tuple)):
|
|
517
|
-
if len(loss_weights) != 1:
|
|
518
|
-
raise ValueError(
|
|
519
|
-
"[BaseModel-compile Error] loss_weights list must have exactly one element for single-task setup."
|
|
520
|
-
)
|
|
521
|
-
loss_weights = loss_weights[0]
|
|
522
|
-
self.loss_weights = [float(loss_weights)] # type: ignore
|
|
523
509
|
else:
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
elif isinstance(loss_weights, (list, tuple)):
|
|
527
|
-
weights = [float(w) for w in loss_weights]
|
|
528
|
-
if len(weights) != self.nums_task:
|
|
529
|
-
raise ValueError(
|
|
530
|
-
f"[BaseModel-compile Error] Number of loss_weights ({len(weights)}) must match number of tasks ({self.nums_task})."
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
raise TypeError(
|
|
534
|
-
f"[BaseModel-compile Error] loss_weights must be int, float, list or tuple, got {type(loss_weights)}"
|
|
535
|
-
)
|
|
536
|
-
self.loss_weights = weights
|
|
510
|
+
self.loss_weights = resolve_loss_weights(loss_weights, self.nums_task)
|
|
511
|
+
self.compiled = True
|
|
537
512
|
|
|
538
513
|
def compute_loss(self, y_pred, y_true):
|
|
539
514
|
if y_true is None:
|
|
540
515
|
raise ValueError(
|
|
541
516
|
"[BaseModel-compute_loss Error] Ground truth labels (y_true) are required."
|
|
542
517
|
)
|
|
518
|
+
|
|
543
519
|
# single-task
|
|
544
520
|
if self.nums_task == 1:
|
|
545
521
|
if y_pred.dim() == 1:
|
|
@@ -547,13 +523,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
547
523
|
if y_true.dim() == 1:
|
|
548
524
|
y_true = y_true.view(-1, 1)
|
|
549
525
|
if y_pred.shape != y_true.shape:
|
|
550
|
-
raise ValueError(f"Shape mismatch: {y_pred.shape} vs {y_true.shape}")
|
|
551
|
-
loss_fn = self.loss_fn[0] if getattr(self, "loss_fn", None) else None
|
|
552
|
-
if loss_fn is None:
|
|
553
526
|
raise ValueError(
|
|
554
|
-
"[BaseModel-compute_loss Error]
|
|
527
|
+
f"[BaseModel-compute_loss Error] Shape mismatch: {y_pred.shape} vs {y_true.shape}"
|
|
555
528
|
)
|
|
529
|
+
|
|
530
|
+
loss_fn = self.loss_fn[0]
|
|
531
|
+
|
|
532
|
+
if self.ignore_label is not None:
|
|
533
|
+
valid_mask = y_true != self.ignore_label
|
|
534
|
+
if valid_mask.dim() > 1:
|
|
535
|
+
valid_mask = valid_mask.all(dim=1)
|
|
536
|
+
if not torch.any(valid_mask): # if no valid labels, return zero loss
|
|
537
|
+
return y_pred.sum() * 0.0
|
|
538
|
+
|
|
539
|
+
y_pred = y_pred[valid_mask]
|
|
540
|
+
y_true = y_true[valid_mask]
|
|
541
|
+
|
|
556
542
|
mode = self.training_modes[0]
|
|
543
|
+
|
|
557
544
|
task_dim = (
|
|
558
545
|
self.task_dims[0] if hasattr(self, "task_dims") else y_pred.shape[1] # type: ignore
|
|
559
546
|
)
|
|
@@ -584,7 +571,19 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
584
571
|
for i, (start, end) in enumerate(slices): # type: ignore
|
|
585
572
|
y_pred_i = y_pred[:, start:end]
|
|
586
573
|
y_true_i = y_true[:, start:end]
|
|
574
|
+
# mask ignored labels
|
|
575
|
+
if self.ignore_label is not None:
|
|
576
|
+
valid_mask = y_true_i != self.ignore_label
|
|
577
|
+
if valid_mask.dim() > 1:
|
|
578
|
+
valid_mask = valid_mask.all(dim=1)
|
|
579
|
+
if not torch.any(valid_mask):
|
|
580
|
+
task_losses.append(y_pred_i.sum() * 0.0)
|
|
581
|
+
continue
|
|
582
|
+
y_pred_i = y_pred_i[valid_mask]
|
|
583
|
+
y_true_i = y_true_i[valid_mask]
|
|
584
|
+
|
|
587
585
|
mode = self.training_modes[i]
|
|
586
|
+
|
|
588
587
|
if mode in {"pairwise", "listwise"}:
|
|
589
588
|
task_loss = compute_ranking_loss(
|
|
590
589
|
training_mode=mode,
|
|
@@ -594,7 +593,11 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
594
593
|
)
|
|
595
594
|
else:
|
|
596
595
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
596
|
+
# task_loss = normalize_task_loss(
|
|
597
|
+
# task_loss, valid_count, total_count
|
|
598
|
+
# ) # normalize by valid samples to avoid loss scale issues
|
|
597
599
|
task_losses.append(task_loss)
|
|
600
|
+
|
|
598
601
|
if self.grad_norm is not None:
|
|
599
602
|
if self.grad_norm_shared_params is None:
|
|
600
603
|
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
@@ -651,7 +654,7 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
651
654
|
train_data=None,
|
|
652
655
|
valid_data=None,
|
|
653
656
|
metrics: (
|
|
654
|
-
list[
|
|
657
|
+
list[MetricsName] | dict[str, list[MetricsName]] | None
|
|
655
658
|
) = None, # ['auc', 'logloss'] or {'target1': ['auc', 'logloss'], 'target2': ['mse']}
|
|
656
659
|
epochs: int = 1,
|
|
657
660
|
shuffle: bool = True,
|
|
@@ -665,6 +668,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
665
668
|
use_tensorboard: bool = True,
|
|
666
669
|
use_wandb: bool = False,
|
|
667
670
|
use_swanlab: bool = False,
|
|
671
|
+
wandb_api: str | None = None,
|
|
672
|
+
swanlab_api: str | None = None,
|
|
668
673
|
wandb_kwargs: dict | None = None,
|
|
669
674
|
swanlab_kwargs: dict | None = None,
|
|
670
675
|
auto_ddp_sampler: bool = True,
|
|
@@ -694,6 +699,8 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
694
699
|
use_tensorboard: Enable tensorboard logging.
|
|
695
700
|
use_wandb: Enable Weights & Biases logging.
|
|
696
701
|
use_swanlab: Enable SwanLab logging.
|
|
702
|
+
wandb_api: W&B API key for non-tty login.
|
|
703
|
+
swanlab_api: SwanLab API key for non-tty login.
|
|
697
704
|
wandb_kwargs: Optional kwargs for wandb.init(...).
|
|
698
705
|
swanlab_kwargs: Optional kwargs for swanlab.init(...).
|
|
699
706
|
auto_ddp_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
@@ -711,6 +718,16 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
711
718
|
)
|
|
712
719
|
self.to(self.device)
|
|
713
720
|
|
|
721
|
+
if not self.compiled:
|
|
722
|
+
self.compile(
|
|
723
|
+
optimizer="adam",
|
|
724
|
+
optimizer_params={},
|
|
725
|
+
scheduler=None,
|
|
726
|
+
scheduler_params={},
|
|
727
|
+
loss=None,
|
|
728
|
+
loss_params={},
|
|
729
|
+
)
|
|
730
|
+
|
|
714
731
|
if (
|
|
715
732
|
self.distributed
|
|
716
733
|
and dist.is_available()
|
|
@@ -785,6 +802,24 @@ class BaseModel(SummarySet, FeatureSet, nn.Module):
|
|
|
785
802
|
}
|
|
786
803
|
training_config: dict = safe_value(training_config) # type: ignore
|
|
787
804
|
|
|
805
|
+
if self.is_main_process:
|
|
806
|
+
is_tty = sys.stdin.isatty() and sys.stdout.isatty()
|
|
807
|
+
if not is_tty:
|
|
808
|
+
if use_wandb and wandb_api:
|
|
809
|
+
if wandb is None:
|
|
810
|
+
logging.warning(
|
|
811
|
+
"[BaseModel-fit] wandb not installed, skip wandb login."
|
|
812
|
+
)
|
|
813
|
+
else:
|
|
814
|
+
wandb.login(key=wandb_api)
|
|
815
|
+
if use_swanlab and swanlab_api:
|
|
816
|
+
if swanlab is None:
|
|
817
|
+
logging.warning(
|
|
818
|
+
"[BaseModel-fit] swanlab not installed, skip swanlab login."
|
|
819
|
+
)
|
|
820
|
+
else:
|
|
821
|
+
swanlab.login(api_key=swanlab_api)
|
|
822
|
+
|
|
788
823
|
self.training_logger = (
|
|
789
824
|
TrainingLogger(
|
|
790
825
|
session=self.session,
|
|
@@ -2164,7 +2199,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2164
2199
|
scheduler_params: Parameters for the scheduler. e.g., {'step_size': 10, 'gamma': 0.1}.
|
|
2165
2200
|
loss: Loss function(s) to use (name, instance, or list). e.g., 'bce'.
|
|
2166
2201
|
loss_params: Parameters for the loss function(s). e.g., {'reduction': 'mean'}.
|
|
2167
|
-
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2202
|
+
loss_weights: Weights for the loss function(s). e.g., 1.0 or [0.7, 0.3].
|
|
2168
2203
|
"""
|
|
2169
2204
|
if self.training_mode not in self.support_training_modes:
|
|
2170
2205
|
raise ValueError(
|
nextrec/basic/summary.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Summary utilities for BaseModel.
|
|
3
|
+
|
|
4
|
+
Date: create on 03/12/2025
|
|
5
|
+
Checkpoint: edit on 29/12/2025
|
|
6
|
+
Author: Yang Zhou,zyaztec@gmail.com
|
|
3
7
|
"""
|
|
4
8
|
|
|
5
9
|
from __future__ import annotations
|
|
@@ -12,9 +16,39 @@ from torch.utils.data import DataLoader
|
|
|
12
16
|
|
|
13
17
|
from nextrec.basic.loggers import colorize, format_kv
|
|
14
18
|
from nextrec.data.data_processing import extract_label_arrays, get_data_length
|
|
19
|
+
from nextrec.utils.types import TaskTypeName
|
|
15
20
|
|
|
16
21
|
|
|
17
22
|
class SummarySet:
|
|
23
|
+
model_name: str
|
|
24
|
+
dense_features: list[Any]
|
|
25
|
+
sparse_features: list[Any]
|
|
26
|
+
sequence_features: list[Any]
|
|
27
|
+
task: TaskTypeName | list[TaskTypeName]
|
|
28
|
+
target_columns: list[str]
|
|
29
|
+
nums_task: int
|
|
30
|
+
metrics: Any
|
|
31
|
+
device: Any
|
|
32
|
+
optimizer_name: str
|
|
33
|
+
optimizer_params: dict[str, Any]
|
|
34
|
+
scheduler_name: str | None
|
|
35
|
+
scheduler_params: dict[str, Any]
|
|
36
|
+
loss_config: Any
|
|
37
|
+
loss_weights: Any
|
|
38
|
+
grad_norm: Any
|
|
39
|
+
embedding_l1_reg: float
|
|
40
|
+
embedding_l2_reg: float
|
|
41
|
+
dense_l1_reg: float
|
|
42
|
+
dense_l2_reg: float
|
|
43
|
+
early_stop_patience: int
|
|
44
|
+
max_gradient_norm: float | None
|
|
45
|
+
metrics_sample_limit: int | None
|
|
46
|
+
session_id: str | None
|
|
47
|
+
features_config_path: str
|
|
48
|
+
checkpoint_path: str
|
|
49
|
+
train_data_summary: dict[str, Any] | None
|
|
50
|
+
valid_data_summary: dict[str, Any] | None
|
|
51
|
+
|
|
18
52
|
def build_data_summary(
|
|
19
53
|
self, data: Any, data_loader: DataLoader | None, sample_key: str
|
|
20
54
|
):
|
|
@@ -305,7 +339,7 @@ class SummarySet:
|
|
|
305
339
|
lines = details.get("lines", [])
|
|
306
340
|
logger.info(f"{target_name}:")
|
|
307
341
|
for label, value in lines:
|
|
308
|
-
logger.info(format_kv(label, value))
|
|
342
|
+
logger.info(f" {format_kv(label, value)}")
|
|
309
343
|
|
|
310
344
|
if self.valid_data_summary:
|
|
311
345
|
if self.train_data_summary:
|
|
@@ -320,4 +354,4 @@ class SummarySet:
|
|
|
320
354
|
lines = details.get("lines", [])
|
|
321
355
|
logger.info(f"{target_name}:")
|
|
322
356
|
for label, value in lines:
|
|
323
|
-
logger.info(format_kv(label, value))
|
|
357
|
+
logger.info(f" {format_kv(label, value)}")
|