nextrec 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. {nextrec-0.3.2 → nextrec-0.3.3}/PKG-INFO +3 -3
  2. {nextrec-0.3.2 → nextrec-0.3.3}/README.md +2 -2
  3. {nextrec-0.3.2 → nextrec-0.3.3}/README_zh.md +2 -2
  4. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/conf.py +1 -1
  5. nextrec-0.3.3/nextrec/__version__.py +1 -0
  6. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/features.py +10 -23
  7. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/layers.py +18 -61
  8. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/metrics.py +55 -33
  9. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/model.py +247 -389
  10. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/data/__init__.py +2 -2
  11. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/data/data_utils.py +80 -4
  12. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/data/dataloader.py +36 -57
  13. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/data/preprocessor.py +5 -4
  14. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/generative/hstu.py +1 -1
  15. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/dssm.py +2 -2
  16. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/dssm_v2.py +2 -2
  17. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/mind.py +2 -2
  18. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/sdm.py +2 -2
  19. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/youtube_dnn.py +2 -2
  20. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/multi_task/esmm.py +1 -1
  21. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/multi_task/mmoe.py +1 -1
  22. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/multi_task/ple.py +1 -1
  23. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/multi_task/poso.py +1 -1
  24. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/multi_task/share_bottom.py +1 -1
  25. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/afm.py +1 -1
  26. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/autoint.py +1 -1
  27. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/dcn.py +1 -1
  28. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/deepfm.py +1 -1
  29. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/dien.py +1 -1
  30. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/din.py +1 -1
  31. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/fibinet.py +1 -1
  32. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/fm.py +1 -1
  33. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/masknet.py +2 -2
  34. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/pnn.py +1 -1
  35. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/widedeep.py +1 -1
  36. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/xdeepfm.py +1 -1
  37. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/utils/__init__.py +2 -1
  38. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/utils/common.py +21 -2
  39. {nextrec-0.3.2 → nextrec-0.3.3}/pyproject.toml +1 -1
  40. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_losses.py +1 -1
  41. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_multitask_models.py +1 -1
  42. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/example_match_dssm.py +1 -1
  43. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/example_multitask.py +1 -1
  44. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/example_ranking_din.py +1 -5
  45. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/movielen_match_dssm.py +3 -2
  46. nextrec-0.3.3/tutorials/run_all_tutorials.py +59 -0
  47. nextrec-0.3.2/nextrec/__version__.py +0 -1
  48. {nextrec-0.3.2 → nextrec-0.3.3}/.github/workflows/publish.yml +0 -0
  49. {nextrec-0.3.2 → nextrec-0.3.3}/.github/workflows/tests.yml +0 -0
  50. {nextrec-0.3.2 → nextrec-0.3.3}/.gitignore +0 -0
  51. {nextrec-0.3.2 → nextrec-0.3.3}/.readthedocs.yaml +0 -0
  52. {nextrec-0.3.2 → nextrec-0.3.3}/CODE_OF_CONDUCT.md +0 -0
  53. {nextrec-0.3.2 → nextrec-0.3.3}/CONTRIBUTING.md +0 -0
  54. {nextrec-0.3.2 → nextrec-0.3.3}/LICENSE +0 -0
  55. {nextrec-0.3.2 → nextrec-0.3.3}/MANIFEST.in +0 -0
  56. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/Feature Configuration.png +0 -0
  57. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/Model Parameters.png +0 -0
  58. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/Training Configuration.png +0 -0
  59. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/Training logs.png +0 -0
  60. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/logo.png +0 -0
  61. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/mmoe_tutorial.png +0 -0
  62. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/nextrec_diagram_en.png +0 -0
  63. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/nextrec_diagram_zh.png +0 -0
  64. {nextrec-0.3.2 → nextrec-0.3.3}/asserts/test data.png +0 -0
  65. {nextrec-0.3.2 → nextrec-0.3.3}/dataset/ctcvr_task.csv +0 -0
  66. {nextrec-0.3.2 → nextrec-0.3.3}/dataset/match_task.csv +0 -0
  67. {nextrec-0.3.2 → nextrec-0.3.3}/dataset/movielens_100k.csv +0 -0
  68. {nextrec-0.3.2 → nextrec-0.3.3}/dataset/multitask_task.csv +0 -0
  69. {nextrec-0.3.2 → nextrec-0.3.3}/dataset/ranking_task.csv +0 -0
  70. {nextrec-0.3.2 → nextrec-0.3.3}/docs/en/Getting started guide.md +0 -0
  71. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/Makefile +0 -0
  72. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/index.md +0 -0
  73. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/make.bat +0 -0
  74. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/modules.rst +0 -0
  75. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/nextrec.basic.rst +0 -0
  76. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/nextrec.data.rst +0 -0
  77. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/nextrec.loss.rst +0 -0
  78. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/nextrec.rst +0 -0
  79. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/nextrec.utils.rst +0 -0
  80. {nextrec-0.3.2 → nextrec-0.3.3}/docs/rtd/requirements.txt +0 -0
  81. {nextrec-0.3.2 → nextrec-0.3.3}/docs/zh//345/277/253/351/200/237/344/270/212/346/211/213.md" +0 -0
  82. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/__init__.py +0 -0
  83. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/__init__.py +0 -0
  84. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/activation.py +0 -0
  85. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/callback.py +0 -0
  86. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/loggers.py +0 -0
  87. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/basic/session.py +0 -0
  88. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/loss/__init__.py +0 -0
  89. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/loss/listwise.py +0 -0
  90. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/loss/loss_utils.py +0 -0
  91. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/loss/pairwise.py +0 -0
  92. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/loss/pointwise.py +0 -0
  93. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/generative/__init__.py +0 -0
  94. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/generative/tiger.py +0 -0
  95. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/match/__init__.py +0 -0
  96. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/__init__.py +0 -0
  97. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/models/ranking/dcn_v2.py +0 -0
  98. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/utils/embedding.py +0 -0
  99. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/utils/initializer.py +0 -0
  100. {nextrec-0.3.2 → nextrec-0.3.3}/nextrec/utils/optimizer.py +0 -0
  101. {nextrec-0.3.2 → nextrec-0.3.3}/pytest.ini +0 -0
  102. {nextrec-0.3.2 → nextrec-0.3.3}/requirements.txt +0 -0
  103. {nextrec-0.3.2 → nextrec-0.3.3}/test/__init__.py +0 -0
  104. {nextrec-0.3.2 → nextrec-0.3.3}/test/conftest.py +0 -0
  105. {nextrec-0.3.2 → nextrec-0.3.3}/test/run_tests.py +0 -0
  106. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_layers.py +0 -0
  107. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_match_models.py +0 -0
  108. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_preprocessor.py +0 -0
  109. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_ranking_models.py +0 -0
  110. {nextrec-0.3.2 → nextrec-0.3.3}/test/test_utils.py +0 -0
  111. {nextrec-0.3.2 → nextrec-0.3.3}/test_requirements.txt +0 -0
  112. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/movielen_ranking_deepfm.py +0 -0
  113. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/notebooks/en/Hands on dataprocessor.ipynb +0 -0
  114. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/notebooks/en/Hands on nextrec.ipynb +0 -0
  115. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/notebooks/zh/Hands on dataprocessor.ipynb +0 -0
  116. {nextrec-0.3.2 → nextrec-0.3.3}/tutorials/notebooks/zh/Hands on nextrec.ipynb +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nextrec
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: A comprehensive recommendation library with match, ranking, and multi-task learning models
5
5
  Project-URL: Homepage, https://github.com/zerolovesea/NextRec
