kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/experimental/rfm/__init__.py +49 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  11. kumoai/experimental/rfm/backend/local/table.py +32 -14
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +186 -39
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
  18. kumoai/experimental/rfm/base/__init__.py +23 -3
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +761 -0
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  24. kumoai/experimental/rfm/base/table.py +380 -185
  25. kumoai/experimental/rfm/graph.py +404 -144
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +52 -60
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +283 -230
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/pquery/predictive_query.py +10 -6
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/trainer/distilled_trainer.py +175 -0
  40. kumoai/utils/__init__.py +3 -2
  41. kumoai/utils/display.py +51 -0
  42. kumoai/utils/progress_logger.py +178 -12
  43. kumoai/utils/sql.py +3 -0
  44. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
  45. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
  46. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  47. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  48. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  49. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  50. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -2,25 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,18 +26,14 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
32
30
 
33
31
  from kumoai.client.rfm import RFMAPI
34
32
  from kumoai.exceptions import HTTPException
35
33
  from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
34
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
35
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
36
+ from kumoai.utils import ProgressLogger, display
44
37
 
45
38
  _RANDOM_SEED = 42
46
39
 
@@ -95,24 +88,26 @@ class Explanation:
95
88
  def __getitem__(self, index: Literal[1]) -> str:
96
89
  pass
97
90
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
91
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
92
  if index == 0:
100
93
  return self.prediction
101
94
  if index == 1:
102
95
  return self.summary
103
96
  raise IndexError("Index out of range")
104
97
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
98
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
99
  return iter((self.prediction, self.summary))
107
100
 
108
101
  def __repr__(self) -> str:
109
102
  return str((self.prediction, self.summary))
110
103
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
104
+ def print(self) -> None:
105
+ r"""Prints the explanation."""
106
+ display.dataframe(self.prediction)
107
+ display.message(self.summary)
113
108
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
109
+ def _ipython_display_(self) -> None:
110
+ self.print()
116
111
 
117
112
 
118
113
  class KumoRFM:
@@ -151,20 +146,35 @@ class KumoRFM:
151
146
  Args:
152
147
  graph: The graph.
153
148
  verbose: Whether to print verbose output.
149
+ optimize: If set to ``True``, will optimize the underlying data backend
150
+ for optimal querying. For example, for transactional database
151
+ backends, will create any missing indices. Requires write-access to
152
+ the data backend.
154
153
  """
155
154
  def __init__(
156
155
  self,
157
156
  graph: Graph,
158
- verbose: Union[bool, ProgressLogger] = True,
157
+ verbose: bool | ProgressLogger = True,
158
+ optimize: bool = False,
159
159
  ) -> None:
160
160
  graph = graph.validate()
161
161
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
162
 
165
- self._client: Optional[RFMAPI] = None
163
+ if graph.backend == DataBackend.LOCAL:
164
+ from kumoai.experimental.rfm.backend.local import LocalSampler
165
+ self._sampler: Sampler = LocalSampler(graph, verbose)
166
+ elif graph.backend == DataBackend.SQLITE:
167
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
168
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
169
+ elif graph.backend == DataBackend.SNOWFLAKE:
170
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
171
+ self._sampler = SnowSampler(graph, verbose)
172
+ else:
173
+ raise NotImplementedError
166
174
 
167
- self._batch_size: Optional[int | Literal['max']] = None
175
+ self._client: RFMAPI | None = None
176
+
177
+ self._batch_size: int | Literal['max'] | None = None
168
178
  self.num_retries: int = 0
169
179
 
170
180
  @property
@@ -182,7 +192,7 @@ class KumoRFM:
182
192
  @contextmanager
183
193
  def batch_mode(
184
194
  self,
185
- batch_size: Union[int, Literal['max']] = 'max',
195
+ batch_size: int | Literal['max'] = 'max',
186
196
  num_retries: int = 1,
187
197
  ) -> Generator[None, None, None]:
188
198
  """Context manager to predict in batches.
