recnexteval 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. recnexteval/__init__.py +20 -0
  2. recnexteval/algorithms/__init__.py +99 -0
  3. recnexteval/algorithms/base.py +377 -0
  4. recnexteval/algorithms/baseline/__init__.py +10 -0
  5. recnexteval/algorithms/baseline/decay_popularity.py +110 -0
  6. recnexteval/algorithms/baseline/most_popular.py +72 -0
  7. recnexteval/algorithms/baseline/random.py +39 -0
  8. recnexteval/algorithms/baseline/recent_popularity.py +34 -0
  9. recnexteval/algorithms/itemknn/__init__.py +14 -0
  10. recnexteval/algorithms/itemknn/itemknn.py +119 -0
  11. recnexteval/algorithms/itemknn/itemknn_incremental.py +65 -0
  12. recnexteval/algorithms/itemknn/itemknn_incremental_movielens.py +95 -0
  13. recnexteval/algorithms/itemknn/itemknn_rolling.py +17 -0
  14. recnexteval/algorithms/itemknn/itemknn_static.py +31 -0
  15. recnexteval/algorithms/time_aware_item_knn/__init__.py +11 -0
  16. recnexteval/algorithms/time_aware_item_knn/base.py +248 -0
  17. recnexteval/algorithms/time_aware_item_knn/decay_functions.py +260 -0
  18. recnexteval/algorithms/time_aware_item_knn/ding_2005.py +52 -0
  19. recnexteval/algorithms/time_aware_item_knn/liu_2010.py +65 -0
  20. recnexteval/algorithms/time_aware_item_knn/similarity_functions.py +106 -0
  21. recnexteval/algorithms/time_aware_item_knn/top_k.py +61 -0
  22. recnexteval/algorithms/time_aware_item_knn/utils.py +47 -0
  23. recnexteval/algorithms/time_aware_item_knn/vaz_2013.py +50 -0
  24. recnexteval/algorithms/utils.py +51 -0
  25. recnexteval/datasets/__init__.py +109 -0
  26. recnexteval/datasets/base.py +316 -0
  27. recnexteval/datasets/config/__init__.py +113 -0
  28. recnexteval/datasets/config/amazon.py +188 -0
  29. recnexteval/datasets/config/base.py +72 -0
  30. recnexteval/datasets/config/lastfm.py +105 -0
  31. recnexteval/datasets/config/movielens.py +169 -0
  32. recnexteval/datasets/config/yelp.py +25 -0
  33. recnexteval/datasets/datasets/__init__.py +24 -0
  34. recnexteval/datasets/datasets/amazon.py +151 -0
  35. recnexteval/datasets/datasets/base.py +250 -0
  36. recnexteval/datasets/datasets/lastfm.py +121 -0
  37. recnexteval/datasets/datasets/movielens.py +93 -0
  38. recnexteval/datasets/datasets/test.py +46 -0
  39. recnexteval/datasets/datasets/yelp.py +103 -0
  40. recnexteval/datasets/metadata/__init__.py +58 -0
  41. recnexteval/datasets/metadata/amazon.py +68 -0
  42. recnexteval/datasets/metadata/base.py +38 -0
  43. recnexteval/datasets/metadata/lastfm.py +110 -0
  44. recnexteval/datasets/metadata/movielens.py +87 -0
  45. recnexteval/evaluators/__init__.py +189 -0
  46. recnexteval/evaluators/accumulator.py +167 -0
  47. recnexteval/evaluators/base.py +216 -0
  48. recnexteval/evaluators/builder/__init__.py +125 -0
  49. recnexteval/evaluators/builder/base.py +166 -0
  50. recnexteval/evaluators/builder/pipeline.py +111 -0
  51. recnexteval/evaluators/builder/stream.py +54 -0
  52. recnexteval/evaluators/evaluator_pipeline.py +287 -0
  53. recnexteval/evaluators/evaluator_stream.py +374 -0
  54. recnexteval/evaluators/state_management.py +310 -0
  55. recnexteval/evaluators/strategy.py +32 -0
  56. recnexteval/evaluators/util.py +124 -0
  57. recnexteval/matrix/__init__.py +48 -0
  58. recnexteval/matrix/exception.py +5 -0
  59. recnexteval/matrix/interaction_matrix.py +784 -0
  60. recnexteval/matrix/prediction_matrix.py +153 -0
  61. recnexteval/matrix/util.py +24 -0
  62. recnexteval/metrics/__init__.py +57 -0
  63. recnexteval/metrics/binary/__init__.py +4 -0
  64. recnexteval/metrics/binary/hit.py +49 -0
  65. recnexteval/metrics/core/__init__.py +10 -0
  66. recnexteval/metrics/core/base.py +126 -0
  67. recnexteval/metrics/core/elementwise_top_k.py +75 -0
  68. recnexteval/metrics/core/listwise_top_k.py +72 -0
  69. recnexteval/metrics/core/top_k.py +60 -0
  70. recnexteval/metrics/core/util.py +29 -0
  71. recnexteval/metrics/ranking/__init__.py +6 -0
  72. recnexteval/metrics/ranking/dcg.py +55 -0
  73. recnexteval/metrics/ranking/ndcg.py +78 -0
  74. recnexteval/metrics/ranking/precision.py +51 -0
  75. recnexteval/metrics/ranking/recall.py +42 -0
  76. recnexteval/models/__init__.py +4 -0
  77. recnexteval/models/base.py +69 -0
  78. recnexteval/preprocessing/__init__.py +37 -0
  79. recnexteval/preprocessing/filter.py +181 -0
  80. recnexteval/preprocessing/preprocessor.py +137 -0
  81. recnexteval/registries/__init__.py +67 -0
  82. recnexteval/registries/algorithm.py +68 -0
  83. recnexteval/registries/base.py +131 -0
  84. recnexteval/registries/dataset.py +37 -0
  85. recnexteval/registries/metric.py +57 -0
  86. recnexteval/settings/__init__.py +127 -0
  87. recnexteval/settings/base.py +414 -0
  88. recnexteval/settings/exception.py +8 -0
  89. recnexteval/settings/leave_n_out_setting.py +48 -0
  90. recnexteval/settings/processor.py +115 -0
  91. recnexteval/settings/schema.py +11 -0
  92. recnexteval/settings/single_time_point_setting.py +111 -0
  93. recnexteval/settings/sliding_window_setting.py +153 -0
  94. recnexteval/settings/splitters/__init__.py +14 -0
  95. recnexteval/settings/splitters/base.py +57 -0
  96. recnexteval/settings/splitters/n_last.py +39 -0
  97. recnexteval/settings/splitters/n_last_timestamp.py +76 -0
  98. recnexteval/settings/splitters/timestamp.py +82 -0
  99. recnexteval/settings/util.py +0 -0
  100. recnexteval/utils/__init__.py +115 -0
  101. recnexteval/utils/json_to_csv_converter.py +128 -0
  102. recnexteval/utils/logging_tools.py +159 -0
  103. recnexteval/utils/path.py +155 -0
  104. recnexteval/utils/url_certificate_installer.py +54 -0
  105. recnexteval/utils/util.py +166 -0
  106. recnexteval/utils/uuid_util.py +7 -0
  107. recnexteval/utils/yaml_tool.py +65 -0
  108. recnexteval-0.1.0.dist-info/METADATA +85 -0
  109. recnexteval-0.1.0.dist-info/RECORD +110 -0
  110. recnexteval-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,374 @@
