kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +41 -80
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +147 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
- kumoai/experimental/rfm/base/__init__.py +26 -2
- kumoai/experimental/rfm/base/column.py +6 -12
- kumoai/experimental/rfm/base/column_expression.py +16 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +113 -0
- kumoai/experimental/rfm/base/table.py +174 -76
- kumoai/experimental/rfm/graph.py +444 -84
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +77 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -240
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,25 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
30
|
|
|
31
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
33
32
|
from kumoai.client.rfm import RFMAPI
|
|
34
33
|
from kumoai.exceptions import HTTPException
|
|
35
34
|
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
35
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
36
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger
|
|
44
38
|
|
|
45
39
|
_RANDOM_SEED = 42
|
|
46
40
|
|
|
@@ -95,24 +89,41 @@ class Explanation:
|
|
|
95
89
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
90
|
pass
|
|
97
91
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
92
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
93
|
if index == 0:
|
|
100
94
|
return self.prediction
|
|
101
95
|
if index == 1:
|
|
102
96
|
return self.summary
|
|
103
97
|
raise IndexError("Index out of range")
|
|
104
98
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
99
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
100
|
return iter((self.prediction, self.summary))
|
|
107
101
|
|
|
108
102
|
def __repr__(self) -> str:
|
|
109
103
|
return str((self.prediction, self.summary))
|
|
110
104
|
|
|
111
|
-
def
|
|
112
|
-
|
|
105
|
+
def print(self) -> None:
|
|
106
|
+
r"""Prints the explanation."""
|
|
107
|
+
if in_snowflake_notebook():
|
|
108
|
+
import streamlit as st
|
|
109
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
+
st.markdown(self.summary)
|
|
111
|
+
elif in_notebook():
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
+
else:
|
|
117
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
+
except ImportError:
|
|
119
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
+
display(Markdown(self.summary))
|
|
121
|
+
else:
|
|
122
|
+
print(self.prediction.to_string(index=False))
|
|
123
|
+
print(self.summary)
|
|
113
124
|
|
|
114
|
-
|
|
115
|
-
|
|
125
|
+
def _ipython_display_(self) -> None:
|
|
126
|
+
self.print()
|
|
116
127
|
|
|
117
128
|
|
|
118
129
|
class KumoRFM:
|
|
@@ -150,31 +161,36 @@ class KumoRFM:
|
|
|
150
161
|
|
|
151
162
|
Args:
|
|
152
163
|
graph: The graph.
|
|
153
|
-
preprocess: Whether to pre-process the data in advance during graph
|
|
154
|
-
materialization.
|
|
155
|
-
This is a runtime trade-off between graph materialization and model
|
|
156
|
-
processing speed.
|
|
157
|
-
It can be benefical to preprocess your data once and then run many
|
|
158
|
-
queries on top to achieve maximum model speed.
|
|
159
|
-
However, if activiated, graph materialization can take potentially
|
|
160
|
-
much longer, especially on graphs with many large text columns.
|
|
161
|
-
Best to tune this option manually.
|
|
162
164
|
verbose: Whether to print verbose output.
|
|
165
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
166
|
+
for optimal querying. For example, for transactional database
|
|
167
|
+
backends, will create any missing indices. Requires write-access to
|
|
168
|
+
the data backend.
|
|
163
169
|
"""
|
|
164
170
|
def __init__(
|
|
165
171
|
self,
|
|
166
172
|
graph: Graph,
|
|
167
|
-
|
|
168
|
-
|
|
173
|
+
verbose: bool | ProgressLogger = True,
|
|
174
|
+
optimize: bool = False,
|
|
169
175
|
) -> None:
|
|
170
176
|
graph = graph.validate()
|
|
171
177
|
self._graph_def = graph._to_api_graph_definition()
|
|
172
|
-
self._graph_store = LocalGraphStore(graph, preprocess, verbose)
|
|
173
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
174
178
|
|
|
175
|
-
|
|
179
|
+
if graph.backend == DataBackend.LOCAL:
|
|
180
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
181
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
182
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
183
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
184
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
185
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
186
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
187
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
188
|
+
else:
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
self._client: RFMAPI | None = None
|
|
176
192
|
|
|
177
|
-
self._batch_size:
|
|
193
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
178
194
|
self.num_retries: int = 0
|
|
179
195
|
|
|
180
196
|
@property
|
|
@@ -192,7 +208,7 @@ class KumoRFM:
|
|
|
192
208
|
@contextmanager
|
|
193
209
|
def batch_mode(
|
|
194
210
|
self,
|
|
195
|
-
batch_size:
|
|
211
|
+
batch_size: int | Literal['max'] = 'max',
|
|
196
212
|
num_retries: int = 1,
|
|
197
213
|
) -> Generator[None, None, None]:
|
|
198
214
|
"""Context manager to predict in batches.
|
|
@@ -226,17 +242,17 @@ class KumoRFM:
|
|
|
226
242
|
def predict(
|
|
227
243
|
self,
|
|
228
244
|
query: str,
|
|
229
|
-
indices:
|
|
245
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
230
246
|
*,
|
|
231
247
|
explain: Literal[False] = False,
|
|
232
|
-
anchor_time:
|
|
233
|
-
context_anchor_time:
|
|
234
|
-
run_mode:
|
|
235
|
-
num_neighbors:
|
|
248
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
249
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
250
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
251
|
+
num_neighbors: list[int] | None = None,
|
|
236
252
|
num_hops: int = 2,
|
|
237
|
-
max_pq_iterations: int =
|
|
238
|
-
random_seed:
|
|
239
|
-
verbose:
|
|
253
|
+
max_pq_iterations: int = 10,
|
|
254
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
255
|
+
verbose: bool | ProgressLogger = True,
|
|
240
256
|
use_prediction_time: bool = False,
|
|
241
257
|
) -> pd.DataFrame:
|
|
242
258
|
pass
|
|
@@ -245,17 +261,17 @@ class KumoRFM:
|
|
|
245
261
|
def predict(
|
|
246
262
|
self,
|
|
247
263
|
query: str,
|
|
248
|
-
indices:
|
|
264
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
249
265
|
*,
|
|
250
|
-
explain:
|
|
251
|
-
anchor_time:
|
|
252
|
-
context_anchor_time:
|
|
253
|
-
run_mode:
|
|
254
|
-
num_neighbors:
|
|
266
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
267
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
268
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
269
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
270
|
+
num_neighbors: list[int] | None = None,
|
|
255
271
|
num_hops: int = 2,
|
|
256
|
-
max_pq_iterations: int =
|
|
257
|
-
random_seed:
|
|
258
|
-
verbose:
|
|
272
|
+
max_pq_iterations: int = 10,
|
|
273
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
274
|
+
verbose: bool | ProgressLogger = True,
|
|
259
275
|
use_prediction_time: bool = False,
|
|
260
276
|
) -> Explanation:
|
|
261
277
|
pass
|
|
@@ -263,19 +279,19 @@ class KumoRFM:
|
|
|
263
279
|
def predict(
|
|
264
280
|
self,
|
|
265
281
|
query: str,
|
|
266
|
-
indices:
|
|
282
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
267
283
|
*,
|
|
268
|
-
explain:
|
|
269
|
-
anchor_time:
|
|
270
|
-
context_anchor_time:
|
|
271
|
-
run_mode:
|
|
272
|
-
num_neighbors:
|
|
284
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
285
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
286
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
287
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
288
|
+
num_neighbors: list[int] | None = None,
|
|
273
289
|
num_hops: int = 2,
|
|
274
|
-
max_pq_iterations: int =
|
|
275
|
-
random_seed:
|
|
276
|
-
verbose:
|
|
290
|
+
max_pq_iterations: int = 10,
|
|
291
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
292
|
+
verbose: bool | ProgressLogger = True,
|
|
277
293
|
use_prediction_time: bool = False,
|
|
278
|
-
) ->
|
|
294
|
+
) -> pd.DataFrame | Explanation:
|
|
279
295
|
"""Returns predictions for a predictive query.
|
|
280
296
|
|
|
281
297
|
Args:
|
|
@@ -317,7 +333,7 @@ class KumoRFM:
|
|
|
317
333
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
318
334
|
containing the prediction, summary, and details.
|
|
319
335
|
"""
|
|
320
|
-
explain_config:
|
|
336
|
+
explain_config: ExplainConfig | None = None
|
|
321
337
|
if explain is True:
|
|
322
338
|
explain_config = ExplainConfig()
|
|
323
339
|
elif explain is not False:
|
|
@@ -361,15 +377,15 @@ class KumoRFM:
|
|
|
361
377
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
362
378
|
|
|
363
379
|
if not isinstance(verbose, ProgressLogger):
|
|
364
|
-
verbose =
|
|
380
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
365
381
|
|
|
366
382
|
with verbose as logger:
|
|
367
383
|
|
|
368
|
-
batch_size:
|
|
384
|
+
batch_size: int | None = None
|
|
369
385
|
if self._batch_size == 'max':
|
|
370
|
-
task_type =
|
|
371
|
-
query_def,
|
|
372
|
-
edge_types=self.
|
|
386
|
+
task_type = self._get_task_type(
|
|
387
|
+
query=query_def,
|
|
388
|
+
edge_types=self._sampler.edge_types,
|
|
373
389
|
)
|
|
374
390
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
375
391
|
else:
|
|
@@ -385,9 +401,9 @@ class KumoRFM:
|
|
|
385
401
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
386
402
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
387
403
|
|
|
388
|
-
predictions:
|
|
389
|
-
summary:
|
|
390
|
-
details:
|
|
404
|
+
predictions: list[pd.DataFrame] = []
|
|
405
|
+
summary: str | None = None
|
|
406
|
+
details: Explanation | None = None
|
|
391
407
|
for i, batch in enumerate(batches):
|
|
392
408
|
# TODO Re-use the context for subsequent predictions.
|
|
393
409
|
context = self._get_context(
|
|
@@ -421,8 +437,7 @@ class KumoRFM:
|
|
|
421
437
|
stats = Context.get_memory_stats(request_msg.context)
|
|
422
438
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
423
439
|
|
|
424
|
-
if
|
|
425
|
-
and len(batches) > 1):
|
|
440
|
+
if i == 0 and len(batches) > 1:
|
|
426
441
|
verbose.init_progress(
|
|
427
442
|
total=len(batches),
|
|
428
443
|
description='Predicting',
|
|
@@ -443,10 +458,10 @@ class KumoRFM:
|
|
|
443
458
|
|
|
444
459
|
# Cast 'ENTITY' to correct data type:
|
|
445
460
|
if 'ENTITY' in df:
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
461
|
+
table_dict = context.subgraph.table_dict
|
|
462
|
+
table = table_dict[query_def.entity_table]
|
|
463
|
+
ser = table.df[table.primary_key]
|
|
464
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
450
465
|
|
|
451
466
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
452
467
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -461,8 +476,7 @@ class KumoRFM:
|
|
|
461
476
|
|
|
462
477
|
predictions.append(df)
|
|
463
478
|
|
|
464
|
-
if (
|
|
465
|
-
and len(batches) > 1):
|
|
479
|
+
if len(batches) > 1:
|
|
466
480
|
verbose.step()
|
|
467
481
|
|
|
468
482
|
break
|
|
@@ -500,9 +514,9 @@ class KumoRFM:
|
|
|
500
514
|
def is_valid_entity(
|
|
501
515
|
self,
|
|
502
516
|
query: str,
|
|
503
|
-
indices:
|
|
517
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
504
518
|
*,
|
|
505
|
-
anchor_time:
|
|
519
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
506
520
|
) -> np.ndarray:
|
|
507
521
|
r"""Returns a mask that denotes which entities are valid for the
|
|
508
522
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -529,37 +543,32 @@ class KumoRFM:
|
|
|
529
543
|
raise ValueError("At least one entity is required")
|
|
530
544
|
|
|
531
545
|
if anchor_time is None:
|
|
532
|
-
anchor_time = self.
|
|
546
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
533
547
|
|
|
534
548
|
if isinstance(anchor_time, pd.Timestamp):
|
|
535
549
|
self._validate_time(query_def, anchor_time, None, False)
|
|
536
550
|
else:
|
|
537
551
|
assert anchor_time == 'entity'
|
|
538
|
-
if
|
|
552
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
539
553
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
540
554
|
f"table '{query_def.entity_table}' "
|
|
541
555
|
f"to have a time column.")
|
|
542
556
|
|
|
543
|
-
|
|
544
|
-
table_name=query_def.entity_table,
|
|
545
|
-
pkey=pd.Series(indices),
|
|
546
|
-
)
|
|
547
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
548
|
-
return query_driver.is_valid(node, anchor_time)
|
|
557
|
+
raise NotImplementedError
|
|
549
558
|
|
|
550
559
|
def evaluate(
|
|
551
560
|
self,
|
|
552
561
|
query: str,
|
|
553
562
|
*,
|
|
554
|
-
metrics:
|
|
555
|
-
anchor_time:
|
|
556
|
-
context_anchor_time:
|
|
557
|
-
run_mode:
|
|
558
|
-
num_neighbors:
|
|
563
|
+
metrics: list[str] | None = None,
|
|
564
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
565
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
566
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
567
|
+
num_neighbors: list[int] | None = None,
|
|
559
568
|
num_hops: int = 2,
|
|
560
|
-
max_pq_iterations: int =
|
|
561
|
-
random_seed:
|
|
562
|
-
verbose:
|
|
569
|
+
max_pq_iterations: int = 10,
|
|
570
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
571
|
+
verbose: bool | ProgressLogger = True,
|
|
563
572
|
use_prediction_time: bool = False,
|
|
564
573
|
) -> pd.DataFrame:
|
|
565
574
|
"""Evaluates a predictive query.
|
|
@@ -607,7 +616,7 @@ class KumoRFM:
|
|
|
607
616
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
608
617
|
|
|
609
618
|
if not isinstance(verbose, ProgressLogger):
|
|
610
|
-
verbose =
|
|
619
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
611
620
|
|
|
612
621
|
with verbose as logger:
|
|
613
622
|
context = self._get_context(
|
|
@@ -666,9 +675,9 @@ class KumoRFM:
|
|
|
666
675
|
query: str,
|
|
667
676
|
size: int,
|
|
668
677
|
*,
|
|
669
|
-
anchor_time:
|
|
670
|
-
random_seed:
|
|
671
|
-
max_iterations: int =
|
|
678
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
679
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
680
|
+
max_iterations: int = 10,
|
|
672
681
|
) -> pd.DataFrame:
|
|
673
682
|
"""Returns the labels of a predictive query for a specified anchor
|
|
674
683
|
time.
|
|
@@ -688,40 +697,37 @@ class KumoRFM:
|
|
|
688
697
|
query_def = self._parse_query(query)
|
|
689
698
|
|
|
690
699
|
if anchor_time is None:
|
|
691
|
-
anchor_time = self.
|
|
700
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
692
701
|
if query_def.target_ast.date_offset_range is not None:
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
702
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
703
|
+
offset *= query_def.num_forecasts
|
|
704
|
+
anchor_time -= offset
|
|
696
705
|
|
|
697
706
|
assert anchor_time is not None
|
|
698
707
|
if isinstance(anchor_time, pd.Timestamp):
|
|
699
708
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
700
709
|
else:
|
|
701
710
|
assert anchor_time == 'entity'
|
|
702
|
-
if
|
|
711
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
703
712
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
704
713
|
f"table '{query_def.entity_table}' "
|
|
705
714
|
f"to have a time column")
|
|
706
715
|
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
+
train, test = self._sampler.sample_target(
|
|
717
|
+
query=query,
|
|
718
|
+
num_train_examples=0,
|
|
719
|
+
train_anchor_time=anchor_time,
|
|
720
|
+
num_train_trials=0,
|
|
721
|
+
num_test_examples=size,
|
|
722
|
+
test_anchor_time=anchor_time,
|
|
723
|
+
num_test_trials=max_iterations * size,
|
|
724
|
+
random_seed=random_seed,
|
|
716
725
|
)
|
|
717
726
|
|
|
718
|
-
entity = self._graph_store.pkey_map_dict[
|
|
719
|
-
query_def.entity_table].index[node]
|
|
720
|
-
|
|
721
727
|
return pd.DataFrame({
|
|
722
|
-
'ENTITY':
|
|
723
|
-
'ANCHOR_TIMESTAMP':
|
|
724
|
-
'TARGET':
|
|
728
|
+
'ENTITY': test.entity_pkey,
|
|
729
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
730
|
+
'TARGET': test.target,
|
|
725
731
|
})
|
|
726
732
|
|
|
727
733
|
# Helpers #################################################################
|
|
@@ -744,8 +750,6 @@ class KumoRFM:
|
|
|
744
750
|
|
|
745
751
|
resp = self._api_client.parse_query(request)
|
|
746
752
|
|
|
747
|
-
# TODO Expose validation warnings.
|
|
748
|
-
|
|
749
753
|
if len(resp.validation_response.warnings) > 0:
|
|
750
754
|
msg = '\n'.join([
|
|
751
755
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -763,36 +767,92 @@ class KumoRFM:
|
|
|
763
767
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
764
768
|
f"{msg}") from None
|
|
765
769
|
|
|
770
|
+
@staticmethod
|
|
771
|
+
def _get_task_type(
|
|
772
|
+
query: ValidatedPredictiveQuery,
|
|
773
|
+
edge_types: list[tuple[str, str, str]],
|
|
774
|
+
) -> TaskType:
|
|
775
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
776
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
777
|
+
|
|
778
|
+
target = query.target_ast
|
|
779
|
+
if isinstance(target, Join):
|
|
780
|
+
target = target.rhs_target
|
|
781
|
+
if isinstance(target, Aggregation):
|
|
782
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
783
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
784
|
+
'.')
|
|
785
|
+
target_edge_types = [
|
|
786
|
+
edge_type for edge_type in edge_types
|
|
787
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
788
|
+
]
|
|
789
|
+
if len(target_edge_types) != 1:
|
|
790
|
+
raise NotImplementedError(
|
|
791
|
+
f"Multilabel-classification queries based on "
|
|
792
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
793
|
+
f"planned to write a link prediction query instead, "
|
|
794
|
+
f"make sure to register '{col_name}' as a "
|
|
795
|
+
f"foreign key.")
|
|
796
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
797
|
+
|
|
798
|
+
return TaskType.REGRESSION
|
|
799
|
+
|
|
800
|
+
assert isinstance(target, Column)
|
|
801
|
+
|
|
802
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
803
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
804
|
+
|
|
805
|
+
if target.stype in {Stype.numerical}:
|
|
806
|
+
return TaskType.REGRESSION
|
|
807
|
+
|
|
808
|
+
raise NotImplementedError("Task type not yet supported")
|
|
809
|
+
|
|
810
|
+
def _get_default_anchor_time(
|
|
811
|
+
self,
|
|
812
|
+
query: ValidatedPredictiveQuery,
|
|
813
|
+
) -> pd.Timestamp:
|
|
814
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
815
|
+
aggr_table_names = [
|
|
816
|
+
aggr._get_target_column_name().split('.')[0]
|
|
817
|
+
for aggr in query.get_all_target_aggregations()
|
|
818
|
+
]
|
|
819
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
820
|
+
|
|
821
|
+
assert query.query_type == QueryType.STATIC
|
|
822
|
+
return self._sampler.get_max_time()
|
|
823
|
+
|
|
766
824
|
def _validate_time(
|
|
767
825
|
self,
|
|
768
826
|
query: ValidatedPredictiveQuery,
|
|
769
827
|
anchor_time: pd.Timestamp,
|
|
770
|
-
context_anchor_time:
|
|
828
|
+
context_anchor_time: pd.Timestamp | None,
|
|
771
829
|
evaluate: bool,
|
|
772
830
|
) -> None:
|
|
773
831
|
|
|
774
|
-
if self.
|
|
832
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
775
833
|
return # Graph without timestamps
|
|
776
834
|
|
|
777
|
-
|
|
835
|
+
min_time = self._sampler.get_min_time()
|
|
836
|
+
max_time = self._sampler.get_max_time()
|
|
837
|
+
|
|
838
|
+
if anchor_time < min_time:
|
|
778
839
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
779
|
-
f"the earliest timestamp "
|
|
780
|
-
f"
|
|
840
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
841
|
+
f"data.")
|
|
781
842
|
|
|
782
|
-
if
|
|
783
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
843
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
784
844
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
785
845
|
f"aggregation time range is too large. To make "
|
|
786
846
|
f"this prediction, we would need data back to "
|
|
787
847
|
f"'{context_anchor_time}', however, your data "
|
|
788
|
-
f"only contains data back to "
|
|
789
|
-
f"'{self._graph_store.min_time}'.")
|
|
848
|
+
f"only contains data back to '{min_time}'.")
|
|
790
849
|
|
|
791
850
|
if query.target_ast.date_offset_range is not None:
|
|
792
851
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
793
852
|
else:
|
|
794
853
|
end_offset = pd.DateOffset(0)
|
|
795
|
-
|
|
854
|
+
end_offset = end_offset * query.num_forecasts
|
|
855
|
+
|
|
796
856
|
if (context_anchor_time is not None
|
|
797
857
|
and context_anchor_time > anchor_time):
|
|
798
858
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -802,7 +862,7 @@ class KumoRFM:
|
|
|
802
862
|
f"intended.")
|
|
803
863
|
elif (query.query_type == QueryType.TEMPORAL
|
|
804
864
|
and context_anchor_time is not None
|
|
805
|
-
and context_anchor_time +
|
|
865
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
806
866
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
807
867
|
f"'{context_anchor_time}' will leak information "
|
|
808
868
|
f"from the prediction anchor timestamp "
|
|
@@ -810,40 +870,37 @@ class KumoRFM:
|
|
|
810
870
|
f"intended.")
|
|
811
871
|
|
|
812
872
|
elif (context_anchor_time is not None
|
|
813
|
-
and context_anchor_time -
|
|
814
|
-
|
|
815
|
-
_time = context_anchor_time - forecast_end_offset
|
|
873
|
+
and context_anchor_time - end_offset < min_time):
|
|
874
|
+
_time = context_anchor_time - end_offset
|
|
816
875
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
817
876
|
f"aggregation time range is too large. To form "
|
|
818
877
|
f"proper input data, we would need data back to "
|
|
819
878
|
f"'{_time}', however, your data only contains "
|
|
820
|
-
f"data back to '{
|
|
879
|
+
f"data back to '{min_time}'.")
|
|
821
880
|
|
|
822
|
-
if
|
|
823
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
881
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
824
882
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
825
|
-
f"latest timestamp '{
|
|
826
|
-
f"
|
|
883
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
884
|
+
f"make sure this is intended.")
|
|
827
885
|
|
|
828
|
-
|
|
829
|
-
if evaluate and anchor_time > max_eval_time:
|
|
886
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
830
887
|
raise ValueError(
|
|
831
888
|
f"Anchor timestamp for evaluation is after the latest "
|
|
832
|
-
f"supported timestamp '{
|
|
889
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
833
890
|
|
|
834
891
|
def _get_context(
|
|
835
892
|
self,
|
|
836
893
|
query: ValidatedPredictiveQuery,
|
|
837
|
-
indices:
|
|
838
|
-
anchor_time:
|
|
839
|
-
context_anchor_time:
|
|
894
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
+
context_anchor_time: pd.Timestamp | None,
|
|
840
897
|
run_mode: RunMode,
|
|
841
|
-
num_neighbors:
|
|
898
|
+
num_neighbors: list[int] | None,
|
|
842
899
|
num_hops: int,
|
|
843
900
|
max_pq_iterations: int,
|
|
844
901
|
evaluate: bool,
|
|
845
|
-
random_seed:
|
|
846
|
-
logger:
|
|
902
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
903
|
+
logger: ProgressLogger | None = None,
|
|
847
904
|
) -> Context:
|
|
848
905
|
|
|
849
906
|
if num_neighbors is not None:
|
|
@@ -860,10 +917,9 @@ class KumoRFM:
|
|
|
860
917
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
861
918
|
f"must go beyond this for your use-case.")
|
|
862
919
|
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
edge_types=self._graph_store.edge_types,
|
|
920
|
+
task_type = self._get_task_type(
|
|
921
|
+
query=query,
|
|
922
|
+
edge_types=self._sampler.edge_types,
|
|
867
923
|
)
|
|
868
924
|
|
|
869
925
|
if logger is not None:
|
|
@@ -895,14 +951,17 @@ class KumoRFM:
|
|
|
895
951
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
896
952
|
|
|
897
953
|
if query.target_ast.date_offset_range is None:
|
|
898
|
-
|
|
954
|
+
step_offset = pd.DateOffset(0)
|
|
899
955
|
else:
|
|
900
|
-
|
|
901
|
-
|
|
956
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
957
|
+
end_offset = step_offset * query.num_forecasts
|
|
958
|
+
|
|
902
959
|
if anchor_time is None:
|
|
903
|
-
anchor_time = self.
|
|
960
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
+
|
|
904
962
|
if evaluate:
|
|
905
|
-
anchor_time = anchor_time -
|
|
963
|
+
anchor_time = anchor_time - end_offset
|
|
964
|
+
|
|
906
965
|
if logger is not None:
|
|
907
966
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
908
967
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -916,57 +975,71 @@ class KumoRFM:
|
|
|
916
975
|
|
|
917
976
|
assert anchor_time is not None
|
|
918
977
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
|
+
if context_anchor_time == 'entity':
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
919
981
|
if context_anchor_time is None:
|
|
920
|
-
context_anchor_time = anchor_time -
|
|
982
|
+
context_anchor_time = anchor_time - end_offset
|
|
921
983
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
922
984
|
evaluate)
|
|
923
985
|
else:
|
|
924
986
|
assert anchor_time == 'entity'
|
|
925
|
-
if query.
|
|
987
|
+
if query.query_type != QueryType.STATIC:
|
|
988
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
989
|
+
"static predictive queries")
|
|
990
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
926
991
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
927
992
|
f"table '{query.entity_table}' to "
|
|
928
993
|
f"have a time column")
|
|
929
|
-
if context_anchor_time
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
context_anchor_time =
|
|
994
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
995
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
996
|
+
"for context and prediction examples")
|
|
997
|
+
context_anchor_time = 'entity'
|
|
933
998
|
|
|
934
|
-
|
|
999
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
935
1000
|
if evaluate:
|
|
936
|
-
|
|
1001
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
937
1002
|
if task_type.is_link_pred:
|
|
938
|
-
|
|
1003
|
+
num_test_examples = num_test_examples // 5
|
|
1004
|
+
else:
|
|
1005
|
+
num_test_examples = 0
|
|
1006
|
+
|
|
1007
|
+
train, test = self._sampler.sample_target(
|
|
1008
|
+
query=query,
|
|
1009
|
+
num_train_examples=num_train_examples,
|
|
1010
|
+
train_anchor_time=context_anchor_time,
|
|
1011
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1012
|
+
num_test_examples=num_test_examples,
|
|
1013
|
+
test_anchor_time=anchor_time,
|
|
1014
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
|
+
random_seed=random_seed,
|
|
1016
|
+
)
|
|
1017
|
+
train_pkey, train_time, y_train = train
|
|
1018
|
+
test_pkey, test_time, y_test = test
|
|
939
1019
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
num_rhs = y_test.explode().nunique()
|
|
961
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
962
|
-
f"{num_rhs:,} unique items")
|
|
963
|
-
else:
|
|
964
|
-
raise NotImplementedError
|
|
965
|
-
logger.log(msg)
|
|
1020
|
+
if evaluate and logger is not None:
|
|
1021
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1023
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1024
|
+
f"{pos:.2f}% positive cases")
|
|
1025
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1027
|
+
f"{y_test.nunique()} classes")
|
|
1028
|
+
elif task_type == TaskType.REGRESSION:
|
|
1029
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1030
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1031
|
+
f"between {format_value(_min)} and "
|
|
1032
|
+
f"{format_value(_max)}")
|
|
1033
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
+
num_rhs = y_test.explode().nunique()
|
|
1035
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1036
|
+
f"{num_rhs:,} unique items")
|
|
1037
|
+
else:
|
|
1038
|
+
raise NotImplementedError
|
|
1039
|
+
logger.log(msg)
|
|
966
1040
|
|
|
967
|
-
|
|
1041
|
+
if not evaluate:
|
|
968
1042
|
assert indices is not None
|
|
969
|
-
|
|
970
1043
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
971
1044
|
raise ValueError(f"Cannot predict for more than "
|
|
972
1045
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -974,26 +1047,12 @@ class KumoRFM:
|
|
|
974
1047
|
f"`KumoRFM.batch_mode` to process entities "
|
|
975
1048
|
f"in batches")
|
|
976
1049
|
|
|
977
|
-
|
|
978
|
-
table_name=query.entity_table,
|
|
979
|
-
pkey=pd.Series(indices),
|
|
980
|
-
)
|
|
981
|
-
|
|
1050
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
982
1051
|
if isinstance(anchor_time, pd.Timestamp):
|
|
983
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
984
|
-
len(
|
|
1052
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1053
|
+
len(indices)).reset_index(drop=True)
|
|
985
1054
|
else:
|
|
986
|
-
|
|
987
|
-
time = time[test_node] * 1000**3
|
|
988
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
989
|
-
|
|
990
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
991
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
992
|
-
anchor_time=context_anchor_time or 'entity',
|
|
993
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
994
|
-
or anchor_time == 'entity') else None,
|
|
995
|
-
max_iterations=max_pq_iterations,
|
|
996
|
-
)
|
|
1055
|
+
train_time = test_time = 'entity'
|
|
997
1056
|
|
|
998
1057
|
if logger is not None:
|
|
999
1058
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1016,12 +1075,12 @@ class KumoRFM:
|
|
|
1016
1075
|
raise NotImplementedError
|
|
1017
1076
|
logger.log(msg)
|
|
1018
1077
|
|
|
1019
|
-
entity_table_names:
|
|
1078
|
+
entity_table_names: tuple[str, ...]
|
|
1020
1079
|
if task_type.is_link_pred:
|
|
1021
1080
|
final_aggr = query.get_final_target_aggregation()
|
|
1022
1081
|
assert final_aggr is not None
|
|
1023
1082
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1024
|
-
for edge_type in self.
|
|
1083
|
+
for edge_type in self._sampler.edge_types:
|
|
1025
1084
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1026
1085
|
entity_table_names = (
|
|
1027
1086
|
query.entity_table,
|
|
@@ -1033,20 +1092,24 @@ class KumoRFM:
|
|
|
1033
1092
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1034
1093
|
# running out-of-distribution between in-context and test examples:
|
|
1035
1094
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1036
|
-
if
|
|
1095
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1037
1096
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1038
1097
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1039
|
-
|
|
1040
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1098
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1041
1099
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1042
1100
|
|
|
1043
|
-
subgraph = self.
|
|
1101
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1044
1102
|
entity_table_names=entity_table_names,
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
|
|
1103
|
+
entity_pkey=pd.concat(
|
|
1104
|
+
[train_pkey, test_pkey],
|
|
1105
|
+
axis=0,
|
|
1106
|
+
ignore_index=True,
|
|
1107
|
+
),
|
|
1108
|
+
anchor_time=pd.concat(
|
|
1109
|
+
[train_time, test_time],
|
|
1110
|
+
axis=0,
|
|
1111
|
+
ignore_index=True,
|
|
1112
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1050
1113
|
num_neighbors=num_neighbors,
|
|
1051
1114
|
exclude_cols_dict=exclude_cols_dict,
|
|
1052
1115
|
)
|
|
@@ -1058,23 +1121,19 @@ class KumoRFM:
|
|
|
1058
1121
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1059
1122
|
f"must go beyond this for your use-case.")
|
|
1060
1123
|
|
|
1061
|
-
step_size: Optional[int] = None
|
|
1062
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1063
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1064
|
-
|
|
1065
1124
|
return Context(
|
|
1066
1125
|
task_type=task_type,
|
|
1067
1126
|
entity_table_names=entity_table_names,
|
|
1068
1127
|
subgraph=subgraph,
|
|
1069
1128
|
y_train=y_train,
|
|
1070
|
-
y_test=y_test,
|
|
1129
|
+
y_test=y_test if evaluate else None,
|
|
1071
1130
|
top_k=query.top_k,
|
|
1072
|
-
step_size=
|
|
1131
|
+
step_size=None,
|
|
1073
1132
|
)
|
|
1074
1133
|
|
|
1075
1134
|
@staticmethod
|
|
1076
1135
|
def _validate_metrics(
|
|
1077
|
-
metrics:
|
|
1136
|
+
metrics: list[str],
|
|
1078
1137
|
task_type: TaskType,
|
|
1079
1138
|
) -> None:
|
|
1080
1139
|
|
|
@@ -1131,7 +1190,7 @@ class KumoRFM:
|
|
|
1131
1190
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1132
1191
|
|
|
1133
1192
|
|
|
1134
|
-
def format_value(value:
|
|
1193
|
+
def format_value(value: int | float) -> str:
|
|
1135
1194
|
if value == int(value):
|
|
1136
1195
|
return f'{int(value):,}'
|
|
1137
1196
|
if abs(value) >= 1000:
|