@@ -216,17 +226,17 @@ class KumoRFM:
216
226
  def predict(
217
227
  self,
218
228
  query: str,
219
- indices: Union[List[str], List[float], List[int], None] = None,
229
+ indices: list[str] | list[float] | list[int] | None = None,
220
230
  *,
221
231
  explain: Literal[False] = False,
222
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
- context_anchor_time: Union[pd.Timestamp, None] = None,
224
- run_mode: Union[RunMode, str] = RunMode.FAST,
225
- num_neighbors: Optional[List[int]] = None,
232
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
233
+ context_anchor_time: pd.Timestamp | None = None,
234
+ run_mode: RunMode | str = RunMode.FAST,
235
+ num_neighbors: list[int] | None = None,
226
236
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
228
- random_seed: Optional[int] = _RANDOM_SEED,
229
- verbose: Union[bool, ProgressLogger] = True,
237
+ max_pq_iterations: int = 10,
238
+ random_seed: int | None = _RANDOM_SEED,
239
+ verbose: bool | ProgressLogger = True,
230
240
  use_prediction_time: bool = False,
231
241
  ) -> pd.DataFrame:
232
242
  pass
@@ -235,17 +245,17 @@ class KumoRFM:
235
245
  def predict(
236
246
  self,
237
247
  query: str,
238
- indices: Union[List[str], List[float], List[int], None] = None,
248
+ indices: list[str] | list[float] | list[int] | None = None,
239
249
  *,
240
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
- context_anchor_time: Union[pd.Timestamp, None] = None,
243
- run_mode: Union[RunMode, str] = RunMode.FAST,
244
- num_neighbors: Optional[List[int]] = None,
250
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
251
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
252
+ context_anchor_time: pd.Timestamp | None = None,
253
+ run_mode: RunMode | str = RunMode.FAST,
254
+ num_neighbors: list[int] | None = None,
245
255
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
247
- random_seed: Optional[int] = _RANDOM_SEED,
248
- verbose: Union[bool, ProgressLogger] = True,
256
+ max_pq_iterations: int = 10,
257
+ random_seed: int | None = _RANDOM_SEED,
258
+ verbose: bool | ProgressLogger = True,
249
259
  use_prediction_time: bool = False,
250
260
  ) -> Explanation:
251
261
  pass
@@ -253,19 +263,19 @@ class KumoRFM:
253
263
  def predict(
254
264
  self,
255
265
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
266
+ indices: list[str] | list[float] | list[int] | None = None,
257
267
  *,
258
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
268
+ explain: bool | ExplainConfig | dict[str, Any] = False,
269
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
270
+ context_anchor_time: pd.Timestamp | None = None,
271
+ run_mode: RunMode | str = RunMode.FAST,
272
+ num_neighbors: list[int] | None = None,
263
273
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
274
+ max_pq_iterations: int = 10,
275
+ random_seed: int | None = _RANDOM_SEED,
276
+ verbose: bool | ProgressLogger = True,
267
277
  use_prediction_time: bool = False,
268
- ) -> Union[pd.DataFrame, Explanation]:
278
+ ) -> pd.DataFrame | Explanation:
269
279
  """Returns predictions for a predictive query.
270
280
 
271
281
  Args:
@@ -307,7 +317,7 @@ class KumoRFM:
307
317
  If ``explain`` is provided, returns an :class:`Explanation` object
308
318
  containing the prediction, summary, and details.
309
319
  """
310
- explain_config: Optional[ExplainConfig] = None
320
+ explain_config: ExplainConfig | None = None
311
321
  if explain is True:
312
322
  explain_config = ExplainConfig()
313
323
  elif explain is not False:
@@ -351,15 +361,15 @@ class KumoRFM:
351
361
  msg = f'[bold]PREDICT[/bold] {query_repr}'
352
362
 
353
363
  if not isinstance(verbose, ProgressLogger):
354
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
364
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
355
365
 
