nextrec 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {nextrec-0.2.1 → nextrec-0.2.2}/PKG-INFO +2 -2
- {nextrec-0.2.1 → nextrec-0.2.2}/README.md +1 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/README_zh.md +1 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/conf.py +1 -1
- nextrec-0.2.2/docs/nextrec.loss.rst +45 -0
- nextrec-0.2.2/nextrec/__version__.py +1 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/layers.py +2 -2
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/model.py +80 -47
- nextrec-0.2.2/nextrec/loss/__init__.py +42 -0
- nextrec-0.2.2/nextrec/loss/listwise.py +164 -0
- nextrec-0.2.2/nextrec/loss/loss_utils.py +163 -0
- nextrec-0.2.2/nextrec/loss/pairwise.py +105 -0
- nextrec-0.2.2/nextrec/loss/pointwise.py +198 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/dssm.py +24 -15
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/dssm_v2.py +18 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/mind.py +16 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/sdm.py +15 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/youtube_dnn.py +21 -8
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/esmm.py +5 -5
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/mmoe.py +5 -5
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/ple.py +5 -5
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/share_bottom.py +5 -5
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/__init__.py +8 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/afm.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/autoint.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/dcn.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/deepfm.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/dien.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/din.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/fibinet.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/fm.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/masknet.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/pnn.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/widedeep.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/xdeepfm.py +3 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/__init__.py +5 -5
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/initializer.py +3 -3
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/optimizer.py +6 -6
- {nextrec-0.2.1 → nextrec-0.2.2}/pyproject.toml +1 -1
- nextrec-0.2.2/test/test_losses.py +114 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_ranking_din.py +11 -1
- {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/movielen_match_dssm.py +4 -1
- nextrec-0.2.1/docs/nextrec.loss.rst +0 -29
- nextrec-0.2.1/nextrec/__version__.py +0 -1
- nextrec-0.2.1/nextrec/loss/__init__.py +0 -35
- nextrec-0.2.1/nextrec/loss/listwise.py +0 -6
- nextrec-0.2.1/nextrec/loss/loss_utils.py +0 -135
- nextrec-0.2.1/nextrec/loss/match_losses.py +0 -293
- nextrec-0.2.1/nextrec/loss/pairwise.py +0 -6
- nextrec-0.2.1/nextrec/loss/pointwise.py +0 -6
- {nextrec-0.2.1 → nextrec-0.2.2}/.github/workflows/publish.yml +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/.github/workflows/tests.yml +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/.gitignore +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/.readthedocs.yaml +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/CODE_OF_CONDUCT.md +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/CONTRIBUTING.md +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/LICENSE +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/MANIFEST.in +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/dataset/match_task.csv +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/dataset/movielens_100k.csv +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/dataset/multitask_task.csv +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/dataset/ranking_task.csv +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/Makefile +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/index.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/make.bat +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/modules.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.basic.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.data.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.utils.rst +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/docs/requirements.txt +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/__init__.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/__init__.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/activation.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/callback.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/features.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/loggers.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/metrics.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/session.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/__init__.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/data_utils.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/dataloader.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/preprocessor.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/generative/hstu.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/generative/tiger.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/__init__.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/embedding.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/pytest.ini +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/requirements.txt +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/__init__.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/conftest.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/run_tests.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_data_preprocessor.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_dataloader.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_layers.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_match_models.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_multitask_models.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_ranking_models.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test/test_utils.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/test_requirements.txt +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_match_dssm.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_multitask.py +0 -0
- {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/movielen_ranking_deepfm.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nextrec
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
|
|
5
5
|
Project-URL: Homepage, https://github.com/zerolovesea/NextRec
|
|
6
6
|
Project-URL: Repository, https://github.com/zerolovesea/NextRec
|
|
@@ -61,7 +61,7 @@ Description-Content-Type: text/markdown
|
|
|
61
61
|

|
|
62
62
|

|
|
63
63
|

|
|
64
|
-

|
|
65
65
|
|
|
66
66
|
English | [中文版](README_zh.md)
|
|
67
67
|
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|

|
|
6
6
|

|
|
7
7
|

|
|
8
|
-

|
|
9
9
|
|
|
10
10
|
English | [中文版](README_zh.md)
|
|
11
11
|
|
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|

|
|
6
6
|

|
|
7
7
|

|
|
8
|
-

|
|
9
9
|
|
|
10
10
|
[English Version](README.md) | 中文版
|
|
11
11
|
|
|
@@ -12,7 +12,7 @@ sys.path.insert(0, os.path.abspath('../nextrec'))
|
|
|
12
12
|
project = "NextRec"
|
|
13
13
|
copyright = "2025, Yang Zhou"
|
|
14
14
|
author = "Yang Zhou"
|
|
15
|
-
release = "0.2.
|
|
15
|
+
release = "0.2.2"
|
|
16
16
|
|
|
17
17
|
# -- General configuration ---------------------------------------------------
|
|
18
18
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
nextrec.loss package
|
|
2
|
+
====================
|
|
3
|
+
|
|
4
|
+
Submodules
|
|
5
|
+
----------
|
|
6
|
+
|
|
7
|
+
nextrec.loss.loss\_utils module
|
|
8
|
+
-------------------------------
|
|
9
|
+
|
|
10
|
+
.. automodule:: nextrec.loss.loss_utils
|
|
11
|
+
:members:
|
|
12
|
+
:undoc-members:
|
|
13
|
+
:show-inheritance:
|
|
14
|
+
|
|
15
|
+
nextrec.loss.pointwise module
|
|
16
|
+
-----------------------------
|
|
17
|
+
|
|
18
|
+
.. automodule:: nextrec.loss.pointwise
|
|
19
|
+
:members:
|
|
20
|
+
:undoc-members:
|
|
21
|
+
:show-inheritance:
|
|
22
|
+
|
|
23
|
+
nextrec.loss.pairwise module
|
|
24
|
+
----------------------------
|
|
25
|
+
|
|
26
|
+
.. automodule:: nextrec.loss.pairwise
|
|
27
|
+
:members:
|
|
28
|
+
:undoc-members:
|
|
29
|
+
:show-inheritance:
|
|
30
|
+
|
|
31
|
+
nextrec.loss.listwise module
|
|
32
|
+
----------------------------
|
|
33
|
+
|
|
34
|
+
.. automodule:: nextrec.loss.listwise
|
|
35
|
+
:members:
|
|
36
|
+
:undoc-members:
|
|
37
|
+
:show-inheritance:
|
|
38
|
+
|
|
39
|
+
Module contents
|
|
40
|
+
---------------
|
|
41
|
+
|
|
42
|
+
.. automodule:: nextrec.loss
|
|
43
|
+
:members:
|
|
44
|
+
:undoc-members:
|
|
45
|
+
:show-inheritance:
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.2"
|
|
@@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|
|
16
16
|
|
|
17
17
|
from nextrec.basic.activation import activation_layer
|
|
18
18
|
from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
|
|
19
|
-
from nextrec.utils.initializer import
|
|
19
|
+
from nextrec.utils.initializer import get_initializer
|
|
20
20
|
|
|
21
21
|
Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
|
|
22
22
|
|
|
@@ -160,7 +160,7 @@ class EmbeddingLayer(nn.Module):
|
|
|
160
160
|
)
|
|
161
161
|
embedding.weight.requires_grad = feature.trainable
|
|
162
162
|
|
|
163
|
-
initialization =
|
|
163
|
+
initialization = get_initializer(
|
|
164
164
|
init_type=feature.init_type,
|
|
165
165
|
activation="linear",
|
|
166
166
|
param=feature.init_params,
|
|
@@ -6,18 +6,15 @@ Author: Yang Zhou,zyaztec@gmail.com
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import os
|
|
9
|
-
import
|
|
9
|
+
import tqdm
|
|
10
10
|
import logging
|
|
11
|
-
import os
|
|
12
|
-
from pathlib import Path
|
|
13
|
-
|
|
14
11
|
import numpy as np
|
|
15
12
|
import pandas as pd
|
|
16
13
|
import torch
|
|
17
14
|
import torch.nn as nn
|
|
18
15
|
import torch.nn.functional as F
|
|
19
|
-
import tqdm
|
|
20
16
|
|
|
17
|
+
from pathlib import Path
|
|
21
18
|
from typing import Union, Literal
|
|
22
19
|
from torch.utils.data import DataLoader, TensorDataset
|
|
23
20
|
|
|
@@ -25,11 +22,11 @@ from nextrec.basic.callback import EarlyStopper
|
|
|
25
22
|
from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureConfig
|
|
26
23
|
from nextrec.basic.metrics import configure_metrics, evaluate_metrics
|
|
27
24
|
|
|
28
|
-
from nextrec.loss import get_loss_fn
|
|
25
|
+
from nextrec.loss import get_loss_fn, get_loss_kwargs
|
|
29
26
|
from nextrec.data import get_column_data
|
|
30
27
|
from nextrec.data.dataloader import build_tensors_from_data
|
|
31
28
|
from nextrec.basic.loggers import setup_logger, colorize
|
|
32
|
-
from nextrec.utils import
|
|
29
|
+
from nextrec.utils import get_optimizer, get_scheduler
|
|
33
30
|
from nextrec.basic.session import resolve_save_path, create_session
|
|
34
31
|
|
|
35
32
|
|
|
@@ -400,7 +397,9 @@ class BaseModel(FeatureConfig, nn.Module):
|
|
|
400
397
|
optimizer_params: dict | None = None,
|
|
401
398
|
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
402
399
|
scheduler_params: dict | None = None,
|
|
403
|
-
loss: str | nn.Module | list[str | nn.Module] | None= "bce"
|
|
400
|
+
loss: str | nn.Module | list[str | nn.Module] | None= "bce",
|
|
401
|
+
loss_params: dict | list[dict] | None = None):
|
|
402
|
+
|
|
404
403
|
if optimizer_params is None:
|
|
405
404
|
optimizer_params = {}
|
|
406
405
|
|
|
@@ -415,9 +414,10 @@ class BaseModel(FeatureConfig, nn.Module):
|
|
|
415
414
|
self._scheduler_name = None
|
|
416
415
|
self._scheduler_params = scheduler_params or {}
|
|
417
416
|
self._loss_config = loss
|
|
417
|
+
self._loss_params = loss_params
|
|
418
418
|
|
|
419
419
|
# set optimizer
|
|
420
|
-
self.optimizer_fn =
|
|
420
|
+
self.optimizer_fn = get_optimizer(
|
|
421
421
|
optimizer=optimizer,
|
|
422
422
|
params=self.parameters(),
|
|
423
423
|
**optimizer_params
|
|
@@ -430,7 +430,12 @@ class BaseModel(FeatureConfig, nn.Module):
|
|
|
430
430
|
# For ranking and multitask, use pointwise training
|
|
431
431
|
training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
|
|
432
432
|
# Use task_type directly, not self.task_type for single task
|
|
433
|
-
self.loss_fn = [get_loss_fn(
|
|
433
|
+
self.loss_fn = [get_loss_fn(
|
|
434
|
+
task_type=task_type,
|
|
435
|
+
training_mode=training_mode,
|
|
436
|
+
loss=loss_value,
|
|
437
|
+
**get_loss_kwargs(loss_params)
|
|
438
|
+
)]
|
|
434
439
|
else:
|
|
435
440
|
self.loss_fn = []
|
|
436
441
|
for i in range(self.nums_task):
|
|
@@ -443,10 +448,15 @@ class BaseModel(FeatureConfig, nn.Module):
|
|
|
443
448
|
|
|
444
449
|
# Multitask always uses pointwise training
|
|
445
450
|
training_mode = 'pointwise'
|
|
446
|
-
self.loss_fn.append(get_loss_fn(
|
|
451
|
+
self.loss_fn.append(get_loss_fn(
|
|
452
|
+
task_type=task_type,
|
|
453
|
+
training_mode=training_mode,
|
|
454
|
+
loss=loss_value,
|
|
455
|
+
**get_loss_kwargs(loss_params, i)
|
|
456
|
+
))
|
|
447
457
|
|
|
448
458
|
# set scheduler
|
|
449
|
-
self.scheduler_fn =
|
|
459
|
+
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
450
460
|
|
|
451
461
|
def compute_loss(self, y_pred, y_true):
|
|
452
462
|
if y_true is None:
|
|
@@ -1130,10 +1140,13 @@ class BaseMatchModel(BaseModel):
|
|
|
1130
1140
|
Base class for match (retrieval/recall) models
|
|
1131
1141
|
Supports pointwise, pairwise, and listwise training modes
|
|
1132
1142
|
"""
|
|
1133
|
-
|
|
1143
|
+
@property
|
|
1144
|
+
def model_name(self) -> str:
|
|
1145
|
+
raise NotImplementedError
|
|
1146
|
+
|
|
1134
1147
|
@property
|
|
1135
1148
|
def task_type(self) -> str:
|
|
1136
|
-
|
|
1149
|
+
raise NotImplementedError
|
|
1137
1150
|
|
|
1138
1151
|
@property
|
|
1139
1152
|
def support_training_modes(self) -> list[str]:
|
|
@@ -1209,45 +1222,47 @@ class BaseMatchModel(BaseModel):
|
|
|
1209
1222
|
self.num_negative_samples = num_negative_samples
|
|
1210
1223
|
self.temperature = temperature
|
|
1211
1224
|
self.similarity_metric = similarity_metric
|
|
1212
|
-
|
|
1225
|
+
|
|
1226
|
+
self.user_feature_names = [f.name for f in (
|
|
1227
|
+
self.user_dense_features + self.user_sparse_features + self.user_sequence_features
|
|
1228
|
+
)]
|
|
1229
|
+
self.item_feature_names = [f.name for f in (
|
|
1230
|
+
self.item_dense_features + self.item_sparse_features + self.item_sequence_features
|
|
1231
|
+
)]
|
|
1232
|
+
|
|
1213
1233
|
def get_user_features(self, X_input: dict) -> dict:
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
if
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1234
|
+
return {
|
|
1235
|
+
name: X_input[name]
|
|
1236
|
+
for name in self.user_feature_names
|
|
1237
|
+
if name in X_input
|
|
1238
|
+
}
|
|
1239
|
+
|
|
1221
1240
|
def get_item_features(self, X_input: dict) -> dict:
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
if
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1241
|
+
return {
|
|
1242
|
+
name: X_input[name]
|
|
1243
|
+
for name in self.item_feature_names
|
|
1244
|
+
if name in X_input
|
|
1245
|
+
}
|
|
1246
|
+
|
|
1229
1247
|
def compile(self,
|
|
1230
|
-
optimizer = "adam",
|
|
1248
|
+
optimizer: str | torch.optim.Optimizer = "adam",
|
|
1231
1249
|
optimizer_params: dict | None = None,
|
|
1232
1250
|
scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
|
|
1233
1251
|
scheduler_params: dict | None = None,
|
|
1234
|
-
loss: str | nn.Module | list[str | nn.Module] | None=
|
|
1252
|
+
loss: str | nn.Module | list[str | nn.Module] | None = "bce",
|
|
1253
|
+
loss_params: dict | list[dict] | None = None):
|
|
1235
1254
|
"""
|
|
1236
1255
|
Compile match model with optimizer, scheduler, and loss function.
|
|
1237
|
-
|
|
1256
|
+
Mirrors BaseModel.compile while adding training_mode validation for match tasks.
|
|
1238
1257
|
"""
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
model_name=self.model_name
|
|
1246
|
-
)
|
|
1247
|
-
|
|
1258
|
+
if self.training_mode not in self.support_training_modes:
|
|
1259
|
+
raise ValueError(
|
|
1260
|
+
f"{self.model_name} does not support training_mode='{self.training_mode}'. "
|
|
1261
|
+
f"Supported modes: {self.support_training_modes}"
|
|
1262
|
+
)
|
|
1263
|
+
|
|
1248
1264
|
# Call parent compile with match-specific logic
|
|
1249
|
-
|
|
1250
|
-
optimizer_params = {}
|
|
1265
|
+
optimizer_params = optimizer_params or {}
|
|
1251
1266
|
|
|
1252
1267
|
self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
|
|
1253
1268
|
self._optimizer_params = optimizer_params
|
|
@@ -1260,24 +1275,42 @@ class BaseMatchModel(BaseModel):
|
|
|
1260
1275
|
self._scheduler_name = None
|
|
1261
1276
|
self._scheduler_params = scheduler_params or {}
|
|
1262
1277
|
self._loss_config = loss
|
|
1278
|
+
self._loss_params = loss_params
|
|
1263
1279
|
|
|
1264
1280
|
# set optimizer
|
|
1265
|
-
self.optimizer_fn =
|
|
1281
|
+
self.optimizer_fn = get_optimizer(
|
|
1266
1282
|
optimizer=optimizer,
|
|
1267
1283
|
params=self.parameters(),
|
|
1268
1284
|
**optimizer_params
|
|
1269
1285
|
)
|
|
1270
1286
|
|
|
1271
1287
|
# Set loss function based on training mode
|
|
1272
|
-
|
|
1288
|
+
default_losses = {
|
|
1289
|
+
'pointwise': 'bce',
|
|
1290
|
+
'pairwise': 'bpr',
|
|
1291
|
+
'listwise': 'sampled_softmax',
|
|
1292
|
+
}
|
|
1293
|
+
|
|
1294
|
+
if loss is None:
|
|
1295
|
+
loss_value = default_losses.get(self.training_mode, "bce")
|
|
1296
|
+
elif isinstance(loss, list):
|
|
1297
|
+
loss_value = loss[0] if loss and loss[0] is not None else default_losses.get(self.training_mode, "bce")
|
|
1298
|
+
else:
|
|
1299
|
+
loss_value = loss
|
|
1300
|
+
|
|
1301
|
+
# Pairwise/listwise modes do not support BCE, fall back to sensible defaults
|
|
1302
|
+
if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
|
|
1303
|
+
loss_value = default_losses.get(self.training_mode, loss_value)
|
|
1304
|
+
|
|
1273
1305
|
self.loss_fn = [get_loss_fn(
|
|
1274
1306
|
task_type='match',
|
|
1275
1307
|
training_mode=self.training_mode,
|
|
1276
|
-
loss=loss_value
|
|
1308
|
+
loss=loss_value,
|
|
1309
|
+
**get_loss_kwargs(loss_params, 0)
|
|
1277
1310
|
)]
|
|
1278
1311
|
|
|
1279
1312
|
# set scheduler
|
|
1280
|
-
self.scheduler_fn =
|
|
1313
|
+
self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
|
|
1281
1314
|
|
|
1282
1315
|
def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
|
|
1283
1316
|
if self.similarity_metric == 'dot':
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from nextrec.loss.listwise import (
|
|
2
|
+
ApproxNDCGLoss,
|
|
3
|
+
InfoNCELoss,
|
|
4
|
+
ListMLELoss,
|
|
5
|
+
ListNetLoss,
|
|
6
|
+
SampledSoftmaxLoss,
|
|
7
|
+
)
|
|
8
|
+
from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
|
|
9
|
+
from nextrec.loss.pointwise import (
|
|
10
|
+
ClassBalancedFocalLoss,
|
|
11
|
+
CosineContrastiveLoss,
|
|
12
|
+
FocalLoss,
|
|
13
|
+
WeightedBCELoss,
|
|
14
|
+
)
|
|
15
|
+
from nextrec.loss.loss_utils import (
|
|
16
|
+
get_loss_fn,
|
|
17
|
+
get_loss_kwargs,
|
|
18
|
+
VALID_TASK_TYPES,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
# Pointwise
|
|
23
|
+
"CosineContrastiveLoss",
|
|
24
|
+
"WeightedBCELoss",
|
|
25
|
+
"FocalLoss",
|
|
26
|
+
"ClassBalancedFocalLoss",
|
|
27
|
+
# Pairwise
|
|
28
|
+
"BPRLoss",
|
|
29
|
+
"HingeLoss",
|
|
30
|
+
"TripletLoss",
|
|
31
|
+
# Listwise
|
|
32
|
+
"SampledSoftmaxLoss",
|
|
33
|
+
"InfoNCELoss",
|
|
34
|
+
"ListNetLoss",
|
|
35
|
+
"ListMLELoss",
|
|
36
|
+
"ApproxNDCGLoss",
|
|
37
|
+
# Utilities
|
|
38
|
+
"get_loss_fn",
|
|
39
|
+
"get_loss_kwargs",
|
|
40
|
+
"validate_training_mode",
|
|
41
|
+
"VALID_TASK_TYPES",
|
|
42
|
+
]
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Listwise loss functions for ranking and contrastive training.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SampledSoftmaxLoss(nn.Module):
|
|
13
|
+
"""
|
|
14
|
+
Softmax over one positive and multiple sampled negatives.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, reduction: str = "mean"):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self.reduction = reduction
|
|
20
|
+
|
|
21
|
+
def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
pos_logits = pos_logits.unsqueeze(1)
|
|
23
|
+
all_logits = torch.cat([pos_logits, neg_logits], dim=1)
|
|
24
|
+
targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
|
|
25
|
+
loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
|
|
26
|
+
return loss
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class InfoNCELoss(nn.Module):
|
|
30
|
+
"""
|
|
31
|
+
InfoNCE loss for contrastive learning with one positive and many negatives.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.temperature = temperature
|
|
37
|
+
self.reduction = reduction
|
|
38
|
+
|
|
39
|
+
def forward(
|
|
40
|
+
self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
|
|
41
|
+
) -> torch.Tensor:
|
|
42
|
+
pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature
|
|
43
|
+
pos_sim = pos_sim.unsqueeze(1)
|
|
44
|
+
query_expanded = query.unsqueeze(1)
|
|
45
|
+
neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
|
|
46
|
+
logits = torch.cat([pos_sim, neg_sim], dim=1)
|
|
47
|
+
labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
|
|
48
|
+
loss = F.cross_entropy(logits, labels, reduction=self.reduction)
|
|
49
|
+
return loss
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ListNetLoss(nn.Module):
|
|
53
|
+
"""
|
|
54
|
+
ListNet loss using top-1 probability distribution.
|
|
55
|
+
Reference: Cao et al. (ICML 2007)
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.temperature = temperature
|
|
61
|
+
self.reduction = reduction
|
|
62
|
+
|
|
63
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
64
|
+
pred_probs = F.softmax(scores / self.temperature, dim=1)
|
|
65
|
+
true_probs = F.softmax(labels / self.temperature, dim=1)
|
|
66
|
+
loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
|
|
67
|
+
|
|
68
|
+
if self.reduction == "mean":
|
|
69
|
+
return loss.mean()
|
|
70
|
+
if self.reduction == "sum":
|
|
71
|
+
return loss.sum()
|
|
72
|
+
return loss
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ListMLELoss(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
ListMLE (Maximum Likelihood Estimation) loss.
|
|
78
|
+
Reference: Xia et al. (ICML 2008)
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, reduction: str = "mean"):
|
|
82
|
+
super().__init__()
|
|
83
|
+
self.reduction = reduction
|
|
84
|
+
|
|
85
|
+
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
|
|
87
|
+
batch_size, list_size = scores.shape
|
|
88
|
+
batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
|
|
89
|
+
sorted_scores = scores[batch_indices, sorted_indices]
|
|
90
|
+
|
|
91
|
+
loss = torch.tensor(0.0, device=scores.device)
|
|
92
|
+
for i in range(list_size):
|
|
93
|
+
remaining_scores = sorted_scores[:, i:]
|
|
94
|
+
log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
|
|
95
|
+
loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
|
|
96
|
+
|
|
97
|
+
if self.reduction == "mean":
|
|
98
|
+
return loss / batch_size
|
|
99
|
+
if self.reduction == "sum":
|
|
100
|
+
return loss
|
|
101
|
+
return loss / batch_size
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ApproxNDCGLoss(nn.Module):
|
|
105
|
+
"""
|
|
106
|
+
Approximate NDCG loss for learning to rank.
|
|
107
|
+
Reference: Qin et al. (2010)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
|
|
111
|
+
super().__init__()
|
|
112
|
+
self.temperature = temperature
|
|
113
|
+
self.reduction = reduction
|
|
114
|
+
|
|
115
|
+
def _ideal_dcg(self, labels: torch.Tensor, k: Optional[int]) -> torch.Tensor:
|
|
116
|
+
# labels: [B, L]
|
|
117
|
+
sorted_labels, _ = torch.sort(labels, dim=1, descending=True)
|
|
118
|
+
if k is not None:
|
|
119
|
+
sorted_labels = sorted_labels[:, :k]
|
|
120
|
+
|
|
121
|
+
gains = torch.pow(2.0, sorted_labels) - 1.0 # [B, K]
|
|
122
|
+
positions = torch.arange(
|
|
123
|
+
1, gains.size(1) + 1, device=gains.device, dtype=torch.float32
|
|
124
|
+
) # [K]
|
|
125
|
+
discounts = 1.0 / torch.log2(positions + 1.0) # [K]
|
|
126
|
+
ideal_dcg = torch.sum(gains * discounts, dim=1) # [B]
|
|
127
|
+
return ideal_dcg
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
"""
|
|
133
|
+
scores: [B, L]
|
|
134
|
+
labels: [B, L]
|
|
135
|
+
"""
|
|
136
|
+
batch_size, list_size = scores.shape
|
|
137
|
+
device = scores.device
|
|
138
|
+
|
|
139
|
+
# diff[b, i, j] = (s_j - s_i) / T
|
|
140
|
+
scores_i = scores.unsqueeze(2) # [B, L, 1]
|
|
141
|
+
scores_j = scores.unsqueeze(1) # [B, 1, L]
|
|
142
|
+
diff = (scores_j - scores_i) / self.temperature # [B, L, L]
|
|
143
|
+
|
|
144
|
+
P_ji = torch.sigmoid(diff) # [B, L, L]
|
|
145
|
+
eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
|
|
146
|
+
P_ji = P_ji * (1.0 - eye)
|
|
147
|
+
|
|
148
|
+
exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
|
|
149
|
+
|
|
150
|
+
discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
|
|
151
|
+
|
|
152
|
+
gains = torch.pow(2.0, labels) - 1.0 # [B, L]
|
|
153
|
+
approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
|
|
154
|
+
|
|
155
|
+
ideal_dcg = self._ideal_dcg(labels, k) # [B]
|
|
156
|
+
|
|
157
|
+
ndcg = approx_dcg / (ideal_dcg + 1e-10) # [B]
|
|
158
|
+
loss = 1.0 - ndcg
|
|
159
|
+
|
|
160
|
+
if self.reduction == "mean":
|
|
161
|
+
return loss.mean()
|
|
162
|
+
if self.reduction == "sum":
|
|
163
|
+
return loss.sum()
|
|
164
|
+
return loss
|