kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +37 -8
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
- kumoai/experimental/rfm/base/__init__.py +25 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +60 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
- kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
- kumoai/experimental/rfm/rfm.py +669 -246
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +1 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/trainer.py +12 -10
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +239 -4
- kumoai/utils/sql.py +3 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
- kumoai/experimental/rfm/local_graph_sampler.py +0 -176
- kumoai/experimental/rfm/local_pquery_driver.py +0 -404
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,33 +1,56 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.pquery.AST import (
|
|
25
|
+
Aggregation,
|
|
26
|
+
Column,
|
|
27
|
+
Condition,
|
|
28
|
+
Join,
|
|
29
|
+
LogicalOperation,
|
|
30
|
+
)
|
|
31
|
+
from kumoapi.rfm import Context
|
|
32
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
33
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
34
|
RFMEvaluateRequest,
|
|
35
|
+
RFMParseQueryRequest,
|
|
13
36
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
37
|
)
|
|
16
38
|
from kumoapi.task import TaskType
|
|
39
|
+
from kumoapi.typing import AggregationType, Stype
|
|
17
40
|
|
|
18
|
-
from kumoai import
|
|
41
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
42
|
+
from kumoai.client.rfm import RFMAPI
|
|
19
43
|
from kumoai.exceptions import HTTPException
|
|
20
|
-
from kumoai.experimental.rfm import
|
|
21
|
-
from kumoai.experimental.rfm.
|
|
22
|
-
from kumoai.
|
|
23
|
-
from kumoai.
|
|
24
|
-
LocalPQueryDriver,
|
|
25
|
-
date_offset_to_seconds,
|
|
26
|
-
)
|
|
27
|
-
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
44
|
+
from kumoai.experimental.rfm import Graph
|
|
45
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
46
|
+
from kumoai.mixin import CastMixin
|
|
47
|
+
from kumoai.utils import ProgressLogger
|
|
28
48
|
|
|
29
49
|
_RANDOM_SEED = 42
|
|
30
50
|
|
|
51
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
52
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
53
|
+
|
|
31
54
|
_MAX_CONTEXT_SIZE = {
|
|
32
55
|
RunMode.DEBUG: 100,
|
|
33
56
|
RunMode.FAST: 1_000,
|
|
@@ -42,7 +65,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
42
65
|
}
|
|
43
66
|
|
|
44
67
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
45
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
68
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
46
69
|
"reduce either the number of tables in the graph, their "
|
|
47
70
|
"number of columns (e.g., large text columns), "
|
|
48
71
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -51,6 +74,68 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
51
74
|
"beyond this for your use-case.")
|
|
52
75
|
|
|
53
76
|
|
|
77
|
+
@dataclass(repr=False)
|
|
78
|
+
class ExplainConfig(CastMixin):
|
|
79
|
+
"""Configuration for explainability.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
83
|
+
the explanation.
|
|
84
|
+
"""
|
|
85
|
+
skip_summary: bool = False
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(repr=False)
|
|
89
|
+
class Explanation:
|
|
90
|
+
prediction: pd.DataFrame
|
|
91
|
+
summary: str
|
|
92
|
+
details: ExplanationConfig
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
@overload
|
|
99
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
103
|
+
if index == 0:
|
|
104
|
+
return self.prediction
|
|
105
|
+
if index == 1:
|
|
106
|
+
return self.summary
|
|
107
|
+
raise IndexError("Index out of range")
|
|
108
|
+
|
|
109
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
110
|
+
return iter((self.prediction, self.summary))
|
|
111
|
+
|
|
112
|
+
def __repr__(self) -> str:
|
|
113
|
+
return str((self.prediction, self.summary))
|
|
114
|
+
|
|
115
|
+
def print(self) -> None:
|
|
116
|
+
r"""Prints the explanation."""
|
|
117
|
+
if in_snowflake_notebook():
|
|
118
|
+
import streamlit as st
|
|
119
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
120
|
+
st.markdown(self.summary)
|
|
121
|
+
elif in_notebook():
|
|
122
|
+
from IPython.display import Markdown, display
|
|
123
|
+
try:
|
|
124
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
125
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
126
|
+
else:
|
|
127
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
128
|
+
except ImportError:
|
|
129
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
130
|
+
display(Markdown(self.summary))
|
|
131
|
+
else:
|
|
132
|
+
print(self.prediction.to_string(index=False))
|
|
133
|
+
print(self.summary)
|
|
134
|
+
|
|
135
|
+
def _ipython_display_(self) -> None:
|
|
136
|
+
self.print()
|
|
137
|
+
|
|
138
|
+
|
|
54
139
|
class KumoRFM:
|
|
55
140
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
56
141
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -59,17 +144,17 @@ class KumoRFM:
|
|
|
59
144
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
60
145
|
relational dataset without training.
|
|
61
146
|
The model is pre-trained and the class provides an interface to query the
|
|
62
|
-
model from a :class:`
|
|
147
|
+
model from a :class:`Graph` object.
|
|
63
148
|
|
|
64
149
|
.. code-block:: python
|
|
65
150
|
|
|
66
|
-
from kumoai.experimental.rfm import
|
|
151
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
67
152
|
|
|
68
153
|
df_users = pd.DataFrame(...)
|
|
69
154
|
df_items = pd.DataFrame(...)
|
|
70
155
|
df_orders = pd.DataFrame(...)
|
|
71
156
|
|
|
72
|
-
graph =
|
|
157
|
+
graph = Graph.from_data({
|
|
73
158
|
'users': df_users,
|
|
74
159
|
'items': df_items,
|
|
75
160
|
'orders': df_orders,
|
|
@@ -77,62 +162,166 @@ class KumoRFM:
|
|
|
77
162
|
|
|
78
163
|
rfm = KumoRFM(graph)
|
|
79
164
|
|
|
80
|
-
query = ("PREDICT COUNT(
|
|
81
|
-
"FOR users.user_id=
|
|
82
|
-
result = rfm.
|
|
165
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
166
|
+
"FOR users.user_id=1")
|
|
167
|
+
result = rfm.predict(query)
|
|
83
168
|
|
|
84
169
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
85
170
|
# 1 0.85
|
|
86
171
|
|
|
87
172
|
Args:
|
|
88
173
|
graph: The graph.
|
|
89
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
90
|
-
materialization.
|
|
91
|
-
This is a runtime trade-off between graph materialization and model
|
|
92
|
-
processing speed.
|
|
93
|
-
It can be benefical to preprocess your data once and then run many
|
|
94
|
-
queries on top to achieve maximum model speed.
|
|
95
|
-
However, if activiated, graph materialization can take potentially
|
|
96
|
-
much longer, especially on graphs with many large text columns.
|
|
97
|
-
Best to tune this option manually.
|
|
98
174
|
verbose: Whether to print verbose output.
|
|
175
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
176
|
+
for optimal querying. For example, for transactional database
|
|
177
|
+
backends, will create any missing indices. Requires write-access to
|
|
178
|
+
the data backend.
|
|
99
179
|
"""
|
|
100
180
|
def __init__(
|
|
101
181
|
self,
|
|
102
|
-
graph:
|
|
103
|
-
preprocess: bool = False,
|
|
182
|
+
graph: Graph,
|
|
104
183
|
verbose: Union[bool, ProgressLogger] = True,
|
|
184
|
+
optimize: bool = False,
|
|
105
185
|
) -> None:
|
|
106
186
|
graph = graph.validate()
|
|
107
187
|
self._graph_def = graph._to_api_graph_definition()
|
|
108
|
-
|
|
109
|
-
|
|
188
|
+
|
|
189
|
+
if graph.backend == DataBackend.LOCAL:
|
|
190
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
191
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
192
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
193
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
194
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
195
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
196
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
197
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
198
|
+
else:
|
|
199
|
+
raise NotImplementedError
|
|
200
|
+
|
|
201
|
+
self._client: Optional[RFMAPI] = None
|
|
202
|
+
|
|
203
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
204
|
+
self.num_retries: int = 0
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def _api_client(self) -> RFMAPI:
|
|
208
|
+
if self._client is not None:
|
|
209
|
+
return self._client
|
|
210
|
+
|
|
211
|
+
from kumoai.experimental.rfm import global_state
|
|
212
|
+
self._client = RFMAPI(global_state.client)
|
|
213
|
+
return self._client
|
|
110
214
|
|
|
111
215
|
def __repr__(self) -> str:
|
|
112
216
|
return f'{self.__class__.__name__}()'
|
|
113
217
|
|
|
218
|
+
@contextmanager
|
|
219
|
+
def batch_mode(
|
|
220
|
+
self,
|
|
221
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
222
|
+
num_retries: int = 1,
|
|
223
|
+
) -> Generator[None, None, None]:
|
|
224
|
+
"""Context manager to predict in batches.
|
|
225
|
+
|
|
226
|
+
.. code-block:: python
|
|
227
|
+
|
|
228
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
229
|
+
df = model.predict(query, indices=...)
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
233
|
+
maximum applicable batch size for the given task.
|
|
234
|
+
num_retries: The maximum number of retries for failed queries due
|
|
235
|
+
to unexpected server issues.
|
|
236
|
+
"""
|
|
237
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
238
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
239
|
+
f"(got {batch_size})")
|
|
240
|
+
|
|
241
|
+
if num_retries < 0:
|
|
242
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
243
|
+
f"zero (got {num_retries})")
|
|
244
|
+
|
|
245
|
+
self._batch_size = batch_size
|
|
246
|
+
self.num_retries = num_retries
|
|
247
|
+
yield
|
|
248
|
+
self._batch_size = None
|
|
249
|
+
self.num_retries = 0
|
|
250
|
+
|
|
251
|
+
@overload
|
|
114
252
|
def predict(
|
|
115
253
|
self,
|
|
116
254
|
query: str,
|
|
255
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
117
256
|
*,
|
|
257
|
+
explain: Literal[False] = False,
|
|
118
258
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
119
259
|
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
120
260
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
121
261
|
num_neighbors: Optional[List[int]] = None,
|
|
122
262
|
num_hops: int = 2,
|
|
123
|
-
max_pq_iterations: int =
|
|
263
|
+
max_pq_iterations: int = 10,
|
|
124
264
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
125
265
|
verbose: Union[bool, ProgressLogger] = True,
|
|
266
|
+
use_prediction_time: bool = False,
|
|
126
267
|
) -> pd.DataFrame:
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def predict(
|
|
272
|
+
self,
|
|
273
|
+
query: str,
|
|
274
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
275
|
+
*,
|
|
276
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
277
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
278
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
279
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
280
|
+
num_neighbors: Optional[List[int]] = None,
|
|
281
|
+
num_hops: int = 2,
|
|
282
|
+
max_pq_iterations: int = 10,
|
|
283
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
284
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
285
|
+
use_prediction_time: bool = False,
|
|
286
|
+
) -> Explanation:
|
|
287
|
+
pass
|
|
288
|
+
|
|
289
|
+
def predict(
|
|
290
|
+
self,
|
|
291
|
+
query: str,
|
|
292
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
293
|
+
*,
|
|
294
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
295
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
296
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
297
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
298
|
+
num_neighbors: Optional[List[int]] = None,
|
|
299
|
+
num_hops: int = 2,
|
|
300
|
+
max_pq_iterations: int = 10,
|
|
301
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
302
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
303
|
+
use_prediction_time: bool = False,
|
|
304
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
127
305
|
"""Returns predictions for a predictive query.
|
|
128
306
|
|
|
129
307
|
Args:
|
|
130
308
|
query: The predictive query.
|
|
309
|
+
indices: The entity primary keys to predict for. Will override the
|
|
310
|
+
indices given as part of the predictive query. Predictions will
|
|
311
|
+
be generated for all indices, independent of whether they
|
|
312
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
313
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
314
|
+
explain: Configuration for explainability.
|
|
315
|
+
If set to ``True``, will additionally explain the prediction.
|
|
316
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
317
|
+
over which parts of explanation are generated.
|
|
318
|
+
Explainability is currently only supported for single entity
|
|
319
|
+
predictions with ``run_mode="FAST"``.
|
|
131
320
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
132
|
-
|
|
133
|
-
If set to
|
|
321
|
+
``None``, will use the maximum timestamp in the data.
|
|
322
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
134
323
|
context_anchor_time: The maximum anchor timestamp for context
|
|
135
|
-
examples. If set to
|
|
324
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
136
325
|
determine the anchor time for context examples.
|
|
137
326
|
run_mode: The :class:`RunMode` for the query.
|
|
138
327
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -145,83 +334,237 @@ class KumoRFM:
|
|
|
145
334
|
entities to find valid labels.
|
|
146
335
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
147
336
|
verbose: Whether to print verbose output.
|
|
337
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
338
|
+
additional feature during prediction. This is typically
|
|
339
|
+
beneficial for time series forecasting tasks.
|
|
148
340
|
|
|
149
341
|
Returns:
|
|
150
|
-
The predictions as a :class:`pandas.DataFrame
|
|
342
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
343
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
344
|
+
containing the prediction, summary, and details.
|
|
151
345
|
"""
|
|
152
|
-
|
|
346
|
+
explain_config: Optional[ExplainConfig] = None
|
|
347
|
+
if explain is True:
|
|
348
|
+
explain_config = ExplainConfig()
|
|
349
|
+
elif explain is not False:
|
|
350
|
+
explain_config = ExplainConfig._cast(explain)
|
|
351
|
+
|
|
153
352
|
query_def = self._parse_query(query)
|
|
353
|
+
query_str = query_def.to_string()
|
|
154
354
|
|
|
155
355
|
if num_hops != 2 and num_neighbors is not None:
|
|
156
356
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
157
357
|
f"custom 'num_hops={num_hops}' option")
|
|
158
358
|
|
|
159
|
-
if
|
|
359
|
+
if explain_config is not None and run_mode in {
|
|
360
|
+
RunMode.NORMAL, RunMode.BEST
|
|
361
|
+
}:
|
|
160
362
|
warnings.warn(f"Explainability is currently only supported for "
|
|
161
363
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
162
364
|
f"mode has been reset. Please lower the run mode to "
|
|
163
365
|
f"suppress this warning.")
|
|
164
366
|
|
|
165
|
-
if
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
367
|
+
if indices is None:
|
|
368
|
+
if query_def.rfm_entity_ids is None:
|
|
369
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
370
|
+
"pass them via `predict(query, indices=...)`")
|
|
371
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
372
|
+
else:
|
|
373
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
374
|
+
|
|
375
|
+
if len(indices) == 0:
|
|
376
|
+
raise ValueError("At least one entity is required")
|
|
377
|
+
|
|
378
|
+
if explain_config is not None and len(indices) > 1:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
f"Cannot explain predictions for more than a single entity "
|
|
381
|
+
f"(got {len(indices)})")
|
|
171
382
|
|
|
172
383
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
173
|
-
if
|
|
384
|
+
if explain_config is not None:
|
|
174
385
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
175
386
|
else:
|
|
176
387
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
177
388
|
|
|
178
389
|
if not isinstance(verbose, ProgressLogger):
|
|
179
|
-
verbose =
|
|
390
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
180
391
|
|
|
181
392
|
with verbose as logger:
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
393
|
+
|
|
394
|
+
batch_size: Optional[int] = None
|
|
395
|
+
if self._batch_size == 'max':
|
|
396
|
+
task_type = self._get_task_type(
|
|
397
|
+
query=query_def,
|
|
398
|
+
edge_types=self._sampler.edge_types,
|
|
399
|
+
)
|
|
400
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
401
|
+
else:
|
|
402
|
+
batch_size = self._batch_size
|
|
403
|
+
|
|
404
|
+
if batch_size is not None:
|
|
405
|
+
offsets = range(0, len(indices), batch_size)
|
|
406
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
407
|
+
else:
|
|
408
|
+
batches = [indices]
|
|
409
|
+
|
|
410
|
+
if len(batches) > 1:
|
|
411
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
412
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
413
|
+
|
|
414
|
+
predictions: List[pd.DataFrame] = []
|
|
415
|
+
summary: Optional[str] = None
|
|
416
|
+
details: Optional[Explanation] = None
|
|
417
|
+
for i, batch in enumerate(batches):
|
|
418
|
+
# TODO Re-use the context for subsequent predictions.
|
|
419
|
+
context = self._get_context(
|
|
420
|
+
query=query_def,
|
|
421
|
+
indices=batch,
|
|
422
|
+
anchor_time=anchor_time,
|
|
423
|
+
context_anchor_time=context_anchor_time,
|
|
424
|
+
run_mode=RunMode(run_mode),
|
|
425
|
+
num_neighbors=num_neighbors,
|
|
426
|
+
num_hops=num_hops,
|
|
427
|
+
max_pq_iterations=max_pq_iterations,
|
|
428
|
+
evaluate=False,
|
|
429
|
+
random_seed=random_seed,
|
|
430
|
+
logger=logger if i == 0 else None,
|
|
431
|
+
)
|
|
432
|
+
request = RFMPredictRequest(
|
|
433
|
+
context=context,
|
|
434
|
+
run_mode=RunMode(run_mode),
|
|
435
|
+
query=query_str,
|
|
436
|
+
use_prediction_time=use_prediction_time,
|
|
437
|
+
)
|
|
438
|
+
with warnings.catch_warnings():
|
|
439
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
440
|
+
request_msg = request.to_protobuf()
|
|
441
|
+
_bytes = request_msg.SerializeToString()
|
|
442
|
+
if i == 0:
|
|
443
|
+
logger.log(f"Generated context of size "
|
|
444
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
445
|
+
|
|
446
|
+
if len(_bytes) > _MAX_SIZE:
|
|
447
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
448
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
449
|
+
|
|
450
|
+
if i == 0 and len(batches) > 1:
|
|
451
|
+
verbose.init_progress(
|
|
452
|
+
total=len(batches),
|
|
453
|
+
description='Predicting',
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
for attempt in range(self.num_retries + 1):
|
|
457
|
+
try:
|
|
458
|
+
if explain_config is not None:
|
|
459
|
+
resp = self._api_client.explain(
|
|
460
|
+
request=_bytes,
|
|
461
|
+
skip_summary=explain_config.skip_summary,
|
|
462
|
+
)
|
|
463
|
+
summary = resp.summary
|
|
464
|
+
details = resp.details
|
|
465
|
+
else:
|
|
466
|
+
resp = self._api_client.predict(_bytes)
|
|
467
|
+
df = pd.DataFrame(**resp.prediction)
|
|
468
|
+
|
|
469
|
+
# Cast 'ENTITY' to correct data type:
|
|
470
|
+
if 'ENTITY' in df:
|
|
471
|
+
table_dict = context.subgraph.table_dict
|
|
472
|
+
table = table_dict[query_def.entity_table]
|
|
473
|
+
ser = table.df[table.primary_key]
|
|
474
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
475
|
+
|
|
476
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
477
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
478
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
479
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
480
|
+
if isinstance(ser.iloc[0], str):
|
|
481
|
+
unit = None
|
|
482
|
+
else:
|
|
483
|
+
unit = 'ms'
|
|
484
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
485
|
+
ser, errors='coerce', unit=unit)
|
|
486
|
+
|
|
487
|
+
predictions.append(df)
|
|
488
|
+
|
|
489
|
+
if len(batches) > 1:
|
|
490
|
+
verbose.step()
|
|
491
|
+
|
|
492
|
+
break
|
|
493
|
+
except HTTPException as e:
|
|
494
|
+
if attempt == self.num_retries:
|
|
495
|
+
try:
|
|
496
|
+
msg = json.loads(e.detail)['detail']
|
|
497
|
+
except Exception:
|
|
498
|
+
msg = e.detail
|
|
499
|
+
raise RuntimeError(
|
|
500
|
+
f"An unexpected exception occurred. Please "
|
|
501
|
+
f"create an issue at "
|
|
502
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
503
|
+
) from None
|
|
504
|
+
|
|
505
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
506
|
+
|
|
507
|
+
if len(predictions) == 1:
|
|
508
|
+
prediction = predictions[0]
|
|
509
|
+
else:
|
|
510
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
511
|
+
|
|
512
|
+
if explain_config is not None:
|
|
513
|
+
assert len(predictions) == 1
|
|
514
|
+
assert summary is not None
|
|
515
|
+
assert details is not None
|
|
516
|
+
return Explanation(
|
|
517
|
+
prediction=prediction,
|
|
518
|
+
summary=summary,
|
|
519
|
+
details=details,
|
|
197
520
|
)
|
|
198
|
-
with warnings.catch_warnings():
|
|
199
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
200
|
-
request_msg = request.to_protobuf()
|
|
201
|
-
request_bytes = request_msg.SerializeToString()
|
|
202
|
-
logger.log(f"Generated context of size "
|
|
203
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
204
521
|
|
|
205
|
-
|
|
206
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
207
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
522
|
+
return prediction
|
|
208
523
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
raise RuntimeError(f"An unexpected exception occurred. "
|
|
220
|
-
f"Please create an issue at "
|
|
221
|
-
f"'https://github.com/kumo-ai/kumo-rfm'. "
|
|
222
|
-
f"{msg}") from None
|
|
524
|
+
def is_valid_entity(
|
|
525
|
+
self,
|
|
526
|
+
query: str,
|
|
527
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
528
|
+
*,
|
|
529
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
530
|
+
) -> np.ndarray:
|
|
531
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
532
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
533
|
+
entity filter constraints.
|
|
223
534
|
|
|
224
|
-
|
|
535
|
+
Args:
|
|
536
|
+
query: The predictive query.
|
|
537
|
+
indices: The entity primary keys to predict for. Will override the
|
|
538
|
+
indices given as part of the predictive query.
|
|
539
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
540
|
+
``None``, will use the maximum timestamp in the data.
|
|
541
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
542
|
+
"""
|
|
543
|
+
query_def = self._parse_query(query)
|
|
544
|
+
|
|
545
|
+
if indices is None:
|
|
546
|
+
if query_def.rfm_entity_ids is None:
|
|
547
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
548
|
+
"pass them via "
|
|
549
|
+
"`is_valid_entity(query, indices=...)`")
|
|
550
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
551
|
+
|
|
552
|
+
if len(indices) == 0:
|
|
553
|
+
raise ValueError("At least one entity is required")
|
|
554
|
+
|
|
555
|
+
if anchor_time is None:
|
|
556
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
557
|
+
|
|
558
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
559
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
560
|
+
else:
|
|
561
|
+
assert anchor_time == 'entity'
|
|
562
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
563
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
564
|
+
f"table '{query_def.entity_table}' "
|
|
565
|
+
f"to have a time column.")
|
|
566
|
+
|
|
567
|
+
raise NotImplementedError
|
|
225
568
|
|
|
226
569
|
def evaluate(
|
|
227
570
|
self,
|
|
@@ -233,9 +576,10 @@ class KumoRFM:
|
|
|
233
576
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
234
577
|
num_neighbors: Optional[List[int]] = None,
|
|
235
578
|
num_hops: int = 2,
|
|
236
|
-
max_pq_iterations: int =
|
|
579
|
+
max_pq_iterations: int = 10,
|
|
237
580
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
238
581
|
verbose: Union[bool, ProgressLogger] = True,
|
|
582
|
+
use_prediction_time: bool = False,
|
|
239
583
|
) -> pd.DataFrame:
|
|
240
584
|
"""Evaluates a predictive query.
|
|
241
585
|
|
|
@@ -243,10 +587,10 @@ class KumoRFM:
|
|
|
243
587
|
query: The predictive query.
|
|
244
588
|
metrics: The metrics to use.
|
|
245
589
|
anchor_time: The anchor timestamp for the prediction. If set to
|
|
246
|
-
|
|
247
|
-
If set to
|
|
590
|
+
``None``, will use the maximum timestamp in the data.
|
|
591
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
248
592
|
context_anchor_time: The maximum anchor timestamp for context
|
|
249
|
-
examples. If set to
|
|
593
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
250
594
|
determine the anchor time for context examples.
|
|
251
595
|
run_mode: The :class:`RunMode` for the query.
|
|
252
596
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
@@ -259,6 +603,9 @@ class KumoRFM:
|
|
|
259
603
|
entities to find valid labels.
|
|
260
604
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
261
605
|
verbose: Whether to print verbose output.
|
|
606
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
607
|
+
additional feature during prediction. This is typically
|
|
608
|
+
beneficial for time series forecasting tasks.
|
|
262
609
|
|
|
263
610
|
Returns:
|
|
264
611
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -269,15 +616,22 @@ class KumoRFM:
|
|
|
269
616
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
270
617
|
f"custom 'num_hops={num_hops}' option")
|
|
271
618
|
|
|
619
|
+
if query_def.rfm_entity_ids is not None:
|
|
620
|
+
query_def = replace(
|
|
621
|
+
query_def,
|
|
622
|
+
rfm_entity_ids=None,
|
|
623
|
+
)
|
|
624
|
+
|
|
272
625
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
273
626
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
274
627
|
|
|
275
628
|
if not isinstance(verbose, ProgressLogger):
|
|
276
|
-
verbose =
|
|
629
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
277
630
|
|
|
278
631
|
with verbose as logger:
|
|
279
632
|
context = self._get_context(
|
|
280
|
-
query_def,
|
|
633
|
+
query=query_def,
|
|
634
|
+
indices=None,
|
|
281
635
|
anchor_time=anchor_time,
|
|
282
636
|
context_anchor_time=context_anchor_time,
|
|
283
637
|
run_mode=RunMode(run_mode),
|
|
@@ -295,6 +649,7 @@ class KumoRFM:
|
|
|
295
649
|
context=context,
|
|
296
650
|
run_mode=RunMode(run_mode),
|
|
297
651
|
metrics=metrics,
|
|
652
|
+
use_prediction_time=use_prediction_time,
|
|
298
653
|
)
|
|
299
654
|
with warnings.catch_warnings():
|
|
300
655
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -305,10 +660,10 @@ class KumoRFM:
|
|
|
305
660
|
|
|
306
661
|
if len(request_bytes) > _MAX_SIZE:
|
|
307
662
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
308
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
663
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
309
664
|
|
|
310
665
|
try:
|
|
311
|
-
resp =
|
|
666
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
312
667
|
except HTTPException as e:
|
|
313
668
|
try:
|
|
314
669
|
msg = json.loads(e.detail)['detail']
|
|
@@ -332,7 +687,7 @@ class KumoRFM:
|
|
|
332
687
|
*,
|
|
333
688
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
334
689
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
335
|
-
max_iterations: int =
|
|
690
|
+
max_iterations: int = 10,
|
|
336
691
|
) -> pd.DataFrame:
|
|
337
692
|
"""Returns the labels of a predictive query for a specified anchor
|
|
338
693
|
time.
|
|
@@ -352,45 +707,43 @@ class KumoRFM:
|
|
|
352
707
|
query_def = self._parse_query(query)
|
|
353
708
|
|
|
354
709
|
if anchor_time is None:
|
|
355
|
-
anchor_time = self.
|
|
356
|
-
|
|
357
|
-
|
|
710
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
711
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
712
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
713
|
+
offset *= query_def.num_forecasts
|
|
714
|
+
anchor_time -= offset
|
|
358
715
|
|
|
359
716
|
assert anchor_time is not None
|
|
360
717
|
if isinstance(anchor_time, pd.Timestamp):
|
|
361
718
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
362
719
|
else:
|
|
363
720
|
assert anchor_time == 'entity'
|
|
364
|
-
if
|
|
365
|
-
not in self._graph_store.time_dict):
|
|
721
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
366
722
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
367
|
-
f"table '{query_def.
|
|
723
|
+
f"table '{query_def.entity_table}' "
|
|
368
724
|
f"to have a time column")
|
|
369
725
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
726
|
+
train, test = self._sampler.sample_target(
|
|
727
|
+
query=query,
|
|
728
|
+
num_train_examples=0,
|
|
729
|
+
train_anchor_time=anchor_time,
|
|
730
|
+
num_train_trials=0,
|
|
731
|
+
num_test_examples=size,
|
|
732
|
+
test_anchor_time=anchor_time,
|
|
733
|
+
num_test_trials=max_iterations * size,
|
|
734
|
+
random_seed=random_seed,
|
|
379
735
|
)
|
|
380
736
|
|
|
381
|
-
entity = self._graph_store.pkey_map_dict[
|
|
382
|
-
query_def.entity.pkey.table_name].index[node]
|
|
383
|
-
|
|
384
737
|
return pd.DataFrame({
|
|
385
|
-
'ENTITY':
|
|
386
|
-
'ANCHOR_TIMESTAMP':
|
|
387
|
-
'TARGET':
|
|
738
|
+
'ENTITY': test.entity_pkey,
|
|
739
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
740
|
+
'TARGET': test.target,
|
|
388
741
|
})
|
|
389
742
|
|
|
390
743
|
# Helpers #################################################################
|
|
391
744
|
|
|
392
|
-
def _parse_query(self, query: str) ->
|
|
393
|
-
if isinstance(query,
|
|
745
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
746
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
394
747
|
return query
|
|
395
748
|
|
|
396
749
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -400,13 +753,12 @@ class KumoRFM:
|
|
|
400
753
|
"predictions or evaluations.")
|
|
401
754
|
|
|
402
755
|
try:
|
|
403
|
-
request =
|
|
756
|
+
request = RFMParseQueryRequest(
|
|
404
757
|
query=query,
|
|
405
758
|
graph_definition=self._graph_def,
|
|
406
759
|
)
|
|
407
760
|
|
|
408
|
-
resp =
|
|
409
|
-
# TODO Expose validation warnings.
|
|
761
|
+
resp = self._api_client.parse_query(request)
|
|
410
762
|
|
|
411
763
|
if len(resp.validation_response.warnings) > 0:
|
|
412
764
|
msg = '\n'.join([
|
|
@@ -416,7 +768,7 @@ class KumoRFM:
|
|
|
416
768
|
warnings.warn(f"Encountered the following warnings during "
|
|
417
769
|
f"parsing:\n{msg}")
|
|
418
770
|
|
|
419
|
-
return resp.
|
|
771
|
+
return resp.query
|
|
420
772
|
except HTTPException as e:
|
|
421
773
|
try:
|
|
422
774
|
msg = json.loads(e.detail)['detail']
|
|
@@ -425,30 +777,91 @@ class KumoRFM:
|
|
|
425
777
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
426
778
|
f"{msg}") from None
|
|
427
779
|
|
|
780
|
+
@staticmethod
|
|
781
|
+
def _get_task_type(
|
|
782
|
+
query: ValidatedPredictiveQuery,
|
|
783
|
+
edge_types: List[Tuple[str, str, str]],
|
|
784
|
+
) -> TaskType:
|
|
785
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
786
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
787
|
+
|
|
788
|
+
target = query.target_ast
|
|
789
|
+
if isinstance(target, Join):
|
|
790
|
+
target = target.rhs_target
|
|
791
|
+
if isinstance(target, Aggregation):
|
|
792
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
793
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
794
|
+
'.')
|
|
795
|
+
target_edge_types = [
|
|
796
|
+
edge_type for edge_type in edge_types
|
|
797
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
798
|
+
]
|
|
799
|
+
if len(target_edge_types) != 1:
|
|
800
|
+
raise NotImplementedError(
|
|
801
|
+
f"Multilabel-classification queries based on "
|
|
802
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
803
|
+
f"planned to write a link prediction query instead, "
|
|
804
|
+
f"make sure to register '{col_name}' as a "
|
|
805
|
+
f"foreign key.")
|
|
806
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
807
|
+
|
|
808
|
+
return TaskType.REGRESSION
|
|
809
|
+
|
|
810
|
+
assert isinstance(target, Column)
|
|
811
|
+
|
|
812
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
813
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
814
|
+
|
|
815
|
+
if target.stype in {Stype.numerical}:
|
|
816
|
+
return TaskType.REGRESSION
|
|
817
|
+
|
|
818
|
+
raise NotImplementedError("Task type not yet supported")
|
|
819
|
+
|
|
820
|
+
def _get_default_anchor_time(
|
|
821
|
+
self,
|
|
822
|
+
query: ValidatedPredictiveQuery,
|
|
823
|
+
) -> pd.Timestamp:
|
|
824
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
825
|
+
aggr_table_names = [
|
|
826
|
+
aggr._get_target_column_name().split('.')[0]
|
|
827
|
+
for aggr in query.get_all_target_aggregations()
|
|
828
|
+
]
|
|
829
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
830
|
+
|
|
831
|
+
assert query.query_type == QueryType.STATIC
|
|
832
|
+
return self._sampler.get_max_time()
|
|
833
|
+
|
|
428
834
|
def _validate_time(
|
|
429
835
|
self,
|
|
430
|
-
query:
|
|
836
|
+
query: ValidatedPredictiveQuery,
|
|
431
837
|
anchor_time: pd.Timestamp,
|
|
432
838
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
433
839
|
evaluate: bool,
|
|
434
840
|
) -> None:
|
|
435
841
|
|
|
436
|
-
if self.
|
|
842
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
437
843
|
return # Graph without timestamps
|
|
438
844
|
|
|
439
|
-
|
|
845
|
+
min_time = self._sampler.get_min_time()
|
|
846
|
+
max_time = self._sampler.get_max_time()
|
|
847
|
+
|
|
848
|
+
if anchor_time < min_time:
|
|
440
849
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
441
|
-
f"the earliest timestamp "
|
|
442
|
-
f"
|
|
850
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
851
|
+
f"data.")
|
|
443
852
|
|
|
444
|
-
if
|
|
445
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
853
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
446
854
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
447
855
|
f"aggregation time range is too large. To make "
|
|
448
856
|
f"this prediction, we would need data back to "
|
|
449
857
|
f"'{context_anchor_time}', however, your data "
|
|
450
|
-
f"only contains data back to "
|
|
451
|
-
|
|
858
|
+
f"only contains data back to '{min_time}'.")
|
|
859
|
+
|
|
860
|
+
if query.target_ast.date_offset_range is not None:
|
|
861
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
862
|
+
else:
|
|
863
|
+
end_offset = pd.DateOffset(0)
|
|
864
|
+
end_offset = end_offset * query.num_forecasts
|
|
452
865
|
|
|
453
866
|
if (context_anchor_time is not None
|
|
454
867
|
and context_anchor_time > anchor_time):
|
|
@@ -458,41 +871,37 @@ class KumoRFM:
|
|
|
458
871
|
f"(got '{anchor_time}'). Please make sure this is "
|
|
459
872
|
f"intended.")
|
|
460
873
|
elif (query.query_type == QueryType.TEMPORAL
|
|
461
|
-
and context_anchor_time is not None
|
|
462
|
-
|
|
874
|
+
and context_anchor_time is not None
|
|
875
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
463
876
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
464
877
|
f"'{context_anchor_time}' will leak information "
|
|
465
878
|
f"from the prediction anchor timestamp "
|
|
466
879
|
f"'{anchor_time}'. Please make sure this is "
|
|
467
880
|
f"intended.")
|
|
468
881
|
|
|
469
|
-
elif (context_anchor_time is not None
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
_time = context_anchor_time - (query.target.end_offset *
|
|
473
|
-
query.num_forecasts)
|
|
882
|
+
elif (context_anchor_time is not None
|
|
883
|
+
and context_anchor_time - end_offset < min_time):
|
|
884
|
+
_time = context_anchor_time - end_offset
|
|
474
885
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
475
886
|
f"aggregation time range is too large. To form "
|
|
476
887
|
f"proper input data, we would need data back to "
|
|
477
888
|
f"'{_time}', however, your data only contains "
|
|
478
|
-
f"data back to '{
|
|
889
|
+
f"data back to '{min_time}'.")
|
|
479
890
|
|
|
480
|
-
if
|
|
481
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
891
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
482
892
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
483
|
-
f"latest timestamp '{
|
|
484
|
-
f"
|
|
893
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
894
|
+
f"make sure this is intended.")
|
|
485
895
|
|
|
486
|
-
|
|
487
|
-
query.target.end_offset * query.num_forecasts)
|
|
488
|
-
if evaluate and anchor_time > max_eval_time:
|
|
896
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
489
897
|
raise ValueError(
|
|
490
898
|
f"Anchor timestamp for evaluation is after the latest "
|
|
491
|
-
f"supported timestamp '{
|
|
899
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
492
900
|
|
|
493
901
|
def _get_context(
|
|
494
902
|
self,
|
|
495
|
-
query:
|
|
903
|
+
query: ValidatedPredictiveQuery,
|
|
904
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
496
905
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
497
906
|
context_anchor_time: Union[pd.Timestamp, None],
|
|
498
907
|
run_mode: RunMode,
|
|
@@ -518,10 +927,9 @@ class KumoRFM:
|
|
|
518
927
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
519
928
|
f"must go beyond this for your use-case.")
|
|
520
929
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
edge_types=self._graph_store.edge_types,
|
|
930
|
+
task_type = self._get_task_type(
|
|
931
|
+
query=query,
|
|
932
|
+
edge_types=self._sampler.edge_types,
|
|
525
933
|
)
|
|
526
934
|
|
|
527
935
|
if logger is not None:
|
|
@@ -552,104 +960,109 @@ class KumoRFM:
|
|
|
552
960
|
else:
|
|
553
961
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
554
962
|
|
|
963
|
+
if query.target_ast.date_offset_range is None:
|
|
964
|
+
step_offset = pd.DateOffset(0)
|
|
965
|
+
else:
|
|
966
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
967
|
+
end_offset = step_offset * query.num_forecasts
|
|
968
|
+
|
|
555
969
|
if anchor_time is None:
|
|
556
|
-
anchor_time = self.
|
|
970
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
971
|
+
|
|
557
972
|
if evaluate:
|
|
558
|
-
anchor_time = anchor_time -
|
|
559
|
-
|
|
973
|
+
anchor_time = anchor_time - end_offset
|
|
974
|
+
|
|
560
975
|
if logger is not None:
|
|
561
976
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
562
|
-
if
|
|
563
|
-
|
|
564
|
-
|
|
977
|
+
if anchor_time == pd.Timestamp.min:
|
|
978
|
+
pass # Static graph
|
|
979
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
980
|
+
and anchor_time.second == 0
|
|
981
|
+
and anchor_time.microsecond == 0):
|
|
565
982
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
566
983
|
else:
|
|
567
984
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
568
985
|
|
|
569
986
|
assert anchor_time is not None
|
|
570
987
|
if isinstance(anchor_time, pd.Timestamp):
|
|
988
|
+
if context_anchor_time == 'entity':
|
|
989
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
990
|
+
"for context and prediction examples")
|
|
571
991
|
if context_anchor_time is None:
|
|
572
|
-
context_anchor_time = anchor_time -
|
|
573
|
-
query.num_forecasts)
|
|
992
|
+
context_anchor_time = anchor_time - end_offset
|
|
574
993
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
575
994
|
evaluate)
|
|
576
995
|
else:
|
|
577
996
|
assert anchor_time == 'entity'
|
|
578
|
-
if query.
|
|
997
|
+
if query.query_type != QueryType.STATIC:
|
|
998
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
999
|
+
"static predictive queries")
|
|
1000
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
579
1001
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
580
|
-
f"table '{query.
|
|
1002
|
+
f"table '{query.entity_table}' to "
|
|
581
1003
|
f"have a time column")
|
|
582
|
-
if context_anchor_time
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
context_anchor_time =
|
|
1004
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
1005
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
1006
|
+
"for context and prediction examples")
|
|
1007
|
+
context_anchor_time = 'entity'
|
|
586
1008
|
|
|
587
|
-
|
|
1009
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
588
1010
|
if evaluate:
|
|
589
|
-
|
|
1011
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
590
1012
|
if task_type.is_link_pred:
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
test_node, test_time, y_test = query_driver.collect_test(
|
|
594
|
-
size=max_test_size,
|
|
595
|
-
anchor_time=anchor_time,
|
|
596
|
-
max_iterations=max_pq_iterations,
|
|
597
|
-
guarantee_train_examples=True,
|
|
598
|
-
)
|
|
599
|
-
if logger is not None:
|
|
600
|
-
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
601
|
-
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
602
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
603
|
-
f"{pos:.2f}% positive cases")
|
|
604
|
-
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
605
|
-
msg = (f"Collected {len(y_test):,} test examples "
|
|
606
|
-
f"holding {y_test.nunique()} classes")
|
|
607
|
-
elif task_type == TaskType.REGRESSION:
|
|
608
|
-
_min, _max = float(y_test.min()), float(y_test.max())
|
|
609
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
610
|
-
f"targets between {format_value(_min)} and "
|
|
611
|
-
f"{format_value(_max)}")
|
|
612
|
-
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
613
|
-
num_rhs = y_test.explode().nunique()
|
|
614
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
615
|
-
f"{num_rhs:,} unique items")
|
|
616
|
-
else:
|
|
617
|
-
raise NotImplementedError
|
|
618
|
-
logger.log(msg)
|
|
619
|
-
|
|
1013
|
+
num_test_examples = num_test_examples // 5
|
|
620
1014
|
else:
|
|
621
|
-
|
|
1015
|
+
num_test_examples = 0
|
|
1016
|
+
|
|
1017
|
+
train, test = self._sampler.sample_target(
|
|
1018
|
+
query=query,
|
|
1019
|
+
num_train_examples=num_train_examples,
|
|
1020
|
+
train_anchor_time=context_anchor_time,
|
|
1021
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1022
|
+
num_test_examples=num_test_examples,
|
|
1023
|
+
test_anchor_time=anchor_time,
|
|
1024
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1025
|
+
random_seed=random_seed,
|
|
1026
|
+
)
|
|
1027
|
+
train_pkey, train_time, y_train = train
|
|
1028
|
+
test_pkey, test_time, y_test = test
|
|
622
1029
|
|
|
623
|
-
|
|
624
|
-
if
|
|
1030
|
+
if evaluate and logger is not None:
|
|
1031
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1032
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1033
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1034
|
+
f"{pos:.2f}% positive cases")
|
|
1035
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1036
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1037
|
+
f"{y_test.nunique()} classes")
|
|
1038
|
+
elif task_type == TaskType.REGRESSION:
|
|
1039
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1040
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1041
|
+
f"between {format_value(_min)} and "
|
|
1042
|
+
f"{format_value(_max)}")
|
|
1043
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1044
|
+
num_rhs = y_test.explode().nunique()
|
|
1045
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1046
|
+
f"{num_rhs:,} unique items")
|
|
1047
|
+
else:
|
|
1048
|
+
raise NotImplementedError
|
|
1049
|
+
logger.log(msg)
|
|
1050
|
+
|
|
1051
|
+
if not evaluate:
|
|
1052
|
+
assert indices is not None
|
|
1053
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
625
1054
|
raise ValueError(f"Cannot predict for more than "
|
|
626
|
-
f"{
|
|
627
|
-
f"(got {len(
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
table_name=query.entity.pkey.table_name,
|
|
631
|
-
pkey=pd.Series(
|
|
632
|
-
query.entity.ids.value,
|
|
633
|
-
dtype=query.entity.ids.dtype,
|
|
634
|
-
),
|
|
635
|
-
)
|
|
1055
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
1056
|
+
f"once (got {len(indices):,}). Use "
|
|
1057
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
1058
|
+
f"in batches")
|
|
636
1059
|
|
|
1060
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
637
1061
|
if isinstance(anchor_time, pd.Timestamp):
|
|
638
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
639
|
-
len(
|
|
1062
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1063
|
+
len(indices)).reset_index(drop=True)
|
|
640
1064
|
else:
|
|
641
|
-
|
|
642
|
-
query.entity.pkey.table_name]
|
|
643
|
-
time = time[test_node] * 1000**3
|
|
644
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
645
|
-
|
|
646
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
647
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
648
|
-
anchor_time=context_anchor_time or 'entity',
|
|
649
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
650
|
-
or anchor_time == 'entity') else None,
|
|
651
|
-
max_iterations=max_pq_iterations,
|
|
652
|
-
)
|
|
1065
|
+
train_time = test_time = 'entity'
|
|
653
1066
|
|
|
654
1067
|
if logger is not None:
|
|
655
1068
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -672,27 +1085,41 @@ class KumoRFM:
|
|
|
672
1085
|
raise NotImplementedError
|
|
673
1086
|
logger.log(msg)
|
|
674
1087
|
|
|
675
|
-
entity_table_names
|
|
676
|
-
|
|
1088
|
+
entity_table_names: Tuple[str, ...]
|
|
1089
|
+
if task_type.is_link_pred:
|
|
1090
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1091
|
+
assert final_aggr is not None
|
|
1092
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1093
|
+
for edge_type in self._sampler.edge_types:
|
|
1094
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1095
|
+
entity_table_names = (
|
|
1096
|
+
query.entity_table,
|
|
1097
|
+
edge_type[2],
|
|
1098
|
+
)
|
|
1099
|
+
else:
|
|
1100
|
+
entity_table_names = (query.entity_table, )
|
|
677
1101
|
|
|
678
1102
|
# Exclude the entity anchor time from the feature set to prevent
|
|
679
1103
|
# running out-of-distribution between in-context and test examples:
|
|
680
|
-
exclude_cols_dict = query.
|
|
681
|
-
if
|
|
1104
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1105
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
682
1106
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
683
1107
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
684
|
-
|
|
685
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1108
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
686
1109
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
687
1110
|
|
|
688
|
-
subgraph = self.
|
|
1111
|
+
subgraph = self._sampler.sample_subgraph(
|
|
689
1112
|
entity_table_names=entity_table_names,
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
1113
|
+
entity_pkey=pd.concat(
|
|
1114
|
+
[train_pkey, test_pkey],
|
|
1115
|
+
axis=0,
|
|
1116
|
+
ignore_index=True,
|
|
1117
|
+
),
|
|
1118
|
+
anchor_time=pd.concat(
|
|
1119
|
+
[train_time, test_time],
|
|
1120
|
+
axis=0,
|
|
1121
|
+
ignore_index=True,
|
|
1122
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
696
1123
|
num_neighbors=num_neighbors,
|
|
697
1124
|
exclude_cols_dict=exclude_cols_dict,
|
|
698
1125
|
)
|
|
@@ -704,18 +1131,14 @@ class KumoRFM:
|
|
|
704
1131
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
705
1132
|
f"must go beyond this for your use-case.")
|
|
706
1133
|
|
|
707
|
-
step_size: Optional[int] = None
|
|
708
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
709
|
-
step_size = date_offset_to_seconds(query.target.end_offset)
|
|
710
|
-
|
|
711
1134
|
return Context(
|
|
712
1135
|
task_type=task_type,
|
|
713
1136
|
entity_table_names=entity_table_names,
|
|
714
1137
|
subgraph=subgraph,
|
|
715
1138
|
y_train=y_train,
|
|
716
|
-
y_test=y_test,
|
|
1139
|
+
y_test=y_test if evaluate else None,
|
|
717
1140
|
top_k=query.top_k,
|
|
718
|
-
step_size=
|
|
1141
|
+
step_size=None,
|
|
719
1142
|
)
|
|
720
1143
|
|
|
721
1144
|
@staticmethod
|
|
@@ -731,7 +1154,7 @@ class KumoRFM:
|
|
|
731
1154
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
732
1155
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
733
1156
|
elif task_type == TaskType.REGRESSION:
|
|
734
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1157
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
735
1158
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
736
1159
|
supported_metrics = [
|
|
737
1160
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|