kumoai 2.10.0.dev202509291830__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202511161731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/experimental/rfm/__init__.py +5 -3
- kumoai/experimental/rfm/infer/timestamp.py +5 -4
- kumoai/experimental/rfm/local_graph.py +90 -74
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +13 -1
- kumoai/experimental/rfm/local_pquery_driver.py +323 -38
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +313 -84
- kumoai/jobs.py +1 -0
- kumoai/trainer/trainer.py +12 -10
- kumoai/utils/progress_logger.py +13 -0
- {kumoai-2.10.0.dev202509291830.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/METADATA +4 -5
- {kumoai-2.10.0.dev202509291830.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/RECORD +24 -24
- {kumoai-2.10.0.dev202509291830.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202509291830.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202509291830.dist-info → kumoai-2.13.0.dev202511161731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -4,19 +4,29 @@ import warnings
|
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from dataclasses import replace
|
|
8
|
-
from typing import
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
14
26
|
from kumoapi.rfm import (
|
|
15
|
-
Context,
|
|
16
|
-
PQueryDefinition,
|
|
17
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
18
29
|
RFMPredictRequest,
|
|
19
|
-
RFMValidateQueryRequest,
|
|
20
30
|
)
|
|
21
31
|
from kumoapi.task import TaskType
|
|
22
32
|
|
|
@@ -29,6 +39,7 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
29
39
|
LocalPQueryDriver,
|
|
30
40
|
date_offset_to_seconds,
|
|
31
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
32
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
33
44
|
|
|
34
45
|
_RANDOM_SEED = 42
|
|
@@ -59,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
59
70
|
"beyond this for your use-case.")
|
|
60
71
|
|
|
61
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
62
118
|
class KumoRFM:
|
|
63
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
64
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -116,7 +172,7 @@ class KumoRFM:
|
|
|
116
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
117
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
118
174
|
|
|
119
|
-
self._batch_size: Optional[int | Literal['
|
|
175
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
120
176
|
self.num_retries: int = 0
|
|
121
177
|
|
|
122
178
|
def __repr__(self) -> str:
|
|
@@ -125,23 +181,23 @@ class KumoRFM:
|
|
|
125
181
|
@contextmanager
|
|
126
182
|
def batch_mode(
|
|
127
183
|
self,
|
|
128
|
-
batch_size: Union[int, Literal['
|
|
184
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
129
185
|
num_retries: int = 1,
|
|
130
186
|
) -> Generator[None, None, None]:
|
|
131
187
|
"""Context manager to predict in batches.
|
|
132
188
|
|
|
133
189
|
.. code-block:: python
|
|
134
190
|
|
|
135
|
-
with model.batch_mode(batch_size='
|
|
191
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
136
192
|
df = model.predict(query, indices=...)
|
|
137
193
|
|
|
138
194
|
Args:
|
|
139
|
-
batch_size: The batch size. If set to ``"
|
|
195
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
140
196
|
maximum applicable batch size for the given task.
|
|
141
197
|
num_retries: The maximum number of retries for failed queries due
|
|
142
198
|
to unexpected server issues.
|
|
143
199
|
"""
|
|
144
|
-
if batch_size != '
|
|
200
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
145
201
|
raise ValueError(f"'batch_size' must be greater than zero "
|
|
146
202
|
f"(got {batch_size})")
|
|
147
203
|
|
|
@@ -155,11 +211,13 @@ class KumoRFM:
|
|
|
155
211
|
self._batch_size = None
|
|
156
212
|
self.num_retries = 0
|
|
157
213
|
|
|
214
|
+
@overload
|
|
158
215
|
def predict(
|
|
159
216
|
self,
|
|
160
217
|
query: str,
|
|
161
218
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
162
219
|
*,
|
|
220
|
+
explain: Literal[False] = False,
|
|
163
221
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
164
222
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
165
223
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -168,18 +226,65 @@ class KumoRFM:
|
|
|
168
226
|
max_pq_iterations: int = 20,
|
|
169
227
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
170
228
|
verbose: Union[bool, ProgressLogger] = True,
|
|
229
|
+
use_prediction_time: bool = False,
|
|
171
230
|
) -> pd.DataFrame:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
def predict(
|
|
235
|
+
self,
|
|
236
|
+
query: str,
|
|
237
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
238
|
+
*,
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
240
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
241
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
242
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
243
|
+
num_neighbors: Optional[List[int]] = None,
|
|
244
|
+
num_hops: int = 2,
|
|
245
|
+
max_pq_iterations: int = 20,
|
|
246
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
247
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
248
|
+
use_prediction_time: bool = False,
|
|
249
|
+
) -> Explanation:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
def predict(
|
|
253
|
+
self,
|
|
254
|
+
query: str,
|
|
255
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
256
|
+
*,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
258
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
259
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
260
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
261
|
+
num_neighbors: Optional[List[int]] = None,
|
|
262
|
+
num_hops: int = 2,
|
|
263
|
+
max_pq_iterations: int = 20,
|
|
264
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
265
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
267
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
172
268
|
"""Returns predictions for a predictive query.
|
|
173
269
|
|
|
174
270
|
Args:
|
|
175
271
|
query: The predictive query.
|
|
176
|
-
indices: The entity primary keys to predict
|
|
177
|
-
indices given as part of the predictive query.
|
|
272
|
+
indices: The entity primary keys to predict for. Will override the
|
|
273
|
+
indices given as part of the predictive query. Predictions will
|
|
274
|
+
be generated for all indices, independent of whether they
|
|
275
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
276
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
178
283
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
179
|
-
|
|
180
|
-
If set to
|
|
284
|
+
``None``, will use the maximum timestamp in the data.
|
|
285
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
181
286
|
context_anchor_time: The maximum anchor timestamp for context
|
|
182
|
-
examples. If set to
|
|
287
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
183
288
|
determine the anchor time for context examples.
|
|
184
289
|
run_mode: The :class:`RunMode` for the query.
|
|
185
290
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -192,46 +297,54 @@ class KumoRFM:
|
|
|
192
297
|
entities to find valid labels.
|
|
193
298
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
194
299
|
verbose: Whether to print verbose output.
|
|
300
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
301
|
+
additional feature during prediction. This is typically
|
|
302
|
+
beneficial for time series forecasting tasks.
|
|
195
303
|
|
|
196
304
|
Returns:
|
|
197
|
-
The predictions as a :class:`pandas.DataFrame
|
|
305
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
198
308
|
"""
|
|
199
|
-
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
200
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
201
317
|
|
|
202
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
203
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
204
320
|
f"custom 'num_hops={num_hops}' option")
|
|
205
321
|
|
|
206
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
207
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
208
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
209
327
|
f"mode has been reset. Please lower the run mode to "
|
|
210
328
|
f"suppress this warning.")
|
|
211
329
|
|
|
212
330
|
if indices is None:
|
|
213
|
-
if query_def.
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
214
332
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
215
333
|
"pass them via `predict(query, indices=...)`")
|
|
216
|
-
indices = query_def.
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
217
335
|
else:
|
|
218
|
-
query_def = replace(
|
|
219
|
-
query_def,
|
|
220
|
-
entity=replace(query_def.entity, ids=None),
|
|
221
|
-
)
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
222
337
|
|
|
223
338
|
if len(indices) == 0:
|
|
224
|
-
raise ValueError("At least one entity is required
|
|
225
|
-
"prediction")
|
|
339
|
+
raise ValueError("At least one entity is required")
|
|
226
340
|
|
|
227
|
-
if
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
f"entity (got {len(indices)})")
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Cannot explain predictions for more than a single entity "
|
|
344
|
+
f"(got {len(indices)})")
|
|
232
345
|
|
|
233
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
234
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
235
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
236
349
|
else:
|
|
237
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -242,9 +355,9 @@ class KumoRFM:
|
|
|
242
355
|
with verbose as logger:
|
|
243
356
|
|
|
244
357
|
batch_size: Optional[int] = None
|
|
245
|
-
if self._batch_size == '
|
|
246
|
-
task_type =
|
|
247
|
-
|
|
358
|
+
if self._batch_size == 'max':
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
248
361
|
edge_types=self._graph_store.edge_types,
|
|
249
362
|
)
|
|
250
363
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -261,7 +374,9 @@ class KumoRFM:
|
|
|
261
374
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
262
375
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
263
376
|
|
|
264
|
-
|
|
377
|
+
predictions: List[pd.DataFrame] = []
|
|
378
|
+
summary: Optional[str] = None
|
|
379
|
+
details: Optional[Explanation] = None
|
|
265
380
|
for i, batch in enumerate(batches):
|
|
266
381
|
# TODO Re-use the context for subsequent predictions.
|
|
267
382
|
context = self._get_context(
|
|
@@ -280,6 +395,8 @@ class KumoRFM:
|
|
|
280
395
|
request = RFMPredictRequest(
|
|
281
396
|
context=context,
|
|
282
397
|
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
399
|
+
use_prediction_time=use_prediction_time,
|
|
283
400
|
)
|
|
284
401
|
with warnings.catch_warnings():
|
|
285
402
|
warnings.filterwarnings('ignore', message='gencode')
|
|
@@ -302,11 +419,36 @@ class KumoRFM:
|
|
|
302
419
|
|
|
303
420
|
for attempt in range(self.num_retries + 1):
|
|
304
421
|
try:
|
|
305
|
-
if
|
|
306
|
-
resp = global_state.client.rfm_api.explain(
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
427
|
+
summary = resp.summary
|
|
428
|
+
details = resp.details
|
|
307
429
|
else:
|
|
308
430
|
resp = global_state.client.rfm_api.predict(_bytes)
|
|
309
|
-
|
|
431
|
+
df = pd.DataFrame(**resp.prediction)
|
|
432
|
+
|
|
433
|
+
# Cast 'ENTITY' to correct data type:
|
|
434
|
+
if 'ENTITY' in df:
|
|
435
|
+
entity = query_def.entity_table
|
|
436
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
437
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
438
|
+
type(pkey_map.index[0]))
|
|
439
|
+
|
|
440
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
442
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
443
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
444
|
+
if isinstance(ser.iloc[0], str):
|
|
445
|
+
unit = None
|
|
446
|
+
else:
|
|
447
|
+
unit = 'ms'
|
|
448
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
449
|
+
ser, errors='coerce', unit=unit)
|
|
450
|
+
|
|
451
|
+
predictions.append(df)
|
|
310
452
|
|
|
311
453
|
if (isinstance(verbose, InteractiveProgressLogger)
|
|
312
454
|
and len(batches) > 1):
|
|
@@ -327,7 +469,72 @@ class KumoRFM:
|
|
|
327
469
|
|
|
328
470
|
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
329
471
|
|
|
330
|
-
|
|
472
|
+
if len(predictions) == 1:
|
|
473
|
+
prediction = predictions[0]
|
|
474
|
+
else:
|
|
475
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
476
|
+
|
|
477
|
+
if explain_config is not None:
|
|
478
|
+
assert len(predictions) == 1
|
|
479
|
+
assert summary is not None
|
|
480
|
+
assert details is not None
|
|
481
|
+
return Explanation(
|
|
482
|
+
prediction=prediction,
|
|
483
|
+
summary=summary,
|
|
484
|
+
details=details,
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
return prediction
|
|
488
|
+
|
|
489
|
+
def is_valid_entity(
|
|
490
|
+
self,
|
|
491
|
+
query: str,
|
|
492
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
493
|
+
*,
|
|
494
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
495
|
+
) -> np.ndarray:
|
|
496
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
497
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
498
|
+
entity filter constraints.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
query: The predictive query.
|
|
502
|
+
indices: The entity primary keys to predict for. Will override the
|
|
503
|
+
indices given as part of the predictive query.
|
|
504
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
505
|
+
``None``, will use the maximum timestamp in the data.
|
|
506
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
507
|
+
"""
|
|
508
|
+
query_def = self._parse_query(query)
|
|
509
|
+
|
|
510
|
+
if indices is None:
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
512
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
513
|
+
"pass them via "
|
|
514
|
+
"`is_valid_entity(query, indices=...)`")
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
516
|
+
|
|
517
|
+
if len(indices) == 0:
|
|
518
|
+
raise ValueError("At least one entity is required")
|
|
519
|
+
|
|
520
|
+
if anchor_time is None:
|
|
521
|
+
anchor_time = self._graph_store.max_time
|
|
522
|
+
|
|
523
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
524
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
525
|
+
else:
|
|
526
|
+
assert anchor_time == 'entity'
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
528
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
531
|
+
|
|
532
|
+
node = self._graph_store.get_node_id(
|
|
533
|
+
table_name=query_def.entity_table,
|
|
534
|
+
pkey=pd.Series(indices),
|
|
535
|
+
)
|
|
536
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
+
return query_driver.is_valid(node, anchor_time)
|
|
331
538
|
|
|
332
539
|
def evaluate(
|
|
333
540
|
self,
|
|
@@ -342,6 +549,7 @@ class KumoRFM:
|
|
|
342
549
|
max_pq_iterations: int = 20,
|
|
343
550
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
344
551
|
verbose: Union[bool, ProgressLogger] = True,
|
|
552
|
+
use_prediction_time: bool = False,
|
|
345
553
|
) -> pd.DataFrame:
|
|
346
554
|
"""Evaluates a predictive query.
|
|
347
555
|
|
|
@@ -349,10 +557,10 @@ class KumoRFM:
|
|
|
349
557
|
query: The predictive query.
|
|
350
558
|
metrics: The metrics to use.
|
|
351
559
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
352
|
-
|
|
353
|
-
If set to
|
|
560
|
+
``None``, will use the maximum timestamp in the data.
|
|
561
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
354
562
|
context_anchor_time: The maximum anchor timestamp for context
|
|
355
|
-
examples. If set to
|
|
563
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
356
564
|
determine the anchor time for context examples.
|
|
357
565
|
run_mode: The :class:`RunMode` for the query.
|
|
358
566
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -365,6 +573,9 @@ class KumoRFM:
|
|
|
365
573
|
entities to find valid labels.
|
|
366
574
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
367
575
|
verbose: Whether to print verbose output.
|
|
576
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
577
|
+
additional feature during prediction. This is typically
|
|
578
|
+
beneficial for time series forecasting tasks.
|
|
368
579
|
|
|
369
580
|
Returns:
|
|
370
581
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -375,10 +586,10 @@ class KumoRFM:
|
|
|
375
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
376
587
|
f"custom 'num_hops={num_hops}' option")
|
|
377
588
|
|
|
378
|
-
if query_def.
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
379
590
|
query_def = replace(
|
|
380
591
|
query_def,
|
|
381
|
-
|
|
592
|
+
rfm_entity_ids=None,
|
|
382
593
|
)
|
|
383
594
|
|
|
384
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -408,6 +619,7 @@ class KumoRFM:
|
|
|
408
619
|
context=context,
|
|
409
620
|
run_mode=RunMode(run_mode),
|
|
410
621
|
metrics=metrics,
|
|
622
|
+
use_prediction_time=use_prediction_time,
|
|
411
623
|
)
|
|
412
624
|
with warnings.catch_warnings():
|
|
413
625
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -418,7 +630,7 @@ class KumoRFM:
|
|
|
418
630
|
|
|
419
631
|
if len(request_bytes) > _MAX_SIZE:
|
|
420
632
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
421
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
633
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
422
634
|
|
|
423
635
|
try:
|
|
424
636
|
resp = global_state.client.rfm_api.evaluate(request_bytes)
|
|
@@ -466,18 +678,19 @@ class KumoRFM:
|
|
|
466
678
|
|
|
467
679
|
if anchor_time is None:
|
|
468
680
|
anchor_time = self._graph_store.max_time
|
|
469
|
-
|
|
470
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
471
685
|
|
|
472
686
|
assert anchor_time is not None
|
|
473
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
474
688
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
475
689
|
else:
|
|
476
690
|
assert anchor_time == 'entity'
|
|
477
|
-
if (query_def.
|
|
478
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
479
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
480
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
481
694
|
f"to have a time column")
|
|
482
695
|
|
|
483
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -492,7 +705,7 @@ class KumoRFM:
|
|
|
492
705
|
)
|
|
493
706
|
|
|
494
707
|
entity = self._graph_store.pkey_map_dict[
|
|
495
|
-
query_def.
|
|
708
|
+
query_def.entity_table].index[node]
|
|
496
709
|
|
|
497
710
|
return pd.DataFrame({
|
|
498
711
|
'ENTITY': entity,
|
|
@@ -502,8 +715,8 @@ class KumoRFM:
|
|
|
502
715
|
|
|
503
716
|
# Helpers #################################################################
|
|
504
717
|
|
|
505
|
-
def _parse_query(self, query: str) ->
|
|
506
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
507
720
|
return query
|
|
508
721
|
|
|
509
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -513,12 +726,12 @@ class KumoRFM:
|
|
|
513
726
|
"predictions or evaluations.")
|
|
514
727
|
|
|
515
728
|
try:
|
|
516
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
517
730
|
query=query,
|
|
518
731
|
graph_definition=self._graph_def,
|
|
519
732
|
)
|
|
520
733
|
|
|
521
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
522
735
|
# TODO Expose validation warnings.
|
|
523
736
|
|
|
524
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -529,7 +742,7 @@ class KumoRFM:
|
|
|
529
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
530
743
|
f"parsing:\n{msg}")
|
|
531
744
|
|
|
532
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
533
746
|
except HTTPException as e:
|
|
534
747
|
try:
|
|
535
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -540,7 +753,7 @@ class KumoRFM:
|
|
|
540
753
|
|
|
541
754
|
def _validate_time(
|
|
542
755
|
self,
|
|
543
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
544
757
|
anchor_time: pd.Timestamp,
|
|
545
758
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
546
759
|
evaluate: bool,
|
|
@@ -563,6 +776,11 @@ class KumoRFM:
|
|
|
563
776
|
f"only contains data back to "
|
|
564
777
|
f"'{self._graph_store.min_time}'.")
|
|
565
778
|
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
566
784
|
if (context_anchor_time is not None
|
|
567
785
|
and context_anchor_time > anchor_time):
|
|
568
786
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -571,19 +789,18 @@ class KumoRFM:
|
|
|
571
789
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
572
790
|
f"intended.")
|
|
573
791
|
elif (query.query_type == QueryType.TEMPORAL
|
|
574
|
-
and context_anchor_time is not None
|
|
575
|
-
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
576
794
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
577
795
|
f"'{context_anchor_time}' will leak information "
|
|
578
796
|
f"from the prediction anchor timestamp "
|
|
579
797
|
f"'{anchor_time}'. Please make sure this is "
|
|
580
798
|
f"intended.")
|
|
581
799
|
|
|
582
|
-
elif (context_anchor_time is not None
|
|
583
|
-
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
584
802
|
< self._graph_store.min_time):
|
|
585
|
-
_time = context_anchor_time -
|
|
586
|
-
query.num_forecasts)
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
587
804
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
588
805
|
f"aggregation time range is too large. To form "
|
|
589
806
|
f"proper input data, we would need data back to "
|
|
@@ -596,8 +813,7 @@ class KumoRFM:
|
|
|
596
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
597
814
|
f"in the data. Please make sure this is intended.")
|
|
598
815
|
|
|
599
|
-
max_eval_time =
|
|
600
|
-
query.target.end_offset * query.num_forecasts)
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
601
817
|
if evaluate and anchor_time > max_eval_time:
|
|
602
818
|
raise ValueError(
|
|
603
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -605,7 +821,7 @@ class KumoRFM:
|
|
|
605
821
|
|
|
606
822
|
def _get_context(
|
|
607
823
|
self,
|
|
608
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
609
825
|
indices: Union[List[str], List[float], List[int], None],
|
|
610
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
611
827
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -633,8 +849,8 @@ class KumoRFM:
|
|
|
633
849
|
f"must go beyond this for your use-case.")
|
|
634
850
|
|
|
635
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
636
|
-
task_type =
|
|
637
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
638
854
|
edge_types=self._graph_store.edge_types,
|
|
639
855
|
)
|
|
640
856
|
|
|
@@ -666,11 +882,15 @@ class KumoRFM:
|
|
|
666
882
|
else:
|
|
667
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
668
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
669
890
|
if anchor_time is None:
|
|
670
891
|
anchor_time = self._graph_store.max_time
|
|
671
892
|
if evaluate:
|
|
672
|
-
anchor_time = anchor_time -
|
|
673
|
-
query.num_forecasts)
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
674
894
|
if logger is not None:
|
|
675
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
676
896
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -685,15 +905,14 @@ class KumoRFM:
|
|
|
685
905
|
assert anchor_time is not None
|
|
686
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
687
907
|
if context_anchor_time is None:
|
|
688
|
-
context_anchor_time = anchor_time -
|
|
689
|
-
query.num_forecasts)
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
690
909
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
691
910
|
evaluate)
|
|
692
911
|
else:
|
|
693
912
|
assert anchor_time == 'entity'
|
|
694
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
695
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
696
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
697
916
|
f"have a time column")
|
|
698
917
|
if context_anchor_time is not None:
|
|
699
918
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -744,7 +963,7 @@ class KumoRFM:
|
|
|
744
963
|
f"in batches")
|
|
745
964
|
|
|
746
965
|
test_node = self._graph_store.get_node_id(
|
|
747
|
-
table_name=query.
|
|
966
|
+
table_name=query.entity_table,
|
|
748
967
|
pkey=pd.Series(indices),
|
|
749
968
|
)
|
|
750
969
|
|
|
@@ -752,8 +971,7 @@ class KumoRFM:
|
|
|
752
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
753
972
|
len(test_node)).reset_index(drop=True)
|
|
754
973
|
else:
|
|
755
|
-
time = self._graph_store.time_dict[
|
|
756
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
757
975
|
time = time[test_node] * 1000**3
|
|
758
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
759
977
|
|
|
@@ -786,12 +1004,23 @@ class KumoRFM:
|
|
|
786
1004
|
raise NotImplementedError
|
|
787
1005
|
logger.log(msg)
|
|
788
1006
|
|
|
789
|
-
entity_table_names
|
|
790
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
791
1020
|
|
|
792
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
793
1022
|
# running out-of-distribution between in-context and test examples:
|
|
794
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
795
1024
|
if anchor_time == 'entity':
|
|
796
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
797
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -820,7 +1049,7 @@ class KumoRFM:
|
|
|
820
1049
|
|
|
821
1050
|
step_size: Optional[int] = None
|
|
822
1051
|
if query.query_type == QueryType.TEMPORAL:
|
|
823
|
-
step_size = date_offset_to_seconds(
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
824
1053
|
|
|
825
1054
|
return Context(
|
|
826
1055
|
task_type=task_type,
|
|
@@ -845,7 +1074,7 @@ class KumoRFM:
|
|
|
845
1074
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
846
1075
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
847
1076
|
elif task_type == TaskType.REGRESSION:
|
|
848
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1077
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
849
1078
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
850
1079
|
supported_metrics = [
|
|
851
1080
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|