kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +128 -55
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +24 -5
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +429 -30
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/top_level.txt +0 -0
@@ -2,25 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
32
30
 
31
+ from kumoai import in_notebook, in_snowflake_notebook
33
32
  from kumoai.client.rfm import RFMAPI
34
33
  from kumoai.exceptions import HTTPException
35
34
  from kumoai.experimental.rfm import Graph
36
- from kumoai.experimental.rfm.backend.local import LocalGraphStore
37
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
35
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
36
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger
44
38
 
45
39
  _RANDOM_SEED = 42
46
40
 
@@ -95,24 +89,41 @@ class Explanation:
95
89
  def __getitem__(self, index: Literal[1]) -> str:
96
90
  pass
97
91
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
92
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
93
  if index == 0:
100
94
  return self.prediction
101
95
  if index == 1:
102
96
  return self.summary
103
97
  raise IndexError("Index out of range")
104
98
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
99
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
100
  return iter((self.prediction, self.summary))
107
101
 
108
102
  def __repr__(self) -> str:
109
103
  return str((self.prediction, self.summary))
110
104
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
105
+ def print(self) -> None:
106
+ r"""Prints the explanation."""
107
+ if in_snowflake_notebook():
108
+ import streamlit as st
109
+ st.dataframe(self.prediction, hide_index=True)
110
+ st.markdown(self.summary)
111
+ elif in_notebook():
112
+ from IPython.display import Markdown, display
113
+ try:
114
+ if hasattr(self.prediction.style, 'hide'):
115
+ display(self.prediction.hide(axis='index')) # pandas=2
116
+ else:
117
+ display(self.prediction.hide_index()) # pandas <1.3
118
+ except ImportError:
119
+ print(self.prediction.to_string(index=False)) # missing jinja2
120
+ display(Markdown(self.summary))
121
+ else:
122
+ print(self.prediction.to_string(index=False))
123
+ print(self.summary)
113
124
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
125
+ def _ipython_display_(self) -> None:
126
+ self.print()
116
127
 
117
128
 
118
129
  class KumoRFM:
@@ -151,20 +162,35 @@ class KumoRFM:
151
162
  Args:
152
163
  graph: The graph.
153
164
  verbose: Whether to print verbose output.
165
+ optimize: If set to ``True``, will optimize the underlying data backend
166
+ for optimal querying. For example, for transactional database
167
+ backends, will create any missing indices. Requires write-access to
168
+ the data backend.
154
169
  """
155
170
  def __init__(
156
171
  self,
157
172
  graph: Graph,
158
- verbose: Union[bool, ProgressLogger] = True,
173
+ verbose: bool | ProgressLogger = True,
174
+ optimize: bool = False,
159
175
  ) -> None:
160
176
  graph = graph.validate()
161
177
  self._graph_def = graph._to_api_graph_definition()
162
- self._graph_store = LocalGraphStore(graph, verbose)
163
- self._graph_sampler = LocalGraphSampler(self._graph_store)
164
178
 
165
- self._client: Optional[RFMAPI] = None
179
+ if graph.backend == DataBackend.LOCAL:
180
+ from kumoai.experimental.rfm.backend.local import LocalSampler
181
+ self._sampler: Sampler = LocalSampler(graph, verbose)
182
+ elif graph.backend == DataBackend.SQLITE:
183
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
184
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
185
+ elif graph.backend == DataBackend.SNOWFLAKE:
186
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
187
+ self._sampler = SnowSampler(graph, verbose)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ self._client: RFMAPI | None = None
166
192
 
167
- self._batch_size: Optional[int | Literal['max']] = None
193
+ self._batch_size: int | Literal['max'] | None = None
168
194
  self.num_retries: int = 0
169
195
 
170
196
  @property
@@ -182,7 +208,7 @@ class KumoRFM:
182
208
  @contextmanager
183
209
  def batch_mode(
184
210
  self,
185
- batch_size: Union[int, Literal['max']] = 'max',
211
+ batch_size: int | Literal['max'] = 'max',
186
212
  num_retries: int = 1,
187
213
  ) -> Generator[None, None, None]:
188
214
  """Context manager to predict in batches.
