kumoai 2.12.0.dev202511061731__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +41 -35
- kumoai/_version.py +1 -1
- kumoai/client/client.py +15 -13
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +15 -7
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +191 -48
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +735 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +346 -248
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/kumolib.cp311-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +188 -16
- kumoai/utils/sql.py +3 -0
- {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
- {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +54 -31
- kumoai/experimental/rfm/local_graph.py +0 -810
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,15 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import Any, Literal, overload
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
import pandas as pd
|
|
12
12
|
from kumoapi.model_plan import RunMode
|
|
13
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
14
21
|
from kumoapi.rfm import Context
|
|
15
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
23
|
from kumoapi.rfm import (
|
|
@@ -19,17 +26,14 @@ from kumoapi.rfm import (
|
|
|
19
26
|
RFMPredictRequest,
|
|
20
27
|
)
|
|
21
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
22
30
|
|
|
23
|
-
from kumoai import
|
|
31
|
+
from kumoai.client.rfm import RFMAPI
|
|
24
32
|
from kumoai.exceptions import HTTPException
|
|
25
|
-
from kumoai.experimental.rfm import
|
|
26
|
-
from kumoai.experimental.rfm.
|
|
27
|
-
from kumoai.
|
|
28
|
-
from kumoai.
|
|
29
|
-
LocalPQueryDriver,
|
|
30
|
-
date_offset_to_seconds,
|
|
31
|
-
)
|
|
32
|
-
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
33
|
+
from kumoai.experimental.rfm import Graph
|
|
34
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
35
|
+
from kumoai.mixin import CastMixin
|
|
36
|
+
from kumoai.utils import ProgressLogger, display
|
|
33
37
|
|
|
34
38
|
_RANDOM_SEED = 42
|
|
35
39
|
|
|
@@ -59,6 +63,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
59
63
|
"beyond this for your use-case.")
|
|
60
64
|
|
|
61
65
|
|
|
66
|
+
@dataclass(repr=False)
|
|
67
|
+
class ExplainConfig(CastMixin):
|
|
68
|
+
"""Configuration for explainability.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
72
|
+
the explanation.
|
|
73
|
+
"""
|
|
74
|
+
skip_summary: bool = False
|
|
75
|
+
|
|
76
|
+
|
|
62
77
|
@dataclass(repr=False)
|
|
63
78
|
class Explanation:
|
|
64
79
|
prediction: pd.DataFrame
|
|
@@ -73,19 +88,27 @@ class Explanation:
|
|
|
73
88
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
74
89
|
pass
|
|
75
90
|
|
|
76
|
-
def __getitem__(self, index: int) ->
|
|
91
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
77
92
|
if index == 0:
|
|
78
93
|
return self.prediction
|
|
79
94
|
if index == 1:
|
|
80
95
|
return self.summary
|
|
81
96
|
raise IndexError("Index out of range")
|
|
82
97
|
|
|
83
|
-
def __iter__(self) -> Iterator[
|
|
98
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
84
99
|
return iter((self.prediction, self.summary))
|
|
85
100
|
|
|
86
101
|
def __repr__(self) -> str:
|
|
87
102
|
return str((self.prediction, self.summary))
|
|
88
103
|
|
|
104
|
+
def print(self) -> None:
|
|
105
|
+
r"""Prints the explanation."""
|
|
106
|
+
display.dataframe(self.prediction)
|
|
107
|
+
display.message(self.summary)
|
|
108
|
+
|
|
109
|
+
def _ipython_display_(self) -> None:
|
|
110
|
+
self.print()
|
|
111
|
+
|
|
89
112
|
|
|
90
113
|
class KumoRFM:
|
|
91
114
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -95,17 +118,17 @@ class KumoRFM:
|
|
|
95
118
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
96
119
|
relational dataset without training.
|
|
97
120
|
The model is pre-trained and the class provides an interface to query the
|
|
98
|
-
model from a :class:`
|
|
121
|
+
model from a :class:`Graph` object.
|
|
99
122
|
|
|
100
123
|
.. code-block:: python
|
|
101
124
|
|
|
102
|
-
from kumoai.experimental.rfm import
|
|
125
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
103
126
|
|
|
104
127
|
df_users = pd.DataFrame(...)
|
|
105
128
|
df_items = pd.DataFrame(...)
|
|
106
129
|
df_orders = pd.DataFrame(...)
|
|
107
130
|
|
|
108
|
-
graph =
|
|
131
|
+
graph = Graph.from_data({
|
|
109
132
|
'users': df_users,
|
|
110
133
|
'items': df_items,
|
|
111
134
|
'orders': df_orders,
|
|
@@ -113,47 +136,63 @@ class KumoRFM:
|
|
|
113
136
|
|
|
114
137
|
rfm = KumoRFM(graph)
|
|
115
138
|
|
|
116
|
-
query = ("PREDICT COUNT(
|
|
117
|
-
"FOR users.user_id=
|
|
118
|
-
result = rfm.
|
|
139
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
140
|
+
"FOR users.user_id=1")
|
|
141
|
+
result = rfm.predict(query)
|
|
119
142
|
|
|
120
143
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
121
144
|
# 1 0.85
|
|
122
145
|
|
|
123
146
|
Args:
|
|
124
147
|
graph: The graph.
|
|
125
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
126
|
-
materialization.
|
|
127
|
-
This is a runtime trade-off between graph materialization and model
|
|
128
|
-
processing speed.
|
|
129
|
-
It can be benefical to preprocess your data once and then run many
|
|
130
|
-
queries on top to achieve maximum model speed.
|
|
131
|
-
However, if activiated, graph materialization can take potentially
|
|
132
|
-
much longer, especially on graphs with many large text columns.
|
|
133
|
-
Best to tune this option manually.
|
|
134
148
|
verbose: Whether to print verbose output.
|
|
149
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
150
|
+
for optimal querying. For example, for transactional database
|
|
151
|
+
backends, will create any missing indices. Requires write-access to
|
|
152
|
+
the data backend.
|
|
135
153
|
"""
|
|
136
154
|
def __init__(
|
|
137
155
|
self,
|
|
138
|
-
graph:
|
|
139
|
-
|
|
140
|
-
|
|
156
|
+
graph: Graph,
|
|
157
|
+
verbose: bool | ProgressLogger = True,
|
|
158
|
+
optimize: bool = False,
|
|
141
159
|
) -> None:
|
|
142
160
|
graph = graph.validate()
|
|
143
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
144
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
145
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
146
162
|
|
|
147
|
-
|
|
163
|
+
if graph.backend == DataBackend.LOCAL:
|
|
164
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
165
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
166
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
167
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
168
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
169
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
170
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
171
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
172
|
+
else:
|
|
173
|
+
raise NotImplementedError
|
|
174
|
+
|
|
175
|
+
self._client: RFMAPI | None = None
|
|
176
|
+
|
|
177
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
148
178
|
self.num_retries: int = 0
|
|
149
179
|
|
|
180
|
+
@property
|
|
181
|
+
def _api_client(self) -> RFMAPI:
|
|
182
|
+
if self._client is not None:
|
|
183
|
+
return self._client
|
|
184
|
+
|
|
185
|
+
from kumoai.experimental.rfm import global_state
|
|
186
|
+
self._client = RFMAPI(global_state.client)
|
|
187
|
+
return self._client
|
|
188
|
+
|
|
150
189
|
def __repr__(self) -> str:
|
|
151
190
|
return f'{self.__class__.__name__}()'
|
|
152
191
|
|
|
153
192
|
@contextmanager
|
|
154
193
|
def batch_mode(
|
|
155
194
|
self,
|
|
156
|
-
batch_size:
|
|
195
|
+
batch_size: int | Literal['max'] = 'max',
|
|
157
196
|
num_retries: int = 1,
|
|
158
197
|
) -> Generator[None, None, None]:
|
|
159
198
|
"""Context manager to predict in batches.
|
|
@@ -187,17 +226,17 @@ class KumoRFM:
|
|
|
187
226
|
def predict(
|
|
188
227
|
self,
|
|
189
228
|
query: str,
|
|
190
|
-
indices:
|
|
229
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
191
230
|
*,
|
|
192
231
|
explain: Literal[False] = False,
|
|
193
|
-
anchor_time:
|
|
194
|
-
context_anchor_time:
|
|
195
|
-
run_mode:
|
|
196
|
-
num_neighbors:
|
|
232
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
233
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
234
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
235
|
+
num_neighbors: list[int] | None = None,
|
|
197
236
|
num_hops: int = 2,
|
|
198
|
-
max_pq_iterations: int =
|
|
199
|
-
random_seed:
|
|
200
|
-
verbose:
|
|
237
|
+
max_pq_iterations: int = 10,
|
|
238
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
239
|
+
verbose: bool | ProgressLogger = True,
|
|
201
240
|
use_prediction_time: bool = False,
|
|
202
241
|
) -> pd.DataFrame:
|
|
203
242
|
pass
|
|
@@ -206,17 +245,17 @@ class KumoRFM:
|
|
|
206
245
|
def predict(
|
|
207
246
|
self,
|
|
208
247
|
query: str,
|
|
209
|
-
indices:
|
|
248
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
210
249
|
*,
|
|
211
|
-
explain: Literal[True],
|
|
212
|
-
anchor_time:
|
|
213
|
-
context_anchor_time:
|
|
214
|
-
run_mode:
|
|
215
|
-
num_neighbors:
|
|
250
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
251
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
252
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
253
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
254
|
+
num_neighbors: list[int] | None = None,
|
|
216
255
|
num_hops: int = 2,
|
|
217
|
-
max_pq_iterations: int =
|
|
218
|
-
random_seed:
|
|
219
|
-
verbose:
|
|
256
|
+
max_pq_iterations: int = 10,
|
|
257
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
258
|
+
verbose: bool | ProgressLogger = True,
|
|
220
259
|
use_prediction_time: bool = False,
|
|
221
260
|
) -> Explanation:
|
|
222
261
|
pass
|
|
@@ -224,19 +263,19 @@ class KumoRFM:
|
|
|
224
263
|
def predict(
|
|
225
264
|
self,
|
|
226
265
|
query: str,
|
|
227
|
-
indices:
|
|
266
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
228
267
|
*,
|
|
229
|
-
explain: bool = False,
|
|
230
|
-
anchor_time:
|
|
231
|
-
context_anchor_time:
|
|
232
|
-
run_mode:
|
|
233
|
-
num_neighbors:
|
|
268
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
269
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
270
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
271
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
272
|
+
num_neighbors: list[int] | None = None,
|
|
234
273
|
num_hops: int = 2,
|
|
235
|
-
max_pq_iterations: int =
|
|
236
|
-
random_seed:
|
|
237
|
-
verbose:
|
|
274
|
+
max_pq_iterations: int = 10,
|
|
275
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
276
|
+
verbose: bool | ProgressLogger = True,
|
|
238
277
|
use_prediction_time: bool = False,
|
|
239
|
-
) ->
|
|
278
|
+
) -> pd.DataFrame | Explanation:
|
|
240
279
|
"""Returns predictions for a predictive query.
|
|
241
280
|
|
|
242
281
|
Args:
|
|
@@ -246,9 +285,12 @@ class KumoRFM:
|
|
|
246
285
|
be generated for all indices, independent of whether they
|
|
247
286
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
248
287
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
249
|
-
explain:
|
|
250
|
-
|
|
251
|
-
|
|
288
|
+
explain: Configuration for explainability.
|
|
289
|
+
If set to ``True``, will additionally explain the prediction.
|
|
290
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
291
|
+
over which parts of explanation are generated.
|
|
292
|
+
Explainability is currently only supported for single entity
|
|
293
|
+
predictions with ``run_mode="FAST"``.
|
|
252
294
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
253
295
|
``None``, will use the maximum timestamp in the data.
|
|
254
296
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -272,16 +314,25 @@ class KumoRFM:
|
|
|
272
314
|
|
|
273
315
|
Returns:
|
|
274
316
|
The predictions as a :class:`pandas.DataFrame`.
|
|
275
|
-
If ``explain
|
|
276
|
-
|
|
317
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
|
+
containing the prediction, summary, and details.
|
|
277
319
|
"""
|
|
320
|
+
explain_config: ExplainConfig | None = None
|
|
321
|
+
if explain is True:
|
|
322
|
+
explain_config = ExplainConfig()
|
|
323
|
+
elif explain is not False:
|
|
324
|
+
explain_config = ExplainConfig._cast(explain)
|
|
325
|
+
|
|
278
326
|
query_def = self._parse_query(query)
|
|
327
|
+
query_str = query_def.to_string()
|
|
279
328
|
|
|
280
329
|
if num_hops != 2 and num_neighbors is not None:
|
|
281
330
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
282
331
|
f"custom 'num_hops={num_hops}' option")
|
|
283
332
|
|
|
284
|
-
if
|
|
333
|
+
if explain_config is not None and run_mode in {
|
|
334
|
+
RunMode.NORMAL, RunMode.BEST
|
|
335
|
+
}:
|
|
285
336
|
warnings.warn(f"Explainability is currently only supported for "
|
|
286
337
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
287
338
|
f"mode has been reset. Please lower the run mode to "
|
|
@@ -298,27 +349,27 @@ class KumoRFM:
|
|
|
298
349
|
if len(indices) == 0:
|
|
299
350
|
raise ValueError("At least one entity is required")
|
|
300
351
|
|
|
301
|
-
if
|
|
352
|
+
if explain_config is not None and len(indices) > 1:
|
|
302
353
|
raise ValueError(
|
|
303
354
|
f"Cannot explain predictions for more than a single entity "
|
|
304
355
|
f"(got {len(indices)})")
|
|
305
356
|
|
|
306
357
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
307
|
-
if
|
|
358
|
+
if explain_config is not None:
|
|
308
359
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
309
360
|
else:
|
|
310
361
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
311
362
|
|
|
312
363
|
if not isinstance(verbose, ProgressLogger):
|
|
313
|
-
verbose =
|
|
364
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
314
365
|
|
|
315
366
|
with verbose as logger:
|
|
316
367
|
|
|
317
|
-
batch_size:
|
|
368
|
+
batch_size: int | None = None
|
|
318
369
|
if self._batch_size == 'max':
|
|
319
|
-
task_type =
|
|
320
|
-
query_def,
|
|
321
|
-
edge_types=self.
|
|
370
|
+
task_type = self._get_task_type(
|
|
371
|
+
query=query_def,
|
|
372
|
+
edge_types=self._sampler.edge_types,
|
|
322
373
|
)
|
|
323
374
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
324
375
|
else:
|
|
@@ -334,9 +385,9 @@ class KumoRFM:
|
|
|
334
385
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
335
386
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
336
387
|
|
|
337
|
-
predictions:
|
|
338
|
-
summary:
|
|
339
|
-
details:
|
|
388
|
+
predictions: list[pd.DataFrame] = []
|
|
389
|
+
summary: str | None = None
|
|
390
|
+
details: Explanation | None = None
|
|
340
391
|
for i, batch in enumerate(batches):
|
|
341
392
|
# TODO Re-use the context for subsequent predictions.
|
|
342
393
|
context = self._get_context(
|
|
@@ -355,6 +406,7 @@ class KumoRFM:
|
|
|
355
406
|
request = RFMPredictRequest(
|
|
356
407
|
context=context,
|
|
357
408
|
run_mode=RunMode(run_mode),
|
|
409
|
+
query=query_str,
|
|
358
410
|
use_prediction_time=use_prediction_time,
|
|
359
411
|
)
|
|
360
412
|
with warnings.catch_warnings():
|
|
@@ -369,8 +421,7 @@ class KumoRFM:
|
|
|
369
421
|
stats = Context.get_memory_stats(request_msg.context)
|
|
370
422
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
371
423
|
|
|
372
|
-
if
|
|
373
|
-
and len(batches) > 1):
|
|
424
|
+
if i == 0 and len(batches) > 1:
|
|
374
425
|
verbose.init_progress(
|
|
375
426
|
total=len(batches),
|
|
376
427
|
description='Predicting',
|
|
@@ -378,20 +429,23 @@ class KumoRFM:
|
|
|
378
429
|
|
|
379
430
|
for attempt in range(self.num_retries + 1):
|
|
380
431
|
try:
|
|
381
|
-
if
|
|
382
|
-
resp =
|
|
432
|
+
if explain_config is not None:
|
|
433
|
+
resp = self._api_client.explain(
|
|
434
|
+
request=_bytes,
|
|
435
|
+
skip_summary=explain_config.skip_summary,
|
|
436
|
+
)
|
|
383
437
|
summary = resp.summary
|
|
384
438
|
details = resp.details
|
|
385
439
|
else:
|
|
386
|
-
resp =
|
|
440
|
+
resp = self._api_client.predict(_bytes)
|
|
387
441
|
df = pd.DataFrame(**resp.prediction)
|
|
388
442
|
|
|
389
443
|
# Cast 'ENTITY' to correct data type:
|
|
390
444
|
if 'ENTITY' in df:
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
445
|
+
table_dict = context.subgraph.table_dict
|
|
446
|
+
table = table_dict[query_def.entity_table]
|
|
447
|
+
ser = table.df[table.primary_key]
|
|
448
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
395
449
|
|
|
396
450
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
397
451
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -406,8 +460,7 @@ class KumoRFM:
|
|
|
406
460
|
|
|
407
461
|
predictions.append(df)
|
|
408
462
|
|
|
409
|
-
if (
|
|
410
|
-
and len(batches) > 1):
|
|
463
|
+
if len(batches) > 1:
|
|
411
464
|
verbose.step()
|
|
412
465
|
|
|
413
466
|
break
|
|
@@ -430,7 +483,7 @@ class KumoRFM:
|
|
|
430
483
|
else:
|
|
431
484
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
432
485
|
|
|
433
|
-
if
|
|
486
|
+
if explain_config is not None:
|
|
434
487
|
assert len(predictions) == 1
|
|
435
488
|
assert summary is not None
|
|
436
489
|
assert details is not None
|
|
@@ -445,9 +498,9 @@ class KumoRFM:
|
|
|
445
498
|
def is_valid_entity(
|
|
446
499
|
self,
|
|
447
500
|
query: str,
|
|
448
|
-
indices:
|
|
501
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
449
502
|
*,
|
|
450
|
-
anchor_time:
|
|
503
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
451
504
|
) -> np.ndarray:
|
|
452
505
|
r"""Returns a mask that denotes which entities are valid for the
|
|
453
506
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -474,37 +527,32 @@ class KumoRFM:
|
|
|
474
527
|
raise ValueError("At least one entity is required")
|
|
475
528
|
|
|
476
529
|
if anchor_time is None:
|
|
477
|
-
anchor_time = self.
|
|
530
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
478
531
|
|
|
479
532
|
if isinstance(anchor_time, pd.Timestamp):
|
|
480
533
|
self._validate_time(query_def, anchor_time, None, False)
|
|
481
534
|
else:
|
|
482
535
|
assert anchor_time == 'entity'
|
|
483
|
-
if
|
|
536
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
484
537
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
485
538
|
f"table '{query_def.entity_table}' "
|
|
486
539
|
f"to have a time column.")
|
|
487
540
|
|
|
488
|
-
|
|
489
|
-
table_name=query_def.entity_table,
|
|
490
|
-
pkey=pd.Series(indices),
|
|
491
|
-
)
|
|
492
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
493
|
-
return query_driver.is_valid(node, anchor_time)
|
|
541
|
+
raise NotImplementedError
|
|
494
542
|
|
|
495
543
|
def evaluate(
|
|
496
544
|
self,
|
|
497
545
|
query: str,
|
|
498
546
|
*,
|
|
499
|
-
metrics:
|
|
500
|
-
anchor_time:
|
|
501
|
-
context_anchor_time:
|
|
502
|
-
run_mode:
|
|
503
|
-
num_neighbors:
|
|
547
|
+
metrics: list[str] | None = None,
|
|
548
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
549
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
550
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
551
|
+
num_neighbors: list[int] | None = None,
|
|
504
552
|
num_hops: int = 2,
|
|
505
|
-
max_pq_iterations: int =
|
|
506
|
-
random_seed:
|
|
507
|
-
verbose:
|
|
553
|
+
max_pq_iterations: int = 10,
|
|
554
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
555
|
+
verbose: bool | ProgressLogger = True,
|
|
508
556
|
use_prediction_time: bool = False,
|
|
509
557
|
) -> pd.DataFrame:
|
|
510
558
|
"""Evaluates a predictive query.
|
|
@@ -552,7 +600,7 @@ class KumoRFM:
|
|
|
552
600
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
553
601
|
|
|
554
602
|
if not isinstance(verbose, ProgressLogger):
|
|
555
|
-
verbose =
|
|
603
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
556
604
|
|
|
557
605
|
with verbose as logger:
|
|
558
606
|
context = self._get_context(
|
|
@@ -586,10 +634,10 @@ class KumoRFM:
|
|
|
586
634
|
|
|
587
635
|
if len(request_bytes) > _MAX_SIZE:
|
|
588
636
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
589
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
637
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
590
638
|
|
|
591
639
|
try:
|
|
592
|
-
resp =
|
|
640
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
593
641
|
except HTTPException as e:
|
|
594
642
|
try:
|
|
595
643
|
msg = json.loads(e.detail)['detail']
|
|
@@ -611,9 +659,9 @@ class KumoRFM:
|
|
|
611
659
|
query: str,
|
|
612
660
|
size: int,
|
|
613
661
|
*,
|
|
614
|
-
anchor_time:
|
|
615
|
-
random_seed:
|
|
616
|
-
max_iterations: int =
|
|
662
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
663
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
664
|
+
max_iterations: int = 10,
|
|
617
665
|
) -> pd.DataFrame:
|
|
618
666
|
"""Returns the labels of a predictive query for a specified anchor
|
|
619
667
|
time.
|
|
@@ -633,40 +681,37 @@ class KumoRFM:
|
|
|
633
681
|
query_def = self._parse_query(query)
|
|
634
682
|
|
|
635
683
|
if anchor_time is None:
|
|
636
|
-
anchor_time = self.
|
|
684
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
637
685
|
if query_def.target_ast.date_offset_range is not None:
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
686
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
687
|
+
offset *= query_def.num_forecasts
|
|
688
|
+
anchor_time -= offset
|
|
641
689
|
|
|
642
690
|
assert anchor_time is not None
|
|
643
691
|
if isinstance(anchor_time, pd.Timestamp):
|
|
644
692
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
645
693
|
else:
|
|
646
694
|
assert anchor_time == 'entity'
|
|
647
|
-
if
|
|
695
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
648
696
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
649
697
|
f"table '{query_def.entity_table}' "
|
|
650
698
|
f"to have a time column")
|
|
651
699
|
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
700
|
+
train, test = self._sampler.sample_target(
|
|
701
|
+
query=query_def,
|
|
702
|
+
num_train_examples=0,
|
|
703
|
+
train_anchor_time=anchor_time,
|
|
704
|
+
num_train_trials=0,
|
|
705
|
+
num_test_examples=size,
|
|
706
|
+
test_anchor_time=anchor_time,
|
|
707
|
+
num_test_trials=max_iterations * size,
|
|
708
|
+
random_seed=random_seed,
|
|
661
709
|
)
|
|
662
710
|
|
|
663
|
-
entity = self._graph_store.pkey_map_dict[
|
|
664
|
-
query_def.entity_table].index[node]
|
|
665
|
-
|
|
666
711
|
return pd.DataFrame({
|
|
667
|
-
'ENTITY':
|
|
668
|
-
'ANCHOR_TIMESTAMP':
|
|
669
|
-
'TARGET':
|
|
712
|
+
'ENTITY': test.entity_pkey,
|
|
713
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
714
|
+
'TARGET': test.target,
|
|
670
715
|
})
|
|
671
716
|
|
|
672
717
|
# Helpers #################################################################
|
|
@@ -687,8 +732,7 @@ class KumoRFM:
|
|
|
687
732
|
graph_definition=self._graph_def,
|
|
688
733
|
)
|
|
689
734
|
|
|
690
|
-
resp =
|
|
691
|
-
# TODO Expose validation warnings.
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
692
736
|
|
|
693
737
|
if len(resp.validation_response.warnings) > 0:
|
|
694
738
|
msg = '\n'.join([
|
|
@@ -707,36 +751,92 @@ class KumoRFM:
|
|
|
707
751
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
708
752
|
f"{msg}") from None
|
|
709
753
|
|
|
754
|
+
@staticmethod
|
|
755
|
+
def _get_task_type(
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
757
|
+
edge_types: list[tuple[str, str, str]],
|
|
758
|
+
) -> TaskType:
|
|
759
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
760
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
761
|
+
|
|
762
|
+
target = query.target_ast
|
|
763
|
+
if isinstance(target, Join):
|
|
764
|
+
target = target.rhs_target
|
|
765
|
+
if isinstance(target, Aggregation):
|
|
766
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
767
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
768
|
+
'.')
|
|
769
|
+
target_edge_types = [
|
|
770
|
+
edge_type for edge_type in edge_types
|
|
771
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
772
|
+
]
|
|
773
|
+
if len(target_edge_types) != 1:
|
|
774
|
+
raise NotImplementedError(
|
|
775
|
+
f"Multilabel-classification queries based on "
|
|
776
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
777
|
+
f"planned to write a link prediction query instead, "
|
|
778
|
+
f"make sure to register '{col_name}' as a "
|
|
779
|
+
f"foreign key.")
|
|
780
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
781
|
+
|
|
782
|
+
return TaskType.REGRESSION
|
|
783
|
+
|
|
784
|
+
assert isinstance(target, Column)
|
|
785
|
+
|
|
786
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
787
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
788
|
+
|
|
789
|
+
if target.stype in {Stype.numerical}:
|
|
790
|
+
return TaskType.REGRESSION
|
|
791
|
+
|
|
792
|
+
raise NotImplementedError("Task type not yet supported")
|
|
793
|
+
|
|
794
|
+
def _get_default_anchor_time(
|
|
795
|
+
self,
|
|
796
|
+
query: ValidatedPredictiveQuery,
|
|
797
|
+
) -> pd.Timestamp:
|
|
798
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
799
|
+
aggr_table_names = [
|
|
800
|
+
aggr._get_target_column_name().split('.')[0]
|
|
801
|
+
for aggr in query.get_all_target_aggregations()
|
|
802
|
+
]
|
|
803
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
804
|
+
|
|
805
|
+
assert query.query_type == QueryType.STATIC
|
|
806
|
+
return self._sampler.get_max_time()
|
|
807
|
+
|
|
710
808
|
def _validate_time(
|
|
711
809
|
self,
|
|
712
810
|
query: ValidatedPredictiveQuery,
|
|
713
811
|
anchor_time: pd.Timestamp,
|
|
714
|
-
context_anchor_time:
|
|
812
|
+
context_anchor_time: pd.Timestamp | None,
|
|
715
813
|
evaluate: bool,
|
|
716
814
|
) -> None:
|
|
717
815
|
|
|
718
|
-
if self.
|
|
816
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
719
817
|
return # Graph without timestamps
|
|
720
818
|
|
|
721
|
-
|
|
819
|
+
min_time = self._sampler.get_min_time()
|
|
820
|
+
max_time = self._sampler.get_max_time()
|
|
821
|
+
|
|
822
|
+
if anchor_time < min_time:
|
|
722
823
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
723
|
-
f"the earliest timestamp "
|
|
724
|
-
f"
|
|
824
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
825
|
+
f"data.")
|
|
725
826
|
|
|
726
|
-
if
|
|
727
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
827
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
728
828
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
729
829
|
f"aggregation time range is too large. To make "
|
|
730
830
|
f"this prediction, we would need data back to "
|
|
731
831
|
f"'{context_anchor_time}', however, your data "
|
|
732
|
-
f"only contains data back to "
|
|
733
|
-
f"'{self._graph_store.min_time}'.")
|
|
832
|
+
f"only contains data back to '{min_time}'.")
|
|
734
833
|
|
|
735
834
|
if query.target_ast.date_offset_range is not None:
|
|
736
835
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
737
836
|
else:
|
|
738
837
|
end_offset = pd.DateOffset(0)
|
|
739
|
-
|
|
838
|
+
end_offset = end_offset * query.num_forecasts
|
|
839
|
+
|
|
740
840
|
if (context_anchor_time is not None
|
|
741
841
|
and context_anchor_time > anchor_time):
|
|
742
842
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -746,7 +846,7 @@ class KumoRFM:
|
|
|
746
846
|
f"intended.")
|
|
747
847
|
elif (query.query_type == QueryType.TEMPORAL
|
|
748
848
|
and context_anchor_time is not None
|
|
749
|
-
and context_anchor_time +
|
|
849
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
750
850
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
751
851
|
f"'{context_anchor_time}' will leak information "
|
|
752
852
|
f"from the prediction anchor timestamp "
|
|
@@ -754,40 +854,37 @@ class KumoRFM:
|
|
|
754
854
|
f"intended.")
|
|
755
855
|
|
|
756
856
|
elif (context_anchor_time is not None
|
|
757
|
-
and context_anchor_time -
|
|
758
|
-
|
|
759
|
-
_time = context_anchor_time - forecast_end_offset
|
|
857
|
+
and context_anchor_time - end_offset < min_time):
|
|
858
|
+
_time = context_anchor_time - end_offset
|
|
760
859
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
860
|
f"aggregation time range is too large. To form "
|
|
762
861
|
f"proper input data, we would need data back to "
|
|
763
862
|
f"'{_time}', however, your data only contains "
|
|
764
|
-
f"data back to '{
|
|
863
|
+
f"data back to '{min_time}'.")
|
|
765
864
|
|
|
766
|
-
if
|
|
767
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
865
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
768
866
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
769
|
-
f"latest timestamp '{
|
|
770
|
-
f"
|
|
867
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
868
|
+
f"make sure this is intended.")
|
|
771
869
|
|
|
772
|
-
|
|
773
|
-
if evaluate and anchor_time > max_eval_time:
|
|
870
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
774
871
|
raise ValueError(
|
|
775
872
|
f"Anchor timestamp for evaluation is after the latest "
|
|
776
|
-
f"supported timestamp '{
|
|
873
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
777
874
|
|
|
778
875
|
def _get_context(
|
|
779
876
|
self,
|
|
780
877
|
query: ValidatedPredictiveQuery,
|
|
781
|
-
indices:
|
|
782
|
-
anchor_time:
|
|
783
|
-
context_anchor_time:
|
|
878
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
879
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
880
|
+
context_anchor_time: pd.Timestamp | None,
|
|
784
881
|
run_mode: RunMode,
|
|
785
|
-
num_neighbors:
|
|
882
|
+
num_neighbors: list[int] | None,
|
|
786
883
|
num_hops: int,
|
|
787
884
|
max_pq_iterations: int,
|
|
788
885
|
evaluate: bool,
|
|
789
|
-
random_seed:
|
|
790
|
-
logger:
|
|
886
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
887
|
+
logger: ProgressLogger | None = None,
|
|
791
888
|
) -> Context:
|
|
792
889
|
|
|
793
890
|
if num_neighbors is not None:
|
|
@@ -804,10 +901,9 @@ class KumoRFM:
|
|
|
804
901
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
805
902
|
f"must go beyond this for your use-case.")
|
|
806
903
|
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
edge_types=self._graph_store.edge_types,
|
|
904
|
+
task_type = self._get_task_type(
|
|
905
|
+
query=query,
|
|
906
|
+
edge_types=self._sampler.edge_types,
|
|
811
907
|
)
|
|
812
908
|
|
|
813
909
|
if logger is not None:
|
|
@@ -839,14 +935,17 @@ class KumoRFM:
|
|
|
839
935
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
840
936
|
|
|
841
937
|
if query.target_ast.date_offset_range is None:
|
|
842
|
-
|
|
938
|
+
step_offset = pd.DateOffset(0)
|
|
843
939
|
else:
|
|
844
|
-
|
|
845
|
-
|
|
940
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
941
|
+
end_offset = step_offset * query.num_forecasts
|
|
942
|
+
|
|
846
943
|
if anchor_time is None:
|
|
847
|
-
anchor_time = self.
|
|
944
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
945
|
+
|
|
848
946
|
if evaluate:
|
|
849
|
-
anchor_time = anchor_time -
|
|
947
|
+
anchor_time = anchor_time - end_offset
|
|
948
|
+
|
|
850
949
|
if logger is not None:
|
|
851
950
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
852
951
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -860,57 +959,71 @@ class KumoRFM:
|
|
|
860
959
|
|
|
861
960
|
assert anchor_time is not None
|
|
862
961
|
if isinstance(anchor_time, pd.Timestamp):
|
|
962
|
+
if context_anchor_time == 'entity':
|
|
963
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
964
|
+
"for context and prediction examples")
|
|
863
965
|
if context_anchor_time is None:
|
|
864
|
-
context_anchor_time = anchor_time -
|
|
966
|
+
context_anchor_time = anchor_time - end_offset
|
|
865
967
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
866
968
|
evaluate)
|
|
867
969
|
else:
|
|
868
970
|
assert anchor_time == 'entity'
|
|
869
|
-
if query.
|
|
971
|
+
if query.query_type != QueryType.STATIC:
|
|
972
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
973
|
+
"static predictive queries")
|
|
974
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
870
975
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
871
976
|
f"table '{query.entity_table}' to "
|
|
872
977
|
f"have a time column")
|
|
873
|
-
if context_anchor_time
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
context_anchor_time =
|
|
978
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
981
|
+
context_anchor_time = 'entity'
|
|
877
982
|
|
|
878
|
-
|
|
983
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
879
984
|
if evaluate:
|
|
880
|
-
|
|
985
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
881
986
|
if task_type.is_link_pred:
|
|
882
|
-
|
|
987
|
+
num_test_examples = num_test_examples // 5
|
|
988
|
+
else:
|
|
989
|
+
num_test_examples = 0
|
|
990
|
+
|
|
991
|
+
train, test = self._sampler.sample_target(
|
|
992
|
+
query=query,
|
|
993
|
+
num_train_examples=num_train_examples,
|
|
994
|
+
train_anchor_time=context_anchor_time,
|
|
995
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
996
|
+
num_test_examples=num_test_examples,
|
|
997
|
+
test_anchor_time=anchor_time,
|
|
998
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
999
|
+
random_seed=random_seed,
|
|
1000
|
+
)
|
|
1001
|
+
train_pkey, train_time, y_train = train
|
|
1002
|
+
test_pkey, test_time, y_test = test
|
|
883
1003
|
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
num_rhs = y_test.explode().nunique()
|
|
905
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
906
|
-
f"{num_rhs:,} unique items")
|
|
907
|
-
else:
|
|
908
|
-
raise NotImplementedError
|
|
909
|
-
logger.log(msg)
|
|
1004
|
+
if evaluate and logger is not None:
|
|
1005
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1006
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1007
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1008
|
+
f"{pos:.2f}% positive cases")
|
|
1009
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1010
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1011
|
+
f"{y_test.nunique()} classes")
|
|
1012
|
+
elif task_type == TaskType.REGRESSION:
|
|
1013
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1014
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1015
|
+
f"between {format_value(_min)} and "
|
|
1016
|
+
f"{format_value(_max)}")
|
|
1017
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1018
|
+
num_rhs = y_test.explode().nunique()
|
|
1019
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1020
|
+
f"{num_rhs:,} unique items")
|
|
1021
|
+
else:
|
|
1022
|
+
raise NotImplementedError
|
|
1023
|
+
logger.log(msg)
|
|
910
1024
|
|
|
911
|
-
|
|
1025
|
+
if not evaluate:
|
|
912
1026
|
assert indices is not None
|
|
913
|
-
|
|
914
1027
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
915
1028
|
raise ValueError(f"Cannot predict for more than "
|
|
916
1029
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -918,26 +1031,12 @@ class KumoRFM:
|
|
|
918
1031
|
f"`KumoRFM.batch_mode` to process entities "
|
|
919
1032
|
f"in batches")
|
|
920
1033
|
|
|
921
|
-
|
|
922
|
-
table_name=query.entity_table,
|
|
923
|
-
pkey=pd.Series(indices),
|
|
924
|
-
)
|
|
925
|
-
|
|
1034
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
926
1035
|
if isinstance(anchor_time, pd.Timestamp):
|
|
927
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
928
|
-
len(
|
|
1036
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1037
|
+
len(indices)).reset_index(drop=True)
|
|
929
1038
|
else:
|
|
930
|
-
|
|
931
|
-
time = time[test_node] * 1000**3
|
|
932
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
933
|
-
|
|
934
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
935
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
936
|
-
anchor_time=context_anchor_time or 'entity',
|
|
937
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
938
|
-
or anchor_time == 'entity') else None,
|
|
939
|
-
max_iterations=max_pq_iterations,
|
|
940
|
-
)
|
|
1039
|
+
train_time = test_time = 'entity'
|
|
941
1040
|
|
|
942
1041
|
if logger is not None:
|
|
943
1042
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -960,12 +1059,12 @@ class KumoRFM:
|
|
|
960
1059
|
raise NotImplementedError
|
|
961
1060
|
logger.log(msg)
|
|
962
1061
|
|
|
963
|
-
entity_table_names:
|
|
1062
|
+
entity_table_names: tuple[str, ...]
|
|
964
1063
|
if task_type.is_link_pred:
|
|
965
1064
|
final_aggr = query.get_final_target_aggregation()
|
|
966
1065
|
assert final_aggr is not None
|
|
967
1066
|
edge_fkey = final_aggr._get_target_column_name()
|
|
968
|
-
for edge_type in self.
|
|
1067
|
+
for edge_type in self._sampler.edge_types:
|
|
969
1068
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
970
1069
|
entity_table_names = (
|
|
971
1070
|
query.entity_table,
|
|
@@ -977,21 +1076,24 @@ class KumoRFM:
|
|
|
977
1076
|
# Exclude the entity anchor time from the feature set to prevent
|
|
978
1077
|
# running out-of-distribution between in-context and test examples:
|
|
979
1078
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
980
|
-
if
|
|
1079
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
981
1080
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
982
1081
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
983
|
-
|
|
984
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1082
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
985
1083
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
986
1084
|
|
|
987
|
-
subgraph = self.
|
|
1085
|
+
subgraph = self._sampler.sample_subgraph(
|
|
988
1086
|
entity_table_names=entity_table_names,
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
1087
|
+
entity_pkey=pd.concat(
|
|
1088
|
+
[train_pkey, test_pkey],
|
|
1089
|
+
axis=0,
|
|
1090
|
+
ignore_index=True,
|
|
1091
|
+
),
|
|
1092
|
+
anchor_time=pd.concat(
|
|
1093
|
+
[train_time, test_time],
|
|
1094
|
+
axis=0,
|
|
1095
|
+
ignore_index=True,
|
|
1096
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
995
1097
|
num_neighbors=num_neighbors,
|
|
996
1098
|
exclude_cols_dict=exclude_cols_dict,
|
|
997
1099
|
)
|
|
@@ -1003,23 +1105,19 @@ class KumoRFM:
|
|
|
1003
1105
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1004
1106
|
f"must go beyond this for your use-case.")
|
|
1005
1107
|
|
|
1006
|
-
step_size: Optional[int] = None
|
|
1007
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1008
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1009
|
-
|
|
1010
1108
|
return Context(
|
|
1011
1109
|
task_type=task_type,
|
|
1012
1110
|
entity_table_names=entity_table_names,
|
|
1013
1111
|
subgraph=subgraph,
|
|
1014
1112
|
y_train=y_train,
|
|
1015
|
-
y_test=y_test,
|
|
1113
|
+
y_test=y_test if evaluate else None,
|
|
1016
1114
|
top_k=query.top_k,
|
|
1017
|
-
step_size=
|
|
1115
|
+
step_size=None,
|
|
1018
1116
|
)
|
|
1019
1117
|
|
|
1020
1118
|
@staticmethod
|
|
1021
1119
|
def _validate_metrics(
|
|
1022
|
-
metrics:
|
|
1120
|
+
metrics: list[str],
|
|
1023
1121
|
task_type: TaskType,
|
|
1024
1122
|
) -> None:
|
|
1025
1123
|
|
|
@@ -1076,7 +1174,7 @@ class KumoRFM:
|
|
|
1076
1174
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1077
1175
|
|
|
1078
1176
|
|
|
1079
|
-
def format_value(value:
|
|
1177
|
+
def format_value(value: int | float) -> str:
|
|
1080
1178
|
if value == int(value):
|
|
1081
1179
|
return f'{int(value):,}'
|
|
1082
1180
|
if abs(value) >= 1000:
|