nextrec 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,35 @@ Date: create on 09/11/2025
3
3
  Author:
4
4
  Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response prediction[C]//ICDM. 2016: 1149-1154.
6
+ [1] Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response
7
+ prediction[C]//ICDM. 2016: 1149-1154. (https://arxiv.org/abs/1611.00144)
8
+
9
+ Product-based Neural Networks (PNN) are CTR prediction models that explicitly
10
+ encode feature interactions by combining:
11
+ (1) A linear signal from concatenated field embeddings
12
+ (2) A product signal capturing pairwise feature interactions (inner or outer)
13
+ The product layer augments the linear input to an MLP, enabling the network to
14
+ model both first-order and high-order feature interactions in a structured way.
15
+
16
+ Computation workflow:
17
+ - Embed each categorical/sequence field with a shared embedding dimension
18
+ - Linear signal: flatten and concatenate all field embeddings
19
+ - Product signal:
20
+ * Inner product: dot products over all field pairs
21
+ * Outer product: project embeddings then compute element-wise products
22
+ - Concatenate linear and product signals; feed into MLP for prediction
23
+
24
+ Key Advantages:
25
+ - Explicit pairwise interaction modeling without heavy feature engineering
26
+ - Flexible choice between inner/outer products to trade off capacity vs. cost
27
+ - Combines linear context with interaction signal for stronger expressiveness
28
+ - Simple architecture that integrates cleanly with standard MLP pipelines
29
+
30
+ PNN 是一种 CTR 预估模型,通过将线性信号与乘积信号结合,显式建模特征交互:
31
+ - 线性信号:将各字段的 embedding 拼接,用于保留一阶信息
32
+ - 乘积信号:对所有字段对做内积或外积,捕捉二阶及更高阶交互
33
+ 随后将两类信号拼接送入 MLP,实现对用户响应的预测。内积版本计算量更低,
34
+ 外积版本表达力更强,可根据场景取舍。
7
35
  """
8
36
 
9
37
  import torch
@@ -15,6 +43,7 @@ from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
15
43
 
16
44
 
17
45
  class PNN(BaseModel):
46
+
18
47
  @property
19
48
  def model_name(self):
20
49
  return "PNN"
@@ -25,16 +54,16 @@ class PNN(BaseModel):
25
54
 
26
55
  def __init__(
27
56
  self,
28
- dense_features: list[DenseFeature] | list = [],
29
- sparse_features: list[SparseFeature] | list = [],
30
- sequence_features: list[SequenceFeature] | list = [],
31
- mlp_params: dict = {},
32
- product_type: str = "inner",
57
+ dense_features: list[DenseFeature] | None = None,
58
+ sparse_features: list[SparseFeature] | None = None,
59
+ sequence_features: list[SequenceFeature] | None = None,
60
+ mlp_params: dict | None = None,
61
+ product_type: str = "inner", # "inner" (IPNN), "outer" (OPNN), "both" (PNN*)
33
62
  outer_product_dim: int | None = None,
34
- target: list[str] | list = [],
63
+ target: list[str] | str | None = None,
35
64
  task: str | list[str] | None = None,
36
65
  optimizer: str = "adam",
37
- optimizer_params: dict = {},
66
+ optimizer_params: dict | None = None,
38
67
  loss: str | nn.Module | None = "bce",
39
68
  loss_params: dict | list[dict] | None = None,
40
69
  device: str = "cpu",
@@ -45,6 +74,16 @@ class PNN(BaseModel):
45
74
  **kwargs,
46
75
  ):
47
76
 
77
+ dense_features = dense_features or []
78
+ sparse_features = sparse_features or []
79
+ sequence_features = sequence_features or []
80
+ mlp_params = mlp_params or {}
81
+ if outer_product_dim is not None and outer_product_dim <= 0:
82
+ raise ValueError("outer_product_dim must be a positive integer.")
83
+ optimizer_params = optimizer_params or {}
84
+ if loss is None:
85
+ loss = "bce"
86
+
48
87
  super(PNN, self).__init__(
49
88
  dense_features=dense_features,
50
89
  sparse_features=sparse_features,
@@ -59,16 +98,13 @@ class PNN(BaseModel):
59
98
  **kwargs,
60
99
  )
61
100
 
62
- self.loss = loss
63
- if self.loss is None:
64
- self.loss = "bce"
65
-
66
- self.field_features = sparse_features + sequence_features
101
+ self.field_features = dense_features + sparse_features + sequence_features
67
102
  if len(self.field_features) < 2:
68
103
  raise ValueError("PNN requires at least two sparse/sequence features.")
69
104
 
70
105
  self.embedding = EmbeddingLayer(features=self.field_features)
71
106
  self.num_fields = len(self.field_features)
107
+
72
108
  self.embedding_dim = self.field_features[0].embedding_dim
73
109
  if any(f.embedding_dim != self.embedding_dim for f in self.field_features):
74
110
  raise ValueError(
@@ -76,24 +112,34 @@ class PNN(BaseModel):
76
112
  )
77
113
 
78
114
  self.product_type = product_type.lower()
79
- if self.product_type not in {"inner", "outer"}:
80
- raise ValueError("product_type must be 'inner' or 'outer'.")
115
+ if self.product_type not in {"inner", "outer", "both"}:
116
+ raise ValueError("product_type must be 'inner', 'outer', or 'both'.")
81
117
 
82
118
  self.num_pairs = self.num_fields * (self.num_fields - 1) // 2
83
- if self.product_type == "outer":
84
- self.outer_dim = outer_product_dim or self.embedding_dim
85
- self.kernel = nn.Linear(self.embedding_dim, self.outer_dim, bias=False)
86
- product_dim = self.num_pairs * self.outer_dim
119
+ self.outer_product_dim = outer_product_dim or self.embedding_dim
120
+
121
+ if self.product_type in {"outer", "both"}:
122
+ self.kernel = nn.Parameter(
123
+ torch.randn(self.embedding_dim, self.outer_product_dim)
124
+ )
125
+ nn.init.xavier_uniform_(self.kernel)
87
126
  else:
88
- self.outer_dim = None
89
- product_dim = self.num_pairs
127
+ self.kernel = None
90
128
 
91
129
  linear_dim = self.num_fields * self.embedding_dim
130
+
131
+ if self.product_type == "inner":
132
+ product_dim = self.num_pairs
133
+ elif self.product_type == "outer":
134
+ product_dim = self.num_pairs
135
+ else:
136
+ product_dim = 2 * self.num_pairs
137
+
92
138
  self.mlp = MLP(input_dim=linear_dim + product_dim, **mlp_params)
93
139
  self.prediction_layer = PredictionLayer(task_type=self.task)
94
140
 
95
141
  modules = ["mlp"]
96
- if self.product_type == "outer":
142
+ if self.kernel is not None:
97
143
  modules.append("kernel")
98
144
  self.register_regularization_weights(
99
145
  embedding_attr="embedding", include_modules=modules
@@ -106,27 +152,48 @@ class PNN(BaseModel):
106
152
  loss_params=loss_params,
107
153
  )
108
154
 
155
+ def compute_inner_products(self, field_emb: torch.Tensor) -> torch.Tensor:
156
+ interactions = []
157
+ for i in range(self.num_fields - 1):
158
+ vi = field_emb[:, i, :] # [B, D]
159
+ for j in range(i + 1, self.num_fields):
160
+ vj = field_emb[:, j, :] # [B, D]
161
+ # <v_i, v_j> = sum_k v_i,k * v_j,k
162
+ pij = torch.sum(vi * vj, dim=1, keepdim=True) # [B, 1]
163
+ interactions.append(pij)
164
+ return torch.cat(interactions, dim=1) # [B, num_pairs]
165
+
166
+ def compute_outer_kernel_products(self, field_emb: torch.Tensor) -> torch.Tensor:
167
+ if self.kernel is None:
168
+ raise RuntimeError("kernel is not initialized for outer product.")
169
+
170
+ interactions = []
171
+ for i in range(self.num_fields - 1):
172
+ vi = field_emb[:, i, :] # [B, D]
173
+ # Project vi with kernel -> [B, K]
174
+ vi_proj = torch.matmul(vi, self.kernel) # [B, K]
175
+ for j in range(i + 1, self.num_fields):
176
+ vj = field_emb[:, j, :] # [B, D]
177
+ vj_proj = torch.matmul(vj, self.kernel) # [B, K]
178
+ # g(vi, vj) = (v_i^T W) * (v_j^T W) summed over projection dim
179
+ pij = torch.sum(vi_proj * vj_proj, dim=1, keepdim=True) # [B, 1]
180
+ interactions.append(pij)
181
+ return torch.cat(interactions, dim=1) # [B, num_pairs]
182
+
109
183
  def forward(self, x):
184
+ # field_emb: [B, F, D]
110
185
  field_emb = self.embedding(x=x, features=self.field_features, squeeze_dim=False)
111
- linear_signal = field_emb.flatten(start_dim=1)
186
+ # Z = [v_1; v_2; ...; v_F]
187
+ linear_signal = field_emb.flatten(start_dim=1) # [B, F*D]
112
188
 
113
189
  if self.product_type == "inner":
114
- interactions = []
115
- for i in range(self.num_fields - 1):
116
- vi = field_emb[:, i, :]
117
- for j in range(i + 1, self.num_fields):
118
- vj = field_emb[:, j, :]
119
- interactions.append(torch.sum(vi * vj, dim=1, keepdim=True))
120
- product_signal = torch.cat(interactions, dim=1)
190
+ product_signal = self.compute_inner_products(field_emb)
191
+ elif self.product_type == "outer":
192
+ product_signal = self.compute_outer_kernel_products(field_emb)
121
193
  else:
122
- transformed = self.kernel(field_emb) # [B, F, outer_dim]
123
- interactions = []
124
- for i in range(self.num_fields - 1):
125
- vi = transformed[:, i, :]
126
- for j in range(i + 1, self.num_fields):
127
- vj = transformed[:, j, :]
128
- interactions.append(vi * vj)
129
- product_signal = torch.stack(interactions, dim=1).flatten(start_dim=1)
194
+ inner_p = self.compute_inner_products(field_emb)
195
+ outer_p = self.compute_outer_kernel_products(field_emb)
196
+ product_signal = torch.cat([inner_p, outer_p], dim=1)
130
197
 
131
198
  deep_input = torch.cat([linear_signal, product_signal], dim=1)
132
199
  y = self.mlp(deep_input)
@@ -61,10 +61,10 @@ class WideDeep(BaseModel):
61
61
  sparse_features: list[SparseFeature],
62
62
  sequence_features: list[SequenceFeature],
63
63
  mlp_params: dict,
64
- target: list[str] = [],
64
+ target: list[str] | str | None = None,
65
65
  task: str | list[str] | None = None,
66
66
  optimizer: str = "adam",
67
- optimizer_params: dict = {},
67
+ optimizer_params: dict | None = None,
68
68
  loss: str | nn.Module | None = "bce",
69
69
  loss_params: dict | list[dict] | None = None,
70
70
  device: str = "cpu",
@@ -75,6 +75,12 @@ class WideDeep(BaseModel):
75
75
  **kwargs,
76
76
  ):
77
77
 
78
+ if target is None:
79
+ target = []
80
+ optimizer_params = optimizer_params or {}
81
+ if loss is None:
82
+ loss = "bce"
83
+
78
84
  super(WideDeep, self).__init__(
79
85
  dense_features=dense_features,
80
86
  sparse_features=sparse_features,
@@ -90,8 +96,6 @@ class WideDeep(BaseModel):
90
96
  )
91
97
 
92
98
  self.loss = loss
93
- if self.loss is None:
94
- self.loss = "bce"
95
99
 
96
100
  # Wide part: use all features for linear model
97
101
  self.wide_features = sparse_features + sequence_features
@@ -1,12 +1,54 @@
1
1
  """
