nextrec 0.2.1__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. {nextrec-0.2.1 → nextrec-0.2.2}/PKG-INFO +2 -2
  2. {nextrec-0.2.1 → nextrec-0.2.2}/README.md +1 -1
  3. {nextrec-0.2.1 → nextrec-0.2.2}/README_zh.md +1 -1
  4. {nextrec-0.2.1 → nextrec-0.2.2}/docs/conf.py +1 -1
  5. nextrec-0.2.2/docs/nextrec.loss.rst +45 -0
  6. nextrec-0.2.2/nextrec/__version__.py +1 -0
  7. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/layers.py +2 -2
  8. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/model.py +80 -47
  9. nextrec-0.2.2/nextrec/loss/__init__.py +42 -0
  10. nextrec-0.2.2/nextrec/loss/listwise.py +164 -0
  11. nextrec-0.2.2/nextrec/loss/loss_utils.py +163 -0
  12. nextrec-0.2.2/nextrec/loss/pairwise.py +105 -0
  13. nextrec-0.2.2/nextrec/loss/pointwise.py +198 -0
  14. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/dssm.py +24 -15
  15. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/dssm_v2.py +18 -0
  16. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/mind.py +16 -1
  17. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/sdm.py +15 -0
  18. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/youtube_dnn.py +21 -8
  19. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/esmm.py +5 -5
  20. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/mmoe.py +5 -5
  21. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/ple.py +5 -5
  22. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/multi_task/share_bottom.py +5 -5
  23. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/__init__.py +8 -0
  24. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/afm.py +3 -1
  25. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/autoint.py +3 -1
  26. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/dcn.py +3 -1
  27. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/deepfm.py +3 -1
  28. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/dien.py +3 -1
  29. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/din.py +3 -1
  30. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/fibinet.py +3 -1
  31. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/fm.py +3 -1
  32. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/masknet.py +3 -1
  33. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/pnn.py +3 -1
  34. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/widedeep.py +3 -1
  35. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/ranking/xdeepfm.py +3 -1
  36. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/__init__.py +5 -5
  37. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/initializer.py +3 -3
  38. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/optimizer.py +6 -6
  39. {nextrec-0.2.1 → nextrec-0.2.2}/pyproject.toml +1 -1
  40. nextrec-0.2.2/test/test_losses.py +114 -0
  41. {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_ranking_din.py +11 -1
  42. {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/movielen_match_dssm.py +4 -1
  43. nextrec-0.2.1/docs/nextrec.loss.rst +0 -29
  44. nextrec-0.2.1/nextrec/__version__.py +0 -1
  45. nextrec-0.2.1/nextrec/loss/__init__.py +0 -35
  46. nextrec-0.2.1/nextrec/loss/listwise.py +0 -6
  47. nextrec-0.2.1/nextrec/loss/loss_utils.py +0 -135
  48. nextrec-0.2.1/nextrec/loss/match_losses.py +0 -293
  49. nextrec-0.2.1/nextrec/loss/pairwise.py +0 -6
  50. nextrec-0.2.1/nextrec/loss/pointwise.py +0 -6
  51. {nextrec-0.2.1 → nextrec-0.2.2}/.github/workflows/publish.yml +0 -0
  52. {nextrec-0.2.1 → nextrec-0.2.2}/.github/workflows/tests.yml +0 -0
  53. {nextrec-0.2.1 → nextrec-0.2.2}/.gitignore +0 -0
  54. {nextrec-0.2.1 → nextrec-0.2.2}/.readthedocs.yaml +0 -0
  55. {nextrec-0.2.1 → nextrec-0.2.2}/CODE_OF_CONDUCT.md +0 -0
  56. {nextrec-0.2.1 → nextrec-0.2.2}/CONTRIBUTING.md +0 -0
  57. {nextrec-0.2.1 → nextrec-0.2.2}/LICENSE +0 -0
  58. {nextrec-0.2.1 → nextrec-0.2.2}/MANIFEST.in +0 -0
  59. {nextrec-0.2.1 → nextrec-0.2.2}/dataset/match_task.csv +0 -0
  60. {nextrec-0.2.1 → nextrec-0.2.2}/dataset/movielens_100k.csv +0 -0
  61. {nextrec-0.2.1 → nextrec-0.2.2}/dataset/multitask_task.csv +0 -0
  62. {nextrec-0.2.1 → nextrec-0.2.2}/dataset/ranking_task.csv +0 -0
  63. {nextrec-0.2.1 → nextrec-0.2.2}/docs/Makefile +0 -0
  64. {nextrec-0.2.1 → nextrec-0.2.2}/docs/index.rst +0 -0
  65. {nextrec-0.2.1 → nextrec-0.2.2}/docs/make.bat +0 -0
  66. {nextrec-0.2.1 → nextrec-0.2.2}/docs/modules.rst +0 -0
  67. {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.basic.rst +0 -0
  68. {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.data.rst +0 -0
  69. {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.rst +0 -0
  70. {nextrec-0.2.1 → nextrec-0.2.2}/docs/nextrec.utils.rst +0 -0
  71. {nextrec-0.2.1 → nextrec-0.2.2}/docs/requirements.txt +0 -0
  72. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/__init__.py +0 -0
  73. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/__init__.py +0 -0
  74. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/activation.py +0 -0
  75. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/callback.py +0 -0
  76. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/features.py +0 -0
  77. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/loggers.py +0 -0
  78. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/metrics.py +0 -0
  79. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/basic/session.py +0 -0
  80. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/__init__.py +0 -0
  81. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/data_utils.py +0 -0
  82. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/dataloader.py +0 -0
  83. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/data/preprocessor.py +0 -0
  84. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/generative/hstu.py +0 -0
  85. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/generative/tiger.py +0 -0
  86. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/models/match/__init__.py +0 -0
  87. {nextrec-0.2.1 → nextrec-0.2.2}/nextrec/utils/embedding.py +0 -0
  88. {nextrec-0.2.1 → nextrec-0.2.2}/pytest.ini +0 -0
  89. {nextrec-0.2.1 → nextrec-0.2.2}/requirements.txt +0 -0
  90. {nextrec-0.2.1 → nextrec-0.2.2}/test/__init__.py +0 -0
  91. {nextrec-0.2.1 → nextrec-0.2.2}/test/conftest.py +0 -0
  92. {nextrec-0.2.1 → nextrec-0.2.2}/test/run_tests.py +0 -0
  93. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_data_preprocessor.py +0 -0
  94. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_dataloader.py +0 -0
  95. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_layers.py +0 -0
  96. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_match_models.py +0 -0
  97. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_multitask_models.py +0 -0
  98. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_ranking_models.py +0 -0
  99. {nextrec-0.2.1 → nextrec-0.2.2}/test/test_utils.py +0 -0
  100. {nextrec-0.2.1 → nextrec-0.2.2}/test_requirements.txt +0 -0
  101. {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_match_dssm.py +0 -0
  102. {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/example_multitask.py +0 -0
  103. {nextrec-0.2.1 → nextrec-0.2.2}/tutorials/movielen_ranking_deepfm.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.2.1
3
+ Version: 0.2.2
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -61,7 +61,7 @@ Description-Content-Type: text/markdown
61
61
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
62
62
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
63
63
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
64
- ![Version](https://img.shields.io/badge/Version-0.2.1-orange.svg)
64
+ ![Version](https://img.shields.io/badge/Version-0.2.2-orange.svg)
65
65
 
66
66
  English | [中文版](README_zh.md)
67
67
 
@@ -5,7 +5,7 @@
5
5
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
6
6
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
7
7
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
8
- ![Version](https://img.shields.io/badge/Version-0.2.1-orange.svg)
8
+ ![Version](https://img.shields.io/badge/Version-0.2.2-orange.svg)
9
9
 
10
10
  English | [中文版](README_zh.md)
11
11
 
@@ -5,7 +5,7 @@
5
5
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
6
6
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
7
7
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
8
- ![Version](https://img.shields.io/badge/Version-0.2.1-orange.svg)
8
+ ![Version](https://img.shields.io/badge/Version-0.2.2-orange.svg)
9
9
 
10
10
  [English Version](README.md) | 中文版
11
11
 
@@ -12,7 +12,7 @@ sys.path.insert(0, os.path.abspath('../nextrec'))
12
12
  project = "NextRec"
13
13
  copyright = "2025, Yang Zhou"
14
14
  author = "Yang Zhou"
15
- release = "0.2.1"
15
+ release = "0.2.2"
16
16
 
17
17
  # -- General configuration ---------------------------------------------------
18
18
  # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@@ -0,0 +1,45 @@
1
+ nextrec.loss package
2
+ ====================
3
+
4
+ Submodules
5
+ ----------
6
+
7
+ nextrec.loss.loss\_utils module
8
+ -------------------------------
9
+
10
+ .. automodule:: nextrec.loss.loss_utils
11
+ :members:
12
+ :undoc-members:
13
+ :show-inheritance:
14
+
15
+ nextrec.loss.pointwise module
16
+ -----------------------------
17
+
18
+ .. automodule:: nextrec.loss.pointwise
19
+ :members:
20
+ :undoc-members:
21
+ :show-inheritance:
22
+
23
+ nextrec.loss.pairwise module
24
+ ----------------------------
25
+
26
+ .. automodule:: nextrec.loss.pairwise
27
+ :members:
28
+ :undoc-members:
29
+ :show-inheritance:
30
+
31
+ nextrec.loss.listwise module
32
+ ----------------------------
33
+
34
+ .. automodule:: nextrec.loss.listwise
35
+ :members:
36
+ :undoc-members:
37
+ :show-inheritance:
38
+
39
+ Module contents
40
+ ---------------
41
+
42
+ .. automodule:: nextrec.loss
43
+ :members:
44
+ :undoc-members:
45
+ :show-inheritance:
@@ -0,0 +1 @@
1
+ __version__ = "0.2.2"
@@ -16,7 +16,7 @@ import torch.nn.functional as F
16
16
 
17
17
  from nextrec.basic.activation import activation_layer
18
18
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
19
- from nextrec.utils.initializer import get_initializer_fn
19
+ from nextrec.utils.initializer import get_initializer
20
20
 
21
21
  Feature = Union[DenseFeature, SparseFeature, SequenceFeature]
22
22
 
@@ -160,7 +160,7 @@ class EmbeddingLayer(nn.Module):
160
160
  )
161
161
  embedding.weight.requires_grad = feature.trainable
162
162
 
163
- initialization = get_initializer_fn(
163
+ initialization = get_initializer(
164
164
  init_type=feature.init_type,
165
165
  activation="linear",
166
166
  param=feature.init_params,
@@ -6,18 +6,15 @@ Author: Yang Zhou,zyaztec@gmail.com
6
6
  """
7
7
 
8
8
  import os
9
- import datetime
9
+ import tqdm
10
10
  import logging
11
- import os
12
- from pathlib import Path
13
-
14
11
  import numpy as np
15
12
  import pandas as pd
16
13
  import torch
17
14
  import torch.nn as nn
18
15
  import torch.nn.functional as F
19
- import tqdm
20
16
 
17
+ from pathlib import Path
21
18
  from typing import Union, Literal
22
19
  from torch.utils.data import DataLoader, TensorDataset
23
20
 
@@ -25,11 +22,11 @@ from nextrec.basic.callback import EarlyStopper
25
22
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature, FeatureConfig
26
23
  from nextrec.basic.metrics import configure_metrics, evaluate_metrics
27
24
 
28
- from nextrec.loss import get_loss_fn
25
+ from nextrec.loss import get_loss_fn, get_loss_kwargs
29
26
  from nextrec.data import get_column_data
30
27
  from nextrec.data.dataloader import build_tensors_from_data
31
28
  from nextrec.basic.loggers import setup_logger, colorize
32
- from nextrec.utils import get_optimizer_fn, get_scheduler_fn
29
+ from nextrec.utils import get_optimizer, get_scheduler
33
30
  from nextrec.basic.session import resolve_save_path, create_session
34
31
 
35
32
 
@@ -400,7 +397,9 @@ class BaseModel(FeatureConfig, nn.Module):
400
397
  optimizer_params: dict | None = None,
401
398
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
402
399
  scheduler_params: dict | None = None,
403
- loss: str | nn.Module | list[str | nn.Module] | None= "bce"):
400
+ loss: str | nn.Module | list[str | nn.Module] | None= "bce",
401
+ loss_params: dict | list[dict] | None = None):
402
+
404
403
  if optimizer_params is None:
405
404
  optimizer_params = {}
406
405
 
@@ -415,9 +414,10 @@ class BaseModel(FeatureConfig, nn.Module):
415
414
  self._scheduler_name = None
416
415
  self._scheduler_params = scheduler_params or {}
417
416
  self._loss_config = loss
417
+ self._loss_params = loss_params
418
418
 
419
419
  # set optimizer
420
- self.optimizer_fn = get_optimizer_fn(
420
+ self.optimizer_fn = get_optimizer(
421
421
  optimizer=optimizer,
422
422
  params=self.parameters(),
423
423
  **optimizer_params
@@ -430,7 +430,12 @@ class BaseModel(FeatureConfig, nn.Module):
430
430
  # For ranking and multitask, use pointwise training
431
431
  training_mode = 'pointwise' if self.task_type in ['ranking', 'multitask'] else None
432
432
  # Use task_type directly, not self.task_type for single task
433
- self.loss_fn = [get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value)]
433
+ self.loss_fn = [get_loss_fn(
434
+ task_type=task_type,
435
+ training_mode=training_mode,
436
+ loss=loss_value,
437
+ **get_loss_kwargs(loss_params)
438
+ )]
434
439
  else:
435
440
  self.loss_fn = []
436
441
  for i in range(self.nums_task):
@@ -443,10 +448,15 @@ class BaseModel(FeatureConfig, nn.Module):
443
448
 
444
449
  # Multitask always uses pointwise training
445
450
  training_mode = 'pointwise'
446
- self.loss_fn.append(get_loss_fn(task_type=task_type, training_mode=training_mode, loss=loss_value))
451
+ self.loss_fn.append(get_loss_fn(
452
+ task_type=task_type,
453
+ training_mode=training_mode,
454
+ loss=loss_value,
455
+ **get_loss_kwargs(loss_params, i)
456
+ ))
447
457
 
448
458
  # set scheduler
449
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
459
+ self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
450
460
 
451
461
  def compute_loss(self, y_pred, y_true):
452
462
  if y_true is None:
@@ -1130,10 +1140,13 @@ class BaseMatchModel(BaseModel):
1130
1140
  Base class for match (retrieval/recall) models
1131
1141
  Supports pointwise, pairwise, and listwise training modes
1132
1142
  """
1133
-
1143
+ @property
1144
+ def model_name(self) -> str:
1145
+ raise NotImplementedError
1146
+
1134
1147
  @property
1135
1148
  def task_type(self) -> str:
1136
- return 'match'
1149
+ raise NotImplementedError
1137
1150
 
1138
1151
  @property
1139
1152
  def support_training_modes(self) -> list[str]:
@@ -1209,45 +1222,47 @@ class BaseMatchModel(BaseModel):
1209
1222
  self.num_negative_samples = num_negative_samples
1210
1223
  self.temperature = temperature
1211
1224
  self.similarity_metric = similarity_metric
1212
-
1225
+
1226
+ self.user_feature_names = [f.name for f in (
1227
+ self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1228
+ )]
1229
+ self.item_feature_names = [f.name for f in (
1230
+ self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1231
+ )]
1232
+
1213
1233
  def get_user_features(self, X_input: dict) -> dict:
1214
- user_input = {}
1215
- all_user_features = self.user_dense_features + self.user_sparse_features + self.user_sequence_features
1216
- for feature in all_user_features:
1217
- if feature.name in X_input:
1218
- user_input[feature.name] = X_input[feature.name]
1219
- return user_input
1220
-
1234
+ return {
1235
+ name: X_input[name]
1236
+ for name in self.user_feature_names
1237
+ if name in X_input
1238
+ }
1239
+
1221
1240
  def get_item_features(self, X_input: dict) -> dict:
1222
- item_input = {}
1223
- all_item_features = self.item_dense_features + self.item_sparse_features + self.item_sequence_features
1224
- for feature in all_item_features:
1225
- if feature.name in X_input:
1226
- item_input[feature.name] = X_input[feature.name]
1227
- return item_input
1228
-
1241
+ return {
1242
+ name: X_input[name]
1243
+ for name in self.item_feature_names
1244
+ if name in X_input
1245
+ }
1246
+
1229
1247
  def compile(self,
1230
- optimizer = "adam",
1248
+ optimizer: str | torch.optim.Optimizer = "adam",
1231
1249
  optimizer_params: dict | None = None,
1232
1250
  scheduler: str | torch.optim.lr_scheduler._LRScheduler | type[torch.optim.lr_scheduler._LRScheduler] | None = None,
1233
1251
  scheduler_params: dict | None = None,
1234
- loss: str | nn.Module | list[str | nn.Module] | None= None):
1252
+ loss: str | nn.Module | list[str | nn.Module] | None = "bce",
1253
+ loss_params: dict | list[dict] | None = None):
1235
1254
  """
1236
1255
  Compile match model with optimizer, scheduler, and loss function.
1237
- Validates that training_mode is supported by the model.
1256
+ Mirrors BaseModel.compile while adding training_mode validation for match tasks.
1238
1257
  """
1239
- from nextrec.loss import validate_training_mode
1240
-
1241
- # Validate training mode is supported
1242
- validate_training_mode(
1243
- training_mode=self.training_mode,
1244
- support_training_modes=self.support_training_modes,
1245
- model_name=self.model_name
1246
- )
1247
-
1258
+ if self.training_mode not in self.support_training_modes:
1259
+ raise ValueError(
1260
+ f"{self.model_name} does not support training_mode='{self.training_mode}'. "
1261
+ f"Supported modes: {self.support_training_modes}"
1262
+ )
1263
+
1248
1264
  # Call parent compile with match-specific logic
1249
- if optimizer_params is None:
1250
- optimizer_params = {}
1265
+ optimizer_params = optimizer_params or {}
1251
1266
 
1252
1267
  self._optimizer_name = optimizer if isinstance(optimizer, str) else optimizer.__class__.__name__
1253
1268
  self._optimizer_params = optimizer_params
@@ -1260,24 +1275,42 @@ class BaseMatchModel(BaseModel):
1260
1275
  self._scheduler_name = None
1261
1276
  self._scheduler_params = scheduler_params or {}
1262
1277
  self._loss_config = loss
1278
+ self._loss_params = loss_params
1263
1279
 
1264
1280
  # set optimizer
1265
- self.optimizer_fn = get_optimizer_fn(
1281
+ self.optimizer_fn = get_optimizer(
1266
1282
  optimizer=optimizer,
1267
1283
  params=self.parameters(),
1268
1284
  **optimizer_params
1269
1285
  )
1270
1286
 
1271
1287
  # Set loss function based on training mode
1272
- loss_value = loss[0] if isinstance(loss, list) else loss
1288
+ default_losses = {
1289
+ 'pointwise': 'bce',
1290
+ 'pairwise': 'bpr',
1291
+ 'listwise': 'sampled_softmax',
1292
+ }
1293
+
1294
+ if loss is None:
1295
+ loss_value = default_losses.get(self.training_mode, "bce")
1296
+ elif isinstance(loss, list):
1297
+ loss_value = loss[0] if loss and loss[0] is not None else default_losses.get(self.training_mode, "bce")
1298
+ else:
1299
+ loss_value = loss
1300
+
1301
+ # Pairwise/listwise modes do not support BCE, fall back to sensible defaults
1302
+ if self.training_mode in {"pairwise", "listwise"} and loss_value in {"bce", "binary_crossentropy"}:
1303
+ loss_value = default_losses.get(self.training_mode, loss_value)
1304
+
1273
1305
  self.loss_fn = [get_loss_fn(
1274
1306
  task_type='match',
1275
1307
  training_mode=self.training_mode,
1276
- loss=loss_value
1308
+ loss=loss_value,
1309
+ **get_loss_kwargs(loss_params, 0)
1277
1310
  )]
1278
1311
 
1279
1312
  # set scheduler
1280
- self.scheduler_fn = get_scheduler_fn(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1313
+ self.scheduler_fn = get_scheduler(scheduler, self.optimizer_fn, **(scheduler_params or {})) if scheduler else None
1281
1314
 
1282
1315
  def compute_similarity(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
1283
1316
  if self.similarity_metric == 'dot':
@@ -0,0 +1,42 @@
1
+ from nextrec.loss.listwise import (
2
+ ApproxNDCGLoss,
3
+ InfoNCELoss,
4
+ ListMLELoss,
5
+ ListNetLoss,
6
+ SampledSoftmaxLoss,
7
+ )
8
+ from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
9
+ from nextrec.loss.pointwise import (
10
+ ClassBalancedFocalLoss,
11
+ CosineContrastiveLoss,
12
+ FocalLoss,
13
+ WeightedBCELoss,
14
+ )
15
+ from nextrec.loss.loss_utils import (
16
+ get_loss_fn,
17
+ get_loss_kwargs,
18
+ VALID_TASK_TYPES,
19
+ )
20
+
21
+ __all__ = [
22
+ # Pointwise
23
+ "CosineContrastiveLoss",
24
+ "WeightedBCELoss",
25
+ "FocalLoss",
26
+ "ClassBalancedFocalLoss",
27
+ # Pairwise
28
+ "BPRLoss",
29
+ "HingeLoss",
30
+ "TripletLoss",
31
+ # Listwise
32
+ "SampledSoftmaxLoss",
33
+ "InfoNCELoss",
34
+ "ListNetLoss",
35
+ "ListMLELoss",
36
+ "ApproxNDCGLoss",
37
+ # Utilities
38
+ "get_loss_fn",
39
+ "get_loss_kwargs",
40
+ "validate_training_mode",
41
+ "VALID_TASK_TYPES",
42
+ ]
@@ -0,0 +1,164 @@
1
+ """
2
+ Listwise loss functions for ranking and contrastive training.
3
+ """
4
+
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class SampledSoftmaxLoss(nn.Module):
13
+ """
14
+ Softmax over one positive and multiple sampled negatives.
15
+ """
16
+
17
+ def __init__(self, reduction: str = "mean"):
18
+ super().__init__()
19
+ self.reduction = reduction
20
+
21
+ def forward(self, pos_logits: torch.Tensor, neg_logits: torch.Tensor) -> torch.Tensor:
22
+ pos_logits = pos_logits.unsqueeze(1)
23
+ all_logits = torch.cat([pos_logits, neg_logits], dim=1)
24
+ targets = torch.zeros(all_logits.size(0), dtype=torch.long, device=all_logits.device)
25
+ loss = F.cross_entropy(all_logits, targets, reduction=self.reduction)
26
+ return loss
27
+
28
+
29
+ class InfoNCELoss(nn.Module):
30
+ """
31
+ InfoNCE loss for contrastive learning with one positive and many negatives.
32
+ """
33
+
34
+ def __init__(self, temperature: float = 0.07, reduction: str = "mean"):
35
+ super().__init__()
36
+ self.temperature = temperature
37
+ self.reduction = reduction
38
+
39
+ def forward(
40
+ self, query: torch.Tensor, pos_key: torch.Tensor, neg_keys: torch.Tensor
41
+ ) -> torch.Tensor:
42
+ pos_sim = torch.sum(query * pos_key, dim=-1) / self.temperature
43
+ pos_sim = pos_sim.unsqueeze(1)
44
+ query_expanded = query.unsqueeze(1)
45
+ neg_sim = torch.sum(query_expanded * neg_keys, dim=-1) / self.temperature
46
+ logits = torch.cat([pos_sim, neg_sim], dim=1)
47
+ labels = torch.zeros(logits.size(0), dtype=torch.long, device=logits.device)
48
+ loss = F.cross_entropy(logits, labels, reduction=self.reduction)
49
+ return loss
50
+
51
+
52
+ class ListNetLoss(nn.Module):
53
+ """
54
+ ListNet loss using top-1 probability distribution.
55
+ Reference: Cao et al. (ICML 2007)
56
+ """
57
+
58
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
59
+ super().__init__()
60
+ self.temperature = temperature
61
+ self.reduction = reduction
62
+
63
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
64
+ pred_probs = F.softmax(scores / self.temperature, dim=1)
65
+ true_probs = F.softmax(labels / self.temperature, dim=1)
66
+ loss = -torch.sum(true_probs * torch.log(pred_probs + 1e-10), dim=1)
67
+
68
+ if self.reduction == "mean":
69
+ return loss.mean()
70
+ if self.reduction == "sum":
71
+ return loss.sum()
72
+ return loss
73
+
74
+
75
+ class ListMLELoss(nn.Module):
76
+ """
77
+ ListMLE (Maximum Likelihood Estimation) loss.
78
+ Reference: Xia et al. (ICML 2008)
79
+ """
80
+
81
+ def __init__(self, reduction: str = "mean"):
82
+ super().__init__()
83
+ self.reduction = reduction
84
+
85
+ def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
86
+ sorted_labels, sorted_indices = torch.sort(labels, descending=True, dim=1)
87
+ batch_size, list_size = scores.shape
88
+ batch_indices = torch.arange(batch_size, device=scores.device).unsqueeze(1).expand(-1, list_size)
89
+ sorted_scores = scores[batch_indices, sorted_indices]
90
+
91
+ loss = torch.tensor(0.0, device=scores.device)
92
+ for i in range(list_size):
93
+ remaining_scores = sorted_scores[:, i:]
94
+ log_sum_exp = torch.logsumexp(remaining_scores, dim=1)
95
+ loss = loss + (log_sum_exp - sorted_scores[:, i]).sum()
96
+
97
+ if self.reduction == "mean":
98
+ return loss / batch_size
99
+ if self.reduction == "sum":
100
+ return loss
101
+ return loss / batch_size
102
+
103
+
104
+ class ApproxNDCGLoss(nn.Module):
105
+ """
106
+ Approximate NDCG loss for learning to rank.
107
+ Reference: Qin et al. (2010)
108
+ """
109
+
110
+ def __init__(self, temperature: float = 1.0, reduction: str = "mean"):
111
+ super().__init__()
112
+ self.temperature = temperature
113
+ self.reduction = reduction
114
+
115
+ def _ideal_dcg(self, labels: torch.Tensor, k: Optional[int]) -> torch.Tensor:
116
+ # labels: [B, L]
117
+ sorted_labels, _ = torch.sort(labels, dim=1, descending=True)
118
+ if k is not None:
119
+ sorted_labels = sorted_labels[:, :k]
120
+
121
+ gains = torch.pow(2.0, sorted_labels) - 1.0 # [B, K]
122
+ positions = torch.arange(
123
+ 1, gains.size(1) + 1, device=gains.device, dtype=torch.float32
124
+ ) # [K]
125
+ discounts = 1.0 / torch.log2(positions + 1.0) # [K]
126
+ ideal_dcg = torch.sum(gains * discounts, dim=1) # [B]
127
+ return ideal_dcg
128
+
129
+ def forward(
130
+ self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
131
+ ) -> torch.Tensor:
132
+ """
133
+ scores: [B, L]
134
+ labels: [B, L]
135
+ """
136
+ batch_size, list_size = scores.shape
137
+ device = scores.device
138
+
139
+ # diff[b, i, j] = (s_j - s_i) / T
140
+ scores_i = scores.unsqueeze(2) # [B, L, 1]
141
+ scores_j = scores.unsqueeze(1) # [B, 1, L]
142
+ diff = (scores_j - scores_i) / self.temperature # [B, L, L]
143
+
144
+ P_ji = torch.sigmoid(diff) # [B, L, L]
145
+ eye = torch.eye(list_size, device=device).unsqueeze(0) # [1, L, L]
146
+ P_ji = P_ji * (1.0 - eye)
147
+
148
+ exp_rank = 1.0 + P_ji.sum(dim=-1) # [B, L]
149
+
150
+ discounts = 1.0 / torch.log2(exp_rank + 1.0) # [B, L]
151
+
152
+ gains = torch.pow(2.0, labels) - 1.0 # [B, L]
153
+ approx_dcg = torch.sum(gains * discounts, dim=1) # [B]
154
+
155
+ ideal_dcg = self._ideal_dcg(labels, k) # [B]
156
+
157
+ ndcg = approx_dcg / (ideal_dcg + 1e-10) # [B]
158
+ loss = 1.0 - ndcg
159
+
160
+ if self.reduction == "mean":
161
+ return loss.mean()
162
+ if self.reduction == "sum":
163
+ return loss.sum()
164
+ return loss