recnexteval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. recnexteval/__init__.py +20 -0
  2. recnexteval/algorithms/__init__.py +99 -0
  3. recnexteval/algorithms/base.py +377 -0
  4. recnexteval/algorithms/baseline/__init__.py +10 -0
  5. recnexteval/algorithms/baseline/decay_popularity.py +110 -0
  6. recnexteval/algorithms/baseline/most_popular.py +72 -0
  7. recnexteval/algorithms/baseline/random.py +39 -0
  8. recnexteval/algorithms/baseline/recent_popularity.py +34 -0
  9. recnexteval/algorithms/itemknn/__init__.py +14 -0
  10. recnexteval/algorithms/itemknn/itemknn.py +119 -0
  11. recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
  12. recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
  13. recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
  14. recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
  15. recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
  16. recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
  17. recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
  18. recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
  19. recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
  20. recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
  21. recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
  22. recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
  23. recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
  24. recnexteval/algorithms/utils.py +51 -0
  25. recnexteval/datasets/__init__.py +109 -0
  26. recnexteval/datasets/base.py +316 -0
  27. recnexteval/datasets/config/__init__.py +113 -0
  28. recnexteval/datasets/config/amazon.py +188 -0
  29. recnexteval/datasets/config/base.py +72 -0
  30. recnexteval/datasets/config/lastfm.py +105 -0
  31. recnexteval/datasets/config/movielens.py +169 -0
  32. recnexteval/datasets/config/yelp.py +25 -0
  33. recnexteval/datasets/datasets/__init__.py +24 -0
  34. recnexteval/datasets/datasets/amazon.py +151 -0
  35. recnexteval/datasets/datasets/base.py +250 -0
  36. recnexteval/datasets/datasets/lastfm.py +121 -0
  37. recnexteval/datasets/datasets/movielens.py +93 -0
  38. recnexteval/datasets/datasets/test.py +46 -0
  39. recnexteval/datasets/datasets/yelp.py +103 -0
  40. recnexteval/datasets/metadata/__init__.py +58 -0
  41. recnexteval/datasets/metadata/amazon.py +68 -0
  42. recnexteval/datasets/metadata/base.py +38 -0
  43. recnexteval/datasets/metadata/lastfm.py +110 -0
  44. recnexteval/datasets/metadata/movielens.py +87 -0
  45. recnexteval/evaluators/__init__.py +189 -0
  46. recnexteval/evaluators/accumulator.py +167 -0
  47. recnexteval/evaluators/base.py +216 -0
  48. recnexteval/evaluators/builder/__init__.py +125 -0
  49. recnexteval/evaluators/builder/base.py +166 -0
  50. recnexteval/evaluators/builder/pipeline.py +111 -0
  51. recnexteval/evaluators/builder/stream.py +54 -0
  52. recnexteval/evaluators/evaluator_pipeline.py +287 -0
  53. recnexteval/evaluators/evaluator_stream.py +374 -0
  54. recnexteval/evaluators/state_management.py +310 -0
  55. recnexteval/evaluators/strategy.py +32 -0
  56. recnexteval/evaluators/util.py +124 -0
  57. recnexteval/matrix/__init__.py +48 -0
  58. recnexteval/matrix/exception.py +5 -0
  59. recnexteval/matrix/interaction_matrix.py +784 -0
  60. recnexteval/matrix/prediction_matrix.py +153 -0
  61. recnexteval/matrix/util.py +24 -0
  62. recnexteval/metrics/__init__.py +57 -0
  63. recnexteval/metrics/binary/__init__.py +4 -0
  64. recnexteval/metrics/binary/hit.py +49 -0
  65. recnexteval/metrics/core/__init__.py +10 -0
  66. recnexteval/metrics/core/base.py +126 -0
  67. recnexteval/metrics/core/elementwise_top_k.py +75 -0
  68. recnexteval/metrics/core/listwise_top_k.py +72 -0
  69. recnexteval/metrics/core/top_k.py +60 -0
  70. recnexteval/metrics/core/util.py +29 -0
  71. recnexteval/metrics/ranking/__init__.py +6 -0
  72. recnexteval/metrics/ranking/dcg.py +55 -0
  73. recnexteval/metrics/ranking/ndcg.py +78 -0
  74. recnexteval/metrics/ranking/precision.py +51 -0
  75. recnexteval/metrics/ranking/recall.py +42 -0
  76. recnexteval/models/__init__.py +4 -0
  77. recnexteval/models/base.py +69 -0
  78. recnexteval/preprocessing/__init__.py +37 -0
  79. recnexteval/preprocessing/filter.py +181 -0
  80. recnexteval/preprocessing/preprocessor.py +137 -0
  81. recnexteval/registries/__init__.py +67 -0
  82. recnexteval/registries/algorithm.py +68 -0
  83. recnexteval/registries/base.py +131 -0
  84. recnexteval/registries/dataset.py +37 -0
  85. recnexteval/registries/metric.py +57 -0
  86. recnexteval/settings/__init__.py +127 -0
  87. recnexteval/settings/base.py +414 -0
  88. recnexteval/settings/exception.py +8 -0
  89. recnexteval/settings/leave_n_out_setting.py +48 -0
  90. recnexteval/settings/processor.py +115 -0
  91. recnexteval/settings/schema.py +11 -0
  92. recnexteval/settings/single_time_point_setting.py +111 -0
  93. recnexteval/settings/sliding_window_setting.py +153 -0
  94. recnexteval/settings/splitters/__init__.py +14 -0
  95. recnexteval/settings/splitters/base.py +57 -0
  96. recnexteval/settings/splitters/n_last.py +39 -0
  97. recnexteval/settings/splitters/n_last_timestamp.py +76 -0
  98. recnexteval/settings/splitters/timestamp.py +82 -0
  99. recnexteval/settings/util.py +0 -0
  100. recnexteval/utils/__init__.py +115 -0
  101. recnexteval/utils/json_to_csv_converter.py +128 -0
  102. recnexteval/utils/logging_tools.py +159 -0
  103. recnexteval/utils/path.py +155 -0
  104. recnexteval/utils/url_certificate_installer.py +54 -0
  105. recnexteval/utils/util.py +166 -0
  106. recnexteval/utils/uuid_util.py +7 -0
  107. recnexteval/utils/yaml_tool.py +65 -0
  108. recnexteval-0.1.0.dist-info/METADATA +85 -0
  109. recnexteval-0.1.0.dist-info/RECORD +110 -0
  110. recnexteval-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,414 @@
