kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
- kumoai/experimental/rfm/backend/local/sampler.py +131 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +14 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +287 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +43 -4
- kumoai/experimental/rfm/local_pquery_driver.py +222 -27
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +153 -96
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/METADATA +12 -2
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/RECORD +40 -27
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511031731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -5,31 +5,41 @@ from collections import defaultdict
|
|
|
5
5
|
from collections.abc import Generator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
9
19
|
|
|
10
20
|
import numpy as np
|
|
11
21
|
import pandas as pd
|
|
12
22
|
from kumoapi.model_plan import RunMode
|
|
13
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
24
|
from kumoapi.rfm import Context
|
|
15
25
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
16
26
|
from kumoapi.rfm import (
|
|
17
|
-
PQueryDefinition,
|
|
18
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
19
29
|
RFMPredictRequest,
|
|
20
|
-
RFMValidateQueryRequest,
|
|
21
30
|
)
|
|
22
31
|
from kumoapi.task import TaskType
|
|
23
32
|
|
|
24
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
25
34
|
from kumoai.exceptions import HTTPException
|
|
26
|
-
from kumoai.experimental.rfm import
|
|
35
|
+
from kumoai.experimental.rfm import Graph
|
|
36
|
+
from kumoai.experimental.rfm.backend.local import LocalGraphStore
|
|
27
37
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
28
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
29
38
|
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
30
39
|
LocalPQueryDriver,
|
|
31
40
|
date_offset_to_seconds,
|
|
32
41
|
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
33
43
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
34
44
|
|
|
35
45
|
_RANDOM_SEED = 42
|
|
@@ -60,6 +70,17 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
|
60
70
|
"beyond this for your use-case.")
|
|
61
71
|
|
|
62
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
63
84
|
@dataclass(repr=False)
|
|
64
85
|
class Explanation:
|
|
65
86
|
prediction: pd.DataFrame
|
|
@@ -87,6 +108,12 @@ class Explanation:
|
|
|
87
108
|
def __repr__(self) -> str:
|
|
88
109
|
return str((self.prediction, self.summary))
|
|
89
110
|
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
90
117
|
|
|
91
118
|
class KumoRFM:
|
|
92
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
@@ -96,17 +123,17 @@ class KumoRFM:
|
|
|
96
123
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
97
124
|
relational dataset without training.
|
|
98
125
|
The model is pre-trained and the class provides an interface to query the
|
|
99
|
-
model from a :class:`
|
|
126
|
+
model from a :class:`Graph` object.
|
|
100
127
|
|
|
101
128
|
.. code-block:: python
|
|
102
129
|
|
|
103
|
-
from kumoai.experimental.rfm import
|
|
130
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
104
131
|
|
|
105
132
|
df_users = pd.DataFrame(...)
|
|
106
133
|
df_items = pd.DataFrame(...)
|
|
107
134
|
df_orders = pd.DataFrame(...)
|
|
108
135
|
|
|
109
|
-
graph =
|
|
136
|
+
graph = Graph.from_data({
|
|
110
137
|
'users': df_users,
|
|
111
138
|
'items': df_items,
|
|
112
139
|
'orders': df_orders,
|
|
@@ -114,40 +141,41 @@ class KumoRFM:
|
|
|
114
141
|
|
|
115
142
|
rfm = KumoRFM(graph)
|
|
116
143
|
|
|
117
|
-
query = ("PREDICT COUNT(
|
|
118
|
-
"FOR users.user_id=
|
|
119
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
120
147
|
|
|
121
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
122
149
|
# 1 0.85
|
|
123
150
|
|
|
124
151
|
Args:
|
|
125
152
|
graph: The graph.
|
|
126
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
127
|
-
materialization.
|
|
128
|
-
This is a runtime trade-off between graph materialization and model
|
|
129
|
-
processing speed.
|
|
130
|
-
It can be benefical to preprocess your data once and then run many
|
|
131
|
-
queries on top to achieve maximum model speed.
|
|
132
|
-
However, if activiated, graph materialization can take potentially
|
|
133
|
-
much longer, especially on graphs with many large text columns.
|
|
134
|
-
Best to tune this option manually.
|
|
135
153
|
verbose: Whether to print verbose output.
|
|
136
154
|
"""
|
|
137
155
|
def __init__(
|
|
138
156
|
self,
|
|
139
|
-
graph:
|
|
140
|
-
preprocess: bool = False,
|
|
157
|
+
graph: Graph,
|
|
141
158
|
verbose: Union[bool, ProgressLogger] = True,
|
|
142
159
|
) -> None:
|
|
143
160
|
graph = graph.validate()
|
|
144
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
145
|
-
self._graph_store = LocalGraphStore(graph,
|
|
162
|
+
self._graph_store = LocalGraphStore(graph, verbose)
|
|
146
163
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
147
164
|
|
|
165
|
+
self._client: Optional[RFMAPI] = None
|
|
166
|
+
|
|
148
167
|
self._batch_size: Optional[int | Literal['max']] = None
|
|
149
168
|
self.num_retries: int = 0
|
|
150
169
|
|
|
170
|
+
@property
|
|
171
|
+
def _api_client(self) -> RFMAPI:
|
|
172
|
+
if self._client is not None:
|
|
173
|
+
return self._client
|
|
174
|
+
|
|
175
|
+
from kumoai.experimental.rfm import global_state
|
|
176
|
+
self._client = RFMAPI(global_state.client)
|
|
177
|
+
return self._client
|
|
178
|
+
|
|
151
179
|
def __repr__(self) -> str:
|
|
152
180
|
return f'{self.__class__.__name__}()'
|
|
153
181
|
|
|
@@ -209,7 +237,7 @@ class KumoRFM:
|
|
|
209
237
|
query: str,
|
|
210
238
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
211
239
|
*,
|
|
212
|
-
explain: Literal[True],
|
|
240
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
213
241
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
214
242
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
215
243
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -227,7 +255,7 @@ class KumoRFM:
|
|
|
227
255
|
query: str,
|
|
228
256
|
indices: Union[List[str], List[float], List[int], None] = None,
|
|
229
257
|
*,
|
|
230
|
-
explain: bool = False,
|
|
258
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
231
259
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
232
260
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
233
261
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
@@ -247,9 +275,12 @@ class KumoRFM:
|
|
|
247
275
|
be generated for all indices, independent of whether they
|
|
248
276
|
fulfill entity filter constraints. To pre-filter entities, use
|
|
249
277
|
:meth:`~KumoRFM.is_valid_entity`.
|
|
250
|
-
explain:
|
|
251
|
-
|
|
252
|
-
|
|
278
|
+
explain: Configuration for explainability.
|
|
279
|
+
If set to ``True``, will additionally explain the prediction.
|
|
280
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
281
|
+
over which parts of explanation are generated.
|
|
282
|
+
Explainability is currently only supported for single entity
|
|
283
|
+
predictions with ``run_mode="FAST"``.
|
|
253
284
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
254
285
|
``None``, will use the maximum timestamp in the data.
|
|
255
286
|
If set to ``"entity"``, will use the timestamp of the entity.
|
|
@@ -273,42 +304,48 @@ class KumoRFM:
|
|
|
273
304
|
|
|
274
305
|
Returns:
|
|
275
306
|
The predictions as a :class:`pandas.DataFrame`.
|
|
276
|
-
If ``explain
|
|
277
|
-
|
|
307
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
|
+
containing the prediction, summary, and details.
|
|
278
309
|
"""
|
|
310
|
+
explain_config: Optional[ExplainConfig] = None
|
|
311
|
+
if explain is True:
|
|
312
|
+
explain_config = ExplainConfig()
|
|
313
|
+
elif explain is not False:
|
|
314
|
+
explain_config = ExplainConfig._cast(explain)
|
|
315
|
+
|
|
279
316
|
query_def = self._parse_query(query)
|
|
317
|
+
query_str = query_def.to_string()
|
|
280
318
|
|
|
281
319
|
if num_hops != 2 and num_neighbors is not None:
|
|
282
320
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
283
321
|
f"custom 'num_hops={num_hops}' option")
|
|
284
322
|
|
|
285
|
-
if
|
|
323
|
+
if explain_config is not None and run_mode in {
|
|
324
|
+
RunMode.NORMAL, RunMode.BEST
|
|
325
|
+
}:
|
|
286
326
|
warnings.warn(f"Explainability is currently only supported for "
|
|
287
327
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
288
328
|
f"mode has been reset. Please lower the run mode to "
|
|
289
329
|
f"suppress this warning.")
|
|
290
330
|
|
|
291
331
|
if indices is None:
|
|
292
|
-
if query_def.
|
|
332
|
+
if query_def.rfm_entity_ids is None:
|
|
293
333
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
294
334
|
"pass them via `predict(query, indices=...)`")
|
|
295
|
-
indices = query_def.
|
|
335
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
296
336
|
else:
|
|
297
|
-
query_def = replace(
|
|
298
|
-
query_def,
|
|
299
|
-
entity=replace(query_def.entity, ids=None),
|
|
300
|
-
)
|
|
337
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
301
338
|
|
|
302
339
|
if len(indices) == 0:
|
|
303
340
|
raise ValueError("At least one entity is required")
|
|
304
341
|
|
|
305
|
-
if
|
|
342
|
+
if explain_config is not None and len(indices) > 1:
|
|
306
343
|
raise ValueError(
|
|
307
344
|
f"Cannot explain predictions for more than a single entity "
|
|
308
345
|
f"(got {len(indices)})")
|
|
309
346
|
|
|
310
347
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
311
|
-
if
|
|
348
|
+
if explain_config is not None:
|
|
312
349
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
313
350
|
else:
|
|
314
351
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
@@ -320,8 +357,8 @@ class KumoRFM:
|
|
|
320
357
|
|
|
321
358
|
batch_size: Optional[int] = None
|
|
322
359
|
if self._batch_size == 'max':
|
|
323
|
-
task_type =
|
|
324
|
-
|
|
360
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
361
|
+
query_def,
|
|
325
362
|
edge_types=self._graph_store.edge_types,
|
|
326
363
|
)
|
|
327
364
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
@@ -359,6 +396,7 @@ class KumoRFM:
|
|
|
359
396
|
request = RFMPredictRequest(
|
|
360
397
|
context=context,
|
|
361
398
|
run_mode=RunMode(run_mode),
|
|
399
|
+
query=query_str,
|
|
362
400
|
use_prediction_time=use_prediction_time,
|
|
363
401
|
)
|
|
364
402
|
with warnings.catch_warnings():
|
|
@@ -382,17 +420,20 @@ class KumoRFM:
|
|
|
382
420
|
|
|
383
421
|
for attempt in range(self.num_retries + 1):
|
|
384
422
|
try:
|
|
385
|
-
if
|
|
386
|
-
resp =
|
|
423
|
+
if explain_config is not None:
|
|
424
|
+
resp = self._api_client.explain(
|
|
425
|
+
request=_bytes,
|
|
426
|
+
skip_summary=explain_config.skip_summary,
|
|
427
|
+
)
|
|
387
428
|
summary = resp.summary
|
|
388
429
|
details = resp.details
|
|
389
430
|
else:
|
|
390
|
-
resp =
|
|
431
|
+
resp = self._api_client.predict(_bytes)
|
|
391
432
|
df = pd.DataFrame(**resp.prediction)
|
|
392
433
|
|
|
393
434
|
# Cast 'ENTITY' to correct data type:
|
|
394
435
|
if 'ENTITY' in df:
|
|
395
|
-
entity = query_def.
|
|
436
|
+
entity = query_def.entity_table
|
|
396
437
|
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
397
438
|
df['ENTITY'] = df['ENTITY'].astype(
|
|
398
439
|
type(pkey_map.index[0]))
|
|
@@ -434,7 +475,7 @@ class KumoRFM:
|
|
|
434
475
|
else:
|
|
435
476
|
prediction = pd.concat(predictions, ignore_index=True)
|
|
436
477
|
|
|
437
|
-
if
|
|
478
|
+
if explain_config is not None:
|
|
438
479
|
assert len(predictions) == 1
|
|
439
480
|
assert summary is not None
|
|
440
481
|
assert details is not None
|
|
@@ -468,11 +509,11 @@ class KumoRFM:
|
|
|
468
509
|
query_def = self._parse_query(query)
|
|
469
510
|
|
|
470
511
|
if indices is None:
|
|
471
|
-
if query_def.
|
|
512
|
+
if query_def.rfm_entity_ids is None:
|
|
472
513
|
raise ValueError("Cannot find entities to predict for. Please "
|
|
473
514
|
"pass them via "
|
|
474
515
|
"`is_valid_entity(query, indices=...)`")
|
|
475
|
-
indices = query_def.
|
|
516
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
476
517
|
|
|
477
518
|
if len(indices) == 0:
|
|
478
519
|
raise ValueError("At least one entity is required")
|
|
@@ -484,14 +525,13 @@ class KumoRFM:
|
|
|
484
525
|
self._validate_time(query_def, anchor_time, None, False)
|
|
485
526
|
else:
|
|
486
527
|
assert anchor_time == 'entity'
|
|
487
|
-
if (query_def.
|
|
488
|
-
not in self._graph_store.time_dict):
|
|
528
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
489
529
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
490
|
-
f"table '{query_def.
|
|
491
|
-
f"to have a time column")
|
|
530
|
+
f"table '{query_def.entity_table}' "
|
|
531
|
+
f"to have a time column.")
|
|
492
532
|
|
|
493
533
|
node = self._graph_store.get_node_id(
|
|
494
|
-
table_name=query_def.
|
|
534
|
+
table_name=query_def.entity_table,
|
|
495
535
|
pkey=pd.Series(indices),
|
|
496
536
|
)
|
|
497
537
|
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
@@ -547,10 +587,10 @@ class KumoRFM:
|
|
|
547
587
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
548
588
|
f"custom 'num_hops={num_hops}' option")
|
|
549
589
|
|
|
550
|
-
if query_def.
|
|
590
|
+
if query_def.rfm_entity_ids is not None:
|
|
551
591
|
query_def = replace(
|
|
552
592
|
query_def,
|
|
553
|
-
|
|
593
|
+
rfm_entity_ids=None,
|
|
554
594
|
)
|
|
555
595
|
|
|
556
596
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
@@ -591,10 +631,10 @@ class KumoRFM:
|
|
|
591
631
|
|
|
592
632
|
if len(request_bytes) > _MAX_SIZE:
|
|
593
633
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
594
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
634
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
595
635
|
|
|
596
636
|
try:
|
|
597
|
-
resp =
|
|
637
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
598
638
|
except HTTPException as e:
|
|
599
639
|
try:
|
|
600
640
|
msg = json.loads(e.detail)['detail']
|
|
@@ -639,18 +679,19 @@ class KumoRFM:
|
|
|
639
679
|
|
|
640
680
|
if anchor_time is None:
|
|
641
681
|
anchor_time = self._graph_store.max_time
|
|
642
|
-
|
|
643
|
-
|
|
682
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
+
anchor_time = anchor_time - (
|
|
684
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
685
|
+
query_def.num_forecasts)
|
|
644
686
|
|
|
645
687
|
assert anchor_time is not None
|
|
646
688
|
if isinstance(anchor_time, pd.Timestamp):
|
|
647
689
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
648
690
|
else:
|
|
649
691
|
assert anchor_time == 'entity'
|
|
650
|
-
if (query_def.
|
|
651
|
-
not in self._graph_store.time_dict):
|
|
692
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
652
693
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
653
|
-
f"table '{query_def.
|
|
694
|
+
f"table '{query_def.entity_table}' "
|
|
654
695
|
f"to have a time column")
|
|
655
696
|
|
|
656
697
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -665,7 +706,7 @@ class KumoRFM:
|
|
|
665
706
|
)
|
|
666
707
|
|
|
667
708
|
entity = self._graph_store.pkey_map_dict[
|
|
668
|
-
query_def.
|
|
709
|
+
query_def.entity_table].index[node]
|
|
669
710
|
|
|
670
711
|
return pd.DataFrame({
|
|
671
712
|
'ENTITY': entity,
|
|
@@ -675,8 +716,8 @@ class KumoRFM:
|
|
|
675
716
|
|
|
676
717
|
# Helpers #################################################################
|
|
677
718
|
|
|
678
|
-
def _parse_query(self, query: str) ->
|
|
679
|
-
if isinstance(query,
|
|
719
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
720
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
680
721
|
return query
|
|
681
722
|
|
|
682
723
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -686,12 +727,13 @@ class KumoRFM:
|
|
|
686
727
|
"predictions or evaluations.")
|
|
687
728
|
|
|
688
729
|
try:
|
|
689
|
-
request =
|
|
730
|
+
request = RFMParseQueryRequest(
|
|
690
731
|
query=query,
|
|
691
732
|
graph_definition=self._graph_def,
|
|
692
733
|
)
|
|
693
734
|
|
|
694
|
-
resp =
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
736
|
+
|
|
695
737
|
# TODO Expose validation warnings.
|
|
696
738
|
|
|
697
739
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -702,7 +744,7 @@ class KumoRFM:
|
|
|
702
744
|
warnings.warn(f"Encountered the following warnings during "
|
|
703
745
|
f"parsing:\n{msg}")
|
|
704
746
|
|
|
705
|
-
return resp.
|
|
747
|
+
return resp.query
|
|
706
748
|
except HTTPException as e:
|
|
707
749
|
try:
|
|
708
750
|
msg = json.loads(e.detail)['detail']
|
|
@@ -713,7 +755,7 @@ class KumoRFM:
|
|
|
713
755
|
|
|
714
756
|
def _validate_time(
|
|
715
757
|
self,
|
|
716
|
-
query:
|
|
758
|
+
query: ValidatedPredictiveQuery,
|
|
717
759
|
anchor_time: pd.Timestamp,
|
|
718
760
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
719
761
|
evaluate: bool,
|
|
@@ -736,6 +778,11 @@ class KumoRFM:
|
|
|
736
778
|
f"only contains data back to "
|
|
737
779
|
f"'{self._graph_store.min_time}'.")
|
|
738
780
|
|
|
781
|
+
if query.target_ast.date_offset_range is not None:
|
|
782
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
|
+
else:
|
|
784
|
+
end_offset = pd.DateOffset(0)
|
|
785
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
739
786
|
if (context_anchor_time is not None
|
|
740
787
|
and context_anchor_time > anchor_time):
|
|
741
788
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -744,19 +791,18 @@ class KumoRFM:
|
|
|
744
791
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
745
792
|
f"intended.")
|
|
746
793
|
elif (query.query_type == QueryType.TEMPORAL
|
|
747
|
-
and context_anchor_time is not None
|
|
748
|
-
|
|
794
|
+
and context_anchor_time is not None
|
|
795
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
749
796
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
750
797
|
f"'{context_anchor_time}' will leak information "
|
|
751
798
|
f"from the prediction anchor timestamp "
|
|
752
799
|
f"'{anchor_time}'. Please make sure this is "
|
|
753
800
|
f"intended.")
|
|
754
801
|
|
|
755
|
-
elif (context_anchor_time is not None
|
|
756
|
-
|
|
802
|
+
elif (context_anchor_time is not None
|
|
803
|
+
and context_anchor_time - forecast_end_offset
|
|
757
804
|
< self._graph_store.min_time):
|
|
758
|
-
_time = context_anchor_time -
|
|
759
|
-
query.num_forecasts)
|
|
805
|
+
_time = context_anchor_time - forecast_end_offset
|
|
760
806
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
761
807
|
f"aggregation time range is too large. To form "
|
|
762
808
|
f"proper input data, we would need data back to "
|
|
@@ -769,8 +815,7 @@ class KumoRFM:
|
|
|
769
815
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
770
816
|
f"in the data. Please make sure this is intended.")
|
|
771
817
|
|
|
772
|
-
max_eval_time =
|
|
773
|
-
query.target.end_offset * query.num_forecasts)
|
|
818
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
774
819
|
if evaluate and anchor_time > max_eval_time:
|
|
775
820
|
raise ValueError(
|
|
776
821
|
f"Anchor timestamp for evaluation is after the latest "
|
|
@@ -778,7 +823,7 @@ class KumoRFM:
|
|
|
778
823
|
|
|
779
824
|
def _get_context(
|
|
780
825
|
self,
|
|
781
|
-
query:
|
|
826
|
+
query: ValidatedPredictiveQuery,
|
|
782
827
|
indices: Union[List[str], List[float], List[int], None],
|
|
783
828
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
784
829
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
@@ -806,8 +851,8 @@ class KumoRFM:
|
|
|
806
851
|
f"must go beyond this for your use-case.")
|
|
807
852
|
|
|
808
853
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
809
|
-
task_type =
|
|
810
|
-
|
|
854
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
855
|
+
query,
|
|
811
856
|
edge_types=self._graph_store.edge_types,
|
|
812
857
|
)
|
|
813
858
|
|
|
@@ -839,11 +884,15 @@ class KumoRFM:
|
|
|
839
884
|
else:
|
|
840
885
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
841
886
|
|
|
887
|
+
if query.target_ast.date_offset_range is None:
|
|
888
|
+
end_offset = pd.DateOffset(0)
|
|
889
|
+
else:
|
|
890
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
891
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
842
892
|
if anchor_time is None:
|
|
843
893
|
anchor_time = self._graph_store.max_time
|
|
844
894
|
if evaluate:
|
|
845
|
-
anchor_time = anchor_time -
|
|
846
|
-
query.num_forecasts)
|
|
895
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
847
896
|
if logger is not None:
|
|
848
897
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
849
898
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -858,15 +907,14 @@ class KumoRFM:
|
|
|
858
907
|
assert anchor_time is not None
|
|
859
908
|
if isinstance(anchor_time, pd.Timestamp):
|
|
860
909
|
if context_anchor_time is None:
|
|
861
|
-
context_anchor_time = anchor_time -
|
|
862
|
-
query.num_forecasts)
|
|
910
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
863
911
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
864
912
|
evaluate)
|
|
865
913
|
else:
|
|
866
914
|
assert anchor_time == 'entity'
|
|
867
|
-
if query.
|
|
915
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
868
916
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
869
|
-
f"table '{query.
|
|
917
|
+
f"table '{query.entity_table}' to "
|
|
870
918
|
f"have a time column")
|
|
871
919
|
if context_anchor_time is not None:
|
|
872
920
|
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
@@ -917,7 +965,7 @@ class KumoRFM:
|
|
|
917
965
|
f"in batches")
|
|
918
966
|
|
|
919
967
|
test_node = self._graph_store.get_node_id(
|
|
920
|
-
table_name=query.
|
|
968
|
+
table_name=query.entity_table,
|
|
921
969
|
pkey=pd.Series(indices),
|
|
922
970
|
)
|
|
923
971
|
|
|
@@ -925,8 +973,7 @@ class KumoRFM:
|
|
|
925
973
|
test_time = pd.Series(anchor_time).repeat(
|
|
926
974
|
len(test_node)).reset_index(drop=True)
|
|
927
975
|
else:
|
|
928
|
-
time = self._graph_store.time_dict[
|
|
929
|
-
query.entity.pkey.table_name]
|
|
976
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
930
977
|
time = time[test_node] * 1000**3
|
|
931
978
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
932
979
|
|
|
@@ -959,12 +1006,23 @@ class KumoRFM:
|
|
|
959
1006
|
raise NotImplementedError
|
|
960
1007
|
logger.log(msg)
|
|
961
1008
|
|
|
962
|
-
entity_table_names
|
|
963
|
-
|
|
1009
|
+
entity_table_names: Tuple[str, ...]
|
|
1010
|
+
if task_type.is_link_pred:
|
|
1011
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1012
|
+
assert final_aggr is not None
|
|
1013
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
+
for edge_type in self._graph_store.edge_types:
|
|
1015
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
|
+
entity_table_names = (
|
|
1017
|
+
query.entity_table,
|
|
1018
|
+
edge_type[2],
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
entity_table_names = (query.entity_table, )
|
|
964
1022
|
|
|
965
1023
|
# Exclude the entity anchor time from the feature set to prevent
|
|
966
1024
|
# running out-of-distribution between in-context and test examples:
|
|
967
|
-
exclude_cols_dict = query.
|
|
1025
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
968
1026
|
if anchor_time == 'entity':
|
|
969
1027
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
970
1028
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -979,7 +1037,6 @@ class KumoRFM:
|
|
|
979
1037
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
980
1038
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
981
1039
|
]),
|
|
982
|
-
run_mode=run_mode,
|
|
983
1040
|
num_neighbors=num_neighbors,
|
|
984
1041
|
exclude_cols_dict=exclude_cols_dict,
|
|
985
1042
|
)
|
|
@@ -993,7 +1050,7 @@ class KumoRFM:
|
|
|
993
1050
|
|
|
994
1051
|
step_size: Optional[int] = None
|
|
995
1052
|
if query.query_type == QueryType.TEMPORAL:
|
|
996
|
-
step_size = date_offset_to_seconds(
|
|
1053
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
997
1054
|
|
|
998
1055
|
return Context(
|
|
999
1056
|
task_type=task_type,
|