nextrec 0.4.11__py3-none-any.whl → 0.4.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nextrec/__version__.py +1 -1
- nextrec/basic/callback.py +44 -54
- nextrec/basic/features.py +35 -22
- nextrec/basic/layers.py +64 -68
- nextrec/basic/loggers.py +2 -2
- nextrec/basic/metrics.py +9 -5
- nextrec/basic/model.py +208 -110
- nextrec/cli.py +17 -5
- nextrec/data/preprocessor.py +4 -4
- nextrec/loss/__init__.py +3 -0
- nextrec/loss/grad_norm.py +232 -0
- nextrec/loss/loss_utils.py +1 -1
- nextrec/models/multi_task/esmm.py +1 -0
- nextrec/models/multi_task/mmoe.py +1 -0
- nextrec/models/multi_task/ple.py +1 -0
- nextrec/models/multi_task/poso.py +4 -0
- nextrec/models/multi_task/share_bottom.py +1 -0
- nextrec/models/ranking/eulernet.py +44 -75
- nextrec/models/ranking/ffm.py +275 -0
- nextrec/models/ranking/lr.py +1 -3
- nextrec/utils/__init__.py +2 -1
- nextrec/utils/console.py +9 -1
- nextrec/utils/model.py +14 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/METADATA +7 -7
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/RECORD +28 -27
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/WHEEL +0 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/entry_points.txt +0 -0
- {nextrec-0.4.11.dist-info → nextrec-0.4.13.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/model.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
Base Model & Base Match Model Class
|
|
3
3
|
|
|
4
4
|
Date: create on 27/10/2025
|
|
5
|
-
Checkpoint: edit on
|
|
5
|
+
Checkpoint: edit on 20/12/2025
|
|
6
6
|
Author: Yang Zhou,zyaztec@gmail.com
|
|
7
7
|
"""
|
|
8
8
|
|
|
@@ -50,12 +50,14 @@ from nextrec.data.dataloader import (
|
|
|
50
50
|
)
|
|
51
51
|
from nextrec.loss import (
|
|
52
52
|
BPRLoss,
|
|
53
|
+
GradNormLossWeighting,
|
|
53
54
|
HingeLoss,
|
|
54
55
|
InfoNCELoss,
|
|
55
56
|
SampledSoftmaxLoss,
|
|
56
57
|
TripletLoss,
|
|
57
58
|
get_loss_fn,
|
|
58
59
|
)
|
|
60
|
+
from nextrec.loss.grad_norm import get_grad_norm_shared_params
|
|
59
61
|
from nextrec.utils.console import display_metrics_table, progress
|
|
60
62
|
from nextrec.utils.torch_utils import (
|
|
61
63
|
add_distributed_sampler,
|
|
@@ -169,6 +171,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
169
171
|
self.loss_weight = None
|
|
170
172
|
|
|
171
173
|
self.early_stop_patience = early_stop_patience
|
|
174
|
+
# max samples to keep for training metrics, in case of large training set
|
|
172
175
|
self.max_metrics_samples = (
|
|
173
176
|
None if max_metrics_samples is None else int(max_metrics_samples)
|
|
174
177
|
)
|
|
@@ -176,6 +179,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
176
179
|
self.logger_initialized = False
|
|
177
180
|
self.training_logger = None
|
|
178
181
|
self.callbacks = CallbackList(callbacks) if callbacks else CallbackList()
|
|
182
|
+
self.grad_norm: GradNormLossWeighting | None = None
|
|
183
|
+
self.grad_norm_shared_params: list[torch.nn.Parameter] | None = None
|
|
179
184
|
|
|
180
185
|
def register_regularization_weights(
|
|
181
186
|
self,
|
|
@@ -376,7 +381,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
376
381
|
scheduler_params: dict | None = None,
|
|
377
382
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
378
383
|
loss_params: dict | list[dict] | None = None,
|
|
379
|
-
loss_weights: int | float | list[int | float] | None = None,
|
|
384
|
+
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
380
385
|
callbacks: list[Callback] | None = None,
|
|
381
386
|
):
|
|
382
387
|
"""
|
|
@@ -389,6 +394,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
389
394
|
loss: Loss function name, instance, or list for multi-task. e.g., 'bce', 'mse', or torch.nn.BCELoss(), you can also use custom loss functions.
|
|
390
395
|
loss_params: Loss function parameters, or list for multi-task. e.g., {'weight': tensor([0.25, 0.75])}.
|
|
391
396
|
loss_weights: Weights for each task loss, int/float for single-task or list for multi-task. e.g., 1.0, or [1.0, 0.5].
|
|
397
|
+
Use "grad_norm" or {"method": "grad_norm", ...} to enable GradNorm for multi-task loss balancing.
|
|
392
398
|
callbacks: Additional callbacks to add to the existing callback list. e.g., [EarlyStopper(), CheckpointSaver()].
|
|
393
399
|
"""
|
|
394
400
|
if loss_params is None:
|
|
@@ -442,7 +448,31 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
442
448
|
for i in range(self.nums_task)
|
|
443
449
|
]
|
|
444
450
|
|
|
445
|
-
|
|
451
|
+
self.grad_norm = None
|
|
452
|
+
self.grad_norm_shared_params = None
|
|
453
|
+
if isinstance(loss_weights, str) and loss_weights.lower() == "grad_norm":
|
|
454
|
+
if self.nums_task == 1:
|
|
455
|
+
raise ValueError(
|
|
456
|
+
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
457
|
+
)
|
|
458
|
+
self.grad_norm = GradNormLossWeighting(
|
|
459
|
+
num_tasks=self.nums_task, device=self.device
|
|
460
|
+
)
|
|
461
|
+
self.loss_weights = None
|
|
462
|
+
elif (
|
|
463
|
+
isinstance(loss_weights, dict) and loss_weights.get("method") == "grad_norm"
|
|
464
|
+
):
|
|
465
|
+
if self.nums_task == 1:
|
|
466
|
+
raise ValueError(
|
|
467
|
+
"[BaseModel-compile Error] GradNorm requires multi-task setup."
|
|
468
|
+
)
|
|
469
|
+
grad_norm_params = dict(loss_weights)
|
|
470
|
+
grad_norm_params.pop("method", None)
|
|
471
|
+
self.grad_norm = GradNormLossWeighting(
|
|
472
|
+
num_tasks=self.nums_task, device=self.device, **grad_norm_params
|
|
473
|
+
)
|
|
474
|
+
self.loss_weights = None
|
|
475
|
+
elif loss_weights is None:
|
|
446
476
|
self.loss_weights = None
|
|
447
477
|
elif self.nums_task == 1:
|
|
448
478
|
if isinstance(loss_weights, (list, tuple)):
|
|
@@ -507,9 +537,20 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
507
537
|
y_pred_i = y_pred[:, start:end]
|
|
508
538
|
y_true_i = y_true[:, start:end]
|
|
509
539
|
task_loss = self.loss_fn[i](y_pred_i, y_true_i)
|
|
510
|
-
if isinstance(self.loss_weights, (list, tuple)):
|
|
511
|
-
task_loss *= self.loss_weights[i]
|
|
512
540
|
task_losses.append(task_loss)
|
|
541
|
+
if self.grad_norm is not None:
|
|
542
|
+
if self.grad_norm_shared_params is None:
|
|
543
|
+
self.grad_norm_shared_params = get_grad_norm_shared_params(
|
|
544
|
+
self, getattr(self, "grad_norm_shared_modules", None)
|
|
545
|
+
)
|
|
546
|
+
return self.grad_norm.compute_weighted_loss(
|
|
547
|
+
task_losses, self.grad_norm_shared_params
|
|
548
|
+
)
|
|
549
|
+
if isinstance(self.loss_weights, (list, tuple)):
|
|
550
|
+
task_losses = [
|
|
551
|
+
task_loss * self.loss_weights[i]
|
|
552
|
+
for i, task_loss in enumerate(task_losses)
|
|
553
|
+
]
|
|
513
554
|
return torch.stack(task_losses).sum()
|
|
514
555
|
|
|
515
556
|
def prepare_data_loader(
|
|
@@ -563,6 +604,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
563
604
|
num_workers: int = 0,
|
|
564
605
|
tensorboard: bool = True,
|
|
565
606
|
auto_distributed_sampler: bool = True,
|
|
607
|
+
log_interval: int = 1,
|
|
566
608
|
):
|
|
567
609
|
"""
|
|
568
610
|
Train the model.
|
|
@@ -579,6 +621,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
579
621
|
num_workers: DataLoader worker count.
|
|
580
622
|
tensorboard: Enable tensorboard logging.
|
|
581
623
|
auto_distributed_sampler: Attach DistributedSampler automatically when distributed, set False to when data is already sharded per rank.
|
|
624
|
+
log_interval: Log validation metrics every N epochs (still computes metrics each epoch).
|
|
582
625
|
|
|
583
626
|
Notes:
|
|
584
627
|
- Distributed training uses DDP; init occurs via env vars (RANK/WORLD_SIZE/LOCAL_RANK).
|
|
@@ -630,6 +673,9 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
630
673
|
)
|
|
631
674
|
) # ['auc', 'logloss'], {'target1': ['auc', 'logloss'], 'target2': ['mse']}, 'max'
|
|
632
675
|
|
|
676
|
+
if log_interval < 1:
|
|
677
|
+
raise ValueError("[BaseModel-fit Error] log_interval must be >= 1.")
|
|
678
|
+
|
|
633
679
|
# Setup default callbacks if missing
|
|
634
680
|
if self.nums_task == 1:
|
|
635
681
|
monitor_metric = f"val_{self.metrics[0]}"
|
|
@@ -911,23 +957,27 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
911
957
|
user_ids=valid_user_ids if self.needs_user_ids else None,
|
|
912
958
|
num_workers=num_workers,
|
|
913
959
|
)
|
|
914
|
-
|
|
915
|
-
epoch
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
960
|
+
should_log_valid = (epoch + 1) % log_interval == 0 or (
|
|
961
|
+
epoch + 1
|
|
962
|
+
) == epochs
|
|
963
|
+
if should_log_valid:
|
|
964
|
+
display_metrics_table(
|
|
965
|
+
epoch=epoch + 1,
|
|
966
|
+
epochs=epochs,
|
|
967
|
+
split="Valid",
|
|
968
|
+
loss=None,
|
|
969
|
+
metrics=val_metrics,
|
|
970
|
+
target_names=self.target_columns,
|
|
971
|
+
base_metrics=(
|
|
972
|
+
self.metrics
|
|
973
|
+
if isinstance(getattr(self, "metrics", None), list)
|
|
974
|
+
else None
|
|
975
|
+
),
|
|
976
|
+
is_main_process=self.is_main_process,
|
|
977
|
+
colorize=lambda s: colorize(" " + s, color="cyan"),
|
|
978
|
+
)
|
|
929
979
|
self.callbacks.on_validation_end()
|
|
930
|
-
if val_metrics and self.training_logger:
|
|
980
|
+
if should_log_valid and val_metrics and self.training_logger:
|
|
931
981
|
self.training_logger.log_metrics(
|
|
932
982
|
val_metrics, step=epoch + 1, split="valid"
|
|
933
983
|
)
|
|
@@ -1043,6 +1093,8 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1043
1093
|
params = model.parameters() if self.ddp_model is not None else self.parameters() # type: ignore # ddp model parameters or self parameters
|
|
1044
1094
|
nn.utils.clip_grad_norm_(params, self.max_gradient_norm)
|
|
1045
1095
|
self.optimizer_fn.step()
|
|
1096
|
+
if self.grad_norm is not None:
|
|
1097
|
+
self.grad_norm.step()
|
|
1046
1098
|
accumulated_loss += loss.item()
|
|
1047
1099
|
|
|
1048
1100
|
if (
|
|
@@ -1207,7 +1259,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1207
1259
|
user_id_column: Column name for user IDs if user_ids is not provided. e.g. 'user_id'
|
|
1208
1260
|
num_workers: DataLoader worker count.
|
|
1209
1261
|
"""
|
|
1210
|
-
model = self.ddp_model if
|
|
1262
|
+
model = self.ddp_model if self.ddp_model is not None else self
|
|
1211
1263
|
model.eval()
|
|
1212
1264
|
eval_metrics = metrics if metrics is not None else self.metrics
|
|
1213
1265
|
if eval_metrics is None:
|
|
@@ -1233,6 +1285,10 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1233
1285
|
batch_count += 1
|
|
1234
1286
|
batch_dict = batch_to_dict(batch_data)
|
|
1235
1287
|
X_input, y_true = self.get_input(batch_dict, require_labels=True)
|
|
1288
|
+
if X_input is None:
|
|
1289
|
+
raise ValueError(
|
|
1290
|
+
"[BaseModel-evaluate Error] No input features found in the evaluation data."
|
|
1291
|
+
)
|
|
1236
1292
|
y_pred = model(X_input)
|
|
1237
1293
|
if y_true is not None:
|
|
1238
1294
|
y_true_list.append(y_true.cpu().numpy())
|
|
@@ -1322,7 +1378,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1322
1378
|
return_dataframe: bool = True,
|
|
1323
1379
|
streaming_chunk_size: int = 10000,
|
|
1324
1380
|
num_workers: int = 0,
|
|
1325
|
-
) -> pd.DataFrame | np.ndarray:
|
|
1381
|
+
) -> pd.DataFrame | np.ndarray | Path | None:
|
|
1326
1382
|
"""
|
|
1327
1383
|
Note: predict does not support distributed mode currently, consider it as a single-process operation.
|
|
1328
1384
|
Make predictions on the given data.
|
|
@@ -1497,7 +1553,7 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1497
1553
|
streaming_chunk_size: int,
|
|
1498
1554
|
return_dataframe: bool,
|
|
1499
1555
|
id_columns: list[str] | None = None,
|
|
1500
|
-
) -> pd.DataFrame:
|
|
1556
|
+
) -> pd.DataFrame | Path:
|
|
1501
1557
|
if isinstance(data, (str, os.PathLike)):
|
|
1502
1558
|
rec_loader = RecDataLoader(
|
|
1503
1559
|
dense_features=self.dense_features,
|
|
@@ -1624,11 +1680,11 @@ class BaseModel(FeatureSet, nn.Module):
|
|
|
1624
1680
|
)
|
|
1625
1681
|
model_path = Path(target_path)
|
|
1626
1682
|
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
1630
|
-
|
|
1631
|
-
|
|
1683
|
+
ddp_model = getattr(self, "ddp_model", None)
|
|
1684
|
+
if ddp_model is not None:
|
|
1685
|
+
model_to_save = ddp_model.module
|
|
1686
|
+
else:
|
|
1687
|
+
model_to_save = self
|
|
1632
1688
|
torch.save(model_to_save.state_dict(), model_path)
|
|
1633
1689
|
# torch.save(self.state_dict(), model_path)
|
|
1634
1690
|
|
|
@@ -2025,33 +2081,18 @@ class BaseMatchModel(BaseModel):
|
|
|
2025
2081
|
self.num_negative_samples = num_negative_samples
|
|
2026
2082
|
self.temperature = temperature
|
|
2027
2083
|
self.similarity_metric = similarity_metric
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
self.item_dense_features
|
|
2041
|
-
+ self.item_sparse_features
|
|
2042
|
-
+ self.item_sequence_features
|
|
2043
|
-
)
|
|
2044
|
-
]
|
|
2045
|
-
|
|
2046
|
-
def get_user_features(self, X_input: dict) -> dict:
|
|
2047
|
-
return {
|
|
2048
|
-
name: X_input[name] for name in self.user_feature_names if name in X_input
|
|
2049
|
-
}
|
|
2050
|
-
|
|
2051
|
-
def get_item_features(self, X_input: dict) -> dict:
|
|
2052
|
-
return {
|
|
2053
|
-
name: X_input[name] for name in self.item_feature_names if name in X_input
|
|
2054
|
-
}
|
|
2084
|
+
self.user_features_all = (
|
|
2085
|
+
self.user_dense_features
|
|
2086
|
+
+ self.user_sparse_features
|
|
2087
|
+
+ self.user_sequence_features
|
|
2088
|
+
)
|
|
2089
|
+
self.item_features_all = (
|
|
2090
|
+
self.item_dense_features
|
|
2091
|
+
+ self.item_sparse_features
|
|
2092
|
+
+ self.item_sequence_features
|
|
2093
|
+
)
|
|
2094
|
+
self.user_feature_names = {feature.name for feature in self.user_features_all}
|
|
2095
|
+
self.item_feature_names = {feature.name for feature in self.item_features_all}
|
|
2055
2096
|
|
|
2056
2097
|
def compile(
|
|
2057
2098
|
self,
|
|
@@ -2068,13 +2109,11 @@ class BaseMatchModel(BaseModel):
|
|
|
2068
2109
|
scheduler_params: dict | None = None,
|
|
2069
2110
|
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
2070
2111
|
loss_params: dict | list[dict] | None = None,
|
|
2071
|
-
loss_weights: int | float | list[int | float] | None = None,
|
|
2112
|
+
loss_weights: int | float | list[int | float] | dict | str | None = None,
|
|
2072
2113
|
callbacks: list[Callback] | None = None,
|
|
2073
2114
|
):
|
|
2074
2115
|
"""
|
|
2075
2116
|
Configure the match model for training.
|
|
2076
|
-
|
|
2077
|
-
This mirrors `BaseModel.compile()` and additionally validates `training_mode`.
|
|
2078
2117
|
"""
|
|
2079
2118
|
if self.training_mode not in self.support_training_modes:
|
|
2080
2119
|
raise ValueError(
|
|
@@ -2090,7 +2129,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2090
2129
|
effective_loss: str | nn.Module | list[str | nn.Module] | None = loss
|
|
2091
2130
|
if effective_loss is None:
|
|
2092
2131
|
effective_loss = default_loss_by_mode[self.training_mode]
|
|
2093
|
-
elif isinstance(effective_loss,
|
|
2132
|
+
elif isinstance(effective_loss, str):
|
|
2094
2133
|
if self.training_mode in {"pairwise", "listwise"} and effective_loss in {
|
|
2095
2134
|
"bce",
|
|
2096
2135
|
"binary_crossentropy",
|
|
@@ -2124,6 +2163,7 @@ class BaseMatchModel(BaseModel):
|
|
|
2124
2163
|
def inbatch_logits(
|
|
2125
2164
|
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2126
2165
|
) -> torch.Tensor:
|
|
2166
|
+
"""Compute in-batch logits matrix between user and item embeddings."""
|
|
2127
2167
|
if self.similarity_metric == "dot":
|
|
2128
2168
|
logits = torch.matmul(user_emb, item_emb.t())
|
|
2129
2169
|
elif self.similarity_metric == "cosine":
|
|
@@ -2131,8 +2171,8 @@ class BaseMatchModel(BaseModel):
|
|
|
2131
2171
|
item_norm = F.normalize(item_emb, p=2, dim=-1)
|
|
2132
2172
|
logits = torch.matmul(user_norm, item_norm.t())
|
|
2133
2173
|
elif self.similarity_metric == "euclidean":
|
|
2134
|
-
user_sq = (user_emb**2
|
|
2135
|
-
item_sq = (item_emb**2
|
|
2174
|
+
user_sq = torch.sum(user_emb**2, dim=1, keepdim=True) # [B, 1]
|
|
2175
|
+
item_sq = torch.sum(item_emb**2, dim=1, keepdim=True).t() # [1, B]
|
|
2136
2176
|
logits = -(user_sq + item_sq - 2.0 * torch.matmul(user_emb, item_emb.t()))
|
|
2137
2177
|
else:
|
|
2138
2178
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
@@ -2141,56 +2181,43 @@ class BaseMatchModel(BaseModel):
|
|
|
2141
2181
|
def compute_similarity(
|
|
2142
2182
|
self, user_emb: torch.Tensor, item_emb: torch.Tensor
|
|
2143
2183
|
) -> torch.Tensor:
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
similarity = torch.sum(
|
|
2148
|
-
user_emb * item_emb, dim=-1
|
|
2149
|
-
) # [batch_size, num_items]
|
|
2150
|
-
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
2151
|
-
# [batch_size, emb_dim] @ [batch_size, num_items, emb_dim]
|
|
2152
|
-
user_emb_expanded = user_emb.unsqueeze(1) # [batch_size, 1, emb_dim]
|
|
2153
|
-
similarity = torch.sum(
|
|
2154
|
-
user_emb_expanded * item_emb, dim=-1
|
|
2155
|
-
) # [batch_size, num_items]
|
|
2156
|
-
else:
|
|
2157
|
-
similarity = torch.sum(user_emb * item_emb, dim=-1) # [batch_size]
|
|
2184
|
+
"""Compute similarity score between user and item embeddings."""
|
|
2185
|
+
if user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
2186
|
+
user_emb = user_emb.unsqueeze(1)
|
|
2158
2187
|
|
|
2188
|
+
if self.similarity_metric == "dot":
|
|
2189
|
+
similarity = torch.sum(user_emb * item_emb, dim=-1)
|
|
2159
2190
|
elif self.similarity_metric == "cosine":
|
|
2160
|
-
|
|
2161
|
-
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
2162
|
-
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
2163
|
-
user_emb_expanded = user_emb.unsqueeze(1)
|
|
2164
|
-
similarity = F.cosine_similarity(user_emb_expanded, item_emb, dim=-1)
|
|
2165
|
-
else:
|
|
2166
|
-
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
2167
|
-
|
|
2191
|
+
similarity = F.cosine_similarity(user_emb, item_emb, dim=-1)
|
|
2168
2192
|
elif self.similarity_metric == "euclidean":
|
|
2169
|
-
|
|
2170
|
-
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
2171
|
-
elif user_emb.dim() == 2 and item_emb.dim() == 3:
|
|
2172
|
-
user_emb_expanded = user_emb.unsqueeze(1)
|
|
2173
|
-
distance = torch.sum((user_emb_expanded - item_emb) ** 2, dim=-1)
|
|
2174
|
-
else:
|
|
2175
|
-
distance = torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
2176
|
-
similarity = -distance
|
|
2177
|
-
|
|
2193
|
+
similarity = -torch.sum((user_emb - item_emb) ** 2, dim=-1)
|
|
2178
2194
|
else:
|
|
2179
2195
|
raise ValueError(f"Unknown similarity metric: {self.similarity_metric}")
|
|
2180
2196
|
similarity = similarity / self.temperature
|
|
2181
2197
|
return similarity
|
|
2182
2198
|
|
|
2183
2199
|
def user_tower(self, user_input: dict) -> torch.Tensor:
|
|
2200
|
+
"""User tower to encode user features into embeddings."""
|
|
2184
2201
|
raise NotImplementedError
|
|
2185
2202
|
|
|
2186
2203
|
def item_tower(self, item_input: dict) -> torch.Tensor:
|
|
2204
|
+
"""Item tower to encode item features into embeddings."""
|
|
2187
2205
|
raise NotImplementedError
|
|
2188
2206
|
|
|
2189
2207
|
def forward(
|
|
2190
2208
|
self, X_input: dict
|
|
2191
2209
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
2192
|
-
|
|
2193
|
-
|
|
2210
|
+
"""Rewrite forward to handle user and item features separately."""
|
|
2211
|
+
user_input = {
|
|
2212
|
+
name: tensor
|
|
2213
|
+
for name, tensor in X_input.items()
|
|
2214
|
+
if name in self.user_feature_names
|
|
2215
|
+
}
|
|
2216
|
+
item_input = {
|
|
2217
|
+
name: tensor
|
|
2218
|
+
for name, tensor in X_input.items()
|
|
2219
|
+
if name in self.item_feature_names
|
|
2220
|
+
}
|
|
2194
2221
|
|
|
2195
2222
|
user_emb = self.user_tower(user_input) # [B, D]
|
|
2196
2223
|
item_emb = self.item_tower(item_input) # [B, D]
|
|
@@ -2254,11 +2281,35 @@ class BaseMatchModel(BaseModel):
|
|
|
2254
2281
|
raise ValueError(f"Unknown training mode: {self.training_mode}")
|
|
2255
2282
|
|
|
2256
2283
|
def prepare_feature_data(
|
|
2257
|
-
self,
|
|
2284
|
+
self,
|
|
2285
|
+
data,
|
|
2286
|
+
features: list,
|
|
2287
|
+
batch_size: int,
|
|
2288
|
+
num_workers: int = 0,
|
|
2289
|
+
streaming_chunk_size: int = 10000,
|
|
2258
2290
|
) -> DataLoader:
|
|
2259
2291
|
"""Prepare data loader for specific features."""
|
|
2260
2292
|
if isinstance(data, DataLoader):
|
|
2261
2293
|
return data
|
|
2294
|
+
if isinstance(data, (str, os.PathLike)):
|
|
2295
|
+
dense_features = [f for f in features if isinstance(f, DenseFeature)]
|
|
2296
|
+
sparse_features = [f for f in features if isinstance(f, SparseFeature)]
|
|
2297
|
+
sequence_features = [f for f in features if isinstance(f, SequenceFeature)]
|
|
2298
|
+
rec_loader = RecDataLoader(
|
|
2299
|
+
dense_features=dense_features,
|
|
2300
|
+
sparse_features=sparse_features,
|
|
2301
|
+
sequence_features=sequence_features,
|
|
2302
|
+
target=[],
|
|
2303
|
+
id_columns=[],
|
|
2304
|
+
)
|
|
2305
|
+
return rec_loader.create_dataloader(
|
|
2306
|
+
data=data,
|
|
2307
|
+
batch_size=batch_size,
|
|
2308
|
+
shuffle=False,
|
|
2309
|
+
streaming=True,
|
|
2310
|
+
chunk_size=streaming_chunk_size,
|
|
2311
|
+
num_workers=num_workers,
|
|
2312
|
+
)
|
|
2262
2313
|
tensors = build_tensors_from_data(
|
|
2263
2314
|
data=data,
|
|
2264
2315
|
raw_data=data,
|
|
@@ -2276,44 +2327,91 @@ class BaseMatchModel(BaseModel):
|
|
|
2276
2327
|
batch_size=batch_size,
|
|
2277
2328
|
shuffle=False,
|
|
2278
2329
|
collate_fn=collate_fn,
|
|
2330
|
+
num_workers=num_workers,
|
|
2279
2331
|
)
|
|
2280
2332
|
|
|
2333
|
+
def build_feature_tensors(self, feature_source: dict, features: list) -> dict:
|
|
2334
|
+
"""Convert feature values to tensors on the model device."""
|
|
2335
|
+
tensors = {}
|
|
2336
|
+
for feature in features:
|
|
2337
|
+
if feature.name not in feature_source:
|
|
2338
|
+
raise KeyError(
|
|
2339
|
+
f"[BaseMatchModel-feature Error] Feature '{feature.name}' not found in input data."
|
|
2340
|
+
)
|
|
2341
|
+
feature_data = get_column_data(feature_source, feature.name)
|
|
2342
|
+
tensors[feature.name] = to_tensor(
|
|
2343
|
+
feature_data,
|
|
2344
|
+
dtype=(
|
|
2345
|
+
torch.float32 if isinstance(feature, DenseFeature) else torch.long
|
|
2346
|
+
),
|
|
2347
|
+
device=self.device,
|
|
2348
|
+
)
|
|
2349
|
+
return tensors
|
|
2350
|
+
|
|
2281
2351
|
def encode_user(
|
|
2282
|
-
self,
|
|
2352
|
+
self,
|
|
2353
|
+
data: (
|
|
2354
|
+
dict
|
|
2355
|
+
| pd.DataFrame
|
|
2356
|
+
| DataLoader
|
|
2357
|
+
| str
|
|
2358
|
+
| os.PathLike
|
|
2359
|
+
| list[str | os.PathLike]
|
|
2360
|
+
),
|
|
2361
|
+
batch_size: int = 512,
|
|
2362
|
+
num_workers: int = 0,
|
|
2363
|
+
streaming_chunk_size: int = 10000,
|
|
2283
2364
|
) -> np.ndarray:
|
|
2284
2365
|
self.eval()
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2366
|
+
data_loader = self.prepare_feature_data(
|
|
2367
|
+
data,
|
|
2368
|
+
self.user_features_all,
|
|
2369
|
+
batch_size,
|
|
2370
|
+
num_workers=num_workers,
|
|
2371
|
+
streaming_chunk_size=streaming_chunk_size,
|
|
2289
2372
|
)
|
|
2290
|
-
data_loader = self.prepare_feature_data(data, all_user_features, batch_size)
|
|
2291
2373
|
|
|
2292
2374
|
embeddings_list = []
|
|
2293
2375
|
with torch.no_grad():
|
|
2294
2376
|
for batch_data in progress(data_loader, description="Encoding users"):
|
|
2295
2377
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2296
|
-
user_input = self.
|
|
2378
|
+
user_input = self.build_feature_tensors(
|
|
2379
|
+
batch_dict["features"], self.user_features_all
|
|
2380
|
+
)
|
|
2297
2381
|
user_emb = self.user_tower(user_input)
|
|
2298
2382
|
embeddings_list.append(user_emb.cpu().numpy())
|
|
2299
2383
|
return np.concatenate(embeddings_list, axis=0)
|
|
2300
2384
|
|
|
2301
2385
|
def encode_item(
|
|
2302
|
-
self,
|
|
2386
|
+
self,
|
|
2387
|
+
data: (
|
|
2388
|
+
dict
|
|
2389
|
+
| pd.DataFrame
|
|
2390
|
+
| DataLoader
|
|
2391
|
+
| str
|
|
2392
|
+
| os.PathLike
|
|
2393
|
+
| list[str | os.PathLike]
|
|
2394
|
+
),
|
|
2395
|
+
batch_size: int = 512,
|
|
2396
|
+
num_workers: int = 0,
|
|
2397
|
+
streaming_chunk_size: int = 10000,
|
|
2303
2398
|
) -> np.ndarray:
|
|
2304
2399
|
self.eval()
|
|
2305
|
-
|
|
2306
|
-
|
|
2307
|
-
|
|
2308
|
-
|
|
2400
|
+
data_loader = self.prepare_feature_data(
|
|
2401
|
+
data,
|
|
2402
|
+
self.item_features_all,
|
|
2403
|
+
batch_size,
|
|
2404
|
+
num_workers=num_workers,
|
|
2405
|
+
streaming_chunk_size=streaming_chunk_size,
|
|
2309
2406
|
)
|
|
2310
|
-
data_loader = self.prepare_feature_data(data, all_item_features, batch_size)
|
|
2311
2407
|
|
|
2312
2408
|
embeddings_list = []
|
|
2313
2409
|
with torch.no_grad():
|
|
2314
2410
|
for batch_data in progress(data_loader, description="Encoding items"):
|
|
2315
2411
|
batch_dict = batch_to_dict(batch_data, include_ids=False)
|
|
2316
|
-
item_input = self.
|
|
2412
|
+
item_input = self.build_feature_tensors(
|
|
2413
|
+
batch_dict["features"], self.item_features_all
|
|
2414
|
+
)
|
|
2317
2415
|
item_emb = self.item_tower(item_input)
|
|
2318
2416
|
embeddings_list.append(item_emb.cpu().numpy())
|
|
2319
2417
|
return np.concatenate(embeddings_list, axis=0)
|
nextrec/cli.py
CHANGED
|
@@ -380,6 +380,7 @@ def train_model(train_config_path: str) -> None:
|
|
|
380
380
|
optimizer_params=train_cfg.get("optimizer_params", {}),
|
|
381
381
|
loss=train_cfg.get("loss", "focal"),
|
|
382
382
|
loss_params=train_cfg.get("loss_params", {}),
|
|
383
|
+
loss_weights=train_cfg.get("loss_weights"),
|
|
383
384
|
)
|
|
384
385
|
|
|
385
386
|
model.fit(
|
|
@@ -416,7 +417,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
416
417
|
# Auto-infer session_id from checkpoint directory name
|
|
417
418
|
session_cfg = cfg.get("session", {}) or {}
|
|
418
419
|
session_id = session_cfg.get("id") or session_dir.name
|
|
419
|
-
|
|
420
|
+
|
|
420
421
|
setup_logger(session_id=session_id)
|
|
421
422
|
|
|
422
423
|
log_cli_section("CLI")
|
|
@@ -436,7 +437,7 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
436
437
|
processor_path = session_dir / "processor" / "processor.pkl"
|
|
437
438
|
|
|
438
439
|
predict_cfg = cfg.get("predict", {}) or {}
|
|
439
|
-
|
|
440
|
+
|
|
440
441
|
# Auto-find model_config in checkpoint directory if not specified
|
|
441
442
|
if "model_config" in cfg:
|
|
442
443
|
model_cfg_path = resolve_path(cfg["model_config"], config_dir)
|
|
@@ -563,7 +564,12 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
563
564
|
log_kv_lines(
|
|
564
565
|
[
|
|
565
566
|
("Data path", data_path),
|
|
566
|
-
(
|
|
567
|
+
(
|
|
568
|
+
"Format",
|
|
569
|
+
predict_cfg.get(
|
|
570
|
+
"source_data_format", predict_cfg.get("data_format", "auto")
|
|
571
|
+
),
|
|
572
|
+
),
|
|
567
573
|
("Batch size", batch_size),
|
|
568
574
|
("Chunk size", predict_cfg.get("chunk_size", 20000)),
|
|
569
575
|
("Streaming", predict_cfg.get("streaming", True)),
|
|
@@ -579,7 +585,9 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
579
585
|
)
|
|
580
586
|
|
|
581
587
|
# Build output path: {checkpoint_path}/predictions/{name}.{save_data_format}
|
|
582
|
-
save_format = predict_cfg.get(
|
|
588
|
+
save_format = predict_cfg.get(
|
|
589
|
+
"save_data_format", predict_cfg.get("save_format", "csv")
|
|
590
|
+
)
|
|
583
591
|
pred_name = predict_cfg.get("name", "pred")
|
|
584
592
|
# Pass filename with extension to let model.predict handle path resolution
|
|
585
593
|
save_path = f"{pred_name}.{save_format}"
|
|
@@ -597,7 +605,11 @@ def predict_model(predict_config_path: str) -> None:
|
|
|
597
605
|
)
|
|
598
606
|
duration = time.time() - start
|
|
599
607
|
# When return_dataframe=False, result is the actual file path
|
|
600
|
-
output_path =
|
|
608
|
+
output_path = (
|
|
609
|
+
result
|
|
610
|
+
if isinstance(result, Path)
|
|
611
|
+
else checkpoint_base / "predictions" / save_path
|
|
612
|
+
)
|
|
601
613
|
logger.info(f"Prediction completed, results saved to: {output_path}")
|
|
602
614
|
logger.info(f"Total time: {duration:.2f} seconds")
|
|
603
615
|
|
nextrec/data/preprocessor.py
CHANGED
|
@@ -610,7 +610,7 @@ class DataProcessor(FeatureSet):
|
|
|
610
610
|
save_format: Optional[Literal["csv", "parquet"]],
|
|
611
611
|
output_path: Optional[str],
|
|
612
612
|
warn_missing: bool = True,
|
|
613
|
-
)
|
|
613
|
+
):
|
|
614
614
|
logger = logging.getLogger()
|
|
615
615
|
is_dataframe = isinstance(data, pd.DataFrame)
|
|
616
616
|
data_dict = data if isinstance(data, dict) else None
|
|
@@ -705,7 +705,7 @@ class DataProcessor(FeatureSet):
|
|
|
705
705
|
output_path: Optional[str],
|
|
706
706
|
save_format: Optional[Literal["csv", "parquet"]],
|
|
707
707
|
chunk_size: int = 200000,
|
|
708
|
-
)
|
|
708
|
+
):
|
|
709
709
|
"""Transform data from files under a path and save them to a new location.
|
|
710
710
|
|
|
711
711
|
Uses chunked reading/writing to keep peak memory bounded for large files.
|
|
@@ -852,7 +852,7 @@ class DataProcessor(FeatureSet):
|
|
|
852
852
|
save_format: Optional[Literal["csv", "parquet"]] = None,
|
|
853
853
|
output_path: Optional[str] = None,
|
|
854
854
|
chunk_size: int = 200000,
|
|
855
|
-
)
|
|
855
|
+
):
|
|
856
856
|
if not self.is_fitted:
|
|
857
857
|
raise ValueError(
|
|
858
858
|
"[Data Processor Error] DataProcessor must be fitted before transform"
|
|
@@ -880,7 +880,7 @@ class DataProcessor(FeatureSet):
|
|
|
880
880
|
save_format: Optional[Literal["csv", "parquet"]] = None,
|
|
881
881
|
output_path: Optional[str] = None,
|
|
882
882
|
chunk_size: int = 200000,
|
|
883
|
-
)
|
|
883
|
+
):
|
|
884
884
|
self.fit(data, chunk_size=chunk_size)
|
|
885
885
|
return self.transform(
|
|
886
886
|
data,
|
nextrec/loss/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from nextrec.loss.listwise import (
|
|
|
5
5
|
ListNetLoss,
|
|
6
6
|
SampledSoftmaxLoss,
|
|
7
7
|
)
|
|
8
|
+
from nextrec.loss.grad_norm import GradNormLossWeighting
|
|
8
9
|
from nextrec.loss.loss_utils import VALID_TASK_TYPES, get_loss_fn, get_loss_kwargs
|
|
9
10
|
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
10
11
|
from nextrec.loss.pointwise import (
|
|
@@ -30,6 +31,8 @@ __all__ = [
|
|
|
30
31
|
"ListNetLoss",
|
|
31
32
|
"ListMLELoss",
|
|
32
33
|
"ApproxNDCGLoss",
|
|
34
|
+
# Multi-task weighting
|
|
35
|
+
"GradNormLossWeighting",
|
|
33
36
|
# Utilities
|
|
34
37
|
"get_loss_fn",
|
|
35
38
|
"get_loss_kwargs",
|