@@ -216,17 +242,17 @@ class KumoRFM:
216
242
  def predict(
217
243
  self,
218
244
  query: str,
219
- indices: Union[List[str], List[float], List[int], None] = None,
245
+ indices: list[str] | list[float] | list[int] | None = None,
220
246
  *,
221
247
  explain: Literal[False] = False,
222
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
223
- context_anchor_time: Union[pd.Timestamp, None] = None,
224
- run_mode: Union[RunMode, str] = RunMode.FAST,
225
- num_neighbors: Optional[List[int]] = None,
248
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
249
+ context_anchor_time: pd.Timestamp | None = None,
250
+ run_mode: RunMode | str = RunMode.FAST,
251
+ num_neighbors: list[int] | None = None,
226
252
  num_hops: int = 2,
227
- max_pq_iterations: int = 20,
228
- random_seed: Optional[int] = _RANDOM_SEED,
229
- verbose: Union[bool, ProgressLogger] = True,
253
+ max_pq_iterations: int = 10,
254
+ random_seed: int | None = _RANDOM_SEED,
255
+ verbose: bool | ProgressLogger = True,
230
256
  use_prediction_time: bool = False,
231
257
  ) -> pd.DataFrame:
232
258
  pass
@@ -235,17 +261,17 @@ class KumoRFM:
235
261
  def predict(
236
262
  self,
237
263
  query: str,
238
- indices: Union[List[str], List[float], List[int], None] = None,
264
+ indices: list[str] | list[float] | list[int] | None = None,
239
265
  *,
240
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
241
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
242
- context_anchor_time: Union[pd.Timestamp, None] = None,
243
- run_mode: Union[RunMode, str] = RunMode.FAST,
244
- num_neighbors: Optional[List[int]] = None,
266
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
267
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
268
+ context_anchor_time: pd.Timestamp | None = None,
269
+ run_mode: RunMode | str = RunMode.FAST,
270
+ num_neighbors: list[int] | None = None,
245
271
  num_hops: int = 2,
246
- max_pq_iterations: int = 20,
247
- random_seed: Optional[int] = _RANDOM_SEED,
248
- verbose: Union[bool, ProgressLogger] = True,
272
+ max_pq_iterations: int = 10,
273
+ random_seed: int | None = _RANDOM_SEED,
274
+ verbose: bool | ProgressLogger = True,
249
275
  use_prediction_time: bool = False,
250
276
  ) -> Explanation:
251
277
  pass
@@ -253,19 +279,19 @@ class KumoRFM:
253
279
  def predict(
254
280
  self,
255
281
  query: str,
256
- indices: Union[List[str], List[float], List[int], None] = None,
282
+ indices: list[str] | list[float] | list[int] | None = None,
257
283
  *,
258
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
259
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
260
- context_anchor_time: Union[pd.Timestamp, None] = None,
261
- run_mode: Union[RunMode, str] = RunMode.FAST,
262
- num_neighbors: Optional[List[int]] = None,
284
+ explain: bool | ExplainConfig | dict[str, Any] = False,
285
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
286
+ context_anchor_time: pd.Timestamp | None = None,
287
+ run_mode: RunMode | str = RunMode.FAST,
288
+ num_neighbors: list[int] | None = None,
263
289
  num_hops: int = 2,
264
- max_pq_iterations: int = 20,
265
- random_seed: Optional[int] = _RANDOM_SEED,
266
- verbose: Union[bool, ProgressLogger] = True,
290
+ max_pq_iterations: int = 10,
291
+ random_seed: int | None = _RANDOM_SEED,
292
+ verbose: bool | ProgressLogger = True,
267
293
  use_prediction_time: bool = False,
268
- ) -> Union[pd.DataFrame, Explanation]:
294
+ ) -> pd.DataFrame | Explanation:
269
295
  """Returns predictions for a predictive query.
270
296
 
271
297
  Args:
@@ -307,7 +333,7 @@ class KumoRFM:
307
333
  If ``explain`` is provided, returns an :class:`Explanation` object
308
334
  containing the prediction, summary, and details.
309
335
  """
310
- explain_config: Optional[ExplainConfig] = None
336
+ explain_config: ExplainConfig | None = None
311
337
  if explain is True:
312
338
  explain_config = ExplainConfig()
313
339
  elif explain is not False:
@@ -351,15 +377,15 @@ class KumoRFM:
351
377
  msg = f'[bold]PREDICT[/bold] {query_repr}'
352
378
 
353
379
  if not isinstance(verbose, ProgressLogger):
354
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
380
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
355
381
 
356
382
  with verbose as logger:
357
383
 
358
- batch_size: Optional[int] = None
384
+ batch_size: int | None = None
359
385
  if self._batch_size == 'max':
