kumoai 2.9.0.dev202509061830__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/rfm.py +3 -2
- kumoai/connector/file_upload_connector.py +71 -102
- kumoai/connector/utils.py +1367 -236
- kumoai/experimental/rfm/__init__.py +2 -2
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -80
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +22 -6
- kumoai/experimental/rfm/local_pquery_driver.py +129 -28
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -0
- kumoai/experimental/rfm/pquery/backend.py +4 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_backend.py +71 -30
- kumoai/experimental/rfm/pquery/pandas_executor.py +506 -0
- kumoai/experimental/rfm/rfm.py +442 -94
- kumoai/jobs.py +1 -0
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/progress_logger.py +62 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/METADATA +4 -5
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/RECORD +28 -26
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/WHEEL +0 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,13 +1,19 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import Iterator, List, Literal, Optional, Union, overload
|
|
4
9
|
|
|
5
10
|
import numpy as np
|
|
6
11
|
import pandas as pd
|
|
7
12
|
from kumoapi.model_plan import RunMode
|
|
8
13
|
from kumoapi.pquery import QueryType
|
|
14
|
+
from kumoapi.rfm import Context
|
|
15
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
16
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
17
|
PQueryDefinition,
|
|
12
18
|
RFMEvaluateRequest,
|
|
13
19
|
RFMPredictRequest,
|
|
@@ -20,11 +26,17 @@ from kumoai.exceptions import HTTPException
|
|
|
20
26
|
from kumoai.experimental.rfm import LocalGraph
|
|
21
27
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
22
28
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
23
|
-
from kumoai.experimental.rfm.local_pquery_driver import
|
|
29
|
+
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
30
|
+
LocalPQueryDriver,
|
|
31
|
+
date_offset_to_seconds,
|
|
32
|
+
)
|
|
24
33
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
25
34
|
|
|
26
35
|
_RANDOM_SEED = 42
|
|
27
36
|
|
|
37
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
38
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
39
|
+
|
|
28
40
|
_MAX_CONTEXT_SIZE = {
|
|
29
41
|
RunMode.DEBUG: 100,
|
|
30
42
|
RunMode.FAST: 1_000,
|
|
@@ -39,7 +51,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
39
51
|
}
|
|
40
52
|
|
|
41
53
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
42
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
54
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
43
55
|
"reduce either the number of tables in the graph, their "
|
|
44
56
|
"number of columns (e.g., large text columns), "
|
|
45
57
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -48,6 +60,34 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
48
60
|
"beyond this for your use-case.")
|
|
49
61
|
|
|
50
62
|
|
|
63
|
+
@dataclass(repr=False)
|
|
64
|
+
class Explanation:
|
|
65
|
+
prediction: pd.DataFrame
|
|
66
|
+
summary: str
|
|
67
|
+
details: ExplanationConfig
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
71
|
+
pass
|
|
72
|
+
|
|
73
|
+
@overload
|
|
74
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
78
|
+
if index == 0:
|
|
79
|
+
return self.prediction
|
|
80
|
+
if index == 1:
|
|
81
|
+
return self.summary
|
|
82
|
+
raise IndexError("Index out of range")
|
|
83
|
+
|
|
84
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
85
|
+
return iter((self.prediction, self.summary))
|
|
86
|
+
|
|
87
|
+
def __repr__(self) -> str:
|
|
88
|
+
return str((self.prediction, self.summary))
|
|
89
|
+
|
|
90
|
+
|
|
51
91
|
class KumoRFM:
|
|
52
92
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
53
93
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -105,28 +145,117 @@ class KumoRFM:
|
|
|
105
145
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
106
146
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
107
147
|
|
|
148
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
|
+
self.num_retries: int = 0
|
|
150
|
+
|
|
108
151
|
def __repr__(self) -> str:
|
|
109
152
|
return f'{self.__class__.__name__}()'
|
|
110
153
|
|
|
154
|
+
@contextmanager
|
|
155
|
+
def batch_mode(
|
|
156
|
+
self,
|
|
157
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
158
|
+
num_retries: int = 1,
|
|
159
|
+
) -> Generator[None, None, None]:
|
|
160
|
+
"""Context manager to predict in batches.
|
|
161
|
+
|
|
162
|
+
.. code-block:: python
|
|
163
|
+
|
|
164
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
165
|
+
df = model.predict(query, indices=...)
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
169
|
+
maximum applicable batch size for the given task.
|
|
170
|
+
num_retries: The maximum number of retries for failed queries due
|
|
171
|
+
to unexpected server issues.
|
|
172
|
+
"""
|
|
173
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
174
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
175
|
+
f"(got {batch_size})")
|
|
176
|
+
|
|
177
|
+
if num_retries < 0:
|
|
178
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
179
|
+
f"zero (got {num_retries})")
|
|
180
|
+
|
|
181
|
+
self._batch_size = batch_size
|
|
182
|
+
self.num_retries = num_retries
|
|
183
|
+
yield
|
|
184
|
+
self._batch_size = None
|
|
185
|
+
self.num_retries = 0
|
|
186
|
+
|
|
187
|
+
@overload
|
|
111
188
|
def predict(
|
|
112
189
|
self,
|
|
113
190
|
query: str,
|
|
191
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
114
192
|
*,
|
|
193
|
+
explain: Literal[False] = False,
|
|
115
194
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
195
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
116
196
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
117
197
|
num_neighbors: Optional[List[int]] = None,
|
|
118
198
|
num_hops: int = 2,
|
|
119
199
|
max_pq_iterations: int = 20,
|
|
120
200
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
121
201
|
verbose: Union[bool, ProgressLogger] = True,
|
|
202
|
+
use_prediction_time: bool = False,
|
|
122
203
|
) -> pd.DataFrame:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
@overload
|
|
207
|
+
def predict(
|
|
208
|
+
self,
|
|
209
|
+
query: str,
|
|
210
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
211
|
+
*,
|
|
212
|
+
explain: Literal[True],
|
|
213
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
214
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
215
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
216
|
+
num_neighbors: Optional[List[int]] = None,
|
|
217
|
+
num_hops: int = 2,
|
|
218
|
+
max_pq_iterations: int = 20,
|
|
219
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
220
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
221
|
+
use_prediction_time: bool = False,
|
|
222
|
+
) -> Explanation:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
def predict(
|
|
226
|
+
self,
|
|
227
|
+
query: str,
|
|
228
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
|
+
*,
|
|
230
|
+
explain: bool = False,
|
|
231
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
234
|
+
num_neighbors: Optional[List[int]] = None,
|
|
235
|
+
num_hops: int = 2,
|
|
236
|
+
max_pq_iterations: int = 20,
|
|
237
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
238
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
239
|
+
use_prediction_time: bool = False,
|
|
240
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
123
241
|
"""Returns predictions for a predictive query.
|
|
124
242
|
|
|
125
243
|
Args:
|
|
126
244
|
query: The predictive query.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
245
|
+
indices: The entity primary keys to predict for. Will override the
|
|
246
|
+
indices given as part of the predictive query. Predictions will
|
|
247
|
+
be generated for all indices, independent of whether they
|
|
248
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
+
explain: If set to ``True``, will additionally explain the
|
|
251
|
+
prediction. Explainability is currently only supported for
|
|
252
|
+
single entity predictions with ``run_mode="FAST"``.
|
|
253
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
|
+
``None``, will use the maximum timestamp in the data.
|
|
255
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
256
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
257
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
258
|
+
determine the anchor time for context examples.
|
|
130
259
|
run_mode: The :class:`RunMode` for the query.
|
|
131
260
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
132
261
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -138,11 +267,15 @@ class KumoRFM:
|
|
|
138
267
|
entities to find valid labels.
|
|
139
268
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
140
269
|
verbose: Whether to print verbose output.
|
|
270
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
271
|
+
additional feature during prediction. This is typically
|
|
272
|
+
beneficial for time series forecasting tasks.
|
|
141
273
|
|
|
142
274
|
Returns:
|
|
143
|
-
The predictions as a :class:`pandas.DataFrame
|
|
275
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
+
If ``explain=True``, additionally returns a textual summary that
|
|
277
|
+
explains the prediction.
|
|
144
278
|
"""
|
|
145
|
-
explain = False
|
|
146
279
|
query_def = self._parse_query(query)
|
|
147
280
|
|
|
148
281
|
if num_hops != 2 and num_neighbors is not None:
|
|
@@ -155,12 +288,24 @@ class KumoRFM:
|
|
|
155
288
|
f"mode has been reset. Please lower the run mode to "
|
|
156
289
|
f"suppress this warning.")
|
|
157
290
|
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
291
|
+
if indices is None:
|
|
292
|
+
if query_def.entity.ids is None:
|
|
293
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
|
+
"pass them via `predict(query, indices=...)`")
|
|
295
|
+
indices = query_def.entity.ids.value
|
|
296
|
+
else:
|
|
297
|
+
query_def = replace(
|
|
298
|
+
query_def,
|
|
299
|
+
entity=replace(query_def.entity, ids=None),
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if len(indices) == 0:
|
|
303
|
+
raise ValueError("At least one entity is required")
|
|
304
|
+
|
|
305
|
+
if explain and len(indices) > 1:
|
|
306
|
+
raise ValueError(
|
|
307
|
+
f"Cannot explain predictions for more than a single entity "
|
|
308
|
+
f"(got {len(indices)})")
|
|
164
309
|
|
|
165
310
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
166
311
|
if explain:
|
|
@@ -172,48 +317,185 @@ class KumoRFM:
|
|
|
172
317
|
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
173
318
|
|
|
174
319
|
with verbose as logger:
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
320
|
+
|
|
321
|
+
batch_size: Optional[int] = None
|
|
322
|
+
if self._batch_size == 'max':
|
|
323
|
+
task_type = query_def.get_task_type(
|
|
324
|
+
stypes=self._graph_store.stype_dict,
|
|
325
|
+
edge_types=self._graph_store.edge_types,
|
|
326
|
+
)
|
|
327
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
328
|
+
else:
|
|
329
|
+
batch_size = self._batch_size
|
|
330
|
+
|
|
331
|
+
if batch_size is not None:
|
|
332
|
+
offsets = range(0, len(indices), batch_size)
|
|
333
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
334
|
+
else:
|
|
335
|
+
batches = [indices]
|
|
336
|
+
|
|
337
|
+
if len(batches) > 1:
|
|
338
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
339
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
340
|
+
|
|
341
|
+
predictions: List[pd.DataFrame] = []
|
|
342
|
+
summary: Optional[str] = None
|
|
343
|
+
details: Optional[Explanation] = None
|
|
344
|
+
for i, batch in enumerate(batches):
|
|
345
|
+
# TODO Re-use the context for subsequent predictions.
|
|
346
|
+
context = self._get_context(
|
|
347
|
+
query=query_def,
|
|
348
|
+
indices=batch,
|
|
349
|
+
anchor_time=anchor_time,
|
|
350
|
+
context_anchor_time=context_anchor_time,
|
|
351
|
+
run_mode=RunMode(run_mode),
|
|
352
|
+
num_neighbors=num_neighbors,
|
|
353
|
+
num_hops=num_hops,
|
|
354
|
+
max_pq_iterations=max_pq_iterations,
|
|
355
|
+
evaluate=False,
|
|
356
|
+
random_seed=random_seed,
|
|
357
|
+
logger=logger if i == 0 else None,
|
|
358
|
+
)
|
|
359
|
+
request = RFMPredictRequest(
|
|
360
|
+
context=context,
|
|
361
|
+
run_mode=RunMode(run_mode),
|
|
362
|
+
use_prediction_time=use_prediction_time,
|
|
363
|
+
)
|
|
364
|
+
with warnings.catch_warnings():
|
|
365
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
366
|
+
request_msg = request.to_protobuf()
|
|
367
|
+
_bytes = request_msg.SerializeToString()
|
|
368
|
+
if i == 0:
|
|
369
|
+
logger.log(f"Generated context of size "
|
|
370
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
371
|
+
|
|
372
|
+
if len(_bytes) > _MAX_SIZE:
|
|
373
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
374
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
375
|
+
|
|
376
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
377
|
+
and len(batches) > 1):
|
|
378
|
+
verbose.init_progress(
|
|
379
|
+
total=len(batches),
|
|
380
|
+
description='Predicting',
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
for attempt in range(self.num_retries + 1):
|
|
384
|
+
try:
|
|
385
|
+
if explain:
|
|
386
|
+
resp = global_state.client.rfm_api.explain(_bytes)
|
|
387
|
+
summary = resp.summary
|
|
388
|
+
details = resp.details
|
|
389
|
+
else:
|
|
390
|
+
resp = global_state.client.rfm_api.predict(_bytes)
|
|
391
|
+
df = pd.DataFrame(**resp.prediction)
|
|
392
|
+
|
|
393
|
+
# Cast 'ENTITY' to correct data type:
|
|
394
|
+
if 'ENTITY' in df:
|
|
395
|
+
entity = query_def.entity.pkey.table_name
|
|
396
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
397
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
398
|
+
type(pkey_map.index[0]))
|
|
399
|
+
|
|
400
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
401
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
402
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
403
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
404
|
+
if isinstance(ser.iloc[0], str):
|
|
405
|
+
unit = None
|
|
406
|
+
else:
|
|
407
|
+
unit = 'ms'
|
|
408
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
409
|
+
ser, errors='coerce', unit=unit)
|
|
410
|
+
|
|
411
|
+
predictions.append(df)
|
|
412
|
+
|
|
413
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
414
|
+
and len(batches) > 1):
|
|
415
|
+
verbose.step()
|
|
416
|
+
|
|
417
|
+
break
|
|
418
|
+
except HTTPException as e:
|
|
419
|
+
if attempt == self.num_retries:
|
|
420
|
+
try:
|
|
421
|
+
msg = json.loads(e.detail)['detail']
|
|
422
|
+
except Exception:
|
|
423
|
+
msg = e.detail
|
|
424
|
+
raise RuntimeError(
|
|
425
|
+
f"An unexpected exception occurred. Please "
|
|
426
|
+
f"create an issue at "
|
|
427
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
428
|
+
) from None
|
|
429
|
+
|
|
430
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
431
|
+
|
|
432
|
+
if len(predictions) == 1:
|
|
433
|
+
prediction = predictions[0]
|
|
434
|
+
else:
|
|
435
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
|
+
|
|
437
|
+
if explain:
|
|
438
|
+
assert len(predictions) == 1
|
|
439
|
+
assert summary is not None
|
|
440
|
+
assert details is not None
|
|
441
|
+
return Explanation(
|
|
442
|
+
prediction=prediction,
|
|
443
|
+
summary=summary,
|
|
444
|
+
details=details,
|
|
189
445
|
)
|
|
190
|
-
with warnings.catch_warnings():
|
|
191
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
192
|
-
request_msg = request.to_protobuf()
|
|
193
|
-
request_bytes = request_msg.SerializeToString()
|
|
194
|
-
logger.log(f"Generated context of size "
|
|
195
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
196
446
|
|
|
197
|
-
|
|
198
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
199
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
447
|
+
return prediction
|
|
200
448
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
449
|
+
def is_valid_entity(
|
|
450
|
+
self,
|
|
451
|
+
query: str,
|
|
452
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
453
|
+
*,
|
|
454
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
455
|
+
) -> np.ndarray:
|
|
456
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
457
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
458
|
+
entity filter constraints.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
query: The predictive query.
|
|
462
|
+
indices: The entity primary keys to predict for. Will override the
|
|
463
|
+
indices given as part of the predictive query.
|
|
464
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
465
|
+
``None``, will use the maximum timestamp in the data.
|
|
466
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
467
|
+
"""
|
|
468
|
+
query_def = self._parse_query(query)
|
|
469
|
+
|
|
470
|
+
if indices is None:
|
|
471
|
+
if query_def.entity.ids is None:
|
|
472
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
|
+
"pass them via "
|
|
474
|
+
"`is_valid_entity(query, indices=...)`")
|
|
475
|
+
indices = query_def.entity.ids.value
|
|
476
|
+
|
|
477
|
+
if len(indices) == 0:
|
|
478
|
+
raise ValueError("At least one entity is required")
|
|
479
|
+
|
|
480
|
+
if anchor_time is None:
|
|
481
|
+
anchor_time = self._graph_store.max_time
|
|
482
|
+
|
|
483
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
484
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
485
|
+
else:
|
|
486
|
+
assert anchor_time == 'entity'
|
|
487
|
+
if (query_def.entity.pkey.table_name
|
|
488
|
+
not in self._graph_store.time_dict):
|
|
489
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
+
f"table '{query_def.entity.pkey.table_name}' "
|
|
491
|
+
f"to have a time column")
|
|
215
492
|
|
|
216
|
-
|
|
493
|
+
node = self._graph_store.get_node_id(
|
|
494
|
+
table_name=query_def.entity.pkey.table_name,
|
|
495
|
+
pkey=pd.Series(indices),
|
|
496
|
+
)
|
|
497
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
498
|
+
return query_driver.is_valid(node, anchor_time)
|
|
217
499
|
|
|
218
500
|
def evaluate(
|
|
219
501
|
self,
|
|
@@ -221,21 +503,26 @@ class KumoRFM:
|
|
|
221
503
|
*,
|
|
222
504
|
metrics: Optional[List[str]] = None,
|
|
223
505
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
506
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
224
507
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
225
508
|
num_neighbors: Optional[List[int]] = None,
|
|
226
509
|
num_hops: int = 2,
|
|
227
510
|
max_pq_iterations: int = 20,
|
|
228
511
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
229
512
|
verbose: Union[bool, ProgressLogger] = True,
|
|
513
|
+
use_prediction_time: bool = False,
|
|
230
514
|
) -> pd.DataFrame:
|
|
231
515
|
"""Evaluates a predictive query.
|
|
232
516
|
|
|
233
517
|
Args:
|
|
234
518
|
query: The predictive query.
|
|
235
519
|
metrics: The metrics to use.
|
|
236
|
-
anchor_time: The anchor timestamp for the
|
|
237
|
-
|
|
238
|
-
If set to
|
|
520
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
521
|
+
``None``, will use the maximum timestamp in the data.
|
|
522
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
523
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
524
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
525
|
+
determine the anchor time for context examples.
|
|
239
526
|
run_mode: The :class:`RunMode` for the query.
|
|
240
527
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
241
528
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -247,6 +534,9 @@ class KumoRFM:
|
|
|
247
534
|
entities to find valid labels.
|
|
248
535
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
249
536
|
verbose: Whether to print verbose output.
|
|
537
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
538
|
+
additional feature during prediction. This is typically
|
|
539
|
+
beneficial for time series forecasting tasks.
|
|
250
540
|
|
|
251
541
|
Returns:
|
|
252
542
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -257,6 +547,12 @@ class KumoRFM:
|
|
|
257
547
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
258
548
|
f"custom 'num_hops={num_hops}' option")
|
|
259
549
|
|
|
550
|
+
if query_def.entity.ids is not None:
|
|
551
|
+
query_def = replace(
|
|
552
|
+
query_def,
|
|
553
|
+
entity=replace(query_def.entity, ids=None),
|
|
554
|
+
)
|
|
555
|
+
|
|
260
556
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
261
557
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
262
558
|
|
|
@@ -265,8 +561,10 @@ class KumoRFM:
|
|
|
265
561
|
|
|
266
562
|
with verbose as logger:
|
|
267
563
|
context = self._get_context(
|
|
268
|
-
query_def,
|
|
564
|
+
query=query_def,
|
|
565
|
+
indices=None,
|
|
269
566
|
anchor_time=anchor_time,
|
|
567
|
+
context_anchor_time=context_anchor_time,
|
|
270
568
|
run_mode=RunMode(run_mode),
|
|
271
569
|
num_neighbors=num_neighbors,
|
|
272
570
|
num_hops=num_hops,
|
|
@@ -282,6 +580,7 @@ class KumoRFM:
|
|
|
282
580
|
context=context,
|
|
283
581
|
run_mode=RunMode(run_mode),
|
|
284
582
|
metrics=metrics,
|
|
583
|
+
use_prediction_time=use_prediction_time,
|
|
285
584
|
)
|
|
286
585
|
with warnings.catch_warnings():
|
|
287
586
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -340,11 +639,12 @@ class KumoRFM:
|
|
|
340
639
|
|
|
341
640
|
if anchor_time is None:
|
|
342
641
|
anchor_time = self._graph_store.max_time
|
|
343
|
-
anchor_time = anchor_time - query_def.target.end_offset
|
|
642
|
+
anchor_time = anchor_time - (query_def.target.end_offset *
|
|
643
|
+
query_def.num_forecasts)
|
|
344
644
|
|
|
345
645
|
assert anchor_time is not None
|
|
346
646
|
if isinstance(anchor_time, pd.Timestamp):
|
|
347
|
-
self._validate_time(query_def, anchor_time, evaluate=True)
|
|
647
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
348
648
|
else:
|
|
349
649
|
assert anchor_time == 'entity'
|
|
350
650
|
if (query_def.entity.pkey.table_name
|
|
@@ -361,10 +661,14 @@ class KumoRFM:
|
|
|
361
661
|
anchor_time=anchor_time,
|
|
362
662
|
batch_size=min(10_000, size),
|
|
363
663
|
max_iterations=max_iterations,
|
|
664
|
+
guarantee_train_examples=False,
|
|
364
665
|
)
|
|
365
666
|
|
|
667
|
+
entity = self._graph_store.pkey_map_dict[
|
|
668
|
+
query_def.entity.pkey.table_name].index[node]
|
|
669
|
+
|
|
366
670
|
return pd.DataFrame({
|
|
367
|
-
'ENTITY':
|
|
671
|
+
'ENTITY': entity,
|
|
368
672
|
'ANCHOR_TIMESTAMP': time,
|
|
369
673
|
'TARGET': y,
|
|
370
674
|
})
|
|
@@ -411,6 +715,7 @@ class KumoRFM:
|
|
|
411
715
|
self,
|
|
412
716
|
query: PQueryDefinition,
|
|
413
717
|
anchor_time: pd.Timestamp,
|
|
718
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
414
719
|
evaluate: bool,
|
|
415
720
|
) -> None:
|
|
416
721
|
|
|
@@ -422,22 +727,41 @@ class KumoRFM:
|
|
|
422
727
|
f"the earliest timestamp "
|
|
423
728
|
f"'{self._graph_store.min_time}' in the data.")
|
|
424
729
|
|
|
425
|
-
if
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
f"
|
|
429
|
-
f"
|
|
430
|
-
f"however, your data
|
|
730
|
+
if (context_anchor_time is not None
|
|
731
|
+
and context_anchor_time < self._graph_store.min_time):
|
|
732
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
733
|
+
f"aggregation time range is too large. To make "
|
|
734
|
+
f"this prediction, we would need data back to "
|
|
735
|
+
f"'{context_anchor_time}', however, your data "
|
|
736
|
+
f"only contains data back to "
|
|
431
737
|
f"'{self._graph_store.min_time}'.")
|
|
432
738
|
|
|
433
|
-
if (
|
|
434
|
-
|
|
435
|
-
warnings.warn(f"
|
|
436
|
-
f"
|
|
437
|
-
f"
|
|
438
|
-
f"'{anchor_time
|
|
439
|
-
f"
|
|
440
|
-
|
|
739
|
+
if (context_anchor_time is not None
|
|
740
|
+
and context_anchor_time > anchor_time):
|
|
741
|
+
warnings.warn(f"Context anchor timestamp "
|
|
742
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
743
|
+
f"date than the prediction anchor timestamp "
|
|
744
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
|
+
f"intended.")
|
|
746
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
+
and context_anchor_time is not None and context_anchor_time +
|
|
748
|
+
query.target.end_offset * query.num_forecasts > anchor_time):
|
|
749
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
|
+
f"'{context_anchor_time}' will leak information "
|
|
751
|
+
f"from the prediction anchor timestamp "
|
|
752
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
753
|
+
f"intended.")
|
|
754
|
+
|
|
755
|
+
elif (context_anchor_time is not None and context_anchor_time -
|
|
756
|
+
query.target.end_offset * query.num_forecasts
|
|
757
|
+
< self._graph_store.min_time):
|
|
758
|
+
_time = context_anchor_time - (query.target.end_offset *
|
|
759
|
+
query.num_forecasts)
|
|
760
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
|
+
f"aggregation time range is too large. To form "
|
|
762
|
+
f"proper input data, we would need data back to "
|
|
763
|
+
f"'{_time}', however, your data only contains "
|
|
764
|
+
f"data back to '{self._graph_store.min_time}'.")
|
|
441
765
|
|
|
442
766
|
if (not evaluate and anchor_time
|
|
443
767
|
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
@@ -445,17 +769,19 @@ class KumoRFM:
|
|
|
445
769
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
446
770
|
f"in the data. Please make sure this is intended.")
|
|
447
771
|
|
|
448
|
-
|
|
449
|
-
|
|
772
|
+
max_eval_time = (self._graph_store.max_time -
|
|
773
|
+
query.target.end_offset * query.num_forecasts)
|
|
774
|
+
if evaluate and anchor_time > max_eval_time:
|
|
450
775
|
raise ValueError(
|
|
451
776
|
f"Anchor timestamp for evaluation is after the latest "
|
|
452
|
-
f"supported timestamp "
|
|
453
|
-
f"'{self._graph_store.max_time - query.target.end_offset}'.")
|
|
777
|
+
f"supported timestamp '{max_eval_time}'.")
|
|
454
778
|
|
|
455
779
|
def _get_context(
|
|
456
780
|
self,
|
|
457
781
|
query: PQueryDefinition,
|
|
782
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
458
783
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
784
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
459
785
|
run_mode: RunMode,
|
|
460
786
|
num_neighbors: Optional[List[int]],
|
|
461
787
|
num_hops: int,
|
|
@@ -516,25 +842,36 @@ class KumoRFM:
|
|
|
516
842
|
if anchor_time is None:
|
|
517
843
|
anchor_time = self._graph_store.max_time
|
|
518
844
|
if evaluate:
|
|
519
|
-
anchor_time = anchor_time - query.target.end_offset
|
|
845
|
+
anchor_time = anchor_time - (query.target.end_offset *
|
|
846
|
+
query.num_forecasts)
|
|
520
847
|
if logger is not None:
|
|
521
848
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
522
|
-
if
|
|
523
|
-
|
|
524
|
-
|
|
849
|
+
if anchor_time == pd.Timestamp.min:
|
|
850
|
+
pass # Static graph
|
|
851
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
852
|
+
and anchor_time.second == 0
|
|
853
|
+
and anchor_time.microsecond == 0):
|
|
525
854
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
526
855
|
else:
|
|
527
856
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
528
857
|
|
|
529
858
|
assert anchor_time is not None
|
|
530
859
|
if isinstance(anchor_time, pd.Timestamp):
|
|
531
|
-
|
|
860
|
+
if context_anchor_time is None:
|
|
861
|
+
context_anchor_time = anchor_time - (query.target.end_offset *
|
|
862
|
+
query.num_forecasts)
|
|
863
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
|
+
evaluate)
|
|
532
865
|
else:
|
|
533
866
|
assert anchor_time == 'entity'
|
|
534
867
|
if query.entity.pkey.table_name not in self._graph_store.time_dict:
|
|
535
868
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
536
869
|
f"table '{query.entity.pkey.table_name}' to "
|
|
537
870
|
f"have a time column")
|
|
871
|
+
if context_anchor_time is not None:
|
|
872
|
+
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
873
|
+
"`anchor_time='entity'`")
|
|
874
|
+
context_anchor_time = None
|
|
538
875
|
|
|
539
876
|
y_test: Optional[pd.Series] = None
|
|
540
877
|
if evaluate:
|
|
@@ -546,6 +883,7 @@ class KumoRFM:
|
|
|
546
883
|
size=max_test_size,
|
|
547
884
|
anchor_time=anchor_time,
|
|
548
885
|
max_iterations=max_pq_iterations,
|
|
886
|
+
guarantee_train_examples=True,
|
|
549
887
|
)
|
|
550
888
|
if logger is not None:
|
|
551
889
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -569,20 +907,18 @@ class KumoRFM:
|
|
|
569
907
|
logger.log(msg)
|
|
570
908
|
|
|
571
909
|
else:
|
|
572
|
-
assert
|
|
910
|
+
assert indices is not None
|
|
573
911
|
|
|
574
|
-
|
|
575
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
912
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
576
913
|
raise ValueError(f"Cannot predict for more than "
|
|
577
|
-
f"{
|
|
578
|
-
f"(got {len(
|
|
914
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
915
|
+
f"once (got {len(indices):,}). Use "
|
|
916
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
917
|
+
f"in batches")
|
|
579
918
|
|
|
580
919
|
test_node = self._graph_store.get_node_id(
|
|
581
920
|
table_name=query.entity.pkey.table_name,
|
|
582
|
-
pkey=pd.Series(
|
|
583
|
-
query.entity.ids.value,
|
|
584
|
-
dtype=query.entity.ids.dtype,
|
|
585
|
-
),
|
|
921
|
+
pkey=pd.Series(indices),
|
|
586
922
|
)
|
|
587
923
|
|
|
588
924
|
if isinstance(anchor_time, pd.Timestamp):
|
|
@@ -596,7 +932,7 @@ class KumoRFM:
|
|
|
596
932
|
|
|
597
933
|
train_node, train_time, y_train = query_driver.collect_train(
|
|
598
934
|
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
599
|
-
anchor_time=
|
|
935
|
+
anchor_time=context_anchor_time or 'entity',
|
|
600
936
|
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
601
937
|
or anchor_time == 'entity') else None,
|
|
602
938
|
max_iterations=max_pq_iterations,
|
|
@@ -648,6 +984,17 @@ class KumoRFM:
|
|
|
648
984
|
exclude_cols_dict=exclude_cols_dict,
|
|
649
985
|
)
|
|
650
986
|
|
|
987
|
+
if len(subgraph.table_dict) >= 15:
|
|
988
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
989
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
990
|
+
f"Please create a feature request at "
|
|
991
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
992
|
+
f"must go beyond this for your use-case.")
|
|
993
|
+
|
|
994
|
+
step_size: Optional[int] = None
|
|
995
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
+
step_size = date_offset_to_seconds(query.target.end_offset)
|
|
997
|
+
|
|
651
998
|
return Context(
|
|
652
999
|
task_type=task_type,
|
|
653
1000
|
entity_table_names=entity_table_names,
|
|
@@ -655,6 +1002,7 @@ class KumoRFM:
|
|
|
655
1002
|
y_train=y_train,
|
|
656
1003
|
y_test=y_test,
|
|
657
1004
|
top_k=query.top_k,
|
|
1005
|
+
step_size=step_size,
|
|
658
1006
|
)
|
|
659
1007
|
|
|
660
1008
|
@staticmethod
|
|
@@ -670,7 +1018,7 @@ class KumoRFM:
|
|
|
670
1018
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
671
1019
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
672
1020
|
elif task_type == TaskType.REGRESSION:
|
|
673
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1021
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
674
1022
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
675
1023
|
supported_metrics = [
|
|
676
1024
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|