kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +52 -52
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +753 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +81 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +313 -245
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +49 -29
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,25 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
30
|
|
|
31
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
33
32
|
from kumoai.client.rfm import RFMAPI
|
|
34
33
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
34
|
+
from kumoai.experimental.rfm import Graph
|
|
35
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
36
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger
|
|
44
38
|
|
|
45
39
|
_RANDOM_SEED = 42
|
|
46
40
|
|
|
@@ -95,24 +89,41 @@ class Explanation:
|
|
|
95
89
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
90
|
pass
|
|
97
91
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
92
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
93
|
if index == 0:
|
|
100
94
|
return self.prediction
|
|
101
95
|
if index == 1:
|
|
102
96
|
return self.summary
|
|
103
97
|
raise IndexError("Index out of range")
|
|
104
98
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
99
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
100
|
return iter((self.prediction, self.summary))
|
|
107
101
|
|
|
108
102
|
def __repr__(self) -> str:
|
|
109
103
|
return str((self.prediction, self.summary))
|
|
110
104
|
|
|
111
|
-
def
|
|
112
|
-
|
|
105
|
+
def print(self) -> None:
|
|
106
|
+
r"""Prints the explanation."""
|
|
107
|
+
if in_snowflake_notebook():
|
|
108
|
+
import streamlit as st
|
|
109
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
+
st.markdown(self.summary)
|
|
111
|
+
elif in_notebook():
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
+
else:
|
|
117
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
+
except ImportError:
|
|
119
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
+
display(Markdown(self.summary))
|
|
121
|
+
else:
|
|
122
|
+
print(self.prediction.to_string(index=False))
|
|
123
|
+
print(self.summary)
|
|
113
124
|
|
|
114
|
-
|
|
115
|
-
|
|
125
|
+
def _ipython_display_(self) -> None:
|
|
126
|
+
self.print()
|
|
116
127
|
|
|
117
128
|
|
|
118
129
|
class KumoRFM:
|
|
@@ -123,17 +134,17 @@ class KumoRFM:
|
|
|
123
134
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
135
|
relational dataset without training.
|
|
125
136
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
137
|
+
model from a :class:`Graph` object.
|
|
127
138
|
|
|
128
139
|
.. code-block:: python
|
|
129
140
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
141
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
142
|
|
|
132
143
|
df_users = pd.DataFrame(...)
|
|
133
144
|
df_items = pd.DataFrame(...)
|
|
134
145
|
df_orders = pd.DataFrame(...)
|
|
135
146
|
|
|
136
|
-
graph =
|
|
147
|
+
graph = Graph.from_data({
|
|
137
148
|
'users': df_users,
|
|
138
149
|
'items': df_items,
|
|
139
150
|
'orders': df_orders,
|
|
@@ -150,32 +161,46 @@ class KumoRFM:
|
|
|
150
161
|
|
|
151
162
|
Args:
|
|
152
163
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
164
|
verbose: Whether to print verbose output.
|
|
165
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
166
|
+
for optimal querying. For example, for transactional database
|
|
167
|
+
backends, will create any missing indices. Requires write-access to
|
|
168
|
+
the data backend.
|
|
163
169
|
"""
|
|
164
170
|
def __init__(
|
|
165
171
|
self,
|
|
166
|
-
graph:
|
|
167
|
-
|
|
168
|
-
|
|
172
|
+
graph: Graph,
|
|
173
|
+
verbose: bool | ProgressLogger = True,
|
|
174
|
+
optimize: bool = False,
|
|
169
175
|
) -> None:
|
|
170
176
|
graph = graph.validate()
|
|
171
177
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
173
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
174
178
|
|
|
175
|
-
|
|
179
|
+
if graph.backend == DataBackend.LOCAL:
|
|
180
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
181
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
182
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
183
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
184
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
185
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
186
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
187
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
188
|
+
else:
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
self._client: RFMAPI | None = None
|
|
192
|
+
|
|
193
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
176
194
|
self.num_retries: int = 0
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def _api_client(self) -> RFMAPI:
|
|
198
|
+
if self._client is not None:
|
|
199
|
+
return self._client
|
|
200
|
+
|
|
177
201
|
from kumoai.experimental.rfm import global_state
|
|
178
|
-
self.
|
|
202
|
+
self._client = RFMAPI(global_state.client)
|
|
203
|
+
return self._client
|
|
179
204
|
|
|
180
205
|
def __repr__(self) -> str:
|
|
181
206
|
return f'{self.__class__.__name__}()'
|
|
@@ -183,7 +208,7 @@ class KumoRFM:
|
|
|
183
208
|
@contextmanager
|
|
184
209
|
def batch_mode(
|
|
185
210
|
self,
|
|
186
|
-
batch_size:
|
|
211
|
+
batch_size: int | Literal['max'] = 'max',
|
|
187
212
|
num_retries: int = 1,
|
|
188
213
|
) -> Generator[None, None, None]:
|
|
189
214
|
"""Context manager to predict in batches.
|
|
@@ -217,17 +242,17 @@ class KumoRFM:
|
|
|
217
242
|
def predict(
|
|
218
243
|
self,
|
|
219
244
|
query: str,
|
|
220
|
-
indices:
|
|
245
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
221
246
|
*,
|
|
222
247
|
explain: Literal[False] = False,
|
|
223
|
-
anchor_time:
|
|
224
|
-
context_anchor_time:
|
|
225
|
-
run_mode:
|
|
226
|
-
num_neighbors:
|
|
248
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
249
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
250
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
251
|
+
num_neighbors: list[int] | None = None,
|
|
227
252
|
num_hops: int = 2,
|
|
228
|
-
max_pq_iterations: int =
|
|
229
|
-
random_seed:
|
|
230
|
-
verbose:
|
|
253
|
+
max_pq_iterations: int = 10,
|
|
254
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
255
|
+
verbose: bool | ProgressLogger = True,
|
|
231
256
|
use_prediction_time: bool = False,
|
|
232
257
|
) -> pd.DataFrame:
|
|
233
258
|
pass
|
|
@@ -236,17 +261,17 @@ class KumoRFM:
|
|
|
236
261
|
def predict(
|
|
237
262
|
self,
|
|
238
263
|
query: str,
|
|
239
|
-
indices:
|
|
264
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
240
265
|
*,
|
|
241
|
-
explain:
|
|
242
|
-
anchor_time:
|
|
243
|
-
context_anchor_time:
|
|
244
|
-
run_mode:
|
|
245
|
-
num_neighbors:
|
|
266
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
267
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
268
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
269
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
270
|
+
num_neighbors: list[int] | None = None,
|
|
246
271
|
num_hops: int = 2,
|
|
247
|
-
max_pq_iterations: int =
|
|
248
|
-
random_seed:
|
|
249
|
-
verbose:
|
|
272
|
+
max_pq_iterations: int = 10,
|
|
273
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
274
|
+
verbose: bool | ProgressLogger = True,
|
|
250
275
|
use_prediction_time: bool = False,
|
|
251
276
|
) -> Explanation:
|
|
252
277
|
pass
|
|
@@ -254,19 +279,19 @@ class KumoRFM:
|
|
|
254
279
|
def predict(
|
|
255
280
|
self,
|
|
256
281
|
query: str,
|
|
257
|
-
indices:
|
|
282
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
258
283
|
*,
|
|
259
|
-
explain:
|
|
260
|
-
anchor_time:
|
|
261
|
-
context_anchor_time:
|
|
262
|
-
run_mode:
|
|
263
|
-
num_neighbors:
|
|
284
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
285
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
286
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
287
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
288
|
+
num_neighbors: list[int] | None = None,
|
|
264
289
|
num_hops: int = 2,
|
|
265
|
-
max_pq_iterations: int =
|
|
266
|
-
random_seed:
|
|
267
|
-
verbose:
|
|
290
|
+
max_pq_iterations: int = 10,
|
|
291
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
292
|
+
verbose: bool | ProgressLogger = True,
|
|
268
293
|
use_prediction_time: bool = False,
|
|
269
|
-
) ->
|
|
294
|
+
) -> pd.DataFrame | Explanation:
|
|
270
295
|
"""Returns predictions for a predictive query.
|
|
271
296
|
|
|
272
297
|
Args:
|
|
@@ -308,7 +333,7 @@ class KumoRFM:
|
|
|
308
333
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
309
334
|
containing the prediction, summary, and details.
|
|
310
335
|
"""
|
|
311
|
-
explain_config:
|
|
336
|
+
explain_config: ExplainConfig | None = None
|
|
312
337
|
if explain is True:
|
|
313
338
|
explain_config = ExplainConfig()
|
|
314
339
|
elif explain is not False:
|
|
@@ -352,15 +377,15 @@ class KumoRFM:
|
|
|
352
377
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
353
378
|
|
|
354
379
|
if not isinstance(verbose, ProgressLogger):
|
|
355
|
-
verbose =
|
|
380
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
356
381
|
|
|
357
382
|
with verbose as logger:
|
|
358
383
|
|
|
359
|
-
batch_size:
|
|
384
|
+
batch_size: int | None = None
|
|
360
385
|
if self._batch_size == 'max':
|
|
361
|
-
task_type =
|
|
362
|
-
query_def,
|
|
363
|
-
edge_types=self.
|
|
386
|
+
task_type = self._get_task_type(
|
|
387
|
+
query=query_def,
|
|
388
|
+
edge_types=self._sampler.edge_types,
|
|
364
389
|
)
|
|
365
390
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
366
391
|
else:
|
|
@@ -376,9 +401,9 @@ class KumoRFM:
|
|
|
376
401
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
377
402
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
378
403
|
|
|
379
|
-
predictions:
|
|
380
|
-
summary:
|
|
381
|
-
details:
|
|
404
|
+
predictions: list[pd.DataFrame] = []
|
|
405
|
+
summary: str | None = None
|
|
406
|
+
details: Explanation | None = None
|
|
382
407
|
for i, batch in enumerate(batches):
|
|
383
408
|
# TODO Re-use the context for subsequent predictions.
|
|
384
409
|
context = self._get_context(
|
|
@@ -412,8 +437,7 @@ class KumoRFM:
|
|
|
412
437
|
stats = Context.get_memory_stats(request_msg.context)
|
|
413
438
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
414
439
|
|
|
415
|
-
if
|
|
416
|
-
and len(batches) > 1):
|
|
440
|
+
if i == 0 and len(batches) > 1:
|
|
417
441
|
verbose.init_progress(
|
|
418
442
|
total=len(batches),
|
|
419
443
|
description='Predicting',
|
|
@@ -434,10 +458,10 @@ class KumoRFM:
|
|
|
434
458
|
|
|
435
459
|
# Cast 'ENTITY' to correct data type:
|
|
436
460
|
if 'ENTITY' in df:
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
461
|
+
table_dict = context.subgraph.table_dict
|
|
462
|
+
table = table_dict[query_def.entity_table]
|
|
463
|
+
ser = table.df[table.primary_key]
|
|
464
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
441
465
|
|
|
442
466
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
443
467
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -452,8 +476,7 @@ class KumoRFM:
|
|
|
452
476
|
|
|
453
477
|
predictions.append(df)
|
|
454
478
|
|
|
455
|
-
if (
|
|
456
|
-
and len(batches) > 1):
|
|
479
|
+
if len(batches) > 1:
|
|
457
480
|
verbose.step()
|
|
458
481
|
|
|
459
482
|
break
|
|
@@ -491,9 +514,9 @@ class KumoRFM:
|
|
|
491
514
|
def is_valid_entity(
|
|
492
515
|
self,
|
|
493
516
|
query: str,
|
|
494
|
-
indices:
|
|
517
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
495
518
|
*,
|
|
496
|
-
anchor_time:
|
|
519
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
497
520
|
) -> np.ndarray:
|
|
498
521
|
r"""Returns a mask that denotes which entities are valid for the
|
|
499
522
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -520,37 +543,32 @@ class KumoRFM:
|
|
|
520
543
|
raise ValueError("At least one entity is required")
|
|
521
544
|
|
|
522
545
|
if anchor_time is None:
|
|
523
|
-
anchor_time = self.
|
|
546
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
524
547
|
|
|
525
548
|
if isinstance(anchor_time, pd.Timestamp):
|
|
526
549
|
self._validate_time(query_def, anchor_time, None, False)
|
|
527
550
|
else:
|
|
528
551
|
assert anchor_time == 'entity'
|
|
529
|
-
if
|
|
552
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
530
553
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
531
554
|
f"table '{query_def.entity_table}' "
|
|
532
555
|
f"to have a time column.")
|
|
533
556
|
|
|
534
|
-
|
|
535
|
-
table_name=query_def.entity_table,
|
|
536
|
-
pkey=pd.Series(indices),
|
|
537
|
-
)
|
|
538
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
539
|
-
return query_driver.is_valid(node, anchor_time)
|
|
557
|
+
raise NotImplementedError
|
|
540
558
|
|
|
541
559
|
def evaluate(
|
|
542
560
|
self,
|
|
543
561
|
query: str,
|
|
544
562
|
*,
|
|
545
|
-
metrics:
|
|
546
|
-
anchor_time:
|
|
547
|
-
context_anchor_time:
|
|
548
|
-
run_mode:
|
|
549
|
-
num_neighbors:
|
|
563
|
+
metrics: list[str] | None = None,
|
|
564
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
565
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
566
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
567
|
+
num_neighbors: list[int] | None = None,
|
|
550
568
|
num_hops: int = 2,
|
|
551
|
-
max_pq_iterations: int =
|
|
552
|
-
random_seed:
|
|
553
|
-
verbose:
|
|
569
|
+
max_pq_iterations: int = 10,
|
|
570
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
571
|
+
verbose: bool | ProgressLogger = True,
|
|
554
572
|
use_prediction_time: bool = False,
|
|
555
573
|
) -> pd.DataFrame:
|
|
556
574
|
"""Evaluates a predictive query.
|
|
@@ -598,7 +616,7 @@ class KumoRFM:
|
|
|
598
616
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
599
617
|
|
|
600
618
|
if not isinstance(verbose, ProgressLogger):
|
|
601
|
-
verbose =
|
|
619
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
602
620
|
|
|
603
621
|
with verbose as logger:
|
|
604
622
|
context = self._get_context(
|
|
@@ -657,9 +675,9 @@ class KumoRFM:
|
|
|
657
675
|
query: str,
|
|
658
676
|
size: int,
|
|
659
677
|
*,
|
|
660
|
-
anchor_time:
|
|
661
|
-
random_seed:
|
|
662
|
-
max_iterations: int =
|
|
678
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
679
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
680
|
+
max_iterations: int = 10,
|
|
663
681
|
) -> pd.DataFrame:
|
|
664
682
|
"""Returns the labels of a predictive query for a specified anchor
|
|
665
683
|
time.
|
|
@@ -679,40 +697,37 @@ class KumoRFM:
|
|
|
679
697
|
query_def = self._parse_query(query)
|
|
680
698
|
|
|
681
699
|
if anchor_time is None:
|
|
682
|
-
anchor_time = self.
|
|
700
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
683
701
|
if query_def.target_ast.date_offset_range is not None:
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
702
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
703
|
+
offset *= query_def.num_forecasts
|
|
704
|
+
anchor_time -= offset
|
|
687
705
|
|
|
688
706
|
assert anchor_time is not None
|
|
689
707
|
if isinstance(anchor_time, pd.Timestamp):
|
|
690
708
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
691
709
|
else:
|
|
692
710
|
assert anchor_time == 'entity'
|
|
693
|
-
if
|
|
711
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
694
712
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
695
713
|
f"table '{query_def.entity_table}' "
|
|
696
714
|
f"to have a time column")
|
|
697
715
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
716
|
+
train, test = self._sampler.sample_target(
|
|
717
|
+
query=query_def,
|
|
718
|
+
num_train_examples=0,
|
|
719
|
+
train_anchor_time=anchor_time,
|
|
720
|
+
num_train_trials=0,
|
|
721
|
+
num_test_examples=size,
|
|
722
|
+
test_anchor_time=anchor_time,
|
|
723
|
+
num_test_trials=max_iterations * size,
|
|
724
|
+
random_seed=random_seed,
|
|
707
725
|
)
|
|
708
726
|
|
|
709
|
-
entity = self._graph_store.pkey_map_dict[
|
|
710
|
-
query_def.entity_table].index[node]
|
|
711
|
-
|
|
712
727
|
return pd.DataFrame({
|
|
713
|
-
'ENTITY':
|
|
714
|
-
'ANCHOR_TIMESTAMP':
|
|
715
|
-
'TARGET':
|
|
728
|
+
'ENTITY': test.entity_pkey,
|
|
729
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
730
|
+
'TARGET': test.target,
|
|
716
731
|
})
|
|
717
732
|
|
|
718
733
|
# Helpers #################################################################
|
|
@@ -735,8 +750,6 @@ class KumoRFM:
|
|
|
735
750
|
|
|
736
751
|
resp = self._api_client.parse_query(request)
|
|
737
752
|
|
|
738
|
-
# TODO Expose validation warnings.
|
|
739
|
-
|
|
740
753
|
if len(resp.validation_response.warnings) > 0:
|
|
741
754
|
msg = '\n'.join([
|
|
742
755
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -754,36 +767,92 @@ class KumoRFM:
|
|
|
754
767
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
755
768
|
f"{msg}") from None
|
|
756
769
|
|
|
770
|
+
@staticmethod
|
|
771
|
+
def _get_task_type(
|
|
772
|
+
query: ValidatedPredictiveQuery,
|
|
773
|
+
edge_types: list[tuple[str, str, str]],
|
|
774
|
+
) -> TaskType:
|
|
775
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
776
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
777
|
+
|
|
778
|
+
target = query.target_ast
|
|
779
|
+
if isinstance(target, Join):
|
|
780
|
+
target = target.rhs_target
|
|
781
|
+
if isinstance(target, Aggregation):
|
|
782
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
783
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
784
|
+
'.')
|
|
785
|
+
target_edge_types = [
|
|
786
|
+
edge_type for edge_type in edge_types
|
|
787
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
788
|
+
]
|
|
789
|
+
if len(target_edge_types) != 1:
|
|
790
|
+
raise NotImplementedError(
|
|
791
|
+
f"Multilabel-classification queries based on "
|
|
792
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
793
|
+
f"planned to write a link prediction query instead, "
|
|
794
|
+
f"make sure to register '{col_name}' as a "
|
|
795
|
+
f"foreign key.")
|
|
796
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
797
|
+
|
|
798
|
+
return TaskType.REGRESSION
|
|
799
|
+
|
|
800
|
+
assert isinstance(target, Column)
|
|
801
|
+
|
|
802
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
803
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
804
|
+
|
|
805
|
+
if target.stype in {Stype.numerical}:
|
|
806
|
+
return TaskType.REGRESSION
|
|
807
|
+
|
|
808
|
+
raise NotImplementedError("Task type not yet supported")
|
|
809
|
+
|
|
810
|
+
def _get_default_anchor_time(
|
|
811
|
+
self,
|
|
812
|
+
query: ValidatedPredictiveQuery,
|
|
813
|
+
) -> pd.Timestamp:
|
|
814
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
815
|
+
aggr_table_names = [
|
|
816
|
+
aggr._get_target_column_name().split('.')[0]
|
|
817
|
+
for aggr in query.get_all_target_aggregations()
|
|
818
|
+
]
|
|
819
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
820
|
+
|
|
821
|
+
assert query.query_type == QueryType.STATIC
|
|
822
|
+
return self._sampler.get_max_time()
|
|
823
|
+
|
|
757
824
|
def _validate_time(
|
|
758
825
|
self,
|
|
759
826
|
query: ValidatedPredictiveQuery,
|
|
760
827
|
anchor_time: pd.Timestamp,
|
|
761
|
-
context_anchor_time:
|
|
828
|
+
context_anchor_time: pd.Timestamp | None,
|
|
762
829
|
evaluate: bool,
|
|
763
830
|
) -> None:
|
|
764
831
|
|
|
765
|
-
if self.
|
|
832
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
766
833
|
return # Graph without timestamps
|
|
767
834
|
|
|
768
|
-
|
|
835
|
+
min_time = self._sampler.get_min_time()
|
|
836
|
+
max_time = self._sampler.get_max_time()
|
|
837
|
+
|
|
838
|
+
if anchor_time < min_time:
|
|
769
839
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
770
|
-
f"the earliest timestamp "
|
|
771
|
-
f"
|
|
840
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
841
|
+
f"data.")
|
|
772
842
|
|
|
773
|
-
if
|
|
774
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
843
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
775
844
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
776
845
|
f"aggregation time range is too large. To make "
|
|
777
846
|
f"this prediction, we would need data back to "
|
|
778
847
|
f"'{context_anchor_time}', however, your data "
|
|
779
|
-
f"only contains data back to "
|
|
780
|
-
f"'{self._graph_store.min_time}'.")
|
|
848
|
+
f"only contains data back to '{min_time}'.")
|
|
781
849
|
|
|
782
850
|
if query.target_ast.date_offset_range is not None:
|
|
783
851
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
784
852
|
else:
|
|
785
853
|
end_offset = pd.DateOffset(0)
|
|
786
|
-
|
|
854
|
+
end_offset = end_offset * query.num_forecasts
|
|
855
|
+
|
|
787
856
|
if (context_anchor_time is not None
|
|
788
857
|
and context_anchor_time > anchor_time):
|
|
789
858
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -793,7 +862,7 @@ class KumoRFM:
|
|
|
793
862
|
f"intended.")
|
|
794
863
|
elif (query.query_type == QueryType.TEMPORAL
|
|
795
864
|
and context_anchor_time is not None
|
|
796
|
-
and context_anchor_time +
|
|
865
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
797
866
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
798
867
|
f"'{context_anchor_time}' will leak information "
|
|
799
868
|
f"from the prediction anchor timestamp "
|
|
@@ -801,40 +870,37 @@ class KumoRFM:
|
|
|
801
870
|
f"intended.")
|
|
802
871
|
|
|
803
872
|
elif (context_anchor_time is not None
|
|
804
|
-
and context_anchor_time -
|
|
805
|
-
|
|
806
|
-
_time = context_anchor_time - forecast_end_offset
|
|
873
|
+
and context_anchor_time - end_offset < min_time):
|
|
874
|
+
_time = context_anchor_time - end_offset
|
|
807
875
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
808
876
|
f"aggregation time range is too large. To form "
|
|
809
877
|
f"proper input data, we would need data back to "
|
|
810
878
|
f"'{_time}', however, your data only contains "
|
|
811
|
-
f"data back to '{
|
|
879
|
+
f"data back to '{min_time}'.")
|
|
812
880
|
|
|
813
|
-
if
|
|
814
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
881
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
815
882
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
816
|
-
f"latest timestamp '{
|
|
817
|
-
f"
|
|
883
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
884
|
+
f"make sure this is intended.")
|
|
818
885
|
|
|
819
|
-
|
|
820
|
-
if evaluate and anchor_time > max_eval_time:
|
|
886
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
821
887
|
raise ValueError(
|
|
822
888
|
f"Anchor timestamp for evaluation is after the latest "
|
|
823
|
-
f"supported timestamp '{
|
|
889
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
824
890
|
|
|
825
891
|
def _get_context(
|
|
826
892
|
self,
|
|
827
893
|
query: ValidatedPredictiveQuery,
|
|
828
|
-
indices:
|
|
829
|
-
anchor_time:
|
|
830
|
-
context_anchor_time:
|
|
894
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
+
context_anchor_time: pd.Timestamp | None,
|
|
831
897
|
run_mode: RunMode,
|
|
832
|
-
num_neighbors:
|
|
898
|
+
num_neighbors: list[int] | None,
|
|
833
899
|
num_hops: int,
|
|
834
900
|
max_pq_iterations: int,
|
|
835
901
|
evaluate: bool,
|
|
836
|
-
random_seed:
|
|
837
|
-
logger:
|
|
902
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
903
|
+
logger: ProgressLogger | None = None,
|
|
838
904
|
) -> Context:
|
|
839
905
|
|
|
840
906
|
if num_neighbors is not None:
|
|
@@ -851,10 +917,9 @@ class KumoRFM:
|
|
|
851
917
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
852
918
|
f"must go beyond this for your use-case.")
|
|
853
919
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
edge_types=self._graph_store.edge_types,
|
|
920
|
+
task_type = self._get_task_type(
|
|
921
|
+
query=query,
|
|
922
|
+
edge_types=self._sampler.edge_types,
|
|
858
923
|
)
|
|
859
924
|
|
|
860
925
|
if logger is not None:
|
|
@@ -886,14 +951,17 @@ class KumoRFM:
|
|
|
886
951
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
887
952
|
|
|
888
953
|
if query.target_ast.date_offset_range is None:
|
|
889
|
-
|
|
954
|
+
step_offset = pd.DateOffset(0)
|
|
890
955
|
else:
|
|
891
|
-
|
|
892
|
-
|
|
956
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
957
|
+
end_offset = step_offset * query.num_forecasts
|
|
958
|
+
|
|
893
959
|
if anchor_time is None:
|
|
894
|
-
anchor_time = self.
|
|
960
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
+
|
|
895
962
|
if evaluate:
|
|
896
|
-
anchor_time = anchor_time -
|
|
963
|
+
anchor_time = anchor_time - end_offset
|
|
964
|
+
|
|
897
965
|
if logger is not None:
|
|
898
966
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
899
967
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -907,57 +975,71 @@ class KumoRFM:
|
|
|
907
975
|
|
|
908
976
|
assert anchor_time is not None
|
|
909
977
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
|
+
if context_anchor_time == 'entity':
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
910
981
|
if context_anchor_time is None:
|
|
911
|
-
context_anchor_time = anchor_time -
|
|
982
|
+
context_anchor_time = anchor_time - end_offset
|
|
912
983
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
913
984
|
evaluate)
|
|
914
985
|
else:
|
|
915
986
|
assert anchor_time == 'entity'
|
|
916
|
-
if query.
|
|
987
|
+
if query.query_type != QueryType.STATIC:
|
|
988
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
989
|
+
"static predictive queries")
|
|
990
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
917
991
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
918
992
|
f"table '{query.entity_table}' to "
|
|
919
993
|
f"have a time column")
|
|
920
|
-
if context_anchor_time
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
context_anchor_time =
|
|
994
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
995
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
996
|
+
"for context and prediction examples")
|
|
997
|
+
context_anchor_time = 'entity'
|
|
924
998
|
|
|
925
|
-
|
|
999
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
926
1000
|
if evaluate:
|
|
927
|
-
|
|
1001
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
928
1002
|
if task_type.is_link_pred:
|
|
929
|
-
|
|
1003
|
+
num_test_examples = num_test_examples // 5
|
|
1004
|
+
else:
|
|
1005
|
+
num_test_examples = 0
|
|
1006
|
+
|
|
1007
|
+
train, test = self._sampler.sample_target(
|
|
1008
|
+
query=query,
|
|
1009
|
+
num_train_examples=num_train_examples,
|
|
1010
|
+
train_anchor_time=context_anchor_time,
|
|
1011
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1012
|
+
num_test_examples=num_test_examples,
|
|
1013
|
+
test_anchor_time=anchor_time,
|
|
1014
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
|
+
random_seed=random_seed,
|
|
1016
|
+
)
|
|
1017
|
+
train_pkey, train_time, y_train = train
|
|
1018
|
+
test_pkey, test_time, y_test = test
|
|
930
1019
|
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
num_rhs = y_test.explode().nunique()
|
|
952
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
953
|
-
f"{num_rhs:,} unique items")
|
|
954
|
-
else:
|
|
955
|
-
raise NotImplementedError
|
|
956
|
-
logger.log(msg)
|
|
1020
|
+
if evaluate and logger is not None:
|
|
1021
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1023
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1024
|
+
f"{pos:.2f}% positive cases")
|
|
1025
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1027
|
+
f"{y_test.nunique()} classes")
|
|
1028
|
+
elif task_type == TaskType.REGRESSION:
|
|
1029
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1030
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1031
|
+
f"between {format_value(_min)} and "
|
|
1032
|
+
f"{format_value(_max)}")
|
|
1033
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
+
num_rhs = y_test.explode().nunique()
|
|
1035
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1036
|
+
f"{num_rhs:,} unique items")
|
|
1037
|
+
else:
|
|
1038
|
+
raise NotImplementedError
|
|
1039
|
+
logger.log(msg)
|
|
957
1040
|
|
|
958
|
-
|
|
1041
|
+
if not evaluate:
|
|
959
1042
|
assert indices is not None
|
|
960
|
-
|
|
961
1043
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
962
1044
|
raise ValueError(f"Cannot predict for more than "
|
|
963
1045
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -965,26 +1047,12 @@ class KumoRFM:
|
|
|
965
1047
|
f"`KumoRFM.batch_mode` to process entities "
|
|
966
1048
|
f"in batches")
|
|
967
1049
|
|
|
968
|
-
|
|
969
|
-
table_name=query.entity_table,
|
|
970
|
-
pkey=pd.Series(indices),
|
|
971
|
-
)
|
|
972
|
-
|
|
1050
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
973
1051
|
if isinstance(anchor_time, pd.Timestamp):
|
|
974
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
975
|
-
len(
|
|
1052
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1053
|
+
len(indices)).reset_index(drop=True)
|
|
976
1054
|
else:
|
|
977
|
-
|
|
978
|
-
time = time[test_node] * 1000**3
|
|
979
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
980
|
-
|
|
981
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
982
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
983
|
-
anchor_time=context_anchor_time or 'entity',
|
|
984
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
985
|
-
or anchor_time == 'entity') else None,
|
|
986
|
-
max_iterations=max_pq_iterations,
|
|
987
|
-
)
|
|
1055
|
+
train_time = test_time = 'entity'
|
|
988
1056
|
|
|
989
1057
|
if logger is not None:
|
|
990
1058
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1007,12 +1075,12 @@ class KumoRFM:
|
|
|
1007
1075
|
raise NotImplementedError
|
|
1008
1076
|
logger.log(msg)
|
|
1009
1077
|
|
|
1010
|
-
entity_table_names:
|
|
1078
|
+
entity_table_names: tuple[str, ...]
|
|
1011
1079
|
if task_type.is_link_pred:
|
|
1012
1080
|
final_aggr = query.get_final_target_aggregation()
|
|
1013
1081
|
assert final_aggr is not None
|
|
1014
1082
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1015
|
-
for edge_type in self.
|
|
1083
|
+
for edge_type in self._sampler.edge_types:
|
|
1016
1084
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1017
1085
|
entity_table_names = (
|
|
1018
1086
|
query.entity_table,
|
|
@@ -1024,20 +1092,24 @@ class KumoRFM:
|
|
|
1024
1092
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1025
1093
|
# running out-of-distribution between in-context and test examples:
|
|
1026
1094
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1027
|
-
if
|
|
1095
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1028
1096
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1029
1097
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1030
|
-
|
|
1031
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1098
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1032
1099
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1033
1100
|
|
|
1034
|
-
subgraph = self.
|
|
1101
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1035
1102
|
entity_table_names=entity_table_names,
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1103
|
+
entity_pkey=pd.concat(
|
|
1104
|
+
[train_pkey, test_pkey],
|
|
1105
|
+
axis=0,
|
|
1106
|
+
ignore_index=True,
|
|
1107
|
+
),
|
|
1108
|
+
anchor_time=pd.concat(
|
|
1109
|
+
[train_time, test_time],
|
|
1110
|
+
axis=0,
|
|
1111
|
+
ignore_index=True,
|
|
1112
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1041
1113
|
num_neighbors=num_neighbors,
|
|
1042
1114
|
exclude_cols_dict=exclude_cols_dict,
|
|
1043
1115
|
)
|
|
@@ -1049,23 +1121,19 @@ class KumoRFM:
|
|
|
1049
1121
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1050
1122
|
f"must go beyond this for your use-case.")
|
|
1051
1123
|
|
|
1052
|
-
step_size: Optional[int] = None
|
|
1053
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1054
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1055
|
-
|
|
1056
1124
|
return Context(
|
|
1057
1125
|
task_type=task_type,
|
|
1058
1126
|
entity_table_names=entity_table_names,
|
|
1059
1127
|
subgraph=subgraph,
|
|
1060
1128
|
y_train=y_train,
|
|
1061
|
-
y_test=y_test,
|
|
1129
|
+
y_test=y_test if evaluate else None,
|
|
1062
1130
|
top_k=query.top_k,
|
|
1063
|
-
step_size=
|
|
1131
|
+
step_size=None,
|
|
1064
1132
|
)
|
|
1065
1133
|
|
|
1066
1134
|
@staticmethod
|
|
1067
1135
|
def _validate_metrics(
|
|
1068
|
-
metrics:
|
|
1136
|
+
metrics: list[str],
|
|
1069
1137
|
task_type: TaskType,
|
|
1070
1138
|
) -> None:
|
|
1071
1139
|
|
|
@@ -1122,7 +1190,7 @@ class KumoRFM:
|
|
|
1122
1190
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1123
1191
|
|
|
1124
1192
|
|
|
1125
|
-
def format_value(value:
|
|
1193
|
+
def format_value(value: int | float) -> str:
|
|
1126
1194
|
if value == int(value):
|
|
1127
1195
|
return f'{int(value):,}'
|
|
1128
1196
|
if abs(value) >= 1000:
|