356
366
  with verbose as logger:
357
367
 
358
- batch_size: Optional[int] = None
368
+ batch_size: int | None = None
359
369
  if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
370
+ task_type = self._get_task_type(
371
+ query=query_def,
372
+ edge_types=self._sampler.edge_types,
363
373
  )
364
374
  batch_size = _MAX_PRED_SIZE[task_type]
365
375
  else:
@@ -375,9 +385,9 @@ class KumoRFM:
375
385
  logger.log(f"Splitting {len(indices):,} entities into "
376
386
  f"{len(batches):,} batches of size {batch_size:,}")
377
387
 
378
- predictions: List[pd.DataFrame] = []
379
- summary: Optional[str] = None
380
- details: Optional[Explanation] = None
388
+ predictions: list[pd.DataFrame] = []
389
+ summary: str | None = None
390
+ details: Explanation | None = None
381
391
  for i, batch in enumerate(batches):
382
392
  # TODO Re-use the context for subsequent predictions.
383
393
  context = self._get_context(
@@ -411,8 +421,7 @@ class KumoRFM:
411
421
  stats = Context.get_memory_stats(request_msg.context)
412
422
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
423
 
414
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
- and len(batches) > 1):
424
+ if i == 0 and len(batches) > 1:
416
425
  verbose.init_progress(
417
426
  total=len(batches),
418
427
  description='Predicting',
@@ -433,10 +442,10 @@ class KumoRFM:
433
442
 
434
443
  # Cast 'ENTITY' to correct data type:
435
444
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
445
+ table_dict = context.subgraph.table_dict
446
+ table = table_dict[query_def.entity_table]
447
+ ser = table.df[table.primary_key]
448
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
449
 
441
450
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
451
  if 'ANCHOR_TIMESTAMP' in df:
@@ -451,8 +460,7 @@ class KumoRFM:
451
460
 
452
461
  predictions.append(df)
453
462
 
454
- if (isinstance(verbose, InteractiveProgressLogger)
455
- and len(batches) > 1):
463
+ if len(batches) > 1:
456
464
  verbose.step()
457
465
 
458
466
  break
@@ -490,9 +498,9 @@ class KumoRFM:
490
498
  def is_valid_entity(
491
499
  self,
492
500
  query: str,
493
- indices: Union[List[str], List[float], List[int], None] = None,
501
+ indices: list[str] | list[float] | list[int] | None = None,
494
502
  *,
495
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
503
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
496
504
  ) -> np.ndarray:
497
505
  r"""Returns a mask that denotes which entities are valid for the
498
506
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -519,37 +527,32 @@ class KumoRFM:
519
527
  raise ValueError("At least one entity is required")
520
528
 
521
529
  if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
530
+ anchor_time = self._get_default_anchor_time(query_def)
523
531
 
524
532
  if isinstance(anchor_time, pd.Timestamp):
525
533
  self._validate_time(query_def, anchor_time, None, False)
526
534
  else:
527
535
  assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
536
+ if query_def.entity_table not in self._sampler.time_column_dict:
529
537
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
538
  f"table '{query_def.entity_table}' "
531
539
  f"to have a time column.")
532
540
 
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
541
+ raise NotImplementedError
539
542
 
540
543
  def evaluate(
541
544
  self,
542
545
  query: str,
543
546
  *,
544
- metrics: Optional[List[str]] = None,
545
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
- context_anchor_time: Union[pd.Timestamp, None] = None,
547
- run_mode: Union[RunMode, str] = RunMode.FAST,
548
- num_neighbors: Optional[List[int]] = None,
547
+ metrics: list[str] | None = None,
548
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
549
+ context_anchor_time: pd.Timestamp | None = None,
550
+ run_mode: RunMode | str = RunMode.FAST,
551
+ num_neighbors: list[int] | None = None,
549
552
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
551
- random_seed: Optional[int] = _RANDOM_SEED,
552
- verbose: Union[bool, ProgressLogger] = True,
553
+ max_pq_iterations: int = 10,
554
+ random_seed: int | None = _RANDOM_SEED,
555
+ verbose: bool | ProgressLogger = True,
553
556
  use_prediction_time: bool = False,
554
557
  ) -> pd.DataFrame:
555
558
  """Evaluates a predictive query.
@@ -597,7 +600,7 @@ class KumoRFM:
597
600
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
598
601
 
599
602
  if not isinstance(verbose, ProgressLogger):
600
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
603
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
601
604
 
602
605
  with verbose as logger:
603
606
  context = self._get_context(
@@ -656,9 +659,9 @@ class KumoRFM:
656
659
  query: str,
657
660
  size: int,
658
661
  *,
659
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
- random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
662
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
663
+ random_seed: int | None = _RANDOM_SEED,
664
+ max_iterations: int = 10,
662
665
  ) -> pd.DataFrame:
663
666
  """Returns the labels of a predictive query for a specified anchor
664
667
  time.
@@ -678,40 +681,37 @@ class KumoRFM:
678
681
  query_def = self._parse_query(query)
679
682
 
680
683
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
684
+ anchor_time = self._get_default_anchor_time(query_def)
682
685
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
686
+ offset = query_def.target_ast.date_offset_range.end_date_offset
687
+ offset *= query_def.num_forecasts
688
+ anchor_time -= offset
686
689
 
687
690
  assert anchor_time is not None
688
691
  if isinstance(anchor_time, pd.Timestamp):
689
692
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
693
  else:
691
694
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
695
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
696
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
697
  f"table '{query_def.entity_table}' "
695
698
  f"to have a time column")
696
699
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
700
+ train, test = self._sampler.sample_target(
701
+ query=query_def,
702
+ num_train_examples=0,
703
+ train_anchor_time=anchor_time,
704
+ num_train_trials=0,
705
+ num_test_examples=size,
706
+ test_anchor_time=anchor_time,
707
+ num_test_trials=max_iterations * size,
708
+ random_seed=random_seed,
706
709
  )
707
710
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
711
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
712
+ 'ENTITY': test.entity_pkey,
713
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
714
+ 'TARGET': test.target,
715
715
  })
716
716
 
717
717
  # Helpers #################################################################
@@ -734,8 +734,6 @@ class KumoRFM:
734
734
 
735
735
  resp = self._api_client.parse_query(request)
736
736
 
737
- # TODO Expose validation warnings.
738
-
739
737
  if len(resp.validation_response.warnings) > 0:
740
738
  msg = '\n'.join([
741
739
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -753,36 +751,92 @@ class KumoRFM:
753
751
  raise ValueError(f"Failed to parse query '{query}'. "
754
752
  f"{msg}") from None
755
753
 
754
+ @staticmethod
755
+ def _get_task_type(
756
+ query: ValidatedPredictiveQuery,
757
+ edge_types: list[tuple[str, str, str]],
758
+ ) -> TaskType:
759
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
760
+ return TaskType.BINARY_CLASSIFICATION
761
+
762
+ target = query.target_ast
763
+ if isinstance(target, Join):
764
+ target = target.rhs_target
765
+ if isinstance(target, Aggregation):
766
+ if target.aggr == AggregationType.LIST_DISTINCT:
767
+ table_name, col_name = target._get_target_column_name().split(
768
+ '.')
769
+ target_edge_types = [
770
+ edge_type for edge_type in edge_types
771
+ if edge_type[0] == table_name and edge_type[1] == col_name
772
+ ]
773
+ if len(target_edge_types) != 1:
774
+ raise NotImplementedError(
775
+ f"Multilabel-classification queries based on "
776
+ f"'LIST_DISTINCT' are not supported yet. If you "
777
+ f"planned to write a link prediction query instead, "
778
+ f"make sure to register '{col_name}' as a "
779
+ f"foreign key.")
780
+ return TaskType.TEMPORAL_LINK_PREDICTION
781
+
782
+ return TaskType.REGRESSION
783
+
784
+ assert isinstance(target, Column)
785
+
786
+ if target.stype in {Stype.ID, Stype.categorical}:
787
+ return TaskType.MULTICLASS_CLASSIFICATION
788
+
789
+ if target.stype in {Stype.numerical}:
790
+ return TaskType.REGRESSION
791
+
792
+ raise NotImplementedError("Task type not yet supported")
793
+
794
+ def _get_default_anchor_time(
795
+ self,
796
+ query: ValidatedPredictiveQuery,
797
+ ) -> pd.Timestamp:
798
+ if query.query_type == QueryType.TEMPORAL:
799
+ aggr_table_names = [
800
+ aggr._get_target_column_name().split('.')[0]
801
+ for aggr in query.get_all_target_aggregations()
802
+ ]
803
+ return self._sampler.get_max_time(aggr_table_names)
804
+
805
+ assert query.query_type == QueryType.STATIC
806
+ return self._sampler.get_max_time()
807
+
756
808
  def _validate_time(
757
809
  self,
758
810
  query: ValidatedPredictiveQuery,
759
811
  anchor_time: pd.Timestamp,
760
- context_anchor_time: Union[pd.Timestamp, None],
812
+ context_anchor_time: pd.Timestamp | None,
761
813
  evaluate: bool,
762
814
  ) -> None:
763
815
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
816
+ if len(self._sampler.time_column_dict) == 0:
765
817
  return # Graph without timestamps
766
818
 
767
- if anchor_time < self._graph_store.min_time:
819
+ min_time = self._sampler.get_min_time()
820
+ max_time = self._sampler.get_max_time()
821
+
822
+ if anchor_time < min_time:
768
823
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
824
+ f"the earliest timestamp '{min_time}' in the "
825
+ f"data.")
771
826
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
827
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
828
  raise ValueError(f"Context anchor timestamp is too early or "
775
829
  f"aggregation time range is too large. To make "
776
830
  f"this prediction, we would need data back to "
777
831
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
832
+ f"only contains data back to '{min_time}'.")
780
833
 
781
834
  if query.target_ast.date_offset_range is not None:
782
835
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
836
  else:
784
837
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
838
+ end_offset = end_offset * query.num_forecasts
839
+
786
840
  if (context_anchor_time is not None
787
841
  and context_anchor_time > anchor_time):
788
842
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +846,7 @@ class KumoRFM:
792
846
  f"intended.")
793
847
  elif (query.query_type == QueryType.TEMPORAL
794
848
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
849
+ and context_anchor_time + end_offset > anchor_time):
796
850
  warnings.warn(f"Aggregation for context examples at timestamp "
797
851
  f"'{context_anchor_time}' will leak information "
798
852
  f"from the prediction anchor timestamp "
@@ -800,40 +854,37 @@ class KumoRFM:
800
854
  f"intended.")
801
855
 
802
856
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
857
+ and context_anchor_time - end_offset < min_time):
858
+ _time = context_anchor_time - end_offset
806
859
  warnings.warn(f"Context anchor timestamp is too early or "
807
860
  f"aggregation time range is too large. To form "
808
861
  f"proper input data, we would need data back to "
809
862
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
863
+ f"data back to '{min_time}'.")
811
864
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
865
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
866
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
867
+ f"latest timestamp '{max_time}' in the data. Please "
868
+ f"make sure this is intended.")
817
869
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
870
+ if evaluate and anchor_time > max_time - end_offset:
820
871
  raise ValueError(
821
872
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
873
+ f"supported timestamp '{max_time - end_offset}'.")
823
874
 
824
875
  def _get_context(
825
876
  self,
826
877
  query: ValidatedPredictiveQuery,
827
- indices: Union[List[str], List[float], List[int], None],
828
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
- context_anchor_time: Union[pd.Timestamp, None],
878
+ indices: list[str] | list[float] | list[int] | None,
879
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
880
+ context_anchor_time: pd.Timestamp | None,
830
881
  run_mode: RunMode,
831
- num_neighbors: Optional[List[int]],
882
+ num_neighbors: list[int] | None,
832
883
  num_hops: int,
833
884
  max_pq_iterations: int,
834
885
  evaluate: bool,
835
- random_seed: Optional[int] = _RANDOM_SEED,
836
- logger: Optional[ProgressLogger] = None,
886
+ random_seed: int | None = _RANDOM_SEED,
887
+ logger: ProgressLogger | None = None,
837
888
  ) -> Context:
838
889
 
839
890
  if num_neighbors is not None:
@@ -850,10 +901,9 @@ class KumoRFM:
850
901
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
902
  f"must go beyond this for your use-case.")
852
903
 
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
904
+ task_type = self._get_task_type(
905
+ query=query,
906
+ edge_types=self._sampler.edge_types,
857
907
  )
858
908
 
859
909
  if logger is not None:
@@ -885,14 +935,17 @@ class KumoRFM:
885
935
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
936
 
887
937
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
938
+ step_offset = pd.DateOffset(0)
889
939
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
940
+ step_offset = query.target_ast.date_offset_range.end_date_offset
941
+ end_offset = step_offset * query.num_forecasts
942
+
892
943
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
944
+ anchor_time = self._get_default_anchor_time(query)
945
+
894
946
  if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
947
+ anchor_time = anchor_time - end_offset
948
+
896
949
  if logger is not None:
897
950
  assert isinstance(anchor_time, pd.Timestamp)
898
951
  if anchor_time == pd.Timestamp.min:
@@ -906,57 +959,71 @@ class KumoRFM:
906
959
 
907
960
  assert anchor_time is not None
908
961
  if isinstance(anchor_time, pd.Timestamp):
962
+ if context_anchor_time == 'entity':
963
+ raise ValueError("Anchor time 'entity' needs to be shared "
964
+ "for context and prediction examples")
909
965
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
966
+ context_anchor_time = anchor_time - end_offset
911
967
  self._validate_time(query, anchor_time, context_anchor_time,
912
968
  evaluate)
913
969
  else:
914
970
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
971
+ if query.query_type != QueryType.STATIC:
972
+ raise ValueError("Anchor time 'entity' is only valid for "
973
+ "static predictive queries")
974
+ if query.entity_table not in self._sampler.time_column_dict:
916
975
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
976
  f"table '{query.entity_table}' to "
918
977
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
978
+ if isinstance(context_anchor_time, pd.Timestamp):
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
981
+ context_anchor_time = 'entity'
923
982
 
924
- y_test: Optional[pd.Series] = None
983
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
925
984
  if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
985
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
927
986
  if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
987
+ num_test_examples = num_test_examples // 5
988
+ else:
989
+ num_test_examples = 0
990
+
991
+ train, test = self._sampler.sample_target(
992
+ query=query,
993
+ num_train_examples=num_train_examples,
994
+ train_anchor_time=context_anchor_time,
995
+ num_train_trials=max_pq_iterations * num_train_examples,
996
+ num_test_examples=num_test_examples,
997
+ test_anchor_time=anchor_time,
998
+ num_test_trials=max_pq_iterations * num_test_examples,
999
+ random_seed=random_seed,
1000
+ )
1001
+ train_pkey, train_time, y_train = train
1002
+ test_pkey, test_time, y_test = test
929
1003
 
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
1004
+ if evaluate and logger is not None:
1005
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1006
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1007
+ msg = (f"Collected {len(y_test):,} test examples with "
1008
+ f"{pos:.2f}% positive cases")
1009
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1010
+ msg = (f"Collected {len(y_test):,} test examples holding "
1011
+ f"{y_test.nunique()} classes")
1012
+ elif task_type == TaskType.REGRESSION:
1013
+ _min, _max = float(y_test.min()), float(y_test.max())
1014
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1015
+ f"between {format_value(_min)} and "
1016
+ f"{format_value(_max)}")
1017
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1018
+ num_rhs = y_test.explode().nunique()
1019
+ msg = (f"Collected {len(y_test):,} test examples with "
1020
+ f"{num_rhs:,} unique items")
1021
+ else:
1022
+ raise NotImplementedError
1023
+ logger.log(msg)
956
1024
 
957
- else:
1025
+ if not evaluate:
958
1026
  assert indices is not None
959
-
960
1027
  if len(indices) > _MAX_PRED_SIZE[task_type]:
961
1028
  raise ValueError(f"Cannot predict for more than "
962
1029
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -964,26 +1031,12 @@ class KumoRFM:
964
1031
  f"`KumoRFM.batch_mode` to process entities "
965
1032
  f"in batches")
966
1033
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
971
-
1034
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1035
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1036
+ test_time = pd.Series([anchor_time]).repeat(
1037
+ len(indices)).reset_index(drop=True)
975
1038
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1039
+ train_time = test_time = 'entity'
987
1040
 
988
1041
  if logger is not None:
989
1042
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1006,12 +1059,12 @@ class KumoRFM:
1006
1059
  raise NotImplementedError
1007
1060
  logger.log(msg)
1008
1061
 
1009
- entity_table_names: Tuple[str, ...]
1062
+ entity_table_names: tuple[str, ...]
1010
1063
  if task_type.is_link_pred:
1011
1064
  final_aggr = query.get_final_target_aggregation()
1012
1065
  assert final_aggr is not None
1013
1066
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1067
+ for edge_type in self._sampler.edge_types:
1015
1068
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1069
  entity_table_names = (
1017
1070
  query.entity_table,
@@ -1023,20 +1076,24 @@ class KumoRFM:
1023
1076
  # Exclude the entity anchor time from the feature set to prevent
1024
1077
  # running out-of-distribution between in-context and test examples:
1025
1078
  exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1079
+ if entity_table_names[0] in self._sampler.time_column_dict:
1027
1080
  if entity_table_names[0] not in exclude_cols_dict:
1028
1081
  exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1082
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1031
1083
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
1084
 
1033
- subgraph = self._graph_sampler(
1085
+ subgraph = self._sampler.sample_subgraph(
1034
1086
  entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1087
+ entity_pkey=pd.concat(
1088
+ [train_pkey, test_pkey],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ),
1092
+ anchor_time=pd.concat(
1093
+ [train_time, test_time],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ) if isinstance(train_time, pd.Series) else 'entity',
1040
1097
  num_neighbors=num_neighbors,
1041
1098
  exclude_cols_dict=exclude_cols_dict,
1042
1099
  )
@@ -1048,23 +1105,19 @@ class KumoRFM:
1048
1105
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1106
  f"must go beyond this for your use-case.")
1050
1107
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1054
-
1055
1108
  return Context(
1056
1109
  task_type=task_type,
1057
1110
  entity_table_names=entity_table_names,
1058
1111
  subgraph=subgraph,
1059
1112
  y_train=y_train,
1060
- y_test=y_test,
1113
+ y_test=y_test if evaluate else None,
1061
1114
  top_k=query.top_k,
1062
- step_size=step_size,
1115
+ step_size=None,
1063
1116
  )
1064
1117
 
1065
1118
  @staticmethod
1066
1119
  def _validate_metrics(
1067
- metrics: List[str],
1120
+ metrics: list[str],
1068
1121
  task_type: TaskType,
1069
1122
  ) -> None:
1070
1123
 
@@ -1121,7 +1174,7 @@ class KumoRFM:
1121
1174
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1122
1175
 
1123
1176
 
1124
- def format_value(value: Union[int, float]) -> str:
1177
+ def format_value(value: int | float) -> str:
1125
1178
  if value == int(value):
1126
1179
  return f'{int(value):,}'
1127
1180
  if abs(value) >= 1000: