kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +41 -80
  10. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +147 -0
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
  16. kumoai/experimental/rfm/base/__init__.py +26 -2
  17. kumoai/experimental/rfm/base/column.py +6 -12
  18. kumoai/experimental/rfm/base/column_expression.py +16 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +19 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +113 -0
  23. kumoai/experimental/rfm/base/table.py +174 -76
  24. kumoai/experimental/rfm/graph.py +444 -84
  25. kumoai/experimental/rfm/infer/__init__.py +6 -0
  26. kumoai/experimental/rfm/infer/dtype.py +77 -0
  27. kumoai/experimental/rfm/infer/pkey.py +128 -0
  28. kumoai/experimental/rfm/infer/time_col.py +61 -0
  29. kumoai/experimental/rfm/pquery/executor.py +27 -27
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  31. kumoai/experimental/rfm/rfm.py +299 -240
  32. kumoai/experimental/rfm/sagemaker.py +4 -4
  33. kumoai/pquery/predictive_query.py +10 -6
  34. kumoai/testing/snow.py +50 -0
  35. kumoai/utils/__init__.py +3 -2
  36. kumoai/utils/progress_logger.py +178 -12
  37. kumoai/utils/sql.py +3 -0
  38. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
  39. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
  40. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  41. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  42. kumoai/experimental/rfm/utils.py +0 -344
  43. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
  44. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
  45. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
@@ -2,25 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
32
30
 
31
+ from kumoai import in_notebook, in_snowflake_notebook
33
32
  from kumoai.client.rfm import RFMAPI
34
33
  from kumoai.exceptions import HTTPException
35
34
  from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
35
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
36
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger
44
38
 
45
39
  _RANDOM_SEED = 42
46
40
 
@@ -95,24 +89,41 @@ class Explanation:
95
89
  def __getitem__(self, index: Literal[1]) -> str:
96
90
  pass
97
91
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
92
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
93
  if index == 0:
100
94
  return self.prediction
101
95
  if index == 1:
102
96
  return self.summary
103
97
  raise IndexError("Index out of range")
104
98
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
99
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
100
  return iter((self.prediction, self.summary))
107
101
 
108
102
  def __repr__(self) -> str:
109
103
  return str((self.prediction, self.summary))
110
104
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
105
+ def print(self) -> None:
106
+ r"""Prints the explanation."""
107
+ if in_snowflake_notebook():
108
+ import streamlit as st
109
+ st.dataframe(self.prediction, hide_index=True)
110
+ st.markdown(self.summary)
111
+ elif in_notebook():
112
+ from IPython.display import Markdown, display
113
+ try:
114
+ if hasattr(self.prediction.style, 'hide'):
115
+ display(self.prediction.hide(axis='index')) # pandas=2
116
+ else:
117
+ display(self.prediction.hide_index()) # pandas <1.3
118
+ except ImportError:
119
+ print(self.prediction.to_string(index=False)) # missing jinja2
120
+ display(Markdown(self.summary))
121
+ else:
122
+ print(self.prediction.to_string(index=False))
123
+ print(self.summary)
113
124
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
125
+ def _ipython_display_(self) -> None:
126
+ self.print()
116
127
 
117
128
 
118
129
  class KumoRFM:
@@ -150,31 +161,36 @@ class KumoRFM:
150
161
 
151
162
  Args:
152
163
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
164
  verbose: Whether to print verbose output.
165
+ optimize: If set to ``True``, will optimize the underlying data backend
166
+ for optimal querying. For example, for transactional database
167
+ backends, will create any missing indices. Requires write-access to
168
+ the data backend.
163
169
  """
164
170
  def __init__(
165
171
  self,
166
172
  graph: Graph,
167
- preprocess: bool = False,
168
- verbose: Union[bool, ProgressLogger] = True,
173
+ verbose: bool | ProgressLogger = True,
174
+ optimize: bool = False,
169
175
  ) -> None:
170
176
  graph = graph.validate()
171
177
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
174
178
 
175
- self._client: Optional[RFMAPI] = None
179
+ if graph.backend == DataBackend.LOCAL:
180
+ from kumoai.experimental.rfm.backend.local import LocalSampler
181
+ self._sampler: Sampler = LocalSampler(graph, verbose)
182
+ elif graph.backend == DataBackend.SQLITE:
183
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
184
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
185
+ elif graph.backend == DataBackend.SNOWFLAKE:
186
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
187
+ self._sampler = SnowSampler(graph, verbose)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ self._client: RFMAPI | None = None
176
192
 
177
- self._batch_size: Optional[int | Literal['max']] = None
193
+ self._batch_size: int | Literal['max'] | None = None
178
194
  self.num_retries: int = 0
179
195
 
180
196
  @property
@@ -192,7 +208,7 @@ class KumoRFM:
192
208
  @contextmanager
193
209
  def batch_mode(
194
210
  self,
195
- batch_size: Union[int, Literal['max']] = 'max',
211
+ batch_size: int | Literal['max'] = 'max',
196
212
  num_retries: int = 1,
197
213
  ) -> Generator[None, None, None]:
198
214
  """Context manager to predict in batches.
