kumoai 2.7.0.dev202508201830__cp312-cp312-win_amd64.whl → 2.12.0.dev202511111731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/client.py +10 -5
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/snowflake_connector.py +9 -0
- kumoai/connector/utils.py +1377 -209
- kumoai/experimental/rfm/__init__.py +5 -3
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +96 -82
- kumoai/experimental/rfm/local_graph_sampler.py +16 -8
- kumoai/experimental/rfm/local_graph_store.py +32 -10
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/local_table.py +142 -45
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +535 -125
- kumoai/experimental/rfm/utils.py +0 -3
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/METADATA +4 -5
- {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/RECORD +34 -34
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/WHEEL +0 -0
- {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.7.0.dev202508201830.dist-info → kumoai-2.12.0.dev202511111731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,17 +1,32 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
26
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
13
29
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
30
|
)
|
|
16
31
|
from kumoapi.task import TaskType
|
|
17
32
|
|
|
@@ -20,11 +35,18 @@ from kumoai.exceptions import HTTPException
|
|
|
20
35
|
from kumoai.experimental.rfm import LocalGraph
|
|
21
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
22
37
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
23
|
-
from kumoai.experimental.rfm.local_pquery_driver import
|
|
24
|
-
|
|
38
|
+
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
+
LocalPQueryDriver,
|
|
40
|
+
date_offset_to_seconds,
|
|
41
|
+
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
43
|
+
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
25
44
|
|
|
26
45
|
_RANDOM_SEED = 42
|
|
27
46
|
|
|
47
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
|
+
|
|
28
50
|
_MAX_CONTEXT_SIZE = {
|
|
29
51
|
RunMode.DEBUG: 100,
|
|
30
52
|
RunMode.FAST: 1_000,
|
|
@@ -39,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
39
61
|
}
|
|
40
62
|
|
|
41
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
42
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
64
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
43
65
|
"reduce either the number of tables in the graph, their "
|
|
44
66
|
"number of columns (e.g., large text columns), "
|
|
45
67
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -48,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
48
70
|
"beyond this for your use-case.")
|
|
49
71
|
|
|
50
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
51
118
|
class KumoRFM:
|
|
52
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
53
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -98,35 +165,127 @@ class KumoRFM:
|
|
|
98
165
|
self,
|
|
99
166
|
graph: LocalGraph,
|
|
100
167
|
preprocess: bool = False,
|
|
101
|
-
verbose: bool = True,
|
|
168
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
102
169
|
) -> None:
|
|
103
170
|
graph = graph.validate()
|
|
104
171
|
self._graph_def = graph._to_api_graph_definition()
|
|
105
172
|
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
106
173
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
107
174
|
|
|
175
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
176
|
+
self.num_retries: int = 0
|
|
177
|
+
|
|
108
178
|
def __repr__(self) -> str:
|
|
109
179
|
return f'{self.__class__.__name__}()'
|
|
110
180
|
|
|
181
|
+
@contextmanager
|
|
182
|
+
def batch_mode(
|
|
183
|
+
self,
|
|
184
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
185
|
+
num_retries: int = 1,
|
|
186
|
+
) -> Generator[None, None, None]:
|
|
187
|
+
"""Context manager to predict in batches.
|
|
188
|
+
|
|
189
|
+
.. code-block:: python
|
|
190
|
+
|
|
191
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
192
|
+
df = model.predict(query, indices=...)
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
196
|
+
maximum applicable batch size for the given task.
|
|
197
|
+
num_retries: The maximum number of retries for failed queries due
|
|
198
|
+
to unexpected server issues.
|
|
199
|
+
"""
|
|
200
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
201
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
202
|
+
f"(got {batch_size})")
|
|
203
|
+
|
|
204
|
+
if num_retries < 0:
|
|
205
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
206
|
+
f"zero (got {num_retries})")
|
|
207
|
+
|
|
208
|
+
self._batch_size = batch_size
|
|
209
|
+
self.num_retries = num_retries
|
|
210
|
+
yield
|
|
211
|
+
self._batch_size = None
|
|
212
|
+
self.num_retries = 0
|
|
213
|
+
|
|
214
|
+
@overload
|
|
111
215
|
def predict(
|
|
112
216
|
self,
|
|
113
217
|
query: str,
|
|
218
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
114
219
|
*,
|
|
220
|
+
explain: Literal[False] = False,
|
|
115
221
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
222
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
116
223
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
117
224
|
num_neighbors: Optional[List[int]] = None,
|
|
118
225
|
num_hops: int = 2,
|
|
119
226
|
max_pq_iterations: int = 20,
|
|
120
227
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
121
|
-
verbose: bool = True,
|
|
228
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
229
|
+
use_prediction_time: bool = False,
|
|
122
230
|
) -> pd.DataFrame:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
@overload
|
|
234
|
+
def predict(
|
|
235
|
+
self,
|
|
236
|
+
query: str,
|
|
237
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
238
|
+
*,
|
|
239
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
240
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
241
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
242
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
243
|
+
num_neighbors: Optional[List[int]] = None,
|
|
244
|
+
num_hops: int = 2,
|
|
245
|
+
max_pq_iterations: int = 20,
|
|
246
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
247
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
248
|
+
use_prediction_time: bool = False,
|
|
249
|
+
) -> Explanation:
|
|
250
|
+
pass
|
|
251
|
+
|
|
252
|
+
def predict(
|
|
253
|
+
self,
|
|
254
|
+
query: str,
|
|
255
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
256
|
+
*,
|
|
257
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
258
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
259
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
260
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
261
|
+
num_neighbors: Optional[List[int]] = None,
|
|
262
|
+
num_hops: int = 2,
|
|
263
|
+
max_pq_iterations: int = 20,
|
|
264
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
265
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
267
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
123
268
|
"""Returns predictions for a predictive query.
|
|
124
269
|
|
|
125
270
|
Args:
|
|
126
271
|
query: The predictive query.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
272
|
+
indices: The entity primary keys to predict for. Will override the
|
|
273
|
+
indices given as part of the predictive query. Predictions will
|
|
274
|
+
be generated for all indices, independent of whether they
|
|
275
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
276
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
277
|
+
explain: Configuration for explainability.
|
|
278
|
+
If set to ``True``, will additionally explain the prediction.
|
|
279
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
280
|
+
over which parts of explanation are generated.
|
|
281
|
+
Explainability is currently only supported for single entity
|
|
282
|
+
predictions with ``run_mode="FAST"``.
|
|
283
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
284
|
+
``None``, will use the maximum timestamp in the data.
|
|
285
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
286
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
287
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
288
|
+
determine the anchor time for context examples.
|
|
130
289
|
run_mode: The :class:`RunMode` for the query.
|
|
131
290
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
132
291
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -138,79 +297,244 @@ class KumoRFM:
|
|
|
138
297
|
entities to find valid labels.
|
|
139
298
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
140
299
|
verbose: Whether to print verbose output.
|
|
300
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
301
|
+
additional feature during prediction. This is typically
|
|
302
|
+
beneficial for time series forecasting tasks.
|
|
141
303
|
|
|
142
304
|
Returns:
|
|
143
|
-
The predictions as a :class:`pandas.DataFrame
|
|
305
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
306
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
|
+
containing the prediction, summary, and details.
|
|
144
308
|
"""
|
|
145
|
-
|
|
309
|
+
explain_config: Optional[ExplainConfig] = None
|
|
310
|
+
if explain is True:
|
|
311
|
+
explain_config = ExplainConfig()
|
|
312
|
+
elif explain is not False:
|
|
313
|
+
explain_config = ExplainConfig._cast(explain)
|
|
314
|
+
|
|
146
315
|
query_def = self._parse_query(query)
|
|
316
|
+
query_str = query_def.to_string()
|
|
147
317
|
|
|
148
318
|
if num_hops != 2 and num_neighbors is not None:
|
|
149
319
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
150
320
|
f"custom 'num_hops={num_hops}' option")
|
|
151
321
|
|
|
152
|
-
if
|
|
322
|
+
if explain_config is not None and run_mode in {
|
|
323
|
+
RunMode.NORMAL, RunMode.BEST
|
|
324
|
+
}:
|
|
153
325
|
warnings.warn(f"Explainability is currently only supported for "
|
|
154
326
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
155
327
|
f"mode has been reset. Please lower the run mode to "
|
|
156
328
|
f"suppress this warning.")
|
|
157
329
|
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
330
|
+
if indices is None:
|
|
331
|
+
if query_def.rfm_entity_ids is None:
|
|
332
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
333
|
+
"pass them via `predict(query, indices=...)`")
|
|
334
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
335
|
+
else:
|
|
336
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
337
|
+
|
|
338
|
+
if len(indices) == 0:
|
|
339
|
+
raise ValueError("At least one entity is required")
|
|
340
|
+
|
|
341
|
+
if explain_config is not None and len(indices) > 1:
|
|
342
|
+
raise ValueError(
|
|
343
|
+
f"Cannot explain predictions for more than a single entity "
|
|
344
|
+
f"(got {len(indices)})")
|
|
164
345
|
|
|
165
346
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
166
|
-
if
|
|
347
|
+
if explain_config is not None:
|
|
167
348
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
168
349
|
else:
|
|
169
350
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
170
351
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
352
|
+
if not isinstance(verbose, ProgressLogger):
|
|
353
|
+
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
354
|
+
|
|
355
|
+
with verbose as logger:
|
|
356
|
+
|
|
357
|
+
batch_size: Optional[int] = None
|
|
358
|
+
if self._batch_size == 'max':
|
|
359
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
360
|
+
query_def,
|
|
361
|
+
edge_types=self._graph_store.edge_types,
|
|
362
|
+
)
|
|
363
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
364
|
+
else:
|
|
365
|
+
batch_size = self._batch_size
|
|
366
|
+
|
|
367
|
+
if batch_size is not None:
|
|
368
|
+
offsets = range(0, len(indices), batch_size)
|
|
369
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
370
|
+
else:
|
|
371
|
+
batches = [indices]
|
|
372
|
+
|
|
373
|
+
if len(batches) > 1:
|
|
374
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
375
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
376
|
+
|
|
377
|
+
predictions: List[pd.DataFrame] = []
|
|
378
|
+
summary: Optional[str] = None
|
|
379
|
+
details: Optional[Explanation] = None
|
|
380
|
+
for i, batch in enumerate(batches):
|
|
381
|
+
# TODO Re-use the context for subsequent predictions.
|
|
382
|
+
context = self._get_context(
|
|
383
|
+
query=query_def,
|
|
384
|
+
indices=batch,
|
|
385
|
+
anchor_time=anchor_time,
|
|
386
|
+
context_anchor_time=context_anchor_time,
|
|
387
|
+
run_mode=RunMode(run_mode),
|
|
388
|
+
num_neighbors=num_neighbors,
|
|
389
|
+
num_hops=num_hops,
|
|
390
|
+
max_pq_iterations=max_pq_iterations,
|
|
391
|
+
evaluate=False,
|
|
392
|
+
random_seed=random_seed,
|
|
393
|
+
logger=logger if i == 0 else None,
|
|
394
|
+
)
|
|
395
|
+
request = RFMPredictRequest(
|
|
396
|
+
context=context,
|
|
397
|
+
run_mode=RunMode(run_mode),
|
|
398
|
+
query=query_str,
|
|
399
|
+
use_prediction_time=use_prediction_time,
|
|
400
|
+
)
|
|
401
|
+
with warnings.catch_warnings():
|
|
402
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
403
|
+
request_msg = request.to_protobuf()
|
|
404
|
+
_bytes = request_msg.SerializeToString()
|
|
405
|
+
if i == 0:
|
|
406
|
+
logger.log(f"Generated context of size "
|
|
407
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
408
|
+
|
|
409
|
+
if len(_bytes) > _MAX_SIZE:
|
|
410
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
411
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
412
|
+
|
|
413
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
414
|
+
and len(batches) > 1):
|
|
415
|
+
verbose.init_progress(
|
|
416
|
+
total=len(batches),
|
|
417
|
+
description='Predicting',
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
for attempt in range(self.num_retries + 1):
|
|
421
|
+
try:
|
|
422
|
+
if explain_config is not None:
|
|
423
|
+
resp = global_state.client.rfm_api.explain(
|
|
424
|
+
request=_bytes,
|
|
425
|
+
skip_summary=explain_config.skip_summary,
|
|
426
|
+
)
|
|
427
|
+
summary = resp.summary
|
|
428
|
+
details = resp.details
|
|
429
|
+
else:
|
|
430
|
+
resp = global_state.client.rfm_api.predict(_bytes)
|
|
431
|
+
df = pd.DataFrame(**resp.prediction)
|
|
432
|
+
|
|
433
|
+
# Cast 'ENTITY' to correct data type:
|
|
434
|
+
if 'ENTITY' in df:
|
|
435
|
+
entity = query_def.entity_table
|
|
436
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
437
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
438
|
+
type(pkey_map.index[0]))
|
|
439
|
+
|
|
440
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
442
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
443
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
444
|
+
if isinstance(ser.iloc[0], str):
|
|
445
|
+
unit = None
|
|
446
|
+
else:
|
|
447
|
+
unit = 'ms'
|
|
448
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
449
|
+
ser, errors='coerce', unit=unit)
|
|
450
|
+
|
|
451
|
+
predictions.append(df)
|
|
452
|
+
|
|
453
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
454
|
+
and len(batches) > 1):
|
|
455
|
+
verbose.step()
|
|
456
|
+
|
|
457
|
+
break
|
|
458
|
+
except HTTPException as e:
|
|
459
|
+
if attempt == self.num_retries:
|
|
460
|
+
try:
|
|
461
|
+
msg = json.loads(e.detail)['detail']
|
|
462
|
+
except Exception:
|
|
463
|
+
msg = e.detail
|
|
464
|
+
raise RuntimeError(
|
|
465
|
+
f"An unexpected exception occurred. Please "
|
|
466
|
+
f"create an issue at "
|
|
467
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
468
|
+
) from None
|
|
469
|
+
|
|
470
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
471
|
+
|
|
472
|
+
if len(predictions) == 1:
|
|
473
|
+
prediction = predictions[0]
|
|
474
|
+
else:
|
|
475
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
476
|
+
|
|
477
|
+
if explain_config is not None:
|
|
478
|
+
assert len(predictions) == 1
|
|
479
|
+
assert summary is not None
|
|
480
|
+
assert details is not None
|
|
481
|
+
return Explanation(
|
|
482
|
+
prediction=prediction,
|
|
483
|
+
summary=summary,
|
|
484
|
+
details=details,
|
|
186
485
|
)
|
|
187
|
-
with warnings.catch_warnings():
|
|
188
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
189
|
-
request_msg = request.to_protobuf()
|
|
190
|
-
request_bytes = request_msg.SerializeToString()
|
|
191
|
-
logger.log(f"Generated context of size "
|
|
192
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
193
486
|
|
|
194
|
-
|
|
195
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
196
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
487
|
+
return prediction
|
|
197
488
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
raise RuntimeError(f"An unexpected exception occurred. "
|
|
209
|
-
f"Please create an issue at "
|
|
210
|
-
f"'https://github.com/kumo-ai/kumo-rfm'. "
|
|
211
|
-
f"{msg}") from None
|
|
489
|
+
def is_valid_entity(
|
|
490
|
+
self,
|
|
491
|
+
query: str,
|
|
492
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
493
|
+
*,
|
|
494
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
495
|
+
) -> np.ndarray:
|
|
496
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
497
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
498
|
+
entity filter constraints.
|
|
212
499
|
|
|
213
|
-
|
|
500
|
+
Args:
|
|
501
|
+
query: The predictive query.
|
|
502
|
+
indices: The entity primary keys to predict for. Will override the
|
|
503
|
+
indices given as part of the predictive query.
|
|
504
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
505
|
+
``None``, will use the maximum timestamp in the data.
|
|
506
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
507
|
+
"""
|
|
508
|
+
query_def = self._parse_query(query)
|
|
509
|
+
|
|
510
|
+
if indices is None:
|
|
511
|
+
if query_def.rfm_entity_ids is None:
|
|
512
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
513
|
+
"pass them via "
|
|
514
|
+
"`is_valid_entity(query, indices=...)`")
|
|
515
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
516
|
+
|
|
517
|
+
if len(indices) == 0:
|
|
518
|
+
raise ValueError("At least one entity is required")
|
|
519
|
+
|
|
520
|
+
if anchor_time is None:
|
|
521
|
+
anchor_time = self._graph_store.max_time
|
|
522
|
+
|
|
523
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
524
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
525
|
+
else:
|
|
526
|
+
assert anchor_time == 'entity'
|
|
527
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
528
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
|
+
f"table '{query_def.entity_table}' "
|
|
530
|
+
f"to have a time column.")
|
|
531
|
+
|
|
532
|
+
node = self._graph_store.get_node_id(
|
|
533
|
+
table_name=query_def.entity_table,
|
|
534
|
+
pkey=pd.Series(indices),
|
|
535
|
+
)
|
|
536
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
+
return query_driver.is_valid(node, anchor_time)
|
|
214
538
|
|
|
215
539
|
def evaluate(
|
|
216
540
|
self,
|
|
@@ -218,21 +542,26 @@ class KumoRFM:
|
|
|
218
542
|
*,
|
|
219
543
|
metrics: Optional[List[str]] = None,
|
|
220
544
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
545
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
221
546
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
222
547
|
num_neighbors: Optional[List[int]] = None,
|
|
223
548
|
num_hops: int = 2,
|
|
224
549
|
max_pq_iterations: int = 20,
|
|
225
550
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
226
|
-
verbose: bool = True,
|
|
551
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
552
|
+
use_prediction_time: bool = False,
|
|
227
553
|
) -> pd.DataFrame:
|
|
228
554
|
"""Evaluates a predictive query.
|
|
229
555
|
|
|
230
556
|
Args:
|
|
231
557
|
query: The predictive query.
|
|
232
558
|
metrics: The metrics to use.
|
|
233
|
-
anchor_time: The anchor timestamp for the
|
|
234
|
-
|
|
235
|
-
If set to
|
|
559
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
560
|
+
``None``, will use the maximum timestamp in the data.
|
|
561
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
562
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
563
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
564
|
+
determine the anchor time for context examples.
|
|
236
565
|
run_mode: The :class:`RunMode` for the query.
|
|
237
566
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
238
567
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -244,6 +573,9 @@ class KumoRFM:
|
|
|
244
573
|
entities to find valid labels.
|
|
245
574
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
246
575
|
verbose: Whether to print verbose output.
|
|
576
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
577
|
+
additional feature during prediction. This is typically
|
|
578
|
+
beneficial for time series forecasting tasks.
|
|
247
579
|
|
|
248
580
|
Returns:
|
|
249
581
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -254,13 +586,24 @@ class KumoRFM:
|
|
|
254
586
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
255
587
|
f"custom 'num_hops={num_hops}' option")
|
|
256
588
|
|
|
589
|
+
if query_def.rfm_entity_ids is not None:
|
|
590
|
+
query_def = replace(
|
|
591
|
+
query_def,
|
|
592
|
+
rfm_entity_ids=None,
|
|
593
|
+
)
|
|
594
|
+
|
|
257
595
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
258
596
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
259
597
|
|
|
260
|
-
|
|
598
|
+
if not isinstance(verbose, ProgressLogger):
|
|
599
|
+
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
600
|
+
|
|
601
|
+
with verbose as logger:
|
|
261
602
|
context = self._get_context(
|
|
262
|
-
query_def,
|
|
603
|
+
query=query_def,
|
|
604
|
+
indices=None,
|
|
263
605
|
anchor_time=anchor_time,
|
|
606
|
+
context_anchor_time=context_anchor_time,
|
|
264
607
|
run_mode=RunMode(run_mode),
|
|
265
608
|
num_neighbors=num_neighbors,
|
|
266
609
|
num_hops=num_hops,
|
|
@@ -276,6 +619,7 @@ class KumoRFM:
|
|
|
276
619
|
context=context,
|
|
277
620
|
run_mode=RunMode(run_mode),
|
|
278
621
|
metrics=metrics,
|
|
622
|
+
use_prediction_time=use_prediction_time,
|
|
279
623
|
)
|
|
280
624
|
with warnings.catch_warnings():
|
|
281
625
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -334,17 +678,19 @@ class KumoRFM:
|
|
|
334
678
|
|
|
335
679
|
if anchor_time is None:
|
|
336
680
|
anchor_time = self._graph_store.max_time
|
|
337
|
-
|
|
681
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
+
anchor_time = anchor_time - (
|
|
683
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
684
|
+
query_def.num_forecasts)
|
|
338
685
|
|
|
339
686
|
assert anchor_time is not None
|
|
340
687
|
if isinstance(anchor_time, pd.Timestamp):
|
|
341
|
-
self._validate_time(query_def, anchor_time, evaluate=True)
|
|
688
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
342
689
|
else:
|
|
343
690
|
assert anchor_time == 'entity'
|
|
344
|
-
if (query_def.
|
|
345
|
-
not in self._graph_store.time_dict):
|
|
691
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
346
692
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
347
|
-
f"table '{query_def.
|
|
693
|
+
f"table '{query_def.entity_table}' "
|
|
348
694
|
f"to have a time column")
|
|
349
695
|
|
|
350
696
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -355,18 +701,22 @@ class KumoRFM:
|
|
|
355
701
|
anchor_time=anchor_time,
|
|
356
702
|
batch_size=min(10_000, size),
|
|
357
703
|
max_iterations=max_iterations,
|
|
704
|
+
guarantee_train_examples=False,
|
|
358
705
|
)
|
|
359
706
|
|
|
707
|
+
entity = self._graph_store.pkey_map_dict[
|
|
708
|
+
query_def.entity_table].index[node]
|
|
709
|
+
|
|
360
710
|
return pd.DataFrame({
|
|
361
|
-
'ENTITY':
|
|
711
|
+
'ENTITY': entity,
|
|
362
712
|
'ANCHOR_TIMESTAMP': time,
|
|
363
713
|
'TARGET': y,
|
|
364
714
|
})
|
|
365
715
|
|
|
366
716
|
# Helpers #################################################################
|
|
367
717
|
|
|
368
|
-
def _parse_query(self, query: str) ->
|
|
369
|
-
if isinstance(query,
|
|
718
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
719
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
370
720
|
return query
|
|
371
721
|
|
|
372
722
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -376,12 +726,12 @@ class KumoRFM:
|
|
|
376
726
|
"predictions or evaluations.")
|
|
377
727
|
|
|
378
728
|
try:
|
|
379
|
-
request =
|
|
729
|
+
request = RFMParseQueryRequest(
|
|
380
730
|
query=query,
|
|
381
731
|
graph_definition=self._graph_def,
|
|
382
732
|
)
|
|
383
733
|
|
|
384
|
-
resp = global_state.client.rfm_api.
|
|
734
|
+
resp = global_state.client.rfm_api.parse_query(request)
|
|
385
735
|
# TODO Expose validation warnings.
|
|
386
736
|
|
|
387
737
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -392,7 +742,7 @@ class KumoRFM:
|
|
|
392
742
|
warnings.warn(f"Encountered the following warnings during "
|
|
393
743
|
f"parsing:\n{msg}")
|
|
394
744
|
|
|
395
|
-
return resp.
|
|
745
|
+
return resp.query
|
|
396
746
|
except HTTPException as e:
|
|
397
747
|
try:
|
|
398
748
|
msg = json.loads(e.detail)['detail']
|
|
@@ -403,8 +753,9 @@ class KumoRFM:
|
|
|
403
753
|
|
|
404
754
|
def _validate_time(
|
|
405
755
|
self,
|
|
406
|
-
query:
|
|
756
|
+
query: ValidatedPredictiveQuery,
|
|
407
757
|
anchor_time: pd.Timestamp,
|
|
758
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
408
759
|
evaluate: bool,
|
|
409
760
|
) -> None:
|
|
410
761
|
|
|
@@ -416,22 +767,45 @@ class KumoRFM:
|
|
|
416
767
|
f"the earliest timestamp "
|
|
417
768
|
f"'{self._graph_store.min_time}' in the data.")
|
|
418
769
|
|
|
419
|
-
if
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
f"
|
|
423
|
-
f"
|
|
424
|
-
f"however, your data
|
|
770
|
+
if (context_anchor_time is not None
|
|
771
|
+
and context_anchor_time < self._graph_store.min_time):
|
|
772
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
773
|
+
f"aggregation time range is too large. To make "
|
|
774
|
+
f"this prediction, we would need data back to "
|
|
775
|
+
f"'{context_anchor_time}', however, your data "
|
|
776
|
+
f"only contains data back to "
|
|
425
777
|
f"'{self._graph_store.min_time}'.")
|
|
426
778
|
|
|
427
|
-
if
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
779
|
+
if query.target_ast.date_offset_range is not None:
|
|
780
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
|
+
else:
|
|
782
|
+
end_offset = pd.DateOffset(0)
|
|
783
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
784
|
+
if (context_anchor_time is not None
|
|
785
|
+
and context_anchor_time > anchor_time):
|
|
786
|
+
warnings.warn(f"Context anchor timestamp "
|
|
787
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
788
|
+
f"date than the prediction anchor timestamp "
|
|
789
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
790
|
+
f"intended.")
|
|
791
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
792
|
+
and context_anchor_time is not None
|
|
793
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
794
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
795
|
+
f"'{context_anchor_time}' will leak information "
|
|
796
|
+
f"from the prediction anchor timestamp "
|
|
797
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
798
|
+
f"intended.")
|
|
799
|
+
|
|
800
|
+
elif (context_anchor_time is not None
|
|
801
|
+
and context_anchor_time - forecast_end_offset
|
|
802
|
+
< self._graph_store.min_time):
|
|
803
|
+
_time = context_anchor_time - forecast_end_offset
|
|
804
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
805
|
+
f"aggregation time range is too large. To form "
|
|
806
|
+
f"proper input data, we would need data back to "
|
|
807
|
+
f"'{_time}', however, your data only contains "
|
|
808
|
+
f"data back to '{self._graph_store.min_time}'.")
|
|
435
809
|
|
|
436
810
|
if (not evaluate and anchor_time
|
|
437
811
|
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
@@ -439,17 +813,18 @@ class KumoRFM:
|
|
|
439
813
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
440
814
|
f"in the data. Please make sure this is intended.")
|
|
441
815
|
|
|
442
|
-
|
|
443
|
-
|
|
816
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
817
|
+
if evaluate and anchor_time > max_eval_time:
|
|
444
818
|
raise ValueError(
|
|
445
819
|
f"Anchor timestamp for evaluation is after the latest "
|
|
446
|
-
f"supported timestamp "
|
|
447
|
-
f"'{self._graph_store.max_time - query.target.end_offset}'.")
|
|
820
|
+
f"supported timestamp '{max_eval_time}'.")
|
|
448
821
|
|
|
449
822
|
def _get_context(
|
|
450
823
|
self,
|
|
451
|
-
query:
|
|
824
|
+
query: ValidatedPredictiveQuery,
|
|
825
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
452
826
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
827
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
453
828
|
run_mode: RunMode,
|
|
454
829
|
num_neighbors: Optional[List[int]],
|
|
455
830
|
num_hops: int,
|
|
@@ -474,8 +849,8 @@ class KumoRFM:
|
|
|
474
849
|
f"must go beyond this for your use-case.")
|
|
475
850
|
|
|
476
851
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
477
|
-
task_type =
|
|
478
|
-
|
|
852
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
853
|
+
query,
|
|
479
854
|
edge_types=self._graph_store.edge_types,
|
|
480
855
|
)
|
|
481
856
|
|
|
@@ -507,28 +882,42 @@ class KumoRFM:
|
|
|
507
882
|
else:
|
|
508
883
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
509
884
|
|
|
885
|
+
if query.target_ast.date_offset_range is None:
|
|
886
|
+
end_offset = pd.DateOffset(0)
|
|
887
|
+
else:
|
|
888
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
889
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
510
890
|
if anchor_time is None:
|
|
511
891
|
anchor_time = self._graph_store.max_time
|
|
512
892
|
if evaluate:
|
|
513
|
-
anchor_time = anchor_time -
|
|
893
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
514
894
|
if logger is not None:
|
|
515
895
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
516
|
-
if
|
|
517
|
-
|
|
518
|
-
|
|
896
|
+
if anchor_time == pd.Timestamp.min:
|
|
897
|
+
pass # Static graph
|
|
898
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
899
|
+
and anchor_time.second == 0
|
|
900
|
+
and anchor_time.microsecond == 0):
|
|
519
901
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
520
902
|
else:
|
|
521
903
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
522
904
|
|
|
523
905
|
assert anchor_time is not None
|
|
524
906
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
|
-
|
|
907
|
+
if context_anchor_time is None:
|
|
908
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
909
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
910
|
+
evaluate)
|
|
526
911
|
else:
|
|
527
912
|
assert anchor_time == 'entity'
|
|
528
|
-
if query.
|
|
913
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
529
914
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
|
-
f"table '{query.
|
|
915
|
+
f"table '{query.entity_table}' to "
|
|
531
916
|
f"have a time column")
|
|
917
|
+
if context_anchor_time is not None:
|
|
918
|
+
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
919
|
+
"`anchor_time='entity'`")
|
|
920
|
+
context_anchor_time = None
|
|
532
921
|
|
|
533
922
|
y_test: Optional[pd.Series] = None
|
|
534
923
|
if evaluate:
|
|
@@ -540,6 +929,7 @@ class KumoRFM:
|
|
|
540
929
|
size=max_test_size,
|
|
541
930
|
anchor_time=anchor_time,
|
|
542
931
|
max_iterations=max_pq_iterations,
|
|
932
|
+
guarantee_train_examples=True,
|
|
543
933
|
)
|
|
544
934
|
if logger is not None:
|
|
545
935
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -563,34 +953,31 @@ class KumoRFM:
|
|
|
563
953
|
logger.log(msg)
|
|
564
954
|
|
|
565
955
|
else:
|
|
566
|
-
assert
|
|
956
|
+
assert indices is not None
|
|
567
957
|
|
|
568
|
-
|
|
569
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
958
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
570
959
|
raise ValueError(f"Cannot predict for more than "
|
|
571
|
-
f"{
|
|
572
|
-
f"(got {len(
|
|
960
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
961
|
+
f"once (got {len(indices):,}). Use "
|
|
962
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
963
|
+
f"in batches")
|
|
573
964
|
|
|
574
965
|
test_node = self._graph_store.get_node_id(
|
|
575
|
-
table_name=query.
|
|
576
|
-
pkey=pd.Series(
|
|
577
|
-
query.entity.ids.value,
|
|
578
|
-
dtype=query.entity.ids.dtype,
|
|
579
|
-
),
|
|
966
|
+
table_name=query.entity_table,
|
|
967
|
+
pkey=pd.Series(indices),
|
|
580
968
|
)
|
|
581
969
|
|
|
582
970
|
if isinstance(anchor_time, pd.Timestamp):
|
|
583
971
|
test_time = pd.Series(anchor_time).repeat(
|
|
584
972
|
len(test_node)).reset_index(drop=True)
|
|
585
973
|
else:
|
|
586
|
-
time = self._graph_store.time_dict[
|
|
587
|
-
query.entity.pkey.table_name]
|
|
974
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
588
975
|
time = time[test_node] * 1000**3
|
|
589
976
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
590
977
|
|
|
591
978
|
train_node, train_time, y_train = query_driver.collect_train(
|
|
592
979
|
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
593
|
-
anchor_time=
|
|
980
|
+
anchor_time=context_anchor_time or 'entity',
|
|
594
981
|
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
595
982
|
or anchor_time == 'entity') else None,
|
|
596
983
|
max_iterations=max_pq_iterations,
|
|
@@ -617,12 +1004,23 @@ class KumoRFM:
|
|
|
617
1004
|
raise NotImplementedError
|
|
618
1005
|
logger.log(msg)
|
|
619
1006
|
|
|
620
|
-
entity_table_names
|
|
621
|
-
|
|
1007
|
+
entity_table_names: Tuple[str, ...]
|
|
1008
|
+
if task_type.is_link_pred:
|
|
1009
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1010
|
+
assert final_aggr is not None
|
|
1011
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
+
for edge_type in self._graph_store.edge_types:
|
|
1013
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
|
+
entity_table_names = (
|
|
1015
|
+
query.entity_table,
|
|
1016
|
+
edge_type[2],
|
|
1017
|
+
)
|
|
1018
|
+
else:
|
|
1019
|
+
entity_table_names = (query.entity_table, )
|
|
622
1020
|
|
|
623
1021
|
# Exclude the entity anchor time from the feature set to prevent
|
|
624
1022
|
# running out-of-distribution between in-context and test examples:
|
|
625
|
-
exclude_cols_dict = query.
|
|
1023
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
626
1024
|
if anchor_time == 'entity':
|
|
627
1025
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
628
1026
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -642,6 +1040,17 @@ class KumoRFM:
|
|
|
642
1040
|
exclude_cols_dict=exclude_cols_dict,
|
|
643
1041
|
)
|
|
644
1042
|
|
|
1043
|
+
if len(subgraph.table_dict) >= 15:
|
|
1044
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
1045
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
1046
|
+
f"Please create a feature request at "
|
|
1047
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1048
|
+
f"must go beyond this for your use-case.")
|
|
1049
|
+
|
|
1050
|
+
step_size: Optional[int] = None
|
|
1051
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1052
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
1053
|
+
|
|
645
1054
|
return Context(
|
|
646
1055
|
task_type=task_type,
|
|
647
1056
|
entity_table_names=entity_table_names,
|
|
@@ -649,6 +1058,7 @@ class KumoRFM:
|
|
|
649
1058
|
y_train=y_train,
|
|
650
1059
|
y_test=y_test,
|
|
651
1060
|
top_k=query.top_k,
|
|
1061
|
+
step_size=step_size,
|
|
652
1062
|
)
|
|
653
1063
|
|
|
654
1064
|
@staticmethod
|
|
@@ -664,7 +1074,7 @@ class KumoRFM:
|
|
|
664
1074
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
665
1075
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
666
1076
|
elif task_type == TaskType.REGRESSION:
|
|
667
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1077
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
668
1078
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
669
1079
|
supported_metrics = [
|
|
670
1080
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|