kumoai 2.13.0.dev202511131731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +15 -13
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +191 -50
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +753 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +81 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +322 -252
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +13 -2
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +50 -29
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,25 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
30
|
|
|
33
|
-
from kumoai import
|
|
31
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
32
|
+
from kumoai.client.rfm import RFMAPI
|
|
34
33
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
34
|
+
from kumoai.experimental.rfm import Graph
|
|
35
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
36
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger
|
|
44
38
|
|
|
45
39
|
_RANDOM_SEED = 42
|
|
46
40
|
|
|
@@ -95,24 +89,41 @@ class Explanation:
|
|
|
95
89
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
90
|
pass
|
|
97
91
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
92
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
93
|
if index == 0:
|
|
100
94
|
return self.prediction
|
|
101
95
|
if index == 1:
|
|
102
96
|
return self.summary
|
|
103
97
|
raise IndexError("Index out of range")
|
|
104
98
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
99
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
100
|
return iter((self.prediction, self.summary))
|
|
107
101
|
|
|
108
102
|
def __repr__(self) -> str:
|
|
109
103
|
return str((self.prediction, self.summary))
|
|
110
104
|
|
|
111
|
-
def
|
|
112
|
-
|
|
105
|
+
def print(self) -> None:
|
|
106
|
+
r"""Prints the explanation."""
|
|
107
|
+
if in_snowflake_notebook():
|
|
108
|
+
import streamlit as st
|
|
109
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
+
st.markdown(self.summary)
|
|
111
|
+
elif in_notebook():
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
+
else:
|
|
117
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
+
except ImportError:
|
|
119
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
+
display(Markdown(self.summary))
|
|
121
|
+
else:
|
|
122
|
+
print(self.prediction.to_string(index=False))
|
|
123
|
+
print(self.summary)
|
|
113
124
|
|
|
114
|
-
|
|
115
|
-
|
|
125
|
+
def _ipython_display_(self) -> None:
|
|
126
|
+
self.print()
|
|
116
127
|
|
|
117
128
|
|
|
118
129
|
class KumoRFM:
|
|
@@ -123,17 +134,17 @@ class KumoRFM:
|
|
|
123
134
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
135
|
relational dataset without training.
|
|
125
136
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
137
|
+
model from a :class:`Graph` object.
|
|
127
138
|
|
|
128
139
|
.. code-block:: python
|
|
129
140
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
141
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
142
|
|
|
132
143
|
df_users = pd.DataFrame(...)
|
|
133
144
|
df_items = pd.DataFrame(...)
|
|
134
145
|
df_orders = pd.DataFrame(...)
|
|
135
146
|
|
|
136
|
-
graph =
|
|
147
|
+
graph = Graph.from_data({
|
|
137
148
|
'users': df_users,
|
|
138
149
|
'items': df_items,
|
|
139
150
|
'orders': df_orders,
|
|
@@ -141,47 +152,63 @@ class KumoRFM:
|
|
|
141
152
|
|
|
142
153
|
rfm = KumoRFM(graph)
|
|
143
154
|
|
|
144
|
-
query = ("PREDICT COUNT(
|
|
145
|
-
"FOR users.user_id=
|
|
146
|
-
result = rfm.
|
|
155
|
+
query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
|
|
156
|
+
"FOR users.user_id=1")
|
|
157
|
+
result = rfm.predict(query)
|
|
147
158
|
|
|
148
159
|
print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
|
|
149
160
|
# 1 0.85
|
|
150
161
|
|
|
151
162
|
Args:
|
|
152
163
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
164
|
verbose: Whether to print verbose output.
|
|
165
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
166
|
+
for optimal querying. For example, for transactional database
|
|
167
|
+
backends, will create any missing indices. Requires write-access to
|
|
168
|
+
the data backend.
|
|
163
169
|
"""
|
|
164
170
|
def __init__(
|
|
165
171
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
|
|
168
|
-
|
|
172
|
+
graph: Graph,
|
|
173
|
+
verbose: bool | ProgressLogger = True,
|
|
174
|
+
optimize: bool = False,
|
|
169
175
|
) -> None:
|
|
170
176
|
graph = graph.validate()
|
|
171
177
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
173
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
174
178
|
|
|
175
|
-
|
|
179
|
+
if graph.backend == DataBackend.LOCAL:
|
|
180
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
181
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
182
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
183
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
184
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
185
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
186
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
187
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
188
|
+
else:
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
self._client: RFMAPI | None = None
|
|
192
|
+
|
|
193
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
176
194
|
self.num_retries: int = 0
|
|
177
195
|
|
|
196
|
+
@property
|
|
197
|
+
def _api_client(self) -> RFMAPI:
|
|
198
|
+
if self._client is not None:
|
|
199
|
+
return self._client
|
|
200
|
+
|
|
201
|
+
from kumoai.experimental.rfm import global_state
|
|
202
|
+
self._client = RFMAPI(global_state.client)
|
|
203
|
+
return self._client
|
|
204
|
+
|
|
178
205
|
def __repr__(self) -> str:
|
|
179
206
|
return f'{self.__class__.__name__}()'
|
|
180
207
|
|
|
181
208
|
@contextmanager
|
|
182
209
|
def batch_mode(
|
|
183
210
|
self,
|
|
184
|
-
batch_size:
|
|
211
|
+
batch_size: int | Literal['max'] = 'max',
|
|
185
212
|
num_retries: int = 1,
|
|
186
213
|
) -> Generator[None, None, None]:
|
|
187
214
|
"""Context manager to predict in batches.
|
|
@@ -215,17 +242,17 @@ class KumoRFM:
|
|
|
215
242
|
def predict(
|
|
216
243
|
self,
|
|
217
244
|
query: str,
|
|
218
|
-
indices:
|
|
245
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
219
246
|
*,
|
|
220
247
|
explain: Literal[False] = False,
|
|
221
|
-
anchor_time:
|
|
222
|
-
context_anchor_time:
|
|
223
|
-
run_mode:
|
|
224
|
-
num_neighbors:
|
|
248
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
249
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
250
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
251
|
+
num_neighbors: list[int] | None = None,
|
|
225
252
|
num_hops: int = 2,
|
|
226
|
-
max_pq_iterations: int =
|
|
227
|
-
random_seed:
|
|
228
|
-
verbose:
|
|
253
|
+
max_pq_iterations: int = 10,
|
|
254
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
255
|
+
verbose: bool | ProgressLogger = True,
|
|
229
256
|
use_prediction_time: bool = False,
|
|
230
257
|
) -> pd.DataFrame:
|
|
231
258
|
pass
|
|
@@ -234,17 +261,17 @@ class KumoRFM:
|
|
|
234
261
|
def predict(
|
|
235
262
|
self,
|
|
236
263
|
query: str,
|
|
237
|
-
indices:
|
|
264
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
238
265
|
*,
|
|
239
|
-
explain:
|
|
240
|
-
anchor_time:
|
|
241
|
-
context_anchor_time:
|
|
242
|
-
run_mode:
|
|
243
|
-
num_neighbors:
|
|
266
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
267
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
268
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
269
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
270
|
+
num_neighbors: list[int] | None = None,
|
|
244
271
|
num_hops: int = 2,
|
|
245
|
-
max_pq_iterations: int =
|
|
246
|
-
random_seed:
|
|
247
|
-
verbose:
|
|
272
|
+
max_pq_iterations: int = 10,
|
|
273
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
274
|
+
verbose: bool | ProgressLogger = True,
|
|
248
275
|
use_prediction_time: bool = False,
|
|
249
276
|
) -> Explanation:
|
|
250
277
|
pass
|
|
@@ -252,19 +279,19 @@ class KumoRFM:
|
|
|
252
279
|
def predict(
|
|
253
280
|
self,
|
|
254
281
|
query: str,
|
|
255
|
-
indices:
|
|
282
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
256
283
|
*,
|
|
257
|
-
explain:
|
|
258
|
-
anchor_time:
|
|
259
|
-
context_anchor_time:
|
|
260
|
-
run_mode:
|
|
261
|
-
num_neighbors:
|
|
284
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
285
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
286
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
287
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
288
|
+
num_neighbors: list[int] | None = None,
|
|
262
289
|
num_hops: int = 2,
|
|
263
|
-
max_pq_iterations: int =
|
|
264
|
-
random_seed:
|
|
265
|
-
verbose:
|
|
290
|
+
max_pq_iterations: int = 10,
|
|
291
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
292
|
+
verbose: bool | ProgressLogger = True,
|
|
266
293
|
use_prediction_time: bool = False,
|
|
267
|
-
) ->
|
|
294
|
+
) -> pd.DataFrame | Explanation:
|
|
268
295
|
"""Returns predictions for a predictive query.
|
|
269
296
|
|
|
270
297
|
Args:
|
|
@@ -306,7 +333,7 @@ class KumoRFM:
|
|
|
306
333
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
307
334
|
containing the prediction, summary, and details.
|
|
308
335
|
"""
|
|
309
|
-
explain_config:
|
|
336
|
+
explain_config: ExplainConfig | None = None
|
|
310
337
|
if explain is True:
|
|
311
338
|
explain_config = ExplainConfig()
|
|
312
339
|
elif explain is not False:
|
|
@@ -350,15 +377,15 @@ class KumoRFM:
|
|
|
350
377
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
351
378
|
|
|
352
379
|
if not isinstance(verbose, ProgressLogger):
|
|
353
|
-
verbose =
|
|
380
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
354
381
|
|
|
355
382
|
with verbose as logger:
|
|
356
383
|
|
|
357
|
-
batch_size:
|
|
384
|
+
batch_size: int | None = None
|
|
358
385
|
if self._batch_size == 'max':
|
|
359
|
-
task_type =
|
|
360
|
-
query_def,
|
|
361
|
-
edge_types=self.
|
|
386
|
+
task_type = self._get_task_type(
|
|
387
|
+
query=query_def,
|
|
388
|
+
edge_types=self._sampler.edge_types,
|
|
362
389
|
)
|
|
363
390
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
364
391
|
else:
|
|
@@ -374,9 +401,9 @@ class KumoRFM:
|
|
|
374
401
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
375
402
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
376
403
|
|
|
377
|
-
predictions:
|
|
378
|
-
summary:
|
|
379
|
-
details:
|
|
404
|
+
predictions: list[pd.DataFrame] = []
|
|
405
|
+
summary: str | None = None
|
|
406
|
+
details: Explanation | None = None
|
|
380
407
|
for i, batch in enumerate(batches):
|
|
381
408
|
# TODO Re-use the context for subsequent predictions.
|
|
382
409
|
context = self._get_context(
|
|
@@ -410,8 +437,7 @@ class KumoRFM:
|
|
|
410
437
|
stats = Context.get_memory_stats(request_msg.context)
|
|
411
438
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
412
439
|
|
|
413
|
-
if
|
|
414
|
-
and len(batches) > 1):
|
|
440
|
+
if i == 0 and len(batches) > 1:
|
|
415
441
|
verbose.init_progress(
|
|
416
442
|
total=len(batches),
|
|
417
443
|
description='Predicting',
|
|
@@ -420,22 +446,22 @@ class KumoRFM:
|
|
|
420
446
|
for attempt in range(self.num_retries + 1):
|
|
421
447
|
try:
|
|
422
448
|
if explain_config is not None:
|
|
423
|
-
resp =
|
|
449
|
+
resp = self._api_client.explain(
|
|
424
450
|
request=_bytes,
|
|
425
451
|
skip_summary=explain_config.skip_summary,
|
|
426
452
|
)
|
|
427
453
|
summary = resp.summary
|
|
428
454
|
details = resp.details
|
|
429
455
|
else:
|
|
430
|
-
resp =
|
|
456
|
+
resp = self._api_client.predict(_bytes)
|
|
431
457
|
df = pd.DataFrame(**resp.prediction)
|
|
432
458
|
|
|
433
459
|
# Cast 'ENTITY' to correct data type:
|
|
434
460
|
if 'ENTITY' in df:
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
461
|
+
table_dict = context.subgraph.table_dict
|
|
462
|
+
table = table_dict[query_def.entity_table]
|
|
463
|
+
ser = table.df[table.primary_key]
|
|
464
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
439
465
|
|
|
440
466
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
441
467
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -450,8 +476,7 @@ class KumoRFM:
|
|
|
450
476
|
|
|
451
477
|
predictions.append(df)
|
|
452
478
|
|
|
453
|
-
if (
|
|
454
|
-
and len(batches) > 1):
|
|
479
|
+
if len(batches) > 1:
|
|
455
480
|
verbose.step()
|
|
456
481
|
|
|
457
482
|
break
|
|
@@ -489,9 +514,9 @@ class KumoRFM:
|
|
|
489
514
|
def is_valid_entity(
|
|
490
515
|
self,
|
|
491
516
|
query: str,
|
|
492
|
-
indices:
|
|
517
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
493
518
|
*,
|
|
494
|
-
anchor_time:
|
|
519
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
495
520
|
) -> np.ndarray:
|
|
496
521
|
r"""Returns a mask that denotes which entities are valid for the
|
|
497
522
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -518,37 +543,32 @@ class KumoRFM:
|
|
|
518
543
|
raise ValueError("At least one entity is required")
|
|
519
544
|
|
|
520
545
|
if anchor_time is None:
|
|
521
|
-
anchor_time = self.
|
|
546
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
522
547
|
|
|
523
548
|
if isinstance(anchor_time, pd.Timestamp):
|
|
524
549
|
self._validate_time(query_def, anchor_time, None, False)
|
|
525
550
|
else:
|
|
526
551
|
assert anchor_time == 'entity'
|
|
527
|
-
if
|
|
552
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
528
553
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
529
554
|
f"table '{query_def.entity_table}' "
|
|
530
555
|
f"to have a time column.")
|
|
531
556
|
|
|
532
|
-
|
|
533
|
-
table_name=query_def.entity_table,
|
|
534
|
-
pkey=pd.Series(indices),
|
|
535
|
-
)
|
|
536
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
537
|
-
return query_driver.is_valid(node, anchor_time)
|
|
557
|
+
raise NotImplementedError
|
|
538
558
|
|
|
539
559
|
def evaluate(
|
|
540
560
|
self,
|
|
541
561
|
query: str,
|
|
542
562
|
*,
|
|
543
|
-
metrics:
|
|
544
|
-
anchor_time:
|
|
545
|
-
context_anchor_time:
|
|
546
|
-
run_mode:
|
|
547
|
-
num_neighbors:
|
|
563
|
+
metrics: list[str] | None = None,
|
|
564
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
565
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
566
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
567
|
+
num_neighbors: list[int] | None = None,
|
|
548
568
|
num_hops: int = 2,
|
|
549
|
-
max_pq_iterations: int =
|
|
550
|
-
random_seed:
|
|
551
|
-
verbose:
|
|
569
|
+
max_pq_iterations: int = 10,
|
|
570
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
571
|
+
verbose: bool | ProgressLogger = True,
|
|
552
572
|
use_prediction_time: bool = False,
|
|
553
573
|
) -> pd.DataFrame:
|
|
554
574
|
"""Evaluates a predictive query.
|
|
@@ -596,7 +616,7 @@ class KumoRFM:
|
|
|
596
616
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
597
617
|
|
|
598
618
|
if not isinstance(verbose, ProgressLogger):
|
|
599
|
-
verbose =
|
|
619
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
600
620
|
|
|
601
621
|
with verbose as logger:
|
|
602
622
|
context = self._get_context(
|
|
@@ -633,7 +653,7 @@ class KumoRFM:
|
|
|
633
653
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
|
|
634
654
|
|
|
635
655
|
try:
|
|
636
|
-
resp =
|
|
656
|
+
resp = self._api_client.evaluate(request_bytes)
|
|
637
657
|
except HTTPException as e:
|
|
638
658
|
try:
|
|
639
659
|
msg = json.loads(e.detail)['detail']
|
|
@@ -655,9 +675,9 @@ class KumoRFM:
|
|
|
655
675
|
query: str,
|
|
656
676
|
size: int,
|
|
657
677
|
*,
|
|
658
|
-
anchor_time:
|
|
659
|
-
random_seed:
|
|
660
|
-
max_iterations: int =
|
|
678
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
679
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
680
|
+
max_iterations: int = 10,
|
|
661
681
|
) -> pd.DataFrame:
|
|
662
682
|
"""Returns the labels of a predictive query for a specified anchor
|
|
663
683
|
time.
|
|
@@ -677,40 +697,37 @@ class KumoRFM:
|
|
|
677
697
|
query_def = self._parse_query(query)
|
|
678
698
|
|
|
679
699
|
if anchor_time is None:
|
|
680
|
-
anchor_time = self.
|
|
700
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
681
701
|
if query_def.target_ast.date_offset_range is not None:
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
702
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
703
|
+
offset *= query_def.num_forecasts
|
|
704
|
+
anchor_time -= offset
|
|
685
705
|
|
|
686
706
|
assert anchor_time is not None
|
|
687
707
|
if isinstance(anchor_time, pd.Timestamp):
|
|
688
708
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
689
709
|
else:
|
|
690
710
|
assert anchor_time == 'entity'
|
|
691
|
-
if
|
|
711
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
692
712
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
693
713
|
f"table '{query_def.entity_table}' "
|
|
694
714
|
f"to have a time column")
|
|
695
715
|
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
716
|
+
train, test = self._sampler.sample_target(
|
|
717
|
+
query=query_def,
|
|
718
|
+
num_train_examples=0,
|
|
719
|
+
train_anchor_time=anchor_time,
|
|
720
|
+
num_train_trials=0,
|
|
721
|
+
num_test_examples=size,
|
|
722
|
+
test_anchor_time=anchor_time,
|
|
723
|
+
num_test_trials=max_iterations * size,
|
|
724
|
+
random_seed=random_seed,
|
|
705
725
|
)
|
|
706
726
|
|
|
707
|
-
entity = self._graph_store.pkey_map_dict[
|
|
708
|
-
query_def.entity_table].index[node]
|
|
709
|
-
|
|
710
727
|
return pd.DataFrame({
|
|
711
|
-
'ENTITY':
|
|
712
|
-
'ANCHOR_TIMESTAMP':
|
|
713
|
-
'TARGET':
|
|
728
|
+
'ENTITY': test.entity_pkey,
|
|
729
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
730
|
+
'TARGET': test.target,
|
|
714
731
|
})
|
|
715
732
|
|
|
716
733
|
# Helpers #################################################################
|
|
@@ -731,8 +748,7 @@ class KumoRFM:
|
|
|
731
748
|
graph_definition=self._graph_def,
|
|
732
749
|
)
|
|
733
750
|
|
|
734
|
-
resp =
|
|
735
|
-
# TODO Expose validation warnings.
|
|
751
|
+
resp = self._api_client.parse_query(request)
|
|
736
752
|
|
|
737
753
|
if len(resp.validation_response.warnings) > 0:
|
|
738
754
|
msg = '\n'.join([
|
|
@@ -751,36 +767,92 @@ class KumoRFM:
|
|
|
751
767
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
752
768
|
f"{msg}") from None
|
|
753
769
|
|
|
770
|
+
@staticmethod
|
|
771
|
+
def _get_task_type(
|
|
772
|
+
query: ValidatedPredictiveQuery,
|
|
773
|
+
edge_types: list[tuple[str, str, str]],
|
|
774
|
+
) -> TaskType:
|
|
775
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
776
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
777
|
+
|
|
778
|
+
target = query.target_ast
|
|
779
|
+
if isinstance(target, Join):
|
|
780
|
+
target = target.rhs_target
|
|
781
|
+
if isinstance(target, Aggregation):
|
|
782
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
783
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
784
|
+
'.')
|
|
785
|
+
target_edge_types = [
|
|
786
|
+
edge_type for edge_type in edge_types
|
|
787
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
788
|
+
]
|
|
789
|
+
if len(target_edge_types) != 1:
|
|
790
|
+
raise NotImplementedError(
|
|
791
|
+
f"Multilabel-classification queries based on "
|
|
792
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
793
|
+
f"planned to write a link prediction query instead, "
|
|
794
|
+
f"make sure to register '{col_name}' as a "
|
|
795
|
+
f"foreign key.")
|
|
796
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
797
|
+
|
|
798
|
+
return TaskType.REGRESSION
|
|
799
|
+
|
|
800
|
+
assert isinstance(target, Column)
|
|
801
|
+
|
|
802
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
803
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
804
|
+
|
|
805
|
+
if target.stype in {Stype.numerical}:
|
|
806
|
+
return TaskType.REGRESSION
|
|
807
|
+
|
|
808
|
+
raise NotImplementedError("Task type not yet supported")
|
|
809
|
+
|
|
810
|
+
def _get_default_anchor_time(
|
|
811
|
+
self,
|
|
812
|
+
query: ValidatedPredictiveQuery,
|
|
813
|
+
) -> pd.Timestamp:
|
|
814
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
815
|
+
aggr_table_names = [
|
|
816
|
+
aggr._get_target_column_name().split('.')[0]
|
|
817
|
+
for aggr in query.get_all_target_aggregations()
|
|
818
|
+
]
|
|
819
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
820
|
+
|
|
821
|
+
assert query.query_type == QueryType.STATIC
|
|
822
|
+
return self._sampler.get_max_time()
|
|
823
|
+
|
|
754
824
|
def _validate_time(
|
|
755
825
|
self,
|
|
756
826
|
query: ValidatedPredictiveQuery,
|
|
757
827
|
anchor_time: pd.Timestamp,
|
|
758
|
-
context_anchor_time:
|
|
828
|
+
context_anchor_time: pd.Timestamp | None,
|
|
759
829
|
evaluate: bool,
|
|
760
830
|
) -> None:
|
|
761
831
|
|
|
762
|
-
if self.
|
|
832
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
763
833
|
return # Graph without timestamps
|
|
764
834
|
|
|
765
|
-
|
|
835
|
+
min_time = self._sampler.get_min_time()
|
|
836
|
+
max_time = self._sampler.get_max_time()
|
|
837
|
+
|
|
838
|
+
if anchor_time < min_time:
|
|
766
839
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
767
|
-
f"the earliest timestamp "
|
|
768
|
-
f"
|
|
840
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
841
|
+
f"data.")
|
|
769
842
|
|
|
770
|
-
if
|
|
771
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
843
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
772
844
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
773
845
|
f"aggregation time range is too large. To make "
|
|
774
846
|
f"this prediction, we would need data back to "
|
|
775
847
|
f"'{context_anchor_time}', however, your data "
|
|
776
|
-
f"only contains data back to "
|
|
777
|
-
f"'{self._graph_store.min_time}'.")
|
|
848
|
+
f"only contains data back to '{min_time}'.")
|
|
778
849
|
|
|
779
850
|
if query.target_ast.date_offset_range is not None:
|
|
780
851
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
781
852
|
else:
|
|
782
853
|
end_offset = pd.DateOffset(0)
|
|
783
|
-
|
|
854
|
+
end_offset = end_offset * query.num_forecasts
|
|
855
|
+
|
|
784
856
|
if (context_anchor_time is not None
|
|
785
857
|
and context_anchor_time > anchor_time):
|
|
786
858
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -790,7 +862,7 @@ class KumoRFM:
|
|
|
790
862
|
f"intended.")
|
|
791
863
|
elif (query.query_type == QueryType.TEMPORAL
|
|
792
864
|
and context_anchor_time is not None
|
|
793
|
-
and context_anchor_time +
|
|
865
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
794
866
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
795
867
|
f"'{context_anchor_time}' will leak information "
|
|
796
868
|
f"from the prediction anchor timestamp "
|
|
@@ -798,40 +870,37 @@ class KumoRFM:
|
|
|
798
870
|
f"intended.")
|
|
799
871
|
|
|
800
872
|
elif (context_anchor_time is not None
|
|
801
|
-
and context_anchor_time -
|
|
802
|
-
|
|
803
|
-
_time = context_anchor_time - forecast_end_offset
|
|
873
|
+
and context_anchor_time - end_offset < min_time):
|
|
874
|
+
_time = context_anchor_time - end_offset
|
|
804
875
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
805
876
|
f"aggregation time range is too large. To form "
|
|
806
877
|
f"proper input data, we would need data back to "
|
|
807
878
|
f"'{_time}', however, your data only contains "
|
|
808
|
-
f"data back to '{
|
|
879
|
+
f"data back to '{min_time}'.")
|
|
809
880
|
|
|
810
|
-
if
|
|
811
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
881
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
812
882
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
813
|
-
f"latest timestamp '{
|
|
814
|
-
f"
|
|
883
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
884
|
+
f"make sure this is intended.")
|
|
815
885
|
|
|
816
|
-
|
|
817
|
-
if evaluate and anchor_time > max_eval_time:
|
|
886
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
818
887
|
raise ValueError(
|
|
819
888
|
f"Anchor timestamp for evaluation is after the latest "
|
|
820
|
-
f"supported timestamp '{
|
|
889
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
821
890
|
|
|
822
891
|
def _get_context(
|
|
823
892
|
self,
|
|
824
893
|
query: ValidatedPredictiveQuery,
|
|
825
|
-
indices:
|
|
826
|
-
anchor_time:
|
|
827
|
-
context_anchor_time:
|
|
894
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
+
context_anchor_time: pd.Timestamp | None,
|
|
828
897
|
run_mode: RunMode,
|
|
829
|
-
num_neighbors:
|
|
898
|
+
num_neighbors: list[int] | None,
|
|
830
899
|
num_hops: int,
|
|
831
900
|
max_pq_iterations: int,
|
|
832
901
|
evaluate: bool,
|
|
833
|
-
random_seed:
|
|
834
|
-
logger:
|
|
902
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
903
|
+
logger: ProgressLogger | None = None,
|
|
835
904
|
) -> Context:
|
|
836
905
|
|
|
837
906
|
if num_neighbors is not None:
|
|
@@ -848,10 +917,9 @@ class KumoRFM:
|
|
|
848
917
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
849
918
|
f"must go beyond this for your use-case.")
|
|
850
919
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
edge_types=self._graph_store.edge_types,
|
|
920
|
+
task_type = self._get_task_type(
|
|
921
|
+
query=query,
|
|
922
|
+
edge_types=self._sampler.edge_types,
|
|
855
923
|
)
|
|
856
924
|
|
|
857
925
|
if logger is not None:
|
|
@@ -883,14 +951,17 @@ class KumoRFM:
|
|
|
883
951
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
884
952
|
|
|
885
953
|
if query.target_ast.date_offset_range is None:
|
|
886
|
-
|
|
954
|
+
step_offset = pd.DateOffset(0)
|
|
887
955
|
else:
|
|
888
|
-
|
|
889
|
-
|
|
956
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
957
|
+
end_offset = step_offset * query.num_forecasts
|
|
958
|
+
|
|
890
959
|
if anchor_time is None:
|
|
891
|
-
anchor_time = self.
|
|
960
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
+
|
|
892
962
|
if evaluate:
|
|
893
|
-
anchor_time = anchor_time -
|
|
963
|
+
anchor_time = anchor_time - end_offset
|
|
964
|
+
|
|
894
965
|
if logger is not None:
|
|
895
966
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
896
967
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -904,57 +975,71 @@ class KumoRFM:
|
|
|
904
975
|
|
|
905
976
|
assert anchor_time is not None
|
|
906
977
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
|
+
if context_anchor_time == 'entity':
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
907
981
|
if context_anchor_time is None:
|
|
908
|
-
context_anchor_time = anchor_time -
|
|
982
|
+
context_anchor_time = anchor_time - end_offset
|
|
909
983
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
910
984
|
evaluate)
|
|
911
985
|
else:
|
|
912
986
|
assert anchor_time == 'entity'
|
|
913
|
-
if query.
|
|
987
|
+
if query.query_type != QueryType.STATIC:
|
|
988
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
989
|
+
"static predictive queries")
|
|
990
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
914
991
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
915
992
|
f"table '{query.entity_table}' to "
|
|
916
993
|
f"have a time column")
|
|
917
|
-
if context_anchor_time
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
context_anchor_time =
|
|
994
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
995
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
996
|
+
"for context and prediction examples")
|
|
997
|
+
context_anchor_time = 'entity'
|
|
921
998
|
|
|
922
|
-
|
|
999
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
923
1000
|
if evaluate:
|
|
924
|
-
|
|
1001
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
925
1002
|
if task_type.is_link_pred:
|
|
926
|
-
|
|
1003
|
+
num_test_examples = num_test_examples // 5
|
|
1004
|
+
else:
|
|
1005
|
+
num_test_examples = 0
|
|
1006
|
+
|
|
1007
|
+
train, test = self._sampler.sample_target(
|
|
1008
|
+
query=query,
|
|
1009
|
+
num_train_examples=num_train_examples,
|
|
1010
|
+
train_anchor_time=context_anchor_time,
|
|
1011
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1012
|
+
num_test_examples=num_test_examples,
|
|
1013
|
+
test_anchor_time=anchor_time,
|
|
1014
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
|
+
random_seed=random_seed,
|
|
1016
|
+
)
|
|
1017
|
+
train_pkey, train_time, y_train = train
|
|
1018
|
+
test_pkey, test_time, y_test = test
|
|
927
1019
|
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
num_rhs = y_test.explode().nunique()
|
|
949
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
950
|
-
f"{num_rhs:,} unique items")
|
|
951
|
-
else:
|
|
952
|
-
raise NotImplementedError
|
|
953
|
-
logger.log(msg)
|
|
1020
|
+
if evaluate and logger is not None:
|
|
1021
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1023
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1024
|
+
f"{pos:.2f}% positive cases")
|
|
1025
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1027
|
+
f"{y_test.nunique()} classes")
|
|
1028
|
+
elif task_type == TaskType.REGRESSION:
|
|
1029
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1030
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1031
|
+
f"between {format_value(_min)} and "
|
|
1032
|
+
f"{format_value(_max)}")
|
|
1033
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
+
num_rhs = y_test.explode().nunique()
|
|
1035
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1036
|
+
f"{num_rhs:,} unique items")
|
|
1037
|
+
else:
|
|
1038
|
+
raise NotImplementedError
|
|
1039
|
+
logger.log(msg)
|
|
954
1040
|
|
|
955
|
-
|
|
1041
|
+
if not evaluate:
|
|
956
1042
|
assert indices is not None
|
|
957
|
-
|
|
958
1043
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
959
1044
|
raise ValueError(f"Cannot predict for more than "
|
|
960
1045
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -962,26 +1047,12 @@ class KumoRFM:
|
|
|
962
1047
|
f"`KumoRFM.batch_mode` to process entities "
|
|
963
1048
|
f"in batches")
|
|
964
1049
|
|
|
965
|
-
|
|
966
|
-
table_name=query.entity_table,
|
|
967
|
-
pkey=pd.Series(indices),
|
|
968
|
-
)
|
|
969
|
-
|
|
1050
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
970
1051
|
if isinstance(anchor_time, pd.Timestamp):
|
|
971
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
972
|
-
len(
|
|
1052
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1053
|
+
len(indices)).reset_index(drop=True)
|
|
973
1054
|
else:
|
|
974
|
-
|
|
975
|
-
time = time[test_node] * 1000**3
|
|
976
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
977
|
-
|
|
978
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
979
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
980
|
-
anchor_time=context_anchor_time or 'entity',
|
|
981
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
982
|
-
or anchor_time == 'entity') else None,
|
|
983
|
-
max_iterations=max_pq_iterations,
|
|
984
|
-
)
|
|
1055
|
+
train_time = test_time = 'entity'
|
|
985
1056
|
|
|
986
1057
|
if logger is not None:
|
|
987
1058
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1004,12 +1075,12 @@ class KumoRFM:
|
|
|
1004
1075
|
raise NotImplementedError
|
|
1005
1076
|
logger.log(msg)
|
|
1006
1077
|
|
|
1007
|
-
entity_table_names:
|
|
1078
|
+
entity_table_names: tuple[str, ...]
|
|
1008
1079
|
if task_type.is_link_pred:
|
|
1009
1080
|
final_aggr = query.get_final_target_aggregation()
|
|
1010
1081
|
assert final_aggr is not None
|
|
1011
1082
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1012
|
-
for edge_type in self.
|
|
1083
|
+
for edge_type in self._sampler.edge_types:
|
|
1013
1084
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1014
1085
|
entity_table_names = (
|
|
1015
1086
|
query.entity_table,
|
|
@@ -1021,21 +1092,24 @@ class KumoRFM:
|
|
|
1021
1092
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1022
1093
|
# running out-of-distribution between in-context and test examples:
|
|
1023
1094
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1024
|
-
if
|
|
1095
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1025
1096
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1026
1097
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1027
|
-
|
|
1028
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1098
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1029
1099
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1030
1100
|
|
|
1031
|
-
subgraph = self.
|
|
1101
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1032
1102
|
entity_table_names=entity_table_names,
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1103
|
+
entity_pkey=pd.concat(
|
|
1104
|
+
[train_pkey, test_pkey],
|
|
1105
|
+
axis=0,
|
|
1106
|
+
ignore_index=True,
|
|
1107
|
+
),
|
|
1108
|
+
anchor_time=pd.concat(
|
|
1109
|
+
[train_time, test_time],
|
|
1110
|
+
axis=0,
|
|
1111
|
+
ignore_index=True,
|
|
1112
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1039
1113
|
num_neighbors=num_neighbors,
|
|
1040
1114
|
exclude_cols_dict=exclude_cols_dict,
|
|
1041
1115
|
)
|
|
@@ -1047,23 +1121,19 @@ class KumoRFM:
|
|
|
1047
1121
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1048
1122
|
f"must go beyond this for your use-case.")
|
|
1049
1123
|
|
|
1050
|
-
step_size: Optional[int] = None
|
|
1051
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1052
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1053
|
-
|
|
1054
1124
|
return Context(
|
|
1055
1125
|
task_type=task_type,
|
|
1056
1126
|
entity_table_names=entity_table_names,
|
|
1057
1127
|
subgraph=subgraph,
|
|
1058
1128
|
y_train=y_train,
|
|
1059
|
-
y_test=y_test,
|
|
1129
|
+
y_test=y_test if evaluate else None,
|
|
1060
1130
|
top_k=query.top_k,
|
|
1061
|
-
step_size=
|
|
1131
|
+
step_size=None,
|
|
1062
1132
|
)
|
|
1063
1133
|
|
|
1064
1134
|
@staticmethod
|
|
1065
1135
|
def _validate_metrics(
|
|
1066
|
-
metrics:
|
|
1136
|
+
metrics: list[str],
|
|
1067
1137
|
task_type: TaskType,
|
|
1068
1138
|
) -> None:
|
|
1069
1139
|
|
|
@@ -1120,7 +1190,7 @@ class KumoRFM:
|
|
|
1120
1190
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1121
1191
|
|
|
1122
1192
|
|
|
1123
|
-
def format_value(value:
|
|
1193
|
+
def format_value(value: int | float) -> str:
|
|
1124
1194
|
if value == int(value):
|
|
1125
1195
|
return f'{int(value):,}'
|
|
1126
1196
|
if abs(value) >= 1000:
|