@@ -226,17 +242,17 @@ class KumoRFM:
226
242
  def predict(
227
243
  self,
228
244
  query: str,
229
- indices: Union[List[str], List[float], List[int], None] = None,
245
+ indices: list[str] | list[float] | list[int] | None = None,
230
246
  *,
231
247
  explain: Literal[False] = False,
232
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
233
- context_anchor_time: Union[pd.Timestamp, None] = None,
234
- run_mode: Union[RunMode, str] = RunMode.FAST,
235
- num_neighbors: Optional[List[int]] = None,
248
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
249
+ context_anchor_time: pd.Timestamp | None = None,
250
+ run_mode: RunMode | str = RunMode.FAST,
251
+ num_neighbors: list[int] | None = None,
236
252
  num_hops: int = 2,
237
- max_pq_iterations: int = 20,
238
- random_seed: Optional[int] = _RANDOM_SEED,
239
- verbose: Union[bool, ProgressLogger] = True,
253
+ max_pq_iterations: int = 10,
254
+ random_seed: int | None = _RANDOM_SEED,
255
+ verbose: bool | ProgressLogger = True,
240
256
  use_prediction_time: bool = False,
241
257
  ) -> pd.DataFrame:
242
258
  pass
@@ -245,17 +261,17 @@ class KumoRFM:
245
261
  def predict(
246
262
  self,
247
263
  query: str,
248
- indices: Union[List[str], List[float], List[int], None] = None,
264
+ indices: list[str] | list[float] | list[int] | None = None,
249
265
  *,
250
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
251
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
252
- context_anchor_time: Union[pd.Timestamp, None] = None,
253
- run_mode: Union[RunMode, str] = RunMode.FAST,
254
- num_neighbors: Optional[List[int]] = None,
266
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
267
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
268
+ context_anchor_time: pd.Timestamp | None = None,
269
+ run_mode: RunMode | str = RunMode.FAST,
270
+ num_neighbors: list[int] | None = None,
255
271
  num_hops: int = 2,
256
- max_pq_iterations: int = 20,
257
- random_seed: Optional[int] = _RANDOM_SEED,
258
- verbose: Union[bool, ProgressLogger] = True,
272
+ max_pq_iterations: int = 10,
273
+ random_seed: int | None = _RANDOM_SEED,
274
+ verbose: bool | ProgressLogger = True,
259
275
  use_prediction_time: bool = False,
260
276
  ) -> Explanation:
261
277
  pass
@@ -263,19 +279,19 @@ class KumoRFM:
263
279
  def predict(
264
280
  self,
265
281
  query: str,
266
- indices: Union[List[str], List[float], List[int], None] = None,
282
+ indices: list[str] | list[float] | list[int] | None = None,
267
283
  *,
268
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
269
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
270
- context_anchor_time: Union[pd.Timestamp, None] = None,
271
- run_mode: Union[RunMode, str] = RunMode.FAST,
272
- num_neighbors: Optional[List[int]] = None,
284
+ explain: bool | ExplainConfig | dict[str, Any] = False,
285
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
286
+ context_anchor_time: pd.Timestamp | None = None,
287
+ run_mode: RunMode | str = RunMode.FAST,
288
+ num_neighbors: list[int] | None = None,
273
289
  num_hops: int = 2,
274
- max_pq_iterations: int = 20,
275
- random_seed: Optional[int] = _RANDOM_SEED,
276
- verbose: Union[bool, ProgressLogger] = True,
290
+ max_pq_iterations: int = 10,
291
+ random_seed: int | None = _RANDOM_SEED,
292
+ verbose: bool | ProgressLogger = True,
277
293
  use_prediction_time: bool = False,
278
- ) -> Union[pd.DataFrame, Explanation]:
294
+ ) -> pd.DataFrame | Explanation:
279
295
  """Returns predictions for a predictive query.
280
296
 
281
297
  Args:
@@ -317,7 +333,7 @@ class KumoRFM:
317
333
  If ``explain`` is provided, returns an :class:`Explanation` object
318
334
  containing the prediction, summary, and details.
319
335
  """
320
- explain_config: Optional[ExplainConfig] = None
336
+ explain_config: ExplainConfig | None = None
321
337
  if explain is True:
