kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
- kumoai/experimental/rfm/backend/local/sampler.py +213 -14
- kumoai/experimental/rfm/backend/local/table.py +21 -16
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +101 -49
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
- kumoai/experimental/rfm/base/__init__.py +25 -6
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sampler.py +438 -38
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +165 -135
- kumoai/experimental/rfm/graph.py +266 -102
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +41 -35
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -2,25 +2,22 @@ import json
|
|
|
2
2
|
import time
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from collections.abc import Generator
|
|
5
|
+
from collections.abc import Generator, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
7
|
from dataclasses import dataclass, replace
|
|
8
|
-
from typing import
|
|
9
|
-
Any,
|
|
10
|
-
Dict,
|
|
11
|
-
Iterator,
|
|
12
|
-
List,
|
|
13
|
-
Literal,
|
|
14
|
-
Optional,
|
|
15
|
-
Tuple,
|
|
16
|
-
Union,
|
|
17
|
-
overload,
|
|
18
|
-
)
|
|
8
|
+
from typing import Any, Literal, overload
|
|
19
9
|
|
|
20
10
|
import numpy as np
|
|
21
11
|
import pandas as pd
|
|
22
12
|
from kumoapi.model_plan import RunMode
|
|
23
13
|
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
14
|
+
from kumoapi.pquery.AST import (
|
|
15
|
+
Aggregation,
|
|
16
|
+
Column,
|
|
17
|
+
Condition,
|
|
18
|
+
Join,
|
|
19
|
+
LogicalOperation,
|
|
20
|
+
)
|
|
24
21
|
from kumoapi.rfm import Context
|
|
25
22
|
from kumoapi.rfm import Explanation as ExplanationConfig
|
|
26
23
|
from kumoapi.rfm import (
|
|
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
|
|
|
29
26
|
RFMPredictRequest,
|
|
30
27
|
)
|
|
31
28
|
from kumoapi.task import TaskType
|
|
29
|
+
from kumoapi.typing import AggregationType, Stype
|
|
32
30
|
|
|
31
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
33
32
|
from kumoai.client.rfm import RFMAPI
|
|
34
33
|
from kumoai.exceptions import HTTPException
|
|
35
34
|
from kumoai.experimental.rfm import Graph
|
|
36
|
-
from kumoai.experimental.rfm.
|
|
37
|
-
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
38
|
-
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
39
|
-
LocalPQueryDriver,
|
|
40
|
-
date_offset_to_seconds,
|
|
41
|
-
)
|
|
35
|
+
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
42
36
|
from kumoai.mixin import CastMixin
|
|
43
|
-
from kumoai.utils import
|
|
37
|
+
from kumoai.utils import ProgressLogger
|
|
44
38
|
|
|
45
39
|
_RANDOM_SEED = 42
|
|
46
40
|
|
|
@@ -95,24 +89,41 @@ class Explanation:
|
|
|
95
89
|
def __getitem__(self, index: Literal[1]) -> str:
|
|
96
90
|
pass
|
|
97
91
|
|
|
98
|
-
def __getitem__(self, index: int) ->
|
|
92
|
+
def __getitem__(self, index: int) -> pd.DataFrame | str:
|
|
99
93
|
if index == 0:
|
|
100
94
|
return self.prediction
|
|
101
95
|
if index == 1:
|
|
102
96
|
return self.summary
|
|
103
97
|
raise IndexError("Index out of range")
|
|
104
98
|
|
|
105
|
-
def __iter__(self) -> Iterator[
|
|
99
|
+
def __iter__(self) -> Iterator[pd.DataFrame | str]:
|
|
106
100
|
return iter((self.prediction, self.summary))
|
|
107
101
|
|
|
108
102
|
def __repr__(self) -> str:
|
|
109
103
|
return str((self.prediction, self.summary))
|
|
110
104
|
|
|
111
|
-
def
|
|
112
|
-
|
|
105
|
+
def print(self) -> None:
|
|
106
|
+
r"""Prints the explanation."""
|
|
107
|
+
if in_snowflake_notebook():
|
|
108
|
+
import streamlit as st
|
|
109
|
+
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
+
st.markdown(self.summary)
|
|
111
|
+
elif in_notebook():
|
|
112
|
+
from IPython.display import Markdown, display
|
|
113
|
+
try:
|
|
114
|
+
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
+
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
+
else:
|
|
117
|
+
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
+
except ImportError:
|
|
119
|
+
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
+
display(Markdown(self.summary))
|
|
121
|
+
else:
|
|
122
|
+
print(self.prediction.to_string(index=False))
|
|
123
|
+
print(self.summary)
|
|
113
124
|
|
|
114
|
-
|
|
115
|
-
|
|
125
|
+
def _ipython_display_(self) -> None:
|
|
126
|
+
self.print()
|
|
116
127
|
|
|
117
128
|
|
|
118
129
|
class KumoRFM:
|
|
@@ -151,20 +162,35 @@ class KumoRFM:
|
|
|
151
162
|
Args:
|
|
152
163
|
graph: The graph.
|
|
153
164
|
verbose: Whether to print verbose output.
|
|
165
|
+
optimize: If set to ``True``, will optimize the underlying data backend
|
|
166
|
+
for optimal querying. For example, for transactional database
|
|
167
|
+
backends, will create any missing indices. Requires write-access to
|
|
168
|
+
the data backend.
|
|
154
169
|
"""
|
|
155
170
|
def __init__(
|
|
156
171
|
self,
|
|
157
172
|
graph: Graph,
|
|
158
|
-
verbose:
|
|
173
|
+
verbose: bool | ProgressLogger = True,
|
|
174
|
+
optimize: bool = False,
|
|
159
175
|
) -> None:
|
|
160
176
|
graph = graph.validate()
|
|
161
177
|
self._graph_def = graph._to_api_graph_definition()
|
|
162
|
-
self._graph_store = LocalGraphStore(graph, verbose)
|
|
163
|
-
self._graph_sampler = LocalGraphSampler(self._graph_store)
|
|
164
178
|
|
|
165
|
-
|
|
179
|
+
if graph.backend == DataBackend.LOCAL:
|
|
180
|
+
from kumoai.experimental.rfm.backend.local import LocalSampler
|
|
181
|
+
self._sampler: Sampler = LocalSampler(graph, verbose)
|
|
182
|
+
elif graph.backend == DataBackend.SQLITE:
|
|
183
|
+
from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
|
|
184
|
+
self._sampler = SQLiteSampler(graph, verbose, optimize)
|
|
185
|
+
elif graph.backend == DataBackend.SNOWFLAKE:
|
|
186
|
+
from kumoai.experimental.rfm.backend.snow import SnowSampler
|
|
187
|
+
self._sampler = SnowSampler(graph, verbose)
|
|
188
|
+
else:
|
|
189
|
+
raise NotImplementedError
|
|
190
|
+
|
|
191
|
+
self._client: RFMAPI | None = None
|
|
166
192
|
|
|
167
|
-
self._batch_size:
|
|
193
|
+
self._batch_size: int | Literal['max'] | None = None
|
|
168
194
|
self.num_retries: int = 0
|
|
169
195
|
|
|
170
196
|
@property
|
|
@@ -182,7 +208,7 @@ class KumoRFM:
|
|
|
182
208
|
@contextmanager
|
|
183
209
|
def batch_mode(
|
|
184
210
|
self,
|
|
185
|
-
batch_size:
|
|
211
|
+
batch_size: int | Literal['max'] = 'max',
|
|
186
212
|
num_retries: int = 1,
|
|
187
213
|
) -> Generator[None, None, None]:
|
|
188
214
|
"""Context manager to predict in batches.
|
|
@@ -216,17 +242,17 @@ class KumoRFM:
|
|
|
216
242
|
def predict(
|
|
217
243
|
self,
|
|
218
244
|
query: str,
|
|
219
|
-
indices:
|
|
245
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
220
246
|
*,
|
|
221
247
|
explain: Literal[False] = False,
|
|
222
|
-
anchor_time:
|
|
223
|
-
context_anchor_time:
|
|
224
|
-
run_mode:
|
|
225
|
-
num_neighbors:
|
|
248
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
249
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
250
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
251
|
+
num_neighbors: list[int] | None = None,
|
|
226
252
|
num_hops: int = 2,
|
|
227
|
-
max_pq_iterations: int =
|
|
228
|
-
random_seed:
|
|
229
|
-
verbose:
|
|
253
|
+
max_pq_iterations: int = 10,
|
|
254
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
255
|
+
verbose: bool | ProgressLogger = True,
|
|
230
256
|
use_prediction_time: bool = False,
|
|
231
257
|
) -> pd.DataFrame:
|
|
232
258
|
pass
|
|
@@ -235,17 +261,17 @@ class KumoRFM:
|
|
|
235
261
|
def predict(
|
|
236
262
|
self,
|
|
237
263
|
query: str,
|
|
238
|
-
indices:
|
|
264
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
239
265
|
*,
|
|
240
|
-
explain:
|
|
241
|
-
anchor_time:
|
|
242
|
-
context_anchor_time:
|
|
243
|
-
run_mode:
|
|
244
|
-
num_neighbors:
|
|
266
|
+
explain: Literal[True] | ExplainConfig | dict[str, Any],
|
|
267
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
268
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
269
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
270
|
+
num_neighbors: list[int] | None = None,
|
|
245
271
|
num_hops: int = 2,
|
|
246
|
-
max_pq_iterations: int =
|
|
247
|
-
random_seed:
|
|
248
|
-
verbose:
|
|
272
|
+
max_pq_iterations: int = 10,
|
|
273
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
274
|
+
verbose: bool | ProgressLogger = True,
|
|
249
275
|
use_prediction_time: bool = False,
|
|
250
276
|
) -> Explanation:
|
|
251
277
|
pass
|
|
@@ -253,19 +279,19 @@ class KumoRFM:
|
|
|
253
279
|
def predict(
|
|
254
280
|
self,
|
|
255
281
|
query: str,
|
|
256
|
-
indices:
|
|
282
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
257
283
|
*,
|
|
258
|
-
explain:
|
|
259
|
-
anchor_time:
|
|
260
|
-
context_anchor_time:
|
|
261
|
-
run_mode:
|
|
262
|
-
num_neighbors:
|
|
284
|
+
explain: bool | ExplainConfig | dict[str, Any] = False,
|
|
285
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
286
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
287
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
288
|
+
num_neighbors: list[int] | None = None,
|
|
263
289
|
num_hops: int = 2,
|
|
264
|
-
max_pq_iterations: int =
|
|
265
|
-
random_seed:
|
|
266
|
-
verbose:
|
|
290
|
+
max_pq_iterations: int = 10,
|
|
291
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
292
|
+
verbose: bool | ProgressLogger = True,
|
|
267
293
|
use_prediction_time: bool = False,
|
|
268
|
-
) ->
|
|
294
|
+
) -> pd.DataFrame | Explanation:
|
|
269
295
|
"""Returns predictions for a predictive query.
|
|
270
296
|
|
|
271
297
|
Args:
|
|
@@ -307,7 +333,7 @@ class KumoRFM:
|
|
|
307
333
|
If ``explain`` is provided, returns an :class:`Explanation` object
|
|
308
334
|
containing the prediction, summary, and details.
|
|
309
335
|
"""
|
|
310
|
-
explain_config:
|
|
336
|
+
explain_config: ExplainConfig | None = None
|
|
311
337
|
if explain is True:
|
|
312
338
|
explain_config = ExplainConfig()
|
|
313
339
|
elif explain is not False:
|
|
@@ -351,15 +377,15 @@ class KumoRFM:
|
|
|
351
377
|
msg = f'[bold]PREDICT[/bold] {query_repr}'
|
|
352
378
|
|
|
353
379
|
if not isinstance(verbose, ProgressLogger):
|
|
354
|
-
verbose =
|
|
380
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
355
381
|
|
|
356
382
|
with verbose as logger:
|
|
357
383
|
|
|
358
|
-
batch_size:
|
|
384
|
+
batch_size: int | None = None
|
|
359
385
|
if self._batch_size == 'max':
|
|
360
|
-
task_type =
|
|
361
|
-
query_def,
|
|
362
|
-
edge_types=self.
|
|
386
|
+
task_type = self._get_task_type(
|
|
387
|
+
query=query_def,
|
|
388
|
+
edge_types=self._sampler.edge_types,
|
|
363
389
|
)
|
|
364
390
|
batch_size = _MAX_PRED_SIZE[task_type]
|
|
365
391
|
else:
|
|
@@ -375,9 +401,9 @@ class KumoRFM:
|
|
|
375
401
|
logger.log(f"Splitting {len(indices):,} entities into "
|
|
376
402
|
f"{len(batches):,} batches of size {batch_size:,}")
|
|
377
403
|
|
|
378
|
-
predictions:
|
|
379
|
-
summary:
|
|
380
|
-
details:
|
|
404
|
+
predictions: list[pd.DataFrame] = []
|
|
405
|
+
summary: str | None = None
|
|
406
|
+
details: Explanation | None = None
|
|
381
407
|
for i, batch in enumerate(batches):
|
|
382
408
|
# TODO Re-use the context for subsequent predictions.
|
|
383
409
|
context = self._get_context(
|
|
@@ -411,8 +437,7 @@ class KumoRFM:
|
|
|
411
437
|
stats = Context.get_memory_stats(request_msg.context)
|
|
412
438
|
raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
|
|
413
439
|
|
|
414
|
-
if
|
|
415
|
-
and len(batches) > 1):
|
|
440
|
+
if i == 0 and len(batches) > 1:
|
|
416
441
|
verbose.init_progress(
|
|
417
442
|
total=len(batches),
|
|
418
443
|
description='Predicting',
|
|
@@ -433,10 +458,10 @@ class KumoRFM:
|
|
|
433
458
|
|
|
434
459
|
# Cast 'ENTITY' to correct data type:
|
|
435
460
|
if 'ENTITY' in df:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
461
|
+
table_dict = context.subgraph.table_dict
|
|
462
|
+
table = table_dict[query_def.entity_table]
|
|
463
|
+
ser = table.df[table.primary_key]
|
|
464
|
+
df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
|
|
440
465
|
|
|
441
466
|
# Cast 'ANCHOR_TIMESTAMP' to correct data type:
|
|
442
467
|
if 'ANCHOR_TIMESTAMP' in df:
|
|
@@ -451,8 +476,7 @@ class KumoRFM:
|
|
|
451
476
|
|
|
452
477
|
predictions.append(df)
|
|
453
478
|
|
|
454
|
-
if (
|
|
455
|
-
and len(batches) > 1):
|
|
479
|
+
if len(batches) > 1:
|
|
456
480
|
verbose.step()
|
|
457
481
|
|
|
458
482
|
break
|
|
@@ -490,9 +514,9 @@ class KumoRFM:
|
|
|
490
514
|
def is_valid_entity(
|
|
491
515
|
self,
|
|
492
516
|
query: str,
|
|
493
|
-
indices:
|
|
517
|
+
indices: list[str] | list[float] | list[int] | None = None,
|
|
494
518
|
*,
|
|
495
|
-
anchor_time:
|
|
519
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
496
520
|
) -> np.ndarray:
|
|
497
521
|
r"""Returns a mask that denotes which entities are valid for the
|
|
498
522
|
given predictive query, *i.e.*, which entities fulfill (temporal)
|
|
@@ -519,37 +543,32 @@ class KumoRFM:
|
|
|
519
543
|
raise ValueError("At least one entity is required")
|
|
520
544
|
|
|
521
545
|
if anchor_time is None:
|
|
522
|
-
anchor_time = self.
|
|
546
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
523
547
|
|
|
524
548
|
if isinstance(anchor_time, pd.Timestamp):
|
|
525
549
|
self._validate_time(query_def, anchor_time, None, False)
|
|
526
550
|
else:
|
|
527
551
|
assert anchor_time == 'entity'
|
|
528
|
-
if
|
|
552
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
529
553
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
530
554
|
f"table '{query_def.entity_table}' "
|
|
531
555
|
f"to have a time column.")
|
|
532
556
|
|
|
533
|
-
|
|
534
|
-
table_name=query_def.entity_table,
|
|
535
|
-
pkey=pd.Series(indices),
|
|
536
|
-
)
|
|
537
|
-
query_driver = LocalPQueryDriver(self._graph_store, query_def)
|
|
538
|
-
return query_driver.is_valid(node, anchor_time)
|
|
557
|
+
raise NotImplementedError
|
|
539
558
|
|
|
540
559
|
def evaluate(
|
|
541
560
|
self,
|
|
542
561
|
query: str,
|
|
543
562
|
*,
|
|
544
|
-
metrics:
|
|
545
|
-
anchor_time:
|
|
546
|
-
context_anchor_time:
|
|
547
|
-
run_mode:
|
|
548
|
-
num_neighbors:
|
|
563
|
+
metrics: list[str] | None = None,
|
|
564
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
565
|
+
context_anchor_time: pd.Timestamp | None = None,
|
|
566
|
+
run_mode: RunMode | str = RunMode.FAST,
|
|
567
|
+
num_neighbors: list[int] | None = None,
|
|
549
568
|
num_hops: int = 2,
|
|
550
|
-
max_pq_iterations: int =
|
|
551
|
-
random_seed:
|
|
552
|
-
verbose:
|
|
569
|
+
max_pq_iterations: int = 10,
|
|
570
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
571
|
+
verbose: bool | ProgressLogger = True,
|
|
553
572
|
use_prediction_time: bool = False,
|
|
554
573
|
) -> pd.DataFrame:
|
|
555
574
|
"""Evaluates a predictive query.
|
|
@@ -597,7 +616,7 @@ class KumoRFM:
|
|
|
597
616
|
msg = f'[bold]EVALUATE[/bold] {query_repr}'
|
|
598
617
|
|
|
599
618
|
if not isinstance(verbose, ProgressLogger):
|
|
600
|
-
verbose =
|
|
619
|
+
verbose = ProgressLogger.default(msg=msg, verbose=verbose)
|
|
601
620
|
|
|
602
621
|
with verbose as logger:
|
|
603
622
|
context = self._get_context(
|
|
@@ -656,9 +675,9 @@ class KumoRFM:
|
|
|
656
675
|
query: str,
|
|
657
676
|
size: int,
|
|
658
677
|
*,
|
|
659
|
-
anchor_time:
|
|
660
|
-
random_seed:
|
|
661
|
-
max_iterations: int =
|
|
678
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None = None,
|
|
679
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
680
|
+
max_iterations: int = 10,
|
|
662
681
|
) -> pd.DataFrame:
|
|
663
682
|
"""Returns the labels of a predictive query for a specified anchor
|
|
664
683
|
time.
|
|
@@ -678,40 +697,37 @@ class KumoRFM:
|
|
|
678
697
|
query_def = self._parse_query(query)
|
|
679
698
|
|
|
680
699
|
if anchor_time is None:
|
|
681
|
-
anchor_time = self.
|
|
700
|
+
anchor_time = self._get_default_anchor_time(query_def)
|
|
682
701
|
if query_def.target_ast.date_offset_range is not None:
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
702
|
+
offset = query_def.target_ast.date_offset_range.end_date_offset
|
|
703
|
+
offset *= query_def.num_forecasts
|
|
704
|
+
anchor_time -= offset
|
|
686
705
|
|
|
687
706
|
assert anchor_time is not None
|
|
688
707
|
if isinstance(anchor_time, pd.Timestamp):
|
|
689
708
|
self._validate_time(query_def, anchor_time, None, evaluate=True)
|
|
690
709
|
else:
|
|
691
710
|
assert anchor_time == 'entity'
|
|
692
|
-
if
|
|
711
|
+
if query_def.entity_table not in self._sampler.time_column_dict:
|
|
693
712
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
694
713
|
f"table '{query_def.entity_table}' "
|
|
695
714
|
f"to have a time column")
|
|
696
715
|
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
716
|
+
train, test = self._sampler.sample_target(
|
|
717
|
+
query=query,
|
|
718
|
+
num_train_examples=0,
|
|
719
|
+
train_anchor_time=anchor_time,
|
|
720
|
+
num_train_trials=0,
|
|
721
|
+
num_test_examples=size,
|
|
722
|
+
test_anchor_time=anchor_time,
|
|
723
|
+
num_test_trials=max_iterations * size,
|
|
724
|
+
random_seed=random_seed,
|
|
706
725
|
)
|
|
707
726
|
|
|
708
|
-
entity = self._graph_store.pkey_map_dict[
|
|
709
|
-
query_def.entity_table].index[node]
|
|
710
|
-
|
|
711
727
|
return pd.DataFrame({
|
|
712
|
-
'ENTITY':
|
|
713
|
-
'ANCHOR_TIMESTAMP':
|
|
714
|
-
'TARGET':
|
|
728
|
+
'ENTITY': test.entity_pkey,
|
|
729
|
+
'ANCHOR_TIMESTAMP': test.anchor_time,
|
|
730
|
+
'TARGET': test.target,
|
|
715
731
|
})
|
|
716
732
|
|
|
717
733
|
# Helpers #################################################################
|
|
@@ -734,8 +750,6 @@ class KumoRFM:
|
|
|
734
750
|
|
|
735
751
|
resp = self._api_client.parse_query(request)
|
|
736
752
|
|
|
737
|
-
# TODO Expose validation warnings.
|
|
738
|
-
|
|
739
753
|
if len(resp.validation_response.warnings) > 0:
|
|
740
754
|
msg = '\n'.join([
|
|
741
755
|
f'{i+1}. {warning.title}: {warning.message}' for i, warning
|
|
@@ -753,36 +767,92 @@ class KumoRFM:
|
|
|
753
767
|
raise ValueError(f"Failed to parse query '{query}'. "
|
|
754
768
|
f"{msg}") from None
|
|
755
769
|
|
|
770
|
+
@staticmethod
|
|
771
|
+
def _get_task_type(
|
|
772
|
+
query: ValidatedPredictiveQuery,
|
|
773
|
+
edge_types: list[tuple[str, str, str]],
|
|
774
|
+
) -> TaskType:
|
|
775
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
776
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
777
|
+
|
|
778
|
+
target = query.target_ast
|
|
779
|
+
if isinstance(target, Join):
|
|
780
|
+
target = target.rhs_target
|
|
781
|
+
if isinstance(target, Aggregation):
|
|
782
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
783
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
784
|
+
'.')
|
|
785
|
+
target_edge_types = [
|
|
786
|
+
edge_type for edge_type in edge_types
|
|
787
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
788
|
+
]
|
|
789
|
+
if len(target_edge_types) != 1:
|
|
790
|
+
raise NotImplementedError(
|
|
791
|
+
f"Multilabel-classification queries based on "
|
|
792
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
793
|
+
f"planned to write a link prediction query instead, "
|
|
794
|
+
f"make sure to register '{col_name}' as a "
|
|
795
|
+
f"foreign key.")
|
|
796
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
797
|
+
|
|
798
|
+
return TaskType.REGRESSION
|
|
799
|
+
|
|
800
|
+
assert isinstance(target, Column)
|
|
801
|
+
|
|
802
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
803
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
804
|
+
|
|
805
|
+
if target.stype in {Stype.numerical}:
|
|
806
|
+
return TaskType.REGRESSION
|
|
807
|
+
|
|
808
|
+
raise NotImplementedError("Task type not yet supported")
|
|
809
|
+
|
|
810
|
+
def _get_default_anchor_time(
|
|
811
|
+
self,
|
|
812
|
+
query: ValidatedPredictiveQuery,
|
|
813
|
+
) -> pd.Timestamp:
|
|
814
|
+
if query.query_type == QueryType.TEMPORAL:
|
|
815
|
+
aggr_table_names = [
|
|
816
|
+
aggr._get_target_column_name().split('.')[0]
|
|
817
|
+
for aggr in query.get_all_target_aggregations()
|
|
818
|
+
]
|
|
819
|
+
return self._sampler.get_max_time(aggr_table_names)
|
|
820
|
+
|
|
821
|
+
assert query.query_type == QueryType.STATIC
|
|
822
|
+
return self._sampler.get_max_time()
|
|
823
|
+
|
|
756
824
|
def _validate_time(
|
|
757
825
|
self,
|
|
758
826
|
query: ValidatedPredictiveQuery,
|
|
759
827
|
anchor_time: pd.Timestamp,
|
|
760
|
-
context_anchor_time:
|
|
828
|
+
context_anchor_time: pd.Timestamp | None,
|
|
761
829
|
evaluate: bool,
|
|
762
830
|
) -> None:
|
|
763
831
|
|
|
764
|
-
if self.
|
|
832
|
+
if len(self._sampler.time_column_dict) == 0:
|
|
765
833
|
return # Graph without timestamps
|
|
766
834
|
|
|
767
|
-
|
|
835
|
+
min_time = self._sampler.get_min_time()
|
|
836
|
+
max_time = self._sampler.get_max_time()
|
|
837
|
+
|
|
838
|
+
if anchor_time < min_time:
|
|
768
839
|
raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
|
|
769
|
-
f"the earliest timestamp "
|
|
770
|
-
f"
|
|
840
|
+
f"the earliest timestamp '{min_time}' in the "
|
|
841
|
+
f"data.")
|
|
771
842
|
|
|
772
|
-
if
|
|
773
|
-
and context_anchor_time < self._graph_store.min_time):
|
|
843
|
+
if context_anchor_time is not None and context_anchor_time < min_time:
|
|
774
844
|
raise ValueError(f"Context anchor timestamp is too early or "
|
|
775
845
|
f"aggregation time range is too large. To make "
|
|
776
846
|
f"this prediction, we would need data back to "
|
|
777
847
|
f"'{context_anchor_time}', however, your data "
|
|
778
|
-
f"only contains data back to "
|
|
779
|
-
f"'{self._graph_store.min_time}'.")
|
|
848
|
+
f"only contains data back to '{min_time}'.")
|
|
780
849
|
|
|
781
850
|
if query.target_ast.date_offset_range is not None:
|
|
782
851
|
end_offset = query.target_ast.date_offset_range.end_date_offset
|
|
783
852
|
else:
|
|
784
853
|
end_offset = pd.DateOffset(0)
|
|
785
|
-
|
|
854
|
+
end_offset = end_offset * query.num_forecasts
|
|
855
|
+
|
|
786
856
|
if (context_anchor_time is not None
|
|
787
857
|
and context_anchor_time > anchor_time):
|
|
788
858
|
warnings.warn(f"Context anchor timestamp "
|
|
@@ -792,7 +862,7 @@ class KumoRFM:
|
|
|
792
862
|
f"intended.")
|
|
793
863
|
elif (query.query_type == QueryType.TEMPORAL
|
|
794
864
|
and context_anchor_time is not None
|
|
795
|
-
and context_anchor_time +
|
|
865
|
+
and context_anchor_time + end_offset > anchor_time):
|
|
796
866
|
warnings.warn(f"Aggregation for context examples at timestamp "
|
|
797
867
|
f"'{context_anchor_time}' will leak information "
|
|
798
868
|
f"from the prediction anchor timestamp "
|
|
@@ -800,40 +870,37 @@ class KumoRFM:
|
|
|
800
870
|
f"intended.")
|
|
801
871
|
|
|
802
872
|
elif (context_anchor_time is not None
|
|
803
|
-
and context_anchor_time -
|
|
804
|
-
|
|
805
|
-
_time = context_anchor_time - forecast_end_offset
|
|
873
|
+
and context_anchor_time - end_offset < min_time):
|
|
874
|
+
_time = context_anchor_time - end_offset
|
|
806
875
|
warnings.warn(f"Context anchor timestamp is too early or "
|
|
807
876
|
f"aggregation time range is too large. To form "
|
|
808
877
|
f"proper input data, we would need data back to "
|
|
809
878
|
f"'{_time}', however, your data only contains "
|
|
810
|
-
f"data back to '{
|
|
879
|
+
f"data back to '{min_time}'.")
|
|
811
880
|
|
|
812
|
-
if
|
|
813
|
-
> self._graph_store.max_time + pd.DateOffset(days=1)):
|
|
881
|
+
if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
|
|
814
882
|
warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
|
|
815
|
-
f"latest timestamp '{
|
|
816
|
-
f"
|
|
883
|
+
f"latest timestamp '{max_time}' in the data. Please "
|
|
884
|
+
f"make sure this is intended.")
|
|
817
885
|
|
|
818
|
-
|
|
819
|
-
if evaluate and anchor_time > max_eval_time:
|
|
886
|
+
if evaluate and anchor_time > max_time - end_offset:
|
|
820
887
|
raise ValueError(
|
|
821
888
|
f"Anchor timestamp for evaluation is after the latest "
|
|
822
|
-
f"supported timestamp '{
|
|
889
|
+
f"supported timestamp '{max_time - end_offset}'.")
|
|
823
890
|
|
|
824
891
|
def _get_context(
|
|
825
892
|
self,
|
|
826
893
|
query: ValidatedPredictiveQuery,
|
|
827
|
-
indices:
|
|
828
|
-
anchor_time:
|
|
829
|
-
context_anchor_time:
|
|
894
|
+
indices: list[str] | list[float] | list[int] | None,
|
|
895
|
+
anchor_time: pd.Timestamp | Literal['entity'] | None,
|
|
896
|
+
context_anchor_time: pd.Timestamp | None,
|
|
830
897
|
run_mode: RunMode,
|
|
831
|
-
num_neighbors:
|
|
898
|
+
num_neighbors: list[int] | None,
|
|
832
899
|
num_hops: int,
|
|
833
900
|
max_pq_iterations: int,
|
|
834
901
|
evaluate: bool,
|
|
835
|
-
random_seed:
|
|
836
|
-
logger:
|
|
902
|
+
random_seed: int | None = _RANDOM_SEED,
|
|
903
|
+
logger: ProgressLogger | None = None,
|
|
837
904
|
) -> Context:
|
|
838
905
|
|
|
839
906
|
if num_neighbors is not None:
|
|
@@ -850,10 +917,9 @@ class KumoRFM:
|
|
|
850
917
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
851
918
|
f"must go beyond this for your use-case.")
|
|
852
919
|
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
edge_types=self._graph_store.edge_types,
|
|
920
|
+
task_type = self._get_task_type(
|
|
921
|
+
query=query,
|
|
922
|
+
edge_types=self._sampler.edge_types,
|
|
857
923
|
)
|
|
858
924
|
|
|
859
925
|
if logger is not None:
|
|
@@ -885,14 +951,17 @@ class KumoRFM:
|
|
|
885
951
|
num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
|
|
886
952
|
|
|
887
953
|
if query.target_ast.date_offset_range is None:
|
|
888
|
-
|
|
954
|
+
step_offset = pd.DateOffset(0)
|
|
889
955
|
else:
|
|
890
|
-
|
|
891
|
-
|
|
956
|
+
step_offset = query.target_ast.date_offset_range.end_date_offset
|
|
957
|
+
end_offset = step_offset * query.num_forecasts
|
|
958
|
+
|
|
892
959
|
if anchor_time is None:
|
|
893
|
-
anchor_time = self.
|
|
960
|
+
anchor_time = self._get_default_anchor_time(query)
|
|
961
|
+
|
|
894
962
|
if evaluate:
|
|
895
|
-
anchor_time = anchor_time -
|
|
963
|
+
anchor_time = anchor_time - end_offset
|
|
964
|
+
|
|
896
965
|
if logger is not None:
|
|
897
966
|
assert isinstance(anchor_time, pd.Timestamp)
|
|
898
967
|
if anchor_time == pd.Timestamp.min:
|
|
@@ -906,57 +975,71 @@ class KumoRFM:
|
|
|
906
975
|
|
|
907
976
|
assert anchor_time is not None
|
|
908
977
|
if isinstance(anchor_time, pd.Timestamp):
|
|
978
|
+
if context_anchor_time == 'entity':
|
|
979
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
980
|
+
"for context and prediction examples")
|
|
909
981
|
if context_anchor_time is None:
|
|
910
|
-
context_anchor_time = anchor_time -
|
|
982
|
+
context_anchor_time = anchor_time - end_offset
|
|
911
983
|
self._validate_time(query, anchor_time, context_anchor_time,
|
|
912
984
|
evaluate)
|
|
913
985
|
else:
|
|
914
986
|
assert anchor_time == 'entity'
|
|
915
|
-
if query.
|
|
987
|
+
if query.query_type != QueryType.STATIC:
|
|
988
|
+
raise ValueError("Anchor time 'entity' is only valid for "
|
|
989
|
+
"static predictive queries")
|
|
990
|
+
if query.entity_table not in self._sampler.time_column_dict:
|
|
916
991
|
raise ValueError(f"Anchor time 'entity' requires the entity "
|
|
917
992
|
f"table '{query.entity_table}' to "
|
|
918
993
|
f"have a time column")
|
|
919
|
-
if context_anchor_time
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
context_anchor_time =
|
|
994
|
+
if isinstance(context_anchor_time, pd.Timestamp):
|
|
995
|
+
raise ValueError("Anchor time 'entity' needs to be shared "
|
|
996
|
+
"for context and prediction examples")
|
|
997
|
+
context_anchor_time = 'entity'
|
|
923
998
|
|
|
924
|
-
|
|
999
|
+
num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
|
|
925
1000
|
if evaluate:
|
|
926
|
-
|
|
1001
|
+
num_test_examples = _MAX_TEST_SIZE[run_mode]
|
|
927
1002
|
if task_type.is_link_pred:
|
|
928
|
-
|
|
1003
|
+
num_test_examples = num_test_examples // 5
|
|
1004
|
+
else:
|
|
1005
|
+
num_test_examples = 0
|
|
1006
|
+
|
|
1007
|
+
train, test = self._sampler.sample_target(
|
|
1008
|
+
query=query,
|
|
1009
|
+
num_train_examples=num_train_examples,
|
|
1010
|
+
train_anchor_time=context_anchor_time,
|
|
1011
|
+
num_train_trials=max_pq_iterations * num_train_examples,
|
|
1012
|
+
num_test_examples=num_test_examples,
|
|
1013
|
+
test_anchor_time=anchor_time,
|
|
1014
|
+
num_test_trials=max_pq_iterations * num_test_examples,
|
|
1015
|
+
random_seed=random_seed,
|
|
1016
|
+
)
|
|
1017
|
+
train_pkey, train_time, y_train = train
|
|
1018
|
+
test_pkey, test_time, y_test = test
|
|
929
1019
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
num_rhs = y_test.explode().nunique()
|
|
951
|
-
msg = (f"Collected {len(y_test):,} test examples with "
|
|
952
|
-
f"{num_rhs:,} unique items")
|
|
953
|
-
else:
|
|
954
|
-
raise NotImplementedError
|
|
955
|
-
logger.log(msg)
|
|
1020
|
+
if evaluate and logger is not None:
|
|
1021
|
+
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
1022
|
+
pos = 100 * int((y_test > 0).sum()) / len(y_test)
|
|
1023
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1024
|
+
f"{pos:.2f}% positive cases")
|
|
1025
|
+
elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
|
|
1026
|
+
msg = (f"Collected {len(y_test):,} test examples holding "
|
|
1027
|
+
f"{y_test.nunique()} classes")
|
|
1028
|
+
elif task_type == TaskType.REGRESSION:
|
|
1029
|
+
_min, _max = float(y_test.min()), float(y_test.max())
|
|
1030
|
+
msg = (f"Collected {len(y_test):,} test examples with targets "
|
|
1031
|
+
f"between {format_value(_min)} and "
|
|
1032
|
+
f"{format_value(_max)}")
|
|
1033
|
+
elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
|
|
1034
|
+
num_rhs = y_test.explode().nunique()
|
|
1035
|
+
msg = (f"Collected {len(y_test):,} test examples with "
|
|
1036
|
+
f"{num_rhs:,} unique items")
|
|
1037
|
+
else:
|
|
1038
|
+
raise NotImplementedError
|
|
1039
|
+
logger.log(msg)
|
|
956
1040
|
|
|
957
|
-
|
|
1041
|
+
if not evaluate:
|
|
958
1042
|
assert indices is not None
|
|
959
|
-
|
|
960
1043
|
if len(indices) > _MAX_PRED_SIZE[task_type]:
|
|
961
1044
|
raise ValueError(f"Cannot predict for more than "
|
|
962
1045
|
f"{_MAX_PRED_SIZE[task_type]:,} entities at "
|
|
@@ -964,26 +1047,12 @@ class KumoRFM:
|
|
|
964
1047
|
f"`KumoRFM.batch_mode` to process entities "
|
|
965
1048
|
f"in batches")
|
|
966
1049
|
|
|
967
|
-
|
|
968
|
-
table_name=query.entity_table,
|
|
969
|
-
pkey=pd.Series(indices),
|
|
970
|
-
)
|
|
971
|
-
|
|
1050
|
+
test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
|
|
972
1051
|
if isinstance(anchor_time, pd.Timestamp):
|
|
973
|
-
test_time = pd.Series(anchor_time).repeat(
|
|
974
|
-
len(
|
|
1052
|
+
test_time = pd.Series([anchor_time]).repeat(
|
|
1053
|
+
len(indices)).reset_index(drop=True)
|
|
975
1054
|
else:
|
|
976
|
-
|
|
977
|
-
time = time[test_node] * 1000**3
|
|
978
|
-
test_time = pd.Series(time, dtype='datetime64[ns]')
|
|
979
|
-
|
|
980
|
-
train_node, train_time, y_train = query_driver.collect_train(
|
|
981
|
-
size=_MAX_CONTEXT_SIZE[run_mode],
|
|
982
|
-
anchor_time=context_anchor_time or 'entity',
|
|
983
|
-
exclude_node=test_node if (query.query_type == QueryType.STATIC
|
|
984
|
-
or anchor_time == 'entity') else None,
|
|
985
|
-
max_iterations=max_pq_iterations,
|
|
986
|
-
)
|
|
1055
|
+
train_time = test_time = 'entity'
|
|
987
1056
|
|
|
988
1057
|
if logger is not None:
|
|
989
1058
|
if task_type == TaskType.BINARY_CLASSIFICATION:
|
|
@@ -1006,12 +1075,12 @@ class KumoRFM:
|
|
|
1006
1075
|
raise NotImplementedError
|
|
1007
1076
|
logger.log(msg)
|
|
1008
1077
|
|
|
1009
|
-
entity_table_names:
|
|
1078
|
+
entity_table_names: tuple[str, ...]
|
|
1010
1079
|
if task_type.is_link_pred:
|
|
1011
1080
|
final_aggr = query.get_final_target_aggregation()
|
|
1012
1081
|
assert final_aggr is not None
|
|
1013
1082
|
edge_fkey = final_aggr._get_target_column_name()
|
|
1014
|
-
for edge_type in self.
|
|
1083
|
+
for edge_type in self._sampler.edge_types:
|
|
1015
1084
|
if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
|
|
1016
1085
|
entity_table_names = (
|
|
1017
1086
|
query.entity_table,
|
|
@@ -1023,20 +1092,24 @@ class KumoRFM:
|
|
|
1023
1092
|
# Exclude the entity anchor time from the feature set to prevent
|
|
1024
1093
|
# running out-of-distribution between in-context and test examples:
|
|
1025
1094
|
exclude_cols_dict = query.get_exclude_cols_dict()
|
|
1026
|
-
if
|
|
1095
|
+
if entity_table_names[0] in self._sampler.time_column_dict:
|
|
1027
1096
|
if entity_table_names[0] not in exclude_cols_dict:
|
|
1028
1097
|
exclude_cols_dict[entity_table_names[0]] = []
|
|
1029
|
-
|
|
1030
|
-
time_column = time_column_dict[entity_table_names[0]]
|
|
1098
|
+
time_column = self._sampler.time_column_dict[entity_table_names[0]]
|
|
1031
1099
|
exclude_cols_dict[entity_table_names[0]].append(time_column)
|
|
1032
1100
|
|
|
1033
|
-
subgraph = self.
|
|
1101
|
+
subgraph = self._sampler.sample_subgraph(
|
|
1034
1102
|
entity_table_names=entity_table_names,
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1103
|
+
entity_pkey=pd.concat(
|
|
1104
|
+
[train_pkey, test_pkey],
|
|
1105
|
+
axis=0,
|
|
1106
|
+
ignore_index=True,
|
|
1107
|
+
),
|
|
1108
|
+
anchor_time=pd.concat(
|
|
1109
|
+
[train_time, test_time],
|
|
1110
|
+
axis=0,
|
|
1111
|
+
ignore_index=True,
|
|
1112
|
+
) if isinstance(train_time, pd.Series) else 'entity',
|
|
1040
1113
|
num_neighbors=num_neighbors,
|
|
1041
1114
|
exclude_cols_dict=exclude_cols_dict,
|
|
1042
1115
|
)
|
|
@@ -1048,23 +1121,19 @@ class KumoRFM:
|
|
|
1048
1121
|
f"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
1049
1122
|
f"must go beyond this for your use-case.")
|
|
1050
1123
|
|
|
1051
|
-
step_size: Optional[int] = None
|
|
1052
|
-
if query.query_type == QueryType.TEMPORAL:
|
|
1053
|
-
step_size = date_offset_to_seconds(end_offset)
|
|
1054
|
-
|
|
1055
1124
|
return Context(
|
|
1056
1125
|
task_type=task_type,
|
|
1057
1126
|
entity_table_names=entity_table_names,
|
|
1058
1127
|
subgraph=subgraph,
|
|
1059
1128
|
y_train=y_train,
|
|
1060
|
-
y_test=y_test,
|
|
1129
|
+
y_test=y_test if evaluate else None,
|
|
1061
1130
|
top_k=query.top_k,
|
|
1062
|
-
step_size=
|
|
1131
|
+
step_size=None,
|
|
1063
1132
|
)
|
|
1064
1133
|
|
|
1065
1134
|
@staticmethod
|
|
1066
1135
|
def _validate_metrics(
|
|
1067
|
-
metrics:
|
|
1136
|
+
metrics: list[str],
|
|
1068
1137
|
task_type: TaskType,
|
|
1069
1138
|
) -> None:
|
|
1070
1139
|
|
|
@@ -1121,7 +1190,7 @@ class KumoRFM:
|
|
|
1121
1190
|
f"'https://github.com/kumo-ai/kumo-rfm'.")
|
|
1122
1191
|
|
|
1123
1192
|
|
|
1124
|
-
def format_value(value:
|
|
1193
|
+
def format_value(value: int | float) -> str:
|
|
1125
1194
|
if value == int(value):
|
|
1126
1195
|
return f'{int(value):,}'
|
|
1127
1196
|
if abs(value) >= 1000:
|