6
6
  Project-URL: Repository, https://github.com/zerolovesea/NextRec
@@ -63,7 +63,7 @@ Description-Content-Type: text/markdown
63
63
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
64
64
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
65
65
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
66
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
66
+ ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
67
67
 
68
68
  English | [中文文档](README_zh.md)
69
69
 
@@ -110,7 +110,7 @@ To dive deeper, Jupyter notebooks are available:
110
110
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
111
111
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
112
112
 
113
- > Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
113
+ > Current version [0.3.3]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
114
114
 
115
115
  ## 5-Minute Quick Start
116
116
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
11
11
 
12
12
  English | [中文文档](README_zh.md)
13
13
 
@@ -54,7 +54,7 @@ To dive deeper, Jupyter notebooks are available:
54
54
  - [Hands on the NextRec framework](/tutorials/notebooks/en/Hands%20on%20nextrec.ipynb)
55
55
  - [Using the data processor for preprocessing](/tutorials/notebooks/en/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > Current version [0.3.2]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
57
+ > Current version [0.3.3]: the matching module is not fully polished yet and may have compatibility issues or unexpected errors. Please raise an issue if you run into problems.
58
58
 
59
59
  ## 5-Minute Quick Start
60
60
 
@@ -7,7 +7,7 @@
7
7
  ![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)
8
8
  ![PyTorch](https://img.shields.io/badge/PyTorch-1.10+-ee4c2c.svg)
9
9
  ![License](https://img.shields.io/badge/License-Apache%202.0-green.svg)
10
- ![Version](https://img.shields.io/badge/Version-0.3.2-orange.svg)
10
+ ![Version](https://img.shields.io/badge/Version-0.3.3-orange.svg)
11
11
 
12
12
  [English Version](README.md) | 中文文档
13
13
 
@@ -54,7 +54,7 @@ NextRec采用模块化、低耦合的工程设计,使得推荐系统从数据
54
54
  - [如何上手NextRec框架](/tutorials/notebooks/zh/Hands%20on%20nextrec.ipynb)
55
55
  - [如何使用数据处理器进行数据预处理](/tutorials/notebooks/zh/Hands%20on%20dataprocessor.ipynb)
56
56
 
57
- > 当前版本[0.3.2],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
57
+ > 当前版本[0.3.3],召回模型模块尚不完善,可能存在一些兼容性问题或意外报错,如果遇到问题,欢迎开发者在Issue区提出问题。
58
58
 
59
59
  ## 5分钟快速上手
60
60
 
@@ -11,7 +11,7 @@ sys.path.insert(0, str(PROJECT_ROOT / "nextrec"))
11
11
  project = "NextRec"
12
12
  copyright = "2025, Yang Zhou"
13
13
  author = "Yang Zhou"
14
- release = "0.3.2"
14
+ release = "0.3.3"
15
15
 
16
16
  extensions = [
17
17
  "myst_parser",
@@ -0,0 +1 @@
1
+ __version__ = "0.3.3"
@@ -2,19 +2,16 @@
2
2
  Feature definitions
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou, zyaztec@gmail.com
7
7
  """
8
8
  import torch
9
9
  from nextrec.utils.embedding import get_auto_embedding_dim
10
+ from nextrec.utils.common import normalize_to_list
10
11
 
11
12
  class BaseFeature(object):
12
13
  def __repr__(self):
13
- params = {
14
- k: v
15
- for k, v in self.__dict__.items()
16
- if not k.startswith("_")
17
- }
14
+ params = {k: v for k, v in self.__dict__.items() if not k.startswith("_") }
18
15
  param_str = ", ".join(f"{k}={v!r}" for k, v in params.items())
19
16
  return f"{self.__class__.__name__}({param_str})"
20
17
 
@@ -93,11 +90,8 @@ class DenseFeature(BaseFeature):
93
90
  else:
94
91
  self.use_embedding = use_embedding # user decides for dim <= 1
95
92
 
96
- class FeatureSpecMixin:
97
- """
98
- Mixin that normalizes dense/sparse/sequence feature lists and target/id columns.
99
- """
100
- def _set_feature_config(
93
+ class FeatureSet:
94
+ def set_all_features(
101
95
  self,
102
96
  dense_features: list[DenseFeature] | None = None,
103
97
  sparse_features: list[SparseFeature] | None = None,
@@ -111,21 +105,14 @@ class FeatureSpecMixin:
111
105
 
112
106
  self.all_features = self.dense_features + self.sparse_features + self.sequence_features
113
107
  self.feature_names = [feat.name for feat in self.all_features]
114
- self.target_columns = self._normalize_to_list(target)
115
- self.id_columns = self._normalize_to_list(id_columns)
108
+ self.target_columns = normalize_to_list(target)
109
+ self.id_columns = normalize_to_list(id_columns)
116
110
 
117
- def _set_target_id_config(
111
+ def set_target_id(
118
112
  self,
119
113
  target: str | list[str] | None = None,
120
114
  id_columns: str | list[str] | None = None,
121
115
  ) -> None:
122
- self.target_columns = self._normalize_to_list(target)
123
- self.id_columns = self._normalize_to_list(id_columns)
116
+ self.target_columns = normalize_to_list(target)
117
+ self.id_columns = normalize_to_list(id_columns)
124
118
 
125
- @staticmethod
126
- def _normalize_to_list(value: str | list[str] | None) -> list[str]:
127
- if value is None:
128
- return []
129
- if isinstance(value, str):
130
- return [value]
131
- return list(value)
@@ -18,23 +18,6 @@ from nextrec.basic.features import DenseFeature, SequenceFeature, SparseFeature
18
18
  from nextrec.utils.initializer import get_initializer
19
19
  from nextrec.basic.activation import activation_layer
20
20
 
21
- __all__ = [
22
- "PredictionLayer",
23
- "EmbeddingLayer",
24
- "InputMask",
25
- "LR",
26
- "ConcatPooling",
27
- "AveragePooling",
28
- "SumPooling",
29
- "MLP",
30
- "FM",
31
- "CrossLayer",
32
- "SENETLayer",
33
- "BiLinearInteractionLayer",
34
- "MultiHeadSelfAttention",
35
- "AttentionPoolingLayer",
36
- ]
37
-
38
21
  class PredictionLayer(nn.Module):
39
22
  def __init__(
40
23
  self,
@@ -44,12 +27,10 @@ class PredictionLayer(nn.Module):
44
27
  return_logits: bool = False,
45
28
  ):
46
29
  super().__init__()
47
- if isinstance(task_type, str):
48
- self.task_types = [task_type]
49
- else:
50
- self.task_types = list(task_type)
30
+ self.task_types = [task_type] if isinstance(task_type, str) else list(task_type)
51
31
  if len(self.task_types) == 0:
52
32
  raise ValueError("At least one task_type must be specified.")
33
+
53
34
  if task_dims is None:
54
35
  dims = [1] * len(self.task_types)
55
36
  elif isinstance(task_dims, int):
@@ -64,7 +45,7 @@ class PredictionLayer(nn.Module):
64
45
  self.total_dim = sum(self.task_dims)
65
46
  self.return_logits = return_logits
66
47
 
67
- # Keep slice offsets per task
48
+ # slice offsets per task
68
49
  start = 0
69
50
  self._task_slices: list[tuple[int, int]] = []
70
51
  for dim in self.task_dims:
@@ -85,27 +66,25 @@ class PredictionLayer(nn.Module):
85
66
  logits = x if self.bias is None else x + self.bias
86
67
  outputs = []
87
68
  for task_type, (start, end) in zip(self.task_types, self._task_slices):
88
- task_logits = logits[..., start:end] # Extract logits for the current task
69
+ task_logits = logits[..., start:end] # logits for the current task
89
70
  if self.return_logits:
90
71
  outputs.append(task_logits)
91
72
  continue
92
- activation = self._get_activation(task_type)
73
+ task = task_type.lower()
74
+ if task == 'binary':
75
+ activation = torch.sigmoid
76
+ elif task == 'regression':
77
+ activation = lambda x: x
78
+ elif task == 'multiclass':
79
+ activation = lambda x: torch.softmax(x, dim=-1)
80
+ else:
81
+ raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
93
82
  outputs.append(activation(task_logits))
94
83
  result = torch.cat(outputs, dim=-1)
95
84
  if result.shape[-1] == 1:
96
85
  result = result.squeeze(-1)
97
86
  return result
98
87
 
99
- def _get_activation(self, task_type: str):
100
- task = task_type.lower()
101
- if task == 'binary':
102
- return torch.sigmoid
103
- if task == 'regression':
104
- return lambda x: x
105
- if task == 'multiclass':
106
- return lambda x: torch.softmax(x, dim=-1)
107
- raise ValueError(f"[PredictionLayer Error]: Unsupported task_type '{task_type}'.")
108
-
109
88
  class EmbeddingLayer(nn.Module):
110
89
  def __init__(self, features: list):
111
90
  super().__init__()
@@ -145,7 +124,7 @@ class EmbeddingLayer(nn.Module):
145
124
  self.dense_input_dims[feature.name] = in_dim
146
125
  else:
147
126
  raise TypeError(f"[EmbeddingLayer Error]: Unsupported feature type: {type(feature)}")
148
- self.output_dim = self._compute_output_dim()
127
+ self.output_dim = self.compute_output_dim()
149
128
 
150
129
  def forward(
151
130
  self,
@@ -181,7 +160,7 @@ class EmbeddingLayer(nn.Module):
181
160
  sparse_embeds.append(pooling_layer(seq_emb, feature_mask).unsqueeze(1))
182
161
 
183
162
  elif isinstance(feature, DenseFeature):
184
- dense_embeds.append(self._project_dense(feature, x))
163
+ dense_embeds.append(self.project_dense(feature, x))
185
164
 
186
165
  if squeeze_dim:
187
166
  flattened_sparse = [emb.flatten(start_dim=1) for emb in sparse_embeds]
@@ -212,7 +191,7 @@ class EmbeddingLayer(nn.Module):
212
191
  raise ValueError("[EmbeddingLayer Error]: squeeze_dim=False requires at least one sparse/sequence feature or dense features with identical projected dimensions.")
213
192
  return torch.cat(output_embeddings, dim=1)
214
193
 
215
- def _project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
194
+ def project_dense(self, feature: DenseFeature, x: dict[str, torch.Tensor]) -> torch.Tensor:
216
195
  if feature.name not in x:
217
196
  raise KeyError(f"[EmbeddingLayer Error]:Dense feature '{feature.name}' is missing from input.")
218
197
  value = x[feature.name].float()
@@ -228,11 +207,7 @@ class EmbeddingLayer(nn.Module):
228
207
  dense_layer = self.dense_transforms[feature.name]
229
208
  return dense_layer(value)
230
209
 
231
- def _compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
232
- """
233
- Compute flattened embedding dimension for provided features or all tracked features.
234
- Deduplicates by feature name to avoid double-counting shared embeddings.
235
- """
210
+ def compute_output_dim(self, features: list[DenseFeature | SequenceFeature | SparseFeature] | None = None) -> int:
236
211
  candidates = list(features) if features is not None else self.features
237
212
  unique_feats = OrderedDict((feat.name, feat) for feat in candidates) # type: ignore[assignment]
238
213
  dim = 0
@@ -249,14 +224,13 @@ class EmbeddingLayer(nn.Module):
249
224
  return dim
250
225
 
251
226
  def get_input_dim(self, features: list[object] | None = None) -> int:
252
- return self._compute_output_dim(features) # type: ignore[assignment]
227
+ return self.compute_output_dim(features) # type: ignore[assignment]
253
228
 
254
229
  @property
255
230
  def input_dim(self) -> int:
256
231
  return self.output_dim
257
232
 
258
233
  class InputMask(nn.Module):
259
- """Utility module to build sequence masks for pooling layers."""
260
234
  def __init__(self):
261
235
  super().__init__()
262
236
 
@@ -271,7 +245,6 @@ class InputMask(nn.Module):
271
245
  return mask.unsqueeze(1).float()
272
246
 
273
247
  class LR(nn.Module):
274
- """Wide component from Wide&Deep (Cheng et al., 2016)."""
275
248
  def __init__(
276
249
  self,
277
250
  input_dim: int,
@@ -287,7 +260,6 @@ class LR(nn.Module):
287
260
  return self.fc(x)
288
261
 
289
262
  class ConcatPooling(nn.Module):
290
- """Concatenates sequence embeddings along the temporal dimension."""
291
263
  def __init__(self):
292
264
  super().__init__()
293
265
 
@@ -295,7 +267,6 @@ class ConcatPooling(nn.Module):
295
267
  return x.flatten(start_dim=1, end_dim=2)
296
268
 
297
269
  class AveragePooling(nn.Module):
298
- """Mean pooling with optional padding mask."""
299
270
  def __init__(self):
300
271
  super().__init__()
301
272
 
@@ -308,7 +279,6 @@ class AveragePooling(nn.Module):
308
279
  return sum_pooling_matrix / (non_padding_length.float() + 1e-16)
309
280
 
310
281
  class SumPooling(nn.Module):
311
- """Sum pooling with optional padding mask."""
312
282
  def __init__(self):
313
283
  super().__init__()
314
284
 
@@ -319,7 +289,6 @@ class SumPooling(nn.Module):
319
289
  return torch.bmm(mask, x).squeeze(1)
320
290
 
321
291
  class MLP(nn.Module):
322
- """Stacked fully connected layers used in the deep component."""
323
292
  def __init__(
324
293
  self,
325
294
  input_dim: int,
@@ -345,7 +314,6 @@ class MLP(nn.Module):
345
314
  return self.mlp(x)
346
315
 
347
316
  class FM(nn.Module):
348
- """Factorization Machine (Rendle, 2010) second-order interaction term."""
349
317
  def __init__(self, reduce_sum: bool = True):
350
318
  super().__init__()
351
319
  self.reduce_sum = reduce_sum
@@ -359,7 +327,6 @@ class FM(nn.Module):
359
327
  return 0.5 * ix
360
328
 
361
329
  class CrossLayer(nn.Module):
362
- """Single cross layer used in DCN (Wang et al., 2017)."""
363
330
  def __init__(self, input_dim: int):
364
331
  super(CrossLayer, self).__init__()
365
332
  self.w = torch.nn.Linear(input_dim, 1, bias=False)
@@ -370,7 +337,6 @@ class CrossLayer(nn.Module):
370
337
  return x
371
338
 
372
339
  class SENETLayer(nn.Module):
373
- """Squeeze-and-Excitation block adopted by FiBiNET (Huang et al., 2019)."""
374
340
  def __init__(
375
341
  self,
376
342
  num_fields: int,
@@ -388,7 +354,6 @@ class SENETLayer(nn.Module):
388
354
  return v
389
355
 
390
356
  class BiLinearInteractionLayer(nn.Module):
391
- """Bilinear feature interaction from FiBiNET (Huang et al., 2019)."""
392
357
  def __init__(
393
358
  self,
394
359
  input_dim: int,
@@ -416,7 +381,6 @@ class BiLinearInteractionLayer(nn.Module):
416
381
  return torch.cat(bilinear_list, dim=1)
417
382
 
418
383
  class MultiHeadSelfAttention(nn.Module):
419
- """Multi-head self-attention layer from AutoInt (Song et al., 2019)."""
420
384
  def __init__(
421
385
  self,
422
386
  embedding_dim: int,
@@ -438,13 +402,6 @@ class MultiHeadSelfAttention(nn.Module):
438
402
  self.dropout = nn.Dropout(dropout)
439
403
 
440
404
  def forward(self, x: torch.Tensor) -> torch.Tensor:
441
- """
442
- Args:
443
- x (torch.Tensor): Tensor of shape (batch_size, num_fields, embedding_dim)
444
-
445
- Returns:
446
- torch.Tensor: Output tensor of shape (batch_size, num_fields, embedding_dim)
447
- """
448
405
  batch_size, num_fields, _ = x.shape
449
406
  Q = self.W_Q(x) # [batch_size, num_fields, embedding_dim]
450
407
  K = self.W_K(x)
@@ -2,10 +2,12 @@
2
2
  Metrics computation and configuration for model evaluation.
3
3
 
4
4
  Date: create on 27/10/2025
5
- Checkpoint: edit on 29/11/2025
5
+ Checkpoint: edit on 02/12/2025
6
6
  Author: Yang Zhou,zyaztec@gmail.com
7
7
  """
8
8
  import logging
9
+ from typing import Any
10
+
9
11
  import numpy as np
10
12
  from sklearn.metrics import (
11
13
  roc_auc_score, log_loss, mean_squared_error, mean_absolute_error,
@@ -21,6 +23,32 @@ TASK_DEFAULT_METRICS = {
21
23
  'matching': ['auc', 'gauc', 'precision@10', 'hitrate@10', 'map@10','cosine']+ [f'recall@{k}' for k in (5,10,20)] + [f'ndcg@{k}' for k in (5,10,20)] + [f'mrr@{k}' for k in (5,10,20)]
22
24
  }
23
25
 
26
+
27
+ def check_user_id(*metric_sources: Any) -> bool:
28
+ """Return True when GAUC or ranking@K metrics appear in the provided sources."""
29
+ metric_names: set[str] = set()
30
+ stack: list[Any] = list(metric_sources)
31
+ while stack:
32
+ item = stack.pop()
33
+ if not item:
34
+ continue
35
+ if isinstance(item, dict):
36
+ stack.extend(item.values())
37
+ continue
38
+ if isinstance(item, str):
39
+ metric_names.add(item.lower())
40
+ continue
41
+ try:
42
+ stack.extend(item)
43
+ except TypeError:
44
+ continue
45
+ for name in metric_names:
46
+ if name == "gauc":
47
+ return True
48
+ if name.startswith(("recall@", "precision@", "hitrate@", "hr@", "mrr@", "ndcg@", "map@")):
49
+ return True
50
+ return False
51
+
24
52
  def compute_ks(y_true: np.ndarray, y_pred: np.ndarray) -> float:
25
53
  """Compute Kolmogorov-Smirnov statistic."""
26
54
  sorted_indices = np.argsort(y_pred)[::-1]
@@ -80,7 +108,7 @@ def compute_gauc(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray |
80
108
  gauc = float(np.sum(user_aucs * user_weights) / np.sum(user_weights))
81
109
  return gauc
82
110
 
83
- def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
111
+ def group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndarray]:
84
112
  """Group sample indices by user_id. If user_ids is None, treat all as one group."""
85
113
  if user_ids is None:
86
114
  return [np.arange(n_samples)]
@@ -92,13 +120,13 @@ def _group_indices_by_user(user_ids: np.ndarray, n_samples: int) -> list[np.ndar
92
120
  groups = [np.where(user_ids == u)[0] for u in unique_users]
93
121
  return groups
94
122
 
95
- def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
123
+ def compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
96
124
  """Compute Precision@K."""
97
125
  if user_ids is None:
98
126
  raise ValueError("[Metrics Error: Precision@K] user_ids must be provided for Precision@K computation.")
99
127
  y_true = (y_true > 0).astype(int)
100
128
  n = len(y_true)
101
- groups = _group_indices_by_user(user_ids, n)
129
+ groups = group_indices_by_user(user_ids, n)
102
130
  precisions = []
103
131
  for idx in groups:
104
132
  if idx.size == 0:
@@ -112,13 +140,13 @@ def _compute_precision_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np
112
140
  precisions.append(hits / float(k_user))
113
141
  return float(np.mean(precisions)) if precisions else 0.0
114
142
 
115
- def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
143
+ def compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
116
144
  """Compute Recall@K."""
117
145
  if user_ids is None:
118
146
  raise ValueError("[Metrics Error: Recall@K] user_ids must be provided for Recall@K computation.")
119
147
  y_true = (y_true > 0).astype(int)
120
148
  n = len(y_true)
121
- groups = _group_indices_by_user(user_ids, n)
149
+ groups = group_indices_by_user(user_ids, n)
122
150
  recalls = []
123
151
  for idx in groups:
124
152
  if idx.size == 0:
@@ -135,13 +163,13 @@ def _compute_recall_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.nd
135
163
  recalls.append(hits / float(num_pos))
136
164
  return float(np.mean(recalls)) if recalls else 0.0
137
165
 
138
- def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
166
+ def compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
139
167
  """Compute HitRate@K."""
140
168
  if user_ids is None:
141
169
  raise ValueError("[Metrics Error: HitRate@K] user_ids must be provided for HitRate@K computation.")
142
170
  y_true = (y_true > 0).astype(int)
143
171
  n = len(y_true)
144
- groups = _group_indices_by_user(user_ids, n)
172
+ groups = group_indices_by_user(user_ids, n)
145
173
  hits_per_user = []
146
174
  for idx in groups:
147
175
  if idx.size == 0:
@@ -157,13 +185,13 @@ def _compute_hitrate_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.n
157
185
  hits_per_user.append(1.0 if hits > 0 else 0.0)
158
186
  return float(np.mean(hits_per_user)) if hits_per_user else 0.0
159
187
 
160
- def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
188
+ def compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
161
189
  """Compute MRR@K."""
162
190
  if user_ids is None:
163
191
  raise ValueError("[Metrics Error: MRR@K] user_ids must be provided for MRR@K computation.")
164
192
  y_true = (y_true > 0).astype(int)
165
193
  n = len(y_true)
166
- groups = _group_indices_by_user(user_ids, n)
194
+ groups = group_indices_by_user(user_ids, n)
167
195
  mrrs = []
168
196
  for idx in groups:
169
197
  if idx.size == 0:
@@ -184,7 +212,7 @@ def _compute_mrr_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
184
212
  mrrs.append(rr)
185
213
  return float(np.mean(mrrs)) if mrrs else 0.0
186
214
 
187
- def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
215
+ def compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
188
216
  k_user = min(k, labels.size)
189
217
  if k_user == 0:
190
218
  return 0.0
@@ -192,13 +220,13 @@ def _compute_dcg_at_k(labels: np.ndarray, k: int) -> float:
192
220
  discounts = np.log2(np.arange(2, k_user + 2))
193
221
  return float(np.sum(gains / discounts))
194
222
 
195
- def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
223
+ def compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
196
224
  """Compute NDCG@K."""
197
225
  if user_ids is None:
198
226
  raise ValueError("[Metrics Error: NDCG@K] user_ids must be provided for NDCG@K computation.")
199
227
  y_true = (y_true > 0).astype(int)
200
228
  n = len(y_true)
201
- groups = _group_indices_by_user(user_ids, n)
229
+ groups = group_indices_by_user(user_ids, n)
202
230
  ndcgs = []
203
231
  for idx in groups:
204
232
  if idx.size == 0:
@@ -209,23 +237,23 @@ def _compute_ndcg_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndar
209
237
  scores = y_pred[idx]
210
238
  order = np.argsort(scores)[::-1]
211
239
  ranked_labels = labels[order]
212
- dcg = _compute_dcg_at_k(ranked_labels, k)
240
+ dcg = compute_dcg_at_k(ranked_labels, k)
213
241
  # ideal DCG
214
242
  ideal_labels = np.sort(labels)[::-1]
215
- idcg = _compute_dcg_at_k(ideal_labels, k)
243
+ idcg = compute_dcg_at_k(ideal_labels, k)
216
244
  if idcg == 0.0:
217
245
  continue
218
246
  ndcgs.append(dcg / idcg)
219
247
  return float(np.mean(ndcgs)) if ndcgs else 0.0
220
248
 
221
249
 
222
- def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
250
+ def compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarray, k: int) -> float:
223
251
  """Mean Average Precision@K."""
224
252
  if user_ids is None:
225
253
  raise ValueError("[Metrics Error: MAP@K] user_ids must be provided for MAP@K computation.")
226
254
  y_true = (y_true > 0).astype(int)
227
255
  n = len(y_true)
228
- groups = _group_indices_by_user(user_ids, n)
256
+ groups = group_indices_by_user(user_ids, n)
229
257
  aps = []
230
258
  for idx in groups:
231
259
  if idx.size == 0:
@@ -250,7 +278,7 @@ def _compute_map_at_k(y_true: np.ndarray, y_pred: np.ndarray, user_ids: np.ndarr
250
278
  return float(np.mean(aps)) if aps else 0.0
251
279
 
252
280
 
253
- def _compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
281
+ def compute_cosine_separation(y_true: np.ndarray, y_pred: np.ndarray) -> float:
254
282
  """Compute Cosine Separation."""
255
283
  y_true = (y_true > 0).astype(int)
256
284
  pos_mask = y_true == 1
@@ -310,10 +338,10 @@ def configure_metrics(
310
338
  if primary_task not in TASK_DEFAULT_METRICS:
311
339
  raise ValueError(f"Unsupported task type: {primary_task}")
312
340
  metrics_list = TASK_DEFAULT_METRICS[primary_task]
313
- best_metrics_mode = get_best_metric_mode(metrics_list[0], primary_task)
341
+ best_metrics_mode = getbest_metric_mode(metrics_list[0], primary_task)
314
342
  return metrics_list, task_specific_metrics, best_metrics_mode
315
343
 
316
- def get_best_metric_mode(first_metric: str, primary_task: str) -> str:
344
+ def getbest_metric_mode(first_metric: str, primary_task: str) -> str:
317
345
  """Determine if metric should be maximized or minimized."""
318
346
  first_metric_lower = first_metric.lower()
319
347
  # Metrics that should be maximized
@@ -350,34 +378,28 @@ def compute_single_metric(
350
378
  y_p_binary = (y_pred > 0.5).astype(int)
351
379
  try:
352
380
  metric_lower = metric.lower()
353
- # recall@K
354
381
  if metric_lower.startswith('recall@'):
355
382
  k = int(metric_lower.split('@')[1])
356
- return _compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
357
- # precision@K
383
+ return compute_recall_at_k(y_true, y_pred, user_ids, k) # type: ignore
358
384
  if metric_lower.startswith('precision@'):
359
385
  k = int(metric_lower.split('@')[1])
360
- return _compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
361
- # hitrate@K / hr@K
386
+ return compute_precision_at_k(y_true, y_pred, user_ids, k) # type: ignore
362
387
  if metric_lower.startswith('hitrate@') or metric_lower.startswith('hr@'):
363
388
  k_str = metric_lower.split('@')[1]
364
389
  k = int(k_str)
365
- return _compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
366
- # mrr@K
390
+ return compute_hitrate_at_k(y_true, y_pred, user_ids, k) # type: ignore
367
391
  if metric_lower.startswith('mrr@'):
368
392
  k = int(metric_lower.split('@')[1])
369
- return _compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
370
- # ndcg@K
393
+ return compute_mrr_at_k(y_true, y_pred, user_ids, k) # type: ignore
371
394
  if metric_lower.startswith('ndcg@'):
372
395
  k = int(metric_lower.split('@')[1])
373
- return _compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
374
- # map@K
396
+ return compute_ndcg_at_k(y_true, y_pred, user_ids, k) # type: ignore
375
397
  if metric_lower.startswith('map@'):
376
398
  k = int(metric_lower.split('@')[1])
377
- return _compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
399
+ return compute_map_at_k(y_true, y_pred, user_ids, k) # type: ignore
378
400
  # cosine for matching task
379
401
  if metric_lower == 'cosine':
380
- return _compute_cosine_separation(y_true, y_pred)
402
+ return compute_cosine_separation(y_true, y_pred)
381
403
  if metric == 'auc':
382
404
  value = float(roc_auc_score(y_true, y_pred, average='macro' if task_type == 'multilabel' else None))
383
405
  elif metric == 'gauc':