2
2
  Date: create on 09/11/2025
3
3
  Author:
4
- Yang Zhou,zyaztec@gmail.com
4
+ Yang Zhou,zyaztec@gmail.com
5
5
  Reference:
6
- [1] Lian J, Zhou X, Zhang F, et al. xdeepfm: Combining explicit and implicit feature interactions
7
- for recommender systems[C]//Proceedings of the 24th ACM SIGKDD international conference on
8
- knowledge discovery & data mining. 2018: 1754-1763.
9
- (https://arxiv.org/abs/1803.05170)
6
+ [1] Lian J, Zhou X, Zhang F, et al. xdeepfm: Combining explicit and implicit feature interactions
7
+ for recommender systems[C]//Proceedings of the 24th ACM SIGKDD international conference on
8
+ knowledge discovery & data mining. 2018: 1754-1763.
9
+ (https://arxiv.org/abs/1803.05170)
10
+
11
+ xDeepFM is a CTR prediction model that unifies explicit and implicit
12
+ feature interaction learning. It extends DeepFM by adding the
13
+ Compressed Interaction Network (CIN) to explicitly model high-order
14
+ interactions at the vector-wise level, while an MLP captures implicit
15
+ non-linear crosses. A linear term retains first-order signals, and all
16
+ three parts are learned jointly end-to-end.
17
+
18
+ In the forward pass:
19
+ (1) Embedding Layer: transforms sparse/sequence fields into dense vectors
20
+ (2) Linear Part: captures first-order contributions of sparse/sequence fields
21
+ (3) CIN: explicitly builds higher-order feature crosses via convolution over
22
+ outer products of field embeddings, with optional split-half connections
23
+ (4) Deep Part (MLP): models implicit, non-linear interactions across all fields
24
+ (5) Combination: sums outputs from linear, CIN, and deep branches before the
25
+ task-specific prediction layer
26
+
27
+ Key Advantages:
28
+ - Jointly learns first-order, explicit high-order, and implicit interactions
29
+ - CIN offers interpretable vector-wise crosses with controlled complexity
30
+ - Deep branch enhances representation power for non-linear patterns
31
+ - End-to-end optimization eliminates heavy manual feature engineering
32
+ - Flexible design supports both sparse and sequence features
33
+
34
+ xDeepFM 是一个 CTR 预估模型,将显式与隐式的特征交互学习统一到同一框架。
35
+ 在 DeepFM 的基础上,额外引入了 CIN(Compressed Interaction Network)
36
+ 显式建模高阶向量级交互,同时 MLP 负责隐式非线性交互,线性部分保留一阶信号,
37
+ 三者联合训练。
38
+
39
+ 前向流程:
40
+ (1) 嵌入层:将稀疏/序列特征映射为稠密向量
41
+ (2) 线性部分:建模稀疏/序列特征的一阶贡献
42
+ (3) CIN:通过对字段嵌入做外积并卷积,显式捕获高阶交叉,可选 split-half 以控参
43
+ (4) 深层部分(MLP):对所有特征进行隐式非线性交互建模
44
+ (5) 融合:线性、CIN、MLP 输出求和后进入任务预测层
45
+
46
+ 主要优点:
47
+ - 同时学习一阶、显式高阶、隐式交互
48
+ - CIN 提供可解释的向量级交叉并可控复杂度
49
+ - 深层分支提升非线性表达能力
50
+ - 端到端训练降低人工特征工程需求
51
+ - 兼容稀疏与序列特征的建模
10
52
  """
11
53
 
12
54
  import torch
@@ -76,12 +118,12 @@ class xDeepFM(BaseModel):
76
118
  sparse_features: list[SparseFeature],
77
119
  sequence_features: list[SequenceFeature],
78
120
  mlp_params: dict,
79
- cin_size: list[int] = [128, 128],
121
+ cin_size: list[int] | None = None,
80
122
  split_half: bool = True,
81
- target: list[str] = [],
123
+ target: list[str] | str | None = None,
82
124
  task: str | list[str] | None = None,
83
125
  optimizer: str = "adam",
84
- optimizer_params: dict = {},
126
+ optimizer_params: dict | None = None,
85
127
  loss: str | nn.Module | None = "bce",
86
128
  loss_params: dict | list[dict] | None = None,
87
129
  device: str = "cpu",
@@ -92,6 +134,13 @@ class xDeepFM(BaseModel):
92
134
  **kwargs,
93
135
  ):
94
136
 
137
+ cin_size = cin_size or [128, 128]
138
+ if target is None:
139
+ target = []
140
+ optimizer_params = optimizer_params or {}
141
+ if loss is None:
142
+ loss = "bce"
143
+
95
144
  super(xDeepFM, self).__init__(
96
145
  dense_features=dense_features,
97
146
  sparse_features=sparse_features,
@@ -107,8 +156,6 @@ class xDeepFM(BaseModel):
107
156
  )
108
157
 
109
158
  self.loss = loss
110
- if self.loss is None:
111
- self.loss = "bce"
112
159
 
113
160
  # Linear part and CIN part: use sparse and sequence features
114
161
  self.linear_features = sparse_features + sequence_features
nextrec/utils/config.py CHANGED
@@ -28,9 +28,15 @@ def resolve_path(path_str: str | Path, base_dir: Path) -> Path:
28
28
  path = Path(path_str).expanduser()
29
29
  if path.is_absolute():
30
30
  return path
31
- if path.exists():
32
- return path.resolve()
33
- return (base_dir / path).resolve()
31
+ # Prefer resolving relative to current working directory when the path (or its parent)
32
+ # already exists there; otherwise fall back to the config file's directory.
33
+ cwd_path = (Path.cwd() / path).resolve()
34
+ if cwd_path.exists() or cwd_path.parent.exists():
35
+ return cwd_path
36
+ base_dir_path = (base_dir / path).resolve()
37
+ if base_dir_path.exists() or base_dir_path.parent.exists():
38
+ return base_dir_path
39
+ return cwd_path
34
40
 
35
41
 
36
42
  def select_features(
@@ -154,8 +160,11 @@ def build_feature_objects(
154
160
  SparseFeature(
155
161
  name=name,
156
162
  vocab_size=int(vocab_size),
163
+ embedding_name=embed_cfg.get("embedding_name", name),
157
164
  embedding_dim=embed_cfg.get("embedding_dim"),
158
165
  padding_idx=embed_cfg.get("padding_idx"),
166
+ init_type=embed_cfg.get("init_type", "xavier_uniform"),
167
+ init_params=embed_cfg.get("init_params"),
159
168
  l1_reg=embed_cfg.get("l1_reg", 0.0),
160
169
  l2_reg=embed_cfg.get("l2_reg", 1e-5),
161
170
  trainable=embed_cfg.get("trainable", True),
@@ -178,9 +187,12 @@ def build_feature_objects(
178
187
  name=name,
179
188
  vocab_size=int(vocab_size),
180
189
  max_len=embed_cfg.get("max_len") or proc_cfg.get("max_len", 50),
190
+ embedding_name=embed_cfg.get("embedding_name", name),
181
191
  embedding_dim=embed_cfg.get("embedding_dim"),
182
192
  padding_idx=embed_cfg.get("padding_idx"),
183
193
  combiner=embed_cfg.get("combiner", "mean"),
194
+ init_type=embed_cfg.get("init_type", "xavier_uniform"),
195
+ init_params=embed_cfg.get("init_params"),
184
196
  l1_reg=embed_cfg.get("l1_reg", 0.0),
185
197
  l2_reg=embed_cfg.get("l2_reg", 1e-5),
186
198
  trainable=embed_cfg.get("trainable", True),
nextrec/utils/file.py CHANGED
@@ -60,7 +60,8 @@ def read_table(path: str | Path, data_format: str | None = None) -> pd.DataFrame
60
60
  if fmt in {"parquet", ""}:
61
61
  return pd.read_parquet(data_path)
62
62
  if fmt in {"csv", "txt"}:
63
- return pd.read_csv(data_path)
63
+ # Use low_memory=False to avoid mixed-type DtypeWarning on wide CSVs
64
+ return pd.read_csv(data_path, low_memory=False)
64
65
  raise ValueError(f"Unsupported data format: {data_path}")
65
66
 
66
67
 
@@ -5,10 +5,9 @@ Date: create on 13/11/2025
5
5
  Author: Yang Zhou, zyaztec@gmail.com
6
6
  """
7
7
 
8
- from typing import Any, Dict, Set, cast
8
+ from typing import Any, Dict, Set,
9
9
 
10
10
  import torch.nn as nn
11
- from torch.nn.init import _NonlinearityType
12
11
 
13
12
  KNOWN_NONLINEARITIES: Set[str] = {
14
13
  "linear",
@@ -27,28 +26,25 @@ KNOWN_NONLINEARITIES: Set[str] = {
27
26
  }
28
27
 
29
28
 
30
- def resolve_nonlinearity(activation: str | _NonlinearityType) -> _NonlinearityType:
31
- if isinstance(activation, str):
32
- if activation in KNOWN_NONLINEARITIES:
33
- return cast(_NonlinearityType, activation)
34
- # Fall back to linear for custom activations (gain handled separately).
35
- return "linear"
36
- return activation
29
+ def resolve_nonlinearity(activation: str):
30
+ if activation in KNOWN_NONLINEARITIES:
31
+ return activation
32
+ return "linear"
37
33
 
38
34
 
39
- def resolve_gain(activation: str | _NonlinearityType, param: Dict[str, Any]) -> float:
35
+ def resolve_gain(activation: str, param: Dict[str, Any]) -> float:
40
36
  if "gain" in param:
41
37
  return param["gain"]
42
38
  nonlinearity = resolve_nonlinearity(activation)
43
39
  try:
44
- return nn.init.calculate_gain(nonlinearity, param.get("param"))
40
+ return nn.init.calculate_gain(nonlinearity, param.get("param")) # type: ignore
45
41
  except ValueError:
46
- return 1.0 # custom activation with no gain estimate available
42
+ return 1.0
47
43
 
48
44
 
49
45
  def get_initializer(
50
46
  init_type: str = "normal",
51
- activation: str | _NonlinearityType = "linear",
47
+ activation: str = "linear",
52
48
  param: Dict[str, Any] | None = None,
53
49
  ):
54
50
  param = param or {}
@@ -62,11 +58,11 @@ def get_initializer(
62
58
  nn.init.xavier_normal_(tensor, gain=gain)
63
59
  elif init_type == "kaiming_uniform":
64
60
  nn.init.kaiming_uniform_(
65
- tensor, a=param.get("a", 0), nonlinearity=nonlinearity
61
+ tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
66
62
  )
67
63
  elif init_type == "kaiming_normal":
68
64
  nn.init.kaiming_normal_(
69
- tensor, a=param.get("a", 0), nonlinearity=nonlinearity
65
+ tensor, a=param.get("a", 0), nonlinearity=nonlinearity # type: ignore
70
66
  )
71
67
  elif init_type == "orthogonal":
72
68
  nn.init.orthogonal_(tensor, gain=gain)
@@ -80,4 +76,4 @@ def get_initializer(
80
76
  raise ValueError(f"Unknown init_type: {init_type}")
81
77
  return tensor
82
78
 
83
- return initializer_fn
79
+ return initializer_fn
nextrec/utils/model.py CHANGED
@@ -20,3 +20,25 @@ def get_mlp_output_dim(params: dict, fallback: int) -> int:
20
20
  if dims:
21
21
  return dims[-1]
22
22
  return fallback
23
+
24
+
25
+ def select_features(
26
+ available_features: list,
27
+ names: list[str],
28
+ param_name: str,
29
+ ) -> list:
30
+ if not names:
31
+ return []
32
+
33
+ if len(names) != len(set(names)):
34
+ raise ValueError(f"{param_name} contains duplicate feature names: {names}")
35
+
36
+ feature_map = {feat.name: feat for feat in available_features}
37
+ missing = [name for name in names if name not in feature_map]
38
+ if missing:
39
+ raise ValueError(
40
+ f"{param_name} contains unknown feature names {missing}. "
41
+ f"Available features: {list(feature_map)}"
42
+ )
43
+
44
+ return [feature_map[name] for name in names]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.4.2
3
+ Version: 0.4.4
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.4.2-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.4.4-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -71,59 +71,78 @@ English | [中文文档](README_zh.md)
71
71
 
72
72
  </div>
73
73
 
74
+ ## Table of Contents
75
+
76
+ - [Introduction](#introduction)
77
+ - [Installation](#installation)
78
+ - [Architecture](#architecture)
79
+ - [5-Minute Quick Start](#5-minute-quick-start)
80
+ - [CLI Usage](#cli-usage)
81
+ - [Platform Compatibility](#platform-compatibility)
82
+ - [Supported Models](#supported-models)
83
+ - [Contributing](#contributing)
84
+
74
85
  ## Introduction
75
86
 
76
- NextRec is a modern recommendation framework built on PyTorch, delivering a unified experience for modeling, training, and evaluation. It follows a modular design with rich model implementations, data-processing utilities, and engineering-ready training components. NextRec focuses on large-scale industrial recall scenarios on Spark clusters, training on massive offline parquet features.
87
+ NextRec is a modern recommendation framework built on PyTorch, delivering a unified experience for modeling, training, and evaluation. Design with rich model implementations, data-processing utilities, and engineering-ready training components. NextRec focuses on large-scale industrial recommendation scenarios on Spark clusters, training on massive offline features(`parquet/csv`).
77
88
 
78
89
  ## Why NextRec
79
90
 
80
- - **Unified feature engineering & data pipeline**: Dense/Sparse/Sequence feature definitions, persistent DataProcessor, and batch-optimized RecDataLoader, matching offline feature training/inference in industrial big-data settings.
91
+ - **Unified feature engineering & data pipeline**: NextRec provide unified Dense/Sparse/Sequence feature definitions, DataProcessor, and batch-optimized RecDataLoader, matching offline feature training/inference in industrial big-data settings.
81
92
  - **Multi-scenario coverage**: Ranking (CTR/CVR), retrieval, multi-task learning, and more marketing/rec models, with a continuously expanding model zoo.
82
- - **Developer-friendly experience**: Stream processing/training/inference for csv/parquet/pathlike data, plus GPU/MPS acceleration and visualization support.
93
+ - **Developer-friendly experience**: `Stream processing/distributed training/inference` for `csv/parquet/pathlike` data, plus GPU/MPS acceleration and visualization support.
83
94
  - **Efficient training & evaluation**: Standardized engine with optimizers, LR schedulers, early stopping, checkpoints, and detailed logging out of the box.
84
95
 
85
96
  ## Architecture
86
97
 
87
- NextRec adopts a modular and low-coupling engineering design, enabling full-pipeline reusability and scalability across data processing → model construction → training & evaluation → inference & deployment. Its core components include: a Feature-Spec-driven Embedding architecture, the BaseModel abstraction, a set of independent reusable Layers, a unified DataLoader for both training and inference, and a ready-to-use Model Zoo.
98
+ NextRec adopts a modular design, enabling full-pipeline reusability and scalability across data processing → model construction → training & evaluation → inference & deployment. Its core components include: a Feature-Spec-driven Embedding architecture, the BaseModel abstraction, a set of independent reusable Layers, a unified DataLoader for both training and inference, and a ready-to-use Model Zoo.
88
99
 
89
100
  ![NextRec Architecture](assets/nextrec_diagram_en.png)
90
101
 
91
- > The project borrows ideas from excellent open-source rec libraries. Early layers referenced [torch-rechub](https://github.com/datawhalechina/torch-rechub) but have been replaced with in-house implementations. torch-rechub remains mature in architecture and models; the author contributed a bit there—feel free to check it out.
102
+ > The project borrows ideas from excellent open-source rec libraries, for example: [torch-rechub](https://github.com/datawhalechina/torch-rechub). torch-rechub remains mature in architecture and models; the author contributed a bit there—feel free to check it out.
92
103
 
93
104
  ---
94
105
 
95
106
  ## Installation
96
107
 
97
- You can quickly install the latest NextRec via `pip install nextrec`; Python 3.10+ is required.
108
+ You can quickly install the latest NextRec via `pip install nextrec`; Python 3.10+ is required. If you want to run some tutorial codes, pull this project first:
109
+
110
+ ```bash
111
+ git clone https://github.com/zerolovesea/NextRec.git
112
+ cd NextRec/
113
+ pip install nextrec # or pip install -e .
114
+ ```
98
115
 
99
116
  ## Tutorials
100
117
 
101
118
  See `tutorials/` for examples covering ranking, retrieval, multi-task learning, and data processing:
102
119
 
103
- - [movielen_ranking_deepfm.py](/tutorials/movielen_ranking_deepfm.py) — DeepFM training on MovieLens 100k
104
- - [example_ranking_din.py](/tutorials/example_ranking_din.py) — DIN training on the e-commerce dataset
105
- - [example_multitask.py](/tutorials/example_multitask.py) — ESMM multi-task training on the e-commerce dataset
106
- - [movielen_match_dssm.py](/tutorials/example_match_dssm.py) — DSSM retrieval on MovieLens 100k
120
+ - [movielen_ranking_deepfm.py](/tutorials/movielen_ranking_deepfm.py) — DeepFM training on MovieLens 100k dataset
121
+ - [example_ranking_din.py](/tutorials/example_ranking_din.py) — DIN Deep Interest Network training on e-commerce dataset
122
+ - [example_multitask.py](/tutorials/example_multitask.py) — ESMM multi-task learning training on e-commerce dataset
123
+ - [movielen_match_dssm.py](/tutorials/example_match_dssm.py) — DSSM retrieval model training on MovieLens 100k dataset
107
124
 
108
- To dive deeper, Jupyter notebooks are available:
125
+ - [run_all_ranking_models.py](/tutorials/run_all_ranking_models.py) Quickly validate availability of all ranking models
126
+ - [run_all_multitask_models.py](/tutorials/run_all_multitask_models.py) — Quickly validate availability of all multi-task models
127
+ - [run_all_match_models.py](/tutorials/run_all_match_models.py) — Quickly validate availability of all retrieval models
128
+
129
+ To dive deeper into NextRec framework details, Jupyter notebooks are available:
109
130
 
110
131
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
132
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
133
 
113
- > Current version [0.4.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
-
115
134
  ## 5-Minute Quick Start
116
135
 
117
- We provide a detailed quick start and paired datasets to help you learn the framework. In `datasets/` youll find an e-commerce sample dataset like this:
136
+ We provide a detailed quick-start guide and paired datasets to help you get familiar with different features of NextRec framework. In `datasets/` you'll find an e-commerce scenario test dataset like this:
118
137
 
119
138
  | user_id | item_id | dense_0 | dense_1 | dense_2 | dense_3 | dense_4 | dense_5 | dense_6 | dense_7 | sparse_0 | sparse_1 | sparse_2 | sparse_3 | sparse_4 | sparse_5 | sparse_6 | sparse_7 | sparse_8 | sparse_9 | sequence_0 | sequence_1 | label |
120
139
  |--------|---------|-------------|-------------|-------------|------------|-------------|-------------|-------------|-------------|----------|----------|----------|----------|----------|----------|----------|----------|----------|----------|-----------------------------------------------------------|-----------------------------------------------------------|-------|
121
140
  | 1 | 7817 | 0.14704075 | 0.31020382 | 0.77780896 | 0.944897 | 0.62315375 | 0.57124174 | 0.77009535 | 0.3211029 | 315 | 260 | 379 | 146 | 168 | 161 | 138 | 88 | 5 | 312 | [170,175,97,338,105,353,272,546,175,545,463,128,0,0,0] | [368,414,820,405,548,63,327,0,0,0,0,0,0,0,0] | 0 |
122
141
  | 1 | 3579 | 0.77811223 | 0.80359334 | 0.5185201 | 0.91091245 | 0.043562356 | 0.82142705 | 0.8803686 | 0.33748195 | 149 | 229 | 442 | 6 | 167 | 252 | 25 | 402 | 7 | 168 | [179,48,61,551,284,165,344,151,0,0,0,0,0,0,0] | [814,0,0,0,0,0,0,0,0,0,0,0,0,0,0] | 1 |
123
142
 
124
- Below is a short example showing how to train a DIN model. DIN (Deep Interest Network) won Best Paper at KDD 2018 for CTR prediction. You can also run `python tutorials/example_ranking_din.py` directly.
143
+ Below is a short example showing how to train a DIN (Deep Interest Network) model. You can also run `python tutorials/example_ranking_din.py` directly to execute the training and inference code.
125
144
 
126
- After training, detailed logs are available under `nextrec_logs/din_tutorial`.
145
+ After training starts, you can find detailed training logs at `nextrec_logs/din_tutorial`.
127
146
 
128
147
  ```python
129
148
  import pandas as pd
@@ -196,9 +215,26 @@ metrics = model.evaluate(
196
215
  )
197
216
  ```
198
217
 
218
+ ## CLI Usage
219
+
220
+ NextRec provides a powerful command-line interface for model training and prediction using YAML configuration files. For detailed CLI documentation, see:
221
+
222
+ - [NextRec CLI User Guide](/nextrec_cli_preset/NextRec-CLI.md) - Complete guide for using the CLI
223
+ - [NextRec CLI Configuration Examples](/nextrec_cli_preset/) - CLI configuration file examples
224
+
225
+ ```bash
226
+ # Train a model
227
+ nextrec --mode=train --train_config=path/to/train_config.yaml
228
+
229
+ # Run prediction
230
+ nextrec --mode=predict --predict_config=path/to/predict_config.yaml
231
+ ```
232
+
233
+ > As of version 0.4.4, NextRec CLI supports single-machine training; distributed training features are currently under development.
234
+
199
235
  ## Platform Compatibility
200
236
 
201
- The current version is 0.4.2. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
237
+ The current version is 0.4.4. All models and test code have been validated on the following platforms. If you encounter compatibility issues, please report them in the issue tracker with your system version:
202
238
 
203
239
  | Platform | Configuration |
204
240
  |----------|---------------|
@@ -247,14 +283,13 @@ The current version is 0.4.2. All models and test code have been validated on th
247
283
  | [ESMM](nextrec/models/multi_task/esmm.py) | Entire Space Multi-task Model | SIGIR 2018 | Supported |
248
284
  | [ShareBottom](nextrec/models/multi_task/share_bottom.py) | Multitask Learning | - | Supported |
249
285
  | [POSO](nextrec/models/multi_task/poso.py) | POSO: Personalized Cold-start Modules for Large-scale Recommender Systems | 2021 | Supported |
250
- | [POSO-IFLYTEK](nextrec/models/multi_task/poso_iflytek.py) | POSO with PLE-style gating for sequential marketing tasks | - | Supported |
251
286
 
252
287
  ### Generative Models
253
288
 
254
289
  | Model | Paper | Year | Status |
255
290
  |-------|-------|------|--------|
256
291
  | [TIGER](nextrec/models/generative/tiger.py) | Recommender Systems with Generative Retrieval | NeurIPS 2023 | In Progress |
257
- | [HSTU](nextrec/models/generative/hstu.py) | Hierarchical Sequential Transduction Units | - | In Progress |
292
+ | [HSTU](nextrec/models/generative/hstu.py) | Hierarchical Sequential Transduction Units | - | Supported |
258
293
 
259
294
  ---
260
295
 
@@ -270,7 +305,7 @@ We welcome contributions of any form!
270
305
  4. Push your branch (`git push origin feature/AmazingFeature`)
271
306
  5. Open a Pull Request
272
307
 
273
- > Before submitting a PR, please run tests using `pytest test/ -v` or `python -m pytest` to ensure everything passes.
308
+ > Before submitting a PR, please run `python test/run_tests.py` and `python scripts/format_code.py` to ensure all tests pass and code style is consistent.
274
309
 
275
310
  ### Code Style
276
311