360
- task_type = LocalPQueryDriver.get_task_type(
361
- query_def,
362
- edge_types=self._graph_store.edge_types,
386
+ task_type = self._get_task_type(
387
+ query=query_def,
388
+ edge_types=self._sampler.edge_types,
363
389
  )
364
390
  batch_size = _MAX_PRED_SIZE[task_type]
365
391
  else:
@@ -375,9 +401,9 @@ class KumoRFM:
375
401
  logger.log(f"Splitting {len(indices):,} entities into "
376
402
  f"{len(batches):,} batches of size {batch_size:,}")
377
403
 
378
- predictions: List[pd.DataFrame] = []
379
- summary: Optional[str] = None
380
- details: Optional[Explanation] = None
404
+ predictions: list[pd.DataFrame] = []
405
+ summary: str | None = None
406
+ details: Explanation | None = None
381
407
  for i, batch in enumerate(batches):
382
408
  # TODO Re-use the context for subsequent predictions.
383
409
  context = self._get_context(
@@ -411,8 +437,7 @@ class KumoRFM:
411
437
  stats = Context.get_memory_stats(request_msg.context)
412
438
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
413
439
 
414
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
415
- and len(batches) > 1):
440
+ if i == 0 and len(batches) > 1:
416
441
  verbose.init_progress(
417
442
  total=len(batches),
418
443
  description='Predicting',
@@ -433,10 +458,10 @@ class KumoRFM:
433
458
 
434
459
  # Cast 'ENTITY' to correct data type:
435
460
  if 'ENTITY' in df:
436
- entity = query_def.entity_table
437
- pkey_map = self._graph_store.pkey_map_dict[entity]
438
- df['ENTITY'] = df['ENTITY'].astype(
439
- type(pkey_map.index[0]))
461
+ table_dict = context.subgraph.table_dict
462
+ table = table_dict[query_def.entity_table]
463
+ ser = table.df[table.primary_key]
464
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
440
465
 
441
466
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
442
467
  if 'ANCHOR_TIMESTAMP' in df:
@@ -451,8 +476,7 @@ class KumoRFM:
451
476
 
452
477
  predictions.append(df)
453
478
 
454
- if (isinstance(verbose, InteractiveProgressLogger)
455
- and len(batches) > 1):
479
+ if len(batches) > 1:
456
480
  verbose.step()
457
481
 
458
482
  break
@@ -490,9 +514,9 @@ class KumoRFM:
490
514
  def is_valid_entity(
491
515
  self,
492
516
  query: str,
493
- indices: Union[List[str], List[float], List[int], None] = None,
517
+ indices: list[str] | list[float] | list[int] | None = None,
494
518
  *,
495
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
519
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
496
520
  ) -> np.ndarray:
497
521
  r"""Returns a mask that denotes which entities are valid for the
498
522
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -519,37 +543,32 @@ class KumoRFM:
519
543
  raise ValueError("At least one entity is required")
520
544
 
521
545
  if anchor_time is None:
522
- anchor_time = self._graph_store.max_time
546
+ anchor_time = self._get_default_anchor_time(query_def)
523
547
 
524
548
  if isinstance(anchor_time, pd.Timestamp):
525
549
  self._validate_time(query_def, anchor_time, None, False)
526
550
  else:
527
551
  assert anchor_time == 'entity'
528
- if (query_def.entity_table not in self._graph_store.time_dict):
552
+ if query_def.entity_table not in self._sampler.time_column_dict:
529
553
  raise ValueError(f"Anchor time 'entity' requires the entity "
530
554
  f"table '{query_def.entity_table}' "
531
555
  f"to have a time column.")
532
556
 
533
- node = self._graph_store.get_node_id(
534
- table_name=query_def.entity_table,
535
- pkey=pd.Series(indices),
536
- )
537
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
538
- return query_driver.is_valid(node, anchor_time)
557
+ raise NotImplementedError
539
558
 
540
559
  def evaluate(
541
560
  self,
542
561
  query: str,
543
562
  *,
544
- metrics: Optional[List[str]] = None,
545
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
546
- context_anchor_time: Union[pd.Timestamp, None] = None,
547
- run_mode: Union[RunMode, str] = RunMode.FAST,
548
- num_neighbors: Optional[List[int]] = None,
563
+ metrics: list[str] | None = None,
564
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
565
+ context_anchor_time: pd.Timestamp | None = None,
566
+ run_mode: RunMode | str = RunMode.FAST,
567
+ num_neighbors: list[int] | None = None,
549
568
  num_hops: int = 2,
550
- max_pq_iterations: int = 20,
551
- random_seed: Optional[int] = _RANDOM_SEED,
552
- verbose: Union[bool, ProgressLogger] = True,
569
+ max_pq_iterations: int = 10,
570
+ random_seed: int | None = _RANDOM_SEED,
571
+ verbose: bool | ProgressLogger = True,
553
572
  use_prediction_time: bool = False,
554
573
  ) -> pd.DataFrame:
555
574
  """Evaluates a predictive query.
@@ -597,7 +616,7 @@ class KumoRFM:
597
616
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
598
617
 
599
618
  if not isinstance(verbose, ProgressLogger):
600
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
619
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
601
620
 
602
621
  with verbose as logger:
603
622
  context = self._get_context(
@@ -656,9 +675,9 @@ class KumoRFM:
656
675
  query: str,
657
676
  size: int,
658
677
  *,
659
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
660
- random_seed: Optional[int] = _RANDOM_SEED,
661
- max_iterations: int = 20,
678
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
679
+ random_seed: int | None = _RANDOM_SEED,
680
+ max_iterations: int = 10,
662
681
  ) -> pd.DataFrame:
663
682
  """Returns the labels of a predictive query for a specified anchor
664
683
  time.
@@ -678,40 +697,37 @@ class KumoRFM:
678
697
  query_def = self._parse_query(query)
679
698
 
680
699
  if anchor_time is None:
681
- anchor_time = self._graph_store.max_time
700
+ anchor_time = self._get_default_anchor_time(query_def)
682
701
  if query_def.target_ast.date_offset_range is not None:
683
- anchor_time = anchor_time - (
684
- query_def.target_ast.date_offset_range.end_date_offset *
685
- query_def.num_forecasts)
702
+ offset = query_def.target_ast.date_offset_range.end_date_offset
703
+ offset *= query_def.num_forecasts
704
+ anchor_time -= offset
686
705
 
687
706
  assert anchor_time is not None
688
707
  if isinstance(anchor_time, pd.Timestamp):
689
708
  self._validate_time(query_def, anchor_time, None, evaluate=True)
690
709
  else:
691
710
  assert anchor_time == 'entity'
692
- if (query_def.entity_table not in self._graph_store.time_dict):
711
+ if query_def.entity_table not in self._sampler.time_column_dict:
693
712
  raise ValueError(f"Anchor time 'entity' requires the entity "
694
713
  f"table '{query_def.entity_table}' "
695
714
  f"to have a time column")
696
715
 
697
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
698
- random_seed)
699
-
700
- node, time, y = query_driver.collect_test(
701
- size=size,
702
- anchor_time=anchor_time,
703
- batch_size=min(10_000, size),
704
- max_iterations=max_iterations,
705
- guarantee_train_examples=False,
716
+ train, test = self._sampler.sample_target(
717
+ query=query,
718
+ num_train_examples=0,
719
+ train_anchor_time=anchor_time,
720
+ num_train_trials=0,
721
+ num_test_examples=size,
722
+ test_anchor_time=anchor_time,
723
+ num_test_trials=max_iterations * size,
724
+ random_seed=random_seed,
706
725
  )
707
726
 
708
- entity = self._graph_store.pkey_map_dict[
709
- query_def.entity_table].index[node]
710
-
711
727
  return pd.DataFrame({
712
- 'ENTITY': entity,
713
- 'ANCHOR_TIMESTAMP': time,
714
- 'TARGET': y,
728
+ 'ENTITY': test.entity_pkey,
729
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
730
+ 'TARGET': test.target,
715
731
  })
716
732
 
717
733
  # Helpers #################################################################
@@ -734,8 +750,6 @@ class KumoRFM:
734
750
 
735
751
  resp = self._api_client.parse_query(request)
736
752
 
737
- # TODO Expose validation warnings.
738
-
739
753
  if len(resp.validation_response.warnings) > 0:
740
754
  msg = '\n'.join([
741
755
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -753,36 +767,92 @@ class KumoRFM:
753
767
  raise ValueError(f"Failed to parse query '{query}'. "
754
768
  f"{msg}") from None
755
769
 
770
+ @staticmethod
771
+ def _get_task_type(
772
+ query: ValidatedPredictiveQuery,
773
+ edge_types: list[tuple[str, str, str]],
774
+ ) -> TaskType:
775
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
776
+ return TaskType.BINARY_CLASSIFICATION
777
+
778
+ target = query.target_ast
779
+ if isinstance(target, Join):
780
+ target = target.rhs_target
781
+ if isinstance(target, Aggregation):
782
+ if target.aggr == AggregationType.LIST_DISTINCT:
783
+ table_name, col_name = target._get_target_column_name().split(
784
+ '.')
785
+ target_edge_types = [
786
+ edge_type for edge_type in edge_types
787
+ if edge_type[0] == table_name and edge_type[1] == col_name
788
+ ]
789
+ if len(target_edge_types) != 1:
790
+ raise NotImplementedError(
791
+ f"Multilabel-classification queries based on "
792
+ f"'LIST_DISTINCT' are not supported yet. If you "
793
+ f"planned to write a link prediction query instead, "
794
+ f"make sure to register '{col_name}' as a "
795
+ f"foreign key.")
796
+ return TaskType.TEMPORAL_LINK_PREDICTION
797
+
798
+ return TaskType.REGRESSION
799
+
800
+ assert isinstance(target, Column)
801
+
802
+ if target.stype in {Stype.ID, Stype.categorical}:
803
+ return TaskType.MULTICLASS_CLASSIFICATION
804
+
805
+ if target.stype in {Stype.numerical}:
806
+ return TaskType.REGRESSION
807
+
808
+ raise NotImplementedError("Task type not yet supported")
809
+
810
+ def _get_default_anchor_time(
811
+ self,
812
+ query: ValidatedPredictiveQuery,
813
+ ) -> pd.Timestamp:
814
+ if query.query_type == QueryType.TEMPORAL:
815
+ aggr_table_names = [
816
+ aggr._get_target_column_name().split('.')[0]
817
+ for aggr in query.get_all_target_aggregations()
818
+ ]
819
+ return self._sampler.get_max_time(aggr_table_names)
820
+
821
+ assert query.query_type == QueryType.STATIC
822
+ return self._sampler.get_max_time()
823
+
756
824
  def _validate_time(
757
825
  self,
758
826
  query: ValidatedPredictiveQuery,
759
827
  anchor_time: pd.Timestamp,
760
- context_anchor_time: Union[pd.Timestamp, None],
828
+ context_anchor_time: pd.Timestamp | None,
761
829
  evaluate: bool,
762
830
  ) -> None:
763
831
 
764
- if self._graph_store.min_time == pd.Timestamp.max:
832
+ if len(self._sampler.time_column_dict) == 0:
765
833
  return # Graph without timestamps
766
834
 
767
- if anchor_time < self._graph_store.min_time:
835
+ min_time = self._sampler.get_min_time()
836
+ max_time = self._sampler.get_max_time()
837
+
838
+ if anchor_time < min_time:
768
839
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
769
- f"the earliest timestamp "
770
- f"'{self._graph_store.min_time}' in the data.")
840
+ f"the earliest timestamp '{min_time}' in the "
841
+ f"data.")
771
842
 
772
- if (context_anchor_time is not None
773
- and context_anchor_time < self._graph_store.min_time):
843
+ if context_anchor_time is not None and context_anchor_time < min_time:
774
844
  raise ValueError(f"Context anchor timestamp is too early or "
775
845
  f"aggregation time range is too large. To make "
776
846
  f"this prediction, we would need data back to "
777
847
  f"'{context_anchor_time}', however, your data "
778
- f"only contains data back to "
779
- f"'{self._graph_store.min_time}'.")
848
+ f"only contains data back to '{min_time}'.")
780
849
 
781
850
  if query.target_ast.date_offset_range is not None:
782
851
  end_offset = query.target_ast.date_offset_range.end_date_offset
783
852
  else:
784
853
  end_offset = pd.DateOffset(0)
785
- forecast_end_offset = end_offset * query.num_forecasts
854
+ end_offset = end_offset * query.num_forecasts
855
+
786
856
  if (context_anchor_time is not None
787
857
  and context_anchor_time > anchor_time):
788
858
  warnings.warn(f"Context anchor timestamp "
@@ -792,7 +862,7 @@ class KumoRFM:
792
862
  f"intended.")
793
863
  elif (query.query_type == QueryType.TEMPORAL
794
864
  and context_anchor_time is not None
795
- and context_anchor_time + forecast_end_offset > anchor_time):
865
+ and context_anchor_time + end_offset > anchor_time):
796
866
  warnings.warn(f"Aggregation for context examples at timestamp "
797
867
  f"'{context_anchor_time}' will leak information "
798
868
  f"from the prediction anchor timestamp "
@@ -800,40 +870,37 @@ class KumoRFM:
800
870
  f"intended.")
801
871
 
802
872
  elif (context_anchor_time is not None
803
- and context_anchor_time - forecast_end_offset
804
- < self._graph_store.min_time):
805
- _time = context_anchor_time - forecast_end_offset
873
+ and context_anchor_time - end_offset < min_time):
874
+ _time = context_anchor_time - end_offset
806
875
  warnings.warn(f"Context anchor timestamp is too early or "
807
876
  f"aggregation time range is too large. To form "
808
877
  f"proper input data, we would need data back to "
809
878
  f"'{_time}', however, your data only contains "
810
- f"data back to '{self._graph_store.min_time}'.")
879
+ f"data back to '{min_time}'.")
811
880
 
812
- if (not evaluate and anchor_time
813
- > self._graph_store.max_time + pd.DateOffset(days=1)):
881
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
814
882
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
815
- f"latest timestamp '{self._graph_store.max_time}' "
816
- f"in the data. Please make sure this is intended.")
883
+ f"latest timestamp '{max_time}' in the data. Please "
884
+ f"make sure this is intended.")
817
885
 
818
- max_eval_time = self._graph_store.max_time - forecast_end_offset
819
- if evaluate and anchor_time > max_eval_time:
886
+ if evaluate and anchor_time > max_time - end_offset:
820
887
  raise ValueError(
821
888
  f"Anchor timestamp for evaluation is after the latest "
822
- f"supported timestamp '{max_eval_time}'.")
889
+ f"supported timestamp '{max_time - end_offset}'.")
823
890
 
824
891
  def _get_context(
825
892
  self,
826
893
  query: ValidatedPredictiveQuery,
827
- indices: Union[List[str], List[float], List[int], None],
828
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
829
- context_anchor_time: Union[pd.Timestamp, None],
894
+ indices: list[str] | list[float] | list[int] | None,
895
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
896
+ context_anchor_time: pd.Timestamp | None,
830
897
  run_mode: RunMode,
831
- num_neighbors: Optional[List[int]],
898
+ num_neighbors: list[int] | None,
832
899
  num_hops: int,
833
900
  max_pq_iterations: int,
834
901
  evaluate: bool,
835
- random_seed: Optional[int] = _RANDOM_SEED,
836
- logger: Optional[ProgressLogger] = None,
902
+ random_seed: int | None = _RANDOM_SEED,
903
+ logger: ProgressLogger | None = None,
837
904
  ) -> Context:
838
905
 
839
906
  if num_neighbors is not None:
@@ -850,10 +917,9 @@ class KumoRFM:
850
917
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
851
918
  f"must go beyond this for your use-case.")
852
919
 
853
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
854
- task_type = LocalPQueryDriver.get_task_type(
855
- query,
856
- edge_types=self._graph_store.edge_types,
920
+ task_type = self._get_task_type(
921
+ query=query,
922
+ edge_types=self._sampler.edge_types,
857
923
  )
858
924
 
859
925
  if logger is not None:
@@ -885,14 +951,17 @@ class KumoRFM:
885
951
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
886
952
 
887
953
  if query.target_ast.date_offset_range is None:
888
- end_offset = pd.DateOffset(0)
954
+ step_offset = pd.DateOffset(0)
889
955
  else:
890
- end_offset = query.target_ast.date_offset_range.end_date_offset
891
- forecast_end_offset = end_offset * query.num_forecasts
956
+ step_offset = query.target_ast.date_offset_range.end_date_offset
957
+ end_offset = step_offset * query.num_forecasts
958
+
892
959
  if anchor_time is None:
893
- anchor_time = self._graph_store.max_time
960
+ anchor_time = self._get_default_anchor_time(query)
961
+
894
962
  if evaluate:
895
- anchor_time = anchor_time - forecast_end_offset
963
+ anchor_time = anchor_time - end_offset
964
+
896
965
  if logger is not None:
897
966
  assert isinstance(anchor_time, pd.Timestamp)
898
967
  if anchor_time == pd.Timestamp.min:
@@ -906,57 +975,71 @@ class KumoRFM:
906
975
 
907
976
  assert anchor_time is not None
908
977
  if isinstance(anchor_time, pd.Timestamp):
978
+ if context_anchor_time == 'entity':
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
909
981
  if context_anchor_time is None:
910
- context_anchor_time = anchor_time - forecast_end_offset
982
+ context_anchor_time = anchor_time - end_offset
911
983
  self._validate_time(query, anchor_time, context_anchor_time,
912
984
  evaluate)
913
985
  else:
914
986
  assert anchor_time == 'entity'
915
- if query.entity_table not in self._graph_store.time_dict:
987
+ if query.query_type != QueryType.STATIC:
988
+ raise ValueError("Anchor time 'entity' is only valid for "
989
+ "static predictive queries")
990
+ if query.entity_table not in self._sampler.time_column_dict:
916
991
  raise ValueError(f"Anchor time 'entity' requires the entity "
917
992
  f"table '{query.entity_table}' to "
918
993
  f"have a time column")
919
- if context_anchor_time is not None:
920
- warnings.warn("Ignoring option 'context_anchor_time' for "
921
- "`anchor_time='entity'`")
922
- context_anchor_time = None
994
+ if isinstance(context_anchor_time, pd.Timestamp):
995
+ raise ValueError("Anchor time 'entity' needs to be shared "
996
+ "for context and prediction examples")
997
+ context_anchor_time = 'entity'
923
998
 
924
- y_test: Optional[pd.Series] = None
999
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
925
1000
  if evaluate:
926
- max_test_size = _MAX_TEST_SIZE[run_mode]
1001
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
927
1002
  if task_type.is_link_pred:
928
- max_test_size = max_test_size // 5
1003
+ num_test_examples = num_test_examples // 5
1004
+ else:
1005
+ num_test_examples = 0
1006
+
1007
+ train, test = self._sampler.sample_target(
1008
+ query=query,
1009
+ num_train_examples=num_train_examples,
1010
+ train_anchor_time=context_anchor_time,
1011
+ num_train_trials=max_pq_iterations * num_train_examples,
1012
+ num_test_examples=num_test_examples,
1013
+ test_anchor_time=anchor_time,
1014
+ num_test_trials=max_pq_iterations * num_test_examples,
1015
+ random_seed=random_seed,
1016
+ )
1017
+ train_pkey, train_time, y_train = train
1018
+ test_pkey, test_time, y_test = test
929
1019
 
930
- test_node, test_time, y_test = query_driver.collect_test(
931
- size=max_test_size,
932
- anchor_time=anchor_time,
933
- max_iterations=max_pq_iterations,
934
- guarantee_train_examples=True,
935
- )
936
- if logger is not None:
937
- if task_type == TaskType.BINARY_CLASSIFICATION:
938
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
939
- msg = (f"Collected {len(y_test):,} test examples with "
940
- f"{pos:.2f}% positive cases")
941
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
942
- msg = (f"Collected {len(y_test):,} test examples "
943
- f"holding {y_test.nunique()} classes")
944
- elif task_type == TaskType.REGRESSION:
945
- _min, _max = float(y_test.min()), float(y_test.max())
946
- msg = (f"Collected {len(y_test):,} test examples with "
947
- f"targets between {format_value(_min)} and "
948
- f"{format_value(_max)}")
949
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
950
- num_rhs = y_test.explode().nunique()
951
- msg = (f"Collected {len(y_test):,} test examples with "
952
- f"{num_rhs:,} unique items")
953
- else:
954
- raise NotImplementedError
955
- logger.log(msg)
1020
+ if evaluate and logger is not None:
1021
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1022
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
+ msg = (f"Collected {len(y_test):,} test examples with "
1024
+ f"{pos:.2f}% positive cases")
1025
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
+ msg = (f"Collected {len(y_test):,} test examples holding "
1027
+ f"{y_test.nunique()} classes")
1028
+ elif task_type == TaskType.REGRESSION:
1029
+ _min, _max = float(y_test.min()), float(y_test.max())
1030
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1031
+ f"between {format_value(_min)} and "
1032
+ f"{format_value(_max)}")
1033
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
+ num_rhs = y_test.explode().nunique()
1035
+ msg = (f"Collected {len(y_test):,} test examples with "
1036
+ f"{num_rhs:,} unique items")
1037
+ else:
1038
+ raise NotImplementedError
1039
+ logger.log(msg)
956
1040
 
957
- else:
1041
+ if not evaluate:
958
1042
  assert indices is not None
959
-
960
1043
  if len(indices) > _MAX_PRED_SIZE[task_type]:
961
1044
  raise ValueError(f"Cannot predict for more than "
962
1045
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -964,26 +1047,12 @@ class KumoRFM:
964
1047
  f"`KumoRFM.batch_mode` to process entities "
965
1048
  f"in batches")
966
1049
 
967
- test_node = self._graph_store.get_node_id(
968
- table_name=query.entity_table,
969
- pkey=pd.Series(indices),
970
- )
971
-
1050
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
972
1051
  if isinstance(anchor_time, pd.Timestamp):
973
- test_time = pd.Series(anchor_time).repeat(
974
- len(test_node)).reset_index(drop=True)
1052
+ test_time = pd.Series([anchor_time]).repeat(
1053
+ len(indices)).reset_index(drop=True)
975
1054
  else:
976
- time = self._graph_store.time_dict[query.entity_table]
977
- time = time[test_node] * 1000**3
978
- test_time = pd.Series(time, dtype='datetime64[ns]')
979
-
980
- train_node, train_time, y_train = query_driver.collect_train(
981
- size=_MAX_CONTEXT_SIZE[run_mode],
982
- anchor_time=context_anchor_time or 'entity',
983
- exclude_node=test_node if (query.query_type == QueryType.STATIC
984
- or anchor_time == 'entity') else None,
985
- max_iterations=max_pq_iterations,
986
- )
1055
+ train_time = test_time = 'entity'
987
1056
 
988
1057
  if logger is not None:
989
1058
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1006,12 +1075,12 @@ class KumoRFM:
1006
1075
  raise NotImplementedError
1007
1076
  logger.log(msg)
1008
1077
 
1009
- entity_table_names: Tuple[str, ...]
1078
+ entity_table_names: tuple[str, ...]
1010
1079
  if task_type.is_link_pred:
1011
1080
  final_aggr = query.get_final_target_aggregation()
1012
1081
  assert final_aggr is not None
1013
1082
  edge_fkey = final_aggr._get_target_column_name()
1014
- for edge_type in self._graph_store.edge_types:
1083
+ for edge_type in self._sampler.edge_types:
1015
1084
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1016
1085
  entity_table_names = (
1017
1086
  query.entity_table,
@@ -1023,20 +1092,24 @@ class KumoRFM:
1023
1092
  # Exclude the entity anchor time from the feature set to prevent
1024
1093
  # running out-of-distribution between in-context and test examples:
1025
1094
  exclude_cols_dict = query.get_exclude_cols_dict()
1026
- if anchor_time == 'entity':
1095
+ if entity_table_names[0] in self._sampler.time_column_dict:
1027
1096
  if entity_table_names[0] not in exclude_cols_dict:
1028
1097
  exclude_cols_dict[entity_table_names[0]] = []
1029
- time_column_dict = self._graph_store.time_column_dict
1030
- time_column = time_column_dict[entity_table_names[0]]
1098
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1031
1099
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1032
1100
 
1033
- subgraph = self._graph_sampler(
1101
+ subgraph = self._sampler.sample_subgraph(
1034
1102
  entity_table_names=entity_table_names,
1035
- node=np.concatenate([train_node, test_node]),
1036
- time=np.concatenate([
1037
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1038
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- ]),
1103
+ entity_pkey=pd.concat(
1104
+ [train_pkey, test_pkey],
1105
+ axis=0,
1106
+ ignore_index=True,
1107
+ ),
1108
+ anchor_time=pd.concat(
1109
+ [train_time, test_time],
1110
+ axis=0,
1111
+ ignore_index=True,
1112
+ ) if isinstance(train_time, pd.Series) else 'entity',
1040
1113
  num_neighbors=num_neighbors,
1041
1114
  exclude_cols_dict=exclude_cols_dict,
1042
1115
  )
@@ -1048,23 +1121,19 @@ class KumoRFM:
1048
1121
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1049
1122
  f"must go beyond this for your use-case.")
1050
1123
 
1051
- step_size: Optional[int] = None
1052
- if query.query_type == QueryType.TEMPORAL:
1053
- step_size = date_offset_to_seconds(end_offset)
1054
-
1055
1124
  return Context(
1056
1125
  task_type=task_type,
1057
1126
  entity_table_names=entity_table_names,
1058
1127
  subgraph=subgraph,
1059
1128
  y_train=y_train,
1060
- y_test=y_test,
1129
+ y_test=y_test if evaluate else None,
1061
1130
  top_k=query.top_k,
1062
- step_size=step_size,
1131
+ step_size=None,
1063
1132
  )
1064
1133
 
1065
1134
  @staticmethod
1066
1135
  def _validate_metrics(
1067
- metrics: List[str],
1136
+ metrics: list[str],
1068
1137
  task_type: TaskType,
1069
1138
  ) -> None:
1070
1139
 
@@ -1121,7 +1190,7 @@ class KumoRFM:
1121
1190
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1122
1191
 
1123
1192
 
1124
- def format_value(value: Union[int, float]) -> str:
1193
+ def format_value(value: int | float) -> str:
1125
1194
  if value == int(value):
1126
1195
  return f'{int(value):,}'
1127
1196
  if abs(value) >= 1000: