kumoai 2.10.0.dev202509281831__cp313-cp313-win_amd64.whl → 2.13.0.dev202511211730__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/experimental/rfm/__init__.py +153 -10
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -74
- kumoai/experimental/rfm/local_graph_sampler.py +16 -10
- kumoai/experimental/rfm/local_graph_store.py +13 -1
- kumoai/experimental/rfm/local_pquery_driver.py +323 -38
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +458 -115
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/jobs.py +1 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +12 -10
- kumoai/utils/progress_logger.py +68 -0
- {kumoai-2.10.0.dev202509281831.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/METADATA +13 -5
- {kumoai-2.10.0.dev202509281831.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/RECORD +27 -26
- {kumoai-2.10.0.dev202509281831.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202509281831.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202509281831.dist-info → kumoai-2.13.0.dev202511211730.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,21 +1,36 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
26
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
13
29
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
30
|
)
|
|
16
31
|
from kumoapi.task import TaskType
|
|
17
32
|
|
|
18
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
19
34
|
from kumoai.exceptions import HTTPException
|
|
20
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
21
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
@@ -24,10 +39,14 @@ from kumoai.experimental.rfm.local_pquery_driver import (
|
|
|
24
39
|
LocalPQueryDriver,
|
|
25
40
|
date_offset_to_seconds,
|
|
26
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
27
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
28
44
|
|
|
29
45
|
_RANDOM_SEED = 42
|
|
30
46
|
|
|
47
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
|
+
|
|
31
50
|
_MAX_CONTEXT_SIZE = {
|
|
32
51
|
RunMode.DEBUG: 100,
|
|
33
52
|
RunMode.FAST: 1_000,
|
|
@@ -42,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
42
61
|
}
|
|
43
62
|
|
|
44
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
45
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
64
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
46
65
|
"reduce either the number of tables in the graph, their "
|
|
47
66
|
"number of columns (e.g., large text columns), "
|
|
48
67
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -51,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
51
70
|
"beyond this for your use-case.")
|
|
52
71
|
|
|
53
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
54
118
|
class KumoRFM:
|
|
55
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
56
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -77,9 +141,9 @@ class KumoRFM:
|
|
|
77
141
|
|
|
78
142
|
rfm = KumoRFM(graph)
|
|
79
143
|
|
|
80
|
-
query = ("PREDICT COUNT(
|
|
81
|
-
"FOR users.user_id=
|
|
82
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
83
147
|
|
|
84
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
85
149
|
# 1 0.85
|
|
@@ -108,13 +172,54 @@ class KumoRFM:
|
|
|
108
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
109
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
110
174
|
|
|
175
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
|
+
self.num_retries: int = 0
|
|
177
|
+
from kumoai.experimental.rfm import global_state
|
|
178
|
+
self._api_client = RFMAPI(global_state.client)
|
|
179
|
+
|
|
111
180
|
def __repr__(self) -> str:
|
|
112
181
|
return f'{self.__class__.__name__}()'
|
|
113
182
|
|
|
183
|
+
@contextmanager
|
|
184
|
+
def batch_mode(
|
|
185
|
+
self,
|
|
186
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
187
|
+
num_retries: int = 1,
|
|
188
|
+
) -> Generator[None, None, None]:
|
|
189
|
+
"""Context manager to predict in batches.
|
|
190
|
+
|
|
191
|
+
.. code-block:: python
|
|
192
|
+
|
|
193
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
194
|
+
df = model.predict(query, indices=...)
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
198
|
+
maximum applicable batch size for the given task.
|
|
199
|
+
num_retries: The maximum number of retries for failed queries due
|
|
200
|
+
to unexpected server issues.
|
|
201
|
+
"""
|
|
202
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
203
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
204
|
+
f"(got {batch_size})")
|
|
205
|
+
|
|
206
|
+
if num_retries < 0:
|
|
207
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
208
|
+
f"zero (got {num_retries})")
|
|
209
|
+
|
|
210
|
+
self._batch_size = batch_size
|
|
211
|
+
self.num_retries = num_retries
|
|
212
|
+
yield
|
|
213
|
+
self._batch_size = None
|
|
214
|
+
self.num_retries = 0
|
|
215
|
+
|
|
216
|
+
@overload
|
|
114
217
|
def predict(
|
|
115
218
|
self,
|
|
116
219
|
query: str,
|
|
220
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
117
221
|
*,
|
|
222
|
+
explain: Literal[False] = False,
|
|
118
223
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
119
224
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
120
225
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -123,16 +228,65 @@ class KumoRFM:
|
|
|
123
228
|
max_pq_iterations: int = 20,
|
|
124
229
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
125
230
|
verbose: Union[bool, ProgressLogger] = True,
|
|
231
|
+
use_prediction_time: bool = False,
|
|
126
232
|
) -> pd.DataFrame:
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
@overload
|
|
236
|
+
def predict(
|
|
237
|
+
self,
|
|
238
|
+
query: str,
|
|
239
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
240
|
+
*,
|
|
241
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
242
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
243
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
244
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
245
|
+
num_neighbors: Optional[List[int]] = None,
|
|
246
|
+
num_hops: int = 2,
|
|
247
|
+
max_pq_iterations: int = 20,
|
|
248
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
249
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
250
|
+
use_prediction_time: bool = False,
|
|
251
|
+
) -> Explanation:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
def predict(
|
|
255
|
+
self,
|
|
256
|
+
query: str,
|
|
257
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
258
|
+
*,
|
|
259
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
260
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
261
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
262
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
263
|
+
num_neighbors: Optional[List[int]] = None,
|
|
264
|
+
num_hops: int = 2,
|
|
265
|
+
max_pq_iterations: int = 20,
|
|
266
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
267
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
268
|
+
use_prediction_time: bool = False,
|
|
269
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
127
270
|
"""Returns predictions for a predictive query.
|
|
128
271
|
|
|
129
272
|
Args:
|
|
130
273
|
query: The predictive query.
|
|
274
|
+
indices: The entity primary keys to predict for. Will override the
|
|
275
|
+
indices given as part of the predictive query. Predictions will
|
|
276
|
+
be generated for all indices, independent of whether they
|
|
277
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
278
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
279
|
+
explain: Configuration for explainability.
|
|
280
|
+
If set to ``True``, will additionally explain the prediction.
|
|
281
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
282
|
+
over which parts of explanation are generated.
|
|
283
|
+
Explainability is currently only supported for single entity
|
|
284
|
+
predictions with ``run_mode="FAST"``.
|
|
131
285
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
132
|
-
|
|
133
|
-
If set to
|
|
286
|
+
``None``, will use the maximum timestamp in the data.
|
|
287
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
134
288
|
context_anchor_time: The maximum anchor timestamp for context
|
|
135
|
-
examples. If set to
|
|
289
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
136
290
|
determine the anchor time for context examples.
|
|
137
291
|
run_mode: The :class:`RunMode` for the query.
|
|
138
292
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -145,32 +299,54 @@ class KumoRFM:
|
|
|
145
299
|
entities to find valid labels.
|
|
146
300
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
147
301
|
verbose: Whether to print verbose output.
|
|
302
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
303
|
+
additional feature during prediction. This is typically
|
|
304
|
+
beneficial for time series forecasting tasks.
|
|
148
305
|
|
|
149
306
|
Returns:
|
|
150
|
-
The predictions as a :class:`pandas.DataFrame
|
|
307
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
308
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
|
+
containing the prediction, summary, and details.
|
|
151
310
|
"""
|
|
152
|
-
|
|
311
|
+
explain_config: Optional[ExplainConfig] = None
|
|
312
|
+
if explain is True:
|
|
313
|
+
explain_config = ExplainConfig()
|
|
314
|
+
elif explain is not False:
|
|
315
|
+
explain_config = ExplainConfig._cast(explain)
|
|
316
|
+
|
|
153
317
|
query_def = self._parse_query(query)
|
|
318
|
+
query_str = query_def.to_string()
|
|
154
319
|
|
|
155
320
|
if num_hops != 2 and num_neighbors is not None:
|
|
156
321
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
157
322
|
f"custom 'num_hops={num_hops}' option")
|
|
158
323
|
|
|
159
|
-
if
|
|
324
|
+
if explain_config is not None and run_mode in {
|
|
325
|
+
RunMode.NORMAL, RunMode.BEST
|
|
326
|
+
}:
|
|
160
327
|
warnings.warn(f"Explainability is currently only supported for "
|
|
161
328
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
162
329
|
f"mode has been reset. Please lower the run mode to "
|
|
163
330
|
f"suppress this warning.")
|
|
164
331
|
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
332
|
+
if indices is None:
|
|
333
|
+
if query_def.rfm_entity_ids is None:
|
|
334
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
335
|
+
"pass them via `predict(query, indices=...)`")
|
|
336
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
337
|
+
else:
|
|
338
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
339
|
+
|
|
340
|
+
if len(indices) == 0:
|
|
341
|
+
raise ValueError("At least one entity is required")
|
|
342
|
+
|
|
343
|
+
if explain_config is not None and len(indices) > 1:
|
|
344
|
+
raise ValueError(
|
|
345
|
+
f"Cannot explain predictions for more than a single entity "
|
|
346
|
+
f"(got {len(indices)})")
|
|
171
347
|
|
|
172
348
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
173
|
-
if
|
|
349
|
+
if explain_config is not None:
|
|
174
350
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
175
351
|
else:
|
|
176
352
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -179,49 +355,188 @@ class KumoRFM:
|
|
|
179
355
|
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
180
356
|
|
|
181
357
|
with verbose as logger:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
358
|
+
|
|
359
|
+
batch_size: Optional[int] = None
|
|
360
|
+
if self._batch_size == 'max':
|
|
361
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
362
|
+
query_def,
|
|
363
|
+
edge_types=self._graph_store.edge_types,
|
|
364
|
+
)
|
|
365
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
366
|
+
else:
|
|
367
|
+
batch_size = self._batch_size
|
|
368
|
+
|
|
369
|
+
if batch_size is not None:
|
|
370
|
+
offsets = range(0, len(indices), batch_size)
|
|
371
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
372
|
+
else:
|
|
373
|
+
batches = [indices]
|
|
374
|
+
|
|
375
|
+
if len(batches) > 1:
|
|
376
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
377
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
378
|
+
|
|
379
|
+
predictions: List[pd.DataFrame] = []
|
|
380
|
+
summary: Optional[str] = None
|
|
381
|
+
details: Optional[Explanation] = None
|
|
382
|
+
for i, batch in enumerate(batches):
|
|
383
|
+
# TODO Re-use the context for subsequent predictions.
|
|
384
|
+
context = self._get_context(
|
|
385
|
+
query=query_def,
|
|
386
|
+
indices=batch,
|
|
387
|
+
anchor_time=anchor_time,
|
|
388
|
+
context_anchor_time=context_anchor_time,
|
|
389
|
+
run_mode=RunMode(run_mode),
|
|
390
|
+
num_neighbors=num_neighbors,
|
|
391
|
+
num_hops=num_hops,
|
|
392
|
+
max_pq_iterations=max_pq_iterations,
|
|
393
|
+
evaluate=False,
|
|
394
|
+
random_seed=random_seed,
|
|
395
|
+
logger=logger if i == 0 else None,
|
|
396
|
+
)
|
|
397
|
+
request = RFMPredictRequest(
|
|
398
|
+
context=context,
|
|
399
|
+
run_mode=RunMode(run_mode),
|
|
400
|
+
query=query_str,
|
|
401
|
+
use_prediction_time=use_prediction_time,
|
|
402
|
+
)
|
|
403
|
+
with warnings.catch_warnings():
|
|
404
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
405
|
+
request_msg = request.to_protobuf()
|
|
406
|
+
_bytes = request_msg.SerializeToString()
|
|
407
|
+
if i == 0:
|
|
408
|
+
logger.log(f"Generated context of size "
|
|
409
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
410
|
+
|
|
411
|
+
if len(_bytes) > _MAX_SIZE:
|
|
412
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
413
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
414
|
+
|
|
415
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
416
|
+
and len(batches) > 1):
|
|
417
|
+
verbose.init_progress(
|
|
418
|
+
total=len(batches),
|
|
419
|
+
description='Predicting',
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
for attempt in range(self.num_retries + 1):
|
|
423
|
+
try:
|
|
424
|
+
if explain_config is not None:
|
|
425
|
+
resp = self._api_client.explain(
|
|
426
|
+
request=_bytes,
|
|
427
|
+
skip_summary=explain_config.skip_summary,
|
|
428
|
+
)
|
|
429
|
+
summary = resp.summary
|
|
430
|
+
details = resp.details
|
|
431
|
+
else:
|
|
432
|
+
resp = self._api_client.predict(_bytes)
|
|
433
|
+
df = pd.DataFrame(**resp.prediction)
|
|
434
|
+
|
|
435
|
+
# Cast 'ENTITY' to correct data type:
|
|
436
|
+
if 'ENTITY' in df:
|
|
437
|
+
entity = query_def.entity_table
|
|
438
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
439
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
440
|
+
type(pkey_map.index[0]))
|
|
441
|
+
|
|
442
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
443
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
444
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
445
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
446
|
+
if isinstance(ser.iloc[0], str):
|
|
447
|
+
unit = None
|
|
448
|
+
else:
|
|
449
|
+
unit = 'ms'
|
|
450
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
451
|
+
ser, errors='coerce', unit=unit)
|
|
452
|
+
|
|
453
|
+
predictions.append(df)
|
|
454
|
+
|
|
455
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
456
|
+
and len(batches) > 1):
|
|
457
|
+
verbose.step()
|
|
458
|
+
|
|
459
|
+
break
|
|
460
|
+
except HTTPException as e:
|
|
461
|
+
if attempt == self.num_retries:
|
|
462
|
+
try:
|
|
463
|
+
msg = json.loads(e.detail)['detail']
|
|
464
|
+
except Exception:
|
|
465
|
+
msg = e.detail
|
|
466
|
+
raise RuntimeError(
|
|
467
|
+
f"An unexpected exception occurred. Please "
|
|
468
|
+
f"create an issue at "
|
|
469
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
470
|
+
) from None
|
|
471
|
+
|
|
472
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
473
|
+
|
|
474
|
+
if len(predictions) == 1:
|
|
475
|
+
prediction = predictions[0]
|
|
476
|
+
else:
|
|
477
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
478
|
+
|
|
479
|
+
if explain_config is not None:
|
|
480
|
+
assert len(predictions) == 1
|
|
481
|
+
assert summary is not None
|
|
482
|
+
assert details is not None
|
|
483
|
+
return Explanation(
|
|
484
|
+
prediction=prediction,
|
|
485
|
+
summary=summary,
|
|
486
|
+
details=details,
|
|
197
487
|
)
|
|
198
|
-
with warnings.catch_warnings():
|
|
199
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
200
|
-
request_msg = request.to_protobuf()
|
|
201
|
-
request_bytes = request_msg.SerializeToString()
|
|
202
|
-
logger.log(f"Generated context of size "
|
|
203
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
204
488
|
|
|
205
|
-
|
|
206
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
207
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
489
|
+
return prediction
|
|
208
490
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
491
|
+
def is_valid_entity(
|
|
492
|
+
self,
|
|
493
|
+
query: str,
|
|
494
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
495
|
+
*,
|
|
496
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
497
|
+
) -> np.ndarray:
|
|
498
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
499
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
500
|
+
entity filter constraints.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
query: The predictive query.
|
|
504
|
+
indices: The entity primary keys to predict for. Will override the
|
|
505
|
+
indices given as part of the predictive query.
|
|
506
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
507
|
+
``None``, will use the maximum timestamp in the data.
|
|
508
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
509
|
+
"""
|
|
510
|
+
query_def = self._parse_query(query)
|
|
511
|
+
|
|
512
|
+
if indices is None:
|
|
513
|
+
if query_def.rfm_entity_ids is None:
|
|
514
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
515
|
+
"pass them via "
|
|
516
|
+
"`is_valid_entity(query, indices=...)`")
|
|
517
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
518
|
+
|
|
519
|
+
if len(indices) == 0:
|
|
520
|
+
raise ValueError("At least one entity is required")
|
|
223
521
|
|
|
224
|
-
|
|
522
|
+
if anchor_time is None:
|
|
523
|
+
anchor_time = self._graph_store.max_time
|
|
524
|
+
|
|
525
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
526
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
527
|
+
else:
|
|
528
|
+
assert anchor_time == 'entity'
|
|
529
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
530
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
531
|
+
f"table '{query_def.entity_table}' "
|
|
532
|
+
f"to have a time column.")
|
|
533
|
+
|
|
534
|
+
node = self._graph_store.get_node_id(
|
|
535
|
+
table_name=query_def.entity_table,
|
|
536
|
+
pkey=pd.Series(indices),
|
|
537
|
+
)
|
|
538
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
539
|
+
return query_driver.is_valid(node, anchor_time)
|
|
225
540
|
|
|
226
541
|
def evaluate(
|
|
227
542
|
self,
|
|
@@ -236,6 +551,7 @@ class KumoRFM:
|
|
|
236
551
|
max_pq_iterations: int = 20,
|
|
237
552
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
238
553
|
verbose: Union[bool, ProgressLogger] = True,
|
|
554
|
+
use_prediction_time: bool = False,
|
|
239
555
|
) -> pd.DataFrame:
|
|
240
556
|
"""Evaluates a predictive query.
|
|
241
557
|
|
|
@@ -243,10 +559,10 @@ class KumoRFM:
|
|
|
243
559
|
query: The predictive query.
|
|
244
560
|
metrics: The metrics to use.
|
|
245
561
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
246
|
-
|
|
247
|
-
If set to
|
|
562
|
+
``None``, will use the maximum timestamp in the data.
|
|
563
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
248
564
|
context_anchor_time: The maximum anchor timestamp for context
|
|
249
|
-
examples. If set to
|
|
565
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
250
566
|
determine the anchor time for context examples.
|
|
251
567
|
run_mode: The :class:`RunMode` for the query.
|
|
252
568
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -259,6 +575,9 @@ class KumoRFM:
|
|
|
259
575
|
entities to find valid labels.
|
|
260
576
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
261
577
|
verbose: Whether to print verbose output.
|
|
578
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
579
|
+
additional feature during prediction. This is typically
|
|
580
|
+
beneficial for time series forecasting tasks.
|
|
262
581
|
|
|
263
582
|
Returns:
|
|
264
583
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -269,6 +588,12 @@ class KumoRFM:
|
|
|
269
588
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
270
589
|
f"custom 'num_hops={num_hops}' option")
|
|
271
590
|
|
|
591
|
+
if query_def.rfm_entity_ids is not None:
|
|
592
|
+
query_def = replace(
|
|
593
|
+
query_def,
|
|
594
|
+
rfm_entity_ids=None,
|
|
595
|
+
)
|
|
596
|
+
|
|
272
597
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
273
598
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
274
599
|
|
|
@@ -277,7 +602,8 @@ class KumoRFM:
|
|
|
277
602
|
|
|
278
603
|
with verbose as logger:
|
|
279
604
|
context = self._get_context(
|
|
280
|
-
query_def,
|
|
605
|
+
query=query_def,
|
|
606
|
+
indices=None,
|
|
281
607
|
anchor_time=anchor_time,
|
|
282
608
|
context_anchor_time=context_anchor_time,
|
|
283
609
|
run_mode=RunMode(run_mode),
|
|
@@ -295,6 +621,7 @@ class KumoRFM:
|
|
|
295
621
|
context=context,
|
|
296
622
|
run_mode=RunMode(run_mode),
|
|
297
623
|
metrics=metrics,
|
|
624
|
+
use_prediction_time=use_prediction_time,
|
|
298
625
|
)
|
|
299
626
|
with warnings.catch_warnings():
|
|
300
627
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -305,10 +632,10 @@ class KumoRFM:
|
|
|
305
632
|
|
|
306
633
|
if len(request_bytes) > _MAX_SIZE:
|
|
307
634
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
308
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
635
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
309
636
|
|
|
310
637
|
try:
|
|
311
|
-
resp =
|
|
638
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
312
639
|
except HTTPException as e:
|
|
313
640
|
try:
|
|
314
641
|
msg = json.loads(e.detail)['detail']
|
|
@@ -353,18 +680,19 @@ class KumoRFM:
|
|
|
353
680
|
|
|
354
681
|
if anchor_time is None:
|
|
355
682
|
anchor_time = self._graph_store.max_time
|
|
356
|
-
|
|
357
|
-
|
|
683
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
+
anchor_time = anchor_time - (
|
|
685
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
686
|
+
query_def.num_forecasts)
|
|
358
687
|
|
|
359
688
|
assert anchor_time is not None
|
|
360
689
|
if isinstance(anchor_time, pd.Timestamp):
|
|
361
690
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
362
691
|
else:
|
|
363
692
|
assert anchor_time == 'entity'
|
|
364
|
-
if (query_def.
|
|
365
|
-
not in self._graph_store.time_dict):
|
|
693
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
366
694
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
367
|
-
f"table '{query_def.
|
|
695
|
+
f"table '{query_def.entity_table}' "
|
|
368
696
|
f"to have a time column")
|
|
369
697
|
|
|
370
698
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -379,7 +707,7 @@ class KumoRFM:
|
|
|
379
707
|
)
|
|
380
708
|
|
|
381
709
|
entity = self._graph_store.pkey_map_dict[
|
|
382
|
-
query_def.
|
|
710
|
+
query_def.entity_table].index[node]
|
|
383
711
|
|
|
384
712
|
return pd.DataFrame({
|
|
385
713
|
'ENTITY': entity,
|
|
@@ -389,8 +717,8 @@ class KumoRFM:
|
|
|
389
717
|
|
|
390
718
|
# Helpers #################################################################
|
|
391
719
|
|
|
392
|
-
def _parse_query(self, query: str) ->
|
|
393
|
-
if isinstance(query,
|
|
720
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
721
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
394
722
|
return query
|
|
395
723
|
|
|
396
724
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -400,12 +728,13 @@ class KumoRFM:
|
|
|
400
728
|
"predictions or evaluations.")
|
|
401
729
|
|
|
402
730
|
try:
|
|
403
|
-
request =
|
|
731
|
+
request = RFMParseQueryRequest(
|
|
404
732
|
query=query,
|
|
405
733
|
graph_definition=self._graph_def,
|
|
406
734
|
)
|
|
407
735
|
|
|
408
|
-
resp =
|
|
736
|
+
resp = self._api_client.parse_query(request)
|
|
737
|
+
|
|
409
738
|
# TODO Expose validation warnings.
|
|
410
739
|
|
|
411
740
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -416,7 +745,7 @@ class KumoRFM:
|
|
|
416
745
|
warnings.warn(f"Encountered the following warnings during "
|
|
417
746
|
f"parsing:\n{msg}")
|
|
418
747
|
|
|
419
|
-
return resp.
|
|
748
|
+
return resp.query
|
|
420
749
|
except HTTPException as e:
|
|
421
750
|
try:
|
|
422
751
|
msg = json.loads(e.detail)['detail']
|
|
@@ -427,7 +756,7 @@ class KumoRFM:
|
|
|
427
756
|
|
|
428
757
|
def _validate_time(
|
|
429
758
|
self,
|
|
430
|
-
query:
|
|
759
|
+
query: ValidatedPredictiveQuery,
|
|
431
760
|
anchor_time: pd.Timestamp,
|
|
432
761
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
433
762
|
evaluate: bool,
|
|
@@ -450,6 +779,11 @@ class KumoRFM:
|
|
|
450
779
|
f"only contains data back to "
|
|
451
780
|
f"'{self._graph_store.min_time}'.")
|
|
452
781
|
|
|
782
|
+
if query.target_ast.date_offset_range is not None:
|
|
783
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
|
+
else:
|
|
785
|
+
end_offset = pd.DateOffset(0)
|
|
786
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
453
787
|
if (context_anchor_time is not None
|
|
454
788
|
and context_anchor_time > anchor_time):
|
|
455
789
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -458,19 +792,18 @@ class KumoRFM:
|
|
|
458
792
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
459
793
|
f"intended.")
|
|
460
794
|
elif (query.query_type == QueryType.TEMPORAL
|
|
461
|
-
and context_anchor_time is not None
|
|
462
|
-
|
|
795
|
+
and context_anchor_time is not None
|
|
796
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
463
797
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
464
798
|
f"'{context_anchor_time}' will leak information "
|
|
465
799
|
f"from the prediction anchor timestamp "
|
|
466
800
|
f"'{anchor_time}'. Please make sure this is "
|
|
467
801
|
f"intended.")
|
|
468
802
|
|
|
469
|
-
elif (context_anchor_time is not None
|
|
470
|
-
|
|
803
|
+
elif (context_anchor_time is not None
|
|
804
|
+
and context_anchor_time - forecast_end_offset
|
|
471
805
|
< self._graph_store.min_time):
|
|
472
|
-
_time = context_anchor_time -
|
|
473
|
-
query.num_forecasts)
|
|
806
|
+
_time = context_anchor_time - forecast_end_offset
|
|
474
807
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
475
808
|
f"aggregation time range is too large. To form "
|
|
476
809
|
f"proper input data, we would need data back to "
|
|
@@ -483,8 +816,7 @@ class KumoRFM:
|
|
|
483
816
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
484
817
|
f"in the data. Please make sure this is intended.")
|
|
485
818
|
|
|
486
|
-
max_eval_time =
|
|
487
|
-
query.target.end_offset * query.num_forecasts)
|
|
819
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
488
820
|
if evaluate and anchor_time > max_eval_time:
|
|
489
821
|
raise ValueError(
|
|
490
822
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -492,7 +824,8 @@ class KumoRFM:
|
|
|
492
824
|
|
|
493
825
|
def _get_context(
|
|
494
826
|
self,
|
|
495
|
-
query:
|
|
827
|
+
query: ValidatedPredictiveQuery,
|
|
828
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
496
829
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
497
830
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
498
831
|
run_mode: RunMode,
|
|
@@ -519,8 +852,8 @@ class KumoRFM:
|
|
|
519
852
|
f"must go beyond this for your use-case.")
|
|
520
853
|
|
|
521
854
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
522
|
-
task_type =
|
|
523
|
-
|
|
855
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
856
|
+
query,
|
|
524
857
|
edge_types=self._graph_store.edge_types,
|
|
525
858
|
)
|
|
526
859
|
|
|
@@ -552,11 +885,15 @@ class KumoRFM:
|
|
|
552
885
|
else:
|
|
553
886
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
554
887
|
|
|
888
|
+
if query.target_ast.date_offset_range is None:
|
|
889
|
+
end_offset = pd.DateOffset(0)
|
|
890
|
+
else:
|
|
891
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
892
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
555
893
|
if anchor_time is None:
|
|
556
894
|
anchor_time = self._graph_store.max_time
|
|
557
895
|
if evaluate:
|
|
558
|
-
anchor_time = anchor_time -
|
|
559
|
-
query.num_forecasts)
|
|
896
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
560
897
|
if logger is not None:
|
|
561
898
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
562
899
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -571,15 +908,14 @@ class KumoRFM:
|
|
|
571
908
|
assert anchor_time is not None
|
|
572
909
|
if isinstance(anchor_time, pd.Timestamp):
|
|
573
910
|
if context_anchor_time is None:
|
|
574
|
-
context_anchor_time = anchor_time -
|
|
575
|
-
query.num_forecasts)
|
|
911
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
576
912
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
577
913
|
evaluate)
|
|
578
914
|
else:
|
|
579
915
|
assert anchor_time == 'entity'
|
|
580
|
-
if query.
|
|
916
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
581
917
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
582
|
-
f"table '{query.
|
|
918
|
+
f"table '{query.entity_table}' to "
|
|
583
919
|
f"have a time column")
|
|
584
920
|
if context_anchor_time is not None:
|
|
585
921
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -620,28 +956,25 @@ class KumoRFM:
|
|
|
620
956
|
logger.log(msg)
|
|
621
957
|
|
|
622
958
|
else:
|
|
623
|
-
assert
|
|
959
|
+
assert indices is not None
|
|
624
960
|
|
|
625
|
-
|
|
626
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
961
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
627
962
|
raise ValueError(f"Cannot predict for more than "
|
|
628
|
-
f"{
|
|
629
|
-
f"(got {len(
|
|
963
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
964
|
+
f"once (got {len(indices):,}). Use "
|
|
965
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
966
|
+
f"in batches")
|
|
630
967
|
|
|
631
968
|
test_node = self._graph_store.get_node_id(
|
|
632
|
-
table_name=query.
|
|
633
|
-
pkey=pd.Series(
|
|
634
|
-
query.entity.ids.value,
|
|
635
|
-
dtype=query.entity.ids.dtype,
|
|
636
|
-
),
|
|
969
|
+
table_name=query.entity_table,
|
|
970
|
+
pkey=pd.Series(indices),
|
|
637
971
|
)
|
|
638
972
|
|
|
639
973
|
if isinstance(anchor_time, pd.Timestamp):
|
|
640
974
|
test_time = pd.Series(anchor_time).repeat(
|
|
641
975
|
len(test_node)).reset_index(drop=True)
|
|
642
976
|
else:
|
|
643
|
-
time = self._graph_store.time_dict[
|
|
644
|
-
query.entity.pkey.table_name]
|
|
977
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
645
978
|
time = time[test_node] * 1000**3
|
|
646
979
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
647
980
|
|
|
@@ -674,12 +1007,23 @@ class KumoRFM:
|
|
|
674
1007
|
raise NotImplementedError
|
|
675
1008
|
logger.log(msg)
|
|
676
1009
|
|
|
677
|
-
entity_table_names
|
|
678
|
-
|
|
1010
|
+
entity_table_names: Tuple[str, ...]
|
|
1011
|
+
if task_type.is_link_pred:
|
|
1012
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1013
|
+
assert final_aggr is not None
|
|
1014
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
+
for edge_type in self._graph_store.edge_types:
|
|
1016
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
|
+
entity_table_names = (
|
|
1018
|
+
query.entity_table,
|
|
1019
|
+
edge_type[2],
|
|
1020
|
+
)
|
|
1021
|
+
else:
|
|
1022
|
+
entity_table_names = (query.entity_table, )
|
|
679
1023
|
|
|
680
1024
|
# Exclude the entity anchor time from the feature set to prevent
|
|
681
1025
|
# running out-of-distribution between in-context and test examples:
|
|
682
|
-
exclude_cols_dict = query.
|
|
1026
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
683
1027
|
if anchor_time == 'entity':
|
|
684
1028
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
685
1029
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -694,7 +1038,6 @@ class KumoRFM:
|
|
|
694
1038
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
695
1039
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
696
1040
|
]),
|
|
697
|
-
run_mode=run_mode,
|
|
698
1041
|
num_neighbors=num_neighbors,
|
|
699
1042
|
exclude_cols_dict=exclude_cols_dict,
|
|
700
1043
|
)
|
|
@@ -708,7 +1051,7 @@ class KumoRFM:
|
|
|
708
1051
|
|
|
709
1052
|
step_size: Optional[int] = None
|
|
710
1053
|
if query.query_type == QueryType.TEMPORAL:
|
|
711
|
-
step_size = date_offset_to_seconds(
|
|
1054
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
712
1055
|
|
|
713
1056
|
return Context(
|
|
714
1057
|
task_type=task_type,
|
|
@@ -733,7 +1076,7 @@ class KumoRFM:
|
|
|
733
1076
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
734
1077
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
735
1078
|
elif task_type == TaskType.REGRESSION:
|
|
736
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1079
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
737
1080
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
738
1081
|
supported_metrics = [
|
|
739
1082
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|