kumoai 2.9.0.dev202509081831__cp312-cp312-win_amd64.whl → 2.12.0.dev202511111731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +71 -102
- kumoai/connector/utils.py +1367 -236
- kumoai/experimental/rfm/__init__.py +5 -3
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -80
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +22 -6
- kumoai/experimental/rfm/local_pquery_driver.py +336 -42
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -222
- kumoai/experimental/rfm/rfm.py +514 -117
- kumoai/jobs.py +1 -0
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/progress_logger.py +68 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA +4 -5
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/RECORD +28 -28
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/WHEEL +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,17 +1,32 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
26
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
13
29
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
30
|
)
|
|
16
31
|
from kumoapi.task import TaskType
|
|
17
32
|
|
|
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
24
39
|
LocalPQueryDriver,
|
|
25
40
|
date_offset_to_seconds,
|
|
26
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
27
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
28
44
|
|
|
29
45
|
_RANDOM_SEED = 42
|
|
30
46
|
|
|
47
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
|
+
|
|
31
50
|
_MAX_CONTEXT_SIZE = {
|
|
32
51
|
RunMode.DEBUG: 100,
|
|
33
52
|
RunMode.FAST: 1_000,
|
|
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
42
61
|
}
|
|
43
62
|
|
|
44
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
45
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
64
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
46
65
|
"reduce either the number of tables in the graph, their "
|
|
47
66
|
"number of columns (e.g., large text columns), "
|
|
48
67
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
51
70
|
"beyond this for your use-case.")
|
|
52
71
|
|
|
53
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
54
118
|
class KumoRFM:
|
|
55
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
56
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -108,28 +172,120 @@ class KumoRFM:
|
|
|
108
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
109
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
110
174
|
|
|
175
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
|
+
self.num_retries: int = 0
|
|
177
|
+
|
|
111
178
|
def __repr__(self) -> str:
|
|
112
179
|
return f'{self.__class__.__name__}()'
|
|
113
180
|
|
|
181
|
+
@contextmanager
|
|
182
|
+
def batch_mode(
|
|
183
|
+
self,
|
|
184
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
185
|
+
num_retries: int = 1,
|
|
186
|
+
) -> Generator[None, None, None]:
|
|
187
|
+
"""Context manager to predict in batches.
|
|
188
|
+
|
|
189
|
+
.. code-block:: python
|
|
190
|
+
|
|
191
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
192
|
+
df = model.predict(query, indices=...)
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
196
|
+
maximum applicable batch size for the given task.
|
|
197
|
+
num_retries: The maximum number of retries for failed queries due
|
|
198
|
+
to unexpected server issues.
|
|
199
|
+
"""
|
|
200
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
201
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
202
|
+
f"(got {batch_size})")
|
|
203
|
+
|
|
204
|
+
if num_retries < 0:
|
|
205
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
206
|
+
f"zero (got {num_retries})")
|
|
207
|
+
|
|
208
|
+
self._batch_size = batch_size
|
|
209
|
+
self.num_retries = num_retries
|
|
210
|
+
yield
|
|
211
|
+
self._batch_size = None
|
|
212
|
+
self.num_retries = 0
|
|
213
|
+
|
|
214
|
+
@overload
|
|
114
215
|
def predict(
|
|
115
216
|
self,
|
|
116
217
|
query: str,
|
|
218
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
117
219
|
*,
|
|
220
|
+
explain: Literal[False] = False,
|
|
118
221
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
222
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
119
223
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
120
224
|
num_neighbors: Optional[List[int]] = None,
|
|
121
225
|
num_hops: int = 2,
|
|
122
226
|
max_pq_iterations: int = 20,
|
|
123
227
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
124
228
|
verbose: Union[bool, ProgressLogger] = True,
|
|
229
|
+
use_prediction_time: bool = False,
|
|
125
230
|
) -> pd.DataFrame:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
def predict(
|
|
235
|
+
self,
|
|
236
|
+
query: str,
|
|
237
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
238
|
+
*,
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
240
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
241
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
242
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
243
|
+
num_neighbors: Optional[List[int]] = None,
|
|
244
|
+
num_hops: int = 2,
|
|
245
|
+
max_pq_iterations: int = 20,
|
|
246
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
247
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
248
|
+
use_prediction_time: bool = False,
|
|
249
|
+
) -> Explanation:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
def predict(
|
|
253
|
+
self,
|
|
254
|
+
query: str,
|
|
255
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
256
|
+
*,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
258
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
259
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
260
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
261
|
+
num_neighbors: Optional[List[int]] = None,
|
|
262
|
+
num_hops: int = 2,
|
|
263
|
+
max_pq_iterations: int = 20,
|
|
264
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
265
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
267
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
126
268
|
"""Returns predictions for a predictive query.
|
|
127
269
|
|
|
128
270
|
Args:
|
|
129
271
|
query: The predictive query.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
272
|
+
indices: The entity primary keys to predict for. Will override the
|
|
273
|
+
indices given as part of the predictive query. Predictions will
|
|
274
|
+
be generated for all indices, independent of whether they
|
|
275
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
276
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
283
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
284
|
+
``None``, will use the maximum timestamp in the data.
|
|
285
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
286
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
287
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
288
|
+
determine the anchor time for context examples.
|
|
133
289
|
run_mode: The :class:`RunMode` for the query.
|
|
134
290
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
135
291
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -141,32 +297,54 @@ class KumoRFM:
|
|
|
141
297
|
entities to find valid labels.
|
|
142
298
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
143
299
|
verbose: Whether to print verbose output.
|
|
300
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
301
|
+
additional feature during prediction. This is typically
|
|
302
|
+
beneficial for time series forecasting tasks.
|
|
144
303
|
|
|
145
304
|
Returns:
|
|
146
|
-
The predictions as a :class:`pandas.DataFrame
|
|
305
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
147
308
|
"""
|
|
148
|
-
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
149
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
150
317
|
|
|
151
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
152
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
153
320
|
f"custom 'num_hops={num_hops}' option")
|
|
154
321
|
|
|
155
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
156
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
157
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
158
327
|
f"mode has been reset. Please lower the run mode to "
|
|
159
328
|
f"suppress this warning.")
|
|
160
329
|
|
|
161
|
-
if
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
330
|
+
if indices is None:
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
332
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
333
|
+
"pass them via `predict(query, indices=...)`")
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
335
|
+
else:
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
337
|
+
|
|
338
|
+
if len(indices) == 0:
|
|
339
|
+
raise ValueError("At least one entity is required")
|
|
340
|
+
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Cannot explain predictions for more than a single entity "
|
|
344
|
+
f"(got {len(indices)})")
|
|
167
345
|
|
|
168
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
169
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
170
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
171
349
|
else:
|
|
172
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -175,48 +353,188 @@ class KumoRFM:
|
|
|
175
353
|
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
176
354
|
|
|
177
355
|
with verbose as logger:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
356
|
+
|
|
357
|
+
batch_size: Optional[int] = None
|
|
358
|
+
if self._batch_size == 'max':
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
361
|
+
edge_types=self._graph_store.edge_types,
|
|
362
|
+
)
|
|
363
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
364
|
+
else:
|
|
365
|
+
batch_size = self._batch_size
|
|
366
|
+
|
|
367
|
+
if batch_size is not None:
|
|
368
|
+
offsets = range(0, len(indices), batch_size)
|
|
369
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
370
|
+
else:
|
|
371
|
+
batches = [indices]
|
|
372
|
+
|
|
373
|
+
if len(batches) > 1:
|
|
374
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
375
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
376
|
+
|
|
377
|
+
predictions: List[pd.DataFrame] = []
|
|
378
|
+
summary: Optional[str] = None
|
|
379
|
+
details: Optional[Explanation] = None
|
|
380
|
+
for i, batch in enumerate(batches):
|
|
381
|
+
# TODO Re-use the context for subsequent predictions.
|
|
382
|
+
context = self._get_context(
|
|
383
|
+
query=query_def,
|
|
384
|
+
indices=batch,
|
|
385
|
+
anchor_time=anchor_time,
|
|
386
|
+
context_anchor_time=context_anchor_time,
|
|
387
|
+
run_mode=RunMode(run_mode),
|
|
388
|
+
num_neighbors=num_neighbors,
|
|
389
|
+
num_hops=num_hops,
|
|
390
|
+
max_pq_iterations=max_pq_iterations,
|
|
391
|
+
evaluate=False,
|
|
392
|
+
random_seed=random_seed,
|
|
393
|
+
logger=logger if i == 0 else None,
|
|
394
|
+
)
|
|
395
|
+
request = RFMPredictRequest(
|
|
396
|
+
context=context,
|
|
397
|
+
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
399
|
+
use_prediction_time=use_prediction_time,
|
|
400
|
+
)
|
|
401
|
+
with warnings.catch_warnings():
|
|
402
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
403
|
+
request_msg = request.to_protobuf()
|
|
404
|
+
_bytes = request_msg.SerializeToString()
|
|
405
|
+
if i == 0:
|
|
406
|
+
logger.log(f"Generated context of size "
|
|
407
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
408
|
+
|
|
409
|
+
if len(_bytes) > _MAX_SIZE:
|
|
410
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
411
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
412
|
+
|
|
413
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
414
|
+
and len(batches) > 1):
|
|
415
|
+
verbose.init_progress(
|
|
416
|
+
total=len(batches),
|
|
417
|
+
description='Predicting',
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
for attempt in range(self.num_retries + 1):
|
|
421
|
+
try:
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
427
|
+
summary = resp.summary
|
|
428
|
+
details = resp.details
|
|
429
|
+
else:
|
|
430
|
+
resp = global_state.client.rfm_api.predict(_bytes)
|
|
431
|
+
df = pd.DataFrame(**resp.prediction)
|
|
432
|
+
|
|
433
|
+
# Cast 'ENTITY' to correct data type:
|
|
434
|
+
if 'ENTITY' in df:
|
|
435
|
+
entity = query_def.entity_table
|
|
436
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
437
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
438
|
+
type(pkey_map.index[0]))
|
|
439
|
+
|
|
440
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
442
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
443
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
444
|
+
if isinstance(ser.iloc[0], str):
|
|
445
|
+
unit = None
|
|
446
|
+
else:
|
|
447
|
+
unit = 'ms'
|
|
448
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
449
|
+
ser, errors='coerce', unit=unit)
|
|
450
|
+
|
|
451
|
+
predictions.append(df)
|
|
452
|
+
|
|
453
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
454
|
+
and len(batches) > 1):
|
|
455
|
+
verbose.step()
|
|
456
|
+
|
|
457
|
+
break
|
|
458
|
+
except HTTPException as e:
|
|
459
|
+
if attempt == self.num_retries:
|
|
460
|
+
try:
|
|
461
|
+
msg = json.loads(e.detail)['detail']
|
|
462
|
+
except Exception:
|
|
463
|
+
msg = e.detail
|
|
464
|
+
raise RuntimeError(
|
|
465
|
+
f"An unexpected exception occurred. Please "
|
|
466
|
+
f"create an issue at "
|
|
467
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
468
|
+
) from None
|
|
469
|
+
|
|
470
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
471
|
+
|
|
472
|
+
if len(predictions) == 1:
|
|
473
|
+
prediction = predictions[0]
|
|
474
|
+
else:
|
|
475
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
476
|
+
|
|
477
|
+
if explain_config is not None:
|
|
478
|
+
assert len(predictions) == 1
|
|
479
|
+
assert summary is not None
|
|
480
|
+
assert details is not None
|
|
481
|
+
return Explanation(
|
|
482
|
+
prediction=prediction,
|
|
483
|
+
summary=summary,
|
|
484
|
+
details=details,
|
|
192
485
|
)
|
|
193
|
-
with warnings.catch_warnings():
|
|
194
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
195
|
-
request_msg = request.to_protobuf()
|
|
196
|
-
request_bytes = request_msg.SerializeToString()
|
|
197
|
-
logger.log(f"Generated context of size "
|
|
198
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
199
486
|
|
|
200
|
-
|
|
201
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
202
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
487
|
+
return prediction
|
|
203
488
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
489
|
+
def is_valid_entity(
|
|
490
|
+
self,
|
|
491
|
+
query: str,
|
|
492
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
493
|
+
*,
|
|
494
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
495
|
+
) -> np.ndarray:
|
|
496
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
497
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
498
|
+
entity filter constraints.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
query: The predictive query.
|
|
502
|
+
indices: The entity primary keys to predict for. Will override the
|
|
503
|
+
indices given as part of the predictive query.
|
|
504
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
505
|
+
``None``, will use the maximum timestamp in the data.
|
|
506
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
507
|
+
"""
|
|
508
|
+
query_def = self._parse_query(query)
|
|
509
|
+
|
|
510
|
+
if indices is None:
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
512
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
513
|
+
"pass them via "
|
|
514
|
+
"`is_valid_entity(query, indices=...)`")
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
516
|
+
|
|
517
|
+
if len(indices) == 0:
|
|
518
|
+
raise ValueError("At least one entity is required")
|
|
519
|
+
|
|
520
|
+
if anchor_time is None:
|
|
521
|
+
anchor_time = self._graph_store.max_time
|
|
522
|
+
|
|
523
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
524
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
525
|
+
else:
|
|
526
|
+
assert anchor_time == 'entity'
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
528
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
218
531
|
|
|
219
|
-
|
|
532
|
+
node = self._graph_store.get_node_id(
|
|
533
|
+
table_name=query_def.entity_table,
|
|
534
|
+
pkey=pd.Series(indices),
|
|
535
|
+
)
|
|
536
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
+
return query_driver.is_valid(node, anchor_time)
|
|
220
538
|
|
|
221
539
|
def evaluate(
|
|
222
540
|
self,
|
|
@@ -224,21 +542,26 @@ class KumoRFM:
|
|
|
224
542
|
*,
|
|
225
543
|
metrics: Optional[List[str]] = None,
|
|
226
544
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
545
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
227
546
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
228
547
|
num_neighbors: Optional[List[int]] = None,
|
|
229
548
|
num_hops: int = 2,
|
|
230
549
|
max_pq_iterations: int = 20,
|
|
231
550
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
232
551
|
verbose: Union[bool, ProgressLogger] = True,
|
|
552
|
+
use_prediction_time: bool = False,
|
|
233
553
|
) -> pd.DataFrame:
|
|
234
554
|
"""Evaluates a predictive query.
|
|
235
555
|
|
|
236
556
|
Args:
|
|
237
557
|
query: The predictive query.
|
|
238
558
|
metrics: The metrics to use.
|
|
239
|
-
anchor_time: The anchor timestamp for the
|
|
240
|
-
|
|
241
|
-
If set to
|
|
559
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
560
|
+
``None``, will use the maximum timestamp in the data.
|
|
561
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
562
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
563
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
564
|
+
determine the anchor time for context examples.
|
|
242
565
|
run_mode: The :class:`RunMode` for the query.
|
|
243
566
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
244
567
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -250,6 +573,9 @@ class KumoRFM:
|
|
|
250
573
|
entities to find valid labels.
|
|
251
574
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
252
575
|
verbose: Whether to print verbose output.
|
|
576
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
577
|
+
additional feature during prediction. This is typically
|
|
578
|
+
beneficial for time series forecasting tasks.
|
|
253
579
|
|
|
254
580
|
Returns:
|
|
255
581
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -260,6 +586,12 @@ class KumoRFM:
|
|
|
260
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
261
587
|
f"custom 'num_hops={num_hops}' option")
|
|
262
588
|
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
590
|
+
query_def = replace(
|
|
591
|
+
query_def,
|
|
592
|
+
rfm_entity_ids=None,
|
|
593
|
+
)
|
|
594
|
+
|
|
263
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
264
596
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
265
597
|
|
|
@@ -268,8 +600,10 @@ class KumoRFM:
|
|
|
268
600
|
|
|
269
601
|
with verbose as logger:
|
|
270
602
|
context = self._get_context(
|
|
271
|
-
query_def,
|
|
603
|
+
query=query_def,
|
|
604
|
+
indices=None,
|
|
272
605
|
anchor_time=anchor_time,
|
|
606
|
+
context_anchor_time=context_anchor_time,
|
|
273
607
|
run_mode=RunMode(run_mode),
|
|
274
608
|
num_neighbors=num_neighbors,
|
|
275
609
|
num_hops=num_hops,
|
|
@@ -285,6 +619,7 @@ class KumoRFM:
|
|
|
285
619
|
context=context,
|
|
286
620
|
run_mode=RunMode(run_mode),
|
|
287
621
|
metrics=metrics,
|
|
622
|
+
use_prediction_time=use_prediction_time,
|
|
288
623
|
)
|
|
289
624
|
with warnings.catch_warnings():
|
|
290
625
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -343,17 +678,19 @@ class KumoRFM:
|
|
|
343
678
|
|
|
344
679
|
if anchor_time is None:
|
|
345
680
|
anchor_time = self._graph_store.max_time
|
|
346
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
347
685
|
|
|
348
686
|
assert anchor_time is not None
|
|
349
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
350
|
-
self._validate_time(query_def, anchor_time, evaluate=True)
|
|
688
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
351
689
|
else:
|
|
352
690
|
assert anchor_time == 'entity'
|
|
353
|
-
if (query_def.
|
|
354
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
355
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
356
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
357
694
|
f"to have a time column")
|
|
358
695
|
|
|
359
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -364,18 +701,22 @@ class KumoRFM:
|
|
|
364
701
|
anchor_time=anchor_time,
|
|
365
702
|
batch_size=min(10_000, size),
|
|
366
703
|
max_iterations=max_iterations,
|
|
704
|
+
guarantee_train_examples=False,
|
|
367
705
|
)
|
|
368
706
|
|
|
707
|
+
entity = self._graph_store.pkey_map_dict[
|
|
708
|
+
query_def.entity_table].index[node]
|
|
709
|
+
|
|
369
710
|
return pd.DataFrame({
|
|
370
|
-
'ENTITY':
|
|
711
|
+
'ENTITY': entity,
|
|
371
712
|
'ANCHOR_TIMESTAMP': time,
|
|
372
713
|
'TARGET': y,
|
|
373
714
|
})
|
|
374
715
|
|
|
375
716
|
# Helpers #################################################################
|
|
376
717
|
|
|
377
|
-
def _parse_query(self, query: str) ->
|
|
378
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
379
720
|
return query
|
|
380
721
|
|
|
381
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -385,12 +726,12 @@ class KumoRFM:
|
|
|
385
726
|
"predictions or evaluations.")
|
|
386
727
|
|
|
387
728
|
try:
|
|
388
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
389
730
|
query=query,
|
|
390
731
|
graph_definition=self._graph_def,
|
|
391
732
|
)
|
|
392
733
|
|
|
393
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
394
735
|
# TODO Expose validation warnings.
|
|
395
736
|
|
|
396
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -401,7 +742,7 @@ class KumoRFM:
|
|
|
401
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
402
743
|
f"parsing:\n{msg}")
|
|
403
744
|
|
|
404
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
405
746
|
except HTTPException as e:
|
|
406
747
|
try:
|
|
407
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -412,8 +753,9 @@ class KumoRFM:
|
|
|
412
753
|
|
|
413
754
|
def _validate_time(
|
|
414
755
|
self,
|
|
415
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
416
757
|
anchor_time: pd.Timestamp,
|
|
758
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
417
759
|
evaluate: bool,
|
|
418
760
|
) -> None:
|
|
419
761
|
|
|
@@ -425,20 +767,44 @@ class KumoRFM:
|
|
|
425
767
|
f"the earliest timestamp "
|
|
426
768
|
f"'{self._graph_store.min_time}' in the data.")
|
|
427
769
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
raise ValueError(f"
|
|
431
|
-
f"time range is too large. To make
|
|
432
|
-
f"prediction, we would need data back to "
|
|
433
|
-
f"'{
|
|
434
|
-
f"data back to
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
770
|
+
if (context_anchor_time is not None
|
|
771
|
+
and context_anchor_time < self._graph_store.min_time):
|
|
772
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
773
|
+
f"aggregation time range is too large. To make "
|
|
774
|
+
f"this prediction, we would need data back to "
|
|
775
|
+
f"'{context_anchor_time}', however, your data "
|
|
776
|
+
f"only contains data back to "
|
|
777
|
+
f"'{self._graph_store.min_time}'.")
|
|
778
|
+
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
784
|
+
if (context_anchor_time is not None
|
|
785
|
+
and context_anchor_time > anchor_time):
|
|
786
|
+
warnings.warn(f"Context anchor timestamp "
|
|
787
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
788
|
+
f"date than the prediction anchor timestamp "
|
|
789
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
790
|
+
f"intended.")
|
|
791
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
794
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
795
|
+
f"'{context_anchor_time}' will leak information "
|
|
796
|
+
f"from the prediction anchor timestamp "
|
|
797
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
798
|
+
f"intended.")
|
|
799
|
+
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
802
|
+
< self._graph_store.min_time):
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
804
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
805
|
+
f"aggregation time range is too large. To form "
|
|
806
|
+
f"proper input data, we would need data back to "
|
|
807
|
+
f"'{_time}', however, your data only contains "
|
|
442
808
|
f"data back to '{self._graph_store.min_time}'.")
|
|
443
809
|
|
|
444
810
|
if (not evaluate and anchor_time
|
|
@@ -447,8 +813,7 @@ class KumoRFM:
|
|
|
447
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
448
814
|
f"in the data. Please make sure this is intended.")
|
|
449
815
|
|
|
450
|
-
max_eval_time =
|
|
451
|
-
query.target.end_offset * query.num_forecasts)
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
452
817
|
if evaluate and anchor_time > max_eval_time:
|
|
453
818
|
raise ValueError(
|
|
454
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -456,8 +821,10 @@ class KumoRFM:
|
|
|
456
821
|
|
|
457
822
|
def _get_context(
|
|
458
823
|
self,
|
|
459
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
825
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
460
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
827
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
461
828
|
run_mode: RunMode,
|
|
462
829
|
num_neighbors: Optional[List[int]],
|
|
463
830
|
num_hops: int,
|
|
@@ -482,8 +849,8 @@ class KumoRFM:
|
|
|
482
849
|
f"must go beyond this for your use-case.")
|
|
483
850
|
|
|
484
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
485
|
-
task_type =
|
|
486
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
487
854
|
edge_types=self._graph_store.edge_types,
|
|
488
855
|
)
|
|
489
856
|
|
|
@@ -515,28 +882,42 @@ class KumoRFM:
|
|
|
515
882
|
else:
|
|
516
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
517
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
518
890
|
if anchor_time is None:
|
|
519
891
|
anchor_time = self._graph_store.max_time
|
|
520
892
|
if evaluate:
|
|
521
|
-
anchor_time = anchor_time -
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
522
894
|
if logger is not None:
|
|
523
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
524
|
-
if
|
|
525
|
-
|
|
526
|
-
|
|
896
|
+
if anchor_time == pd.Timestamp.min:
|
|
897
|
+
pass # Static graph
|
|
898
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
899
|
+
and anchor_time.second == 0
|
|
900
|
+
and anchor_time.microsecond == 0):
|
|
527
901
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
528
902
|
else:
|
|
529
903
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
530
904
|
|
|
531
905
|
assert anchor_time is not None
|
|
532
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
533
|
-
|
|
907
|
+
if context_anchor_time is None:
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
909
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
910
|
+
evaluate)
|
|
534
911
|
else:
|
|
535
912
|
assert anchor_time == 'entity'
|
|
536
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
537
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
538
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
539
916
|
f"have a time column")
|
|
917
|
+
if context_anchor_time is not None:
|
|
918
|
+
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
919
|
+
"`anchor_time='entity'`")
|
|
920
|
+
context_anchor_time = None
|
|
540
921
|
|
|
541
922
|
y_test: Optional[pd.Series] = None
|
|
542
923
|
if evaluate:
|
|
@@ -548,6 +929,7 @@ class KumoRFM:
|
|
|
548
929
|
size=max_test_size,
|
|
549
930
|
anchor_time=anchor_time,
|
|
550
931
|
max_iterations=max_pq_iterations,
|
|
932
|
+
guarantee_train_examples=True,
|
|
551
933
|
)
|
|
552
934
|
if logger is not None:
|
|
553
935
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -571,34 +953,31 @@ class KumoRFM:
|
|
|
571
953
|
logger.log(msg)
|
|
572
954
|
|
|
573
955
|
else:
|
|
574
|
-
assert
|
|
956
|
+
assert indices is not None
|
|
575
957
|
|
|
576
|
-
|
|
577
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
958
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
578
959
|
raise ValueError(f"Cannot predict for more than "
|
|
579
|
-
f"{
|
|
580
|
-
f"(got {len(
|
|
960
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
961
|
+
f"once (got {len(indices):,}). Use "
|
|
962
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
963
|
+
f"in batches")
|
|
581
964
|
|
|
582
965
|
test_node = self._graph_store.get_node_id(
|
|
583
|
-
table_name=query.
|
|
584
|
-
pkey=pd.Series(
|
|
585
|
-
query.entity.ids.value,
|
|
586
|
-
dtype=query.entity.ids.dtype,
|
|
587
|
-
),
|
|
966
|
+
table_name=query.entity_table,
|
|
967
|
+
pkey=pd.Series(indices),
|
|
588
968
|
)
|
|
589
969
|
|
|
590
970
|
if isinstance(anchor_time, pd.Timestamp):
|
|
591
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
592
972
|
len(test_node)).reset_index(drop=True)
|
|
593
973
|
else:
|
|
594
|
-
time = self._graph_store.time_dict[
|
|
595
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
596
975
|
time = time[test_node] * 1000**3
|
|
597
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
598
977
|
|
|
599
978
|
train_node, train_time, y_train = query_driver.collect_train(
|
|
600
979
|
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
601
|
-
anchor_time=
|
|
980
|
+
anchor_time=context_anchor_time or 'entity',
|
|
602
981
|
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
603
982
|
or anchor_time == 'entity') else None,
|
|
604
983
|
max_iterations=max_pq_iterations,
|
|
@@ -625,12 +1004,23 @@ class KumoRFM:
|
|
|
625
1004
|
raise NotImplementedError
|
|
626
1005
|
logger.log(msg)
|
|
627
1006
|
|
|
628
|
-
entity_table_names
|
|
629
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
630
1020
|
|
|
631
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
632
1022
|
# running out-of-distribution between in-context and test examples:
|
|
633
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
634
1024
|
if anchor_time == 'entity':
|
|
635
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
636
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -650,9 +1040,16 @@ class KumoRFM:
|
|
|
650
1040
|
exclude_cols_dict=exclude_cols_dict,
|
|
651
1041
|
)
|
|
652
1042
|
|
|
1043
|
+
if len(subgraph.table_dict) >= 15:
|
|
1044
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
1045
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
1046
|
+
f"Please create a feature request at "
|
|
1047
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1048
|
+
f"must go beyond this for your use-case.")
|
|
1049
|
+
|
|
653
1050
|
step_size: Optional[int] = None
|
|
654
1051
|
if query.query_type == QueryType.TEMPORAL:
|
|
655
|
-
step_size = date_offset_to_seconds(
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
656
1053
|
|
|
657
1054
|
return Context(
|
|
658
1055
|
task_type=task_type,
|
|
@@ -677,7 +1074,7 @@ class KumoRFM:
|
|
|
677
1074
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
678
1075
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
679
1076
|
elif task_type == TaskType.REGRESSION:
|
|
680
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1077
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
681
1078
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
682
1079
|
supported_metrics = [
|
|
683
1080
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|