nextrec 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. nextrec/__version__.py +1 -1
  2. nextrec/basic/activation.py +4 -8
  3. nextrec/basic/callback.py +1 -1
  4. nextrec/basic/features.py +33 -25
  5. nextrec/basic/layers.py +164 -601
  6. nextrec/basic/loggers.py +3 -4
  7. nextrec/basic/metrics.py +39 -115
  8. nextrec/basic/model.py +248 -174
  9. nextrec/basic/session.py +1 -5
  10. nextrec/data/__init__.py +12 -0
  11. nextrec/data/data_utils.py +3 -27
  12. nextrec/data/dataloader.py +26 -34
  13. nextrec/data/preprocessor.py +2 -1
  14. nextrec/loss/listwise.py +6 -4
  15. nextrec/loss/loss_utils.py +10 -6
  16. nextrec/loss/pairwise.py +5 -3
  17. nextrec/loss/pointwise.py +7 -13
  18. nextrec/models/match/mind.py +110 -1
  19. nextrec/models/multi_task/esmm.py +46 -27
  20. nextrec/models/multi_task/mmoe.py +48 -30
  21. nextrec/models/multi_task/ple.py +156 -141
  22. nextrec/models/multi_task/poso.py +413 -0
  23. nextrec/models/multi_task/share_bottom.py +43 -26
  24. nextrec/models/ranking/__init__.py +2 -0
  25. nextrec/models/ranking/autoint.py +1 -1
  26. nextrec/models/ranking/dcn.py +20 -1
  27. nextrec/models/ranking/dcn_v2.py +84 -0
  28. nextrec/models/ranking/deepfm.py +44 -18
  29. nextrec/models/ranking/dien.py +130 -27
  30. nextrec/models/ranking/masknet.py +13 -67
  31. nextrec/models/ranking/widedeep.py +39 -18
  32. nextrec/models/ranking/xdeepfm.py +34 -1
  33. nextrec/utils/common.py +26 -1
  34. nextrec-0.3.1.dist-info/METADATA +306 -0
  35. nextrec-0.3.1.dist-info/RECORD +56 -0
  36. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/WHEEL +1 -1
  37. nextrec-0.2.6.dist-info/METADATA +0 -281
  38. nextrec-0.2.6.dist-info/RECORD +0 -54
  39. {nextrec-0.2.6.dist-info → nextrec-0.3.1.dist-info}/licenses/LICENSE +0 -0
nextrec/basic/session.py CHANGED
@@ -13,8 +13,6 @@ Date: create on 23/11/2025
13
13
  Author: Yang Zhou,zyaztec@gmail.com
14
14
  """
15
15
 
16
- from __future__ import annotations
17
-
18
16
  import os
19
17
  import tempfile
20
18
  from dataclasses import dataclass
@@ -95,7 +93,7 @@ def create_session(experiment_id: str | Path | None = None) -> Session:
95
93
  return Session(experiment_id=exp_id, root=root)
96
94
 
97
95
  def resolve_save_path(
98
- path: str | Path | None,
96
+ path: str | os.PathLike | Path | None,
99
97
  default_dir: str | Path,
100
98
  default_name: str,
101
99
  suffix: str,
@@ -146,5 +144,3 @@ def resolve_save_path(
146
144
  file_stem = f"{file_stem}_{timestamp}"
147
145
 
148
146
  return (base_dir / f"{file_stem}{normalized_suffix}").resolve()
149
-
150
-
nextrec/data/__init__.py CHANGED
@@ -20,6 +20,13 @@ from nextrec.data.data_utils import (
20
20
  )
21
21
  from nextrec.basic.features import FeatureSpecMixin
22
22
  from nextrec.data import data_utils
23
+ from nextrec.data.dataloader import (
24
+ TensorDictDataset,
25
+ FileDataset,
26
+ RecDataLoader,
27
+ build_tensors_from_data,
28
+ )
29
+ from nextrec.data.preprocessor import DataProcessor
23
30
 
24
31
  __all__ = [
25
32
  'collate_fn',
@@ -33,4 +40,9 @@ __all__ = [
33
40
  'load_dataframes',
34
41
  'FeatureSpecMixin',
35
42
  'data_utils',
43
+ 'TensorDictDataset',
44
+ 'FileDataset',
45
+ 'RecDataLoader',
46
+ 'build_tensors_from_data',
47
+ 'DataProcessor',
36
48
  ]
@@ -6,19 +6,17 @@ import pandas as pd
6
6
  import pyarrow.parquet as pq
7
7
  from pathlib import Path
8
8
 
9
-
10
9
  def _stack_section(batch: list[dict], section: str):
11
10
  """Stack one section of the batch (features/labels/ids)."""
12
11
  entries = [item.get(section) for item in batch if item.get(section) is not None]
13
12
  if not entries:
14
13
  return None
15
14
  merged: dict = {}
16
- for name in entries[0]:
15
+ for name in entries[0]: # type: ignore
17
16
  tensors = [item[section][name] for item in batch if item.get(section) is not None and name in item[section]]
18
17
  merged[name] = torch.stack(tensors, dim=0)
19
18
  return merged
20
19
 
21
-
22
20
  def collate_fn(batch):
23
21
  """
