nextrec 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/callback.py +30 -15
  3. nextrec/basic/features.py +1 -0
  4. nextrec/basic/layers.py +6 -8
  5. nextrec/basic/loggers.py +14 -7
  6. nextrec/basic/metrics.py +6 -76
  7. nextrec/basic/model.py +312 -318
  8. nextrec/cli.py +5 -10
  9. nextrec/data/__init__.py +13 -16
  10. nextrec/data/batch_utils.py +3 -2
  11. nextrec/data/data_processing.py +10 -2
  12. nextrec/data/data_utils.py +9 -14
  13. nextrec/data/dataloader.py +12 -13
  14. nextrec/data/preprocessor.py +328 -255
  15. nextrec/loss/__init__.py +1 -5
  16. nextrec/loss/loss_utils.py +2 -8
  17. nextrec/models/generative/__init__.py +1 -8
  18. nextrec/models/generative/hstu.py +6 -4
  19. nextrec/models/multi_task/esmm.py +2 -2
  20. nextrec/models/multi_task/mmoe.py +2 -2
  21. nextrec/models/multi_task/ple.py +2 -2
  22. nextrec/models/multi_task/poso.py +2 -3
  23. nextrec/models/multi_task/share_bottom.py +2 -2
  24. nextrec/models/ranking/afm.py +2 -2
  25. nextrec/models/ranking/autoint.py +2 -2
  26. nextrec/models/ranking/dcn.py +2 -2
  27. nextrec/models/ranking/dcn_v2.py +2 -2
  28. nextrec/models/ranking/deepfm.py +2 -2
  29. nextrec/models/ranking/dien.py +3 -3
  30. nextrec/models/ranking/din.py +3 -3
  31. nextrec/models/ranking/ffm.py +0 -0
  32. nextrec/models/ranking/fibinet.py +5 -5
  33. nextrec/models/ranking/fm.py +3 -7
  34. nextrec/models/ranking/lr.py +0 -0
  35. nextrec/models/ranking/masknet.py +2 -2
  36. nextrec/models/ranking/pnn.py +2 -2
  37. nextrec/models/ranking/widedeep.py +2 -2
  38. nextrec/models/ranking/xdeepfm.py +2 -2
  39. nextrec/models/representation/__init__.py +9 -0
  40. nextrec/models/{generative → representation}/rqvae.py +9 -9
  41. nextrec/models/retrieval/__init__.py +0 -0
  42. nextrec/models/{match → retrieval}/dssm.py +8 -3
  43. nextrec/models/{match → retrieval}/dssm_v2.py +8 -3
  44. nextrec/models/{match → retrieval}/mind.py +4 -3
  45. nextrec/models/{match → retrieval}/sdm.py +4 -3
  46. nextrec/models/{match → retrieval}/youtube_dnn.py +8 -3
  47. nextrec/utils/__init__.py +60 -46
  48. nextrec/utils/config.py +8 -7
  49. nextrec/utils/console.py +371 -0
  50. nextrec/utils/{synthetic_data.py → data.py} +102 -15
  51. nextrec/utils/feature.py +15 -0
  52. nextrec/utils/torch_utils.py +411 -0
  53. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/METADATA +6 -6
  54. nextrec-0.4.9.dist-info/RECORD +70 -0
  55. nextrec/utils/cli_utils.py +0 -58
  56. nextrec/utils/device.py +0 -78
  57. nextrec/utils/distributed.py +0 -141
  58. nextrec/utils/file.py +0 -92
  59. nextrec/utils/initializer.py +0 -79
  60. nextrec/utils/optimizer.py +0 -75
  61. nextrec/utils/tensor.py +0 -72
  62. nextrec-0.4.8.dist-info/RECORD +0 -71
  63. /nextrec/models/{match/__init__.py → ranking/eulernet.py} +0 -0
  64. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/WHEEL +0 -0
  65. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/entry_points.txt +0 -0
  66. {nextrec-0.4.8.dist-info → nextrec-0.4.9.dist-info}/licenses/LICENSE +0 -0
