kumoai 2.9.0.dev202509081831__cp312-cp312-win_amd64.whl → 2.13.0.dev202511201731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +71 -102
- kumoai/connector/utils.py +1367 -236
- kumoai/experimental/rfm/__init__.py +153 -10
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -80
- kumoai/experimental/rfm/local_graph_sampler.py +16 -10
- kumoai/experimental/rfm/local_graph_store.py +22 -6
- kumoai/experimental/rfm/local_pquery_driver.py +336 -42
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -222
- kumoai/experimental/rfm/rfm.py +523 -124
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/jobs.py +1 -0
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/progress_logger.py +68 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/METADATA +13 -5
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/RECORD +30 -29
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/WHEEL +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,21 +1,36 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
26
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
13
29
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
30
|
)
|
|
16
31
|
from kumoapi.task import TaskType
|
|
17
32
|
|
|
18
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
19
34
|
from kumoai.exceptions import HTTPException
|
|
20
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
21
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
24
39
|
LocalPQueryDriver,
|
|
25
40
|
date_offset_to_seconds,
|
|
26
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
27
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
28
44
|
|
|
29
45
|
_RANDOM_SEED = 42
|
|
30
46
|
|
|
47
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
|
+
|
|
31
50
|
_MAX_CONTEXT_SIZE = {
|
|
32
51
|
RunMode.DEBUG: 100,
|
|
33
52
|
RunMode.FAST: 1_000,
|
|
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
42
61
|
}
|
|
43
62
|
|
|
44
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
45
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
64
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
46
65
|
"reduce either the number of tables in the graph, their "
|
|
47
66
|
"number of columns (e.g., large text columns), "
|
|
48
67
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
51
70
|
"beyond this for your use-case.")
|
|
52
71
|
|
|
53
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
54
118
|
class KumoRFM:
|
|
55
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
56
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -77,9 +141,9 @@ class KumoRFM:
|
|
|
77
141
|
|
|
78
142
|
rfm = KumoRFM(graph)
|
|
79
143
|
|
|
80
|
-
query = ("PREDICT COUNT(
|
|
81
|
-
"FOR users.user_id=
|
|
82
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
83
147
|
|
|
84
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
85
149
|
# 1 0.85
|
|
@@ -108,28 +172,122 @@ class KumoRFM:
|
|
|
108
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
109
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
110
174
|
|
|
175
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
|
+
self.num_retries: int = 0
|
|
177
|
+
from kumoai.experimental.rfm import global_state
|
|
178
|
+
self._api_client = RFMAPI(global_state.client)
|
|
179
|
+
|
|
111
180
|
def __repr__(self) -> str:
|
|
112
181
|
return f'{self.__class__.__name__}()'
|
|
113
182
|
|
|
183
|
+
@contextmanager
|
|
184
|
+
def batch_mode(
|
|
185
|
+
self,
|
|
186
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
187
|
+
num_retries: int = 1,
|
|
188
|
+
) -> Generator[None, None, None]:
|
|
189
|
+
"""Context manager to predict in batches.
|
|
190
|
+
|
|
191
|
+
.. code-block:: python
|
|
192
|
+
|
|
193
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
194
|
+
df = model.predict(query, indices=...)
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
198
|
+
maximum applicable batch size for the given task.
|
|
199
|
+
num_retries: The maximum number of retries for failed queries due
|
|
200
|
+
to unexpected server issues.
|
|
201
|
+
"""
|
|
202
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
203
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
204
|
+
f"(got {batch_size})")
|
|
205
|
+
|
|
206
|
+
if num_retries < 0:
|
|
207
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
208
|
+
f"zero (got {num_retries})")
|
|
209
|
+
|
|
210
|
+
self._batch_size = batch_size
|
|
211
|
+
self.num_retries = num_retries
|
|
212
|
+
yield
|
|
213
|
+
self._batch_size = None
|
|
214
|
+
self.num_retries = 0
|
|
215
|
+
|
|
216
|
+
@overload
|
|
114
217
|
def predict(
|
|
115
218
|
self,
|
|
116
219
|
query: str,
|
|
220
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
117
221
|
*,
|
|
222
|
+
explain: Literal[False] = False,
|
|
118
223
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
224
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
119
225
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
120
226
|
num_neighbors: Optional[List[int]] = None,
|
|
121
227
|
num_hops: int = 2,
|
|
122
228
|
max_pq_iterations: int = 20,
|
|
123
229
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
124
230
|
verbose: Union[bool, ProgressLogger] = True,
|
|
231
|
+
use_prediction_time: bool = False,
|
|
125
232
|
) -> pd.DataFrame:
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
@overload
|
|
236
|
+
def predict(
|
|
237
|
+
self,
|
|
238
|
+
query: str,
|
|
239
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
240
|
+
*,
|
|
241
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
242
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
243
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
244
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
245
|
+
num_neighbors: Optional[List[int]] = None,
|
|
246
|
+
num_hops: int = 2,
|
|
247
|
+
max_pq_iterations: int = 20,
|
|
248
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
249
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
250
|
+
use_prediction_time: bool = False,
|
|
251
|
+
) -> Explanation:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
def predict(
|
|
255
|
+
self,
|
|
256
|
+
query: str,
|
|
257
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
258
|
+
*,
|
|
259
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
260
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
261
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
262
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
263
|
+
num_neighbors: Optional[List[int]] = None,
|
|
264
|
+
num_hops: int = 2,
|
|
265
|
+
max_pq_iterations: int = 20,
|
|
266
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
267
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
268
|
+
use_prediction_time: bool = False,
|
|
269
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
126
270
|
"""Returns predictions for a predictive query.
|
|
127
271
|
|
|
128
272
|
Args:
|
|
129
273
|
query: The predictive query.
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
274
|
+
indices: The entity primary keys to predict for. Will override the
|
|
275
|
+
indices given as part of the predictive query. Predictions will
|
|
276
|
+
be generated for all indices, independent of whether they
|
|
277
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
278
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
279
|
+
explain: Configuration for explainability.
|
|
280
|
+
If set to ``True``, will additionally explain the prediction.
|
|
281
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
282
|
+
over which parts of explanation are generated.
|
|
283
|
+
Explainability is currently only supported for single entity
|
|
284
|
+
predictions with ``run_mode="FAST"``.
|
|
285
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
286
|
+
``None``, will use the maximum timestamp in the data.
|
|
287
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
288
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
289
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
290
|
+
determine the anchor time for context examples.
|
|
133
291
|
run_mode: The :class:`RunMode` for the query.
|
|
134
292
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
135
293
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -141,32 +299,54 @@ class KumoRFM:
|
|
|
141
299
|
entities to find valid labels.
|
|
142
300
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
143
301
|
verbose: Whether to print verbose output.
|
|
302
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
303
|
+
additional feature during prediction. This is typically
|
|
304
|
+
beneficial for time series forecasting tasks.
|
|
144
305
|
|
|
145
306
|
Returns:
|
|
146
|
-
The predictions as a :class:`pandas.DataFrame
|
|
307
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
308
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
|
+
containing the prediction, summary, and details.
|
|
147
310
|
"""
|
|
148
|
-
|
|
311
|
+
explain_config: Optional[ExplainConfig] = None
|
|
312
|
+
if explain is True:
|
|
313
|
+
explain_config = ExplainConfig()
|
|
314
|
+
elif explain is not False:
|
|
315
|
+
explain_config = ExplainConfig._cast(explain)
|
|
316
|
+
|
|
149
317
|
query_def = self._parse_query(query)
|
|
318
|
+
query_str = query_def.to_string()
|
|
150
319
|
|
|
151
320
|
if num_hops != 2 and num_neighbors is not None:
|
|
152
321
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
153
322
|
f"custom 'num_hops={num_hops}' option")
|
|
154
323
|
|
|
155
|
-
if
|
|
324
|
+
if explain_config is not None and run_mode in {
|
|
325
|
+
RunMode.NORMAL, RunMode.BEST
|
|
326
|
+
}:
|
|
156
327
|
warnings.warn(f"Explainability is currently only supported for "
|
|
157
328
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
158
329
|
f"mode has been reset. Please lower the run mode to "
|
|
159
330
|
f"suppress this warning.")
|
|
160
331
|
|
|
161
|
-
if
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
332
|
+
if indices is None:
|
|
333
|
+
if query_def.rfm_entity_ids is None:
|
|
334
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
335
|
+
"pass them via `predict(query, indices=...)`")
|
|
336
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
337
|
+
else:
|
|
338
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
339
|
+
|
|
340
|
+
if len(indices) == 0:
|
|
341
|
+
raise ValueError("At least one entity is required")
|
|
342
|
+
|
|
343
|
+
if explain_config is not None and len(indices) > 1:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Cannot explain predictions for more than a single entity "
|
|
346
|
+
f"(got {len(indices)})")
|
|
167
347
|
|
|
168
348
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
169
|
-
if
|
|
349
|
+
if explain_config is not None:
|
|
170
350
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
171
351
|
else:
|
|
172
352
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -175,48 +355,188 @@ class KumoRFM:
|
|
|
175
355
|
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
176
356
|
|
|
177
357
|
with verbose as logger:
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
358
|
+
|
|
359
|
+
batch_size: Optional[int] = None
|
|
360
|
+
if self._batch_size == 'max':
|
|
361
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
362
|
+
query_def,
|
|
363
|
+
edge_types=self._graph_store.edge_types,
|
|
364
|
+
)
|
|
365
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
366
|
+
else:
|
|
367
|
+
batch_size = self._batch_size
|
|
368
|
+
|
|
369
|
+
if batch_size is not None:
|
|
370
|
+
offsets = range(0, len(indices), batch_size)
|
|
371
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
372
|
+
else:
|
|
373
|
+
batches = [indices]
|
|
374
|
+
|
|
375
|
+
if len(batches) > 1:
|
|
376
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
377
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
378
|
+
|
|
379
|
+
predictions: List[pd.DataFrame] = []
|
|
380
|
+
summary: Optional[str] = None
|
|
381
|
+
details: Optional[Explanation] = None
|
|
382
|
+
for i, batch in enumerate(batches):
|
|
383
|
+
# TODO Re-use the context for subsequent predictions.
|
|
384
|
+
context = self._get_context(
|
|
385
|
+
query=query_def,
|
|
386
|
+
indices=batch,
|
|
387
|
+
anchor_time=anchor_time,
|
|
388
|
+
context_anchor_time=context_anchor_time,
|
|
389
|
+
run_mode=RunMode(run_mode),
|
|
390
|
+
num_neighbors=num_neighbors,
|
|
391
|
+
num_hops=num_hops,
|
|
392
|
+
max_pq_iterations=max_pq_iterations,
|
|
393
|
+
evaluate=False,
|
|
394
|
+
random_seed=random_seed,
|
|
395
|
+
logger=logger if i == 0 else None,
|
|
396
|
+
)
|
|
397
|
+
request = RFMPredictRequest(
|
|
398
|
+
context=context,
|
|
399
|
+
run_mode=RunMode(run_mode),
|
|
400
|
+
query=query_str,
|
|
401
|
+
use_prediction_time=use_prediction_time,
|
|
402
|
+
)
|
|
403
|
+
with warnings.catch_warnings():
|
|
404
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
405
|
+
request_msg = request.to_protobuf()
|
|
406
|
+
_bytes = request_msg.SerializeToString()
|
|
407
|
+
if i == 0:
|
|
408
|
+
logger.log(f"Generated context of size "
|
|
409
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
410
|
+
|
|
411
|
+
if len(_bytes) > _MAX_SIZE:
|
|
412
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
413
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
414
|
+
|
|
415
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
416
|
+
and len(batches) > 1):
|
|
417
|
+
verbose.init_progress(
|
|
418
|
+
total=len(batches),
|
|
419
|
+
description='Predicting',
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
for attempt in range(self.num_retries + 1):
|
|
423
|
+
try:
|
|
424
|
+
if explain_config is not None:
|
|
425
|
+
resp = self._api_client.explain(
|
|
426
|
+
request=_bytes,
|
|
427
|
+
skip_summary=explain_config.skip_summary,
|
|
428
|
+
)
|
|
429
|
+
summary = resp.summary
|
|
430
|
+
details = resp.details
|
|
431
|
+
else:
|
|
432
|
+
resp = self._api_client.predict(_bytes)
|
|
433
|
+
df = pd.DataFrame(**resp.prediction)
|
|
434
|
+
|
|
435
|
+
# Cast 'ENTITY' to correct data type:
|
|
436
|
+
if 'ENTITY' in df:
|
|
437
|
+
entity = query_def.entity_table
|
|
438
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
439
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
440
|
+
type(pkey_map.index[0]))
|
|
441
|
+
|
|
442
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
443
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
444
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
445
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
446
|
+
if isinstance(ser.iloc[0], str):
|
|
447
|
+
unit = None
|
|
448
|
+
else:
|
|
449
|
+
unit = 'ms'
|
|
450
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
451
|
+
ser, errors='coerce', unit=unit)
|
|
452
|
+
|
|
453
|
+
predictions.append(df)
|
|
454
|
+
|
|
455
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
456
|
+
and len(batches) > 1):
|
|
457
|
+
verbose.step()
|
|
458
|
+
|
|
459
|
+
break
|
|
460
|
+
except HTTPException as e:
|
|
461
|
+
if attempt == self.num_retries:
|
|
462
|
+
try:
|
|
463
|
+
msg = json.loads(e.detail)['detail']
|
|
464
|
+
except Exception:
|
|
465
|
+
msg = e.detail
|
|
466
|
+
raise RuntimeError(
|
|
467
|
+
f"An unexpected exception occurred. Please "
|
|
468
|
+
f"create an issue at "
|
|
469
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
470
|
+
) from None
|
|
471
|
+
|
|
472
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
473
|
+
|
|
474
|
+
if len(predictions) == 1:
|
|
475
|
+
prediction = predictions[0]
|
|
476
|
+
else:
|
|
477
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
478
|
+
|
|
479
|
+
if explain_config is not None:
|
|
480
|
+
assert len(predictions) == 1
|
|
481
|
+
assert summary is not None
|
|
482
|
+
assert details is not None
|
|
483
|
+
return Explanation(
|
|
484
|
+
prediction=prediction,
|
|
485
|
+
summary=summary,
|
|
486
|
+
details=details,
|
|
192
487
|
)
|
|
193
|
-
with warnings.catch_warnings():
|
|
194
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
195
|
-
request_msg = request.to_protobuf()
|
|
196
|
-
request_bytes = request_msg.SerializeToString()
|
|
197
|
-
logger.log(f"Generated context of size "
|
|
198
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
199
488
|
|
|
200
|
-
|
|
201
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
202
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
489
|
+
return prediction
|
|
203
490
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
raise RuntimeError(f"An unexpected exception occurred. "
|
|
215
|
-
f"Please create an issue at "
|
|
216
|
-
f"'https://github.com/kumo-ai/kumo-rfm'. "
|
|
217
|
-
f"{msg}") from None
|
|
491
|
+
def is_valid_entity(
|
|
492
|
+
self,
|
|
493
|
+
query: str,
|
|
494
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
495
|
+
*,
|
|
496
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
497
|
+
) -> np.ndarray:
|
|
498
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
499
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
500
|
+
entity filter constraints.
|
|
218
501
|
|
|
219
|
-
|
|
502
|
+
Args:
|
|
503
|
+
query: The predictive query.
|
|
504
|
+
indices: The entity primary keys to predict for. Will override the
|
|
505
|
+
indices given as part of the predictive query.
|
|
506
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
507
|
+
``None``, will use the maximum timestamp in the data.
|
|
508
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
509
|
+
"""
|
|
510
|
+
query_def = self._parse_query(query)
|
|
511
|
+
|
|
512
|
+
if indices is None:
|
|
513
|
+
if query_def.rfm_entity_ids is None:
|
|
514
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
515
|
+
"pass them via "
|
|
516
|
+
"`is_valid_entity(query, indices=...)`")
|
|
517
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
518
|
+
|
|
519
|
+
if len(indices) == 0:
|
|
520
|
+
raise ValueError("At least one entity is required")
|
|
521
|
+
|
|
522
|
+
if anchor_time is None:
|
|
523
|
+
anchor_time = self._graph_store.max_time
|
|
524
|
+
|
|
525
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
526
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
527
|
+
else:
|
|
528
|
+
assert anchor_time == 'entity'
|
|
529
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
530
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
531
|
+
f"table '{query_def.entity_table}' "
|
|
532
|
+
f"to have a time column.")
|
|
533
|
+
|
|
534
|
+
node = self._graph_store.get_node_id(
|
|
535
|
+
table_name=query_def.entity_table,
|
|
536
|
+
pkey=pd.Series(indices),
|
|
537
|
+
)
|
|
538
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
539
|
+
return query_driver.is_valid(node, anchor_time)
|
|
220
540
|
|
|
221
541
|
def evaluate(
|
|
222
542
|
self,
|
|
@@ -224,21 +544,26 @@ class KumoRFM:
|
|
|
224
544
|
*,
|
|
225
545
|
metrics: Optional[List[str]] = None,
|
|
226
546
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
547
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
227
548
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
228
549
|
num_neighbors: Optional[List[int]] = None,
|
|
229
550
|
num_hops: int = 2,
|
|
230
551
|
max_pq_iterations: int = 20,
|
|
231
552
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
232
553
|
verbose: Union[bool, ProgressLogger] = True,
|
|
554
|
+
use_prediction_time: bool = False,
|
|
233
555
|
) -> pd.DataFrame:
|
|
234
556
|
"""Evaluates a predictive query.
|
|
235
557
|
|
|
236
558
|
Args:
|
|
237
559
|
query: The predictive query.
|
|
238
560
|
metrics: The metrics to use.
|
|
239
|
-
anchor_time: The anchor timestamp for the
|
|
240
|
-
|
|
241
|
-
If set to
|
|
561
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
562
|
+
``None``, will use the maximum timestamp in the data.
|
|
563
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
564
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
565
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
566
|
+
determine the anchor time for context examples.
|
|
242
567
|
run_mode: The :class:`RunMode` for the query.
|
|
243
568
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
244
569
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -250,6 +575,9 @@ class KumoRFM:
|
|
|
250
575
|
entities to find valid labels.
|
|
251
576
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
252
577
|
verbose: Whether to print verbose output.
|
|
578
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
579
|
+
additional feature during prediction. This is typically
|
|
580
|
+
beneficial for time series forecasting tasks.
|
|
253
581
|
|
|
254
582
|
Returns:
|
|
255
583
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -260,6 +588,12 @@ class KumoRFM:
|
|
|
260
588
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
261
589
|
f"custom 'num_hops={num_hops}' option")
|
|
262
590
|
|
|
591
|
+
if query_def.rfm_entity_ids is not None:
|
|
592
|
+
query_def = replace(
|
|
593
|
+
query_def,
|
|
594
|
+
rfm_entity_ids=None,
|
|
595
|
+
)
|
|
596
|
+
|
|
263
597
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
264
598
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
265
599
|
|
|
@@ -268,8 +602,10 @@ class KumoRFM:
|
|
|
268
602
|
|
|
269
603
|
with verbose as logger:
|
|
270
604
|
context = self._get_context(
|
|
271
|
-
query_def,
|
|
605
|
+
query=query_def,
|
|
606
|
+
indices=None,
|
|
272
607
|
anchor_time=anchor_time,
|
|
608
|
+
context_anchor_time=context_anchor_time,
|
|
273
609
|
run_mode=RunMode(run_mode),
|
|
274
610
|
num_neighbors=num_neighbors,
|
|
275
611
|
num_hops=num_hops,
|
|
@@ -285,6 +621,7 @@ class KumoRFM:
|
|
|
285
621
|
context=context,
|
|
286
622
|
run_mode=RunMode(run_mode),
|
|
287
623
|
metrics=metrics,
|
|
624
|
+
use_prediction_time=use_prediction_time,
|
|
288
625
|
)
|
|
289
626
|
with warnings.catch_warnings():
|
|
290
627
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -295,10 +632,10 @@ class KumoRFM:
|
|
|
295
632
|
|
|
296
633
|
if len(request_bytes) > _MAX_SIZE:
|
|
297
634
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
298
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
635
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
299
636
|
|
|
300
637
|
try:
|
|
301
|
-
resp =
|
|
638
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
302
639
|
except HTTPException as e:
|
|
303
640
|
try:
|
|
304
641
|
msg = json.loads(e.detail)['detail']
|
|
@@ -343,17 +680,19 @@ class KumoRFM:
|
|
|
343
680
|
|
|
344
681
|
if anchor_time is None:
|
|
345
682
|
anchor_time = self._graph_store.max_time
|
|
346
|
-
|
|
683
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
+
anchor_time = anchor_time - (
|
|
685
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
686
|
+
query_def.num_forecasts)
|
|
347
687
|
|
|
348
688
|
assert anchor_time is not None
|
|
349
689
|
if isinstance(anchor_time, pd.Timestamp):
|
|
350
|
-
self._validate_time(query_def, anchor_time, evaluate=True)
|
|
690
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
351
691
|
else:
|
|
352
692
|
assert anchor_time == 'entity'
|
|
353
|
-
if (query_def.
|
|
354
|
-
not in self._graph_store.time_dict):
|
|
693
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
355
694
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
356
|
-
f"table '{query_def.
|
|
695
|
+
f"table '{query_def.entity_table}' "
|
|
357
696
|
f"to have a time column")
|
|
358
697
|
|
|
359
698
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -364,18 +703,22 @@ class KumoRFM:
|
|
|
364
703
|
anchor_time=anchor_time,
|
|
365
704
|
batch_size=min(10_000, size),
|
|
366
705
|
max_iterations=max_iterations,
|
|
706
|
+
guarantee_train_examples=False,
|
|
367
707
|
)
|
|
368
708
|
|
|
709
|
+
entity = self._graph_store.pkey_map_dict[
|
|
710
|
+
query_def.entity_table].index[node]
|
|
711
|
+
|
|
369
712
|
return pd.DataFrame({
|
|
370
|
-
'ENTITY':
|
|
713
|
+
'ENTITY': entity,
|
|
371
714
|
'ANCHOR_TIMESTAMP': time,
|
|
372
715
|
'TARGET': y,
|
|
373
716
|
})
|
|
374
717
|
|
|
375
718
|
# Helpers #################################################################
|
|
376
719
|
|
|
377
|
-
def _parse_query(self, query: str) ->
|
|
378
|
-
if isinstance(query,
|
|
720
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
721
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
379
722
|
return query
|
|
380
723
|
|
|
381
724
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -385,12 +728,13 @@ class KumoRFM:
|
|
|
385
728
|
"predictions or evaluations.")
|
|
386
729
|
|
|
387
730
|
try:
|
|
388
|
-
request =
|
|
731
|
+
request = RFMParseQueryRequest(
|
|
389
732
|
query=query,
|
|
390
733
|
graph_definition=self._graph_def,
|
|
391
734
|
)
|
|
392
735
|
|
|
393
|
-
resp =
|
|
736
|
+
resp = self._api_client.parse_query(request)
|
|
737
|
+
|
|
394
738
|
# TODO Expose validation warnings.
|
|
395
739
|
|
|
396
740
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -401,7 +745,7 @@ class KumoRFM:
|
|
|
401
745
|
warnings.warn(f"Encountered the following warnings during "
|
|
402
746
|
f"parsing:\n{msg}")
|
|
403
747
|
|
|
404
|
-
return resp.
|
|
748
|
+
return resp.query
|
|
405
749
|
except HTTPException as e:
|
|
406
750
|
try:
|
|
407
751
|
msg = json.loads(e.detail)['detail']
|
|
@@ -412,8 +756,9 @@ class KumoRFM:
|
|
|
412
756
|
|
|
413
757
|
def _validate_time(
|
|
414
758
|
self,
|
|
415
|
-
query:
|
|
759
|
+
query: ValidatedPredictiveQuery,
|
|
416
760
|
anchor_time: pd.Timestamp,
|
|
761
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
417
762
|
evaluate: bool,
|
|
418
763
|
) -> None:
|
|
419
764
|
|
|
@@ -425,20 +770,44 @@ class KumoRFM:
|
|
|
425
770
|
f"the earliest timestamp "
|
|
426
771
|
f"'{self._graph_store.min_time}' in the data.")
|
|
427
772
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
raise ValueError(f"
|
|
431
|
-
f"time range is too large. To make
|
|
432
|
-
f"prediction, we would need data back to "
|
|
433
|
-
f"'{
|
|
434
|
-
f"data back to
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
if
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
773
|
+
if (context_anchor_time is not None
|
|
774
|
+
and context_anchor_time < self._graph_store.min_time):
|
|
775
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
776
|
+
f"aggregation time range is too large. To make "
|
|
777
|
+
f"this prediction, we would need data back to "
|
|
778
|
+
f"'{context_anchor_time}', however, your data "
|
|
779
|
+
f"only contains data back to "
|
|
780
|
+
f"'{self._graph_store.min_time}'.")
|
|
781
|
+
|
|
782
|
+
if query.target_ast.date_offset_range is not None:
|
|
783
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
|
+
else:
|
|
785
|
+
end_offset = pd.DateOffset(0)
|
|
786
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
787
|
+
if (context_anchor_time is not None
|
|
788
|
+
and context_anchor_time > anchor_time):
|
|
789
|
+
warnings.warn(f"Context anchor timestamp "
|
|
790
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
791
|
+
f"date than the prediction anchor timestamp "
|
|
792
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
793
|
+
f"intended.")
|
|
794
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
795
|
+
and context_anchor_time is not None
|
|
796
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
797
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
798
|
+
f"'{context_anchor_time}' will leak information "
|
|
799
|
+
f"from the prediction anchor timestamp "
|
|
800
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
801
|
+
f"intended.")
|
|
802
|
+
|
|
803
|
+
elif (context_anchor_time is not None
|
|
804
|
+
and context_anchor_time - forecast_end_offset
|
|
805
|
+
< self._graph_store.min_time):
|
|
806
|
+
_time = context_anchor_time - forecast_end_offset
|
|
807
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
808
|
+
f"aggregation time range is too large. To form "
|
|
809
|
+
f"proper input data, we would need data back to "
|
|
810
|
+
f"'{_time}', however, your data only contains "
|
|
442
811
|
f"data back to '{self._graph_store.min_time}'.")
|
|
443
812
|
|
|
444
813
|
if (not evaluate and anchor_time
|
|
@@ -447,8 +816,7 @@ class KumoRFM:
|
|
|
447
816
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
448
817
|
f"in the data. Please make sure this is intended.")
|
|
449
818
|
|
|
450
|
-
max_eval_time =
|
|
451
|
-
query.target.end_offset * query.num_forecasts)
|
|
819
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
452
820
|
if evaluate and anchor_time > max_eval_time:
|
|
453
821
|
raise ValueError(
|
|
454
822
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -456,8 +824,10 @@ class KumoRFM:
|
|
|
456
824
|
|
|
457
825
|
def _get_context(
|
|
458
826
|
self,
|
|
459
|
-
query:
|
|
827
|
+
query: ValidatedPredictiveQuery,
|
|
828
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
460
829
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
830
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
461
831
|
run_mode: RunMode,
|
|
462
832
|
num_neighbors: Optional[List[int]],
|
|
463
833
|
num_hops: int,
|
|
@@ -482,8 +852,8 @@ class KumoRFM:
|
|
|
482
852
|
f"must go beyond this for your use-case.")
|
|
483
853
|
|
|
484
854
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
485
|
-
task_type =
|
|
486
|
-
|
|
855
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
856
|
+
query,
|
|
487
857
|
edge_types=self._graph_store.edge_types,
|
|
488
858
|
)
|
|
489
859
|
|
|
@@ -515,28 +885,42 @@ class KumoRFM:
|
|
|
515
885
|
else:
|
|
516
886
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
517
887
|
|
|
888
|
+
if query.target_ast.date_offset_range is None:
|
|
889
|
+
end_offset = pd.DateOffset(0)
|
|
890
|
+
else:
|
|
891
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
892
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
518
893
|
if anchor_time is None:
|
|
519
894
|
anchor_time = self._graph_store.max_time
|
|
520
895
|
if evaluate:
|
|
521
|
-
anchor_time = anchor_time -
|
|
896
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
522
897
|
if logger is not None:
|
|
523
898
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
524
|
-
if
|
|
525
|
-
|
|
526
|
-
|
|
899
|
+
if anchor_time == pd.Timestamp.min:
|
|
900
|
+
pass # Static graph
|
|
901
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
902
|
+
and anchor_time.second == 0
|
|
903
|
+
and anchor_time.microsecond == 0):
|
|
527
904
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
528
905
|
else:
|
|
529
906
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
530
907
|
|
|
531
908
|
assert anchor_time is not None
|
|
532
909
|
if isinstance(anchor_time, pd.Timestamp):
|
|
533
|
-
|
|
910
|
+
if context_anchor_time is None:
|
|
911
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
912
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
913
|
+
evaluate)
|
|
534
914
|
else:
|
|
535
915
|
assert anchor_time == 'entity'
|
|
536
|
-
if query.
|
|
916
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
537
917
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
538
|
-
f"table '{query.
|
|
918
|
+
f"table '{query.entity_table}' to "
|
|
539
919
|
f"have a time column")
|
|
920
|
+
if context_anchor_time is not None:
|
|
921
|
+
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
922
|
+
"`anchor_time='entity'`")
|
|
923
|
+
context_anchor_time = None
|
|
540
924
|
|
|
541
925
|
y_test: Optional[pd.Series] = None
|
|
542
926
|
if evaluate:
|
|
@@ -548,6 +932,7 @@ class KumoRFM:
|
|
|
548
932
|
size=max_test_size,
|
|
549
933
|
anchor_time=anchor_time,
|
|
550
934
|
max_iterations=max_pq_iterations,
|
|
935
|
+
guarantee_train_examples=True,
|
|
551
936
|
)
|
|
552
937
|
if logger is not None:
|
|
553
938
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -571,34 +956,31 @@ class KumoRFM:
|
|
|
571
956
|
logger.log(msg)
|
|
572
957
|
|
|
573
958
|
else:
|
|
574
|
-
assert
|
|
959
|
+
assert indices is not None
|
|
575
960
|
|
|
576
|
-
|
|
577
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
961
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
578
962
|
raise ValueError(f"Cannot predict for more than "
|
|
579
|
-
f"{
|
|
580
|
-
f"(got {len(
|
|
963
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
964
|
+
f"once (got {len(indices):,}). Use "
|
|
965
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
966
|
+
f"in batches")
|
|
581
967
|
|
|
582
968
|
test_node = self._graph_store.get_node_id(
|
|
583
|
-
table_name=query.
|
|
584
|
-
pkey=pd.Series(
|
|
585
|
-
query.entity.ids.value,
|
|
586
|
-
dtype=query.entity.ids.dtype,
|
|
587
|
-
),
|
|
969
|
+
table_name=query.entity_table,
|
|
970
|
+
pkey=pd.Series(indices),
|
|
588
971
|
)
|
|
589
972
|
|
|
590
973
|
if isinstance(anchor_time, pd.Timestamp):
|
|
591
974
|
test_time = pd.Series(anchor_time).repeat(
|
|
592
975
|
len(test_node)).reset_index(drop=True)
|
|
593
976
|
else:
|
|
594
|
-
time = self._graph_store.time_dict[
|
|
595
|
-
query.entity.pkey.table_name]
|
|
977
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
596
978
|
time = time[test_node] * 1000**3
|
|
597
979
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
598
980
|
|
|
599
981
|
train_node, train_time, y_train = query_driver.collect_train(
|
|
600
982
|
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
601
|
-
anchor_time=
|
|
983
|
+
anchor_time=context_anchor_time or 'entity',
|
|
602
984
|
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
603
985
|
or anchor_time == 'entity') else None,
|
|
604
986
|
max_iterations=max_pq_iterations,
|
|
@@ -625,12 +1007,23 @@ class KumoRFM:
|
|
|
625
1007
|
raise NotImplementedError
|
|
626
1008
|
logger.log(msg)
|
|
627
1009
|
|
|
628
|
-
entity_table_names
|
|
629
|
-
|
|
1010
|
+
entity_table_names: Tuple[str, ...]
|
|
1011
|
+
if task_type.is_link_pred:
|
|
1012
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1013
|
+
assert final_aggr is not None
|
|
1014
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
+
for edge_type in self._graph_store.edge_types:
|
|
1016
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
|
+
entity_table_names = (
|
|
1018
|
+
query.entity_table,
|
|
1019
|
+
edge_type[2],
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
entity_table_names = (query.entity_table, )
|
|
630
1023
|
|
|
631
1024
|
# Exclude the entity anchor time from the feature set to prevent
|
|
632
1025
|
# running out-of-distribution between in-context and test examples:
|
|
633
|
-
exclude_cols_dict = query.
|
|
1026
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
634
1027
|
if anchor_time == 'entity':
|
|
635
1028
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
636
1029
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -645,14 +1038,20 @@ class KumoRFM:
|
|
|
645
1038
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
646
1039
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
647
1040
|
]),
|
|
648
|
-
run_mode=run_mode,
|
|
649
1041
|
num_neighbors=num_neighbors,
|
|
650
1042
|
exclude_cols_dict=exclude_cols_dict,
|
|
651
1043
|
)
|
|
652
1044
|
|
|
1045
|
+
if len(subgraph.table_dict) >= 15:
|
|
1046
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
1047
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
1048
|
+
f"Please create a feature request at "
|
|
1049
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1050
|
+
f"must go beyond this for your use-case.")
|
|
1051
|
+
|
|
653
1052
|
step_size: Optional[int] = None
|
|
654
1053
|
if query.query_type == QueryType.TEMPORAL:
|
|
655
|
-
step_size = date_offset_to_seconds(
|
|
1054
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
656
1055
|
|
|
657
1056
|
return Context(
|
|
658
1057
|
task_type=task_type,
|
|
@@ -677,7 +1076,7 @@ class KumoRFM:
|
|
|
677
1076
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
678
1077
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
679
1078
|
elif task_type == TaskType.REGRESSION:
|
|
680
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1079
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
681
1080
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
682
1081
|
supported_metrics = [
|
|
683
1082
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|