1
+ import logging
2
+ from enum import Enum
3
+ from uuid import UUID
4
+
5
+ from scipy.sparse import csr_matrix
6
+
7
+ from recnexteval.algorithms import Algorithm
8
+ from recnexteval.matrix import InteractionMatrix, PredictionMatrix
9
+ from recnexteval.registries import (
10
+ METRIC_REGISTRY,
11
+ MetricEntry,
12
+ )
13
+ from recnexteval.settings import EOWSettingError, Setting
14
+ from .accumulator import MetricAccumulator
15
+ from .base import EvaluatorBase
16
+ from .state_management import AlgorithmStateEnum, AlgorithmStateManager
17
+ from .strategy import EvaluationStrategy, SlidingWindowStrategy
18
+
19
+
20
+ class EvaluatorState(Enum):
21
+ """Evaluator lifecycle states"""
22
+
23
+ INITIALIZED = "initialized"
24
+ STARTED = "started"
25
+ IN_PROGRESS = "in_progress"
26
+ COMPLETED = "completed"
27
+ FAILED = "failed"
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class EvaluatorStreamer(EvaluatorBase):
34
+ """Evaluation via streaming through API.
35
+
36
+ The diagram below shows the diagram of the streamer evaluator for the
37
+ sliding window setting. Instead of the pipeline, we allow the user to
38
+ stream the data release to the algorithm. The data communication is shown
39
+ between the evaluator and the algorithm. Note that while only 2 splits are
40
+ shown here, the evaluator will continue to stream the data until the end
41
+ of the setting where there are no more splits.
42
+
43
+ ![stream scheme](../../../assets/_static/stream_scheme.png)
44
+
45
+ This class exposes a few of the core API that allows the user to stream
46
+ the evaluation process. The following API are exposed:
47
+
48
+ 1. :meth:`register_algorithm`
49
+ 2. :meth:`start_stream`
50
+ 3. :meth:`get_unlabeled_data`
51
+ 4. :meth:`submit_prediction`
52
+
53
+ The programmer can take a look at the specific method for more details
54
+ on the implementation of the API. The methods are designed with the
55
+ methodological approach that the algorithm is decoupled from the
56
+ the evaluating platform. And thus, the evaluator will only provide
57
+ the necessary data to the algorithm and evaluate the prediction.
58
+
59
+ Args:
60
+ metric_entries: list of metric entries.
61
+ setting: Setting object.
62
+ metric_k: Number of top interactions to consider.
63
+ ignore_unknown_user: To ignore unknown users.
64
+ ignore_unknown_item: To ignore unknown items.
65
+ seed: Random seed for the evaluator.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ metric_entries: list[MetricEntry],
71
+ setting: Setting,
72
+ metric_k: int,
73
+ ignore_unknown_user: bool = False,
74
+ ignore_unknown_item: bool = False,
75
+ seed: int = 42,
76
+ strategy: None | EvaluationStrategy = None,
77
+ ) -> None:
78
+ super().__init__(
79
+ metric_entries,
80
+ setting,
81
+ metric_k,
82
+ ignore_unknown_user,
83
+ ignore_unknown_item,
84
+ seed,
85
+ )
86
+ self._algo_state_mgr = AlgorithmStateManager()
87
+ self._unlabeled_data_cache: PredictionMatrix
88
+ self._ground_truth_data_cache: PredictionMatrix
89
+ self._training_data_cache: PredictionMatrix
90
+
91
+ # Evaluator state management
92
+ self._state = EvaluatorState.INITIALIZED
93
+
94
+ # Evaluation strategy
95
+ self._strategy = strategy or SlidingWindowStrategy()
96
+
97
+ @property
98
+ def state(self) -> EvaluatorState:
99
+ return self._state
100
+
101
+ def _assert_state(self, expected: EvaluatorState, error_msg: str) -> None:
102
+ """Assert evaluator is in expected state"""
103
+ if self._state != expected:
104
+ raise RuntimeError(f"{error_msg} (Current state: {self._state.value})")
105
+
106
+ def _transition_state(self, new_state: EvaluatorState, allow_from: list[EvaluatorState]) -> None:
107
+ """Guard state transitions explicitly"""
108
+ if self._state not in allow_from:
109
+ raise ValueError(f"Cannot transition from {self._state} to {new_state}. Allowed from: {allow_from}")
110
+ self._state = new_state
111
+ logger.info(f"Evaluator transitioned to {new_state}")
112
+
113
+ def _cache_evaluation_data(self) -> None:
114
+ """Cache the evaluation data for the current step.
115
+
116
+ Summary
117
+ -------
118
+ This method will cache the evaluation data for the current step. The method
119
+ will update the unknown user/item base, get the next unlabeled and ground
120
+ truth data, and update the current timestamp.
121
+
122
+ Specifics
123
+ ---------
124
+ The method will update the unknown user/item base with the ground truth data.
125
+ Next, mask the unlabeled and ground truth data with the known user/item
126
+ base. The method will cache the unlabeled and ground truth data in the internal
127
+ attributes :attr:`_unlabeled_data_cache` and :attr:`_ground_truth_data_cache`.
128
+ The timestamp is cached in the internal attribute :attr:`_current_timestamp`.
129
+
130
+ We use an internal attribute :attr:`_run_step` to keep track of the current
131
+ step such that we can check if we have reached the last step.
132
+
133
+ We assume that any method calling this method has already checked if the
134
+ there is still data to be processed.
135
+ """
136
+
137
+ logger.debug(f"Caching evaluation data for step {self._run_step}")
138
+ try:
139
+ self._unlabeled_data_cache, self._ground_truth_data_cache, _ = self._get_evaluation_data()
140
+ except EOWSettingError as e:
141
+ raise e
142
+ logger.debug(f"Data cached for step {self._run_step} complete")
143
+
144
+ def start_stream(self) -> None:
145
+ """Start the streaming process.
146
+
147
+ This method is called to start the streaming process. `start_stream` will
148
+ prepare the evaluator for the streaming process. `start_stream` will reset
149
+ data streamers, prepare the micro and macro accumulators, update
150
+ the known user/item base, and cache data.
151
+
152
+ The method will set the internal state to be be started. The
153
+ method can be called anytime after the evaluator is instantiated.
154
+
155
+ Warning:
156
+ Once `start_stream` is called, the evaluator cannot register any new algorithms.
157
+
158
+ Raises:
159
+ ValueError: If the stream has already started.
160
+ """
161
+ self.setting.restore()
162
+
163
+ logger.debug("Preparing evaluator for streaming")
164
+ self._acc = MetricAccumulator()
165
+ training_data = self.setting.background_data
166
+ # Convert to PredictionMatrix since it's a subclass of InteractionMatrix
167
+ training_data = PredictionMatrix.from_interaction_matrix(training_data)
168
+
169
+ self.user_item_base.update_known_user_item_base(training_data)
170
+ training_data.mask_user_item_shape(self.user_item_base.known_shape)
171
+ self._training_data_cache = training_data
172
+ self._cache_evaluation_data()
173
+ self._algo_state_mgr.set_all_ready(data_segment=self._current_timestamp)
174
+ logger.debug("Evaluator is ready for streaming")
175
+ # TODO: allow programmer to register anytime
176
+ self._transition_state(EvaluatorState.STARTED, allow_from=[EvaluatorState.INITIALIZED])
177
+
178
+ def register_algorithm(
179
+ self,
180
+ algorithm: None | Algorithm = None,
181
+ algorithm_name: None | str = None,
182
+ ) -> UUID:
183
+ """Register the algorithm with the evaluator.
184
+
185
+ This method is called to register the algorithm with the evaluator.
186
+ The method will assign a unique identifier to the algorithm and store
187
+ the algorithm in the registry. The method will raise a ValueError if
188
+ the stream has already started.
189
+ """
190
+ self._assert_state(EvaluatorState.INITIALIZED, "Cannot register algorithms after stream started")
191
+ algo_id = self._algo_state_mgr.register(name=algorithm_name, algo_ptr=algorithm)
192
+ logger.debug(f"Algorithm {algo_id} registered")
193
+ return algo_id
194
+
195
+ def get_algorithm_state(self, algo_id: UUID) -> AlgorithmStateEnum:
196
+ """Get the state of the algorithm.
197
+
198
+ This method is called to get the state of the algorithm given the
199
+ unique identifier of the algorithm. The method will return the state
200
+ of the algorithm.
201
+
202
+ Args:
203
+ algo_id: Unique identifier of the algorithm.
204
+
205
+ Returns:
206
+ The state of the algorithm.
207
+ """
208
+ return self._algo_state_mgr[algo_id].state
209
+
210
+ def get_all_algorithm_status(self) -> dict[str, AlgorithmStateEnum]:
211
+ """Get the status of all algorithms.
212
+
213
+ This method is called to get the status of all algorithms registered
214
+ with the evaluator. The method will return a dictionary where the key
215
+ is the name of the algorithm and the value is the state of the algorithm.
216
+
217
+ Returns:
218
+ The status of all algorithms.
219
+ """
220
+ return self._algo_state_mgr.all_algo_states()
221
+
222
+ def load_next_window(self) -> None:
223
+ self.user_item_base.reset_unknown_user_item_base()
224
+ incremental_data = self.setting.get_split_at(self._run_step).incremental
225
+ if incremental_data is None:
226
+ raise EOWSettingError("No more data to stream")
227
+ # Convert to PredictionMatrix since it's a subclass of InteractionMatrix
228
+ incremental_data = PredictionMatrix.from_interaction_matrix(incremental_data)
229
+
230
+ self.user_item_base.update_known_user_item_base(incremental_data)
231
+ incremental_data.mask_user_item_shape(self.user_item_base.known_shape)
232
+ self._training_data_cache = incremental_data
233
+ self._cache_evaluation_data()
234
+ self._algo_state_mgr.set_all_ready(data_segment=self._current_timestamp)
235
+
236
+ def get_training_data(self, algo_id: UUID) -> InteractionMatrix:
237
+ """Get training data for the algorithm.
238
+
239
+ Summary
240
+ -------
241
+
242
+ This method is called to get the training data for the algorithm. The
243
+ training data is defined as either the background data or the incremental
244
+ data. The training data is always released irrespective of the state of
245
+ the algorithm.
246
+
247
+ Specifics
248
+ ---------
249
+
250
+ 1. If the state is COMPLETED, raise warning that the algorithm has completed
251
+ 2. If the state is NEW, release training data to the algorithm
252
+ 3. If the state is READY and the data segment is the same, raise warning
253
+ that the algorithm has already obtained data
254
+ 4. If the state is PREDICTED and the data segment is the same, inform
255
+ the algorithm that it has already predicted and should wait for other
256
+ algorithms to predict
257
+ 5. This will occur when :attr:`_current_timestamp` has changed, which causes
258
+ scenario 2 to not be caught. In this case, the algorithm is requesting
259
+ the next window of data. Thus, this is a valid data call and the status
260
+ will be updated to READY.
261
+
262
+ Args:
263
+ algo_id: Unique identifier of the algorithm.
264
+
265
+ Raises:
266
+ ValueError: If the stream has not started.
267
+
268
+ Returns:
269
+ The training data for the algorithm.
270
+ """
271
+ self._assert_state(EvaluatorState.STARTED, "Call start_stream() first")
272
+
273
+ logger.debug(f"Getting data for algorithm {algo_id}")
274
+
275
+ if self._strategy.should_advance_window(
276
+ algo_state_mgr=self._algo_state_mgr,
277
+ current_step=self._run_step,
278
+ total_steps=self.setting.num_split,
279
+ ):
280
+ try:
281
+ self.load_next_window()
282
+ except EOWSettingError:
283
+ self._transition_state(
284
+ EvaluatorState.COMPLETED, allow_from=[EvaluatorState.STARTED, EvaluatorState.IN_PROGRESS]
285
+ )
286
+ raise RuntimeError("End of evaluation window reached")
287
+
288
+ can_request, reason = self._algo_state_mgr.can_request_training_data(algo_id)
289
+ if not can_request:
290
+ raise PermissionError(f"Cannot request data: {reason}")
291
+ # TODO handle case when algo is ready after submitting prediction, but current timestamp has not changed, meaning algo is requesting same data again
292
+ self._algo_state_mgr.transition(
293
+ algo_id,
294
+ AlgorithmStateEnum.RUNNING,
295
+ data_segment=self._current_timestamp,
296
+ )
297
+
298
+ self._evaluator_state = EvaluatorState.IN_PROGRESS
299
+ # release data to the algorithm
300
+ return self._training_data_cache
301
+
302
+ def get_unlabeled_data(self, algo_id: UUID) -> PredictionMatrix:
303
+ """Get unlabeled data for the algorithm.
304
+
305
+ This method is called to get the unlabeled data for the algorithm. The
306
+ unlabeled data is the data that the algorithm will predict. It will
307
+ contain `(user_id, -1)` pairs, where the value -1 indicates that the
308
+ item is to be predicted.
309
+ """
310
+ logger.debug(f"Getting unlabeled data for algorithm {algo_id}")
311
+ can_submit, reason = self._algo_state_mgr.can_request_unlabeled_data(algo_id)
312
+ if not can_submit:
313
+ raise PermissionError(f"Cannot get unlabeled data: {reason}")
314
+ return self._unlabeled_data_cache
315
+
316
+ def submit_prediction(self, algo_id: UUID, X_pred: csr_matrix) -> None:
317
+ """Submit the prediction of the algorithm.
318
+
319
+ This method is called to submit the prediction of the algorithm.
320
+ There are a few checks that are done before the prediction is
321
+ evaluated by calling :meth:`_evaluate_algo_pred`.
322
+
323
+ Once the prediction is evaluated, the method will update the state
324
+ of the algorithm to PREDICTED.
325
+ """
326
+ logger.debug(f"Submitting prediction for algorithm {algo_id}")
327
+ can_submit, reason = self._algo_state_mgr.can_submit_prediction(algo_id)
328
+ if not can_submit:
329
+ raise PermissionError(f"Cannot submit prediction: {reason}")
330
+
331
+ self._evaluate_algo_pred(algo_id=algo_id, y_pred=X_pred)
332
+ self._algo_state_mgr.transition(
333
+ algo_id,
334
+ AlgorithmStateEnum.PREDICTED,
335
+ )
336
+
337
+ def _evaluate_algo_pred(self, algo_id: UUID, y_pred: csr_matrix) -> None:
338
+ """Evaluate the prediction for algorithm.
339
+
340
+ Given the prediction and the algorithm ID, the method will evaluate the
341
+ prediction using the metrics specified in the evaluator. The prediction
342
+ of the algorithm is compared to the ground truth data currently cached.
343
+
344
+ The evaluation results will be stored in the micro and macro accumulators
345
+ which will later be used to calculate the final evaluation results.
346
+
347
+ Args:
348
+ algo_id: The unique identifier of the algorithm.
349
+ y_pred: The prediction of the algorithm.
350
+ """
351
+ # get top k ground truth interactions
352
+ y_true = self._ground_truth_data_cache
353
+ # y_true = self._ground_truth_data_cache.get_users_n_first_interaction(self.metric_k)
354
+ y_true = y_true.item_interaction_sequence_matrix
355
+
356
+ y_pred = self._prediction_shape_handler(y_true, y_pred)
357
+ algorithm_name = self._algo_state_mgr.get_algorithm_identifier(algo_id)
358
+
359
+ # evaluate the prediction
360
+ for metric_entry in self.metric_entries:
361
+ metric_cls = METRIC_REGISTRY.get(metric_entry.name)
362
+ params = {
363
+ 'timestamp_limit': self._current_timestamp,
364
+ 'user_id_sequence_array': self._ground_truth_data_cache.user_id_sequence_array,
365
+ 'user_item_shape': self._ground_truth_data_cache.user_item_shape,
366
+ }
367
+ if metric_entry.K is not None:
368
+ params['K'] = metric_entry.K
369
+
370
+ metric = metric_cls(**params)
371
+ metric.calculate(y_true, y_pred)
372
+ self._acc.add(metric=metric, algorithm_name=algorithm_name)
373
+
374
+ logger.debug(f"Prediction evaluated for algorithm {algo_id} complete")
@@ -0,0 +1,310 @@
1
+ import logging
2
+ from collections.abc import Iterator
3
+ from dataclasses import dataclass, field
4
+ from enum import StrEnum
5
+ from typing import Any
6
+ from uuid import UUID
7
+
8
+ from recnexteval.algorithms import Algorithm
9
+ from ..utils.uuid_util import generate_algorithm_uuid
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AlgorithmStateEnum(StrEnum):
16
+ """Enum for the state of the algorithm.
17
+
18
+ Used to keep track of the state of the algorithm during the streaming
19
+ process in the `EvaluatorStreamer`.
20
+ """
21
+
22
+ NEW = "NEW"
23
+ READY = "READY"
24
+ RUNNING = "RUNNING"
25
+ PREDICTED = "PREDICTED"
26
+ COMPLETED = "COMPLETED"
27
+
28
+
29
+ @dataclass
30
+ class AlgorithmStateEntry:
31
+ """Entry for the algorithm status registry.
32
+
33
+ This dataclass stores the status of an algorithm for use by
34
+ `AlgorithmStateManager`. It contains the algorithm name, unique
35
+ identifier, current state, associated data segment, and an optional
36
+ pointer to the algorithm object.
37
+
38
+ Attributes:
39
+ name: Name of the algorithm.
40
+ algo_uuid: Unique identifier for the algorithm.
41
+ state: State of the algorithm.
42
+ data_segment: Data segment the algorithm is associated with.
43
+ params: Parameters for the algorithm.
44
+ algo_ptr: Pointer to the algorithm object.
45
+ """
46
+
47
+ name: str
48
+ algorithm_uuid: UUID
49
+ state: AlgorithmStateEnum = AlgorithmStateEnum.NEW
50
+ data_segment: int = 0
51
+ params: dict[str, Any] = field(default_factory=dict)
52
+ algo_ptr: None | type[Algorithm] | Algorithm = None
53
+
54
+
55
+ class AlgorithmStateManager:
56
+ def __init__(self) -> None:
57
+ self._algorithms: dict[UUID, AlgorithmStateEntry] = {}
58
+
59
+ def __iter__(self) -> Iterator[UUID]:
60
+ """Return an iterator over registered algorithm UUIDs.
61
+
62
+ Allows iteration over the UUIDs of registered entries.
63
+
64
+ Returns:
65
+ An iterator over the UUIDs of registered entries.
66
+ """
67
+ return iter(self._algorithms)
68
+
69
+ def __len__(self) -> int:
70
+ """Return the number of registered algorithms.
71
+
72
+ Returns:
73
+ The number of registered algorithms.
74
+ """
75
+ return len(self._algorithms)
76
+
77
+ def values(self) -> Iterator[AlgorithmStateEntry]:
78
+ """Return an iterator over registered AlgorithmStateEntry objects.
79
+
80
+ Allows iteration over the registered entries.
81
+
82
+ Returns:
83
+ An iterator over the registered entries.
84
+ """
85
+ return iter(self._algorithms.values())
86
+
87
+ def __getitem__(self, key: UUID) -> AlgorithmStateEntry:
88
+ if key not in self._algorithms:
89
+ raise ValueError(f"Algorithm with ID:{key} not registered")
90
+ return self._algorithms[key]
91
+
92
+ def __setitem__(self, key: UUID, entry: AlgorithmStateEntry) -> None:
93
+ """Register a new algorithm status entry under `key`.
94
+
95
+ Allows the use of square bracket notation to register new entries.
96
+
97
+ Args:
98
+ key: The UUID to register the entry under.
99
+ entry: The status entry to register.
100
+
101
+ Raises:
102
+ KeyError: If `key` is already registered.
103
+ """
104
+ if key in self:
105
+ raise KeyError(f"Algorithm with ID:{key} already registered")
106
+ self._algorithms[key] = entry
107
+
108
+ def __contains__(self, key: UUID) -> bool:
109
+ """Return whether the given key is known to the registry.
110
+
111
+ Args:
112
+ key: The key to check.
113
+
114
+ Returns:
115
+ True if the key is registered, False otherwise.
116
+ """
117
+ try:
118
+ self[key]
119
+ return True
120
+ except AttributeError:
121
+ return False
122
+
123
+ def get(self, algo_id: UUID) -> AlgorithmStateEntry:
124
+ """Get the :class:`AlgorithmStateEntry` for `algo_id`."""
125
+ return self[algo_id]
126
+
127
+ def get_state(self, algo_id: UUID) -> AlgorithmStateEnum:
128
+ """Get the current state of the algorithm with `algo_id`."""
129
+ return self[algo_id].state
130
+
131
+ def register(
132
+ self,
133
+ name: None | str = None,
134
+ algo_ptr: None | type[Algorithm] | Algorithm = None,
135
+ params: dict[str, Any] = {},
136
+ algo_uuid: None | UUID = None,
137
+ ) -> UUID:
138
+ """Register new algorithm"""
139
+ if not name and not algo_ptr:
140
+ raise ValueError("Either name or algo_ptr must be provided for registration")
141
+ elif algo_ptr and isinstance(algo_ptr, type):
142
+ algo_ptr = algo_ptr(**params)
143
+ name = name or algo_ptr.identifier
144
+ elif algo_ptr and hasattr(algo_ptr, "identifier") and not name:
145
+ name = name or algo_ptr.identifier # type: ignore[attr-defined]
146
+ elif not name:
147
+ # This should not happen if name was provided or algo_ptr has identifier
148
+ raise ValueError("Algorithm name was not provided and could not be inferred from Algorithm pointer")
149
+
150
+ if algo_uuid is None:
151
+ algo_uuid = generate_algorithm_uuid(name)
152
+
153
+ entry = AlgorithmStateEntry(algorithm_uuid=algo_uuid, name=name, algo_ptr=algo_ptr, params=params)
154
+ self._algorithms[algo_uuid] = entry
155
+ logger.info(f"Registered algorithm '{name}' with ID {algo_uuid}")
156
+ return algo_uuid
157
+
158
+ def can_request_training_data(self, algo_id: UUID) -> tuple[bool, str]:
159
+ """Check if algorithm can request training data"""
160
+ if algo_id not in self._algorithms:
161
+ return False, f"Algorithm {algo_id} not registered"
162
+
163
+ state = self._algorithms[algo_id].state
164
+
165
+ if state == AlgorithmStateEnum.COMPLETED:
166
+ return False, "Algorithm has completed evaluation"
167
+ if state == AlgorithmStateEnum.NEW:
168
+ return False, "The algorithm must be set to READY state first"
169
+ if state == AlgorithmStateEnum.PREDICTED:
170
+ return False, "Algorithm has already requested data for this window"
171
+ if state == AlgorithmStateEnum.READY:
172
+ return True, ""
173
+
174
+ return False, f"Unknown state {state}"
175
+
176
+ def can_request_unlabeled_data(self, algo_id: UUID) -> tuple[bool, str]:
177
+ """Check if algorithm can request unlabeled data"""
178
+ if algo_id not in self._algorithms:
179
+ return False, f"Algorithm {algo_id} not registered"
180
+
181
+ state = self._algorithms[algo_id].state
182
+
183
+ if state == AlgorithmStateEnum.RUNNING:
184
+ return True, ""
185
+ if state == AlgorithmStateEnum.COMPLETED:
186
+ return False, "Algorithm has completed evaluation"
187
+ if state == AlgorithmStateEnum.NEW:
188
+ return False, "The algorithm must be set to RUNNING state to request unlabeled data"
189
+ if state == AlgorithmStateEnum.PREDICTED:
190
+ return False, "Algorithm has already requested data for this window"
191
+ if state == AlgorithmStateEnum.READY:
192
+ return (
193
+ False,
194
+ "The algorithm must be set to RUNNING state to request unlabeled data. Request training data first",
195
+ )
196
+
197
+ return False, f"Unknown state {state}"
198
+
199
+ def can_submit_prediction(self, algo_id: UUID) -> tuple[bool, str]:
200
+ """Check if algorithm can submit prediction"""
201
+ if algo_id not in self._algorithms:
202
+ return False, f"Algorithm {algo_id} not registered"
203
+
204
+ state = self._algorithms[algo_id].state
205
+
206
+ if state == AlgorithmStateEnum.RUNNING:
207
+ return True, ""
208
+ if state == AlgorithmStateEnum.READY:
209
+ return False, "There is new data to be requested"
210
+ if state == AlgorithmStateEnum.NEW:
211
+ return False, "Algorithm must request data first"
212
+ if state == AlgorithmStateEnum.PREDICTED:
213
+ return False, "Algorithm already submitted prediction for this window"
214
+ if state == AlgorithmStateEnum.COMPLETED:
215
+ return False, "Algorithm has completed evaluation"
216
+
217
+ return False, f"Unknown state {state}"
218
+
219
+ def transition(self, algo_id: UUID, new_state: AlgorithmStateEnum, data_segment: None | int = None) -> None:
220
+ """Transition algorithm to new state with validation"""
221
+ if algo_id not in self._algorithms:
222
+ raise ValueError(f"Algorithm {algo_id} not registered")
223
+
224
+ entry = self._algorithms[algo_id]
225
+ old_state = entry.state
226
+
227
+ # Define valid transitions
228
+ valid_transitions = {
229
+ # old_state: [list of valid new_states]
230
+ AlgorithmStateEnum.NEW: [AlgorithmStateEnum.READY, AlgorithmStateEnum.COMPLETED],
231
+ AlgorithmStateEnum.READY: [AlgorithmStateEnum.RUNNING],
232
+ AlgorithmStateEnum.RUNNING: [AlgorithmStateEnum.PREDICTED],
233
+ AlgorithmStateEnum.PREDICTED: [AlgorithmStateEnum.READY, AlgorithmStateEnum.COMPLETED],
234
+ AlgorithmStateEnum.COMPLETED: [],
235
+ }
236
+
237
+ if new_state not in valid_transitions.get(old_state, []):
238
+ raise ValueError(f"Invalid transition: {old_state} -> {new_state}")
239
+
240
+ entry.state = new_state
241
+ if data_segment is not None:
242
+ entry.data_segment = data_segment
243
+
244
+ logger.debug(f"Algorithm '{entry.name}' transitioned {old_state.value} -> {new_state.value}")
245
+
246
+ def is_all_predicted(self) -> bool:
247
+ """Return whether every registered algorithm is in PREDICTED state.
248
+
249
+ Returns:
250
+ True if all registered entries have state
251
+ `AlgorithmStateEnum.PREDICTED`, False otherwise.
252
+ """
253
+ if not self._algorithms:
254
+ return False
255
+ return all(entry.state == AlgorithmStateEnum.PREDICTED for entry in self._algorithms.values())
256
+
257
+ def get_all_states(self) -> dict[str, AlgorithmStateEnum]:
258
+ """Get state of all algorithms"""
259
+ return {entry.name: entry.state for entry in self._algorithms.values()}
260
+
261
+ def is_all_same_data_segment(self) -> bool:
262
+ """Return whether all registered entries share the same data segment.
263
+
264
+ Returns:
265
+ True if there is exactly one distinct data segment across all
266
+ registered entries, False otherwise.
267
+ """
268
+ data_segments: set[None | int] = set()
269
+ for key in self:
270
+ data_segments.add(self[key].data_segment)
271
+ return len(data_segments) == 1
272
+
273
+ def all_algo_states(self) -> dict[str, AlgorithmStateEnum]:
274
+ """Return a mapping of identifier strings to algorithm states.
275
+
276
+ The identifier used is "{name}_{uuid}" for each registered entry.
277
+
278
+ Returns:
279
+ Mapping from identifier string to the entry's
280
+ :class:`AlgorithmStateEnum`.
281
+ """
282
+ states: dict[str, AlgorithmStateEnum] = {}
283
+ for key in self:
284
+ states[f"{self[key].name}_{key}"] = self[key].state
285
+ return states
286
+
287
+ def set_all_ready(self, data_segment: int) -> None:
288
+ """Set all registered algorithms to the READY state.
289
+
290
+ Args:
291
+ data_segment: Data segment to assign to every algorithm.
292
+ """
293
+ for key in self:
294
+ self.transition(key, AlgorithmStateEnum.READY, data_segment)
295
+
296
+ def get_algorithm_identifier(self, algo_id: UUID) -> str:
297
+ """Return a stable identifier string for the algorithm.
298
+
299
+ Args:
300
+ algo_id: UUID of the algorithm.
301
+
302
+ Returns:
303
+ Identifier in the format "{name}_{uuid}".
304
+
305
+ Raises:
306
+ AttributeError: If `algo_id` is not registered.
307
+ """
308
+ if algo_id not in self._algorithms:
309
+ raise AttributeError(f"Algorithm with ID:{algo_id} not registered")
310
+ return f"{self[algo_id].name}_{algo_id}"