nextrec/loss/__init__.py CHANGED
@@ -5,6 +5,7 @@ from nextrec.loss.listwise import (
5
5
  ListNetLoss,
6
6
  SampledSoftmaxLoss,
7
7
  )
8
+ from nextrec.loss.loss_utils import VALID_TASK_TYPES, get_loss_fn, get_loss_kwargs
8
9
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
9
10
  from nextrec.loss.pointwise import (
10
11
  ClassBalancedFocalLoss,
@@ -12,11 +13,6 @@ from nextrec.loss.pointwise import (
12
13
  FocalLoss,
13
14
  WeightedBCELoss,
14
15
  )
15
- from nextrec.loss.loss_utils import (
16
- get_loss_fn,
17
- get_loss_kwargs,
18
- VALID_TASK_TYPES,
19
- )
20
16
 
21
17
  __all__ = [
22
18
  # Pointwise
@@ -2,7 +2,7 @@
2
2
  Loss utilities for NextRec.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 17/12/2025
5
+ Checkpoint: edit on 19/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
 
@@ -18,16 +18,10 @@ from nextrec.loss.listwise import (
18
18
  SampledSoftmaxLoss,
19
19
  )
20
20
  from nextrec.loss.pairwise import BPRLoss, HingeLoss, TripletLoss
21
- from nextrec.loss.pointwise import (
22
- ClassBalancedFocalLoss,
23
- FocalLoss,
24
- WeightedBCELoss,
25
- )
26
-
21
+ from nextrec.loss.pointwise import ClassBalancedFocalLoss, FocalLoss, WeightedBCELoss
27
22
 
28
23
  VALID_TASK_TYPES = [
29
24
  "binary",
30
- "multiclass",
31
25
  "multilabel",
32
26
  "regression",
33
27
  ]
@@ -5,12 +5,5 @@ This module contains generative models for recommendation tasks.
5
5
  """
6
6
 
7
7
  from nextrec.models.generative.hstu import HSTU
8
- from nextrec.models.generative.rqvae import (
9
- RQVAE,
10
- RQ,
11
- VQEmbedding,
12
- BalancedKmeans,
13
- kmeans,
14
- )
15
8
 
16
- __all__ = ["HSTU", "RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]
9
+ __all__ = ["HSTU"]
@@ -54,10 +54,9 @@ import torch
54
54
  import torch.nn as nn
55
55
  import torch.nn.functional as F
56
56
 
57
- from nextrec.basic.model import BaseModel
58
- from nextrec.basic.layers import RMSNorm, EmbeddingLayer
59
57
  from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
60
-
58
+ from nextrec.basic.layers import EmbeddingLayer, RMSNorm
59
+ from nextrec.basic.model import BaseModel
61
60
  from nextrec.utils.model import select_features
62
61
 
63
62
 
@@ -302,7 +301,7 @@ class HSTU(BaseModel):
302
301
 
303
302
  @property
304
303
  def default_task(self) -> str:
305
- return "multiclass"
304
+ return "binary"
306
305
 
307
306
  def __init__(
308
307
  self,
@@ -336,6 +335,9 @@ class HSTU(BaseModel):
336
335
  device: str = "cpu",
337
336
  **kwargs,
338
337
  ):
338
+ raise NotImplementedError(
339
+ "[HSTU Error] NextRec no longer supports multiclass tasks; HSTU is disabled."
340
+ )
339
341
  if not sequence_features:
340
342
  raise ValueError(
341
343
  "[HSTU Error] HSTU requires at least one SequenceFeature (user behavior history)."
@@ -44,9 +44,9 @@ CVR 预测 P(conversion|click),二者相乘得到 CTCVR 并在曝光标签上
44
44
  import torch
45
45
  import torch.nn as nn
46
46
 
47
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
48
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
47
49
  from nextrec.basic.model import BaseModel
48
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
49
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
50
50
 
51
51
 
52
52
  class ESMM(BaseModel):
@@ -45,9 +45,9 @@ MMoE(Multi-gate Mixture-of-Experts)是多任务学习框架,通过多个
45
45
  import torch
46
46
  import torch.nn as nn
47
47
 
48
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
50
  from nextrec.basic.model import BaseModel
49
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
50
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
51
51
 
52
52
 
53
53
  class MMOE(BaseModel):
@@ -48,9 +48,9 @@ PLE(Progressive Layered Extraction)通过堆叠 CGC 模块,联合共享与
48
48
  import torch
49
49
  import torch.nn as nn
50
50
 
51
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
52
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
51
53
  from nextrec.basic.model import BaseModel
52
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
53
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
54
54
  from nextrec.utils.model import get_mlp_output_dim
55
55
 
56
56
 
@@ -42,11 +42,10 @@ import torch
42
42
  import torch.nn as nn
43
43
  import torch.nn.functional as F
44
44
 
45
- from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
47
45
  from nextrec.basic.activation import activation_layer
46
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
47
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
48
48
  from nextrec.basic.model import BaseModel
49
-
50
49
  from nextrec.utils.model import select_features
51
50
 
52
51
 
@@ -42,9 +42,9 @@ Share-Bottom(硬共享底层)是多任务学习的经典基线:所有任
42
42
  import torch
43
43
  import torch.nn as nn
44
44
 
45
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
46
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
45
47
  from nextrec.basic.model import BaseModel
46
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
47
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
48
48
 
49
49
 
50
50
  class ShareBottom(BaseModel):
@@ -39,9 +39,9 @@ AFM 在 FM 的二阶交互上引入注意力,为每个特征对学习重要性
39
39
  import torch
40
40
  import torch.nn as nn
41
41
 
42
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
43
+ from nextrec.basic.layers import EmbeddingLayer, InputMask, PredictionLayer
42
44
  from nextrec.basic.model import BaseModel
43
- from nextrec.basic.layers import EmbeddingLayer, PredictionLayer, InputMask
44
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
45
45
 
46
46
 
47
47
  class AFM(BaseModel):
@@ -57,9 +57,9 @@ AutoInt 通过对所有特征 embedding 进行注意力计算,捕捉特征之
57
57
  import torch
58
58
  import torch.nn as nn
59
59
 
60
- from nextrec.basic.model import BaseModel
60
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
61
  from nextrec.basic.layers import EmbeddingLayer, MultiHeadSelfAttention, PredictionLayer
62
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
62
+ from nextrec.basic.model import BaseModel
63
63
 
64
64
 
65
65
  class AutoInt(BaseModel):
@@ -53,9 +53,9 @@ Deep 分支提升表达能力;最终将 Cross(及 Deep)结果送入线性
53
53
  import torch
54
54
  import torch.nn as nn
55
55
 
56
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
57
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
56
58
  from nextrec.basic.model import BaseModel
57
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
58
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
59
59
 
60
60
 
61
61
  class CrossNetwork(nn.Module):
@@ -46,9 +46,9 @@ DCN v2 在原始 DCN 基础上,将标量交叉权重升级为向量/矩阵参
46
46
  import torch
47
47
  import torch.nn as nn
48
48
 
49
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
50
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
49
51
  from nextrec.basic.model import BaseModel
50
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
51
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
52
52
 
53
53
 
54
54
  class CrossNetV2(nn.Module):
@@ -45,9 +45,9 @@ embedding,无需手工构造交叉特征即可端到端训练,常用于 CTR/
45
45
 
46
46
  import torch.nn as nn
47
47
 
48
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
49
+ from nextrec.basic.layers import FM, LR, MLP, EmbeddingLayer, PredictionLayer
48
50
  from nextrec.basic.model import BaseModel
49
- from nextrec.basic.layers import FM, LR, EmbeddingLayer, MLP, PredictionLayer
50
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
51
51
 
52
52
 
53
53
  class DeepFM(BaseModel):
@@ -50,14 +50,14 @@ import torch
50
50
  import torch.nn as nn
51
51
  import torch.nn.functional as F
52
52
 
53
- from nextrec.basic.model import BaseModel
53
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
54
54
  from nextrec.basic.layers import (
55
- EmbeddingLayer,
56
55
  MLP,
57
56
  AttentionPoolingLayer,
57
+ EmbeddingLayer,
58
58
  PredictionLayer,
59
59
  )
60
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
60
+ from nextrec.basic.model import BaseModel
61
61
 
62
62
 
63
63
  class AUGRU(nn.Module):
@@ -50,14 +50,14 @@ DIN 是一个 CTR 预估模型,通过对用户历史行为序列进行目标
50
50
  import torch
51
51
  import torch.nn as nn
52
52
 
53
- from nextrec.basic.model import BaseModel
53
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
54
54
  from nextrec.basic.layers import (
55
- EmbeddingLayer,
56
55
  MLP,
57
56
  AttentionPoolingLayer,
57
+ EmbeddingLayer,
58
58
  PredictionLayer,
59
59
  )
60
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
60
+ from nextrec.basic.model import BaseModel
61
61
 
62
62
 
63
63
  class DIN(BaseModel):
File without changes
@@ -43,17 +43,17 @@ FiBiNET 是一个 CTR 预估模型,通过 SENET 重新分配特征字段的重
43
43
  import torch
44
44
  import torch.nn as nn
45
45
 
46
- from nextrec.basic.model import BaseModel
46
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
47
47
  from nextrec.basic.layers import (
48
- BiLinearInteractionLayer,
49
- HadamardInteractionLayer,
50
- EmbeddingLayer,
51
48
  LR,
52
49
  MLP,
50
+ BiLinearInteractionLayer,
51
+ EmbeddingLayer,
52
+ HadamardInteractionLayer,
53
53
  PredictionLayer,
54
54
  SENETLayer,
55
55
  )
56
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
56
+ from nextrec.basic.model import BaseModel
57
57
 
58
58
 
59
59
  class FiBiNET(BaseModel):
@@ -40,14 +40,10 @@ FM 是一种通过分解二阶特征交互矩阵、以线性复杂度建模特
40
40
 
41
41
  import torch.nn as nn
42
42
 
43
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
44
+ from nextrec.basic.layers import FM as FMInteraction
45
+ from nextrec.basic.layers import LR, EmbeddingLayer, PredictionLayer
43
46
  from nextrec.basic.model import BaseModel
44
- from nextrec.basic.layers import (
45
- EmbeddingLayer,
46
- FM as FMInteraction,
47
- LR,
48
- PredictionLayer,
49
- )
50
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
51
47
 
52
48
 
53
49
  class FM(BaseModel):
File without changes
@@ -57,9 +57,9 @@ import torch
57
57
  import torch.nn as nn
58
58
  import torch.nn.functional as F
59
59
 
60
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
61
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
60
62
  from nextrec.basic.model import BaseModel
61
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
62
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
63
63
 
64
64
 
65
65
  class InstanceGuidedMask(nn.Module):
@@ -37,9 +37,9 @@ PNN 是一种 CTR 预估模型,通过将线性信号与乘积信号结合,
37
37
  import torch
38
38
  import torch.nn as nn
39
39
 
40
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
41
+ from nextrec.basic.layers import MLP, EmbeddingLayer, PredictionLayer
40
42
  from nextrec.basic.model import BaseModel
41
- from nextrec.basic.layers import EmbeddingLayer, MLP, PredictionLayer
42
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
43
43
 
44
44
 
45
45
  class PNN(BaseModel):
@@ -41,9 +41,9 @@ Wide & Deep 同时使用宽线性部分(记忆共现/手工交叉)与深网
41
41
 
42
42
  import torch.nn as nn
43
43
 
44
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
45
+ from nextrec.basic.layers import LR, MLP, EmbeddingLayer, PredictionLayer
44
46
  from nextrec.basic.model import BaseModel
45
- from nextrec.basic.layers import LR, EmbeddingLayer, MLP, PredictionLayer
46
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
47
47
 
48
48
 
49
49
  class WideDeep(BaseModel):
@@ -55,9 +55,9 @@ import torch
55
55
  import torch.nn as nn
56
56
  import torch.nn.functional as F
57
57
 
58
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
59
+ from nextrec.basic.layers import LR, MLP, EmbeddingLayer, PredictionLayer
58
60
  from nextrec.basic.model import BaseModel
59
- from nextrec.basic.layers import LR, EmbeddingLayer, MLP, PredictionLayer
60
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
61
61
 
62
62
 
63
63
  class CIN(nn.Module):
@@ -0,0 +1,9 @@
1
+ from nextrec.models.representation.rqvae import (
2
+ RQ,
3
+ RQVAE,
4
+ BalancedKmeans,
5
+ VQEmbedding,
6
+ kmeans,
7
+ )
8
+
9
+ __all__ = ["RQVAE", "RQ", "VQEmbedding", "BalancedKmeans", "kmeans"]
@@ -46,21 +46,21 @@ RQ-VAE 通过残差量化学习分层离散表示,将连续嵌入(如物品/
46
46
 
47
47
  from __future__ import annotations
48
48
 
49
+ import logging
49
50
  import math
51
+ from typing import cast
52
+
50
53
  import torch
51
54
  import torch.nn as nn
52
55
  import torch.nn.functional as F
53
56
  from sklearn.cluster import KMeans
54
- from typing import cast
55
- import logging
56
- import tqdm
57
-
58
57
  from torch.utils.data import DataLoader
59
58
 
60
59
  from nextrec.basic.features import DenseFeature
60
+ from nextrec.basic.loggers import colorize, setup_logger
61
61
  from nextrec.basic.model import BaseModel
62
62
  from nextrec.data.batch_utils import batch_to_dict
63
- from nextrec.basic.loggers import colorize, setup_logger
63
+ from nextrec.utils.console import progress
64
64
 
65
65
 
66
66
  def kmeans(
@@ -729,9 +729,9 @@ class RQVAE(BaseModel):
729
729
  else:
730
730
  tqdm_disable = not self.is_main_process
731
731
  batch_iter = enumerate(
732
- tqdm.tqdm(
732
+ progress(
733
733
  train_loader,
734
- desc=f"Epoch {epoch + 1}/{epochs}",
734
+ description=f"Epoch {epoch + 1}/{epochs}",
735
735
  total=steps_per_epoch,
736
736
  disable=tqdm_disable,
737
737
  )
@@ -777,9 +777,9 @@ class RQVAE(BaseModel):
777
777
  logging.info(colorize(train_log))
778
778
 
779
779
  if self.is_main_process:
780
- logging.info(" ")
780
+ logging.info("")
781
781
  logging.info(colorize("Training finished.", bold=True))
782
- logging.info(" ")
782
+ logging.info("")
783
783
  return self
784
784
 
785
785
  def predict(
File without changes
@@ -7,13 +7,14 @@ Reference:
7
7
  //Proceedings of the 22nd ACM international conference on Information & Knowledge Management. 2013: 2333-2338.
8
8
  """
9
9
 
10
+ from typing import Literal
11
+
10
12
  import torch
11
13
  import torch.nn as nn
12
- from typing import Literal
13
14
 
14
- from nextrec.basic.model import BaseMatchModel
15
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
16
16
  from nextrec.basic.layers import MLP, EmbeddingLayer
17
+ from nextrec.basic.model import BaseMatchModel
17
18
 
18
19
 
19
20
  class DSSM(BaseMatchModel):
@@ -28,6 +29,10 @@ class DSSM(BaseMatchModel):
28
29
  def model_name(self) -> str:
29
30
  return "DSSM"
30
31
 
32
+ @property
33
+ def support_training_modes(self) -> list[str]:
34
+ return ["pointwise", "pairwise", "listwise"]
35
+
31
36
  def __init__(
32
37
  self,
33
38
  user_dense_features: list[DenseFeature] | None = None,
@@ -6,13 +6,14 @@ Reference:
6
6
  DSSM v2 - DSSM with pairwise training using BPR loss
7
7
  """
8
8
 
9
+ from typing import Literal
10
+
9
11
  import torch
10
12
  import torch.nn as nn
11
- from typing import Literal
12
13
 
13
- from nextrec.basic.model import BaseMatchModel
14
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
14
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
15
15
  from nextrec.basic.layers import MLP, EmbeddingLayer
16
+ from nextrec.basic.model import BaseMatchModel
16
17
 
17
18
 
18
19
  class DSSM_v2(BaseMatchModel):
@@ -24,6 +25,10 @@ class DSSM_v2(BaseMatchModel):
24
25
  def model_name(self) -> str:
25
26
  return "DSSM_v2"
26
27
 
28
+ @property
29
+ def support_training_modes(self) -> list[str]:
30
+ return ["pointwise", "pairwise", "listwise"]
31
+
27
32
  def __init__(
28
33
  self,
29
34
  user_dense_features: list[DenseFeature] | None = None,
@@ -7,14 +7,15 @@ Reference:
7
7
  //Proceedings of the 28th ACM international conference on information and knowledge management. 2019: 2615-2623.
8
8
  """
9
9
 
10
+ from typing import Literal
11
+
10
12
  import torch
11
13
  import torch.nn as nn
12
14
  import torch.nn.functional as F
13
- from typing import Literal
14
15
 
15
- from nextrec.basic.model import BaseMatchModel
16
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
17
17
  from nextrec.basic.layers import MLP, EmbeddingLayer
18
+ from nextrec.basic.model import BaseMatchModel
18
19
 
19
20
 
20
21
  class MultiInterestSA(nn.Module):
@@ -7,14 +7,15 @@ Reference:
7
7
  //IJCAI. 2018: 3926-3932.
8
8
  """
9
9
 
10
+ from typing import Literal
11
+
10
12
  import torch
11
13
  import torch.nn as nn
12
14
  import torch.nn.functional as F
13
- from typing import Literal
14
15
 
15
- from nextrec.basic.model import BaseMatchModel
16
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
17
17
  from nextrec.basic.layers import MLP, EmbeddingLayer
18
+ from nextrec.basic.model import BaseMatchModel
18
19
 
19
20
 
20
21
  class SDM(BaseMatchModel):
@@ -7,13 +7,14 @@ Reference:
7
7
  //Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198.
8
8
  """
9
9
 
10
+ from typing import Literal
11
+
10
12
  import torch
11
13
  import torch.nn as nn
12
- from typing import Literal
13
14
 
14
- from nextrec.basic.model import BaseMatchModel
15
- from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
+ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
16
16
  from nextrec.basic.layers import MLP, EmbeddingLayer
17
+ from nextrec.basic.model import BaseMatchModel
17
18
 
18
19
 
19
20
  class YoutubeDNN(BaseMatchModel):
@@ -28,6 +29,10 @@ class YoutubeDNN(BaseMatchModel):
28
29
  def model_name(self) -> str:
29
30
  return "YouTubeDNN"
30
31
 
32
+ @property
33
+ def support_training_modes(self) -> list[str]:
34
+ return ["pointwise", "pairwise", "listwise"]
35
+
31
36
  def __init__(
32
37
  self,
33
38
  user_dense_features: list[DenseFeature] | None = None,