1
+ import logging
2
+ import time
3
+ from abc import abstractmethod
4
+ from typing import Any, Self, Union
5
+ from warnings import warn
6
+
7
+ from recnexteval.matrix import InteractionMatrix
8
+ from ..models import BaseModel, ParamMixin
9
+ from .exception import EOWSettingError
10
+ from .processor import PredictionDataProcessor
11
+ from .schema import SplitResult
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Setting(BaseModel, ParamMixin):
18
+ """Base class for defining an evaluation setting.
19
+
20
+ Core Attributes:
21
+ - background_data: Data used for inital training of model. Interval is [0, background_t).
22
+ - unlabeled_data: List of unlabeled data. Each element is an InteractionMatrix
23
+ object of interval [0, t).
24
+ - ground_truth_data: List of ground truth data. Each element is an
25
+ InteractionMatrix object of interval [t, t + window_size).
26
+ - incremental_data: List of data used to incrementally update the model.
27
+ Each element is an InteractionMatrix object of interval [t, t + window_size).
28
+ Unique to SlidingWindowSetting.
29
+ - data_timestamp_limit: List of timestamps that the splitter will slide over.
30
+
31
+ We will use `background_data` as the initial training set, `incremental_data` as the data
32
+ to incrementally update the model. However, for public methods, we will refer to both as
33
+ `training_data` to avoid confusion.
34
+
35
+ Args:
36
+ seed: Seed for randomization. Defaults to 42.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ seed: int = 42,
42
+ ) -> None:
43
+ """Initialize the setting.
44
+
45
+ Args:
46
+ seed: Random seed for reproducibility.
47
+ """
48
+ self.seed = seed
49
+ self.prediction_data_processor = PredictionDataProcessor()
50
+ self._num_split_set = 1
51
+
52
+ self._sliding_window_setting = False
53
+ self._split_complete = False
54
+ """Number of splits created from sliding window. Defaults to 1 (no splits on training set)."""
55
+ self._num_full_interactions: int
56
+ self._unlabeled_data: InteractionMatrix | list[InteractionMatrix]
57
+ self._ground_truth_data: InteractionMatrix | list[InteractionMatrix]
58
+ """Data containing the ground truth interactions to :attr:`_unlabeled_data`. If :class:`SlidingWindowSetting`, then it will be a list of :class:`InteractionMatrix`."""
59
+ self._incremental_data: list[InteractionMatrix]
60
+ """Data that is used to incrementally update the model. Unique to :class:`SlidingWindowSetting`."""
61
+ self._background_data: InteractionMatrix
62
+ """Data used as the initial set of interactions to train the model."""
63
+ self._t_window: Union[None, int, list[int]]
64
+ """This is the upper timestamp of the window in split. The actual interaction might have a smaller timestamp value than this because this will is the t cut off value."""
65
+ self.n_seq_data: int
66
+ """Number of last sequential interactions to provide in :attr:`unlabeled_data` as data for model to make prediction."""
67
+ self.top_K: int
68
+ """Number of interaction per user that should be selected for evaluation purposes in :attr:`ground_truth_data`."""
69
+
70
+ def __str__(self) -> str:
71
+ attrs = self.params
72
+ return f"{self.__class__.__name__}({', '.join((f'{k}={v}' for k, v in attrs.items()))})"
73
+
74
+ def get_params(self) -> dict[str, Any]:
75
+ """Get the parameters of the setting."""
76
+ # Get all instance attributes that don't start with underscore
77
+ # and are not special attributes
78
+ exclude_attrs = {"prediction_data_processor"}
79
+
80
+ params = {}
81
+ for attr_name, attr_value in vars(self).items():
82
+ if not attr_name.startswith("_") and attr_name not in exclude_attrs:
83
+ params[attr_name] = attr_value
84
+
85
+ return params
86
+
87
+ @property
88
+ def identifier(self) -> str:
89
+ """Name of the setting."""
90
+ # return f"{super().identifier[:-1]},K={self.K})"
91
+ paramstring = ",".join((f"{k}={v}" for k, v in self.params.items() if v is not None))
92
+ return self.name + "(" + paramstring + ")"
93
+
94
+ @abstractmethod
95
+ def _split(self, data: InteractionMatrix) -> None:
96
+ """Split data according to the setting.
97
+
98
+ This abstract method must be implemented by concrete setting classes
99
+ to split data into background_data, ground_truth_data, and unlabeled_data.
100
+
101
+ Args:
102
+ data: Interaction matrix to be split.
103
+ """
104
+
105
+ def split(self, data: InteractionMatrix) -> None:
106
+ """Split data according to the setting.
107
+
108
+ Calling this method changes the state of the setting object to be ready
109
+ for evaluation. The method splits data into background_data, ground_truth_data,
110
+ and unlabeled_data.
111
+
112
+ Note:
113
+ SlidingWindowSetting will have an additional attribute incremental_data.
114
+
115
+ Args:
116
+ data: Interaction matrix to be split.
117
+ """
118
+ logger.debug("Splitting data...")
119
+ self._num_full_interactions = data.num_interactions
120
+ start = time.time()
121
+ self._split(data)
122
+ end = time.time()
123
+ logger.info(f"{self.name} data split - Took {end - start:.3}s")
124
+
125
+ logger.debug("Checking split attribute and sizes.")
126
+ self._check_split()
127
+
128
+ self._split_complete = True
129
+ logger.info(f"{self.name} data split complete.")
130
+
131
+ def _check_split_complete(self) -> None:
132
+ """Check if the setting is ready for evaluation.
133
+
134
+ Raises:
135
+ KeyError: If the setting has not been split yet.
136
+ """
137
+ if not self.is_ready:
138
+ raise KeyError("Setting has not been split yet. Call split() method before accessing the property.")
139
+
140
+ @property
141
+ def num_split(self) -> int:
142
+ """Get number of splits created from dataset.
143
+
144
+ This property defaults to 1 (no splits on training set) for typical settings.
145
+ For SlidingWindowSetting, this is typically greater than 1 if there are
146
+ multiple splits created from the sliding window.
147
+
148
+ Returns:
149
+ Number of splits created from dataset.
150
+ """
151
+ return self._num_split_set
152
+
153
+ @property
154
+ def is_ready(self) -> bool:
155
+ """Check if setting is ready for evaluation.
156
+
157
+ Returns:
158
+ True if the setting has been split and is ready to use.
159
+ """
160
+ return self._split_complete
161
+
162
+ @property
163
+ def is_sliding_window_setting(self) -> bool:
164
+ """Check if setting is SlidingWindowSetting.
165
+
166
+ Returns:
167
+ True if this is a SlidingWindowSetting instance.
168
+ """
169
+ return self._sliding_window_setting
170
+
171
+ @property
172
+ def background_data(self) -> InteractionMatrix:
173
+ """Get background data for initial model training.
174
+
175
+ Returns:
176
+ InteractionMatrix of training interactions.
177
+ """
178
+ self._check_split_complete()
179
+ return self._background_data
180
+
181
+ @property
182
+ def t_window(self) -> Union[None, int, list[int]]:
183
+ """Get the upper timestamp of the window in split.
184
+
185
+ In settings that respect the global timeline, returns a timestamp value.
186
+ In `SlidingWindowSetting`, returns a list of timestamp values.
187
+ In settings like `LeaveNOutSetting`, returns None.
188
+
189
+ Returns:
190
+ Timestamp limit for the data (int, list of ints, or None).
191
+ """
192
+ self._check_split_complete()
193
+ return self._t_window
194
+
195
+ @property
196
+ def unlabeled_data(self) -> InteractionMatrix | list[InteractionMatrix]:
197
+ """Get unlabeled data for model predictions.
198
+
199
+ Contains the user/item ID for prediction along with previous sequential
200
+ interactions. Used to make predictions on ground truth data.
201
+
202
+ Returns:
203
+ Single InteractionMatrix or list of InteractionMatrix for sliding window setting.
204
+ """
205
+ self._check_split_complete()
206
+
207
+ if not self._sliding_window_setting:
208
+ return self._unlabeled_data
209
+ return self._unlabeled_data
210
+
211
+ @property
212
+ def ground_truth_data(self) -> InteractionMatrix | list[InteractionMatrix]:
213
+ """Get ground truth data for model evaluation.
214
+
215
+ Contains the actual interactions of user-item that the model should predict.
216
+
217
+ Returns:
218
+ Single InteractionMatrix or list of InteractionMatrix for sliding window.
219
+ """
220
+ self._check_split_complete()
221
+
222
+ if not self._sliding_window_setting:
223
+ return self._ground_truth_data
224
+ return self._ground_truth_data
225
+
226
+ @property
227
+ def incremental_data(self) -> list[InteractionMatrix]:
228
+ """Get data for incrementally updating the model.
229
+
230
+ Only available for SlidingWindowSetting.
231
+
232
+ Returns:
233
+ List of InteractionMatrix objects for incremental updates.
234
+
235
+ Raises:
236
+ AttributeError: If setting is not SlidingWindowSetting.
237
+ """
238
+ self._check_split_complete()
239
+
240
+ if not self._sliding_window_setting:
241
+ raise AttributeError("Incremental data is only available for sliding window setting.")
242
+ return self._incremental_data
243
+
244
+ def _check_split(self) -> None:
245
+ """Checks that the splits have been done properly.
246
+
247
+ Makes sure all expected attributes are set.
248
+ """
249
+ logger.debug("Checking split attributes.")
250
+ assert hasattr(self, "_background_data") and self._background_data is not None
251
+
252
+ assert (hasattr(self, "_unlabeled_data") and self._unlabeled_data is not None) or (
253
+ hasattr(self, "_unlabeled_data") and self._unlabeled_data is not None
254
+ )
255
+
256
+ assert (hasattr(self, "_ground_truth_data") and self._ground_truth_data is not None) or (
257
+ hasattr(self, "_ground_truth_data") and self._ground_truth_data is not None
258
+ )
259
+ logger.debug("Split attributes are set.")
260
+
261
+ self._check_size()
262
+
263
+ def _check_size(self) -> None:
264
+ """
265
+ Warns user if any of the sets is unusually small or empty
266
+ """
267
+ logger.debug("Checking size of split sets.")
268
+
269
+ def check_ratio(name, count, total, threshold) -> None:
270
+ if check_empty(name, count):
271
+ return
272
+
273
+ if (count + 1e-9) / (total + 1e-9) < threshold:
274
+ warn(UserWarning(f"{name} resulting from {self.name} is unusually small."))
275
+
276
+ def check_empty(name, count) -> bool:
277
+ if count == 0:
278
+ warn(UserWarning(f"{name} resulting from {self.name} is empty (no interactions)."))
279
+ return True
280
+ return False
281
+
282
+ n_background = self._background_data.num_interactions
283
+ # check_empty("Background data", n_background)
284
+ check_ratio("Background data", n_background, self._num_full_interactions, 0.05)
285
+
286
+ if not self._sliding_window_setting:
287
+ n_unlabel = self._unlabeled_data.num_interactions
288
+ n_ground_truth = self._ground_truth_data.num_interactions
289
+
290
+ check_empty("Unlabeled data", n_unlabel)
291
+ # check_empty("Ground truth data", n_ground_truth)
292
+ check_ratio("Ground truth data", n_ground_truth, n_unlabel, 0.05)
293
+
294
+ else:
295
+ for dataset_idx in range(self._num_split_set):
296
+ n_unlabel = self._unlabeled_data[dataset_idx].num_interactions
297
+ n_ground_truth = self._ground_truth_data[dataset_idx].num_interactions
298
+
299
+ check_empty(f"Unlabeled data[{dataset_idx}]", n_unlabel)
300
+ check_empty(f"Ground truth data[{dataset_idx}]", n_ground_truth)
301
+ logger.debug("Size of split sets are checked.")
302
+
303
+ def restore(self, n: int = 0) -> None:
304
+ """Restore last run.
305
+
306
+ Args:
307
+ n: Iteration number to restore to. If None, restores to beginning.
308
+ """
309
+ logger.debug(f"Restoring setting to iteration {n}")
310
+ self.current_index = n
311
+
312
+ def __iter__(self) -> Self:
313
+ """Iterate over splits in the setting.
314
+
315
+ Resets the index and returns self as the iterator.
316
+ Yields a SplitResult for each split: {'unlabeled', 'ground_truth', 't_window', 'incremental'}.
317
+ """
318
+ self.current_index = 0
319
+ return self
320
+
321
+ def __next__(self) -> SplitResult:
322
+ """Get the next split.
323
+
324
+ Returns:
325
+ SplitResult with split data.
326
+
327
+ Raises:
328
+ EOWSettingError: If no more splits.
329
+ """
330
+ if self.current_index >= self.num_split:
331
+ raise EOWSettingError("No more splits available, EOW reached.")
332
+
333
+ if self._sliding_window_setting:
334
+ if not (
335
+ isinstance(self._unlabeled_data, list)
336
+ and isinstance(self._ground_truth_data, list)
337
+ and isinstance(self._t_window, list)
338
+ ):
339
+ raise ValueError("Expected list of InteractionMatrix for sliding window setting.")
340
+ result = SplitResult(
341
+ unlabeled=self._unlabeled_data[self.current_index],
342
+ ground_truth=self._ground_truth_data[self.current_index],
343
+ t_window=self._t_window[self.current_index],
344
+ incremental=(
345
+ self._incremental_data[self.current_index - 1]
346
+ if self.current_index < len(self._incremental_data) and self.current_index > 1
347
+ else None
348
+ ),
349
+ )
350
+ else:
351
+ if (
352
+ isinstance(self._unlabeled_data, list)
353
+ or isinstance(self._ground_truth_data, list)
354
+ or isinstance(self._t_window, list)
355
+ ):
356
+ raise ValueError("Expected single InteractionMatrix for non-sliding window setting.")
357
+ result = SplitResult(
358
+ unlabeled=self._unlabeled_data,
359
+ ground_truth=self._ground_truth_data,
360
+ t_window=self._t_window,
361
+ incremental=None,
362
+ )
363
+
364
+ self.current_index += 1
365
+ return result
366
+
367
+ def get_split_at(self, index: int) -> SplitResult:
368
+ """Get the split data at a specific index.
369
+
370
+ Args:
371
+ index: The index of the split to retrieve.
372
+
373
+ Returns:
374
+ SplitResult with keys: 'unlabeled', 'ground_truth', 't_window', 'incremental'.
375
+
376
+ Raises:
377
+ IndexError: If index is out of range.
378
+ """
379
+ if index < 0 or index > self.num_split:
380
+ raise IndexError(f"Index {index} out of range for {self.num_split} splits")
381
+
382
+ if self._sliding_window_setting:
383
+ if not (
384
+ isinstance(self._unlabeled_data, list)
385
+ and isinstance(self._ground_truth_data, list)
386
+ and isinstance(self._t_window, list)
387
+ ):
388
+ raise ValueError("Expected list of InteractionMatrix for sliding window setting.")
389
+ result = SplitResult(
390
+ unlabeled=self._unlabeled_data[index],
391
+ ground_truth=self._ground_truth_data[index],
392
+ # TODO change this variable to training_data when refactoring
393
+ incremental=(
394
+ self._incremental_data[index - 1] if index < len(self._incremental_data) and index > 0 else None
395
+ ),
396
+ t_window=self._t_window[index],
397
+ )
398
+ else:
399
+ if index != 0:
400
+ raise IndexError("Non-sliding setting has only one split at index 0")
401
+ if (
402
+ isinstance(self._unlabeled_data, list)
403
+ or isinstance(self._ground_truth_data, list)
404
+ or isinstance(self._t_window, list)
405
+ ):
406
+ raise ValueError("Expected single data for non-sliding setting.")
407
+ result = SplitResult(
408
+ unlabeled=self._unlabeled_data,
409
+ ground_truth=self._ground_truth_data,
410
+ incremental=None,
411
+ t_window=self._t_window,
412
+ )
413
+
414
+ return result
@@ -0,0 +1,8 @@
1
+ class EOWSettingError(Exception):
2
+ """End of Window Setting Exception."""
3
+
4
+ def __init__(self, message: str | None = None) -> None:
5
+ if not message:
6
+ message = "End of Window reached for the setting."
7
+ self.message = message
8
+ super().__init__(self.message)
@@ -0,0 +1,48 @@
1
+ import logging
2
+
3
+ from recnexteval.matrix import InteractionMatrix
4
+ from .base import Setting
5
+ from .splitters import (
6
+ NLastInteractionSplitter,
7
+ )
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LeaveNOutSetting(Setting):
14
+ IS_BASE: bool = False
15
+
16
+ def __init__(
17
+ self,
18
+ n_seq_data: int = 1,
19
+ N: int = 1,
20
+ seed: int = 42,
21
+ ) -> None:
22
+ super().__init__(seed=seed)
23
+ self.n_seq_data = n_seq_data
24
+ # we use top_K to denote the number of items to predict
25
+ self.top_K = N
26
+
27
+ logger.info("Splitting data")
28
+
29
+ self._splitter = NLastInteractionSplitter(N, n_seq_data)
30
+
31
+ def _split(self, data: InteractionMatrix) -> None:
32
+ """Splits your dataset into a training, validation and test dataset
33
+ based on the timestamp of the interaction.
34
+
35
+ :param data: Interaction matrix to be split. Must contain timestamps.
36
+ :type data: InteractionMatrix
37
+ """
38
+
39
+ self._background_data, future_interaction = self._splitter.split(data)
40
+ # we need to copy the data to avoid modifying the background data
41
+ past_interaction = self._background_data.copy()
42
+
43
+ self._unlabeled_data, self._ground_truth_data = self.prediction_data_processor.process(
44
+ past_interaction=past_interaction,
45
+ future_interaction=future_interaction,
46
+ top_K=self.top_K,
47
+ )
48
+ self._t_window = None
@@ -0,0 +1,115 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ from recnexteval.matrix import InteractionMatrix
4
+
5
+
6
+ class Processor(ABC):
7
+ """Base class for processing data.
8
+
9
+ Abstract class for processing data. The `process` method should be
10
+ implemented by the subclass to process the data. The programmer can
11
+ create a subclass of preferred processing of data split.
12
+ """
13
+
14
+ def __init__(self) -> None:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def process(
19
+ self,
20
+ past_interaction: InteractionMatrix,
21
+ future_interaction: InteractionMatrix,
22
+ ) -> tuple[InteractionMatrix, InteractionMatrix]:
23
+ """Injects the user ID to indicate ID for prediction.
24
+
25
+ User ID to be predicted by the model will be indicated with item ID of
26
+ "-1" as the corresponding label. The matrix with past interactions will
27
+ contain the user ID to be predicted which will be derived from the set
28
+ of user IDs in the future interaction matrix. Timestamp of the masked
29
+ interactions will be preserved as the item ID are simply masked with
30
+ "-1".
31
+
32
+ :param past_interaction: Matrix of past interactions.
33
+ :type past_interaction: InteractionMatrix
34
+ :param future_interaction: Matrix of future interactions.
35
+ :type future_interaction: InteractionMatrix
36
+ :return: Tuple of past interaction with injected user ID to predict and
37
+ ground truth future interactions of the actual interaction
38
+ :rtype: Tuple[InteractionMatrix, InteractionMatrix]
39
+ """
40
+ pass
41
+
42
+
43
+ class PredictionDataProcessor(Processor):
44
+ """Injects the user ID to indicate ID for prediction.
45
+
46
+ Operates on the past and future interaction matrices to inject the user
47
+ ID to be predicted by the model into the past interaction matrix. The
48
+ resulting past interaction matrix will contain the user ID to be
49
+ predicted which will be derived from the set of user IDs in the future
50
+ interaction matrix. Timestamp of the masked interactions will be preserved as
51
+ the item ID are simply masked with "-1".
52
+
53
+ The corresponding ground truth future interactions of the actual interaction
54
+ will be returned as well in a tuple when `process` is called.
55
+ """
56
+
57
+ def _inject_user_id(
58
+ self,
59
+ past_interaction: InteractionMatrix,
60
+ future_interaction: InteractionMatrix,
61
+ top_K: int = 1,
62
+ ) -> tuple[InteractionMatrix, InteractionMatrix]:
63
+ """Injects the user ID to indicate ID for prediction.
64
+
65
+ User ID to be predicted by the model will be indicated with item ID of
66
+ "-1" as the corresponding label. The matrix with past interactions will
67
+ contain the user ID to be predicted which will be derived from the set
68
+ of user IDs in the future interaction matrix. Timestamp of the masked
69
+ interactions will be preserved as the item ID are simply masked with
70
+ "-1".
71
+
72
+ :param past_interaction: Matrix of past interactions.
73
+ :type past_interaction: InteractionMatrix
74
+ :param future_interaction: Matrix of future interactions.
75
+ :type future_interaction: InteractionMatrix
76
+ :return: Tuple of past interaction with injected user ID to predict and
77
+ ground truth future interactions of the actual interaction
78
+ :rtype: Tuple[InteractionMatrix, InteractionMatrix]
79
+ """
80
+ users_to_predict = future_interaction.get_users_n_first_interaction(top_K)
81
+ masked_frame = users_to_predict.copy_df()
82
+ masked_frame[InteractionMatrix.ITEM_IX] = InteractionMatrix.MASKED_LABEL
83
+ return past_interaction.concat(masked_frame), users_to_predict
84
+
85
+ def _inject_item_id(
86
+ self,
87
+ past_interaction: InteractionMatrix,
88
+ future_interaction: InteractionMatrix,
89
+ top_K: int = 1,
90
+ ) -> tuple[InteractionMatrix, InteractionMatrix]:
91
+ """Injects the item ID to indicate ID for prediction.
92
+
93
+ User ID to be predicted by the model will be indicated with item ID of
94
+ "-1" as the corresponding label. The matrix with past interactions will
95
+ contain the item ID to be predicted which will be derived from the set
96
+ of item IDs in the future interaction matrix. Timestamp of the masked
97
+ interactions will be preserved as the item ID are simply masked with
98
+ "-1".
99
+ """
100
+ items_to_predict = future_interaction.get_items_n_first_interaction(top_K)
101
+ masked_frame = items_to_predict.copy_df()
102
+ masked_frame[InteractionMatrix.USER_IX] = InteractionMatrix.MASKED_LABEL
103
+ return past_interaction.concat(masked_frame), items_to_predict
104
+
105
+ def process(
106
+ self,
107
+ past_interaction: InteractionMatrix,
108
+ future_interaction: InteractionMatrix,
109
+ top_K: int = 1,
110
+ ) -> tuple[InteractionMatrix, InteractionMatrix]:
111
+ return self._inject_user_id(
112
+ past_interaction=past_interaction,
113
+ future_interaction=future_interaction,
114
+ top_K=top_K,
115
+ )
@@ -0,0 +1,11 @@
1
+ from typing import NamedTuple
2
+
3
+ from recnexteval.matrix import InteractionMatrix
4
+
5
+
6
+ class SplitResult(NamedTuple):
7
+ """Named tuple for split data results."""
8
+ unlabeled: InteractionMatrix
9
+ ground_truth: InteractionMatrix
10
+ incremental: InteractionMatrix | None
11
+ t_window: int | None