recnexteval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- recnexteval/__init__.py +20 -0
- recnexteval/algorithms/__init__.py +99 -0
- recnexteval/algorithms/base.py +377 -0
- recnexteval/algorithms/baseline/__init__.py +10 -0
- recnexteval/algorithms/baseline/decay_popularity.py +110 -0
- recnexteval/algorithms/baseline/most_popular.py +72 -0
- recnexteval/algorithms/baseline/random.py +39 -0
- recnexteval/algorithms/baseline/recent_popularity.py +34 -0
- recnexteval/algorithms/itemknn/__init__.py +14 -0
- recnexteval/algorithms/itemknn/itemknn.py +119 -0
- recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
- recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
- recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
- recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
- recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
- recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
- recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
- recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
- recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
- recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
- recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
- recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
- recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
- recnexteval/algorithms/utils.py +51 -0
- recnexteval/datasets/__init__.py +109 -0
- recnexteval/datasets/base.py +316 -0
- recnexteval/datasets/config/__init__.py +113 -0
- recnexteval/datasets/config/amazon.py +188 -0
- recnexteval/datasets/config/base.py +72 -0
- recnexteval/datasets/config/lastfm.py +105 -0
- recnexteval/datasets/config/movielens.py +169 -0
- recnexteval/datasets/config/yelp.py +25 -0
- recnexteval/datasets/datasets/__init__.py +24 -0
- recnexteval/datasets/datasets/amazon.py +151 -0
- recnexteval/datasets/datasets/base.py +250 -0
- recnexteval/datasets/datasets/lastfm.py +121 -0
- recnexteval/datasets/datasets/movielens.py +93 -0
- recnexteval/datasets/datasets/test.py +46 -0
- recnexteval/datasets/datasets/yelp.py +103 -0
- recnexteval/datasets/metadata/__init__.py +58 -0
- recnexteval/datasets/metadata/amazon.py +68 -0
- recnexteval/datasets/metadata/base.py +38 -0
- recnexteval/datasets/metadata/lastfm.py +110 -0
- recnexteval/datasets/metadata/movielens.py +87 -0
- recnexteval/evaluators/__init__.py +189 -0
- recnexteval/evaluators/accumulator.py +167 -0
- recnexteval/evaluators/base.py +216 -0
- recnexteval/evaluators/builder/__init__.py +125 -0
- recnexteval/evaluators/builder/base.py +166 -0
- recnexteval/evaluators/builder/pipeline.py +111 -0
- recnexteval/evaluators/builder/stream.py +54 -0
- recnexteval/evaluators/evaluator_pipeline.py +287 -0
- recnexteval/evaluators/evaluator_stream.py +374 -0
- recnexteval/evaluators/state_management.py +310 -0
- recnexteval/evaluators/strategy.py +32 -0
- recnexteval/evaluators/util.py +124 -0
- recnexteval/matrix/__init__.py +48 -0
- recnexteval/matrix/exception.py +5 -0
- recnexteval/matrix/interaction_matrix.py +784 -0
- recnexteval/matrix/prediction_matrix.py +153 -0
- recnexteval/matrix/util.py +24 -0
- recnexteval/metrics/__init__.py +57 -0
- recnexteval/metrics/binary/__init__.py +4 -0
- recnexteval/metrics/binary/hit.py +49 -0
- recnexteval/metrics/core/__init__.py +10 -0
- recnexteval/metrics/core/base.py +126 -0
- recnexteval/metrics/core/elementwise_top_k.py +75 -0
- recnexteval/metrics/core/listwise_top_k.py +72 -0
- recnexteval/metrics/core/top_k.py +60 -0
- recnexteval/metrics/core/util.py +29 -0
- recnexteval/metrics/ranking/__init__.py +6 -0
- recnexteval/metrics/ranking/dcg.py +55 -0
- recnexteval/metrics/ranking/ndcg.py +78 -0
- recnexteval/metrics/ranking/precision.py +51 -0
- recnexteval/metrics/ranking/recall.py +42 -0
- recnexteval/models/__init__.py +4 -0
- recnexteval/models/base.py +69 -0
- recnexteval/preprocessing/__init__.py +37 -0
- recnexteval/preprocessing/filter.py +181 -0
- recnexteval/preprocessing/preprocessor.py +137 -0
- recnexteval/registries/__init__.py +67 -0
- recnexteval/registries/algorithm.py +68 -0
- recnexteval/registries/base.py +131 -0
- recnexteval/registries/dataset.py +37 -0
- recnexteval/registries/metric.py +57 -0
- recnexteval/settings/__init__.py +127 -0
- recnexteval/settings/base.py +414 -0
- recnexteval/settings/exception.py +8 -0
- recnexteval/settings/leave_n_out_setting.py +48 -0
- recnexteval/settings/processor.py +115 -0
- recnexteval/settings/schema.py +11 -0
- recnexteval/settings/single_time_point_setting.py +111 -0
- recnexteval/settings/sliding_window_setting.py +153 -0
- recnexteval/settings/splitters/__init__.py +14 -0
- recnexteval/settings/splitters/base.py +57 -0
- recnexteval/settings/splitters/n_last.py +39 -0
- recnexteval/settings/splitters/n_last_timestamp.py +76 -0
- recnexteval/settings/splitters/timestamp.py +82 -0
- recnexteval/settings/util.py +0 -0
- recnexteval/utils/__init__.py +115 -0
- recnexteval/utils/json_to_csv_converter.py +128 -0
- recnexteval/utils/logging_tools.py +159 -0
- recnexteval/utils/path.py +155 -0
- recnexteval/utils/url_certificate_installer.py +54 -0
- recnexteval/utils/util.py +166 -0
- recnexteval/utils/uuid_util.py +7 -0
- recnexteval/utils/yaml_tool.py +65 -0
- recnexteval-0.1.0.dist-info/METADATA +85 -0
- recnexteval-0.1.0.dist-info/RECORD +110 -0
- recnexteval-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,374 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from scipy.sparse import csr_matrix
|
|
6
|
+
|
|
7
|
+
from recnexteval.algorithms import Algorithm
|
|
8
|
+
from recnexteval.matrix import InteractionMatrix, PredictionMatrix
|
|
9
|
+
from recnexteval.registries import (
|
|
10
|
+
METRIC_REGISTRY,
|
|
11
|
+
MetricEntry,
|
|
12
|
+
)
|
|
13
|
+
from recnexteval.settings import EOWSettingError, Setting
|
|
14
|
+
from .accumulator import MetricAccumulator
|
|
15
|
+
from .base import EvaluatorBase
|
|
16
|
+
from .state_management import AlgorithmStateEnum, AlgorithmStateManager
|
|
17
|
+
from .strategy import EvaluationStrategy, SlidingWindowStrategy
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class EvaluatorState(Enum):
|
|
21
|
+
"""Evaluator lifecycle states"""
|
|
22
|
+
|
|
23
|
+
INITIALIZED = "initialized"
|
|
24
|
+
STARTED = "started"
|
|
25
|
+
IN_PROGRESS = "in_progress"
|
|
26
|
+
COMPLETED = "completed"
|
|
27
|
+
FAILED = "failed"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class EvaluatorStreamer(EvaluatorBase):
|
|
34
|
+
"""Evaluation via streaming through API.
|
|
35
|
+
|
|
36
|
+
The diagram below shows the diagram of the streamer evaluator for the
|
|
37
|
+
sliding window setting. Instead of the pipeline, we allow the user to
|
|
38
|
+
stream the data release to the algorithm. The data communication is shown
|
|
39
|
+
between the evaluator and the algorithm. Note that while only 2 splits are
|
|
40
|
+
shown here, the evaluator will continue to stream the data until the end
|
|
41
|
+
of the setting where there are no more splits.
|
|
42
|
+
|
|
43
|
+

|
|
44
|
+
|
|
45
|
+
This class exposes a few of the core API that allows the user to stream
|
|
46
|
+
the evaluation process. The following API are exposed:
|
|
47
|
+
|
|
48
|
+
1. :meth:`register_algorithm`
|
|
49
|
+
2. :meth:`start_stream`
|
|
50
|
+
3. :meth:`get_unlabeled_data`
|
|
51
|
+
4. :meth:`submit_prediction`
|
|
52
|
+
|
|
53
|
+
The programmer can take a look at the specific method for more details
|
|
54
|
+
on the implementation of the API. The methods are designed with the
|
|
55
|
+
methodological approach that the algorithm is decoupled from the
|
|
56
|
+
the evaluating platform. And thus, the evaluator will only provide
|
|
57
|
+
the necessary data to the algorithm and evaluate the prediction.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
metric_entries: list of metric entries.
|
|
61
|
+
setting: Setting object.
|
|
62
|
+
metric_k: Number of top interactions to consider.
|
|
63
|
+
ignore_unknown_user: To ignore unknown users.
|
|
64
|
+
ignore_unknown_item: To ignore unknown items.
|
|
65
|
+
seed: Random seed for the evaluator.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
metric_entries: list[MetricEntry],
|
|
71
|
+
setting: Setting,
|
|
72
|
+
metric_k: int,
|
|
73
|
+
ignore_unknown_user: bool = False,
|
|
74
|
+
ignore_unknown_item: bool = False,
|
|
75
|
+
seed: int = 42,
|
|
76
|
+
strategy: None | EvaluationStrategy = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
super().__init__(
|
|
79
|
+
metric_entries,
|
|
80
|
+
setting,
|
|
81
|
+
metric_k,
|
|
82
|
+
ignore_unknown_user,
|
|
83
|
+
ignore_unknown_item,
|
|
84
|
+
seed,
|
|
85
|
+
)
|
|
86
|
+
self._algo_state_mgr = AlgorithmStateManager()
|
|
87
|
+
self._unlabeled_data_cache: PredictionMatrix
|
|
88
|
+
self._ground_truth_data_cache: PredictionMatrix
|
|
89
|
+
self._training_data_cache: PredictionMatrix
|
|
90
|
+
|
|
91
|
+
# Evaluator state management
|
|
92
|
+
self._state = EvaluatorState.INITIALIZED
|
|
93
|
+
|
|
94
|
+
# Evaluation strategy
|
|
95
|
+
self._strategy = strategy or SlidingWindowStrategy()
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def state(self) -> EvaluatorState:
|
|
99
|
+
return self._state
|
|
100
|
+
|
|
101
|
+
def _assert_state(self, expected: EvaluatorState, error_msg: str) -> None:
|
|
102
|
+
"""Assert evaluator is in expected state"""
|
|
103
|
+
if self._state != expected:
|
|
104
|
+
raise RuntimeError(f"{error_msg} (Current state: {self._state.value})")
|
|
105
|
+
|
|
106
|
+
def _transition_state(self, new_state: EvaluatorState, allow_from: list[EvaluatorState]) -> None:
|
|
107
|
+
"""Guard state transitions explicitly"""
|
|
108
|
+
if self._state not in allow_from:
|
|
109
|
+
raise ValueError(f"Cannot transition from {self._state} to {new_state}. Allowed from: {allow_from}")
|
|
110
|
+
self._state = new_state
|
|
111
|
+
logger.info(f"Evaluator transitioned to {new_state}")
|
|
112
|
+
|
|
113
|
+
def _cache_evaluation_data(self) -> None:
|
|
114
|
+
"""Cache the evaluation data for the current step.
|
|
115
|
+
|
|
116
|
+
Summary
|
|
117
|
+
-------
|
|
118
|
+
This method will cache the evaluation data for the current step. The method
|
|
119
|
+
will update the unknown user/item base, get the next unlabeled and ground
|
|
120
|
+
truth data, and update the current timestamp.
|
|
121
|
+
|
|
122
|
+
Specifics
|
|
123
|
+
---------
|
|
124
|
+
The method will update the unknown user/item base with the ground truth data.
|
|
125
|
+
Next, mask the unlabeled and ground truth data with the known user/item
|
|
126
|
+
base. The method will cache the unlabeled and ground truth data in the internal
|
|
127
|
+
attributes :attr:`_unlabeled_data_cache` and :attr:`_ground_truth_data_cache`.
|
|
128
|
+
The timestamp is cached in the internal attribute :attr:`_current_timestamp`.
|
|
129
|
+
|
|
130
|
+
We use an internal attribute :attr:`_run_step` to keep track of the current
|
|
131
|
+
step such that we can check if we have reached the last step.
|
|
132
|
+
|
|
133
|
+
We assume that any method calling this method has already checked if the
|
|
134
|
+
there is still data to be processed.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
logger.debug(f"Caching evaluation data for step {self._run_step}")
|
|
138
|
+
try:
|
|
139
|
+
self._unlabeled_data_cache, self._ground_truth_data_cache, _ = self._get_evaluation_data()
|
|
140
|
+
except EOWSettingError as e:
|
|
141
|
+
raise e
|
|
142
|
+
logger.debug(f"Data cached for step {self._run_step} complete")
|
|
143
|
+
|
|
144
|
+
def start_stream(self) -> None:
|
|
145
|
+
"""Start the streaming process.
|
|
146
|
+
|
|
147
|
+
This method is called to start the streaming process. `start_stream` will
|
|
148
|
+
prepare the evaluator for the streaming process. `start_stream` will reset
|
|
149
|
+
data streamers, prepare the micro and macro accumulators, update
|
|
150
|
+
the known user/item base, and cache data.
|
|
151
|
+
|
|
152
|
+
The method will set the internal state to be be started. The
|
|
153
|
+
method can be called anytime after the evaluator is instantiated.
|
|
154
|
+
|
|
155
|
+
Warning:
|
|
156
|
+
Once `start_stream` is called, the evaluator cannot register any new algorithms.
|
|
157
|
+
|
|
158
|
+
Raises:
|
|
159
|
+
ValueError: If the stream has already started.
|
|
160
|
+
"""
|
|
161
|
+
self.setting.restore()
|
|
162
|
+
|
|
163
|
+
logger.debug("Preparing evaluator for streaming")
|
|
164
|
+
self._acc = MetricAccumulator()
|
|
165
|
+
training_data = self.setting.background_data
|
|
166
|
+
# Convert to PredictionMatrix since it's a subclass of InteractionMatrix
|
|
167
|
+
training_data = PredictionMatrix.from_interaction_matrix(training_data)
|
|
168
|
+
|
|
169
|
+
self.user_item_base.update_known_user_item_base(training_data)
|
|
170
|
+
training_data.mask_user_item_shape(self.user_item_base.known_shape)
|
|
171
|
+
self._training_data_cache = training_data
|
|
172
|
+
self._cache_evaluation_data()
|
|
173
|
+
self._algo_state_mgr.set_all_ready(data_segment=self._current_timestamp)
|
|
174
|
+
logger.debug("Evaluator is ready for streaming")
|
|
175
|
+
# TODO: allow programmer to register anytime
|
|
176
|
+
self._transition_state(EvaluatorState.STARTED, allow_from=[EvaluatorState.INITIALIZED])
|
|
177
|
+
|
|
178
|
+
def register_algorithm(
|
|
179
|
+
self,
|
|
180
|
+
algorithm: None | Algorithm = None,
|
|
181
|
+
algorithm_name: None | str = None,
|
|
182
|
+
) -> UUID:
|
|
183
|
+
"""Register the algorithm with the evaluator.
|
|
184
|
+
|
|
185
|
+
This method is called to register the algorithm with the evaluator.
|
|
186
|
+
The method will assign a unique identifier to the algorithm and store
|
|
187
|
+
the algorithm in the registry. The method will raise a ValueError if
|
|
188
|
+
the stream has already started.
|
|
189
|
+
"""
|
|
190
|
+
self._assert_state(EvaluatorState.INITIALIZED, "Cannot register algorithms after stream started")
|
|
191
|
+
algo_id = self._algo_state_mgr.register(name=algorithm_name, algo_ptr=algorithm)
|
|
192
|
+
logger.debug(f"Algorithm {algo_id} registered")
|
|
193
|
+
return algo_id
|
|
194
|
+
|
|
195
|
+
def get_algorithm_state(self, algo_id: UUID) -> AlgorithmStateEnum:
|
|
196
|
+
"""Get the state of the algorithm.
|
|
197
|
+
|
|
198
|
+
This method is called to get the state of the algorithm given the
|
|
199
|
+
unique identifier of the algorithm. The method will return the state
|
|
200
|
+
of the algorithm.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
algo_id: Unique identifier of the algorithm.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The state of the algorithm.
|
|
207
|
+
"""
|
|
208
|
+
return self._algo_state_mgr[algo_id].state
|
|
209
|
+
|
|
210
|
+
def get_all_algorithm_status(self) -> dict[str, AlgorithmStateEnum]:
|
|
211
|
+
"""Get the status of all algorithms.
|
|
212
|
+
|
|
213
|
+
This method is called to get the status of all algorithms registered
|
|
214
|
+
with the evaluator. The method will return a dictionary where the key
|
|
215
|
+
is the name of the algorithm and the value is the state of the algorithm.
|
|
216
|
+
|
|
217
|
+
Returns:
|
|
218
|
+
The status of all algorithms.
|
|
219
|
+
"""
|
|
220
|
+
return self._algo_state_mgr.all_algo_states()
|
|
221
|
+
|
|
222
|
+
def load_next_window(self) -> None:
|
|
223
|
+
self.user_item_base.reset_unknown_user_item_base()
|
|
224
|
+
incremental_data = self.setting.get_split_at(self._run_step).incremental
|
|
225
|
+
if incremental_data is None:
|
|
226
|
+
raise EOWSettingError("No more data to stream")
|
|
227
|
+
# Convert to PredictionMatrix since it's a subclass of InteractionMatrix
|
|
228
|
+
incremental_data = PredictionMatrix.from_interaction_matrix(incremental_data)
|
|
229
|
+
|
|
230
|
+
self.user_item_base.update_known_user_item_base(incremental_data)
|
|
231
|
+
incremental_data.mask_user_item_shape(self.user_item_base.known_shape)
|
|
232
|
+
self._training_data_cache = incremental_data
|
|
233
|
+
self._cache_evaluation_data()
|
|
234
|
+
self._algo_state_mgr.set_all_ready(data_segment=self._current_timestamp)
|
|
235
|
+
|
|
236
|
+
def get_training_data(self, algo_id: UUID) -> InteractionMatrix:
|
|
237
|
+
"""Get training data for the algorithm.
|
|
238
|
+
|
|
239
|
+
Summary
|
|
240
|
+
-------
|
|
241
|
+
|
|
242
|
+
This method is called to get the training data for the algorithm. The
|
|
243
|
+
training data is defined as either the background data or the incremental
|
|
244
|
+
data. The training data is always released irrespective of the state of
|
|
245
|
+
the algorithm.
|
|
246
|
+
|
|
247
|
+
Specifics
|
|
248
|
+
---------
|
|
249
|
+
|
|
250
|
+
1. If the state is COMPLETED, raise warning that the algorithm has completed
|
|
251
|
+
2. If the state is NEW, release training data to the algorithm
|
|
252
|
+
3. If the state is READY and the data segment is the same, raise warning
|
|
253
|
+
that the algorithm has already obtained data
|
|
254
|
+
4. If the state is PREDICTED and the data segment is the same, inform
|
|
255
|
+
the algorithm that it has already predicted and should wait for other
|
|
256
|
+
algorithms to predict
|
|
257
|
+
5. This will occur when :attr:`_current_timestamp` has changed, which causes
|
|
258
|
+
scenario 2 to not be caught. In this case, the algorithm is requesting
|
|
259
|
+
the next window of data. Thus, this is a valid data call and the status
|
|
260
|
+
will be updated to READY.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
algo_id: Unique identifier of the algorithm.
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
ValueError: If the stream has not started.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
The training data for the algorithm.
|
|
270
|
+
"""
|
|
271
|
+
self._assert_state(EvaluatorState.STARTED, "Call start_stream() first")
|
|
272
|
+
|
|
273
|
+
logger.debug(f"Getting data for algorithm {algo_id}")
|
|
274
|
+
|
|
275
|
+
if self._strategy.should_advance_window(
|
|
276
|
+
algo_state_mgr=self._algo_state_mgr,
|
|
277
|
+
current_step=self._run_step,
|
|
278
|
+
total_steps=self.setting.num_split,
|
|
279
|
+
):
|
|
280
|
+
try:
|
|
281
|
+
self.load_next_window()
|
|
282
|
+
except EOWSettingError:
|
|
283
|
+
self._transition_state(
|
|
284
|
+
EvaluatorState.COMPLETED, allow_from=[EvaluatorState.STARTED, EvaluatorState.IN_PROGRESS]
|
|
285
|
+
)
|
|
286
|
+
raise RuntimeError("End of evaluation window reached")
|
|
287
|
+
|
|
288
|
+
can_request, reason = self._algo_state_mgr.can_request_training_data(algo_id)
|
|
289
|
+
if not can_request:
|
|
290
|
+
raise PermissionError(f"Cannot request data: {reason}")
|
|
291
|
+
# TODO handle case when algo is ready after submitting prediction, but current timestamp has not changed, meaning algo is requesting same data again
|
|
292
|
+
self._algo_state_mgr.transition(
|
|
293
|
+
algo_id,
|
|
294
|
+
AlgorithmStateEnum.RUNNING,
|
|
295
|
+
data_segment=self._current_timestamp,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
self._evaluator_state = EvaluatorState.IN_PROGRESS
|
|
299
|
+
# release data to the algorithm
|
|
300
|
+
return self._training_data_cache
|
|
301
|
+
|
|
302
|
+
def get_unlabeled_data(self, algo_id: UUID) -> PredictionMatrix:
|
|
303
|
+
"""Get unlabeled data for the algorithm.
|
|
304
|
+
|
|
305
|
+
This method is called to get the unlabeled data for the algorithm. The
|
|
306
|
+
unlabeled data is the data that the algorithm will predict. It will
|
|
307
|
+
contain `(user_id, -1)` pairs, where the value -1 indicates that the
|
|
308
|
+
item is to be predicted.
|
|
309
|
+
"""
|
|
310
|
+
logger.debug(f"Getting unlabeled data for algorithm {algo_id}")
|
|
311
|
+
can_submit, reason = self._algo_state_mgr.can_request_unlabeled_data(algo_id)
|
|
312
|
+
if not can_submit:
|
|
313
|
+
raise PermissionError(f"Cannot get unlabeled data: {reason}")
|
|
314
|
+
return self._unlabeled_data_cache
|
|
315
|
+
|
|
316
|
+
def submit_prediction(self, algo_id: UUID, X_pred: csr_matrix) -> None:
|
|
317
|
+
"""Submit the prediction of the algorithm.
|
|
318
|
+
|
|
319
|
+
This method is called to submit the prediction of the algorithm.
|
|
320
|
+
There are a few checks that are done before the prediction is
|
|
321
|
+
evaluated by calling :meth:`_evaluate_algo_pred`.
|
|
322
|
+
|
|
323
|
+
Once the prediction is evaluated, the method will update the state
|
|
324
|
+
of the algorithm to PREDICTED.
|
|
325
|
+
"""
|
|
326
|
+
logger.debug(f"Submitting prediction for algorithm {algo_id}")
|
|
327
|
+
can_submit, reason = self._algo_state_mgr.can_submit_prediction(algo_id)
|
|
328
|
+
if not can_submit:
|
|
329
|
+
raise PermissionError(f"Cannot submit prediction: {reason}")
|
|
330
|
+
|
|
331
|
+
self._evaluate_algo_pred(algo_id=algo_id, y_pred=X_pred)
|
|
332
|
+
self._algo_state_mgr.transition(
|
|
333
|
+
algo_id,
|
|
334
|
+
AlgorithmStateEnum.PREDICTED,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _evaluate_algo_pred(self, algo_id: UUID, y_pred: csr_matrix) -> None:
|
|
338
|
+
"""Evaluate the prediction for algorithm.
|
|
339
|
+
|
|
340
|
+
Given the prediction and the algorithm ID, the method will evaluate the
|
|
341
|
+
prediction using the metrics specified in the evaluator. The prediction
|
|
342
|
+
of the algorithm is compared to the ground truth data currently cached.
|
|
343
|
+
|
|
344
|
+
The evaluation results will be stored in the micro and macro accumulators
|
|
345
|
+
which will later be used to calculate the final evaluation results.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
algo_id: The unique identifier of the algorithm.
|
|
349
|
+
y_pred: The prediction of the algorithm.
|
|
350
|
+
"""
|
|
351
|
+
# get top k ground truth interactions
|
|
352
|
+
y_true = self._ground_truth_data_cache
|
|
353
|
+
# y_true = self._ground_truth_data_cache.get_users_n_first_interaction(self.metric_k)
|
|
354
|
+
y_true = y_true.item_interaction_sequence_matrix
|
|
355
|
+
|
|
356
|
+
y_pred = self._prediction_shape_handler(y_true, y_pred)
|
|
357
|
+
algorithm_name = self._algo_state_mgr.get_algorithm_identifier(algo_id)
|
|
358
|
+
|
|
359
|
+
# evaluate the prediction
|
|
360
|
+
for metric_entry in self.metric_entries:
|
|
361
|
+
metric_cls = METRIC_REGISTRY.get(metric_entry.name)
|
|
362
|
+
params = {
|
|
363
|
+
'timestamp_limit': self._current_timestamp,
|
|
364
|
+
'user_id_sequence_array': self._ground_truth_data_cache.user_id_sequence_array,
|
|
365
|
+
'user_item_shape': self._ground_truth_data_cache.user_item_shape,
|
|
366
|
+
}
|
|
367
|
+
if metric_entry.K is not None:
|
|
368
|
+
params['K'] = metric_entry.K
|
|
369
|
+
|
|
370
|
+
metric = metric_cls(**params)
|
|
371
|
+
metric.calculate(y_true, y_pred)
|
|
372
|
+
self._acc.add(metric=metric, algorithm_name=algorithm_name)
|
|
373
|
+
|
|
374
|
+
logger.debug(f"Prediction evaluated for algorithm {algo_id} complete")
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from typing import Any
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from recnexteval.algorithms import Algorithm
|
|
9
|
+
from ..utils.uuid_util import generate_algorithm_uuid
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AlgorithmStateEnum(StrEnum):
|
|
16
|
+
"""Enum for the state of the algorithm.
|
|
17
|
+
|
|
18
|
+
Used to keep track of the state of the algorithm during the streaming
|
|
19
|
+
process in the `EvaluatorStreamer`.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
NEW = "NEW"
|
|
23
|
+
READY = "READY"
|
|
24
|
+
RUNNING = "RUNNING"
|
|
25
|
+
PREDICTED = "PREDICTED"
|
|
26
|
+
COMPLETED = "COMPLETED"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class AlgorithmStateEntry:
|
|
31
|
+
"""Entry for the algorithm status registry.
|
|
32
|
+
|
|
33
|
+
This dataclass stores the status of an algorithm for use by
|
|
34
|
+
`AlgorithmStateManager`. It contains the algorithm name, unique
|
|
35
|
+
identifier, current state, associated data segment, and an optional
|
|
36
|
+
pointer to the algorithm object.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
name: Name of the algorithm.
|
|
40
|
+
algo_uuid: Unique identifier for the algorithm.
|
|
41
|
+
state: State of the algorithm.
|
|
42
|
+
data_segment: Data segment the algorithm is associated with.
|
|
43
|
+
params: Parameters for the algorithm.
|
|
44
|
+
algo_ptr: Pointer to the algorithm object.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
name: str
|
|
48
|
+
algorithm_uuid: UUID
|
|
49
|
+
state: AlgorithmStateEnum = AlgorithmStateEnum.NEW
|
|
50
|
+
data_segment: int = 0
|
|
51
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
52
|
+
algo_ptr: None | type[Algorithm] | Algorithm = None
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AlgorithmStateManager:
|
|
56
|
+
def __init__(self) -> None:
|
|
57
|
+
self._algorithms: dict[UUID, AlgorithmStateEntry] = {}
|
|
58
|
+
|
|
59
|
+
def __iter__(self) -> Iterator[UUID]:
|
|
60
|
+
"""Return an iterator over registered algorithm UUIDs.
|
|
61
|
+
|
|
62
|
+
Allows iteration over the UUIDs of registered entries.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
An iterator over the UUIDs of registered entries.
|
|
66
|
+
"""
|
|
67
|
+
return iter(self._algorithms)
|
|
68
|
+
|
|
69
|
+
def __len__(self) -> int:
|
|
70
|
+
"""Return the number of registered algorithms.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
The number of registered algorithms.
|
|
74
|
+
"""
|
|
75
|
+
return len(self._algorithms)
|
|
76
|
+
|
|
77
|
+
def values(self) -> Iterator[AlgorithmStateEntry]:
|
|
78
|
+
"""Return an iterator over registered AlgorithmStateEntry objects.
|
|
79
|
+
|
|
80
|
+
Allows iteration over the registered entries.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
An iterator over the registered entries.
|
|
84
|
+
"""
|
|
85
|
+
return iter(self._algorithms.values())
|
|
86
|
+
|
|
87
|
+
def __getitem__(self, key: UUID) -> AlgorithmStateEntry:
|
|
88
|
+
if key not in self._algorithms:
|
|
89
|
+
raise ValueError(f"Algorithm with ID:{key} not registered")
|
|
90
|
+
return self._algorithms[key]
|
|
91
|
+
|
|
92
|
+
def __setitem__(self, key: UUID, entry: AlgorithmStateEntry) -> None:
|
|
93
|
+
"""Register a new algorithm status entry under `key`.
|
|
94
|
+
|
|
95
|
+
Allows the use of square bracket notation to register new entries.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
key: The UUID to register the entry under.
|
|
99
|
+
entry: The status entry to register.
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
KeyError: If `key` is already registered.
|
|
103
|
+
"""
|
|
104
|
+
if key in self:
|
|
105
|
+
raise KeyError(f"Algorithm with ID:{key} already registered")
|
|
106
|
+
self._algorithms[key] = entry
|
|
107
|
+
|
|
108
|
+
def __contains__(self, key: UUID) -> bool:
|
|
109
|
+
"""Return whether the given key is known to the registry.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
key: The key to check.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
True if the key is registered, False otherwise.
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
self[key]
|
|
119
|
+
return True
|
|
120
|
+
except AttributeError:
|
|
121
|
+
return False
|
|
122
|
+
|
|
123
|
+
def get(self, algo_id: UUID) -> AlgorithmStateEntry:
|
|
124
|
+
"""Get the :class:`AlgorithmStateEntry` for `algo_id`."""
|
|
125
|
+
return self[algo_id]
|
|
126
|
+
|
|
127
|
+
def get_state(self, algo_id: UUID) -> AlgorithmStateEnum:
|
|
128
|
+
"""Get the current state of the algorithm with `algo_id`."""
|
|
129
|
+
return self[algo_id].state
|
|
130
|
+
|
|
131
|
+
def register(
|
|
132
|
+
self,
|
|
133
|
+
name: None | str = None,
|
|
134
|
+
algo_ptr: None | type[Algorithm] | Algorithm = None,
|
|
135
|
+
params: dict[str, Any] = {},
|
|
136
|
+
algo_uuid: None | UUID = None,
|
|
137
|
+
) -> UUID:
|
|
138
|
+
"""Register new algorithm"""
|
|
139
|
+
if not name and not algo_ptr:
|
|
140
|
+
raise ValueError("Either name or algo_ptr must be provided for registration")
|
|
141
|
+
elif algo_ptr and isinstance(algo_ptr, type):
|
|
142
|
+
algo_ptr = algo_ptr(**params)
|
|
143
|
+
name = name or algo_ptr.identifier
|
|
144
|
+
elif algo_ptr and hasattr(algo_ptr, "identifier") and not name:
|
|
145
|
+
name = name or algo_ptr.identifier # type: ignore[attr-defined]
|
|
146
|
+
elif not name:
|
|
147
|
+
# This should not happen if name was provided or algo_ptr has identifier
|
|
148
|
+
raise ValueError("Algorithm name was not provided and could not be inferred from Algorithm pointer")
|
|
149
|
+
|
|
150
|
+
if algo_uuid is None:
|
|
151
|
+
algo_uuid = generate_algorithm_uuid(name)
|
|
152
|
+
|
|
153
|
+
entry = AlgorithmStateEntry(algorithm_uuid=algo_uuid, name=name, algo_ptr=algo_ptr, params=params)
|
|
154
|
+
self._algorithms[algo_uuid] = entry
|
|
155
|
+
logger.info(f"Registered algorithm '{name}' with ID {algo_uuid}")
|
|
156
|
+
return algo_uuid
|
|
157
|
+
|
|
158
|
+
def can_request_training_data(self, algo_id: UUID) -> tuple[bool, str]:
|
|
159
|
+
"""Check if algorithm can request training data"""
|
|
160
|
+
if algo_id not in self._algorithms:
|
|
161
|
+
return False, f"Algorithm {algo_id} not registered"
|
|
162
|
+
|
|
163
|
+
state = self._algorithms[algo_id].state
|
|
164
|
+
|
|
165
|
+
if state == AlgorithmStateEnum.COMPLETED:
|
|
166
|
+
return False, "Algorithm has completed evaluation"
|
|
167
|
+
if state == AlgorithmStateEnum.NEW:
|
|
168
|
+
return False, "The algorithm must be set to READY state first"
|
|
169
|
+
if state == AlgorithmStateEnum.PREDICTED:
|
|
170
|
+
return False, "Algorithm has already requested data for this window"
|
|
171
|
+
if state == AlgorithmStateEnum.READY:
|
|
172
|
+
return True, ""
|
|
173
|
+
|
|
174
|
+
return False, f"Unknown state {state}"
|
|
175
|
+
|
|
176
|
+
def can_request_unlabeled_data(self, algo_id: UUID) -> tuple[bool, str]:
|
|
177
|
+
"""Check if algorithm can request unlabeled data"""
|
|
178
|
+
if algo_id not in self._algorithms:
|
|
179
|
+
return False, f"Algorithm {algo_id} not registered"
|
|
180
|
+
|
|
181
|
+
state = self._algorithms[algo_id].state
|
|
182
|
+
|
|
183
|
+
if state == AlgorithmStateEnum.RUNNING:
|
|
184
|
+
return True, ""
|
|
185
|
+
if state == AlgorithmStateEnum.COMPLETED:
|
|
186
|
+
return False, "Algorithm has completed evaluation"
|
|
187
|
+
if state == AlgorithmStateEnum.NEW:
|
|
188
|
+
return False, "The algorithm must be set to RUNNING state to request unlabeled data"
|
|
189
|
+
if state == AlgorithmStateEnum.PREDICTED:
|
|
190
|
+
return False, "Algorithm has already requested data for this window"
|
|
191
|
+
if state == AlgorithmStateEnum.READY:
|
|
192
|
+
return (
|
|
193
|
+
False,
|
|
194
|
+
"The algorithm must be set to RUNNING state to request unlabeled data. Request training data first",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return False, f"Unknown state {state}"
|
|
198
|
+
|
|
199
|
+
def can_submit_prediction(self, algo_id: UUID) -> tuple[bool, str]:
|
|
200
|
+
"""Check if algorithm can submit prediction"""
|
|
201
|
+
if algo_id not in self._algorithms:
|
|
202
|
+
return False, f"Algorithm {algo_id} not registered"
|
|
203
|
+
|
|
204
|
+
state = self._algorithms[algo_id].state
|
|
205
|
+
|
|
206
|
+
if state == AlgorithmStateEnum.RUNNING:
|
|
207
|
+
return True, ""
|
|
208
|
+
if state == AlgorithmStateEnum.READY:
|
|
209
|
+
return False, "There is new data to be requested"
|
|
210
|
+
if state == AlgorithmStateEnum.NEW:
|
|
211
|
+
return False, "Algorithm must request data first"
|
|
212
|
+
if state == AlgorithmStateEnum.PREDICTED:
|
|
213
|
+
return False, "Algorithm already submitted prediction for this window"
|
|
214
|
+
if state == AlgorithmStateEnum.COMPLETED:
|
|
215
|
+
return False, "Algorithm has completed evaluation"
|
|
216
|
+
|
|
217
|
+
return False, f"Unknown state {state}"
|
|
218
|
+
|
|
219
|
+
def transition(self, algo_id: UUID, new_state: AlgorithmStateEnum, data_segment: None | int = None) -> None:
|
|
220
|
+
"""Transition algorithm to new state with validation"""
|
|
221
|
+
if algo_id not in self._algorithms:
|
|
222
|
+
raise ValueError(f"Algorithm {algo_id} not registered")
|
|
223
|
+
|
|
224
|
+
entry = self._algorithms[algo_id]
|
|
225
|
+
old_state = entry.state
|
|
226
|
+
|
|
227
|
+
# Define valid transitions
|
|
228
|
+
valid_transitions = {
|
|
229
|
+
# old_state: [list of valid new_states]
|
|
230
|
+
AlgorithmStateEnum.NEW: [AlgorithmStateEnum.READY, AlgorithmStateEnum.COMPLETED],
|
|
231
|
+
AlgorithmStateEnum.READY: [AlgorithmStateEnum.RUNNING],
|
|
232
|
+
AlgorithmStateEnum.RUNNING: [AlgorithmStateEnum.PREDICTED],
|
|
233
|
+
AlgorithmStateEnum.PREDICTED: [AlgorithmStateEnum.READY, AlgorithmStateEnum.COMPLETED],
|
|
234
|
+
AlgorithmStateEnum.COMPLETED: [],
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
if new_state not in valid_transitions.get(old_state, []):
|
|
238
|
+
raise ValueError(f"Invalid transition: {old_state} -> {new_state}")
|
|
239
|
+
|
|
240
|
+
entry.state = new_state
|
|
241
|
+
if data_segment is not None:
|
|
242
|
+
entry.data_segment = data_segment
|
|
243
|
+
|
|
244
|
+
logger.debug(f"Algorithm '{entry.name}' transitioned {old_state.value} -> {new_state.value}")
|
|
245
|
+
|
|
246
|
+
def is_all_predicted(self) -> bool:
|
|
247
|
+
"""Return whether every registered algorithm is in PREDICTED state.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
True if all registered entries have state
|
|
251
|
+
`AlgorithmStateEnum.PREDICTED`, False otherwise.
|
|
252
|
+
"""
|
|
253
|
+
if not self._algorithms:
|
|
254
|
+
return False
|
|
255
|
+
return all(entry.state == AlgorithmStateEnum.PREDICTED for entry in self._algorithms.values())
|
|
256
|
+
|
|
257
|
+
def get_all_states(self) -> dict[str, AlgorithmStateEnum]:
|
|
258
|
+
"""Get state of all algorithms"""
|
|
259
|
+
return {entry.name: entry.state for entry in self._algorithms.values()}
|
|
260
|
+
|
|
261
|
+
def is_all_same_data_segment(self) -> bool:
|
|
262
|
+
"""Return whether all registered entries share the same data segment.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
True if there is exactly one distinct data segment across all
|
|
266
|
+
registered entries, False otherwise.
|
|
267
|
+
"""
|
|
268
|
+
data_segments: set[None | int] = set()
|
|
269
|
+
for key in self:
|
|
270
|
+
data_segments.add(self[key].data_segment)
|
|
271
|
+
return len(data_segments) == 1
|
|
272
|
+
|
|
273
|
+
def all_algo_states(self) -> dict[str, AlgorithmStateEnum]:
|
|
274
|
+
"""Return a mapping of identifier strings to algorithm states.
|
|
275
|
+
|
|
276
|
+
The identifier used is "{name}_{uuid}" for each registered entry.
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Mapping from identifier string to the entry's
|
|
280
|
+
:class:`AlgorithmStateEnum`.
|
|
281
|
+
"""
|
|
282
|
+
states: dict[str, AlgorithmStateEnum] = {}
|
|
283
|
+
for key in self:
|
|
284
|
+
states[f"{self[key].name}_{key}"] = self[key].state
|
|
285
|
+
return states
|
|
286
|
+
|
|
287
|
+
def set_all_ready(self, data_segment: int) -> None:
|
|
288
|
+
"""Set all registered algorithms to the READY state.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
data_segment: Data segment to assign to every algorithm.
|
|
292
|
+
"""
|
|
293
|
+
for key in self:
|
|
294
|
+
self.transition(key, AlgorithmStateEnum.READY, data_segment)
|
|
295
|
+
|
|
296
|
+
def get_algorithm_identifier(self, algo_id: UUID) -> str:
|
|
297
|
+
"""Return a stable identifier string for the algorithm.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
algo_id: UUID of the algorithm.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
Identifier in the format "{name}_{uuid}".
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
AttributeError: If `algo_id` is not registered.
|
|
307
|
+
"""
|
|
308
|
+
if algo_id not in self._algorithms:
|
|
309
|
+
raise AttributeError(f"Algorithm with ID:{algo_id} not registered")
|
|
310
|
+
return f"{self[algo_id].name}_{algo_id}"
|