recnexteval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. recnexteval/__init__.py +20 -0
  2. recnexteval/algorithms/__init__.py +99 -0
  3. recnexteval/algorithms/base.py +377 -0
  4. recnexteval/algorithms/baseline/__init__.py +10 -0
  5. recnexteval/algorithms/baseline/decay_popularity.py +110 -0
  6. recnexteval/algorithms/baseline/most_popular.py +72 -0
  7. recnexteval/algorithms/baseline/random.py +39 -0
  8. recnexteval/algorithms/baseline/recent_popularity.py +34 -0
  9. recnexteval/algorithms/itemknn/__init__.py +14 -0
  10. recnexteval/algorithms/itemknn/itemknn.py +119 -0
  11. recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
  12. recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
  13. recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
  14. recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
  15. recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
  16. recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
  17. recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
  18. recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
  19. recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
  20. recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
  21. recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
  22. recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
  23. recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
  24. recnexteval/algorithms/utils.py +51 -0
  25. recnexteval/datasets/__init__.py +109 -0
  26. recnexteval/datasets/base.py +316 -0
  27. recnexteval/datasets/config/__init__.py +113 -0
  28. recnexteval/datasets/config/amazon.py +188 -0
  29. recnexteval/datasets/config/base.py +72 -0
  30. recnexteval/datasets/config/lastfm.py +105 -0
  31. recnexteval/datasets/config/movielens.py +169 -0
  32. recnexteval/datasets/config/yelp.py +25 -0
  33. recnexteval/datasets/datasets/__init__.py +24 -0
  34. recnexteval/datasets/datasets/amazon.py +151 -0
  35. recnexteval/datasets/datasets/base.py +250 -0
  36. recnexteval/datasets/datasets/lastfm.py +121 -0
  37. recnexteval/datasets/datasets/movielens.py +93 -0
  38. recnexteval/datasets/datasets/test.py +46 -0
  39. recnexteval/datasets/datasets/yelp.py +103 -0
  40. recnexteval/datasets/metadata/__init__.py +58 -0
  41. recnexteval/datasets/metadata/amazon.py +68 -0
  42. recnexteval/datasets/metadata/base.py +38 -0
  43. recnexteval/datasets/metadata/lastfm.py +110 -0
  44. recnexteval/datasets/metadata/movielens.py +87 -0
  45. recnexteval/evaluators/__init__.py +189 -0
  46. recnexteval/evaluators/accumulator.py +167 -0
  47. recnexteval/evaluators/base.py +216 -0
  48. recnexteval/evaluators/builder/__init__.py +125 -0
  49. recnexteval/evaluators/builder/base.py +166 -0
  50. recnexteval/evaluators/builder/pipeline.py +111 -0
  51. recnexteval/evaluators/builder/stream.py +54 -0
  52. recnexteval/evaluators/evaluator_pipeline.py +287 -0
  53. recnexteval/evaluators/evaluator_stream.py +374 -0
  54. recnexteval/evaluators/state_management.py +310 -0
  55. recnexteval/evaluators/strategy.py +32 -0
  56. recnexteval/evaluators/util.py +124 -0
  57. recnexteval/matrix/__init__.py +48 -0
  58. recnexteval/matrix/exception.py +5 -0
  59. recnexteval/matrix/interaction_matrix.py +784 -0
  60. recnexteval/matrix/prediction_matrix.py +153 -0
  61. recnexteval/matrix/util.py +24 -0
  62. recnexteval/metrics/__init__.py +57 -0
  63. recnexteval/metrics/binary/__init__.py +4 -0
  64. recnexteval/metrics/binary/hit.py +49 -0
  65. recnexteval/metrics/core/__init__.py +10 -0
  66. recnexteval/metrics/core/base.py +126 -0
  67. recnexteval/metrics/core/elementwise_top_k.py +75 -0
  68. recnexteval/metrics/core/listwise_top_k.py +72 -0
  69. recnexteval/metrics/core/top_k.py +60 -0
  70. recnexteval/metrics/core/util.py +29 -0
  71. recnexteval/metrics/ranking/__init__.py +6 -0
  72. recnexteval/metrics/ranking/dcg.py +55 -0
  73. recnexteval/metrics/ranking/ndcg.py +78 -0
  74. recnexteval/metrics/ranking/precision.py +51 -0
  75. recnexteval/metrics/ranking/recall.py +42 -0
  76. recnexteval/models/__init__.py +4 -0
  77. recnexteval/models/base.py +69 -0
  78. recnexteval/preprocessing/__init__.py +37 -0
  79. recnexteval/preprocessing/filter.py +181 -0
  80. recnexteval/preprocessing/preprocessor.py +137 -0
  81. recnexteval/registries/__init__.py +67 -0
  82. recnexteval/registries/algorithm.py +68 -0
  83. recnexteval/registries/base.py +131 -0
  84. recnexteval/registries/dataset.py +37 -0
  85. recnexteval/registries/metric.py +57 -0
  86. recnexteval/settings/__init__.py +127 -0
  87. recnexteval/settings/base.py +414 -0
  88. recnexteval/settings/exception.py +8 -0
  89. recnexteval/settings/leave_n_out_setting.py +48 -0
  90. recnexteval/settings/processor.py +115 -0
  91. recnexteval/settings/schema.py +11 -0
  92. recnexteval/settings/single_time_point_setting.py +111 -0
  93. recnexteval/settings/sliding_window_setting.py +153 -0
  94. recnexteval/settings/splitters/__init__.py +14 -0
  95. recnexteval/settings/splitters/base.py +57 -0
  96. recnexteval/settings/splitters/n_last.py +39 -0
  97. recnexteval/settings/splitters/n_last_timestamp.py +76 -0
  98. recnexteval/settings/splitters/timestamp.py +82 -0
  99. recnexteval/settings/util.py +0 -0
  100. recnexteval/utils/__init__.py +115 -0
  101. recnexteval/utils/json_to_csv_converter.py +128 -0
  102. recnexteval/utils/logging_tools.py +159 -0
  103. recnexteval/utils/path.py +155 -0
  104. recnexteval/utils/url_certificate_installer.py +54 -0
  105. recnexteval/utils/util.py +166 -0
  106. recnexteval/utils/uuid_util.py +7 -0
  107. recnexteval/utils/yaml_tool.py +65 -0
  108. recnexteval-0.1.0.dist-info/METADATA +85 -0
  109. recnexteval-0.1.0.dist-info/RECORD +110 -0
  110. recnexteval-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,55 @@
