recnexteval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recnexteval/__init__.py +20 -0
- recnexteval/algorithms/__init__.py +99 -0
- recnexteval/algorithms/base.py +377 -0
- recnexteval/algorithms/baseline/__init__.py +10 -0
- recnexteval/algorithms/baseline/decay_popularity.py +110 -0
- recnexteval/algorithms/baseline/most_popular.py +72 -0
- recnexteval/algorithms/baseline/random.py +39 -0
- recnexteval/algorithms/baseline/recent_popularity.py +34 -0
- recnexteval/algorithms/itemknn/__init__.py +14 -0
- recnexteval/algorithms/itemknn/itemknn.py +119 -0
- recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
- recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
- recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
- recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
- recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
- recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
- recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
- recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
- recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
- recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
- recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
- recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
- recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
- recnexteval/algorithms/utils.py +51 -0
- recnexteval/datasets/__init__.py +109 -0
- recnexteval/datasets/base.py +316 -0
- recnexteval/datasets/config/__init__.py +113 -0
- recnexteval/datasets/config/amazon.py +188 -0
- recnexteval/datasets/config/base.py +72 -0
- recnexteval/datasets/config/lastfm.py +105 -0
- recnexteval/datasets/config/movielens.py +169 -0
- recnexteval/datasets/config/yelp.py +25 -0
- recnexteval/datasets/datasets/__init__.py +24 -0
- recnexteval/datasets/datasets/amazon.py +151 -0
- recnexteval/datasets/datasets/base.py +250 -0
- recnexteval/datasets/datasets/lastfm.py +121 -0
- recnexteval/datasets/datasets/movielens.py +93 -0
- recnexteval/datasets/datasets/test.py +46 -0
- recnexteval/datasets/datasets/yelp.py +103 -0
- recnexteval/datasets/metadata/__init__.py +58 -0
- recnexteval/datasets/metadata/amazon.py +68 -0
- recnexteval/datasets/metadata/base.py +38 -0
- recnexteval/datasets/metadata/lastfm.py +110 -0
- recnexteval/datasets/metadata/movielens.py +87 -0
- recnexteval/evaluators/__init__.py +189 -0
- recnexteval/evaluators/accumulator.py +167 -0
- recnexteval/evaluators/base.py +216 -0
- recnexteval/evaluators/builder/__init__.py +125 -0
- recnexteval/evaluators/builder/base.py +166 -0
- recnexteval/evaluators/builder/pipeline.py +111 -0
- recnexteval/evaluators/builder/stream.py +54 -0
- recnexteval/evaluators/evaluator_pipeline.py +287 -0
- recnexteval/evaluators/evaluator_stream.py +374 -0
- recnexteval/evaluators/state_management.py +310 -0
- recnexteval/evaluators/strategy.py +32 -0
- recnexteval/evaluators/util.py +124 -0
- recnexteval/matrix/__init__.py +48 -0
- recnexteval/matrix/exception.py +5 -0
- recnexteval/matrix/interaction_matrix.py +784 -0
- recnexteval/matrix/prediction_matrix.py +153 -0
- recnexteval/matrix/util.py +24 -0
- recnexteval/metrics/__init__.py +57 -0
- recnexteval/metrics/binary/__init__.py +4 -0
- recnexteval/metrics/binary/hit.py +49 -0
- recnexteval/metrics/core/__init__.py +10 -0
- recnexteval/metrics/core/base.py +126 -0
- recnexteval/metrics/core/elementwise_top_k.py +75 -0
- recnexteval/metrics/core/listwise_top_k.py +72 -0
- recnexteval/metrics/core/top_k.py +60 -0
- recnexteval/metrics/core/util.py +29 -0
- recnexteval/metrics/ranking/__init__.py +6 -0
- recnexteval/metrics/ranking/dcg.py +55 -0
- recnexteval/metrics/ranking/ndcg.py +78 -0
- recnexteval/metrics/ranking/precision.py +51 -0
- recnexteval/metrics/ranking/recall.py +42 -0
- recnexteval/models/__init__.py +4 -0
- recnexteval/models/base.py +69 -0
- recnexteval/preprocessing/__init__.py +37 -0
- recnexteval/preprocessing/filter.py +181 -0
- recnexteval/preprocessing/preprocessor.py +137 -0
- recnexteval/registries/__init__.py +67 -0
- recnexteval/registries/algorithm.py +68 -0
- recnexteval/registries/base.py +131 -0
- recnexteval/registries/dataset.py +37 -0
- recnexteval/registries/metric.py +57 -0
- recnexteval/settings/__init__.py +127 -0
- recnexteval/settings/base.py +414 -0
- recnexteval/settings/exception.py +8 -0
- recnexteval/settings/leave_n_out_setting.py +48 -0
- recnexteval/settings/processor.py +115 -0
- recnexteval/settings/schema.py +11 -0
- recnexteval/settings/single_time_point_setting.py +111 -0
- recnexteval/settings/sliding_window_setting.py +153 -0
- recnexteval/settings/splitters/__init__.py +14 -0
- recnexteval/settings/splitters/base.py +57 -0
- recnexteval/settings/splitters/n_last.py +39 -0
- recnexteval/settings/splitters/n_last_timestamp.py +76 -0
- recnexteval/settings/splitters/timestamp.py +82 -0
- recnexteval/settings/util.py +0 -0
- recnexteval/utils/__init__.py +115 -0
- recnexteval/utils/json_to_csv_converter.py +128 -0
- recnexteval/utils/logging_tools.py +159 -0
- recnexteval/utils/path.py +155 -0
- recnexteval/utils/url_certificate_installer.py +54 -0
- recnexteval/utils/util.py +166 -0
- recnexteval/utils/uuid_util.py +7 -0
- recnexteval/utils/yaml_tool.py +65 -0
- recnexteval-0.1.0.dist-info/METADATA +85 -0
- recnexteval-0.1.0.dist-info/RECORD +110 -0
- recnexteval-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from typing import Any, Self, Union
|
|
5
|
+
from warnings import warn
|
|
6
|
+
|
|
7
|
+
from recnexteval.matrix import InteractionMatrix
|
|
8
|
+
from ..models import BaseModel, ParamMixin
|
|
9
|
+
from .exception import EOWSettingError
|
|
10
|
+
from .processor import PredictionDataProcessor
|
|
11
|
+
from .schema import SplitResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Setting(BaseModel, ParamMixin):
|
|
18
|
+
"""Base class for defining an evaluation setting.
|
|
19
|
+
|
|
20
|
+
Core Attributes:
|
|
21
|
+
- background_data: Data used for inital training of model. Interval is [0, background_t).
|
|
22
|
+
- unlabeled_data: List of unlabeled data. Each element is an InteractionMatrix
|
|
23
|
+
object of interval [0, t).
|
|
24
|
+
- ground_truth_data: List of ground truth data. Each element is an
|
|
25
|
+
InteractionMatrix object of interval [t, t + window_size).
|
|
26
|
+
- incremental_data: List of data used to incrementally update the model.
|
|
27
|
+
Each element is an InteractionMatrix object of interval [t, t + window_size).
|
|
28
|
+
Unique to SlidingWindowSetting.
|
|
29
|
+
- data_timestamp_limit: List of timestamps that the splitter will slide over.
|
|
30
|
+
|
|
31
|
+
We will use `background_data` as the initial training set, `incremental_data` as the data
|
|
32
|
+
to incrementally update the model. However, for public methods, we will refer to both as
|
|
33
|
+
`training_data` to avoid confusion.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
seed: Seed for randomization. Defaults to 42.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
seed: int = 42,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Initialize the setting.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
seed: Random seed for reproducibility.
|
|
47
|
+
"""
|
|
48
|
+
self.seed = seed
|
|
49
|
+
self.prediction_data_processor = PredictionDataProcessor()
|
|
50
|
+
self._num_split_set = 1
|
|
51
|
+
|
|
52
|
+
self._sliding_window_setting = False
|
|
53
|
+
self._split_complete = False
|
|
54
|
+
"""Number of splits created from sliding window. Defaults to 1 (no splits on training set)."""
|
|
55
|
+
self._num_full_interactions: int
|
|
56
|
+
self._unlabeled_data: InteractionMatrix | list[InteractionMatrix]
|
|
57
|
+
self._ground_truth_data: InteractionMatrix | list[InteractionMatrix]
|
|
58
|
+
"""Data containing the ground truth interactions to :attr:`_unlabeled_data`. If :class:`SlidingWindowSetting`, then it will be a list of :class:`InteractionMatrix`."""
|
|
59
|
+
self._incremental_data: list[InteractionMatrix]
|
|
60
|
+
"""Data that is used to incrementally update the model. Unique to :class:`SlidingWindowSetting`."""
|
|
61
|
+
self._background_data: InteractionMatrix
|
|
62
|
+
"""Data used as the initial set of interactions to train the model."""
|
|
63
|
+
self._t_window: Union[None, int, list[int]]
|
|
64
|
+
"""This is the upper timestamp of the window in split. The actual interaction might have a smaller timestamp value than this because this will is the t cut off value."""
|
|
65
|
+
self.n_seq_data: int
|
|
66
|
+
"""Number of last sequential interactions to provide in :attr:`unlabeled_data` as data for model to make prediction."""
|
|
67
|
+
self.top_K: int
|
|
68
|
+
"""Number of interaction per user that should be selected for evaluation purposes in :attr:`ground_truth_data`."""
|
|
69
|
+
|
|
70
|
+
def __str__(self) -> str:
|
|
71
|
+
attrs = self.params
|
|
72
|
+
return f"{self.__class__.__name__}({', '.join((f'{k}={v}' for k, v in attrs.items()))})"
|
|
73
|
+
|
|
74
|
+
def get_params(self) -> dict[str, Any]:
|
|
75
|
+
"""Get the parameters of the setting."""
|
|
76
|
+
# Get all instance attributes that don't start with underscore
|
|
77
|
+
# and are not special attributes
|
|
78
|
+
exclude_attrs = {"prediction_data_processor"}
|
|
79
|
+
|
|
80
|
+
params = {}
|
|
81
|
+
for attr_name, attr_value in vars(self).items():
|
|
82
|
+
if not attr_name.startswith("_") and attr_name not in exclude_attrs:
|
|
83
|
+
params[attr_name] = attr_value
|
|
84
|
+
|
|
85
|
+
return params
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def identifier(self) -> str:
|
|
89
|
+
"""Name of the setting."""
|
|
90
|
+
# return f"{super().identifier[:-1]},K={self.K})"
|
|
91
|
+
paramstring = ",".join((f"{k}={v}" for k, v in self.params.items() if v is not None))
|
|
92
|
+
return self.name + "(" + paramstring + ")"
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def _split(self, data: InteractionMatrix) -> None:
|
|
96
|
+
"""Split data according to the setting.
|
|
97
|
+
|
|
98
|
+
This abstract method must be implemented by concrete setting classes
|
|
99
|
+
to split data into background_data, ground_truth_data, and unlabeled_data.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
data: Interaction matrix to be split.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def split(self, data: InteractionMatrix) -> None:
|
|
106
|
+
"""Split data according to the setting.
|
|
107
|
+
|
|
108
|
+
Calling this method changes the state of the setting object to be ready
|
|
109
|
+
for evaluation. The method splits data into background_data, ground_truth_data,
|
|
110
|
+
and unlabeled_data.
|
|
111
|
+
|
|
112
|
+
Note:
|
|
113
|
+
SlidingWindowSetting will have an additional attribute incremental_data.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
data: Interaction matrix to be split.
|
|
117
|
+
"""
|
|
118
|
+
logger.debug("Splitting data...")
|
|
119
|
+
self._num_full_interactions = data.num_interactions
|
|
120
|
+
start = time.time()
|
|
121
|
+
self._split(data)
|
|
122
|
+
end = time.time()
|
|
123
|
+
logger.info(f"{self.name} data split - Took {end - start:.3}s")
|
|
124
|
+
|
|
125
|
+
logger.debug("Checking split attribute and sizes.")
|
|
126
|
+
self._check_split()
|
|
127
|
+
|
|
128
|
+
self._split_complete = True
|
|
129
|
+
logger.info(f"{self.name} data split complete.")
|
|
130
|
+
|
|
131
|
+
def _check_split_complete(self) -> None:
|
|
132
|
+
"""Check if the setting is ready for evaluation.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
KeyError: If the setting has not been split yet.
|
|
136
|
+
"""
|
|
137
|
+
if not self.is_ready:
|
|
138
|
+
raise KeyError("Setting has not been split yet. Call split() method before accessing the property.")
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def num_split(self) -> int:
|
|
142
|
+
"""Get number of splits created from dataset.
|
|
143
|
+
|
|
144
|
+
This property defaults to 1 (no splits on training set) for typical settings.
|
|
145
|
+
For SlidingWindowSetting, this is typically greater than 1 if there are
|
|
146
|
+
multiple splits created from the sliding window.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Number of splits created from dataset.
|
|
150
|
+
"""
|
|
151
|
+
return self._num_split_set
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def is_ready(self) -> bool:
|
|
155
|
+
"""Check if setting is ready for evaluation.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
True if the setting has been split and is ready to use.
|
|
159
|
+
"""
|
|
160
|
+
return self._split_complete
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def is_sliding_window_setting(self) -> bool:
|
|
164
|
+
"""Check if setting is SlidingWindowSetting.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
True if this is a SlidingWindowSetting instance.
|
|
168
|
+
"""
|
|
169
|
+
return self._sliding_window_setting
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def background_data(self) -> InteractionMatrix:
|
|
173
|
+
"""Get background data for initial model training.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
InteractionMatrix of training interactions.
|
|
177
|
+
"""
|
|
178
|
+
self._check_split_complete()
|
|
179
|
+
return self._background_data
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def t_window(self) -> Union[None, int, list[int]]:
|
|
183
|
+
"""Get the upper timestamp of the window in split.
|
|
184
|
+
|
|
185
|
+
In settings that respect the global timeline, returns a timestamp value.
|
|
186
|
+
In `SlidingWindowSetting`, returns a list of timestamp values.
|
|
187
|
+
In settings like `LeaveNOutSetting`, returns None.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Timestamp limit for the data (int, list of ints, or None).
|
|
191
|
+
"""
|
|
192
|
+
self._check_split_complete()
|
|
193
|
+
return self._t_window
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def unlabeled_data(self) -> InteractionMatrix | list[InteractionMatrix]:
|
|
197
|
+
"""Get unlabeled data for model predictions.
|
|
198
|
+
|
|
199
|
+
Contains the user/item ID for prediction along with previous sequential
|
|
200
|
+
interactions. Used to make predictions on ground truth data.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
Single InteractionMatrix or list of InteractionMatrix for sliding window setting.
|
|
204
|
+
"""
|
|
205
|
+
self._check_split_complete()
|
|
206
|
+
|
|
207
|
+
if not self._sliding_window_setting:
|
|
208
|
+
return self._unlabeled_data
|
|
209
|
+
return self._unlabeled_data
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def ground_truth_data(self) -> InteractionMatrix | list[InteractionMatrix]:
|
|
213
|
+
"""Get ground truth data for model evaluation.
|
|
214
|
+
|
|
215
|
+
Contains the actual interactions of user-item that the model should predict.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
Single InteractionMatrix or list of InteractionMatrix for sliding window.
|
|
219
|
+
"""
|
|
220
|
+
self._check_split_complete()
|
|
221
|
+
|
|
222
|
+
if not self._sliding_window_setting:
|
|
223
|
+
return self._ground_truth_data
|
|
224
|
+
return self._ground_truth_data
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def incremental_data(self) -> list[InteractionMatrix]:
|
|
228
|
+
"""Get data for incrementally updating the model.
|
|
229
|
+
|
|
230
|
+
Only available for SlidingWindowSetting.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
List of InteractionMatrix objects for incremental updates.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
AttributeError: If setting is not SlidingWindowSetting.
|
|
237
|
+
"""
|
|
238
|
+
self._check_split_complete()
|
|
239
|
+
|
|
240
|
+
if not self._sliding_window_setting:
|
|
241
|
+
raise AttributeError("Incremental data is only available for sliding window setting.")
|
|
242
|
+
return self._incremental_data
|
|
243
|
+
|
|
244
|
+
def _check_split(self) -> None:
|
|
245
|
+
"""Checks that the splits have been done properly.
|
|
246
|
+
|
|
247
|
+
Makes sure all expected attributes are set.
|
|
248
|
+
"""
|
|
249
|
+
logger.debug("Checking split attributes.")
|
|
250
|
+
assert hasattr(self, "_background_data") and self._background_data is not None
|
|
251
|
+
|
|
252
|
+
assert (hasattr(self, "_unlabeled_data") and self._unlabeled_data is not None) or (
|
|
253
|
+
hasattr(self, "_unlabeled_data") and self._unlabeled_data is not None
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
assert (hasattr(self, "_ground_truth_data") and self._ground_truth_data is not None) or (
|
|
257
|
+
hasattr(self, "_ground_truth_data") and self._ground_truth_data is not None
|
|
258
|
+
)
|
|
259
|
+
logger.debug("Split attributes are set.")
|
|
260
|
+
|
|
261
|
+
self._check_size()
|
|
262
|
+
|
|
263
|
+
def _check_size(self) -> None:
|
|
264
|
+
"""
|
|
265
|
+
Warns user if any of the sets is unusually small or empty
|
|
266
|
+
"""
|
|
267
|
+
logger.debug("Checking size of split sets.")
|
|
268
|
+
|
|
269
|
+
def check_ratio(name, count, total, threshold) -> None:
|
|
270
|
+
if check_empty(name, count):
|
|
271
|
+
return
|
|
272
|
+
|
|
273
|
+
if (count + 1e-9) / (total + 1e-9) < threshold:
|
|
274
|
+
warn(UserWarning(f"{name} resulting from {self.name} is unusually small."))
|
|
275
|
+
|
|
276
|
+
def check_empty(name, count) -> bool:
|
|
277
|
+
if count == 0:
|
|
278
|
+
warn(UserWarning(f"{name} resulting from {self.name} is empty (no interactions)."))
|
|
279
|
+
return True
|
|
280
|
+
return False
|
|
281
|
+
|
|
282
|
+
n_background = self._background_data.num_interactions
|
|
283
|
+
# check_empty("Background data", n_background)
|
|
284
|
+
check_ratio("Background data", n_background, self._num_full_interactions, 0.05)
|
|
285
|
+
|
|
286
|
+
if not self._sliding_window_setting:
|
|
287
|
+
n_unlabel = self._unlabeled_data.num_interactions
|
|
288
|
+
n_ground_truth = self._ground_truth_data.num_interactions
|
|
289
|
+
|
|
290
|
+
check_empty("Unlabeled data", n_unlabel)
|
|
291
|
+
# check_empty("Ground truth data", n_ground_truth)
|
|
292
|
+
check_ratio("Ground truth data", n_ground_truth, n_unlabel, 0.05)
|
|
293
|
+
|
|
294
|
+
else:
|
|
295
|
+
for dataset_idx in range(self._num_split_set):
|
|
296
|
+
n_unlabel = self._unlabeled_data[dataset_idx].num_interactions
|
|
297
|
+
n_ground_truth = self._ground_truth_data[dataset_idx].num_interactions
|
|
298
|
+
|
|
299
|
+
check_empty(f"Unlabeled data[{dataset_idx}]", n_unlabel)
|
|
300
|
+
check_empty(f"Ground truth data[{dataset_idx}]", n_ground_truth)
|
|
301
|
+
logger.debug("Size of split sets are checked.")
|
|
302
|
+
|
|
303
|
+
def restore(self, n: int = 0) -> None:
|
|
304
|
+
"""Restore last run.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
n: Iteration number to restore to. If None, restores to beginning.
|
|
308
|
+
"""
|
|
309
|
+
logger.debug(f"Restoring setting to iteration {n}")
|
|
310
|
+
self.current_index = n
|
|
311
|
+
|
|
312
|
+
def __iter__(self) -> Self:
|
|
313
|
+
"""Iterate over splits in the setting.
|
|
314
|
+
|
|
315
|
+
Resets the index and returns self as the iterator.
|
|
316
|
+
Yields a SplitResult for each split: {'unlabeled', 'ground_truth', 't_window', 'incremental'}.
|
|
317
|
+
"""
|
|
318
|
+
self.current_index = 0
|
|
319
|
+
return self
|
|
320
|
+
|
|
321
|
+
def __next__(self) -> SplitResult:
|
|
322
|
+
"""Get the next split.
|
|
323
|
+
|
|
324
|
+
Returns:
|
|
325
|
+
SplitResult with split data.
|
|
326
|
+
|
|
327
|
+
Raises:
|
|
328
|
+
EOWSettingError: If no more splits.
|
|
329
|
+
"""
|
|
330
|
+
if self.current_index >= self.num_split:
|
|
331
|
+
raise EOWSettingError("No more splits available, EOW reached.")
|
|
332
|
+
|
|
333
|
+
if self._sliding_window_setting:
|
|
334
|
+
if not (
|
|
335
|
+
isinstance(self._unlabeled_data, list)
|
|
336
|
+
and isinstance(self._ground_truth_data, list)
|
|
337
|
+
and isinstance(self._t_window, list)
|
|
338
|
+
):
|
|
339
|
+
raise ValueError("Expected list of InteractionMatrix for sliding window setting.")
|
|
340
|
+
result = SplitResult(
|
|
341
|
+
unlabeled=self._unlabeled_data[self.current_index],
|
|
342
|
+
ground_truth=self._ground_truth_data[self.current_index],
|
|
343
|
+
t_window=self._t_window[self.current_index],
|
|
344
|
+
incremental=(
|
|
345
|
+
self._incremental_data[self.current_index - 1]
|
|
346
|
+
if self.current_index < len(self._incremental_data) and self.current_index > 1
|
|
347
|
+
else None
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
else:
|
|
351
|
+
if (
|
|
352
|
+
isinstance(self._unlabeled_data, list)
|
|
353
|
+
or isinstance(self._ground_truth_data, list)
|
|
354
|
+
or isinstance(self._t_window, list)
|
|
355
|
+
):
|
|
356
|
+
raise ValueError("Expected single InteractionMatrix for non-sliding window setting.")
|
|
357
|
+
result = SplitResult(
|
|
358
|
+
unlabeled=self._unlabeled_data,
|
|
359
|
+
ground_truth=self._ground_truth_data,
|
|
360
|
+
t_window=self._t_window,
|
|
361
|
+
incremental=None,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
self.current_index += 1
|
|
365
|
+
return result
|
|
366
|
+
|
|
367
|
+
def get_split_at(self, index: int) -> SplitResult:
|
|
368
|
+
"""Get the split data at a specific index.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
index: The index of the split to retrieve.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
SplitResult with keys: 'unlabeled', 'ground_truth', 't_window', 'incremental'.
|
|
375
|
+
|
|
376
|
+
Raises:
|
|
377
|
+
IndexError: If index is out of range.
|
|
378
|
+
"""
|
|
379
|
+
if index < 0 or index > self.num_split:
|
|
380
|
+
raise IndexError(f"Index {index} out of range for {self.num_split} splits")
|
|
381
|
+
|
|
382
|
+
if self._sliding_window_setting:
|
|
383
|
+
if not (
|
|
384
|
+
isinstance(self._unlabeled_data, list)
|
|
385
|
+
and isinstance(self._ground_truth_data, list)
|
|
386
|
+
and isinstance(self._t_window, list)
|
|
387
|
+
):
|
|
388
|
+
raise ValueError("Expected list of InteractionMatrix for sliding window setting.")
|
|
389
|
+
result = SplitResult(
|
|
390
|
+
unlabeled=self._unlabeled_data[index],
|
|
391
|
+
ground_truth=self._ground_truth_data[index],
|
|
392
|
+
# TODO change this variable to training_data when refactoring
|
|
393
|
+
incremental=(
|
|
394
|
+
self._incremental_data[index - 1] if index < len(self._incremental_data) and index > 0 else None
|
|
395
|
+
),
|
|
396
|
+
t_window=self._t_window[index],
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
if index != 0:
|
|
400
|
+
raise IndexError("Non-sliding setting has only one split at index 0")
|
|
401
|
+
if (
|
|
402
|
+
isinstance(self._unlabeled_data, list)
|
|
403
|
+
or isinstance(self._ground_truth_data, list)
|
|
404
|
+
or isinstance(self._t_window, list)
|
|
405
|
+
):
|
|
406
|
+
raise ValueError("Expected single data for non-sliding setting.")
|
|
407
|
+
result = SplitResult(
|
|
408
|
+
unlabeled=self._unlabeled_data,
|
|
409
|
+
ground_truth=self._ground_truth_data,
|
|
410
|
+
incremental=None,
|
|
411
|
+
t_window=self._t_window,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
return result
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from recnexteval.matrix import InteractionMatrix
|
|
4
|
+
from .base import Setting
|
|
5
|
+
from .splitters import (
|
|
6
|
+
NLastInteractionSplitter,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class LeaveNOutSetting(Setting):
|
|
14
|
+
IS_BASE: bool = False
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
n_seq_data: int = 1,
|
|
19
|
+
N: int = 1,
|
|
20
|
+
seed: int = 42,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(seed=seed)
|
|
23
|
+
self.n_seq_data = n_seq_data
|
|
24
|
+
# we use top_K to denote the number of items to predict
|
|
25
|
+
self.top_K = N
|
|
26
|
+
|
|
27
|
+
logger.info("Splitting data")
|
|
28
|
+
|
|
29
|
+
self._splitter = NLastInteractionSplitter(N, n_seq_data)
|
|
30
|
+
|
|
31
|
+
def _split(self, data: InteractionMatrix) -> None:
|
|
32
|
+
"""Splits your dataset into a training, validation and test dataset
|
|
33
|
+
based on the timestamp of the interaction.
|
|
34
|
+
|
|
35
|
+
:param data: Interaction matrix to be split. Must contain timestamps.
|
|
36
|
+
:type data: InteractionMatrix
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
self._background_data, future_interaction = self._splitter.split(data)
|
|
40
|
+
# we need to copy the data to avoid modifying the background data
|
|
41
|
+
past_interaction = self._background_data.copy()
|
|
42
|
+
|
|
43
|
+
self._unlabeled_data, self._ground_truth_data = self.prediction_data_processor.process(
|
|
44
|
+
past_interaction=past_interaction,
|
|
45
|
+
future_interaction=future_interaction,
|
|
46
|
+
top_K=self.top_K,
|
|
47
|
+
)
|
|
48
|
+
self._t_window = None
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
from recnexteval.matrix import InteractionMatrix
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Processor(ABC):
|
|
7
|
+
"""Base class for processing data.
|
|
8
|
+
|
|
9
|
+
Abstract class for processing data. The `process` method should be
|
|
10
|
+
implemented by the subclass to process the data. The programmer can
|
|
11
|
+
create a subclass of preferred processing of data split.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def process(
|
|
19
|
+
self,
|
|
20
|
+
past_interaction: InteractionMatrix,
|
|
21
|
+
future_interaction: InteractionMatrix,
|
|
22
|
+
) -> tuple[InteractionMatrix, InteractionMatrix]:
|
|
23
|
+
"""Injects the user ID to indicate ID for prediction.
|
|
24
|
+
|
|
25
|
+
User ID to be predicted by the model will be indicated with item ID of
|
|
26
|
+
"-1" as the corresponding label. The matrix with past interactions will
|
|
27
|
+
contain the user ID to be predicted which will be derived from the set
|
|
28
|
+
of user IDs in the future interaction matrix. Timestamp of the masked
|
|
29
|
+
interactions will be preserved as the item ID are simply masked with
|
|
30
|
+
"-1".
|
|
31
|
+
|
|
32
|
+
:param past_interaction: Matrix of past interactions.
|
|
33
|
+
:type past_interaction: InteractionMatrix
|
|
34
|
+
:param future_interaction: Matrix of future interactions.
|
|
35
|
+
:type future_interaction: InteractionMatrix
|
|
36
|
+
:return: Tuple of past interaction with injected user ID to predict and
|
|
37
|
+
ground truth future interactions of the actual interaction
|
|
38
|
+
:rtype: Tuple[InteractionMatrix, InteractionMatrix]
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class PredictionDataProcessor(Processor):
|
|
44
|
+
"""Injects the user ID to indicate ID for prediction.
|
|
45
|
+
|
|
46
|
+
Operates on the past and future interaction matrices to inject the user
|
|
47
|
+
ID to be predicted by the model into the past interaction matrix. The
|
|
48
|
+
resulting past interaction matrix will contain the user ID to be
|
|
49
|
+
predicted which will be derived from the set of user IDs in the future
|
|
50
|
+
interaction matrix. Timestamp of the masked interactions will be preserved as
|
|
51
|
+
the item ID are simply masked with "-1".
|
|
52
|
+
|
|
53
|
+
The corresponding ground truth future interactions of the actual interaction
|
|
54
|
+
will be returned as well in a tuple when `process` is called.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def _inject_user_id(
|
|
58
|
+
self,
|
|
59
|
+
past_interaction: InteractionMatrix,
|
|
60
|
+
future_interaction: InteractionMatrix,
|
|
61
|
+
top_K: int = 1,
|
|
62
|
+
) -> tuple[InteractionMatrix, InteractionMatrix]:
|
|
63
|
+
"""Injects the user ID to indicate ID for prediction.
|
|
64
|
+
|
|
65
|
+
User ID to be predicted by the model will be indicated with item ID of
|
|
66
|
+
"-1" as the corresponding label. The matrix with past interactions will
|
|
67
|
+
contain the user ID to be predicted which will be derived from the set
|
|
68
|
+
of user IDs in the future interaction matrix. Timestamp of the masked
|
|
69
|
+
interactions will be preserved as the item ID are simply masked with
|
|
70
|
+
"-1".
|
|
71
|
+
|
|
72
|
+
:param past_interaction: Matrix of past interactions.
|
|
73
|
+
:type past_interaction: InteractionMatrix
|
|
74
|
+
:param future_interaction: Matrix of future interactions.
|
|
75
|
+
:type future_interaction: InteractionMatrix
|
|
76
|
+
:return: Tuple of past interaction with injected user ID to predict and
|
|
77
|
+
ground truth future interactions of the actual interaction
|
|
78
|
+
:rtype: Tuple[InteractionMatrix, InteractionMatrix]
|
|
79
|
+
"""
|
|
80
|
+
users_to_predict = future_interaction.get_users_n_first_interaction(top_K)
|
|
81
|
+
masked_frame = users_to_predict.copy_df()
|
|
82
|
+
masked_frame[InteractionMatrix.ITEM_IX] = InteractionMatrix.MASKED_LABEL
|
|
83
|
+
return past_interaction.concat(masked_frame), users_to_predict
|
|
84
|
+
|
|
85
|
+
def _inject_item_id(
|
|
86
|
+
self,
|
|
87
|
+
past_interaction: InteractionMatrix,
|
|
88
|
+
future_interaction: InteractionMatrix,
|
|
89
|
+
top_K: int = 1,
|
|
90
|
+
) -> tuple[InteractionMatrix, InteractionMatrix]:
|
|
91
|
+
"""Injects the item ID to indicate ID for prediction.
|
|
92
|
+
|
|
93
|
+
User ID to be predicted by the model will be indicated with item ID of
|
|
94
|
+
"-1" as the corresponding label. The matrix with past interactions will
|
|
95
|
+
contain the item ID to be predicted which will be derived from the set
|
|
96
|
+
of item IDs in the future interaction matrix. Timestamp of the masked
|
|
97
|
+
interactions will be preserved as the item ID are simply masked with
|
|
98
|
+
"-1".
|
|
99
|
+
"""
|
|
100
|
+
items_to_predict = future_interaction.get_items_n_first_interaction(top_K)
|
|
101
|
+
masked_frame = items_to_predict.copy_df()
|
|
102
|
+
masked_frame[InteractionMatrix.USER_IX] = InteractionMatrix.MASKED_LABEL
|
|
103
|
+
return past_interaction.concat(masked_frame), items_to_predict
|
|
104
|
+
|
|
105
|
+
def process(
|
|
106
|
+
self,
|
|
107
|
+
past_interaction: InteractionMatrix,
|
|
108
|
+
future_interaction: InteractionMatrix,
|
|
109
|
+
top_K: int = 1,
|
|
110
|
+
) -> tuple[InteractionMatrix, InteractionMatrix]:
|
|
111
|
+
return self._inject_user_id(
|
|
112
|
+
past_interaction=past_interaction,
|
|
113
|
+
future_interaction=future_interaction,
|
|
114
|
+
top_K=top_K,
|
|
115
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from typing import NamedTuple
|
|
2
|
+
|
|
3
|
+
from recnexteval.matrix import InteractionMatrix
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SplitResult(NamedTuple):
|
|
7
|
+
"""Named tuple for split data results."""
|
|
8
|
+
unlabeled: InteractionMatrix
|
|
9
|
+
ground_truth: InteractionMatrix
|
|
10
|
+
incremental: InteractionMatrix | None
|
|
11
|
+
t_window: int | None
|