322
338
  explain_config = ExplainConfig()
323
339
  elif explain is not False:
@@ -361,15 +377,15 @@ class KumoRFM:
361
377
  msg = f'[bold]PREDICT[/bold] {query_repr}'
362
378
 
363
379
  if not isinstance(verbose, ProgressLogger):
364
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
380
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
365
381
 
366
382
  with verbose as logger:
367
383
 
368
- batch_size: Optional[int] = None
384
+ batch_size: int | None = None
369
385
  if self._batch_size == 'max':
370
- task_type = LocalPQueryDriver.get_task_type(
371
- query_def,
372
- edge_types=self._graph_store.edge_types,
386
+ task_type = self._get_task_type(
387
+ query=query_def,
388
+ edge_types=self._sampler.edge_types,
373
389
  )
374
390
  batch_size = _MAX_PRED_SIZE[task_type]
375
391
  else:
@@ -385,9 +401,9 @@ class KumoRFM:
385
401
  logger.log(f"Splitting {len(indices):,} entities into "
386
402
  f"{len(batches):,} batches of size {batch_size:,}")
387
403
 
388
- predictions: List[pd.DataFrame] = []
389
- summary: Optional[str] = None
390
- details: Optional[Explanation] = None
404
+ predictions: list[pd.DataFrame] = []
405
+ summary: str | None = None
406
+ details: Explanation | None = None
391
407
  for i, batch in enumerate(batches):
392
408
  # TODO Re-use the context for subsequent predictions.
393
409
  context = self._get_context(
@@ -421,8 +437,7 @@ class KumoRFM:
421
437
  stats = Context.get_memory_stats(request_msg.context)
422
438
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
423
439
 
424
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
425
- and len(batches) > 1):
440
+ if i == 0 and len(batches) > 1:
426
441
  verbose.init_progress(
427
442
  total=len(batches),
428
443
  description='Predicting',
@@ -443,10 +458,10 @@ class KumoRFM:
443
458
 
444
459
  # Cast 'ENTITY' to correct data type:
445
460
  if 'ENTITY' in df:
446
- entity = query_def.entity_table
447
- pkey_map = self._graph_store.pkey_map_dict[entity]
448
- df['ENTITY'] = df['ENTITY'].astype(
449
- type(pkey_map.index[0]))
461
+ table_dict = context.subgraph.table_dict
462
+ table = table_dict[query_def.entity_table]
463
+ ser = table.df[table.primary_key]
464
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
450
465
 
451
466
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
452
467
  if 'ANCHOR_TIMESTAMP' in df:
@@ -461,8 +476,7 @@ class KumoRFM:
461
476
 
462
477
  predictions.append(df)
463
478
 
464
- if (isinstance(verbose, InteractiveProgressLogger)
465
- and len(batches) > 1):
479
+ if len(batches) > 1:
466
480
  verbose.step()
467
481
 
468
482
  break
@@ -500,9 +514,9 @@ class KumoRFM:
500
514
  def is_valid_entity(
501
515
  self,
502
516
  query: str,
503
- indices: Union[List[str], List[float], List[int], None] = None,
517
+ indices: list[str] | list[float] | list[int] | None = None,
504
518
  *,
505
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
519
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
506
520
  ) -> np.ndarray:
507
521
  r"""Returns a mask that denotes which entities are valid for the
508
522
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -529,37 +543,32 @@ class KumoRFM:
529
543
  raise ValueError("At least one entity is required")
530
544
 
531
545
  if anchor_time is None:
532
- anchor_time = self._graph_store.max_time
546
+ anchor_time = self._get_default_anchor_time(query_def)
533
547
 
534
548
  if isinstance(anchor_time, pd.Timestamp):
535
549
  self._validate_time(query_def, anchor_time, None, False)
536
550
  else:
537
551
  assert anchor_time == 'entity'
538
- if (query_def.entity_table not in self._graph_store.time_dict):
552
+ if query_def.entity_table not in self._sampler.time_column_dict:
539
553
  raise ValueError(f"Anchor time 'entity' requires the entity "
540
554
  f"table '{query_def.entity_table}' "
541
555
  f"to have a time column.")
542
556
 
543
- node = self._graph_store.get_node_id(
544
- table_name=query_def.entity_table,
545
- pkey=pd.Series(indices),
546
- )
547
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
548
- return query_driver.is_valid(node, anchor_time)
557
+ raise NotImplementedError
549
558
 
550
559
  def evaluate(
551
560
  self,
552
561
  query: str,
553
562
  *,
554
- metrics: Optional[List[str]] = None,
555
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
556
- context_anchor_time: Union[pd.Timestamp, None] = None,
557
- run_mode: Union[RunMode, str] = RunMode.FAST,
558
- num_neighbors: Optional[List[int]] = None,
563
+ metrics: list[str] | None = None,
564
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
565
+ context_anchor_time: pd.Timestamp | None = None,
566
+ run_mode: RunMode | str = RunMode.FAST,
567
+ num_neighbors: list[int] | None = None,
559
568
  num_hops: int = 2,
560
- max_pq_iterations: int = 20,
561
- random_seed: Optional[int] = _RANDOM_SEED,
562
- verbose: Union[bool, ProgressLogger] = True,
569
+ max_pq_iterations: int = 10,
570
+ random_seed: int | None = _RANDOM_SEED,
571
+ verbose: bool | ProgressLogger = True,
563
572
  use_prediction_time: bool = False,
564
573
  ) -> pd.DataFrame:
565
574
  """Evaluates a predictive query.
@@ -607,7 +616,7 @@ class KumoRFM:
607
616
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
608
617
 
609
618
  if not isinstance(verbose, ProgressLogger):
610
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
619
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
611
620
 
612
621
  with verbose as logger:
613
622
  context = self._get_context(
@@ -666,9 +675,9 @@ class KumoRFM:
666
675
  query: str,
667
676
  size: int,
668
677
  *,
669
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
670
- random_seed: Optional[int] = _RANDOM_SEED,
671
- max_iterations: int = 20,
678
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
679
+ random_seed: int | None = _RANDOM_SEED,
680
+ max_iterations: int = 10,
672
681
  ) -> pd.DataFrame:
673
682
  """Returns the labels of a predictive query for a specified anchor
674
683
  time.
@@ -688,40 +697,37 @@ class KumoRFM:
688
697
  query_def = self._parse_query(query)
689
698
 
690
699
  if anchor_time is None:
691
- anchor_time = self._graph_store.max_time
700
+ anchor_time = self._get_default_anchor_time(query_def)
692
701
  if query_def.target_ast.date_offset_range is not None:
693
- anchor_time = anchor_time - (
694
- query_def.target_ast.date_offset_range.end_date_offset *
695
- query_def.num_forecasts)
702
+ offset = query_def.target_ast.date_offset_range.end_date_offset
703
+ offset *= query_def.num_forecasts
704
+ anchor_time -= offset
696
705
 
697
706
  assert anchor_time is not None
698
707
  if isinstance(anchor_time, pd.Timestamp):
699
708
  self._validate_time(query_def, anchor_time, None, evaluate=True)
700
709
  else:
701
710
  assert anchor_time == 'entity'
702
- if (query_def.entity_table not in self._graph_store.time_dict):
711
+ if query_def.entity_table not in self._sampler.time_column_dict:
703
712
  raise ValueError(f"Anchor time 'entity' requires the entity "
704
713
  f"table '{query_def.entity_table}' "
705
714
  f"to have a time column")
706
715
 
707
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
708
- random_seed)
709
-
710
- node, time, y = query_driver.collect_test(
711
- size=size,
712
- anchor_time=anchor_time,
713
- batch_size=min(10_000, size),
714
- max_iterations=max_iterations,
715
- guarantee_train_examples=False,
716
+ train, test = self._sampler.sample_target(
717
+ query=query,
718
+ num_train_examples=0,
719
+ train_anchor_time=anchor_time,
720
+ num_train_trials=0,
721
+ num_test_examples=size,
722
+ test_anchor_time=anchor_time,
723
+ num_test_trials=max_iterations * size,
724
+ random_seed=random_seed,
716
725
  )
717
726
 
718
- entity = self._graph_store.pkey_map_dict[
719
- query_def.entity_table].index[node]
720
-
721
727
  return pd.DataFrame({
722
- 'ENTITY': entity,
723
- 'ANCHOR_TIMESTAMP': time,
724
- 'TARGET': y,
728
+ 'ENTITY': test.entity_pkey,
729
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
730
+ 'TARGET': test.target,
725
731
  })
726
732
 
727
733
  # Helpers #################################################################
@@ -744,8 +750,6 @@ class KumoRFM:
744
750
 
745
751
  resp = self._api_client.parse_query(request)
746
752
 
747
- # TODO Expose validation warnings.
748
-
749
753
  if len(resp.validation_response.warnings) > 0:
750
754
  msg = '\n'.join([
751
755
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -763,36 +767,92 @@ class KumoRFM:
763
767
  raise ValueError(f"Failed to parse query '{query}'. "
764
768
  f"{msg}") from None
765
769
 
770
+ @staticmethod
771
+ def _get_task_type(
772
+ query: ValidatedPredictiveQuery,
773
+ edge_types: list[tuple[str, str, str]],
774
+ ) -> TaskType:
775
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
776
+ return TaskType.BINARY_CLASSIFICATION
777
+
778
+ target = query.target_ast
779
+ if isinstance(target, Join):
780
+ target = target.rhs_target
781
+ if isinstance(target, Aggregation):
782
+ if target.aggr == AggregationType.LIST_DISTINCT:
783
+ table_name, col_name = target._get_target_column_name().split(
784
+ '.')
785
+ target_edge_types = [
786
+ edge_type for edge_type in edge_types
787
+ if edge_type[0] == table_name and edge_type[1] == col_name
788
+ ]
789
+ if len(target_edge_types) != 1:
790
+ raise NotImplementedError(
791
+ f"Multilabel-classification queries based on "
792
+ f"'LIST_DISTINCT' are not supported yet. If you "
793
+ f"planned to write a link prediction query instead, "
794
+ f"make sure to register '{col_name}' as a "
795
+ f"foreign key.")
796
+ return TaskType.TEMPORAL_LINK_PREDICTION
797
+
798
+ return TaskType.REGRESSION
799
+
800
+ assert isinstance(target, Column)
801
+
802
+ if target.stype in {Stype.ID, Stype.categorical}:
803
+ return TaskType.MULTICLASS_CLASSIFICATION
804
+
805
+ if target.stype in {Stype.numerical}:
806
+ return TaskType.REGRESSION
807
+
808
+ raise NotImplementedError("Task type not yet supported")
809
+
810
+ def _get_default_anchor_time(
811
+ self,
812
+ query: ValidatedPredictiveQuery,
813
+ ) -> pd.Timestamp:
814
+ if query.query_type == QueryType.TEMPORAL:
815
+ aggr_table_names = [
816
+ aggr._get_target_column_name().split('.')[0]
817
+ for aggr in query.get_all_target_aggregations()
818
+ ]
819
+ return self._sampler.get_max_time(aggr_table_names)
820
+
821
+ assert query.query_type == QueryType.STATIC
822
+ return self._sampler.get_max_time()
823
+
766
824
  def _validate_time(
767
825
  self,
768
826
  query: ValidatedPredictiveQuery,
769
827
  anchor_time: pd.Timestamp,
770
- context_anchor_time: Union[pd.Timestamp, None],
828
+ context_anchor_time: pd.Timestamp | None,
771
829
  evaluate: bool,
772
830
  ) -> None:
773
831
 
774
- if self._graph_store.min_time == pd.Timestamp.max:
832
+ if len(self._sampler.time_column_dict) == 0:
775
833
  return # Graph without timestamps
776
834
 
777
- if anchor_time < self._graph_store.min_time:
835
+ min_time = self._sampler.get_min_time()
836
+ max_time = self._sampler.get_max_time()
837
+
838
+ if anchor_time < min_time:
778
839
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
779
- f"the earliest timestamp "
780
- f"'{self._graph_store.min_time}' in the data.")
840
+ f"the earliest timestamp '{min_time}' in the "
841
+ f"data.")
781
842
 
782
- if (context_anchor_time is not None
783
- and context_anchor_time < self._graph_store.min_time):
843
+ if context_anchor_time is not None and context_anchor_time < min_time:
784
844
  raise ValueError(f"Context anchor timestamp is too early or "
785
845
  f"aggregation time range is too large. To make "
786
846
  f"this prediction, we would need data back to "
787
847
  f"'{context_anchor_time}', however, your data "
788
- f"only contains data back to "
789
- f"'{self._graph_store.min_time}'.")
848
+ f"only contains data back to '{min_time}'.")
790
849
 
791
850
  if query.target_ast.date_offset_range is not None:
792
851
  end_offset = query.target_ast.date_offset_range.end_date_offset
793
852
  else:
794
853
  end_offset = pd.DateOffset(0)
795
- forecast_end_offset = end_offset * query.num_forecasts
854
+ end_offset = end_offset * query.num_forecasts
855
+
796
856
  if (context_anchor_time is not None
797
857
  and context_anchor_time > anchor_time):
798
858
  warnings.warn(f"Context anchor timestamp "
@@ -802,7 +862,7 @@ class KumoRFM:
802
862
  f"intended.")
803
863
  elif (query.query_type == QueryType.TEMPORAL
804
864
  and context_anchor_time is not None
805
- and context_anchor_time + forecast_end_offset > anchor_time):
865
+ and context_anchor_time + end_offset > anchor_time):
806
866
  warnings.warn(f"Aggregation for context examples at timestamp "
807
867
  f"'{context_anchor_time}' will leak information "
808
868
  f"from the prediction anchor timestamp "
@@ -810,40 +870,37 @@ class KumoRFM:
810
870
  f"intended.")
811
871
 
812
872
  elif (context_anchor_time is not None
813
- and context_anchor_time - forecast_end_offset
814
- < self._graph_store.min_time):
815
- _time = context_anchor_time - forecast_end_offset
873
+ and context_anchor_time - end_offset < min_time):
874
+ _time = context_anchor_time - end_offset
816
875
  warnings.warn(f"Context anchor timestamp is too early or "
817
876
  f"aggregation time range is too large. To form "
818
877
  f"proper input data, we would need data back to "
819
878
  f"'{_time}', however, your data only contains "
820
- f"data back to '{self._graph_store.min_time}'.")
879
+ f"data back to '{min_time}'.")
821
880
 
822
- if (not evaluate and anchor_time
823
- > self._graph_store.max_time + pd.DateOffset(days=1)):
881
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
824
882
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
825
- f"latest timestamp '{self._graph_store.max_time}' "
826
- f"in the data. Please make sure this is intended.")
883
+ f"latest timestamp '{max_time}' in the data. Please "
884
+ f"make sure this is intended.")
827
885
 
828
- max_eval_time = self._graph_store.max_time - forecast_end_offset
829
- if evaluate and anchor_time > max_eval_time:
886
+ if evaluate and anchor_time > max_time - end_offset:
830
887
  raise ValueError(
831
888
  f"Anchor timestamp for evaluation is after the latest "
832
- f"supported timestamp '{max_eval_time}'.")
889
+ f"supported timestamp '{max_time - end_offset}'.")
833
890
 
834
891
  def _get_context(
835
892
  self,
836
893
  query: ValidatedPredictiveQuery,
837
- indices: Union[List[str], List[float], List[int], None],
838
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
839
- context_anchor_time: Union[pd.Timestamp, None],
894
+ indices: list[str] | list[float] | list[int] | None,
895
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
896
+ context_anchor_time: pd.Timestamp | None,
840
897
  run_mode: RunMode,
841
- num_neighbors: Optional[List[int]],
898
+ num_neighbors: list[int] | None,
842
899
  num_hops: int,
843
900
  max_pq_iterations: int,
844
901
  evaluate: bool,
845
- random_seed: Optional[int] = _RANDOM_SEED,
846
- logger: Optional[ProgressLogger] = None,
902
+ random_seed: int | None = _RANDOM_SEED,
903
+ logger: ProgressLogger | None = None,
847
904
  ) -> Context:
848
905
 
849
906
  if num_neighbors is not None:
@@ -860,10 +917,9 @@ class KumoRFM:
860
917
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
861
918
  f"must go beyond this for your use-case.")
862
919
 
863
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
864
- task_type = LocalPQueryDriver.get_task_type(
865
- query,
866
- edge_types=self._graph_store.edge_types,
920
+ task_type = self._get_task_type(
921
+ query=query,
922
+ edge_types=self._sampler.edge_types,
867
923
  )
868
924
 
869
925
  if logger is not None:
@@ -895,14 +951,17 @@ class KumoRFM:
895
951
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
896
952
 
897
953
  if query.target_ast.date_offset_range is None:
898
- end_offset = pd.DateOffset(0)
954
+ step_offset = pd.DateOffset(0)
899
955
  else:
900
- end_offset = query.target_ast.date_offset_range.end_date_offset
901
- forecast_end_offset = end_offset * query.num_forecasts
956
+ step_offset = query.target_ast.date_offset_range.end_date_offset
957
+ end_offset = step_offset * query.num_forecasts
958
+
902
959
  if anchor_time is None:
903
- anchor_time = self._graph_store.max_time
960
+ anchor_time = self._get_default_anchor_time(query)
961
+
904
962
  if evaluate:
905
- anchor_time = anchor_time - forecast_end_offset
963
+ anchor_time = anchor_time - end_offset
964
+
906
965
  if logger is not None:
907
966
  assert isinstance(anchor_time, pd.Timestamp)
908
967
  if anchor_time == pd.Timestamp.min:
@@ -916,57 +975,71 @@ class KumoRFM:
916
975
 
917
976
  assert anchor_time is not None
918
977
  if isinstance(anchor_time, pd.Timestamp):
978
+ if context_anchor_time == 'entity':
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
919
981
  if context_anchor_time is None:
920
- context_anchor_time = anchor_time - forecast_end_offset
982
+ context_anchor_time = anchor_time - end_offset
921
983
  self._validate_time(query, anchor_time, context_anchor_time,
922
984
  evaluate)
923
985
  else:
924
986
  assert anchor_time == 'entity'
925
- if query.entity_table not in self._graph_store.time_dict:
987
+ if query.query_type != QueryType.STATIC:
988
+ raise ValueError("Anchor time 'entity' is only valid for "
989
+ "static predictive queries")
990
+ if query.entity_table not in self._sampler.time_column_dict:
926
991
  raise ValueError(f"Anchor time 'entity' requires the entity "
927
992
  f"table '{query.entity_table}' to "
928
993
  f"have a time column")
929
- if context_anchor_time is not None:
930
- warnings.warn("Ignoring option 'context_anchor_time' for "
931
- "`anchor_time='entity'`")
932
- context_anchor_time = None
994
+ if isinstance(context_anchor_time, pd.Timestamp):
995
+ raise ValueError("Anchor time 'entity' needs to be shared "
996
+ "for context and prediction examples")
997
+ context_anchor_time = 'entity'
933
998
 
934
- y_test: Optional[pd.Series] = None
999
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
935
1000
  if evaluate:
936
- max_test_size = _MAX_TEST_SIZE[run_mode]
1001
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
937
1002
  if task_type.is_link_pred:
938
- max_test_size = max_test_size // 5
1003
+ num_test_examples = num_test_examples // 5
1004
+ else:
1005
+ num_test_examples = 0
1006
+
1007
+ train, test = self._sampler.sample_target(
1008
+ query=query,
1009
+ num_train_examples=num_train_examples,
1010
+ train_anchor_time=context_anchor_time,
1011
+ num_train_trials=max_pq_iterations * num_train_examples,
1012
+ num_test_examples=num_test_examples,
1013
+ test_anchor_time=anchor_time,
1014
+ num_test_trials=max_pq_iterations * num_test_examples,
1015
+ random_seed=random_seed,
1016
+ )
1017
+ train_pkey, train_time, y_train = train
1018
+ test_pkey, test_time, y_test = test
939
1019
 
940
- test_node, test_time, y_test = query_driver.collect_test(
941
- size=max_test_size,
942
- anchor_time=anchor_time,
943
- max_iterations=max_pq_iterations,
944
- guarantee_train_examples=True,
945
- )
946
- if logger is not None:
947
- if task_type == TaskType.BINARY_CLASSIFICATION:
948
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
949
- msg = (f"Collected {len(y_test):,} test examples with "
950
- f"{pos:.2f}% positive cases")
951
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
952
- msg = (f"Collected {len(y_test):,} test examples "
953
- f"holding {y_test.nunique()} classes")
954
- elif task_type == TaskType.REGRESSION:
955
- _min, _max = float(y_test.min()), float(y_test.max())
956
- msg = (f"Collected {len(y_test):,} test examples with "
957
- f"targets between {format_value(_min)} and "
958
- f"{format_value(_max)}")
959
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
960
- num_rhs = y_test.explode().nunique()
961
- msg = (f"Collected {len(y_test):,} test examples with "
962
- f"{num_rhs:,} unique items")
963
- else:
964
- raise NotImplementedError
965
- logger.log(msg)
1020
+ if evaluate and logger is not None:
1021
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1022
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
+ msg = (f"Collected {len(y_test):,} test examples with "
1024
+ f"{pos:.2f}% positive cases")
1025
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
+ msg = (f"Collected {len(y_test):,} test examples holding "
1027
+ f"{y_test.nunique()} classes")
1028
+ elif task_type == TaskType.REGRESSION:
1029
+ _min, _max = float(y_test.min()), float(y_test.max())
1030
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1031
+ f"between {format_value(_min)} and "
1032
+ f"{format_value(_max)}")
1033
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
+ num_rhs = y_test.explode().nunique()
1035
+ msg = (f"Collected {len(y_test):,} test examples with "
1036
+ f"{num_rhs:,} unique items")
1037
+ else:
1038
+ raise NotImplementedError
1039
+ logger.log(msg)
966
1040
 
967
- else:
1041
+ if not evaluate:
968
1042
  assert indices is not None
969
-
970
1043
  if len(indices) > _MAX_PRED_SIZE[task_type]:
971
1044
  raise ValueError(f"Cannot predict for more than "
972
1045
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -974,26 +1047,12 @@ class KumoRFM:
974
1047
  f"`KumoRFM.batch_mode` to process entities "
975
1048
  f"in batches")
976
1049
 
977
- test_node = self._graph_store.get_node_id(
978
- table_name=query.entity_table,
979
- pkey=pd.Series(indices),
980
- )
981
-
1050
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
982
1051
  if isinstance(anchor_time, pd.Timestamp):
983
- test_time = pd.Series(anchor_time).repeat(
984
- len(test_node)).reset_index(drop=True)
1052
+ test_time = pd.Series([anchor_time]).repeat(
1053
+ len(indices)).reset_index(drop=True)
985
1054
  else:
986
- time = self._graph_store.time_dict[query.entity_table]
987
- time = time[test_node] * 1000**3
988
- test_time = pd.Series(time, dtype='datetime64[ns]')
989
-
990
- train_node, train_time, y_train = query_driver.collect_train(
991
- size=_MAX_CONTEXT_SIZE[run_mode],
992
- anchor_time=context_anchor_time or 'entity',
993
- exclude_node=test_node if (query.query_type == QueryType.STATIC
994
- or anchor_time == 'entity') else None,
995
- max_iterations=max_pq_iterations,
996
- )
1055
+ train_time = test_time = 'entity'
997
1056
 
998
1057
  if logger is not None:
999
1058
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1016,12 +1075,12 @@ class KumoRFM:
1016
1075
  raise NotImplementedError
1017
1076
  logger.log(msg)
1018
1077
 
1019
- entity_table_names: Tuple[str, ...]
1078
+ entity_table_names: tuple[str, ...]
1020
1079
  if task_type.is_link_pred:
1021
1080
  final_aggr = query.get_final_target_aggregation()
1022
1081
  assert final_aggr is not None
1023
1082
  edge_fkey = final_aggr._get_target_column_name()
1024
- for edge_type in self._graph_store.edge_types:
1083
+ for edge_type in self._sampler.edge_types:
1025
1084
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1026
1085
  entity_table_names = (
1027
1086
  query.entity_table,
@@ -1033,20 +1092,24 @@ class KumoRFM:
1033
1092
  # Exclude the entity anchor time from the feature set to prevent
1034
1093
  # running out-of-distribution between in-context and test examples:
1035
1094
  exclude_cols_dict = query.get_exclude_cols_dict()
1036
- if anchor_time == 'entity':
1095
+ if entity_table_names[0] in self._sampler.time_column_dict:
1037
1096
  if entity_table_names[0] not in exclude_cols_dict:
1038
1097
  exclude_cols_dict[entity_table_names[0]] = []
1039
- time_column_dict = self._graph_store.time_column_dict
1040
- time_column = time_column_dict[entity_table_names[0]]
1098
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1041
1099
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1042
1100
 
1043
- subgraph = self._graph_sampler(
1101
+ subgraph = self._sampler.sample_subgraph(
1044
1102
  entity_table_names=entity_table_names,
1045
- node=np.concatenate([train_node, test_node]),
1046
- time=np.concatenate([
1047
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1048
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1049
- ]),
1103
+ entity_pkey=pd.concat(
1104
+ [train_pkey, test_pkey],
1105
+ axis=0,
1106
+ ignore_index=True,
1107
+ ),
1108
+ anchor_time=pd.concat(
1109
+ [train_time, test_time],
1110
+ axis=0,
1111
+ ignore_index=True,
1112
+ ) if isinstance(train_time, pd.Series) else 'entity',
1050
1113
  num_neighbors=num_neighbors,
1051
1114
  exclude_cols_dict=exclude_cols_dict,
1052
1115
  )
@@ -1058,23 +1121,19 @@ class KumoRFM:
1058
1121
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1059
1122
  f"must go beyond this for your use-case.")
1060
1123
 
1061
- step_size: Optional[int] = None
1062
- if query.query_type == QueryType.TEMPORAL:
1063
- step_size = date_offset_to_seconds(end_offset)
1064
-
1065
1124
  return Context(
1066
1125
  task_type=task_type,
1067
1126
  entity_table_names=entity_table_names,
1068
1127
  subgraph=subgraph,
1069
1128
  y_train=y_train,
1070
- y_test=y_test,
1129
+ y_test=y_test if evaluate else None,
1071
1130
  top_k=query.top_k,
1072
- step_size=step_size,
1131
+ step_size=None,
1073
1132
  )
1074
1133
 
1075
1134
  @staticmethod
1076
1135
  def _validate_metrics(
1077
- metrics: List[str],
1136
+ metrics: list[str],
1078
1137
  task_type: TaskType,
1079
1138
  ) -> None:
1080
1139
 
@@ -1131,7 +1190,7 @@ class KumoRFM:
1131
1190
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1132
1191
 
1133
1192
 
1134
- def format_value(value: Union[int, float]) -> str:
1193
+ def format_value(value: int | float) -> str:
1135
1194
  if value == int(value):
1136
1195
  return f'{int(value):,}'
1137
1196
  if abs(value) >= 1000: