kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -1,30 +1,52 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import time
|
|
2
3
|
import warnings
|
|
3
|
-
from
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Generator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, replace
|
|
8
|
+
from typing import (
|
|
9
|
+
Any,
|
|
10
|
+
Dict,
|
|
11
|
+
Iterator,
|
|
12
|
+
List,
|
|
13
|
+
Literal,
|
|
14
|
+
Optional,
|
|
15
|
+
Tuple,
|
|
16
|
+
Union,
|
|
17
|
+
overload,
|
|
18
|
+
)
|
|
4
19
|
|
|
5
20
|
import numpy as np
|
|
6
21
|
import pandas as pd
|
|
7
22
|
from kumoapi.model_plan import RunMode
|
|
8
|
-
from kumoapi.pquery import QueryType
|
|
23
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
24
|
+
from kumoapi.rfm import Context
|
|
25
|
+
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
9
26
|
from kumoapi.rfm import (
|
|
10
|
-
Context,
|
|
11
|
-
PQueryDefinition,
|
|
12
27
|
RFMEvaluateRequest,
|
|
28
|
+
RFMParseQueryRequest,
|
|
13
29
|
RFMPredictRequest,
|
|
14
|
-
RFMValidateQueryRequest,
|
|
15
30
|
)
|
|
16
31
|
from kumoapi.task import TaskType
|
|
17
32
|
|
|
18
|
-
from kumoai import
|
|
33
|
+
from kumoai.client.rfm import RFMAPI
|
|
19
34
|
from kumoai.exceptions import HTTPException
|
|
20
|
-
from kumoai.experimental.rfm import
|
|
35
|
+
from kumoai.experimental.rfm import Graph
|
|
21
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
22
37
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
23
|
-
from kumoai.experimental.rfm.local_pquery_driver import
|
|
24
|
-
|
|
38
|
+
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
+
LocalPQueryDriver,
|
|
40
|
+
date_offset_to_seconds,
|
|
41
|
+
)
|
|
42
|
+
from kumoai.mixin import CastMixin
|
|
43
|
+
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
25
44
|
|
|
26
45
|
_RANDOM_SEED = 42
|
|
27
46
|
|
|
47
|
+
_MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
|
|
48
|
+
_MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
|
|
49
|
+
|
|
28
50
|
_MAX_CONTEXT_SIZE = {
|
|
29
51
|
RunMode.DEBUG: 100,
|
|
30
52
|
RunMode.FAST: 1_000,
|
|
@@ -39,7 +61,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
|
|
|
39
61
|
}
|
|
40
62
|
|
|
41
63
|
_MAX_SIZE = 30 * 1024 * 1024
|
|
42
|
-
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {
|
|
64
|
+
_SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
|
|
43
65
|
"reduce either the number of tables in the graph, their "
|
|
44
66
|
"number of columns (e.g., large text columns), "
|
|
45
67
|
"neighborhood configuration, or the run mode. If none of "
|
|
@@ -48,6 +70,51 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
|
|
|
48
70
|
"beyond this for your use-case.")
|
|
49
71
|
|
|
50
72
|
|
|
73
|
+
@dataclass(repr=False)
|
|
74
|
+
class ExplainConfig(CastMixin):
|
|
75
|
+
"""Configuration for explainability.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
skip_summary: Whether to skip generating a human-readable summary of
|
|
79
|
+
the explanation.
|
|
80
|
+
"""
|
|
81
|
+
skip_summary: bool = False
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass(repr=False)
|
|
85
|
+
class Explanation:
|
|
86
|
+
prediction: pd.DataFrame
|
|
87
|
+
summary: str
|
|
88
|
+
details: ExplanationConfig
|
|
89
|
+
|
|
90
|
+
@overload
|
|
91
|
+
def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
@overload
|
|
95
|
+
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
|
|
99
|
+
if index == 0:
|
|
100
|
+
return self.prediction
|
|
101
|
+
if index == 1:
|
|
102
|
+
return self.summary
|
|
103
|
+
raise IndexError("Index out of range")
|
|
104
|
+
|
|
105
|
+
def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
|
|
106
|
+
return iter((self.prediction, self.summary))
|
|
107
|
+
|
|
108
|
+
def __repr__(self) -> str:
|
|
109
|
+
return str((self.prediction, self.summary))
|
|
110
|
+
|
|
111
|
+
def _ipython_display_(self) -> None:
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
|
|
114
|
+
display(self.prediction)
|
|
115
|
+
display(Markdown(self.summary))
|
|
116
|
+
|
|
117
|
+
|
|
51
118
|
class KumoRFM:
|
|
52
119
|
r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
|
|
53
120
|
Foundation Model for In-Context Learning on Relational Data
|
|
@@ -56,17 +123,17 @@ class KumoRFM:
|
|
|
56
123
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
57
124
|
relational dataset without training.
|
|
58
125
|
The model is pre-trained and the class provides an interface to query the
|
|
59
|
-
model from a :class:`
|
|
126
|
+
model from a :class:`Graph` object.
|
|
60
127
|
|
|
61
128
|
.. code-block:: python
|
|
62
129
|
|
|
63
|
-
from kumoai.experimental.rfm import
|
|
130
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
64
131
|
|
|
65
132
|
df_users = pd.DataFrame(...)
|
|
66
133
|
df_items = pd.DataFrame(...)
|
|
67
134
|
df_orders = pd.DataFrame(...)
|
|
68
135
|
|
|
69
|
-
graph =
|
|
136
|
+
graph = Graph.from_data({
|
|
70
137
|
'users': df_users,
|
|
71
138
|
'items': df_items,
|
|
72
139
|
'orders': df_orders,
|
|
@@ -74,59 +141,152 @@ class KumoRFM:
|
|
|
74
141
|
|
|
75
142
|
rfm = KumoRFM(graph)
|
|
76
143
|
|
|
77
|
-
query = ("PREDICT COUNT(
|
|
78
|
-
"FOR users.user_id=
|
|
79
|
-
result = rfm.
|
|
144
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
145
|
+
"FOR users.user_id=1")
|
|
146
|
+
result = rfm.predict(query)
|
|
80
147
|
|
|
81
148
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
82
149
|
# 1 0.85
|
|
83
150
|
|
|
84
151
|
Args:
|
|
85
152
|
graph: The graph.
|
|
86
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
87
|
-
materialization.
|
|
88
|
-
This is a runtime trade-off between graph materialization and model
|
|
89
|
-
processing speed.
|
|
90
|
-
It can be benefical to preprocess your data once and then run many
|
|
91
|
-
queries on top to achieve maximum model speed.
|
|
92
|
-
However, if activiated, graph materialization can take potentially
|
|
93
|
-
much longer, especially on graphs with many large text columns.
|
|
94
|
-
Best to tune this option manually.
|
|
95
153
|
verbose: Whether to print verbose output.
|
|
96
154
|
"""
|
|
97
155
|
def __init__(
|
|
98
156
|
self,
|
|
99
|
-
graph:
|
|
100
|
-
|
|
101
|
-
verbose: bool = True,
|
|
157
|
+
graph: Graph,
|
|
158
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
102
159
|
) -> None:
|
|
103
160
|
graph = graph.validate()
|
|
104
161
|
self._graph_def = graph._to_api_graph_definition()
|
|
105
|
-
self._graph_store = LocalGraphStore(graph,
|
|
162
|
+
self._graph_store = LocalGraphStore(graph, verbose)
|
|
106
163
|
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
107
164
|
|
|
165
|
+
self._client: Optional[RFMAPI] = None
|
|
166
|
+
|
|
167
|
+
self._batch_size: Optional[int | Literal['max']] = None
|
|
168
|
+
self.num_retries: int = 0
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def _api_client(self) -> RFMAPI:
|
|
172
|
+
if self._client is not None:
|
|
173
|
+
return self._client
|
|
174
|
+
|
|
175
|
+
from kumoai.experimental.rfm import global_state
|
|
176
|
+
self._client = RFMAPI(global_state.client)
|
|
177
|
+
return self._client
|
|
178
|
+
|
|
108
179
|
def __repr__(self) -> str:
|
|
109
180
|
return f'{self.__class__.__name__}()'
|
|
110
181
|
|
|
182
|
+
@contextmanager
|
|
183
|
+
def batch_mode(
|
|
184
|
+
self,
|
|
185
|
+
batch_size: Union[int, Literal['max']] = 'max',
|
|
186
|
+
num_retries: int = 1,
|
|
187
|
+
) -> Generator[None, None, None]:
|
|
188
|
+
"""Context manager to predict in batches.
|
|
189
|
+
|
|
190
|
+
.. code-block:: python
|
|
191
|
+
|
|
192
|
+
with model.batch_mode(batch_size='max', num_retries=1):
|
|
193
|
+
df = model.predict(query, indices=...)
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
batch_size: The batch size. If set to ``"max"``, will use the
|
|
197
|
+
maximum applicable batch size for the given task.
|
|
198
|
+
num_retries: The maximum number of retries for failed queries due
|
|
199
|
+
to unexpected server issues.
|
|
200
|
+
"""
|
|
201
|
+
if batch_size != 'max' and batch_size <= 0:
|
|
202
|
+
raise ValueError(f"'batch_size' must be greater than zero "
|
|
203
|
+
f"(got {batch_size})")
|
|
204
|
+
|
|
205
|
+
if num_retries < 0:
|
|
206
|
+
raise ValueError(f"'num_retries' must be greater than or equal to "
|
|
207
|
+
f"zero (got {num_retries})")
|
|
208
|
+
|
|
209
|
+
self._batch_size = batch_size
|
|
210
|
+
self.num_retries = num_retries
|
|
211
|
+
yield
|
|
212
|
+
self._batch_size = None
|
|
213
|
+
self.num_retries = 0
|
|
214
|
+
|
|
215
|
+
@overload
|
|
111
216
|
def predict(
|
|
112
217
|
self,
|
|
113
218
|
query: str,
|
|
219
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
114
220
|
*,
|
|
221
|
+
explain: Literal[False] = False,
|
|
115
222
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
223
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
116
224
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
117
225
|
num_neighbors: Optional[List[int]] = None,
|
|
118
226
|
num_hops: int = 2,
|
|
119
227
|
max_pq_iterations: int = 20,
|
|
120
228
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
121
|
-
verbose: bool = True,
|
|
229
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
230
|
+
use_prediction_time: bool = False,
|
|
122
231
|
) -> pd.DataFrame:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
@overload
|
|
235
|
+
def predict(
|
|
236
|
+
self,
|
|
237
|
+
query: str,
|
|
238
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
239
|
+
*,
|
|
240
|
+
explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
|
|
241
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
242
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
243
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
244
|
+
num_neighbors: Optional[List[int]] = None,
|
|
245
|
+
num_hops: int = 2,
|
|
246
|
+
max_pq_iterations: int = 20,
|
|
247
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
248
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
249
|
+
use_prediction_time: bool = False,
|
|
250
|
+
) -> Explanation:
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
def predict(
|
|
254
|
+
self,
|
|
255
|
+
query: str,
|
|
256
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
257
|
+
*,
|
|
258
|
+
explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
|
|
259
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
260
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
261
|
+
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
262
|
+
num_neighbors: Optional[List[int]] = None,
|
|
263
|
+
num_hops: int = 2,
|
|
264
|
+
max_pq_iterations: int = 20,
|
|
265
|
+
random_seed: Optional[int] = _RANDOM_SEED,
|
|
266
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
267
|
+
use_prediction_time: bool = False,
|
|
268
|
+
) -> Union[pd.DataFrame, Explanation]:
|
|
123
269
|
"""Returns predictions for a predictive query.
|
|
124
270
|
|
|
125
271
|
Args:
|
|
126
272
|
query: The predictive query.
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
273
|
+
indices: The entity primary keys to predict for. Will override the
|
|
274
|
+
indices given as part of the predictive query. Predictions will
|
|
275
|
+
be generated for all indices, independent of whether they
|
|
276
|
+
fulfill entity filter constraints. To pre-filter entities, use
|
|
277
|
+
:meth:`~KumoRFM.is_valid_entity`.
|
|
278
|
+
explain: Configuration for explainability.
|
|
279
|
+
If set to ``True``, will additionally explain the prediction.
|
|
280
|
+
Passing in an :class:`ExplainConfig` instance provides control
|
|
281
|
+
over which parts of explanation are generated.
|
|
282
|
+
Explainability is currently only supported for single entity
|
|
283
|
+
predictions with ``run_mode="FAST"``.
|
|
284
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
285
|
+
``None``, will use the maximum timestamp in the data.
|
|
286
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
287
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
288
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
289
|
+
determine the anchor time for context examples.
|
|
130
290
|
run_mode: The :class:`RunMode` for the query.
|
|
131
291
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
132
292
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -138,79 +298,244 @@ class KumoRFM:
|
|
|
138
298
|
entities to find valid labels.
|
|
139
299
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
140
300
|
verbose: Whether to print verbose output.
|
|
301
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
302
|
+
additional feature during prediction. This is typically
|
|
303
|
+
beneficial for time series forecasting tasks.
|
|
141
304
|
|
|
142
305
|
Returns:
|
|
143
|
-
The predictions as a :class:`pandas.DataFrame
|
|
306
|
+
The predictions as a :class:`pandas.DataFrame`.
|
|
307
|
+
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
|
+
containing the prediction, summary, and details.
|
|
144
309
|
"""
|
|
145
|
-
|
|
310
|
+
explain_config: Optional[ExplainConfig] = None
|
|
311
|
+
if explain is True:
|
|
312
|
+
explain_config = ExplainConfig()
|
|
313
|
+
elif explain is not False:
|
|
314
|
+
explain_config = ExplainConfig._cast(explain)
|
|
315
|
+
|
|
146
316
|
query_def = self._parse_query(query)
|
|
317
|
+
query_str = query_def.to_string()
|
|
147
318
|
|
|
148
319
|
if num_hops != 2 and num_neighbors is not None:
|
|
149
320
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
150
321
|
f"custom 'num_hops={num_hops}' option")
|
|
151
322
|
|
|
152
|
-
if
|
|
323
|
+
if explain_config is not None and run_mode in {
|
|
324
|
+
RunMode.NORMAL, RunMode.BEST
|
|
325
|
+
}:
|
|
153
326
|
warnings.warn(f"Explainability is currently only supported for "
|
|
154
327
|
f"run mode 'FAST' (got '{run_mode}'). Provided run "
|
|
155
328
|
f"mode has been reset. Please lower the run mode to "
|
|
156
329
|
f"suppress this warning.")
|
|
157
330
|
|
|
158
|
-
if
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
331
|
+
if indices is None:
|
|
332
|
+
if query_def.rfm_entity_ids is None:
|
|
333
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
334
|
+
"pass them via `predict(query, indices=...)`")
|
|
335
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
336
|
+
else:
|
|
337
|
+
query_def = replace(query_def, rfm_entity_ids=None)
|
|
338
|
+
|
|
339
|
+
if len(indices) == 0:
|
|
340
|
+
raise ValueError("At least one entity is required")
|
|
341
|
+
|
|
342
|
+
if explain_config is not None and len(indices) > 1:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
f"Cannot explain predictions for more than a single entity "
|
|
345
|
+
f"(got {len(indices)})")
|
|
164
346
|
|
|
165
347
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
166
|
-
if
|
|
348
|
+
if explain_config is not None:
|
|
167
349
|
msg = f'[bold]EXPLAIN[/bold] {query_repr}'
|
|
168
350
|
else:
|
|
169
351
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
170
352
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
353
|
+
if not isinstance(verbose, ProgressLogger):
|
|
354
|
+
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
355
|
+
|
|
356
|
+
with verbose as logger:
|
|
357
|
+
|
|
358
|
+
batch_size: Optional[int] = None
|
|
359
|
+
if self._batch_size == 'max':
|
|
360
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
361
|
+
query_def,
|
|
362
|
+
edge_types=self._graph_store.edge_types,
|
|
363
|
+
)
|
|
364
|
+
batch_size = _MAX_PRED_SIZE[task_type]
|
|
365
|
+
else:
|
|
366
|
+
batch_size = self._batch_size
|
|
367
|
+
|
|
368
|
+
if batch_size is not None:
|
|
369
|
+
offsets = range(0, len(indices), batch_size)
|
|
370
|
+
batches = [indices[step:step + batch_size] for step in offsets]
|
|
371
|
+
else:
|
|
372
|
+
batches = [indices]
|
|
373
|
+
|
|
374
|
+
if len(batches) > 1:
|
|
375
|
+
logger.log(f"Splitting {len(indices):,} entities into "
|
|
376
|
+
f"{len(batches):,} batches of size {batch_size:,}")
|
|
377
|
+
|
|
378
|
+
predictions: List[pd.DataFrame] = []
|
|
379
|
+
summary: Optional[str] = None
|
|
380
|
+
details: Optional[Explanation] = None
|
|
381
|
+
for i, batch in enumerate(batches):
|
|
382
|
+
# TODO Re-use the context for subsequent predictions.
|
|
383
|
+
context = self._get_context(
|
|
384
|
+
query=query_def,
|
|
385
|
+
indices=batch,
|
|
386
|
+
anchor_time=anchor_time,
|
|
387
|
+
context_anchor_time=context_anchor_time,
|
|
388
|
+
run_mode=RunMode(run_mode),
|
|
389
|
+
num_neighbors=num_neighbors,
|
|
390
|
+
num_hops=num_hops,
|
|
391
|
+
max_pq_iterations=max_pq_iterations,
|
|
392
|
+
evaluate=False,
|
|
393
|
+
random_seed=random_seed,
|
|
394
|
+
logger=logger if i == 0 else None,
|
|
395
|
+
)
|
|
396
|
+
request = RFMPredictRequest(
|
|
397
|
+
context=context,
|
|
398
|
+
run_mode=RunMode(run_mode),
|
|
399
|
+
query=query_str,
|
|
400
|
+
use_prediction_time=use_prediction_time,
|
|
401
|
+
)
|
|
402
|
+
with warnings.catch_warnings():
|
|
403
|
+
warnings.filterwarnings('ignore', message='gencode')
|
|
404
|
+
request_msg = request.to_protobuf()
|
|
405
|
+
_bytes = request_msg.SerializeToString()
|
|
406
|
+
if i == 0:
|
|
407
|
+
logger.log(f"Generated context of size "
|
|
408
|
+
f"{len(_bytes) / (1024*1024):.2f}MB")
|
|
409
|
+
|
|
410
|
+
if len(_bytes) > _MAX_SIZE:
|
|
411
|
+
stats = Context.get_memory_stats(request_msg.context)
|
|
412
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
413
|
+
|
|
414
|
+
if (isinstance(verbose, InteractiveProgressLogger) and i == 0
|
|
415
|
+
and len(batches) > 1):
|
|
416
|
+
verbose.init_progress(
|
|
417
|
+
total=len(batches),
|
|
418
|
+
description='Predicting',
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
for attempt in range(self.num_retries + 1):
|
|
422
|
+
try:
|
|
423
|
+
if explain_config is not None:
|
|
424
|
+
resp = self._api_client.explain(
|
|
425
|
+
request=_bytes,
|
|
426
|
+
skip_summary=explain_config.skip_summary,
|
|
427
|
+
)
|
|
428
|
+
summary = resp.summary
|
|
429
|
+
details = resp.details
|
|
430
|
+
else:
|
|
431
|
+
resp = self._api_client.predict(_bytes)
|
|
432
|
+
df = pd.DataFrame(**resp.prediction)
|
|
433
|
+
|
|
434
|
+
# Cast 'ENTITY' to correct data type:
|
|
435
|
+
if 'ENTITY' in df:
|
|
436
|
+
entity = query_def.entity_table
|
|
437
|
+
pkey_map = self._graph_store.pkey_map_dict[entity]
|
|
438
|
+
df['ENTITY'] = df['ENTITY'].astype(
|
|
439
|
+
type(pkey_map.index[0]))
|
|
440
|
+
|
|
441
|
+
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
|
+
if 'ANCHOR_TIMESTAMP' in df:
|
|
443
|
+
ser = df['ANCHOR_TIMESTAMP']
|
|
444
|
+
if not pd.api.types.is_datetime64_any_dtype(ser):
|
|
445
|
+
if isinstance(ser.iloc[0], str):
|
|
446
|
+
unit = None
|
|
447
|
+
else:
|
|
448
|
+
unit = 'ms'
|
|
449
|
+
df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
|
|
450
|
+
ser, errors='coerce', unit=unit)
|
|
451
|
+
|
|
452
|
+
predictions.append(df)
|
|
453
|
+
|
|
454
|
+
if (isinstance(verbose, InteractiveProgressLogger)
|
|
455
|
+
and len(batches) > 1):
|
|
456
|
+
verbose.step()
|
|
457
|
+
|
|
458
|
+
break
|
|
459
|
+
except HTTPException as e:
|
|
460
|
+
if attempt == self.num_retries:
|
|
461
|
+
try:
|
|
462
|
+
msg = json.loads(e.detail)['detail']
|
|
463
|
+
except Exception:
|
|
464
|
+
msg = e.detail
|
|
465
|
+
raise RuntimeError(
|
|
466
|
+
f"An unexpected exception occurred. Please "
|
|
467
|
+
f"create an issue at "
|
|
468
|
+
f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
|
|
469
|
+
) from None
|
|
470
|
+
|
|
471
|
+
time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
|
|
472
|
+
|
|
473
|
+
if len(predictions) == 1:
|
|
474
|
+
prediction = predictions[0]
|
|
475
|
+
else:
|
|
476
|
+
prediction = pd.concat(predictions, ignore_index=True)
|
|
477
|
+
|
|
478
|
+
if explain_config is not None:
|
|
479
|
+
assert len(predictions) == 1
|
|
480
|
+
assert summary is not None
|
|
481
|
+
assert details is not None
|
|
482
|
+
return Explanation(
|
|
483
|
+
prediction=prediction,
|
|
484
|
+
summary=summary,
|
|
485
|
+
details=details,
|
|
186
486
|
)
|
|
187
|
-
with warnings.catch_warnings():
|
|
188
|
-
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
189
|
-
request_msg = request.to_protobuf()
|
|
190
|
-
request_bytes = request_msg.SerializeToString()
|
|
191
|
-
logger.log(f"Generated context of size "
|
|
192
|
-
f"{len(request_bytes) / (1024*1024):.2f}MB")
|
|
193
487
|
|
|
194
|
-
|
|
195
|
-
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
196
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
|
|
488
|
+
return prediction
|
|
197
489
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
raise RuntimeError(f"An unexpected exception occurred. "
|
|
209
|
-
f"Please create an issue at "
|
|
210
|
-
f"'https://github.com/kumo-ai/kumo-rfm'. "
|
|
211
|
-
f"{msg}") from None
|
|
490
|
+
def is_valid_entity(
|
|
491
|
+
self,
|
|
492
|
+
query: str,
|
|
493
|
+
indices: Union[List[str], List[float], List[int], None] = None,
|
|
494
|
+
*,
|
|
495
|
+
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
496
|
+
) -> np.ndarray:
|
|
497
|
+
r"""Returns a mask that denotes which entities are valid for the
|
|
498
|
+
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
499
|
+
entity filter constraints.
|
|
212
500
|
|
|
213
|
-
|
|
501
|
+
Args:
|
|
502
|
+
query: The predictive query.
|
|
503
|
+
indices: The entity primary keys to predict for. Will override the
|
|
504
|
+
indices given as part of the predictive query.
|
|
505
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
506
|
+
``None``, will use the maximum timestamp in the data.
|
|
507
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
508
|
+
"""
|
|
509
|
+
query_def = self._parse_query(query)
|
|
510
|
+
|
|
511
|
+
if indices is None:
|
|
512
|
+
if query_def.rfm_entity_ids is None:
|
|
513
|
+
raise ValueError("Cannot find entities to predict for. Please "
|
|
514
|
+
"pass them via "
|
|
515
|
+
"`is_valid_entity(query, indices=...)`")
|
|
516
|
+
indices = query_def.get_rfm_entity_id_list()
|
|
517
|
+
|
|
518
|
+
if len(indices) == 0:
|
|
519
|
+
raise ValueError("At least one entity is required")
|
|
520
|
+
|
|
521
|
+
if anchor_time is None:
|
|
522
|
+
anchor_time = self._graph_store.max_time
|
|
523
|
+
|
|
524
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
525
|
+
self._validate_time(query_def, anchor_time, None, False)
|
|
526
|
+
else:
|
|
527
|
+
assert anchor_time == 'entity'
|
|
528
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
529
|
+
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
|
+
f"table '{query_def.entity_table}' "
|
|
531
|
+
f"to have a time column.")
|
|
532
|
+
|
|
533
|
+
node = self._graph_store.get_node_id(
|
|
534
|
+
table_name=query_def.entity_table,
|
|
535
|
+
pkey=pd.Series(indices),
|
|
536
|
+
)
|
|
537
|
+
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
+
return query_driver.is_valid(node, anchor_time)
|
|
214
539
|
|
|
215
540
|
def evaluate(
|
|
216
541
|
self,
|
|
@@ -218,21 +543,26 @@ class KumoRFM:
|
|
|
218
543
|
*,
|
|
219
544
|
metrics: Optional[List[str]] = None,
|
|
220
545
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
|
|
546
|
+
context_anchor_time: Union[pd.Timestamp, None] = None,
|
|
221
547
|
run_mode: Union[RunMode, str] = RunMode.FAST,
|
|
222
548
|
num_neighbors: Optional[List[int]] = None,
|
|
223
549
|
num_hops: int = 2,
|
|
224
550
|
max_pq_iterations: int = 20,
|
|
225
551
|
random_seed: Optional[int] = _RANDOM_SEED,
|
|
226
|
-
verbose: bool = True,
|
|
552
|
+
verbose: Union[bool, ProgressLogger] = True,
|
|
553
|
+
use_prediction_time: bool = False,
|
|
227
554
|
) -> pd.DataFrame:
|
|
228
555
|
"""Evaluates a predictive query.
|
|
229
556
|
|
|
230
557
|
Args:
|
|
231
558
|
query: The predictive query.
|
|
232
559
|
metrics: The metrics to use.
|
|
233
|
-
anchor_time: The anchor timestamp for the
|
|
234
|
-
|
|
235
|
-
If set to
|
|
560
|
+
anchor_time: The anchor timestamp for the prediction. If set to
|
|
561
|
+
``None``, will use the maximum timestamp in the data.
|
|
562
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
563
|
+
context_anchor_time: The maximum anchor timestamp for context
|
|
564
|
+
examples. If set to ``None``, ``anchor_time`` will
|
|
565
|
+
determine the anchor time for context examples.
|
|
236
566
|
run_mode: The :class:`RunMode` for the query.
|
|
237
567
|
num_neighbors: The number of neighbors to sample for each hop.
|
|
238
568
|
If specified, the ``num_hops`` option will be ignored.
|
|
@@ -244,6 +574,9 @@ class KumoRFM:
|
|
|
244
574
|
entities to find valid labels.
|
|
245
575
|
random_seed: A manual seed for generating pseudo-random numbers.
|
|
246
576
|
verbose: Whether to print verbose output.
|
|
577
|
+
use_prediction_time: Whether to use the anchor timestamp as an
|
|
578
|
+
additional feature during prediction. This is typically
|
|
579
|
+
beneficial for time series forecasting tasks.
|
|
247
580
|
|
|
248
581
|
Returns:
|
|
249
582
|
The metrics as a :class:`pandas.DataFrame`
|
|
@@ -254,13 +587,24 @@ class KumoRFM:
|
|
|
254
587
|
warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
|
|
255
588
|
f"custom 'num_hops={num_hops}' option")
|
|
256
589
|
|
|
590
|
+
if query_def.rfm_entity_ids is not None:
|
|
591
|
+
query_def = replace(
|
|
592
|
+
query_def,
|
|
593
|
+
rfm_entity_ids=None,
|
|
594
|
+
)
|
|
595
|
+
|
|
257
596
|
query_repr = query_def.to_string(rich=True, exclude_predict=True)
|
|
258
597
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
259
598
|
|
|
260
|
-
|
|
599
|
+
if not isinstance(verbose, ProgressLogger):
|
|
600
|
+
verbose = InteractiveProgressLogger(msg, verbose=verbose)
|
|
601
|
+
|
|
602
|
+
with verbose as logger:
|
|
261
603
|
context = self._get_context(
|
|
262
|
-
query_def,
|
|
604
|
+
query=query_def,
|
|
605
|
+
indices=None,
|
|
263
606
|
anchor_time=anchor_time,
|
|
607
|
+
context_anchor_time=context_anchor_time,
|
|
264
608
|
run_mode=RunMode(run_mode),
|
|
265
609
|
num_neighbors=num_neighbors,
|
|
266
610
|
num_hops=num_hops,
|
|
@@ -276,6 +620,7 @@ class KumoRFM:
|
|
|
276
620
|
context=context,
|
|
277
621
|
run_mode=RunMode(run_mode),
|
|
278
622
|
metrics=metrics,
|
|
623
|
+
use_prediction_time=use_prediction_time,
|
|
279
624
|
)
|
|
280
625
|
with warnings.catch_warnings():
|
|
281
626
|
warnings.filterwarnings('ignore', message='Protobuf gencode')
|
|
@@ -286,10 +631,10 @@ class KumoRFM:
|
|
|
286
631
|
|
|
287
632
|
if len(request_bytes) > _MAX_SIZE:
|
|
288
633
|
stats_msg = Context.get_memory_stats(request_msg.context)
|
|
289
|
-
raise ValueError(_SIZE_LIMIT_MSG.format(
|
|
634
|
+
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
290
635
|
|
|
291
636
|
try:
|
|
292
|
-
resp =
|
|
637
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
293
638
|
except HTTPException as e:
|
|
294
639
|
try:
|
|
295
640
|
msg = json.loads(e.detail)['detail']
|
|
@@ -334,17 +679,19 @@ class KumoRFM:
|
|
|
334
679
|
|
|
335
680
|
if anchor_time is None:
|
|
336
681
|
anchor_time = self._graph_store.max_time
|
|
337
|
-
|
|
682
|
+
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
+
anchor_time = anchor_time - (
|
|
684
|
+
query_def.target_ast.date_offset_range.end_date_offset *
|
|
685
|
+
query_def.num_forecasts)
|
|
338
686
|
|
|
339
687
|
assert anchor_time is not None
|
|
340
688
|
if isinstance(anchor_time, pd.Timestamp):
|
|
341
|
-
self._validate_time(query_def, anchor_time, evaluate=True)
|
|
689
|
+
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
342
690
|
else:
|
|
343
691
|
assert anchor_time == 'entity'
|
|
344
|
-
if (query_def.
|
|
345
|
-
not in self._graph_store.time_dict):
|
|
692
|
+
if (query_def.entity_table not in self._graph_store.time_dict):
|
|
346
693
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
347
|
-
f"table '{query_def.
|
|
694
|
+
f"table '{query_def.entity_table}' "
|
|
348
695
|
f"to have a time column")
|
|
349
696
|
|
|
350
697
|
query_driver = LocalPQueryDriver(self._graph_store, query_def,
|
|
@@ -355,18 +702,22 @@ class KumoRFM:
|
|
|
355
702
|
anchor_time=anchor_time,
|
|
356
703
|
batch_size=min(10_000, size),
|
|
357
704
|
max_iterations=max_iterations,
|
|
705
|
+
guarantee_train_examples=False,
|
|
358
706
|
)
|
|
359
707
|
|
|
708
|
+
entity = self._graph_store.pkey_map_dict[
|
|
709
|
+
query_def.entity_table].index[node]
|
|
710
|
+
|
|
360
711
|
return pd.DataFrame({
|
|
361
|
-
'ENTITY':
|
|
712
|
+
'ENTITY': entity,
|
|
362
713
|
'ANCHOR_TIMESTAMP': time,
|
|
363
714
|
'TARGET': y,
|
|
364
715
|
})
|
|
365
716
|
|
|
366
717
|
# Helpers #################################################################
|
|
367
718
|
|
|
368
|
-
def _parse_query(self, query: str) ->
|
|
369
|
-
if isinstance(query,
|
|
719
|
+
def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
|
|
720
|
+
if isinstance(query, ValidatedPredictiveQuery):
|
|
370
721
|
return query
|
|
371
722
|
|
|
372
723
|
if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
|
|
@@ -376,12 +727,13 @@ class KumoRFM:
|
|
|
376
727
|
"predictions or evaluations.")
|
|
377
728
|
|
|
378
729
|
try:
|
|
379
|
-
request =
|
|
730
|
+
request = RFMParseQueryRequest(
|
|
380
731
|
query=query,
|
|
381
732
|
graph_definition=self._graph_def,
|
|
382
733
|
)
|
|
383
734
|
|
|
384
|
-
resp =
|
|
735
|
+
resp = self._api_client.parse_query(request)
|
|
736
|
+
|
|
385
737
|
# TODO Expose validation warnings.
|
|
386
738
|
|
|
387
739
|
if len(resp.validation_response.warnings) > 0:
|
|
@@ -392,7 +744,7 @@ class KumoRFM:
|
|
|
392
744
|
warnings.warn(f"Encountered the following warnings during "
|
|
393
745
|
f"parsing:\n{msg}")
|
|
394
746
|
|
|
395
|
-
return resp.
|
|
747
|
+
return resp.query
|
|
396
748
|
except HTTPException as e:
|
|
397
749
|
try:
|
|
398
750
|
msg = json.loads(e.detail)['detail']
|
|
@@ -403,8 +755,9 @@ class KumoRFM:
|
|
|
403
755
|
|
|
404
756
|
def _validate_time(
|
|
405
757
|
self,
|
|
406
|
-
query:
|
|
758
|
+
query: ValidatedPredictiveQuery,
|
|
407
759
|
anchor_time: pd.Timestamp,
|
|
760
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
408
761
|
evaluate: bool,
|
|
409
762
|
) -> None:
|
|
410
763
|
|
|
@@ -416,22 +769,45 @@ class KumoRFM:
|
|
|
416
769
|
f"the earliest timestamp "
|
|
417
770
|
f"'{self._graph_store.min_time}' in the data.")
|
|
418
771
|
|
|
419
|
-
if
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
f"
|
|
423
|
-
f"
|
|
424
|
-
f"however, your data
|
|
772
|
+
if (context_anchor_time is not None
|
|
773
|
+
and context_anchor_time < self._graph_store.min_time):
|
|
774
|
+
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
|
+
f"aggregation time range is too large. To make "
|
|
776
|
+
f"this prediction, we would need data back to "
|
|
777
|
+
f"'{context_anchor_time}', however, your data "
|
|
778
|
+
f"only contains data back to "
|
|
425
779
|
f"'{self._graph_store.min_time}'.")
|
|
426
780
|
|
|
427
|
-
if
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
781
|
+
if query.target_ast.date_offset_range is not None:
|
|
782
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
|
+
else:
|
|
784
|
+
end_offset = pd.DateOffset(0)
|
|
785
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
786
|
+
if (context_anchor_time is not None
|
|
787
|
+
and context_anchor_time > anchor_time):
|
|
788
|
+
warnings.warn(f"Context anchor timestamp "
|
|
789
|
+
f"(got '{context_anchor_time}') is set to a later "
|
|
790
|
+
f"date than the prediction anchor timestamp "
|
|
791
|
+
f"(got '{anchor_time}'). Please make sure this is "
|
|
792
|
+
f"intended.")
|
|
793
|
+
elif (query.query_type == QueryType.TEMPORAL
|
|
794
|
+
and context_anchor_time is not None
|
|
795
|
+
and context_anchor_time + forecast_end_offset > anchor_time):
|
|
796
|
+
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
|
+
f"'{context_anchor_time}' will leak information "
|
|
798
|
+
f"from the prediction anchor timestamp "
|
|
799
|
+
f"'{anchor_time}'. Please make sure this is "
|
|
800
|
+
f"intended.")
|
|
801
|
+
|
|
802
|
+
elif (context_anchor_time is not None
|
|
803
|
+
and context_anchor_time - forecast_end_offset
|
|
804
|
+
< self._graph_store.min_time):
|
|
805
|
+
_time = context_anchor_time - forecast_end_offset
|
|
806
|
+
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
|
+
f"aggregation time range is too large. To form "
|
|
808
|
+
f"proper input data, we would need data back to "
|
|
809
|
+
f"'{_time}', however, your data only contains "
|
|
810
|
+
f"data back to '{self._graph_store.min_time}'.")
|
|
435
811
|
|
|
436
812
|
if (not evaluate and anchor_time
|
|
437
813
|
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
@@ -439,17 +815,18 @@ class KumoRFM:
|
|
|
439
815
|
f"latest timestamp '{self._graph_store.max_time}' "
|
|
440
816
|
f"in the data. Please make sure this is intended.")
|
|
441
817
|
|
|
442
|
-
|
|
443
|
-
|
|
818
|
+
max_eval_time = self._graph_store.max_time - forecast_end_offset
|
|
819
|
+
if evaluate and anchor_time > max_eval_time:
|
|
444
820
|
raise ValueError(
|
|
445
821
|
f"Anchor timestamp for evaluation is after the latest "
|
|
446
|
-
f"supported timestamp "
|
|
447
|
-
f"'{self._graph_store.max_time - query.target.end_offset}'.")
|
|
822
|
+
f"supported timestamp '{max_eval_time}'.")
|
|
448
823
|
|
|
449
824
|
def _get_context(
|
|
450
825
|
self,
|
|
451
|
-
query:
|
|
826
|
+
query: ValidatedPredictiveQuery,
|
|
827
|
+
indices: Union[List[str], List[float], List[int], None],
|
|
452
828
|
anchor_time: Union[pd.Timestamp, Literal['entity'], None],
|
|
829
|
+
context_anchor_time: Union[pd.Timestamp, None],
|
|
453
830
|
run_mode: RunMode,
|
|
454
831
|
num_neighbors: Optional[List[int]],
|
|
455
832
|
num_hops: int,
|
|
@@ -474,8 +851,8 @@ class KumoRFM:
|
|
|
474
851
|
f"must go beyond this for your use-case.")
|
|
475
852
|
|
|
476
853
|
query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
|
|
477
|
-
task_type =
|
|
478
|
-
|
|
854
|
+
task_type = LocalPQueryDriver.get_task_type(
|
|
855
|
+
query,
|
|
479
856
|
edge_types=self._graph_store.edge_types,
|
|
480
857
|
)
|
|
481
858
|
|
|
@@ -507,28 +884,42 @@ class KumoRFM:
|
|
|
507
884
|
else:
|
|
508
885
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
509
886
|
|
|
887
|
+
if query.target_ast.date_offset_range is None:
|
|
888
|
+
end_offset = pd.DateOffset(0)
|
|
889
|
+
else:
|
|
890
|
+
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
891
|
+
forecast_end_offset = end_offset * query.num_forecasts
|
|
510
892
|
if anchor_time is None:
|
|
511
893
|
anchor_time = self._graph_store.max_time
|
|
512
894
|
if evaluate:
|
|
513
|
-
anchor_time = anchor_time -
|
|
895
|
+
anchor_time = anchor_time - forecast_end_offset
|
|
514
896
|
if logger is not None:
|
|
515
897
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
516
|
-
if
|
|
517
|
-
|
|
518
|
-
|
|
898
|
+
if anchor_time == pd.Timestamp.min:
|
|
899
|
+
pass # Static graph
|
|
900
|
+
elif (anchor_time.hour == 0 and anchor_time.minute == 0
|
|
901
|
+
and anchor_time.second == 0
|
|
902
|
+
and anchor_time.microsecond == 0):
|
|
519
903
|
logger.log(f"Derived anchor time {anchor_time.date()}")
|
|
520
904
|
else:
|
|
521
905
|
logger.log(f"Derived anchor time {anchor_time}")
|
|
522
906
|
|
|
523
907
|
assert anchor_time is not None
|
|
524
908
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
|
-
|
|
909
|
+
if context_anchor_time is None:
|
|
910
|
+
context_anchor_time = anchor_time - forecast_end_offset
|
|
911
|
+
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
|
+
evaluate)
|
|
526
913
|
else:
|
|
527
914
|
assert anchor_time == 'entity'
|
|
528
|
-
if query.
|
|
915
|
+
if query.entity_table not in self._graph_store.time_dict:
|
|
529
916
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
|
-
f"table '{query.
|
|
917
|
+
f"table '{query.entity_table}' to "
|
|
531
918
|
f"have a time column")
|
|
919
|
+
if context_anchor_time is not None:
|
|
920
|
+
warnings.warn("Ignoring option 'context_anchor_time' for "
|
|
921
|
+
"`anchor_time='entity'`")
|
|
922
|
+
context_anchor_time = None
|
|
532
923
|
|
|
533
924
|
y_test: Optional[pd.Series] = None
|
|
534
925
|
if evaluate:
|
|
@@ -540,6 +931,7 @@ class KumoRFM:
|
|
|
540
931
|
size=max_test_size,
|
|
541
932
|
anchor_time=anchor_time,
|
|
542
933
|
max_iterations=max_pq_iterations,
|
|
934
|
+
guarantee_train_examples=True,
|
|
543
935
|
)
|
|
544
936
|
if logger is not None:
|
|
545
937
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -563,34 +955,31 @@ class KumoRFM:
|
|
|
563
955
|
logger.log(msg)
|
|
564
956
|
|
|
565
957
|
else:
|
|
566
|
-
assert
|
|
958
|
+
assert indices is not None
|
|
567
959
|
|
|
568
|
-
|
|
569
|
-
if len(query.entity.ids.value) > max_num_test:
|
|
960
|
+
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
570
961
|
raise ValueError(f"Cannot predict for more than "
|
|
571
|
-
f"{
|
|
572
|
-
f"(got {len(
|
|
962
|
+
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
963
|
+
f"once (got {len(indices):,}). Use "
|
|
964
|
+
f"`KumoRFM.batch_mode` to process entities "
|
|
965
|
+
f"in batches")
|
|
573
966
|
|
|
574
967
|
test_node = self._graph_store.get_node_id(
|
|
575
|
-
table_name=query.
|
|
576
|
-
pkey=pd.Series(
|
|
577
|
-
query.entity.ids.value,
|
|
578
|
-
dtype=query.entity.ids.dtype,
|
|
579
|
-
),
|
|
968
|
+
table_name=query.entity_table,
|
|
969
|
+
pkey=pd.Series(indices),
|
|
580
970
|
)
|
|
581
971
|
|
|
582
972
|
if isinstance(anchor_time, pd.Timestamp):
|
|
583
973
|
test_time = pd.Series(anchor_time).repeat(
|
|
584
974
|
len(test_node)).reset_index(drop=True)
|
|
585
975
|
else:
|
|
586
|
-
time = self._graph_store.time_dict[
|
|
587
|
-
query.entity.pkey.table_name]
|
|
976
|
+
time = self._graph_store.time_dict[query.entity_table]
|
|
588
977
|
time = time[test_node] * 1000**3
|
|
589
978
|
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
590
979
|
|
|
591
980
|
train_node, train_time, y_train = query_driver.collect_train(
|
|
592
981
|
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
593
|
-
anchor_time=
|
|
982
|
+
anchor_time=context_anchor_time or 'entity',
|
|
594
983
|
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
595
984
|
or anchor_time == 'entity') else None,
|
|
596
985
|
max_iterations=max_pq_iterations,
|
|
@@ -617,12 +1006,23 @@ class KumoRFM:
|
|
|
617
1006
|
raise NotImplementedError
|
|
618
1007
|
logger.log(msg)
|
|
619
1008
|
|
|
620
|
-
entity_table_names
|
|
621
|
-
|
|
1009
|
+
entity_table_names: Tuple[str, ...]
|
|
1010
|
+
if task_type.is_link_pred:
|
|
1011
|
+
final_aggr = query.get_final_target_aggregation()
|
|
1012
|
+
assert final_aggr is not None
|
|
1013
|
+
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
+
for edge_type in self._graph_store.edge_types:
|
|
1015
|
+
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
|
+
entity_table_names = (
|
|
1017
|
+
query.entity_table,
|
|
1018
|
+
edge_type[2],
|
|
1019
|
+
)
|
|
1020
|
+
else:
|
|
1021
|
+
entity_table_names = (query.entity_table, )
|
|
622
1022
|
|
|
623
1023
|
# Exclude the entity anchor time from the feature set to prevent
|
|
624
1024
|
# running out-of-distribution between in-context and test examples:
|
|
625
|
-
exclude_cols_dict = query.
|
|
1025
|
+
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
626
1026
|
if anchor_time == 'entity':
|
|
627
1027
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
628
1028
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
@@ -637,11 +1037,21 @@ class KumoRFM:
|
|
|
637
1037
|
train_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
638
1038
|
test_time.astype('datetime64[ns]').astype(int).to_numpy(),
|
|
639
1039
|
]),
|
|
640
|
-
run_mode=run_mode,
|
|
641
1040
|
num_neighbors=num_neighbors,
|
|
642
1041
|
exclude_cols_dict=exclude_cols_dict,
|
|
643
1042
|
)
|
|
644
1043
|
|
|
1044
|
+
if len(subgraph.table_dict) >= 15:
|
|
1045
|
+
raise ValueError(f"Cannot query from a graph with more than 15 "
|
|
1046
|
+
f"tables (got {len(subgraph.table_dict)}). "
|
|
1047
|
+
f"Please create a feature request at "
|
|
1048
|
+
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
|
+
f"must go beyond this for your use-case.")
|
|
1050
|
+
|
|
1051
|
+
step_size: Optional[int] = None
|
|
1052
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
1053
|
+
step_size = date_offset_to_seconds(end_offset)
|
|
1054
|
+
|
|
645
1055
|
return Context(
|
|
646
1056
|
task_type=task_type,
|
|
647
1057
|
entity_table_names=entity_table_names,
|
|
@@ -649,6 +1059,7 @@ class KumoRFM:
|
|
|
649
1059
|
y_train=y_train,
|
|
650
1060
|
y_test=y_test,
|
|
651
1061
|
top_k=query.top_k,
|
|
1062
|
+
step_size=step_size,
|
|
652
1063
|
)
|
|
653
1064
|
|
|
654
1065
|
@staticmethod
|
|
@@ -664,7 +1075,7 @@ class KumoRFM:
|
|
|
664
1075
|
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
665
1076
|
supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
|
|
666
1077
|
elif task_type == TaskType.REGRESSION:
|
|
667
|
-
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
|
|
1078
|
+
supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
|
|
668
1079
|
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
669
1080
|
supported_metrics = [
|
|
670
1081
|
'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
|