1
+ # Adopted from RecPack, An Experimentation Toolkit for Top-N Recommendation
2
+ # Copyright (C) 2020 Froomle N.V.
3
+ # License: GNU AGPLv3 - https://gitlab.com/recpack-maintainers/recpack/-/blob/master/LICENSE
4
+ # Author:
5
+ # Lien Michiels
6
+ # Robin Verachtert
7
+
8
+ import logging
9
+
10
+ import numpy as np
11
+ from scipy.sparse import csr_matrix
12
+
13
+ from ..core.listwise_top_k import ListwiseMetricK
14
+ from ..core.util import sparse_divide_nonzero
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DCGK(ListwiseMetricK):
21
+ """Computes the sum of gains of all items in a recommendation list.
22
+
23
+ Relevant items that are ranked higher in the Top-K recommendations have a higher gain.
24
+
25
+ The Discounted Cumulative Gain (DCG) is computed for every user as
26
+
27
+ .. math::
28
+
29
+ \\text{DiscountedCumulativeGain}(u) = \\sum\\limits_{i \\in Top-K(u)} \\frac{y^{true}_{u,i}}{\\log_2 (\\text{rank}(u,i) + 1)}
30
+
31
+
32
+ :param K: Size of the recommendation list consisting of the Top-K item predictions.
33
+ :type K: int
34
+
35
+ This code is adapted from RecPack :cite:`recpack`
36
+ """
37
+ IS_BASE: bool = False
38
+
39
+ def _calculate(self, y_true: csr_matrix, y_pred: csr_matrix) -> None:
40
+ logger.debug("Precision compute started - %s", self.name)
41
+ logger.debug("Shape of matrix: (%d, %d)", y_true.shape[0], y_true.shape[1])
42
+ logger.debug("Number of ground truth interactions: %d", y_true.nnz)
43
+
44
+ denominator = y_pred.multiply(y_true)
45
+ # Denominator: log2(rank_i + 1)
46
+ denominator.data = np.log2(denominator.data + 1)
47
+ # Binary relevance
48
+ # Numerator: rel_i
49
+ numerator = y_true
50
+
51
+ dcg = sparse_divide_nonzero(numerator, denominator)
52
+
53
+ self._scores = csr_matrix(dcg.sum(axis=1))
54
+
55
+ logger.debug(f"DCGK compute complete - {self.name}")
@@ -0,0 +1,78 @@
1
+ # Adopted from RecPack, An Experimentation Toolkit for Top-N Recommendation
2
+ # Copyright (C) 2020 Froomle N.V.
3
+ # License: GNU AGPLv3 - https://gitlab.com/recpack-maintainers/recpack/-/blob/master/LICENSE
4
+ # Author:
5
+ # Lien Michiels
6
+ # Robin Verachtert
7
+
8
+ import logging
9
+ from typing import Optional
10
+
11
+ import numpy as np
12
+ from scipy.sparse import csr_matrix
13
+
14
+ from ..core.listwise_top_k import ListwiseMetricK
15
+ from ..core.util import sparse_divide_nonzero
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class NDCGK(ListwiseMetricK):
22
+
23
+ """Computes the normalized sum of gains of all items in a recommendation list.
24
+
25
+ The normalized Discounted Cumulative Gain (nDCG) is similar to DCG,
26
+ but normalizes by dividing the resulting sum of cumulative gains
27
+ by the best possible discounted cumulative gain for a list of recommendations
28
+ of length K for a user with history length N.
29
+
30
+ Scores are always in the interval [0, 1]
31
+
32
+ .. math::
33
+
34
+ \\text{NormalizedDiscountedCumulativeGain}(u) = \\frac{\\text{DCG}(u)}{\\text{IDCG}(u)}
35
+
36
+ where IDCG stands for Ideal Discounted Cumulative Gain, computed as:
37
+
38
+ .. math::
39
+
40
+ \\text{IDCG}(u) = \\sum\\limits_{j=1}^{\\text{min}(K, |y^{true}_u|)} \\frac{1}{\\log_2 (j + 1)}
41
+
42
+ :param K: Size of the recommendation list consisting of the Top-K item predictions.
43
+ :type K: int
44
+
45
+ This code is adapted from RecPack :cite:`recpack`
46
+ """
47
+ IS_BASE: bool = False
48
+
49
+ def _calculate(self, y_true: csr_matrix, y_pred: csr_matrix) -> None:
50
+ logger.debug(f"NDCGK compute started - {self.name}")
51
+ logger.debug(f"Number of users: {y_true.shape[0]}")
52
+ logger.debug(f"Number of ground truth interactions: {y_true.nnz}")
53
+
54
+ self.discount_template = 1.0 / np.log2(np.arange(2, self.K + 2))
55
+ # Calculate IDCG values by creating a list of partial sums
56
+ self.IDCG_cache = np.concatenate([[1], np.cumsum(self.discount_template)])
57
+
58
+ # Correct predictions only
59
+ denominator = y_pred.multiply(y_true)
60
+ # Denominator: log2(rank_i + 1)
61
+ denominator.data = np.log2(denominator.data + 1)
62
+ # Binary relevance
63
+ # Numerator: rel_i
64
+ numerator = y_true
65
+
66
+ dcg = sparse_divide_nonzero(numerator, denominator)
67
+
68
+ per_user_dcg = dcg.sum(axis=1)
69
+
70
+ hist_len = y_true.sum(axis=1).astype(np.int32)
71
+ hist_len[hist_len > self.K] = self.K
72
+
73
+ self._scores = sparse_divide_nonzero(
74
+ csr_matrix(per_user_dcg),
75
+ csr_matrix(self.IDCG_cache[hist_len]),
76
+ )
77
+
78
+ logger.debug(f"NDCGK compute complete - {self.name}")
@@ -0,0 +1,51 @@
1
+ import logging
2
+
3
+ import scipy.sparse
4
+ from scipy.sparse import csr_matrix
5
+
6
+ from ..core.listwise_top_k import ListwiseMetricK
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class PrecisionK(ListwiseMetricK):
13
+ """Computes the fraction of top-K recommendations that correspond
14
+ to true interactions.
15
+
16
+ Given the prediction and true interaction in binary representation,
17
+ the matrix is multiplied elementwise. These will result in the true
18
+ positives to be 1 and the false positives to be 0. The sum of the
19
+ resulting true positives is then divided by the number of actual top-K
20
+ interactions to get the precision on user level.
21
+
22
+ In simple terms, precision is the ratio of correctly predicted positive
23
+ observations to the total predictions made.
24
+
25
+ Precision is computed per user as:
26
+
27
+ .. math::
28
+
29
+ \\text{Precision}(u) = \\frac{\\sum\\limits_{i \\in \\text{Top-K}(u)} y^{true}_{u,i}}{K}\\
30
+
31
+ ref: RecPack
32
+
33
+ :param K: Size of the recommendation list consisting of the Top-K item predictions.
34
+ :type K: int
35
+ """
36
+ IS_BASE: bool = False
37
+
38
+ def _calculate(self, y_true: csr_matrix, y_pred: csr_matrix) -> None:
39
+ scores = scipy.sparse.lil_matrix(y_pred.shape)
40
+
41
+ logger.debug("Precision compute started - %s", self.name)
42
+ logger.debug("Shape of matrix: (%d, %d)", y_true.shape[0], y_true.shape[1])
43
+ logger.debug("Number of ground truth interactions: %d", y_true.nnz)
44
+
45
+ # obtain true positives
46
+ scores[y_pred.multiply(y_true).astype(bool)] = 1
47
+ scores = scores.tocsr()
48
+
49
+ # true positive/total predictions
50
+ self._scores = csr_matrix(scores.sum(axis=1)) / self.K
51
+ logger.debug("Precision compute complete - %s", self.name)
@@ -0,0 +1,42 @@
1
+ import logging
2
+
3
+ import scipy.sparse
4
+ from scipy.sparse import csr_matrix
5
+
6
+ from ..core.listwise_top_k import ListwiseMetricK
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class RecallK(ListwiseMetricK):
13
+ """Computes the fraction of true interactions that made it into
14
+ the Top-K recommendations.
15
+
16
+ Recall per user is computed as:
17
+
18
+ .. math::
19
+
20
+ \\text{Recall}(u) = \\frac{\\sum\\limits_{i \\in \\text{Top-K}(u)} y^{true}_{u,i} }{\\sum\\limits_{j \\in I} y^{true}_{u,j}}
21
+
22
+ ref: RecPack
23
+
24
+ :param K: Size of the recommendation list consisting of the Top-K item predictions.
25
+ :type K: int
26
+ """
27
+ IS_BASE: bool = False
28
+
29
+ def _calculate(self, y_true: csr_matrix, y_pred: csr_matrix) -> None:
30
+ scores = scipy.sparse.lil_matrix(y_pred.shape)
31
+
32
+ logger.debug("Precision compute started - %s", self.name)
33
+ logger.debug("Shape of matrix: (%d, %d)", y_true.shape[0], y_true.shape[1])
34
+ logger.debug("Number of ground truth interactions: %d", y_true.nnz)
35
+
36
+ # obtain true positives
37
+ scores[y_pred.multiply(y_true).astype(bool)] = 1
38
+ scores = scores.tocsr()
39
+
40
+ # true positive/total actual interactions
41
+ self._scores = csr_matrix(scores.sum(axis=1) / y_true.sum(axis=1))
42
+ logger.debug(f"Recall compute complete - {self.name}")
@@ -0,0 +1,4 @@
1
+ from .base import BaseModel, ParamMixin
2
+
3
+
4
+ __all__ = ["BaseModel", "ParamMixin"]
@@ -0,0 +1,69 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+
5
+ class BaseModel(ABC):
6
+ """Base class for all recnexteval components.
7
+
8
+ Provides common properties like name and universal IS_BASE flag.
9
+ """
10
+
11
+ IS_BASE: bool = True
12
+
13
+ @property
14
+ def name(self) -> str:
15
+ """Name of the object's class.
16
+
17
+ :return: Name of the object's class
18
+ :rtype: str
19
+ """
20
+ return self.__class__.__name__
21
+
22
+
23
+ class ParamMixin(ABC):
24
+ """Mixin class for all recnexteval components with parameters.
25
+
26
+ Provides common properties like name, params, and identifier.
27
+ """
28
+
29
+ @property
30
+ def name(self) -> str:
31
+ """Name of the object's class.
32
+
33
+ :return: Name of the object's class
34
+ :rtype: str
35
+ """
36
+ return self.__class__.__name__
37
+
38
+ @abstractmethod
39
+ def get_params(self) -> dict[str, Any]:
40
+ """Get the parameters of the object.
41
+
42
+ :return: Parameters of the object
43
+ :rtype: dict
44
+ """
45
+ ...
46
+
47
+ @property
48
+ def params(self) -> dict[str, Any]:
49
+ """Parameters of the object.
50
+
51
+ :return: Parameters of the object
52
+ :rtype: dict
53
+ """
54
+ return self.get_params()
55
+
56
+ @property
57
+ def identifier(self) -> str:
58
+ """Identifier of the object.
59
+
60
+ Identifier is made by combining the class name with the parameters
61
+ passed at construction time.
62
+
63
+ Constructed by recreating the initialisation call.
64
+ Example: `Algorithm(param_1=value)`
65
+
66
+ :return: Identifier of the object
67
+ """
68
+ paramstring = ",".join((f"{k}={v}" for k, v in self.get_params().items()))
69
+ return self.name + "(" + paramstring + ")"
@@ -0,0 +1,37 @@
1
+ """Preprocessing module for data preparation.
2
+
3
+ This module contains filters and preprocessors for preparing data before
4
+ transforming it into an InteractionMatrix object.
5
+
6
+ ## Filters
7
+
8
+ Filters are used to filter data before transforming it into an InteractionMatrix
9
+ object. Filter implementations must extend the abstract `Filter` class.
10
+
11
+ Available filters:
12
+
13
+ - `Filter`: Abstract base class for filter implementations
14
+ - `MinItemsPerUser`: Filter requiring minimum interactions per user
15
+ - `MinUsersPerItem`: Filter requiring minimum interactions per item
16
+
17
+ ## Preprocessor
18
+
19
+ The preprocessor allows adding filters for data preprocessing and manages ID
20
+ mappings. After applying filters, it updates item and user ID mappings to
21
+ internal IDs to reduce computation load and enable easy matrix representation.
22
+
23
+ Available preprocessor:
24
+
25
+ - `DataFramePreprocessor`: Preprocesses pandas DataFrames into InteractionMatrix
26
+ """
27
+
28
+ from recnexteval.preprocessing.filter import Filter, MinItemsPerUser, MinUsersPerItem
29
+ from recnexteval.preprocessing.preprocessor import DataFramePreprocessor
30
+
31
+
32
+ __all__ = [
33
+ "Filter",
34
+ "MinItemsPerUser",
35
+ "MinUsersPerItem",
36
+ "DataFramePreprocessor",
37
+ ]
@@ -0,0 +1,181 @@
1
+ """Data filtering module.
2
+
3
+ This module provides abstract base class and filter implementations for
4
+ removing interactions from a DataFrame based on various criteria.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ import pandas as pd
10
+
11
+
12
+ class Filter(ABC):
13
+ """Abstract base class for filter implementations.
14
+
15
+ A filter must implement an `apply` method that takes a pandas DataFrame
16
+ as input and returns a processed pandas DataFrame as output.
17
+ """
18
+
19
+ @abstractmethod
20
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
21
+ """Apply filter to the DataFrame.
22
+
23
+ Args:
24
+ df: DataFrame to filter.
25
+
26
+ Returns:
27
+ Filtered DataFrame.
28
+ """
29
+ raise NotImplementedError
30
+
31
+ def __str__(self) -> str:
32
+ attrs = self.__dict__
33
+ return f"{self.__class__.__name__}({', '.join((f'{k}={v}' for k, v in attrs.items()))})"
34
+
35
+
36
+ class MinItemsPerUser(Filter):
37
+ """Filter requiring users to have minimum interaction count.
38
+
39
+ Removes users who have interacted with fewer than the specified minimum
40
+ number of items. Adapted from RecPack.
41
+
42
+ Args:
43
+ min_items_per_user: Minimum number of items a user must interact with.
44
+ item_ix: Column name containing item identifiers.
45
+ user_ix: Column name containing user identifiers.
46
+ count_duplicates: Whether to count multiple interactions with the same
47
+ item. Defaults to True.
48
+
49
+ Example:
50
+ Original interactions:
51
+ ```
52
+ user | item
53
+ 1 | a
54
+ 1 | b
55
+ 1 | c
56
+ 2 | a
57
+ 2 | b
58
+ 2 | d
59
+ 3 | a
60
+ 3 | b
61
+ 3 | d
62
+ ```
63
+
64
+ After `MinItemsPerUser(3)`:
65
+ ```
66
+ user | item
67
+ 1 | a
68
+ 1 | b
69
+ 2 | a
70
+ 2 | b
71
+ 3 | a
72
+ 3 | b
73
+ ```
74
+
75
+ Users 1 and 2 are removed (have fewer than 3 items).
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ min_items_per_user: int,
81
+ item_ix: str,
82
+ user_ix: str,
83
+ count_duplicates: bool = True,
84
+ ) -> None:
85
+ self.min_items_per_user = min_items_per_user
86
+ self.count_duplicates = count_duplicates
87
+ self.item_ix = item_ix
88
+ self.user_ix = user_ix
89
+
90
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
91
+ """Apply minimum items per user filter.
92
+
93
+ Args:
94
+ df: DataFrame to filter.
95
+
96
+ Returns:
97
+ DataFrame containing only users with sufficient interactions.
98
+ """
99
+ uids = (
100
+ df[self.user_ix]
101
+ if self.count_duplicates
102
+ else df.drop_duplicates([self.user_ix, self.item_ix])[self.user_ix]
103
+ )
104
+
105
+ cnt_items_per_user = uids.value_counts()
106
+ users_of_interest = list(cnt_items_per_user[cnt_items_per_user >= self.min_items_per_user].index)
107
+ return df[df[self.user_ix].isin(users_of_interest)].copy()
108
+
109
+
110
+ class MinUsersPerItem(Filter):
111
+ """Filter requiring items to have minimum user interaction count.
112
+
113
+ Removes items that have been interacted with by fewer than the specified
114
+ minimum number of users. Adapted from RecPack.
115
+
116
+ Args:
117
+ min_users_per_item: Minimum number of users that must interact with item.
118
+ item_ix: Column name containing item identifiers.
119
+ user_ix: Column name containing user identifiers.
120
+ count_duplicates: Whether to count multiple interactions from the same
121
+ user. Defaults to True.
122
+
123
+ Example:
124
+ Original interactions:
125
+ ```
126
+ user | item
127
+ 1 | a
128
+ 1 | b
129
+ 1 | c
130
+ 2 | a
131
+ 2 | b
132
+ 2 | d
133
+ 3 | a
134
+ 3 | b
135
+ 3 | d
136
+ ```
137
+
138
+ After `MinUsersPerItem(3)`:
139
+ ```
140
+ user | item
141
+ 1 | a
142
+ 1 | b
143
+ 2 | a
144
+ 2 | b
145
+ 3 | a
146
+ 3 | b
147
+ ```
148
+
149
+ Items with fewer than 3 users are removed (c and d).
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ min_users_per_item: int,
155
+ item_ix: str,
156
+ user_ix: str,
157
+ count_duplicates: bool = True,
158
+ ) -> None:
159
+ self.item_ix = item_ix
160
+ self.user_ix = user_ix
161
+ self.min_users_per_item = min_users_per_item
162
+ self.count_duplicates = count_duplicates
163
+
164
+ def apply(self, df: pd.DataFrame) -> pd.DataFrame:
165
+ """Apply minimum users per item filter.
166
+
167
+ Args:
168
+ df: DataFrame to filter.
169
+
170
+ Returns:
171
+ DataFrame containing only items with sufficient user interactions.
172
+ """
173
+ iids = (
174
+ df[self.item_ix]
175
+ if self.count_duplicates
176
+ else df.drop_duplicates([self.user_ix, self.item_ix])[self.item_ix]
177
+ )
178
+
179
+ cnt_users_per_item = iids.value_counts()
180
+ items_of_interest = list(cnt_users_per_item[cnt_users_per_item >= self.min_users_per_item].index)
181
+ return df[df[self.item_ix].isin(items_of_interest)].copy()
@@ -0,0 +1,137 @@
1
+ """Data preprocessing module.
2
+
3
+ This module provides the DataFramePreprocessor class for converting pandas
4
+ DataFrames into InteractionMatrix objects with optional filtering.
5
+ """
6
+
7
+ import logging
8
+ from typing import Literal
9
+
10
+ import pandas as pd
11
+
12
+ from recnexteval.matrix import InteractionMatrix
13
+ from recnexteval.preprocessing.filter import Filter
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class DataFramePreprocessor:
20
+ """Preprocesses pandas DataFrames into InteractionMatrix objects.
21
+
22
+ Allows adding filters for data preprocessing before transforming data into
23
+ an InteractionMatrix object. After applying filters, updates item and user
24
+ ID mappings to internal IDs to reduce computation load and enable easy
25
+ matrix representation.
26
+
27
+ Args:
28
+ item_ix: Column name containing item identifiers.
29
+ user_ix: Column name containing user identifiers.
30
+ timestamp_ix: Column name containing timestamps.
31
+ """
32
+
33
+ def __init__(self, item_ix: str, user_ix: str, timestamp_ix: str) -> None:
34
+ self._item_id_mapping = dict()
35
+ self._user_id_mapping = dict()
36
+ self.item_ix = item_ix
37
+ self.user_ix = user_ix
38
+ self.timestamp_ix = timestamp_ix
39
+ self.filters: list[Filter] = []
40
+
41
+ @property
42
+ def item_id_mapping(self) -> pd.DataFrame:
43
+ """Map from original item IDs to internal item IDs.
44
+
45
+ Returns:
46
+ DataFrame with columns for internal item IDs and original item IDs.
47
+ """
48
+ return pd.DataFrame.from_records(
49
+ list(self._item_id_mapping.items()),
50
+ columns=[InteractionMatrix.ITEM_IX, self.item_ix],
51
+ )
52
+
53
+ @property
54
+ def user_id_mapping(self) -> pd.DataFrame:
55
+ """Map from original user IDs to internal user IDs.
56
+
57
+ Returns:
58
+ DataFrame with columns for internal user IDs and original user IDs.
59
+ """
60
+ return pd.DataFrame.from_records(
61
+ list(self._user_id_mapping.items()),
62
+ columns=[InteractionMatrix.USER_IX, self.user_ix],
63
+ )
64
+
65
+ def add_filter(self, filter_: Filter) -> None:
66
+ """Add a preprocessing filter to be applied.
67
+
68
+ The filter will be applied before transforming to an InteractionMatrix
69
+ object. Filters are applied in order of addition and different orderings
70
+ can lead to different results.
71
+
72
+ Args:
73
+ filter_: The filter to be applied.
74
+ """
75
+ self.filters.append(filter_)
76
+
77
+ def _print_log_message(
78
+ self,
79
+ step: Literal["before", "after"],
80
+ stage: Literal["preprocess", "filter"],
81
+ df: pd.DataFrame,
82
+ ) -> None:
83
+ """Log preprocessing progress.
84
+
85
+ Prints a log message with the number of interactions, items, and users
86
+ in the DataFrame at the current stage.
87
+
88
+ Args:
89
+ step: Indicates whether log is before or after preprocessing.
90
+ stage: Current stage of preprocessing (preprocess or filter).
91
+ df: The DataFrame being processed.
92
+ """
93
+ logger.debug(f"\tinteractions {step} {stage}: {len(df.index)}")
94
+ logger.debug(f"\titems {step} {stage}: {df[self.item_ix].nunique()}")
95
+ logger.debug(f"\tusers {step} {stage}: {df[self.user_ix].nunique()}")
96
+
97
+ def _update_id_mappings(self, df: pd.DataFrame) -> None:
98
+ """Update internal ID mappings for users and items.
99
+
100
+ Internal ID mappings are updated to reduce computation load and enable
101
+ easy matrix representation. IDs are assigned by timestamp order.
102
+
103
+ Args:
104
+ df: DataFrame to update ID mappings for.
105
+ """
106
+ # Sort by timestamp to incrementally assign user and item ids by timestamp
107
+ df.sort_values(by=[self.timestamp_ix], inplace=True, ignore_index=True)
108
+ user_index = pd.CategoricalIndex(df[self.user_ix], categories=df[self.user_ix].unique())
109
+ self._user_id_mapping = dict(enumerate(user_index.drop_duplicates()))
110
+ df[self.user_ix] = user_index.codes
111
+
112
+ item_index = pd.CategoricalIndex(df[self.item_ix], categories=df[self.item_ix].unique())
113
+ self._item_id_mapping = dict(enumerate(item_index.drop_duplicates()))
114
+ df[self.item_ix] = item_index.codes
115
+
116
+ def process(self, df: pd.DataFrame) -> InteractionMatrix:
117
+ """Process DataFrame through filters and convert to InteractionMatrix.
118
+
119
+ Args:
120
+ df: DataFrame to process.
121
+
122
+ Returns:
123
+ InteractionMatrix object created from processed DataFrame.
124
+ """
125
+ self._print_log_message("before", "preprocess", df)
126
+
127
+ for filter_ in self.filters:
128
+ logger.debug(f"applying filter: {filter_}")
129
+ df = filter_.apply(df)
130
+ self._print_log_message("after", "filter", df)
131
+
132
+ self._update_id_mappings(df)
133
+ self._print_log_message("after", "preprocess", df)
134
+
135
+ # Convert input data into internal data objects
136
+ interaction_m = InteractionMatrix(df, self.item_ix, self.user_ix, self.timestamp_ix)
137
+ return interaction_m