24
22
  Collate a list of sample dicts into the unified batch format:
@@ -66,7 +64,6 @@ def collate_fn(batch):
66
64
  result.append(stacked)
67
65
  return tuple(result)
68
66
 
69
-
70
67
  def get_column_data(data: dict | pd.DataFrame, name: str):
71
68
  """Extract column data from various data structures."""
72
69
  if isinstance(data, dict):
@@ -80,7 +77,6 @@ def get_column_data(data: dict | pd.DataFrame, name: str):
80
77
  return getattr(data, name)
81
78
  raise KeyError(f"Unsupported data type for extracting column {name}")
82
79
 
83
-
84
80
  def resolve_file_paths(path: str) -> tuple[list[str], str]:
85
81
  """Resolve file or directory path into a sorted list of files and file type."""
86
82
  path_obj = Path(path)
@@ -106,7 +102,6 @@ def resolve_file_paths(path: str) -> tuple[list[str], str]:
106
102
 
107
103
  raise ValueError(f"Invalid path: {path}")
108
104
 
109
-
110
105
  def iter_file_chunks(file_path: str, file_type: str, chunk_size: int):
111
106
  """Yield DataFrame chunks for CSV/Parquet without loading the whole file."""
112
107
  if file_type == "csv":
@@ -116,19 +111,16 @@ def iter_file_chunks(file_path: str, file_type: str, chunk_size: int):
116
111
  for batch in parquet_file.iter_batches(batch_size=chunk_size):
117
112
  yield batch.to_pandas()
118
113
 
119
-
120
114
  def read_table(file_path: str, file_type: str) -> pd.DataFrame:
121
115
  """Read a single CSV/Parquet file."""
122
116
  if file_type == "csv":
123
117
  return pd.read_csv(file_path)
124
118
  return pd.read_parquet(file_path)
125
119
 
126
-
127
120
  def load_dataframes(file_paths: list[str], file_type: str) -> list[pd.DataFrame]:
128
121
  """Load multiple files of the same type into DataFrames."""
129
122
  return [read_table(fp, file_type) for fp in file_paths]
130
123
 
131
-
132
124
  def default_output_dir(path: str) -> Path:
133
125
  """Generate a default output directory path based on the input path."""
134
126
  path_obj = Path(path)
@@ -136,19 +128,16 @@ def default_output_dir(path: str) -> Path:
136
128
  return path_obj.parent / f"{path_obj.stem}_preprocessed"
137
129
  return path_obj.with_name(f"{path_obj.name}_preprocessed")
138
130
 
139
-
140
- def split_dict_random(data_dict: dict, test_size: float=0.2, random_state:int|None=None):
131
+ def split_dict_random(data_dict: dict, test_size: float = 0.2, random_state: int | None = None):
141
132
  """Randomly split a dictionary of data into training and testing sets."""
142
133
  lengths = [len(v) for v in data_dict.values()]
143
134
  if len(set(lengths)) != 1:
144
135
  raise ValueError(f"Length mismatch: {lengths}")
145
136
  n = lengths[0]
146
-
147
137
  rng = np.random.default_rng(random_state)
148
138
  perm = rng.permutation(n)
149
139
  cut = int(round(n * (1 - test_size)))
150
140
  train_idx, test_idx = perm[:cut], perm[cut:]
151
-
152
141
  def take(v, idx):
153
142
  if isinstance(v, np.ndarray):
154
143
  return v[idx]
@@ -157,12 +146,10 @@ def split_dict_random(data_dict: dict, test_size: float=0.2, random_state:int|No
157
146
  else:
158
147
  v_arr = np.asarray(v, dtype=object)
159
148
  return v_arr[idx]
160
-
161
149
  train_dict = {k: take(v, train_idx) for k, v in data_dict.items()}
162
150
  test_dict = {k: take(v, test_idx) for k, v in data_dict.items()}
163
151
  return train_dict, test_dict
164
152
 
165
-
166
153
  def build_eval_candidates(
167
154
  df_all: pd.DataFrame,
168
155
  user_col: str,
@@ -179,37 +166,26 @@ def build_eval_candidates(
179
166
 
180
167
  users = df_all[user_col].unique()
181
168
  all_items = item_features[item_col].unique()
182
-
183
169
  rows = []
184
-
185
- user_hist_items = {
186
- u: df_all[df_all[user_col] == u][item_col].unique()
187
- for u in users
188
- }
189
-
170
+ user_hist_items = {u: df_all[df_all[user_col] == u][item_col].unique() for u in users}
190
171
  for u in users:
191
172
  df_user = df_all[df_all[user_col] == u]
192
173
  pos_items = df_user[df_user[label_col] == 1][item_col].unique()
193
174
  if len(pos_items) == 0:
194
175
  continue
195
-
196
176
  pos_items = pos_items[:num_pos_per_user]
197
177
  seen_items = set(user_hist_items[u])
198
-
199
178
  neg_pool = np.setdiff1d(all_items, np.fromiter(seen_items, dtype=all_items.dtype))
200
179
  if len(neg_pool) == 0:
201
180
  continue
202
-
203
181
  for pos in pos_items:
204
182
  if len(neg_pool) <= num_neg_per_pos:
205
183
  neg_items = neg_pool
206
184
  else:
207
185
  neg_items = rng.choice(neg_pool, size=num_neg_per_pos, replace=False)
208
-
209
186
  rows.append((u, pos, 1))
210
187
  for ni in neg_items:
211
188
  rows.append((u, ni, 0))
212
-
213
189
  eval_df = pd.DataFrame(rows, columns=[user_col, item_col, label_col])
214
190
  eval_df = eval_df.merge(user_features, on=user_col, how='left')
215
191
  eval_df = eval_df.merge(item_features, on=item_col, how='left')
@@ -2,7 +2,7 @@
2
2
  Dataloader definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Update: 25/11/2025
5
+ Checkpoint: edit on 29/11/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
  import os
@@ -14,7 +14,7 @@ import pandas as pd
14
14
  import pyarrow.parquet as pq
15
15
 
16
16
  from pathlib import Path
17
- from typing import Iterator, Literal, Union, Optional
17
+ from typing import cast
18
18
 
19
19
  from torch.utils.data import DataLoader, Dataset, IterableDataset
20
20
  from nextrec.data.preprocessor import DataProcessor
@@ -35,15 +35,14 @@ class TensorDictDataset(Dataset):
35
35
  self.labels = tensors.get("labels")
36
36
  self.ids = tensors.get("ids")
37
37
  if not self.features:
38
- raise ValueError("Dataset requires at least one feature tensor.")
38
+ raise ValueError("[TensorDictDataset Error] Dataset requires at least one feature tensor.")
39
39
  lengths = [tensor.shape[0] for tensor in self.features.values()]
40
40
  if not lengths:
41
- raise ValueError("Feature tensors are empty.")
41
+ raise ValueError("[TensorDictDataset Error] Feature tensors are empty.")
42
42
  self.length = lengths[0]
43
43
  for length in lengths[1:]:
44
44
  if length != self.length:
45
- raise ValueError("All feature tensors must have the same length.")
46
-
45
+ raise ValueError("[TensorDictDataset Error] All feature tensors must have the same length.")
47
46
  def __len__(self) -> int:
48
47
  return self.length
49
48
 
@@ -53,7 +52,6 @@ class TensorDictDataset(Dataset):
53
52
  sample_ids = {name: tensor[idx] for name, tensor in self.ids.items()} if self.ids else None
54
53
  return {"features": sample_features, "labels": sample_labels, "ids": sample_ids}
55
54
 
56
-
57
55
  class FileDataset(FeatureSpecMixin, IterableDataset):
58
56
  def __init__(self,
59
57
  file_paths: list[str], # file paths to read, containing CSV or Parquet files
@@ -109,18 +107,14 @@ class FileDataset(FeatureSpecMixin, IterableDataset):
109
107
  def _dataframe_to_tensors(self, df: pd.DataFrame) -> dict | None:
110
108
  if self.processor is not None:
111
109
  if not self.processor.is_fitted:
112
- raise ValueError("DataProcessor must be fitted before using in streaming mode")
110
+ raise ValueError("[DataLoader Error] DataProcessor must be fitted before using in streaming mode")
113
111
  transformed_data = self.processor.transform(df, return_dict=True)
114
112
  else:
115
113
  transformed_data = df
116
-
117
- batch = build_tensors_from_data(
118
- data=transformed_data,
119
- raw_data=df,
120
- features=self.all_features,
121
- target_columns=self.target_columns,
122
- id_columns=self.id_columns,
123
- )
114
+ if isinstance(transformed_data, list):
115
+ raise TypeError("[DataLoader Error] DataProcessor.transform returned file paths; use return_dict=True with in-memory data for streaming.")
116
+ safe_data = cast(dict | pd.DataFrame, transformed_data)
117
+ batch = build_tensors_from_data(data=safe_data, raw_data=df, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns)
124
118
  if batch is not None:
125
119
  batch["_already_batched"] = True
126
120
  return batch
@@ -133,12 +127,12 @@ class RecDataLoader(FeatureSpecMixin):
133
127
  sequence_features: list[SequenceFeature] | None = None,
134
128
  target: list[str] | None | str = None,
135
129
  id_columns: str | list[str] | None = None,
136
- processor: Optional['DataProcessor'] = None):
130
+ processor: DataProcessor | None = None):
137
131
  self.processor = processor
138
132
  self._set_feature_config(dense_features, sparse_features, sequence_features, target, id_columns)
139
133
 
140
134
  def create_dataloader(self,
141
- data: Union[dict, pd.DataFrame, str, DataLoader],
135
+ data: dict | pd.DataFrame | str | DataLoader,
142
136
  batch_size: int = 32,
143
137
  shuffle: bool = True,
144
138
  load_full: bool = True,
@@ -150,21 +144,21 @@ class RecDataLoader(FeatureSpecMixin):
150
144
  elif isinstance(data, (dict, pd.DataFrame)):
151
145
  return self._create_from_memory(data=data, batch_size=batch_size, shuffle=shuffle)
152
146
  else:
153
- raise ValueError(f"Unsupported data type: {type(data)}")
147
+ raise ValueError(f"[RecDataLoader Error] Unsupported data type: {type(data)}")
154
148
 
155
149
  def _create_from_memory(self,
156
- data: Union[dict, pd.DataFrame],
150
+ data: dict | pd.DataFrame,
157
151
  batch_size: int,
158
152
  shuffle: bool) -> DataLoader:
159
153
  raw_data = data
160
154
 
161
155
  if self.processor is not None:
162
156
  if not self.processor.is_fitted:
163
- raise ValueError("DataProcessor must be fitted before transforming data in memory")
164
- data = self.processor.transform(data, return_dict=True)
157
+ raise ValueError("[RecDataLoader Error] DataProcessor must be fitted before transforming data in memory")
158
+ data = self.processor.transform(data, return_dict=True) # type: ignore
165
159
  tensors = build_tensors_from_data(data=data,raw_data=raw_data, features=self.all_features, target_columns=self.target_columns, id_columns=self.id_columns,)
166
160
  if tensors is None:
167
- raise ValueError("No valid tensors could be built from the provided data.")
161
+ raise ValueError("[RecDataLoader Error] No valid tensors could be built from the provided data.")
168
162
  dataset = TensorDictDataset(tensors)
169
163
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)
170
164
 
@@ -188,11 +182,11 @@ class RecDataLoader(FeatureSpecMixin):
188
182
  df = read_table(file_path, file_type)
189
183
  dfs.append(df)
190
184
  except MemoryError as exc:
191
- raise MemoryError(f"Out of memory while reading {file_path}. Consider using load_full=False with streaming.") from exc
185
+ raise MemoryError(f"[RecDataLoader Error] Out of memory while reading {file_path}. Consider using load_full=False with streaming.") from exc
192
186
  try:
193
187
  combined_df = pd.concat(dfs, ignore_index=True)
194
188
  except MemoryError as exc:
195
- raise MemoryError(f"Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size.") from exc
189
+ raise MemoryError(f"[RecDataLoader Error] Out of memory while concatenating loaded data (approx {total_bytes / (1024**3):.2f} GB). Use load_full=False to stream or reduce chunk_size.") from exc
196
190
  return self._create_from_memory(combined_df, batch_size, shuffle,)
197
191
  else:
198
192
  return self._load_files_streaming(file_paths, file_type, batch_size, chunk_size, shuffle)
@@ -204,9 +198,9 @@ class RecDataLoader(FeatureSpecMixin):
204
198
  chunk_size: int,
205
199
  shuffle: bool) -> DataLoader:
206
200
  if shuffle:
207
- logging.warning("Shuffle is ignored in streaming mode (IterableDataset).")
201
+ logging.warning("[RecDataLoader Warning] Shuffle is ignored in streaming mode (IterableDataset).")
208
202
  if batch_size != 1:
209
- logging.warning("Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
203
+ logging.warning("[RecDataLoader Warning] Streaming mode enforces batch_size=1; tune chunk_size to control memory/throughput.")
210
204
  dataset = FileDataset(
211
205
  file_paths=file_paths,
212
206
  dense_features=self.dense_features,
@@ -230,22 +224,20 @@ def _normalize_sequence_column(column, feature: SequenceFeature) -> np.ndarray:
230
224
  if column.ndim == 0:
231
225
  column = column.reshape(1)
232
226
  if column.dtype == object and any(isinstance(v, str) for v in column.ravel()):
233
- raise TypeError(f"Sequence feature '{feature.name}' expects numeric sequences; found string values.")
227
+ raise TypeError(f"[RecDataLoader Error] Sequence feature '{feature.name}' expects numeric sequences; found string values.")
234
228
  if column.dtype == object and len(column) > 0 and isinstance(column[0], (list, tuple, np.ndarray)):
235
229
  sequences = []
236
230
  for seq in column:
237
231
  if isinstance(seq, str):
238
- raise TypeError(f"Sequence feature '{feature.name}' expects numeric sequences; found string values.")
232
+ raise TypeError(f"[RecDataLoader Error] Sequence feature '{feature.name}' expects numeric sequences; found string values.")
239
233
  if isinstance(seq, (list, tuple, np.ndarray)):
240
234
  arr = np.asarray(seq, dtype=np.int64)
241
235
  else:
242
236
  arr = np.asarray([seq], dtype=np.int64)
243
237
  sequences.append(arr)
244
-
245
238
  max_len = getattr(feature, "max_len", 0)
246
239
  if max_len <= 0:
247
240
  max_len = max((len(seq) for seq in sequences), default=1)
248
-
249
241
  pad_value = getattr(feature, "padding_idx", 0)
250
242
  padded = []
251
243
  for seq in sequences:
@@ -270,7 +262,7 @@ def build_tensors_from_data( # noqa: C901
270
262
  for feature in features:
271
263
  column = get_column_data(data, feature.name)
272
264
  if column is None:
273
- raise ValueError(f"Feature column '{feature.name}' not found in data")
265
+ raise ValueError(f"[RecDataLoader Error] Feature column '{feature.name}' not found in data")
274
266
  if isinstance(feature, SequenceFeature):
275
267
  tensor = torch.from_numpy(_normalize_sequence_column(column, feature))
276
268
  elif isinstance(feature, DenseFeature):
@@ -301,11 +293,11 @@ def build_tensors_from_data( # noqa: C901
301
293
  if column is None:
302
294
  column = get_column_data(data, id_col)
303
295
  if column is None:
304
- raise KeyError(f"ID column '{id_col}' not found in provided data.")
296
+ raise KeyError(f"[RecDataLoader Error] ID column '{id_col}' not found in provided data.")
305
297
  try:
306
298
  id_arr = np.asarray(column, dtype=np.int64)
307
299
  except Exception as exc:
308
- raise TypeError( f"ID column '{id_col}' must contain numeric values. Received dtype={np.asarray(column).dtype}, error: {exc}") from exc
300
+ raise TypeError( f"[RecDataLoader Error] ID column '{id_col}' must contain numeric values. Received dtype={np.asarray(column).dtype}, error: {exc}") from exc
309
301
  id_tensors[id_col] = torch.from_numpy(id_arr)
310
302
  if not feature_tensors:
311
303
  return None
@@ -12,6 +12,7 @@ import logging
12
12
  import numpy as np
13
13
  import pandas as pd
14
14
 
15
+ import tqdm
15
16
  from pathlib import Path
16
17
  from typing import Dict, Union, Optional, Literal, Any
17
18
  from sklearn.preprocessing import (
@@ -665,7 +666,7 @@ class DataProcessor(FeatureSpecMixin):
665
666
  output_root = base_output_dir / "transformed_data"
666
667
  output_root.mkdir(parents=True, exist_ok=True)
667
668
  saved_paths = []
668
- for file_path in file_paths:
669
+ for file_path in tqdm.tqdm(file_paths, desc="Transforming files", unit="file"):
669
670
  df = read_table(file_path, file_type)
670
671
  transformed_df = self._transform_in_memory(
671
672
  df,
nextrec/loss/listwise.py CHANGED
@@ -1,8 +1,10 @@
1
1
  """
2
2
  Listwise loss functions for ranking and contrastive training.
3
- """
4
3
 
5
- from typing import Optional
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
7
+ """
6
8
 
7
9
  import torch
8
10
  import torch.nn as nn
@@ -112,7 +114,7 @@ class ApproxNDCGLoss(nn.Module):
112
114
  self.temperature = temperature
113
115
  self.reduction = reduction
114
116
 
115
- def _ideal_dcg(self, labels: torch.Tensor, k: Optional[int]) -> torch.Tensor:
117
+ def _ideal_dcg(self, labels: torch.Tensor, k: int | None) -> torch.Tensor:
116
118
  # labels: [B, L]
117
119
  sorted_labels, _ = torch.sort(labels, dim=1, descending=True)
118
120
  if k is not None:
@@ -127,7 +129,7 @@ class ApproxNDCGLoss(nn.Module):
127
129
  return ideal_dcg
128
130
 
129
131
  def forward(
130
- self, scores: torch.Tensor, labels: torch.Tensor, k: Optional[int] = None
132
+ self, scores: torch.Tensor, labels: torch.Tensor, k: int | None = None
131
133
  ) -> torch.Tensor:
132
134
  """
133
135
  scores: [B, L]
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Loss utilities for NextRec.
3
+
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from typing import Literal
@@ -39,7 +43,7 @@ def get_loss_fn(loss=None, **kw):
39
43
  if isinstance(loss, nn.Module):
40
44
  return loss
41
45
  if loss is None:
42
- raise ValueError("loss must be provided explicitly")
46
+ raise ValueError("[Loss Error] loss must be provided explicitly")
43
47
  if loss in ["bce", "binary_crossentropy"]:
44
48
  return nn.BCELoss(**kw)
45
49
  if loss == "weighted_bce":
@@ -75,15 +79,15 @@ def get_loss_fn(loss=None, **kw):
75
79
  if loss == "approx_ndcg":
76
80
  return ApproxNDCGLoss(**kw)
77
81
 
78
- raise ValueError(f"Unsupported loss: {loss}")
82
+ raise ValueError(f"[Loss Error] Unsupported loss: {loss}")
79
83
 
80
84
  def get_loss_kwargs(loss_params: dict | list[dict] | None, index: int = 0) -> dict:
81
85
  """
82
- 解析每个 head 对应的 loss_kwargs。
86
+ Parse loss_kwargs for each head.
83
87
 
84
- - loss_params None -> {}
85
- - loss_params dict -> 所有 head 共用
86
- - loss_params list[dict] -> loss_params[index](若存在且非 None),否则 {}
88
+ - loss_params is None -> {}
89
+ - loss_params is dict -> shared by all heads
90
+ - loss_params is list[dict] -> use loss_params[index] (if exists and not None), else {}
87
91
  """
88
92
  if loss_params is None:
89
93
  return {}
nextrec/loss/pairwise.py CHANGED
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Pairwise loss functions for learning-to-rank and matching tasks.
3
+
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from typing import Literal
@@ -32,7 +36,6 @@ class BPRLoss(nn.Module):
32
36
  return loss.sum()
33
37
  return loss
34
38
 
35
-
36
39
  class HingeLoss(nn.Module):
37
40
  """
38
41
  Hinge loss for pairwise ranking.
@@ -56,7 +59,6 @@ class HingeLoss(nn.Module):
56
59
  return loss.sum()
57
60
  return loss
58
61
 
59
-
60
62
  class TripletLoss(nn.Module):
61
63
  """
62
64
  Triplet margin loss with cosine or euclidean distance.
@@ -95,7 +97,7 @@ class TripletLoss(nn.Module):
95
97
  if neg_dist.dim() == 2:
96
98
  pos_dist = pos_dist.unsqueeze(1)
97
99
  else:
98
- raise ValueError(f"Unsupported distance: {self.distance}")
100
+ raise ValueError(f"[Loss Error] Unsupported distance: {self.distance}")
99
101
 
100
102
  loss = torch.clamp(pos_dist - neg_dist + self.margin, min=0)
101
103
  if self.reduction == "mean":
nextrec/loss/pointwise.py CHANGED
@@ -1,5 +1,9 @@
1
1
  """
2
2
  Pointwise loss functions, including imbalance-aware variants.
3
+
4
+ Date: create on 27/10/2025
5
+ Checkpoint: edit on 29/11/2025
6
+ Author: Yang Zhou, zyaztec@gmail.com
3
7
  """
4
8
 
5
9
  from typing import Optional, Sequence
@@ -55,10 +59,7 @@ class WeightedBCELoss(nn.Module):
55
59
  self.auto_balance = auto_balance
56
60
 
57
61
  if pos_weight is not None:
58
- self.register_buffer(
59
- "pos_weight",
60
- torch.as_tensor(pos_weight, dtype=torch.float32),
61
- )
62
+ self.register_buffer("pos_weight", torch.as_tensor(pos_weight, dtype=torch.float32),)
62
63
  else:
63
64
  self.pos_weight = None
64
65
 
@@ -128,9 +129,7 @@ class FocalLoss(nn.Module):
128
129
  else:
129
130
  targets = targets.float()
130
131
  if self.logits:
131
- ce_loss = F.binary_cross_entropy_with_logits(
132
- inputs, targets, reduction="none"
133
- )
132
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
134
133
  probs = torch.sigmoid(inputs)
135
134
  else:
136
135
  ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none")
@@ -140,7 +139,6 @@ class FocalLoss(nn.Module):
140
139
  alpha_factor = self._get_binary_alpha(targets, inputs.device)
141
140
  focal_weight = (1.0 - p_t) ** self.gamma
142
141
  loss = alpha_factor * focal_weight * ce_loss
143
-
144
142
  if self.reduction == "mean":
145
143
  return loss.mean()
146
144
  if self.reduction == "sum":
@@ -163,13 +161,11 @@ class FocalLoss(nn.Module):
163
161
  alpha_tensor = torch.tensor(self.alpha, device=device, dtype=targets.dtype)
164
162
  return torch.where(targets == 1, alpha_tensor, 1 - alpha_tensor)
165
163
 
166
-
167
164
  class ClassBalancedFocalLoss(nn.Module):
168
165
  """
169
166
  Focal loss weighted by effective number of samples per class.
170
167
  Reference: "Class-Balanced Loss Based on Effective Number of Samples"
171
168
  """
172
-
173
169
  def __init__(
174
170
  self,
175
171
  class_counts: Sequence[int] | torch.Tensor,
@@ -187,9 +183,7 @@ class ClassBalancedFocalLoss(nn.Module):
187
183
  self.register_buffer("class_weights", weights)
188
184
 
189
185
  def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
190
- focal = FocalLoss(
191
- gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True
192
- )
186
+ focal = FocalLoss(gamma=self.gamma, alpha=self.class_weights, reduction="none", logits=True)
193
187
  loss = focal(inputs, targets)
194
188
  if self.reduction == "mean":
195
189
  return loss.mean()
@@ -13,7 +13,116 @@ from typing import Literal
13
13
 
14
14
  from nextrec.basic.model import BaseMatchModel
15
15
  from nextrec.basic.features import DenseFeature, SparseFeature, SequenceFeature
16
- from nextrec.basic.layers import MLP, EmbeddingLayer, CapsuleNetwork
16
+ from nextrec.basic.layers import MLP, EmbeddingLayer
17
+
18
+ class MultiInterestSA(nn.Module):
19
+ """Multi-interest self-attention extractor from MIND (Li et al., 2019)."""
20
+
21
+ def __init__(self, embedding_dim, interest_num, hidden_dim=None):
22
+ super(MultiInterestSA, self).__init__()
23
+ self.embedding_dim = embedding_dim
24
+ self.interest_num = interest_num
25
+ if hidden_dim == None:
26
+ self.hidden_dim = self.embedding_dim * 4
27
+ self.W1 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.hidden_dim), requires_grad=True)
28
+ self.W2 = torch.nn.Parameter(torch.rand(self.hidden_dim, self.interest_num), requires_grad=True)
29
+ self.W3 = torch.nn.Parameter(torch.rand(self.embedding_dim, self.embedding_dim), requires_grad=True)
30
+
31
+ def forward(self, seq_emb, mask=None):
32
+ H = torch.einsum('bse, ed -> bsd', seq_emb, self.W1).tanh()
33
+ if mask != None:
34
+ A = torch.einsum('bsd, dk -> bsk', H, self.W2) + -1.e9 * (1 - mask.float())
35
+ A = F.softmax(A, dim=1)
36
+ else:
37
+ A = F.softmax(torch.einsum('bsd, dk -> bsk', H, self.W2), dim=1)
38
+ A = A.permute(0, 2, 1)
39
+ multi_interest_emb = torch.matmul(A, seq_emb)
40
+ return multi_interest_emb
41
+
42
+
43
+ class CapsuleNetwork(nn.Module):
44
+ """Dynamic routing capsule network used in MIND (Li et al., 2019)."""
45
+
46
+ def __init__(self, embedding_dim, seq_len, bilinear_type=2, interest_num=4, routing_times=3, relu_layer=False):
47
+ super(CapsuleNetwork, self).__init__()
48
+ self.embedding_dim = embedding_dim # h
49
+ self.seq_len = seq_len # s
50
+ self.bilinear_type = bilinear_type
51
+ self.interest_num = interest_num
52
+ self.routing_times = routing_times
53
+
54
+ self.relu_layer = relu_layer
55
+ self.stop_grad = True
56
+ self.relu = nn.Sequential(nn.Linear(self.embedding_dim, self.embedding_dim, bias=False), nn.ReLU())
57
+ if self.bilinear_type == 0: # MIND
58
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
59
+ elif self.bilinear_type == 1:
60
+ self.linear = nn.Linear(self.embedding_dim, self.embedding_dim * self.interest_num, bias=False)
61
+ else:
62
+ self.w = nn.Parameter(torch.Tensor(1, self.seq_len, self.interest_num * self.embedding_dim, self.embedding_dim))
63
+ nn.init.xavier_uniform_(self.w)
64
+
65
+ def forward(self, item_eb, mask):
66
+ if self.bilinear_type == 0:
67
+ item_eb_hat = self.linear(item_eb)
68
+ item_eb_hat = item_eb_hat.repeat(1, 1, self.interest_num)
69
+ elif self.bilinear_type == 1:
70
+ item_eb_hat = self.linear(item_eb)
71
+ else:
72
+ u = torch.unsqueeze(item_eb, dim=2)
73
+ item_eb_hat = torch.sum(self.w[:, :self.seq_len, :, :] * u, dim=3)
74
+
75
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.seq_len, self.interest_num, self.embedding_dim))
76
+ item_eb_hat = torch.transpose(item_eb_hat, 1, 2).contiguous()
77
+ item_eb_hat = torch.reshape(item_eb_hat, (-1, self.interest_num, self.seq_len, self.embedding_dim))
78
+
79
+ if self.stop_grad:
80
+ item_eb_hat_iter = item_eb_hat.detach()
81
+ else:
82
+ item_eb_hat_iter = item_eb_hat
83
+
84
+ if self.bilinear_type > 0:
85
+ capsule_weight = torch.zeros(item_eb_hat.shape[0],
86
+ self.interest_num,
87
+ self.seq_len,
88
+ device=item_eb.device,
89
+ requires_grad=False)
90
+ else:
91
+ capsule_weight = torch.randn(item_eb_hat.shape[0],
92
+ self.interest_num,
93
+ self.seq_len,
94
+ device=item_eb.device,
95
+ requires_grad=False)
96
+
97
+ for i in range(self.routing_times): # 动态路由传播3次
98
+ atten_mask = torch.unsqueeze(mask, 1).repeat(1, self.interest_num, 1)
99
+ paddings = torch.zeros_like(atten_mask, dtype=torch.float)
100
+
101
+ capsule_softmax_weight = F.softmax(capsule_weight, dim=-1)
102
+ capsule_softmax_weight = torch.where(torch.eq(atten_mask, 0), paddings, capsule_softmax_weight)
103
+ capsule_softmax_weight = torch.unsqueeze(capsule_softmax_weight, 2)
104
+
105
+ if i < 2:
106
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat_iter)
107
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
108
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
109
+ interest_capsule = scalar_factor * interest_capsule
110
+
111
+ delta_weight = torch.matmul(item_eb_hat_iter, torch.transpose(interest_capsule, 2, 3).contiguous())
112
+ delta_weight = torch.reshape(delta_weight, (-1, self.interest_num, self.seq_len))
113
+ capsule_weight = capsule_weight + delta_weight
114
+ else:
115
+ interest_capsule = torch.matmul(capsule_softmax_weight, item_eb_hat)
116
+ cap_norm = torch.sum(torch.square(interest_capsule), -1, True)
117
+ scalar_factor = cap_norm / (1 + cap_norm) / torch.sqrt(cap_norm + 1e-9)
118
+ interest_capsule = scalar_factor * interest_capsule
119
+
120
+ interest_capsule = torch.reshape(interest_capsule, (-1, self.interest_num, self.embedding_dim))
121
+
122
+ if self.relu_layer:
123
+ interest_capsule = self.relu(interest_capsule)
124
+
125
+ return interest_capsule
17
126
 
18
127
 
19
128
  class MIND(BaseMatchModel):