kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +177 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +247 -0
- kumoai/experimental/rfm/base/table.py +404 -203
- kumoai/experimental/rfm/graph.py +374 -172
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +762 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,26 +1,23 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import math
|
|
2
3
|
import time
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
6
7
|
from contextlib import contextmanager
|
|
7
8
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
9
|
+
from typing import Any, Literal, overload
|
|
19
10
|
|
|
20
|
-
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,35 +26,38 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
30
|
+
from rich.console import Console
|
|
31
|
+
from rich.markdown import Markdown
|
|
32
32
|
|
|
33
|
+
from kumoai import in_notebook
|
|
33
34
|
from kumoai.client.rfm import RFMAPI
|
|
34
35
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
36
|
+
from kumoai.experimental.rfm import Graph, TaskTable
|
|
37
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
38
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
39
|
+
from kumoai.utils import ProgressLogger, display
|
|
44
40
|
|
|
45
41
|
_RANDOM_SEED = 42
|
|
46
42
|
|
|
47
43
|
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
44
|
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
45
|
|
|
46
|
+
_MAX_TEST_SIZE: dict[TaskType, int] = defaultdict(lambda: 2_000)
|
|
47
|
+
_MAX_TEST_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 400
|
|
48
|
+
|
|
50
49
|
_MAX_CONTEXT_SIZE = {
|
|
51
50
|
RunMode.DEBUG: 100,
|
|
52
51
|
RunMode.FAST: 1_000,
|
|
53
52
|
RunMode.NORMAL: 5_000,
|
|
54
53
|
RunMode.BEST: 10_000,
|
|
55
54
|
}
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
RunMode.
|
|
59
|
-
RunMode.
|
|
60
|
-
RunMode.
|
|
55
|
+
|
|
56
|
+
_DEFAULT_NUM_NEIGHBORS = {
|
|
57
|
+
RunMode.DEBUG: [16, 16, 4, 4, 1, 1],
|
|
58
|
+
RunMode.FAST: [32, 32, 8, 8, 4, 4],
|
|
59
|
+
RunMode.NORMAL: [64, 64, 8, 8, 4, 4],
|
|
60
|
+
RunMode.BEST: [64, 64, 8, 8, 4, 4],
|
|
61
61
|
}
|
|
62
62
|
|
|
63
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
@@ -95,24 +95,36 @@ class Explanation:
|
|
|
95
95
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
96
|
pass
|
|
97
97
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
98
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
99
|
if index == 0:
|
|
100
100
|
return self.prediction
|
|
101
101
|
if index == 1:
|
|
102
102
|
return self.summary
|
|
103
103
|
raise IndexError("Index out of range")
|
|
104
104
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
105
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
106
|
return iter((self.prediction, self.summary))
|
|
107
107
|
|
|
108
108
|
def __repr__(self) -> str:
|
|
109
109
|
return str((self.prediction, self.summary))
|
|
110
110
|
|
|
111
|
-
def
|
|
112
|
-
|
|
111
|
+
def __str__(self) -> str:
|
|
112
|
+
console = Console(soft_wrap=True)
|
|
113
|
+
with console.capture() as cap:
|
|
114
|
+
console.print(display.to_rich_table(self.prediction))
|
|
115
|
+
console.print(Markdown(self.summary))
|
|
116
|
+
return cap.get()[:-1]
|
|
117
|
+
|
|
118
|
+
def print(self) -> None:
|
|
119
|
+
r"""Prints the explanation."""
|
|
120
|
+
if in_notebook():
|
|
121
|
+
display.dataframe(self.prediction)
|
|
122
|
+
display.message(self.summary)
|
|
123
|
+
else:
|
|
124
|
+
print(self)
|
|
113
125
|
|
|
114
|
-
|
|
115
|
-
|
|
126
|
+
def _ipython_display_(self) -> None:
|
|
127
|
+
self.print()
|
|
116
128
|
|
|
117
129
|
|
|
118
130
|
class KumoRFM:
|
|
@@ -151,21 +163,36 @@ class KumoRFM:
|
|
|
151
163
|
Args:
|
|
152
164
|
graph: The graph.
|
|
153
165
|
verbose: Whether to print verbose output.
|
|
166
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
167
|
+
for optimal querying. For example, for transactional database
|
|
168
|
+
backends, will create any missing indices. Requires write-access to
|
|
169
|
+
the data backend.
|
|
154
170
|
"""
|
|
155
171
|
def __init__(
|
|
156
172
|
self,
|
|
157
173
|
graph: Graph,
|
|
158
|
-
verbose:
|
|
174
|
+
verbose: bool | ProgressLogger = True,
|
|
175
|
+
optimize: bool = False,
|
|
159
176
|
) -> None:
|
|
160
177
|
graph = graph.validate()
|
|
161
178
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
self._graph_store = LocalGraphStore(graph, verbose)
|
|
163
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
164
179
|
|
|
165
|
-
|
|
180
|
+
if graph.backend == DataBackend.LOCAL:
|
|
181
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
182
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
183
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
184
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
185
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
186
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
187
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
188
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
189
|
+
else:
|
|
190
|
+
raise NotImplementedError
|
|
166
191
|
|
|
167
|
-
self.
|
|
168
|
-
|
|
192
|
+
self._client: RFMAPI | None = None
|
|
193
|
+
|
|
194
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
195
|
+
self._num_retries: int = 0
|
|
169
196
|
|
|
170
197
|
@property
|
|
171
198
|
def _api_client(self) -> RFMAPI:
|
|
@@ -179,10 +206,34 @@ class KumoRFM:
|
|
|
179
206
|
def __repr__(self) -> str:
|
|
180
207
|
return f'{self.__class__.__name__}()'
|
|
181
208
|
|
|
209
|
+
@contextmanager
|
|
210
|
+
def retry(
|
|
211
|
+
self,
|
|
212
|
+
num_retries: int = 1,
|
|
213
|
+
) -> Generator[None, None, None]:
|
|
214
|
+
"""Context manager to retry failed queries due to unexpected server
|
|
215
|
+
issues.
|
|
216
|
+
|
|
217
|
+
.. code-block:: python
|
|
218
|
+
|
|
219
|
+
with model.retry(num_retries=1):
|
|
220
|
+
df = model.predict(query, indices=...)
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
num_retries: The maximum number of retries.
|
|
224
|
+
"""
|
|
225
|
+
if num_retries < 0:
|
|
226
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
227
|
+
f"zero (got {num_retries})")
|
|
228
|
+
|
|
229
|
+
self._num_retries = num_retries
|
|
230
|
+
yield
|
|
231
|
+
self._num_retries = 0
|
|
232
|
+
|
|
182
233
|
@contextmanager
|
|
183
234
|
def batch_mode(
|
|
184
235
|
self,
|
|
185
|
-
batch_size:
|
|
236
|
+
batch_size: int | Literal['max'] = 'max',
|
|
186
237
|
num_retries: int = 1,
|
|
187
238
|
) -> Generator[None, None, None]:
|
|
188
239
|
"""Context manager to predict in batches.
|
|
@@ -202,31 +253,26 @@ class KumoRFM:
|
|
|
202
253
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
203
254
|
f"(got {batch_size})")
|
|
204
255
|
|
|
205
|
-
if num_retries < 0:
|
|
206
|
-
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
207
|
-
f"zero (got {num_retries})")
|
|
208
|
-
|
|
209
256
|
self._batch_size = batch_size
|
|
210
|
-
self.
|
|
211
|
-
|
|
257
|
+
with self.retry(self._num_retries or num_retries):
|
|
258
|
+
yield
|
|
212
259
|
self._batch_size = None
|
|
213
|
-
self.num_retries = 0
|
|
214
260
|
|
|
215
261
|
@overload
|
|
216
262
|
def predict(
|
|
217
263
|
self,
|
|
218
264
|
query: str,
|
|
219
|
-
indices:
|
|
265
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
220
266
|
*,
|
|
221
267
|
explain: Literal[False] = False,
|
|
222
|
-
anchor_time:
|
|
223
|
-
context_anchor_time:
|
|
224
|
-
run_mode:
|
|
225
|
-
num_neighbors:
|
|
268
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
269
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
270
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
271
|
+
num_neighbors: list[int] | None = None,
|
|
226
272
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
228
|
-
random_seed:
|
|
229
|
-
verbose:
|
|
273
|
+
max_pq_iterations: int = 10,
|
|
274
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
275
|
+
verbose: bool | ProgressLogger = True,
|
|
230
276
|
use_prediction_time: bool = False,
|
|
231
277
|
) -> pd.DataFrame:
|
|
232
278
|
pass
|
|
@@ -235,37 +281,56 @@ class KumoRFM:
|
|
|
235
281
|
def predict(
|
|
236
282
|
self,
|
|
237
283
|
query: str,
|
|
238
|
-
indices:
|
|
284
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
239
285
|
*,
|
|
240
|
-
explain:
|
|
241
|
-
anchor_time:
|
|
242
|
-
context_anchor_time:
|
|
243
|
-
run_mode:
|
|
244
|
-
num_neighbors:
|
|
286
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
287
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
288
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
289
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
290
|
+
num_neighbors: list[int] | None = None,
|
|
245
291
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
247
|
-
random_seed:
|
|
248
|
-
verbose:
|
|
292
|
+
max_pq_iterations: int = 10,
|
|
293
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
294
|
+
verbose: bool | ProgressLogger = True,
|
|
249
295
|
use_prediction_time: bool = False,
|
|
250
296
|
) -> Explanation:
|
|
251
297
|
pass
|
|
252
298
|
|
|
299
|
+
@overload
|
|
253
300
|
def predict(
|
|
254
301
|
self,
|
|
255
302
|
query: str,
|
|
256
|
-
indices:
|
|
303
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
304
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
305
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
306
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
307
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
308
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
309
|
+
num_neighbors: list[int] | None = None,
|
|
263
310
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
311
|
+
max_pq_iterations: int = 10,
|
|
312
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
313
|
+
verbose: bool | ProgressLogger = True,
|
|
267
314
|
use_prediction_time: bool = False,
|
|
268
|
-
) ->
|
|
315
|
+
) -> pd.DataFrame | Explanation:
|
|
316
|
+
pass
|
|
317
|
+
|
|
318
|
+
def predict(
|
|
319
|
+
self,
|
|
320
|
+
query: str,
|
|
321
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
322
|
+
*,
|
|
323
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
324
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
325
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
326
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
327
|
+
num_neighbors: list[int] | None = None,
|
|
328
|
+
num_hops: int = 2,
|
|
329
|
+
max_pq_iterations: int = 10,
|
|
330
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
331
|
+
verbose: bool | ProgressLogger = True,
|
|
332
|
+
use_prediction_time: bool = False,
|
|
333
|
+
) -> pd.DataFrame | Explanation:
|
|
269
334
|
"""Returns predictions for a predictive query.
|
|
270
335
|
|
|
271
336
|
Args:
|
|
@@ -273,8 +338,7 @@ class KumoRFM:
|
|
|
273
338
|
indices: The entity primary keys to predict for. Will override the
|
|
274
339
|
indices given as part of the predictive query. Predictions will
|
|
275
340
|
be generated for all indices, independent of whether they
|
|
276
|
-
fulfill entity filter constraints.
|
|
277
|
-
:meth:`~KumoRFM.is_valid_entity`.
|
|
341
|
+
fulfill entity filter constraints.
|
|
278
342
|
explain: Configuration for explainability.
|
|
279
343
|
If set to ``True``, will additionally explain the prediction.
|
|
280
344
|
Passing in an :class:`ExplainConfig` instance provides control
|
|
@@ -307,18 +371,152 @@ class KumoRFM:
|
|
|
307
371
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
372
|
containing the prediction, summary, and details.
|
|
309
373
|
"""
|
|
310
|
-
explain_config: Optional[ExplainConfig] = None
|
|
311
|
-
if explain is True:
|
|
312
|
-
explain_config = ExplainConfig()
|
|
313
|
-
elif explain is not False:
|
|
314
|
-
explain_config = ExplainConfig._cast(explain)
|
|
315
|
-
|
|
316
374
|
query_def = self._parse_query(query)
|
|
317
|
-
query_str = query_def.to_string()
|
|
318
375
|
|
|
376
|
+
if indices is None:
|
|
377
|
+
if query_def.rfm_entity_ids is None:
|
|
378
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
379
|
+
"pass them via `predict(query, indices=...)`")
|
|
380
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
381
|
+
query_def = replace(
|
|
382
|
+
query_def,
|
|
383
|
+
for_each='FOR EACH',
|
|
384
|
+
rfm_entity_ids=None,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
if not isinstance(verbose, ProgressLogger):
|
|
388
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
389
|
+
if explain is not False:
|
|
390
|
+
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
391
|
+
else:
|
|
392
|
+
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
393
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
394
|
+
|
|
395
|
+
with verbose as logger:
|
|
396
|
+
task_table = self._get_task_table(
|
|
397
|
+
query=query_def,
|
|
398
|
+
indices=indices,
|
|
399
|
+
anchor_time=anchor_time,
|
|
400
|
+
context_anchor_time=context_anchor_time,
|
|
401
|
+
run_mode=run_mode,
|
|
402
|
+
max_pq_iterations=max_pq_iterations,
|
|
403
|
+
random_seed=random_seed,
|
|
404
|
+
logger=logger,
|
|
405
|
+
)
|
|
406
|
+
task_table._query = query_def.to_string()
|
|
407
|
+
|
|
408
|
+
return self.predict_task(
|
|
409
|
+
task_table,
|
|
410
|
+
explain=explain,
|
|
411
|
+
run_mode=run_mode,
|
|
412
|
+
num_neighbors=num_neighbors,
|
|
413
|
+
num_hops=num_hops,
|
|
414
|
+
verbose=verbose,
|
|
415
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
416
|
+
use_prediction_time=use_prediction_time,
|
|
417
|
+
top_k=query_def.top_k,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
@overload
|
|
421
|
+
def predict_task(
|
|
422
|
+
self,
|
|
423
|
+
task: TaskTable,
|
|
424
|
+
*,
|
|
425
|
+
explain: Literal[False] = False,
|
|
426
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
427
|
+
num_neighbors: list[int] | None = None,
|
|
428
|
+
num_hops: int = 2,
|
|
429
|
+
verbose: bool | ProgressLogger = True,
|
|
430
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
431
|
+
use_prediction_time: bool = False,
|
|
432
|
+
top_k: int | None = None,
|
|
433
|
+
) -> pd.DataFrame:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
@overload
|
|
437
|
+
def predict_task(
|
|
438
|
+
self,
|
|
439
|
+
task: TaskTable,
|
|
440
|
+
*,
|
|
441
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
442
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
443
|
+
num_neighbors: list[int] | None = None,
|
|
444
|
+
num_hops: int = 2,
|
|
445
|
+
verbose: bool | ProgressLogger = True,
|
|
446
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
447
|
+
use_prediction_time: bool = False,
|
|
448
|
+
top_k: int | None = None,
|
|
449
|
+
) -> Explanation:
|
|
450
|
+
pass
|
|
451
|
+
|
|
452
|
+
@overload
|
|
453
|
+
def predict_task(
|
|
454
|
+
self,
|
|
455
|
+
task: TaskTable,
|
|
456
|
+
*,
|
|
457
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
458
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
459
|
+
num_neighbors: list[int] | None = None,
|
|
460
|
+
num_hops: int = 2,
|
|
461
|
+
verbose: bool | ProgressLogger = True,
|
|
462
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
463
|
+
use_prediction_time: bool = False,
|
|
464
|
+
top_k: int | None = None,
|
|
465
|
+
) -> pd.DataFrame | Explanation:
|
|
466
|
+
pass
|
|
467
|
+
|
|
468
|
+
def predict_task(
|
|
469
|
+
self,
|
|
470
|
+
task: TaskTable,
|
|
471
|
+
*,
|
|
472
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
473
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
474
|
+
num_neighbors: list[int] | None = None,
|
|
475
|
+
num_hops: int = 2,
|
|
476
|
+
verbose: bool | ProgressLogger = True,
|
|
477
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
478
|
+
use_prediction_time: bool = False,
|
|
479
|
+
top_k: int | None = None,
|
|
480
|
+
) -> pd.DataFrame | Explanation:
|
|
481
|
+
"""Returns predictions for a custom task specification.
|
|
482
|
+
|
|
483
|
+
Args:
|
|
484
|
+
task: The custom :class:`TaskTable`.
|
|
485
|
+
explain: Configuration for explainability.
|
|
486
|
+
If set to ``True``, will additionally explain the prediction.
|
|
487
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
488
|
+
over which parts of explanation are generated.
|
|
489
|
+
Explainability is currently only supported for single entity
|
|
490
|
+
predictions with ``run_mode="FAST"``.
|
|
491
|
+
run_mode: The :class:`RunMode` for the query.
|
|
492
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
493
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
494
|
+
num_hops: The number of hops to sample when generating the context.
|
|
495
|
+
verbose: Whether to print verbose output.
|
|
496
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
497
|
+
model input.
|
|
498
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
499
|
+
additional feature during prediction. This is typically
|
|
500
|
+
beneficial for time series forecasting tasks.
|
|
501
|
+
top_k: The number of predictions to return per entity.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
505
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
506
|
+
containing the prediction, summary, and details.
|
|
507
|
+
"""
|
|
319
508
|
if num_hops != 2 and num_neighbors is not None:
|
|
320
509
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
321
510
|
f"custom 'num_hops={num_hops}' option")
|
|
511
|
+
if num_neighbors is None:
|
|
512
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
513
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
514
|
+
|
|
515
|
+
explain_config: ExplainConfig | None = None
|
|
516
|
+
if explain is True:
|
|
517
|
+
explain_config = ExplainConfig()
|
|
518
|
+
elif explain is not False:
|
|
519
|
+
explain_config = ExplainConfig._cast(explain)
|
|
322
520
|
|
|
323
521
|
if explain_config is not None and run_mode in {
|
|
324
522
|
RunMode.NORMAL, RunMode.BEST
|
|
@@ -327,83 +525,82 @@ class KumoRFM:
|
|
|
327
525
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
328
526
|
f"mode has been reset. Please lower the run mode to "
|
|
329
527
|
f"suppress this warning.")
|
|
528
|
+
run_mode = RunMode.FAST
|
|
330
529
|
|
|
331
|
-
if
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
336
|
-
else:
|
|
337
|
-
query_def = replace(query_def, rfm_entity_ids=None)
|
|
338
|
-
|
|
339
|
-
if len(indices) == 0:
|
|
340
|
-
raise ValueError("At least one entity is required")
|
|
341
|
-
|
|
342
|
-
if explain_config is not None and len(indices) > 1:
|
|
343
|
-
raise ValueError(
|
|
344
|
-
f"Cannot explain predictions for more than a single entity "
|
|
345
|
-
f"(got {len(indices)})")
|
|
346
|
-
|
|
347
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
348
|
-
if explain_config is not None:
|
|
349
|
-
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
350
|
-
else:
|
|
351
|
-
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
530
|
+
if explain_config is not None and task.num_prediction_examples > 1:
|
|
531
|
+
raise ValueError(f"Cannot explain predictions for more than a "
|
|
532
|
+
f"single entity "
|
|
533
|
+
f"(got {task.num_prediction_examples:,})")
|
|
352
534
|
|
|
353
535
|
if not isinstance(verbose, ProgressLogger):
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
edge_types=self._graph_store.edge_types,
|
|
363
|
-
)
|
|
364
|
-
batch_size = _MAX_PRED_SIZE[task_type]
|
|
536
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
537
|
+
task_type_repr = 'binary classification'
|
|
538
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
539
|
+
task_type_repr = 'multi-class classification'
|
|
540
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
541
|
+
task_type_repr = 'regression'
|
|
542
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
543
|
+
task_type_repr = 'link prediction'
|
|
365
544
|
else:
|
|
366
|
-
|
|
545
|
+
task_type_repr = str(task.task_type)
|
|
367
546
|
|
|
368
|
-
if
|
|
369
|
-
|
|
370
|
-
batches = [indices[step:step + batch_size] for step in offsets]
|
|
547
|
+
if explain_config is not None:
|
|
548
|
+
msg = f"Explaining {task_type_repr} task"
|
|
371
549
|
else:
|
|
372
|
-
|
|
550
|
+
msg = f"Predicting {task_type_repr} task"
|
|
551
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
373
552
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
553
|
+
with verbose as logger:
|
|
554
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
555
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
556
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
557
|
+
f"examples")
|
|
558
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
559
|
+
|
|
560
|
+
if self._batch_size is None:
|
|
561
|
+
batch_size = task.num_prediction_examples
|
|
562
|
+
elif self._batch_size == 'max':
|
|
563
|
+
batch_size = _MAX_PRED_SIZE[task.task_type]
|
|
564
|
+
else:
|
|
565
|
+
batch_size = self._batch_size
|
|
377
566
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
567
|
+
if batch_size > _MAX_PRED_SIZE[task.task_type]:
|
|
568
|
+
raise ValueError(f"Cannot predict for more than "
|
|
569
|
+
f"{_MAX_PRED_SIZE[task.task_type]:,} "
|
|
570
|
+
f"entities at once (got {batch_size:,}). Use "
|
|
571
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
572
|
+
f"in batches with a sufficient batch size.")
|
|
573
|
+
|
|
574
|
+
if task.num_prediction_examples > batch_size:
|
|
575
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
576
|
+
logger.log(f"Splitting {task.num_prediction_examples:,} "
|
|
577
|
+
f"entities into {num:,} batches of size "
|
|
578
|
+
f"{batch_size:,}")
|
|
579
|
+
|
|
580
|
+
predictions: list[pd.DataFrame] = []
|
|
581
|
+
summary: str | None = None
|
|
582
|
+
details: Explanation | None = None
|
|
583
|
+
for start in range(0, task.num_prediction_examples, batch_size):
|
|
383
584
|
context = self._get_context(
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
anchor_time=anchor_time,
|
|
387
|
-
context_anchor_time=context_anchor_time,
|
|
388
|
-
run_mode=RunMode(run_mode),
|
|
585
|
+
task=task.narrow_prediction(start, length=batch_size),
|
|
586
|
+
run_mode=run_mode,
|
|
389
587
|
num_neighbors=num_neighbors,
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
evaluate=False,
|
|
393
|
-
random_seed=random_seed,
|
|
394
|
-
logger=logger if i == 0 else None,
|
|
588
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
589
|
+
top_k=top_k,
|
|
395
590
|
)
|
|
591
|
+
context.y_test = None
|
|
592
|
+
|
|
396
593
|
request = RFMPredictRequest(
|
|
397
594
|
context=context,
|
|
398
595
|
run_mode=RunMode(run_mode),
|
|
399
|
-
query=
|
|
596
|
+
query=task._query,
|
|
400
597
|
use_prediction_time=use_prediction_time,
|
|
401
598
|
)
|
|
402
599
|
with warnings.catch_warnings():
|
|
403
600
|
warnings.filterwarnings('ignore', message='gencode')
|
|
404
601
|
request_msg = request.to_protobuf()
|
|
405
602
|
_bytes = request_msg.SerializeToString()
|
|
406
|
-
if
|
|
603
|
+
if start == 0:
|
|
407
604
|
logger.log(f"Generated context of size "
|
|
408
605
|
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
409
606
|
|
|
@@ -411,14 +608,11 @@ class KumoRFM:
|
|
|
411
608
|
stats = Context.get_memory_stats(request_msg.context)
|
|
412
609
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
413
610
|
|
|
414
|
-
if
|
|
415
|
-
|
|
416
|
-
verbose.init_progress(
|
|
417
|
-
total=len(batches),
|
|
418
|
-
description='Predicting',
|
|
419
|
-
)
|
|
611
|
+
if start == 0 and task.num_prediction_examples > batch_size:
|
|
612
|
+
num = math.ceil(task.num_prediction_examples / batch_size)
|
|
613
|
+
verbose.init_progress(total=num, description='Predicting')
|
|
420
614
|
|
|
421
|
-
for attempt in range(self.
|
|
615
|
+
for attempt in range(self._num_retries + 1):
|
|
422
616
|
try:
|
|
423
617
|
if explain_config is not None:
|
|
424
618
|
resp = self._api_client.explain(
|
|
@@ -433,10 +627,10 @@ class KumoRFM:
|
|
|
433
627
|
|
|
434
628
|
# Cast 'ENTITY' to correct data type:
|
|
435
629
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
630
|
+
table_dict = context.subgraph.table_dict
|
|
631
|
+
table = table_dict[context.entity_table_names[0]]
|
|
632
|
+
ser = table.df[table.primary_key]
|
|
633
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
634
|
|
|
441
635
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
636
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -451,13 +645,12 @@ class KumoRFM:
|
|
|
451
645
|
|
|
452
646
|
predictions.append(df)
|
|
453
647
|
|
|
454
|
-
if
|
|
455
|
-
and len(batches) > 1):
|
|
648
|
+
if task.num_prediction_examples > batch_size:
|
|
456
649
|
verbose.step()
|
|
457
650
|
|
|
458
651
|
break
|
|
459
652
|
except HTTPException as e:
|
|
460
|
-
if attempt == self.
|
|
653
|
+
if attempt == self._num_retries:
|
|
461
654
|
try:
|
|
462
655
|
msg = json.loads(e.detail)['detail']
|
|
463
656
|
except Exception:
|
|
@@ -487,69 +680,19 @@ class KumoRFM:
|
|
|
487
680
|
|
|
488
681
|
return prediction
|
|
489
682
|
|
|
490
|
-
def is_valid_entity(
|
|
491
|
-
self,
|
|
492
|
-
query: str,
|
|
493
|
-
indices: Union[List[str], List[float], List[int], None] = None,
|
|
494
|
-
*,
|
|
495
|
-
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
496
|
-
) -> np.ndarray:
|
|
497
|
-
r"""Returns a mask that denotes which entities are valid for the
|
|
498
|
-
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
499
|
-
entity filter constraints.
|
|
500
|
-
|
|
501
|
-
Args:
|
|
502
|
-
query: The predictive query.
|
|
503
|
-
indices: The entity primary keys to predict for. Will override the
|
|
504
|
-
indices given as part of the predictive query.
|
|
505
|
-
anchor_time: The anchor timestamp for the prediction. If set to
|
|
506
|
-
``None``, will use the maximum timestamp in the data.
|
|
507
|
-
If set to ``"entity"``, will use the timestamp of the entity.
|
|
508
|
-
"""
|
|
509
|
-
query_def = self._parse_query(query)
|
|
510
|
-
|
|
511
|
-
if indices is None:
|
|
512
|
-
if query_def.rfm_entity_ids is None:
|
|
513
|
-
raise ValueError("Cannot find entities to predict for. Please "
|
|
514
|
-
"pass them via "
|
|
515
|
-
"`is_valid_entity(query, indices=...)`")
|
|
516
|
-
indices = query_def.get_rfm_entity_id_list()
|
|
517
|
-
|
|
518
|
-
if len(indices) == 0:
|
|
519
|
-
raise ValueError("At least one entity is required")
|
|
520
|
-
|
|
521
|
-
if anchor_time is None:
|
|
522
|
-
anchor_time = self._graph_store.max_time
|
|
523
|
-
|
|
524
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
525
|
-
self._validate_time(query_def, anchor_time, None, False)
|
|
526
|
-
else:
|
|
527
|
-
assert anchor_time == 'entity'
|
|
528
|
-
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
529
|
-
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
|
-
f"table '{query_def.entity_table}' "
|
|
531
|
-
f"to have a time column.")
|
|
532
|
-
|
|
533
|
-
node = self._graph_store.get_node_id(
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
539
|
-
|
|
540
683
|
def evaluate(
|
|
541
684
|
self,
|
|
542
685
|
query: str,
|
|
543
686
|
*,
|
|
544
|
-
metrics:
|
|
545
|
-
anchor_time:
|
|
546
|
-
context_anchor_time:
|
|
547
|
-
run_mode:
|
|
548
|
-
num_neighbors:
|
|
687
|
+
metrics: list[str] | None = None,
|
|
688
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
689
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
690
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
691
|
+
num_neighbors: list[int] | None = None,
|
|
549
692
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
551
|
-
random_seed:
|
|
552
|
-
verbose:
|
|
693
|
+
max_pq_iterations: int = 10,
|
|
694
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
695
|
+
verbose: bool | ProgressLogger = True,
|
|
553
696
|
use_prediction_time: bool = False,
|
|
554
697
|
) -> pd.DataFrame:
|
|
555
698
|
"""Evaluates a predictive query.
|
|
@@ -581,41 +724,120 @@ class KumoRFM:
|
|
|
581
724
|
Returns:
|
|
582
725
|
The metrics as a :class:`pandas.DataFrame`
|
|
583
726
|
"""
|
|
584
|
-
query_def =
|
|
727
|
+
query_def = replace(
|
|
728
|
+
self._parse_query(query),
|
|
729
|
+
for_each='FOR EACH',
|
|
730
|
+
rfm_entity_ids=None,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if not isinstance(verbose, ProgressLogger):
|
|
734
|
+
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
735
|
+
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
736
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
737
|
+
|
|
738
|
+
with verbose as logger:
|
|
739
|
+
task_table = self._get_task_table(
|
|
740
|
+
query=query_def,
|
|
741
|
+
indices=None,
|
|
742
|
+
anchor_time=anchor_time,
|
|
743
|
+
context_anchor_time=context_anchor_time,
|
|
744
|
+
run_mode=run_mode,
|
|
745
|
+
max_pq_iterations=max_pq_iterations,
|
|
746
|
+
random_seed=random_seed,
|
|
747
|
+
logger=logger,
|
|
748
|
+
)
|
|
749
|
+
|
|
750
|
+
return self.evaluate_task(
|
|
751
|
+
task_table,
|
|
752
|
+
metrics=metrics,
|
|
753
|
+
run_mode=run_mode,
|
|
754
|
+
num_neighbors=num_neighbors,
|
|
755
|
+
num_hops=num_hops,
|
|
756
|
+
verbose=verbose,
|
|
757
|
+
exclude_cols_dict=query_def.get_exclude_cols_dict(),
|
|
758
|
+
use_prediction_time=use_prediction_time,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
def evaluate_task(
|
|
762
|
+
self,
|
|
763
|
+
task: TaskTable,
|
|
764
|
+
*,
|
|
765
|
+
metrics: list[str] | None = None,
|
|
766
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
767
|
+
num_neighbors: list[int] | None = None,
|
|
768
|
+
num_hops: int = 2,
|
|
769
|
+
verbose: bool | ProgressLogger = True,
|
|
770
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
771
|
+
use_prediction_time: bool = False,
|
|
772
|
+
) -> pd.DataFrame:
|
|
773
|
+
"""Evaluates a custom task specification.
|
|
585
774
|
|
|
775
|
+
Args:
|
|
776
|
+
task: The custom :class:`TaskTable`.
|
|
777
|
+
metrics: The metrics to use.
|
|
778
|
+
run_mode: The :class:`RunMode` for the query.
|
|
779
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
780
|
+
If specified, the ``num_hops`` option will be ignored.
|
|
781
|
+
num_hops: The number of hops to sample when generating the context.
|
|
782
|
+
verbose: Whether to print verbose output.
|
|
783
|
+
exclude_cols_dict: Any column in any table to exclude from the
|
|
784
|
+
model input.
|
|
785
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
786
|
+
additional feature during prediction. This is typically
|
|
787
|
+
beneficial for time series forecasting tasks.
|
|
788
|
+
|
|
789
|
+
Returns:
|
|
790
|
+
The metrics as a :class:`pandas.DataFrame`
|
|
791
|
+
"""
|
|
586
792
|
if num_hops != 2 and num_neighbors is not None:
|
|
587
793
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
588
794
|
f"custom 'num_hops={num_hops}' option")
|
|
795
|
+
if num_neighbors is None:
|
|
796
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
797
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:num_hops]
|
|
589
798
|
|
|
590
|
-
if
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
rfm_entity_ids=None,
|
|
594
|
-
)
|
|
595
|
-
|
|
596
|
-
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
597
|
-
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
799
|
+
if metrics is not None and len(metrics) > 0:
|
|
800
|
+
self._validate_metrics(metrics, task.task_type)
|
|
801
|
+
metrics = list(dict.fromkeys(metrics))
|
|
598
802
|
|
|
599
803
|
if not isinstance(verbose, ProgressLogger):
|
|
600
|
-
|
|
804
|
+
if task.task_type == TaskType.BINARY_CLASSIFICATION:
|
|
805
|
+
task_type_repr = 'binary classification'
|
|
806
|
+
elif task.task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
807
|
+
task_type_repr = 'multi-class classification'
|
|
808
|
+
elif task.task_type == TaskType.REGRESSION:
|
|
809
|
+
task_type_repr = 'regression'
|
|
810
|
+
elif task.task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
811
|
+
task_type_repr = 'link prediction'
|
|
812
|
+
else:
|
|
813
|
+
task_type_repr = str(task.task_type)
|
|
814
|
+
|
|
815
|
+
msg = f"Evaluating {task_type_repr} task"
|
|
816
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
601
817
|
|
|
602
818
|
with verbose as logger:
|
|
819
|
+
if task.num_context_examples > _MAX_CONTEXT_SIZE[run_mode]:
|
|
820
|
+
logger.log(f"Sub-sampled {_MAX_CONTEXT_SIZE[run_mode]:,} "
|
|
821
|
+
f"out of {task.num_context_examples:,} in-context "
|
|
822
|
+
f"examples")
|
|
823
|
+
task = task.narrow_context(0, _MAX_CONTEXT_SIZE[run_mode])
|
|
824
|
+
|
|
825
|
+
if task.num_prediction_examples > _MAX_TEST_SIZE[task.task_type]:
|
|
826
|
+
logger.log(f"Sub-sampled {_MAX_TEST_SIZE[task.task_type]:,} "
|
|
827
|
+
f"out of {task.num_prediction_examples:,} test "
|
|
828
|
+
f"examples")
|
|
829
|
+
task = task.narrow_prediction(
|
|
830
|
+
start=0,
|
|
831
|
+
length=_MAX_TEST_SIZE[task.task_type],
|
|
832
|
+
)
|
|
833
|
+
|
|
603
834
|
context = self._get_context(
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
anchor_time=anchor_time,
|
|
607
|
-
context_anchor_time=context_anchor_time,
|
|
608
|
-
run_mode=RunMode(run_mode),
|
|
835
|
+
task=task,
|
|
836
|
+
run_mode=run_mode,
|
|
609
837
|
num_neighbors=num_neighbors,
|
|
610
|
-
|
|
611
|
-
max_pq_iterations=max_pq_iterations,
|
|
612
|
-
evaluate=True,
|
|
613
|
-
random_seed=random_seed,
|
|
614
|
-
logger=logger if verbose else None,
|
|
838
|
+
exclude_cols_dict=exclude_cols_dict,
|
|
615
839
|
)
|
|
616
|
-
|
|
617
|
-
self._validate_metrics(metrics, context.task_type)
|
|
618
|
-
metrics = list(dict.fromkeys(metrics))
|
|
840
|
+
|
|
619
841
|
request = RFMEvaluateRequest(
|
|
620
842
|
context=context,
|
|
621
843
|
run_mode=RunMode(run_mode),
|
|
@@ -633,17 +855,23 @@ class KumoRFM:
|
|
|
633
855
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
634
856
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
635
857
|
|
|
636
|
-
|
|
637
|
-
resp = self._api_client.evaluate(request_bytes)
|
|
638
|
-
except HTTPException as e:
|
|
858
|
+
for attempt in range(self._num_retries + 1):
|
|
639
859
|
try:
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
860
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
861
|
+
break
|
|
862
|
+
except HTTPException as e:
|
|
863
|
+
if attempt == self._num_retries:
|
|
864
|
+
try:
|
|
865
|
+
msg = json.loads(e.detail)['detail']
|
|
866
|
+
except Exception:
|
|
867
|
+
msg = e.detail
|
|
868
|
+
raise RuntimeError(
|
|
869
|
+
f"An unexpected exception occurred. Please create "
|
|
870
|
+
f"an issue at "
|
|
871
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
872
|
+
) from None
|
|
873
|
+
|
|
874
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
647
875
|
|
|
648
876
|
return pd.DataFrame.from_dict(
|
|
649
877
|
resp.metrics,
|
|
@@ -656,9 +884,9 @@ class KumoRFM:
|
|
|
656
884
|
query: str,
|
|
657
885
|
size: int,
|
|
658
886
|
*,
|
|
659
|
-
anchor_time:
|
|
660
|
-
random_seed:
|
|
661
|
-
max_iterations: int =
|
|
887
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
888
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
889
|
+
max_iterations: int = 10,
|
|
662
890
|
) -> pd.DataFrame:
|
|
663
891
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
892
|
time.
|
|
@@ -678,40 +906,37 @@ class KumoRFM:
|
|
|
678
906
|
query_def = self._parse_query(query)
|
|
679
907
|
|
|
680
908
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
909
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
910
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
911
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
912
|
+
offset *= query_def.num_forecasts
|
|
913
|
+
anchor_time -= offset
|
|
686
914
|
|
|
687
915
|
assert anchor_time is not None
|
|
688
916
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
917
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
918
|
else:
|
|
691
919
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
920
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
921
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
922
|
f"table '{query_def.entity_table}' "
|
|
695
923
|
f"to have a time column")
|
|
696
924
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
925
|
+
train, test = self._sampler.sample_target(
|
|
926
|
+
query=query_def,
|
|
927
|
+
num_train_examples=0,
|
|
928
|
+
train_anchor_time=anchor_time,
|
|
929
|
+
num_train_trials=0,
|
|
930
|
+
num_test_examples=size,
|
|
931
|
+
test_anchor_time=anchor_time,
|
|
932
|
+
num_test_trials=max_iterations * size,
|
|
933
|
+
random_seed=random_seed,
|
|
706
934
|
)
|
|
707
935
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
936
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
937
|
+
'ENTITY': test.entity_pkey,
|
|
938
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
939
|
+
'TARGET': test.target,
|
|
715
940
|
})
|
|
716
941
|
|
|
717
942
|
# Helpers #################################################################
|
|
@@ -726,63 +951,120 @@ class KumoRFM:
|
|
|
726
951
|
"`predict()` or `evaluate()` methods to perform "
|
|
727
952
|
"predictions or evaluations.")
|
|
728
953
|
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
)
|
|
954
|
+
request = RFMParseQueryRequest(
|
|
955
|
+
query=query,
|
|
956
|
+
graph_definition=self._graph_def,
|
|
957
|
+
)
|
|
734
958
|
|
|
735
|
-
|
|
959
|
+
for attempt in range(self._num_retries + 1):
|
|
960
|
+
try:
|
|
961
|
+
resp = self._api_client.parse_query(request)
|
|
962
|
+
break
|
|
963
|
+
except HTTPException as e:
|
|
964
|
+
if attempt == self._num_retries:
|
|
965
|
+
try:
|
|
966
|
+
msg = json.loads(e.detail)['detail']
|
|
967
|
+
except Exception:
|
|
968
|
+
msg = e.detail
|
|
969
|
+
raise ValueError(f"Failed to parse query '{query}'. {msg}")
|
|
736
970
|
|
|
737
|
-
|
|
971
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
738
972
|
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
973
|
+
if len(resp.validation_response.warnings) > 0:
|
|
974
|
+
msg = '\n'.join([
|
|
975
|
+
f'{i+1}. {warning.title}: {warning.message}'
|
|
976
|
+
for i, warning in enumerate(resp.validation_response.warnings)
|
|
977
|
+
])
|
|
978
|
+
warnings.warn(f"Encountered the following warnings during "
|
|
979
|
+
f"parsing:\n{msg}")
|
|
746
980
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
981
|
+
return resp.query
|
|
982
|
+
|
|
983
|
+
@staticmethod
|
|
984
|
+
def _get_task_type(
|
|
985
|
+
query: ValidatedPredictiveQuery,
|
|
986
|
+
edge_types: list[tuple[str, str, str]],
|
|
987
|
+
) -> TaskType:
|
|
988
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
989
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
990
|
+
|
|
991
|
+
target = query.target_ast
|
|
992
|
+
if isinstance(target, Join):
|
|
993
|
+
target = target.rhs_target
|
|
994
|
+
if isinstance(target, Aggregation):
|
|
995
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
996
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
997
|
+
'.')
|
|
998
|
+
target_edge_types = [
|
|
999
|
+
edge_type for edge_type in edge_types
|
|
1000
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
1001
|
+
]
|
|
1002
|
+
if len(target_edge_types) != 1:
|
|
1003
|
+
raise NotImplementedError(
|
|
1004
|
+
f"Multilabel-classification queries based on "
|
|
1005
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
1006
|
+
f"planned to write a link prediction query instead, "
|
|
1007
|
+
f"make sure to register '{col_name}' as a "
|
|
1008
|
+
f"foreign key.")
|
|
1009
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
1010
|
+
|
|
1011
|
+
return TaskType.REGRESSION
|
|
1012
|
+
|
|
1013
|
+
assert isinstance(target, Column)
|
|
1014
|
+
|
|
1015
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
1016
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
1017
|
+
|
|
1018
|
+
if target.stype in {Stype.numerical}:
|
|
1019
|
+
return TaskType.REGRESSION
|
|
1020
|
+
|
|
1021
|
+
raise NotImplementedError("Task type not yet supported")
|
|
1022
|
+
|
|
1023
|
+
def _get_default_anchor_time(
|
|
1024
|
+
self,
|
|
1025
|
+
query: ValidatedPredictiveQuery | None = None,
|
|
1026
|
+
) -> pd.Timestamp:
|
|
1027
|
+
if query is not None and query.query_type == QueryType.TEMPORAL:
|
|
1028
|
+
aggr_table_names = [
|
|
1029
|
+
aggr._get_target_column_name().split('.')[0]
|
|
1030
|
+
for aggr in query.get_all_target_aggregations()
|
|
1031
|
+
]
|
|
1032
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
1033
|
+
|
|
1034
|
+
return self._sampler.get_max_time()
|
|
755
1035
|
|
|
756
1036
|
def _validate_time(
|
|
757
1037
|
self,
|
|
758
1038
|
query: ValidatedPredictiveQuery,
|
|
759
1039
|
anchor_time: pd.Timestamp,
|
|
760
|
-
context_anchor_time:
|
|
1040
|
+
context_anchor_time: pd.Timestamp | None,
|
|
761
1041
|
evaluate: bool,
|
|
762
1042
|
) -> None:
|
|
763
1043
|
|
|
764
|
-
if self.
|
|
1044
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
1045
|
return # Graph without timestamps
|
|
766
1046
|
|
|
767
|
-
|
|
1047
|
+
min_time = self._sampler.get_min_time()
|
|
1048
|
+
max_time = self._sampler.get_max_time()
|
|
1049
|
+
|
|
1050
|
+
if anchor_time < min_time:
|
|
768
1051
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
1052
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
1053
|
+
f"data.")
|
|
771
1054
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
1055
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
1056
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
1057
|
f"aggregation time range is too large. To make "
|
|
776
1058
|
f"this prediction, we would need data back to "
|
|
777
1059
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
1060
|
+
f"only contains data back to '{min_time}'.")
|
|
780
1061
|
|
|
781
1062
|
if query.target_ast.date_offset_range is not None:
|
|
782
1063
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
1064
|
else:
|
|
784
1065
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
1066
|
+
end_offset = end_offset * query.num_forecasts
|
|
1067
|
+
|
|
786
1068
|
if (context_anchor_time is not None
|
|
787
1069
|
and context_anchor_time > anchor_time):
|
|
788
1070
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +1074,7 @@ class KumoRFM:
|
|
|
792
1074
|
f"intended.")
|
|
793
1075
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
1076
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
1077
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
1078
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
1079
|
f"'{context_anchor_time}' will leak information "
|
|
798
1080
|
f"from the prediction anchor timestamp "
|
|
@@ -800,62 +1082,44 @@ class KumoRFM:
|
|
|
800
1082
|
f"intended.")
|
|
801
1083
|
|
|
802
1084
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
1085
|
+
and context_anchor_time - end_offset < min_time):
|
|
1086
|
+
_time = context_anchor_time - end_offset
|
|
806
1087
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
1088
|
f"aggregation time range is too large. To form "
|
|
808
1089
|
f"proper input data, we would need data back to "
|
|
809
1090
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
1091
|
+
f"data back to '{min_time}'.")
|
|
811
1092
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
1093
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
1094
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
1095
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
1096
|
+
f"make sure this is intended.")
|
|
817
1097
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
1098
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
1099
|
raise ValueError(
|
|
821
1100
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
1101
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
1102
|
|
|
824
|
-
def
|
|
1103
|
+
def _get_task_table(
|
|
825
1104
|
self,
|
|
826
1105
|
query: ValidatedPredictiveQuery,
|
|
827
|
-
indices:
|
|
828
|
-
anchor_time:
|
|
829
|
-
context_anchor_time:
|
|
830
|
-
run_mode: RunMode,
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
if num_neighbors is not None:
|
|
840
|
-
num_hops = len(num_neighbors)
|
|
841
|
-
|
|
842
|
-
if num_hops < 0:
|
|
843
|
-
raise ValueError(f"'num_hops' must be non-negative "
|
|
844
|
-
f"(got {num_hops})")
|
|
845
|
-
if num_hops > 6:
|
|
846
|
-
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
847
|
-
f"hops (got {num_hops}). Please reduce the "
|
|
848
|
-
f"number of hops and try again. Please create a "
|
|
849
|
-
f"feature request at "
|
|
850
|
-
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
|
-
f"must go beyond this for your use-case.")
|
|
852
|
-
|
|
853
|
-
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
854
|
-
task_type = LocalPQueryDriver.get_task_type(
|
|
855
|
-
query,
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
1106
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
1107
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
1108
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
1109
|
+
run_mode: RunMode = RunMode.FAST,
|
|
1110
|
+
max_pq_iterations: int = 10,
|
|
1111
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
1112
|
+
logger: ProgressLogger | None = None,
|
|
1113
|
+
) -> TaskTable:
|
|
1114
|
+
|
|
1115
|
+
task_type = self._get_task_type(
|
|
1116
|
+
query=query,
|
|
1117
|
+
edge_types=self._sampler.edge_types,
|
|
857
1118
|
)
|
|
858
1119
|
|
|
1120
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
1121
|
+
num_test_examples = _MAX_TEST_SIZE[task_type] if indices is None else 0
|
|
1122
|
+
|
|
859
1123
|
if logger is not None:
|
|
860
1124
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
861
1125
|
task_type_repr = 'binary classification'
|
|
@@ -869,30 +1133,17 @@ class KumoRFM:
|
|
|
869
1133
|
task_type_repr = str(task_type)
|
|
870
1134
|
logger.log(f"Identified {query.query_type} {task_type_repr} task")
|
|
871
1135
|
|
|
872
|
-
if task_type.is_link_pred and num_hops < 2:
|
|
873
|
-
raise ValueError(f"Cannot perform link prediction on subgraphs "
|
|
874
|
-
f"with less than 2 hops (got {num_hops}) since "
|
|
875
|
-
f"historical target entities need to be part of "
|
|
876
|
-
f"the context. Please increase the number of "
|
|
877
|
-
f"hops and try again.")
|
|
878
|
-
|
|
879
|
-
if num_neighbors is None:
|
|
880
|
-
if run_mode == RunMode.DEBUG:
|
|
881
|
-
num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
|
|
882
|
-
elif run_mode == RunMode.FAST or task_type.is_link_pred:
|
|
883
|
-
num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
|
|
884
|
-
else:
|
|
885
|
-
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
|
-
|
|
887
1136
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
1137
|
+
step_offset = pd.DateOffset(0)
|
|
889
1138
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
1139
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
1140
|
+
end_offset = step_offset * query.num_forecasts
|
|
1141
|
+
|
|
892
1142
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
894
|
-
if
|
|
895
|
-
anchor_time = anchor_time -
|
|
1143
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
1144
|
+
if num_test_examples > 0:
|
|
1145
|
+
anchor_time = anchor_time - end_offset
|
|
1146
|
+
|
|
896
1147
|
if logger is not None:
|
|
897
1148
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
1149
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -904,114 +1155,98 @@ class KumoRFM:
|
|
|
904
1155
|
else:
|
|
905
1156
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
906
1157
|
|
|
907
|
-
assert anchor_time is not None
|
|
908
1158
|
if isinstance(anchor_time, pd.Timestamp):
|
|
1159
|
+
if context_anchor_time == 'entity':
|
|
1160
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1161
|
+
"for context and prediction examples")
|
|
909
1162
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
1163
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
1164
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
|
-
evaluate)
|
|
1165
|
+
evaluate=num_test_examples > 0)
|
|
913
1166
|
else:
|
|
914
1167
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
1168
|
+
if query.query_type != QueryType.STATIC:
|
|
1169
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
1170
|
+
"static predictive queries")
|
|
1171
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
1172
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
1173
|
f"table '{query.entity_table}' to "
|
|
918
1174
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
if logger is not None:
|
|
937
|
-
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
938
|
-
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
939
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
940
|
-
f"{pos:.2f}% positive cases")
|
|
941
|
-
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
942
|
-
msg = (f"Collected {len(y_test):,} test examples "
|
|
943
|
-
f"holding {y_test.nunique()} classes")
|
|
944
|
-
elif task_type == TaskType.REGRESSION:
|
|
945
|
-
_min, _max = float(y_test.min()), float(y_test.max())
|
|
946
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
947
|
-
f"targets between {format_value(_min)} and "
|
|
948
|
-
f"{format_value(_max)}")
|
|
949
|
-
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
956
|
-
|
|
957
|
-
else:
|
|
958
|
-
assert indices is not None
|
|
959
|
-
|
|
960
|
-
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
|
-
raise ValueError(f"Cannot predict for more than "
|
|
962
|
-
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
963
|
-
f"once (got {len(indices):,}). Use "
|
|
964
|
-
f"`KumoRFM.batch_mode` to process entities "
|
|
965
|
-
f"in batches")
|
|
1175
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
1176
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1177
|
+
"for context and prediction examples")
|
|
1178
|
+
context_anchor_time = 'entity'
|
|
1179
|
+
|
|
1180
|
+
train, test = self._sampler.sample_target(
|
|
1181
|
+
query=query,
|
|
1182
|
+
num_train_examples=num_train_examples,
|
|
1183
|
+
train_anchor_time=context_anchor_time,
|
|
1184
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1185
|
+
num_test_examples=num_test_examples,
|
|
1186
|
+
test_anchor_time=anchor_time,
|
|
1187
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1188
|
+
random_seed=random_seed,
|
|
1189
|
+
)
|
|
1190
|
+
train_pkey, train_time, train_y = train
|
|
1191
|
+
test_pkey, test_time, test_y = test
|
|
966
1192
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
1193
|
+
if num_test_examples > 0 and logger is not None:
|
|
1194
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1195
|
+
pos = 100 * int((test_y > 0).sum()) / len(test_y)
|
|
1196
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1197
|
+
f"{pos:.2f}% positive cases")
|
|
1198
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1199
|
+
msg = (f"Collected {len(test_y):,} test examples holding "
|
|
1200
|
+
f"{test_y.nunique()} classes")
|
|
1201
|
+
elif task_type == TaskType.REGRESSION:
|
|
1202
|
+
_min, _max = float(test_y.min()), float(test_y.max())
|
|
1203
|
+
msg = (f"Collected {len(test_y):,} test examples with targets "
|
|
1204
|
+
f"between {format_value(_min)} and "
|
|
1205
|
+
f"{format_value(_max)}")
|
|
1206
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1207
|
+
num_rhs = test_y.explode().nunique()
|
|
1208
|
+
msg = (f"Collected {len(test_y):,} test examples with "
|
|
1209
|
+
f"{num_rhs:,} unique items")
|
|
1210
|
+
else:
|
|
1211
|
+
raise NotImplementedError
|
|
1212
|
+
logger.log(msg)
|
|
971
1213
|
|
|
1214
|
+
if num_test_examples == 0:
|
|
1215
|
+
assert indices is not None
|
|
1216
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1217
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1218
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1219
|
+
len(indices)).reset_index(drop=True)
|
|
975
1220
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1221
|
+
train_time = test_time = 'entity'
|
|
987
1222
|
|
|
988
1223
|
if logger is not None:
|
|
989
1224
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
990
|
-
pos = 100 * int((
|
|
991
|
-
msg = (f"Collected {len(
|
|
1225
|
+
pos = 100 * int((train_y > 0).sum()) / len(train_y)
|
|
1226
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
992
1227
|
f"{pos:.2f}% positive cases")
|
|
993
1228
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
994
|
-
msg = (f"Collected {len(
|
|
995
|
-
f"holding {
|
|
1229
|
+
msg = (f"Collected {len(train_y):,} in-context examples "
|
|
1230
|
+
f"holding {train_y.nunique()} classes")
|
|
996
1231
|
elif task_type == TaskType.REGRESSION:
|
|
997
|
-
_min, _max = float(
|
|
998
|
-
msg = (f"Collected {len(
|
|
1232
|
+
_min, _max = float(train_y.min()), float(train_y.max())
|
|
1233
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
999
1234
|
f"targets between {format_value(_min)} and "
|
|
1000
1235
|
f"{format_value(_max)}")
|
|
1001
1236
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1002
|
-
num_rhs =
|
|
1003
|
-
msg = (f"Collected {len(
|
|
1237
|
+
num_rhs = train_y.explode().nunique()
|
|
1238
|
+
msg = (f"Collected {len(train_y):,} in-context examples with "
|
|
1004
1239
|
f"{num_rhs:,} unique items")
|
|
1005
1240
|
else:
|
|
1006
1241
|
raise NotImplementedError
|
|
1007
1242
|
logger.log(msg)
|
|
1008
1243
|
|
|
1009
|
-
entity_table_names:
|
|
1244
|
+
entity_table_names: tuple[str] | tuple[str, str]
|
|
1010
1245
|
if task_type.is_link_pred:
|
|
1011
1246
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1247
|
assert final_aggr is not None
|
|
1013
1248
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1249
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1250
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1251
|
entity_table_names = (
|
|
1017
1252
|
query.entity_table,
|
|
@@ -1020,23 +1255,80 @@ class KumoRFM:
|
|
|
1020
1255
|
else:
|
|
1021
1256
|
entity_table_names = (query.entity_table, )
|
|
1022
1257
|
|
|
1258
|
+
context_df = pd.DataFrame({'ENTITY': train_pkey, 'TARGET': train_y})
|
|
1259
|
+
if isinstance(train_time, pd.Series):
|
|
1260
|
+
context_df['ANCHOR_TIMESTAMP'] = train_time
|
|
1261
|
+
pred_df = pd.DataFrame({'ENTITY': test_pkey})
|
|
1262
|
+
if num_test_examples > 0:
|
|
1263
|
+
pred_df['TARGET'] = test_y
|
|
1264
|
+
if isinstance(test_time, pd.Series):
|
|
1265
|
+
pred_df['ANCHOR_TIMESTAMP'] = test_time
|
|
1266
|
+
|
|
1267
|
+
return TaskTable(
|
|
1268
|
+
task_type=task_type,
|
|
1269
|
+
context_df=context_df,
|
|
1270
|
+
pred_df=pred_df,
|
|
1271
|
+
entity_table_name=entity_table_names,
|
|
1272
|
+
entity_column='ENTITY',
|
|
1273
|
+
target_column='TARGET',
|
|
1274
|
+
time_column='ANCHOR_TIMESTAMP' if isinstance(
|
|
1275
|
+
train_time, pd.Series) else TaskTable.ENTITY_TIME,
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
def _get_context(
|
|
1279
|
+
self,
|
|
1280
|
+
task: TaskTable,
|
|
1281
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
1282
|
+
num_neighbors: list[int] | None = None,
|
|
1283
|
+
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
1284
|
+
top_k: int | None = None,
|
|
1285
|
+
) -> Context:
|
|
1286
|
+
|
|
1287
|
+
if num_neighbors is None:
|
|
1288
|
+
key = RunMode.FAST if task.task_type.is_link_pred else run_mode
|
|
1289
|
+
num_neighbors = _DEFAULT_NUM_NEIGHBORS[key][:2]
|
|
1290
|
+
|
|
1291
|
+
if len(num_neighbors) > 6:
|
|
1292
|
+
raise ValueError(f"Cannot predict on subgraphs with more than 6 "
|
|
1293
|
+
f"hops (got {len(num_neighbors)}). Reduce the "
|
|
1294
|
+
f"number of hops and try again. Please create a "
|
|
1295
|
+
f"feature request at "
|
|
1296
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1297
|
+
f"must go beyond this for your use-case.")
|
|
1298
|
+
|
|
1023
1299
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1300
|
# running out-of-distribution between in-context and test examples:
|
|
1025
|
-
exclude_cols_dict =
|
|
1026
|
-
if
|
|
1027
|
-
if
|
|
1028
|
-
exclude_cols_dict[
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1301
|
+
exclude_cols_dict = exclude_cols_dict or {}
|
|
1302
|
+
if task.entity_table_name in self._sampler.time_column_dict:
|
|
1303
|
+
if task.entity_table_name not in exclude_cols_dict:
|
|
1304
|
+
exclude_cols_dict[task.entity_table_name] = []
|
|
1305
|
+
time_col = self._sampler.time_column_dict[task.entity_table_name]
|
|
1306
|
+
exclude_cols_dict[task.entity_table_name].append(time_col)
|
|
1307
|
+
|
|
1308
|
+
entity_pkey = pd.concat([
|
|
1309
|
+
task._context_df[task._entity_column],
|
|
1310
|
+
task._pred_df[task._entity_column],
|
|
1311
|
+
], axis=0, ignore_index=True)
|
|
1312
|
+
|
|
1313
|
+
if task.use_entity_time:
|
|
1314
|
+
if task.entity_table_name not in self._sampler.time_column_dict:
|
|
1315
|
+
raise ValueError(f"The given annchor time requires the entity "
|
|
1316
|
+
f"table '{task.entity_table_name}' to have a "
|
|
1317
|
+
f"time column")
|
|
1318
|
+
anchor_time = 'entity'
|
|
1319
|
+
elif task._time_column is not None:
|
|
1320
|
+
anchor_time = pd.concat([
|
|
1321
|
+
task._context_df[task._time_column],
|
|
1322
|
+
task._pred_df[task._time_column],
|
|
1323
|
+
], axis=0, ignore_index=True)
|
|
1324
|
+
else:
|
|
1325
|
+
anchor_time = pd.Series(self._get_default_anchor_time()).repeat(
|
|
1326
|
+
(len(entity_pkey))).reset_index(drop=True)
|
|
1327
|
+
|
|
1328
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1329
|
+
entity_table_names=task.entity_table_names,
|
|
1330
|
+
entity_pkey=entity_pkey,
|
|
1331
|
+
anchor_time=anchor_time,
|
|
1040
1332
|
num_neighbors=num_neighbors,
|
|
1041
1333
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1334
|
)
|
|
@@ -1048,23 +1340,26 @@ class KumoRFM:
|
|
|
1048
1340
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1341
|
f"must go beyond this for your use-case.")
|
|
1050
1342
|
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1343
|
+
if (task.task_type.is_link_pred
|
|
1344
|
+
and task.entity_table_names[-1] not in subgraph.table_dict):
|
|
1345
|
+
raise ValueError("Cannot perform link prediction on subgraphs "
|
|
1346
|
+
"without any historical target entities. Please "
|
|
1347
|
+
"increase the number of hops and try again.")
|
|
1054
1348
|
|
|
1055
1349
|
return Context(
|
|
1056
|
-
task_type=task_type,
|
|
1057
|
-
entity_table_names=entity_table_names,
|
|
1350
|
+
task_type=task.task_type,
|
|
1351
|
+
entity_table_names=task.entity_table_names,
|
|
1058
1352
|
subgraph=subgraph,
|
|
1059
|
-
y_train=
|
|
1060
|
-
y_test=
|
|
1061
|
-
|
|
1062
|
-
|
|
1353
|
+
y_train=task._context_df[task.target_column.name],
|
|
1354
|
+
y_test=task._pred_df[task.target_column.name]
|
|
1355
|
+
if task.evaluate else None,
|
|
1356
|
+
top_k=top_k,
|
|
1357
|
+
step_size=None,
|
|
1063
1358
|
)
|
|
1064
1359
|
|
|
1065
1360
|
@staticmethod
|
|
1066
1361
|
def _validate_metrics(
|
|
1067
|
-
metrics:
|
|
1362
|
+
metrics: list[str],
|
|
1068
1363
|
task_type: TaskType,
|
|
1069
1364
|
) -> None:
|
|
1070
1365
|
|
|
@@ -1121,7 +1416,7 @@ class KumoRFM:
|
|
|
1121
1416
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1122
1417
|
|
|
1123
1418
|
|
|
1124
|
-
def format_value(value:
|
|
1419
|
+
def format_value(value: int | float) -> str:
|
|
1125
1420
|
if value == int(value):
|
|
1126
1421
|
return f'{int(value):,}'
|
|
1127
1422
|
if abs(value) >= 1000:
|