kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +52 -52
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/sampler.py +761 -0
  24. kumoai/experimental/rfm/base/source.py +19 -0
  25. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  26. kumoai/experimental/rfm/base/table.py +753 -0
  27. kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
  28. kumoai/experimental/rfm/infer/__init__.py +8 -0
  29. kumoai/experimental/rfm/infer/dtype.py +81 -0
  30. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  31. kumoai/experimental/rfm/infer/pkey.py +128 -0
  32. kumoai/experimental/rfm/infer/stype.py +35 -0
  33. kumoai/experimental/rfm/infer/time_col.py +61 -0
  34. kumoai/experimental/rfm/pquery/executor.py +27 -27
  35. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  36. kumoai/experimental/rfm/rfm.py +313 -245
  37. kumoai/experimental/rfm/sagemaker.py +15 -7
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/testing/decorators.py +1 -1
  40. kumoai/testing/snow.py +50 -0
  41. kumoai/trainer/distilled_trainer.py +175 -0
  42. kumoai/utils/__init__.py +3 -2
  43. kumoai/utils/progress_logger.py +178 -12
  44. kumoai/utils/sql.py +3 -0
  45. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +10 -8
  46. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +49 -29
  47. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  48. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  49. kumoai/experimental/rfm/local_table.py +0 -545
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
  52. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
@@ -2,25 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
32
30
 
31
+ from kumoai import in_notebook, in_snowflake_notebook
33
32
  from kumoai.client.rfm import RFMAPI
34
33
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
34
+ from kumoai.experimental.rfm import Graph
35
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
36
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger
44
38
 
45
39
  _RANDOM_SEED = 42
46
40
 
@@ -95,24 +89,41 @@ class Explanation:
95
89
  def __getitem__(self, index: Literal[1]) -> str:
96
90
  pass
97
91
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
92
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
93
  if index == 0:
100
94
  return self.prediction
101
95
  if index == 1:
102
96
  return self.summary
103
97
  raise IndexError("Index out of range")
104
98
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
99
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
100
  return iter((self.prediction, self.summary))
107
101
 
108
102
  def __repr__(self) -> str:
109
103
  return str((self.prediction, self.summary))
110
104
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
105
+ def print(self) -> None:
106
+ r"""Prints the explanation."""
107
+ if in_snowflake_notebook():
108
+ import streamlit as st
109
+ st.dataframe(self.prediction, hide_index=True)
110
+ st.markdown(self.summary)
111
+ elif in_notebook():
112
+ from IPython.display import Markdown, display
113
+ try:
114
+ if hasattr(self.prediction.style, 'hide'):
115
+ display(self.prediction.hide(axis='index')) # pandas=2
116
+ else:
117
+ display(self.prediction.hide_index()) # pandas <1.3
118
+ except ImportError:
119
+ print(self.prediction.to_string(index=False)) # missing jinja2
120
+ display(Markdown(self.summary))
121
+ else:
122
+ print(self.prediction.to_string(index=False))
123
+ print(self.summary)
113
124
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
125
+ def _ipython_display_(self) -> None:
126
+ self.print()
116
127
 
117
128
 
118
129
  class KumoRFM:
@@ -123,17 +134,17 @@ class KumoRFM:
123
134
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
135
  relational dataset without training.
125
136
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
137
+ model from a :class:`Graph` object.
127
138
 
128
139
  .. code-block:: python
129
140
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
141
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
142
 
132
143
  df_users = pd.DataFrame(...)
133
144
  df_items = pd.DataFrame(...)
134
145
  df_orders = pd.DataFrame(...)
135
146
 
136
- graph = LocalGraph.from_data({
147
+ graph = Graph.from_data({
137
148
  'users': df_users,
138
149
  'items': df_items,
139
150
  'orders': df_orders,
@@ -150,32 +161,46 @@ class KumoRFM:
150
161
 
151
162
  Args:
152
163
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
164
  verbose: Whether to print verbose output.
165
+ optimize: If set to ``True``, will optimize the underlying data backend
166
+ for optimal querying. For example, for transactional database
167
+ backends, will create any missing indices. Requires write-access to
168
+ the data backend.
163
169
  """
164
170
  def __init__(
165
171
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
168
- verbose: Union[bool, ProgressLogger] = True,
172
+ graph: Graph,
173
+ verbose: bool | ProgressLogger = True,
174
+ optimize: bool = False,
169
175
  ) -> None:
170
176
  graph = graph.validate()
171
177
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
174
178
 
175
- self._batch_size: Optional[int | Literal['max']] = None
179
+ if graph.backend == DataBackend.LOCAL:
180
+ from kumoai.experimental.rfm.backend.local import LocalSampler
181
+ self._sampler: Sampler = LocalSampler(graph, verbose)
182
+ elif graph.backend == DataBackend.SQLITE:
183
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
184
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
185
+ elif graph.backend == DataBackend.SNOWFLAKE:
186
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
187
+ self._sampler = SnowSampler(graph, verbose)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ self._client: RFMAPI | None = None
192
+
193
+ self._batch_size: int | Literal['max'] | None = None
176
194
  self.num_retries: int = 0
195
+
196
+ @property
197
+ def _api_client(self) -> RFMAPI:
198
+ if self._client is not None:
199
+ return self._client
200
+
177
201
  from kumoai.experimental.rfm import global_state
178
- self._api_client = RFMAPI(global_state.client)
202
+ self._client = RFMAPI(global_state.client)
203
+ return self._client
179
204
 
180
205
  def __repr__(self) -> str:
181
206
  return f'{self.__class__.__name__}()'
@@ -183,7 +208,7 @@ class KumoRFM:
183
208
  @contextmanager
184
209
  def batch_mode(
185
210
  self,
186
- batch_size: Union[int, Literal['max']] = 'max',
211
+ batch_size: int | Literal['max'] = 'max',
187
212
  num_retries: int = 1,
188
213
  ) -> Generator[None, None, None]:
189
214
  """Context manager to predict in batches.
@@ -217,17 +242,17 @@ class KumoRFM:
217
242
  def predict(
218
243
  self,
219
244
  query: str,
220
- indices: Union[List[str], List[float], List[int], None] = None,
245
+ indices: list[str] | list[float] | list[int] | None = None,
221
246
  *,
222
247
  explain: Literal[False] = False,
223
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
224
- context_anchor_time: Union[pd.Timestamp, None] = None,
225
- run_mode: Union[RunMode, str] = RunMode.FAST,
226
- num_neighbors: Optional[List[int]] = None,
248
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
249
+ context_anchor_time: pd.Timestamp | None = None,
250
+ run_mode: RunMode | str = RunMode.FAST,
251
+ num_neighbors: list[int] | None = None,
227
252
  num_hops: int = 2,
228
- max_pq_iterations: int = 20,
229
- random_seed: Optional[int] = _RANDOM_SEED,
230
- verbose: Union[bool, ProgressLogger] = True,
253
+ max_pq_iterations: int = 10,
254
+ random_seed: int | None = _RANDOM_SEED,
255
+ verbose: bool | ProgressLogger = True,
231
256
  use_prediction_time: bool = False,
232
257
  ) -> pd.DataFrame:
233
258
  pass
@@ -236,17 +261,17 @@ class KumoRFM:
236
261
  def predict(
237
262
  self,
238
263
  query: str,
239
- indices: Union[List[str], List[float], List[int], None] = None,
264
+ indices: list[str] | list[float] | list[int] | None = None,
240
265
  *,
241
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
242
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
243
- context_anchor_time: Union[pd.Timestamp, None] = None,
244
- run_mode: Union[RunMode, str] = RunMode.FAST,
245
- num_neighbors: Optional[List[int]] = None,
266
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
267
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
268
+ context_anchor_time: pd.Timestamp | None = None,
269
+ run_mode: RunMode | str = RunMode.FAST,
270
+ num_neighbors: list[int] | None = None,
246
271
  num_hops: int = 2,
247
- max_pq_iterations: int = 20,
248
- random_seed: Optional[int] = _RANDOM_SEED,
249
- verbose: Union[bool, ProgressLogger] = True,
272
+ max_pq_iterations: int = 10,
273
+ random_seed: int | None = _RANDOM_SEED,
274
+ verbose: bool | ProgressLogger = True,
250
275
  use_prediction_time: bool = False,
251
276
  ) -> Explanation:
252
277
  pass
@@ -254,19 +279,19 @@ class KumoRFM:
254
279
  def predict(
255
280
  self,
256
281
  query: str,
257
- indices: Union[List[str], List[float], List[int], None] = None,
282
+ indices: list[str] | list[float] | list[int] | None = None,
258
283
  *,
259
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
260
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
261
- context_anchor_time: Union[pd.Timestamp, None] = None,
262
- run_mode: Union[RunMode, str] = RunMode.FAST,
263
- num_neighbors: Optional[List[int]] = None,
284
+ explain: bool | ExplainConfig | dict[str, Any] = False,
285
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
286
+ context_anchor_time: pd.Timestamp | None = None,
287
+ run_mode: RunMode | str = RunMode.FAST,
288
+ num_neighbors: list[int] | None = None,
264
289
  num_hops: int = 2,
265
- max_pq_iterations: int = 20,
266
- random_seed: Optional[int] = _RANDOM_SEED,
267
- verbose: Union[bool, ProgressLogger] = True,
290
+ max_pq_iterations: int = 10,
291
+ random_seed: int | None = _RANDOM_SEED,
292
+ verbose: bool | ProgressLogger = True,
268
293
  use_prediction_time: bool = False,
269
- ) -> Union[pd.DataFrame, Explanation]:
294
+ ) -> pd.DataFrame | Explanation:
270
295
  """Returns predictions for a predictive query.
271
296
 
272
297
  Args:
@@ -308,7 +333,7 @@ class KumoRFM:
308
333
  If ``explain`` is provided, returns an :class:`Explanation` object
309
334
  containing the prediction, summary, and details.
310
335
  """
311
- explain_config: Optional[ExplainConfig] = None
336
+ explain_config: ExplainConfig | None = None
312
337
  if explain is True:
313
338
  explain_config = ExplainConfig()
314
339
  elif explain is not False:
@@ -352,15 +377,15 @@ class KumoRFM:
352
377
  msg = f'[bold]PREDICT[/bold] {query_repr}'
353
378
 
354
379
  if not isinstance(verbose, ProgressLogger):
355
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
380
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
356
381
 
357
382
  with verbose as logger:
358
383
 
359
- batch_size: Optional[int] = None
384
+ batch_size: int | None = None
360
385
  if self._batch_size == 'max':
361
- task_type = LocalPQueryDriver.get_task_type(
362
- query_def,
363
- edge_types=self._graph_store.edge_types,
386
+ task_type = self._get_task_type(
387
+ query=query_def,
388
+ edge_types=self._sampler.edge_types,
364
389
  )
365
390
  batch_size = _MAX_PRED_SIZE[task_type]
366
391
  else:
@@ -376,9 +401,9 @@ class KumoRFM:
376
401
  logger.log(f"Splitting {len(indices):,} entities into "
377
402
  f"{len(batches):,} batches of size {batch_size:,}")
378
403
 
379
- predictions: List[pd.DataFrame] = []
380
- summary: Optional[str] = None
381
- details: Optional[Explanation] = None
404
+ predictions: list[pd.DataFrame] = []
405
+ summary: str | None = None
406
+ details: Explanation | None = None
382
407
  for i, batch in enumerate(batches):
383
408
  # TODO Re-use the context for subsequent predictions.
384
409
  context = self._get_context(
@@ -412,8 +437,7 @@ class KumoRFM:
412
437
  stats = Context.get_memory_stats(request_msg.context)
413
438
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
414
439
 
415
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
416
- and len(batches) > 1):
440
+ if i == 0 and len(batches) > 1:
417
441
  verbose.init_progress(
418
442
  total=len(batches),
419
443
  description='Predicting',
@@ -434,10 +458,10 @@ class KumoRFM:
434
458
 
435
459
  # Cast 'ENTITY' to correct data type:
436
460
  if 'ENTITY' in df:
437
- entity = query_def.entity_table
438
- pkey_map = self._graph_store.pkey_map_dict[entity]
439
- df['ENTITY'] = df['ENTITY'].astype(
440
- type(pkey_map.index[0]))
461
+ table_dict = context.subgraph.table_dict
462
+ table = table_dict[query_def.entity_table]
463
+ ser = table.df[table.primary_key]
464
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
441
465
 
442
466
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
443
467
  if 'ANCHOR_TIMESTAMP' in df:
@@ -452,8 +476,7 @@ class KumoRFM:
452
476
 
453
477
  predictions.append(df)
454
478
 
455
- if (isinstance(verbose, InteractiveProgressLogger)
456
- and len(batches) > 1):
479
+ if len(batches) > 1:
457
480
  verbose.step()
458
481
 
459
482
  break
@@ -491,9 +514,9 @@ class KumoRFM:
491
514
  def is_valid_entity(
492
515
  self,
493
516
  query: str,
494
- indices: Union[List[str], List[float], List[int], None] = None,
517
+ indices: list[str] | list[float] | list[int] | None = None,
495
518
  *,
496
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
519
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
497
520
  ) -> np.ndarray:
498
521
  r"""Returns a mask that denotes which entities are valid for the
499
522
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -520,37 +543,32 @@ class KumoRFM:
520
543
  raise ValueError("At least one entity is required")
521
544
 
522
545
  if anchor_time is None:
523
- anchor_time = self._graph_store.max_time
546
+ anchor_time = self._get_default_anchor_time(query_def)
524
547
 
525
548
  if isinstance(anchor_time, pd.Timestamp):
526
549
  self._validate_time(query_def, anchor_time, None, False)
527
550
  else:
528
551
  assert anchor_time == 'entity'
529
- if (query_def.entity_table not in self._graph_store.time_dict):
552
+ if query_def.entity_table not in self._sampler.time_column_dict:
530
553
  raise ValueError(f"Anchor time 'entity' requires the entity "
531
554
  f"table '{query_def.entity_table}' "
532
555
  f"to have a time column.")
533
556
 
534
- node = self._graph_store.get_node_id(
535
- table_name=query_def.entity_table,
536
- pkey=pd.Series(indices),
537
- )
538
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
539
- return query_driver.is_valid(node, anchor_time)
557
+ raise NotImplementedError
540
558
 
541
559
  def evaluate(
542
560
  self,
543
561
  query: str,
544
562
  *,
545
- metrics: Optional[List[str]] = None,
546
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
547
- context_anchor_time: Union[pd.Timestamp, None] = None,
548
- run_mode: Union[RunMode, str] = RunMode.FAST,
549
- num_neighbors: Optional[List[int]] = None,
563
+ metrics: list[str] | None = None,
564
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
565
+ context_anchor_time: pd.Timestamp | None = None,
566
+ run_mode: RunMode | str = RunMode.FAST,
567
+ num_neighbors: list[int] | None = None,
550
568
  num_hops: int = 2,
551
- max_pq_iterations: int = 20,
552
- random_seed: Optional[int] = _RANDOM_SEED,
553
- verbose: Union[bool, ProgressLogger] = True,
569
+ max_pq_iterations: int = 10,
570
+ random_seed: int | None = _RANDOM_SEED,
571
+ verbose: bool | ProgressLogger = True,
554
572
  use_prediction_time: bool = False,
555
573
  ) -> pd.DataFrame:
556
574
  """Evaluates a predictive query.
@@ -598,7 +616,7 @@ class KumoRFM:
598
616
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
599
617
 
600
618
  if not isinstance(verbose, ProgressLogger):
601
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
619
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
602
620
 
603
621
  with verbose as logger:
604
622
  context = self._get_context(
@@ -657,9 +675,9 @@ class KumoRFM:
657
675
  query: str,
658
676
  size: int,
659
677
  *,
660
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
661
- random_seed: Optional[int] = _RANDOM_SEED,
662
- max_iterations: int = 20,
678
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
679
+ random_seed: int | None = _RANDOM_SEED,
680
+ max_iterations: int = 10,
663
681
  ) -> pd.DataFrame:
664
682
  """Returns the labels of a predictive query for a specified anchor
665
683
  time.
@@ -679,40 +697,37 @@ class KumoRFM:
679
697
  query_def = self._parse_query(query)
680
698
 
681
699
  if anchor_time is None:
682
- anchor_time = self._graph_store.max_time
700
+ anchor_time = self._get_default_anchor_time(query_def)
683
701
  if query_def.target_ast.date_offset_range is not None:
684
- anchor_time = anchor_time - (
685
- query_def.target_ast.date_offset_range.end_date_offset *
686
- query_def.num_forecasts)
702
+ offset = query_def.target_ast.date_offset_range.end_date_offset
703
+ offset *= query_def.num_forecasts
704
+ anchor_time -= offset
687
705
 
688
706
  assert anchor_time is not None
689
707
  if isinstance(anchor_time, pd.Timestamp):
690
708
  self._validate_time(query_def, anchor_time, None, evaluate=True)
691
709
  else:
692
710
  assert anchor_time == 'entity'
693
- if (query_def.entity_table not in self._graph_store.time_dict):
711
+ if query_def.entity_table not in self._sampler.time_column_dict:
694
712
  raise ValueError(f"Anchor time 'entity' requires the entity "
695
713
  f"table '{query_def.entity_table}' "
696
714
  f"to have a time column")
697
715
 
698
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
699
- random_seed)
700
-
701
- node, time, y = query_driver.collect_test(
702
- size=size,
703
- anchor_time=anchor_time,
704
- batch_size=min(10_000, size),
705
- max_iterations=max_iterations,
706
- guarantee_train_examples=False,
716
+ train, test = self._sampler.sample_target(
717
+ query=query_def,
718
+ num_train_examples=0,
719
+ train_anchor_time=anchor_time,
720
+ num_train_trials=0,
721
+ num_test_examples=size,
722
+ test_anchor_time=anchor_time,
723
+ num_test_trials=max_iterations * size,
724
+ random_seed=random_seed,
707
725
  )
708
726
 
709
- entity = self._graph_store.pkey_map_dict[
710
- query_def.entity_table].index[node]
711
-
712
727
  return pd.DataFrame({
713
- 'ENTITY': entity,
714
- 'ANCHOR_TIMESTAMP': time,
715
- 'TARGET': y,
728
+ 'ENTITY': test.entity_pkey,
729
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
730
+ 'TARGET': test.target,
716
731
  })
717
732
 
718
733
  # Helpers #################################################################
@@ -735,8 +750,6 @@ class KumoRFM:
735
750
 
736
751
  resp = self._api_client.parse_query(request)
737
752
 
738
- # TODO Expose validation warnings.
739
-
740
753
  if len(resp.validation_response.warnings) > 0:
741
754
  msg = '\n'.join([
742
755
  f'{i+1}. {warning.title}: {warning.message}' for i, warning
@@ -754,36 +767,92 @@ class KumoRFM:
754
767
  raise ValueError(f"Failed to parse query '{query}'. "
755
768
  f"{msg}") from None
756
769
 
770
+ @staticmethod
771
+ def _get_task_type(
772
+ query: ValidatedPredictiveQuery,
773
+ edge_types: list[tuple[str, str, str]],
774
+ ) -> TaskType:
775
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
776
+ return TaskType.BINARY_CLASSIFICATION
777
+
778
+ target = query.target_ast
779
+ if isinstance(target, Join):
780
+ target = target.rhs_target
781
+ if isinstance(target, Aggregation):
782
+ if target.aggr == AggregationType.LIST_DISTINCT:
783
+ table_name, col_name = target._get_target_column_name().split(
784
+ '.')
785
+ target_edge_types = [
786
+ edge_type for edge_type in edge_types
787
+ if edge_type[0] == table_name and edge_type[1] == col_name
788
+ ]
789
+ if len(target_edge_types) != 1:
790
+ raise NotImplementedError(
791
+ f"Multilabel-classification queries based on "
792
+ f"'LIST_DISTINCT' are not supported yet. If you "
793
+ f"planned to write a link prediction query instead, "
794
+ f"make sure to register '{col_name}' as a "
795
+ f"foreign key.")
796
+ return TaskType.TEMPORAL_LINK_PREDICTION
797
+
798
+ return TaskType.REGRESSION
799
+
800
+ assert isinstance(target, Column)
801
+
802
+ if target.stype in {Stype.ID, Stype.categorical}:
803
+ return TaskType.MULTICLASS_CLASSIFICATION
804
+
805
+ if target.stype in {Stype.numerical}:
806
+ return TaskType.REGRESSION
807
+
808
+ raise NotImplementedError("Task type not yet supported")
809
+
810
+ def _get_default_anchor_time(
811
+ self,
812
+ query: ValidatedPredictiveQuery,
813
+ ) -> pd.Timestamp:
814
+ if query.query_type == QueryType.TEMPORAL:
815
+ aggr_table_names = [
816
+ aggr._get_target_column_name().split('.')[0]
817
+ for aggr in query.get_all_target_aggregations()
818
+ ]
819
+ return self._sampler.get_max_time(aggr_table_names)
820
+
821
+ assert query.query_type == QueryType.STATIC
822
+ return self._sampler.get_max_time()
823
+
757
824
  def _validate_time(
758
825
  self,
759
826
  query: ValidatedPredictiveQuery,
760
827
  anchor_time: pd.Timestamp,
761
- context_anchor_time: Union[pd.Timestamp, None],
828
+ context_anchor_time: pd.Timestamp | None,
762
829
  evaluate: bool,
763
830
  ) -> None:
764
831
 
765
- if self._graph_store.min_time == pd.Timestamp.max:
832
+ if len(self._sampler.time_column_dict) == 0:
766
833
  return # Graph without timestamps
767
834
 
768
- if anchor_time < self._graph_store.min_time:
835
+ min_time = self._sampler.get_min_time()
836
+ max_time = self._sampler.get_max_time()
837
+
838
+ if anchor_time < min_time:
769
839
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
770
- f"the earliest timestamp "
771
- f"'{self._graph_store.min_time}' in the data.")
840
+ f"the earliest timestamp '{min_time}' in the "
841
+ f"data.")
772
842
 
773
- if (context_anchor_time is not None
774
- and context_anchor_time < self._graph_store.min_time):
843
+ if context_anchor_time is not None and context_anchor_time < min_time:
775
844
  raise ValueError(f"Context anchor timestamp is too early or "
776
845
  f"aggregation time range is too large. To make "
777
846
  f"this prediction, we would need data back to "
778
847
  f"'{context_anchor_time}', however, your data "
779
- f"only contains data back to "
780
- f"'{self._graph_store.min_time}'.")
848
+ f"only contains data back to '{min_time}'.")
781
849
 
782
850
  if query.target_ast.date_offset_range is not None:
783
851
  end_offset = query.target_ast.date_offset_range.end_date_offset
784
852
  else:
785
853
  end_offset = pd.DateOffset(0)
786
- forecast_end_offset = end_offset * query.num_forecasts
854
+ end_offset = end_offset * query.num_forecasts
855
+
787
856
  if (context_anchor_time is not None
788
857
  and context_anchor_time > anchor_time):
789
858
  warnings.warn(f"Context anchor timestamp "
@@ -793,7 +862,7 @@ class KumoRFM:
793
862
  f"intended.")
794
863
  elif (query.query_type == QueryType.TEMPORAL
795
864
  and context_anchor_time is not None
796
- and context_anchor_time + forecast_end_offset > anchor_time):
865
+ and context_anchor_time + end_offset > anchor_time):
797
866
  warnings.warn(f"Aggregation for context examples at timestamp "
798
867
  f"'{context_anchor_time}' will leak information "
799
868
  f"from the prediction anchor timestamp "
@@ -801,40 +870,37 @@ class KumoRFM:
801
870
  f"intended.")
802
871
 
803
872
  elif (context_anchor_time is not None
804
- and context_anchor_time - forecast_end_offset
805
- < self._graph_store.min_time):
806
- _time = context_anchor_time - forecast_end_offset
873
+ and context_anchor_time - end_offset < min_time):
874
+ _time = context_anchor_time - end_offset
807
875
  warnings.warn(f"Context anchor timestamp is too early or "
808
876
  f"aggregation time range is too large. To form "
809
877
  f"proper input data, we would need data back to "
810
878
  f"'{_time}', however, your data only contains "
811
- f"data back to '{self._graph_store.min_time}'.")
879
+ f"data back to '{min_time}'.")
812
880
 
813
- if (not evaluate and anchor_time
814
- > self._graph_store.max_time + pd.DateOffset(days=1)):
881
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
815
882
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
816
- f"latest timestamp '{self._graph_store.max_time}' "
817
- f"in the data. Please make sure this is intended.")
883
+ f"latest timestamp '{max_time}' in the data. Please "
884
+ f"make sure this is intended.")
818
885
 
819
- max_eval_time = self._graph_store.max_time - forecast_end_offset
820
- if evaluate and anchor_time > max_eval_time:
886
+ if evaluate and anchor_time > max_time - end_offset:
821
887
  raise ValueError(
822
888
  f"Anchor timestamp for evaluation is after the latest "
823
- f"supported timestamp '{max_eval_time}'.")
889
+ f"supported timestamp '{max_time - end_offset}'.")
824
890
 
825
891
  def _get_context(
826
892
  self,
827
893
  query: ValidatedPredictiveQuery,
828
- indices: Union[List[str], List[float], List[int], None],
829
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
830
- context_anchor_time: Union[pd.Timestamp, None],
894
+ indices: list[str] | list[float] | list[int] | None,
895
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
896
+ context_anchor_time: pd.Timestamp | None,
831
897
  run_mode: RunMode,
832
- num_neighbors: Optional[List[int]],
898
+ num_neighbors: list[int] | None,
833
899
  num_hops: int,
834
900
  max_pq_iterations: int,
835
901
  evaluate: bool,
836
- random_seed: Optional[int] = _RANDOM_SEED,
837
- logger: Optional[ProgressLogger] = None,
902
+ random_seed: int | None = _RANDOM_SEED,
903
+ logger: ProgressLogger | None = None,
838
904
  ) -> Context:
839
905
 
840
906
  if num_neighbors is not None:
@@ -851,10 +917,9 @@ class KumoRFM:
851
917
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
852
918
  f"must go beyond this for your use-case.")
853
919
 
854
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
855
- task_type = LocalPQueryDriver.get_task_type(
856
- query,
857
- edge_types=self._graph_store.edge_types,
920
+ task_type = self._get_task_type(
921
+ query=query,
922
+ edge_types=self._sampler.edge_types,
858
923
  )
859
924
 
860
925
  if logger is not None:
@@ -886,14 +951,17 @@ class KumoRFM:
886
951
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
887
952
 
888
953
  if query.target_ast.date_offset_range is None:
889
- end_offset = pd.DateOffset(0)
954
+ step_offset = pd.DateOffset(0)
890
955
  else:
891
- end_offset = query.target_ast.date_offset_range.end_date_offset
892
- forecast_end_offset = end_offset * query.num_forecasts
956
+ step_offset = query.target_ast.date_offset_range.end_date_offset
957
+ end_offset = step_offset * query.num_forecasts
958
+
893
959
  if anchor_time is None:
894
- anchor_time = self._graph_store.max_time
960
+ anchor_time = self._get_default_anchor_time(query)
961
+
895
962
  if evaluate:
896
- anchor_time = anchor_time - forecast_end_offset
963
+ anchor_time = anchor_time - end_offset
964
+
897
965
  if logger is not None:
898
966
  assert isinstance(anchor_time, pd.Timestamp)
899
967
  if anchor_time == pd.Timestamp.min:
@@ -907,57 +975,71 @@ class KumoRFM:
907
975
 
908
976
  assert anchor_time is not None
909
977
  if isinstance(anchor_time, pd.Timestamp):
978
+ if context_anchor_time == 'entity':
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
910
981
  if context_anchor_time is None:
911
- context_anchor_time = anchor_time - forecast_end_offset
982
+ context_anchor_time = anchor_time - end_offset
912
983
  self._validate_time(query, anchor_time, context_anchor_time,
913
984
  evaluate)
914
985
  else:
915
986
  assert anchor_time == 'entity'
916
- if query.entity_table not in self._graph_store.time_dict:
987
+ if query.query_type != QueryType.STATIC:
988
+ raise ValueError("Anchor time 'entity' is only valid for "
989
+ "static predictive queries")
990
+ if query.entity_table not in self._sampler.time_column_dict:
917
991
  raise ValueError(f"Anchor time 'entity' requires the entity "
918
992
  f"table '{query.entity_table}' to "
919
993
  f"have a time column")
920
- if context_anchor_time is not None:
921
- warnings.warn("Ignoring option 'context_anchor_time' for "
922
- "`anchor_time='entity'`")
923
- context_anchor_time = None
994
+ if isinstance(context_anchor_time, pd.Timestamp):
995
+ raise ValueError("Anchor time 'entity' needs to be shared "
996
+ "for context and prediction examples")
997
+ context_anchor_time = 'entity'
924
998
 
925
- y_test: Optional[pd.Series] = None
999
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
926
1000
  if evaluate:
927
- max_test_size = _MAX_TEST_SIZE[run_mode]
1001
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
928
1002
  if task_type.is_link_pred:
929
- max_test_size = max_test_size // 5
1003
+ num_test_examples = num_test_examples // 5
1004
+ else:
1005
+ num_test_examples = 0
1006
+
1007
+ train, test = self._sampler.sample_target(
1008
+ query=query,
1009
+ num_train_examples=num_train_examples,
1010
+ train_anchor_time=context_anchor_time,
1011
+ num_train_trials=max_pq_iterations * num_train_examples,
1012
+ num_test_examples=num_test_examples,
1013
+ test_anchor_time=anchor_time,
1014
+ num_test_trials=max_pq_iterations * num_test_examples,
1015
+ random_seed=random_seed,
1016
+ )
1017
+ train_pkey, train_time, y_train = train
1018
+ test_pkey, test_time, y_test = test
930
1019
 
931
- test_node, test_time, y_test = query_driver.collect_test(
932
- size=max_test_size,
933
- anchor_time=anchor_time,
934
- max_iterations=max_pq_iterations,
935
- guarantee_train_examples=True,
936
- )
937
- if logger is not None:
938
- if task_type == TaskType.BINARY_CLASSIFICATION:
939
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
940
- msg = (f"Collected {len(y_test):,} test examples with "
941
- f"{pos:.2f}% positive cases")
942
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
943
- msg = (f"Collected {len(y_test):,} test examples "
944
- f"holding {y_test.nunique()} classes")
945
- elif task_type == TaskType.REGRESSION:
946
- _min, _max = float(y_test.min()), float(y_test.max())
947
- msg = (f"Collected {len(y_test):,} test examples with "
948
- f"targets between {format_value(_min)} and "
949
- f"{format_value(_max)}")
950
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
951
- num_rhs = y_test.explode().nunique()
952
- msg = (f"Collected {len(y_test):,} test examples with "
953
- f"{num_rhs:,} unique items")
954
- else:
955
- raise NotImplementedError
956
- logger.log(msg)
1020
+ if evaluate and logger is not None:
1021
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1022
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
+ msg = (f"Collected {len(y_test):,} test examples with "
1024
+ f"{pos:.2f}% positive cases")
1025
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
+ msg = (f"Collected {len(y_test):,} test examples holding "
1027
+ f"{y_test.nunique()} classes")
1028
+ elif task_type == TaskType.REGRESSION:
1029
+ _min, _max = float(y_test.min()), float(y_test.max())
1030
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1031
+ f"between {format_value(_min)} and "
1032
+ f"{format_value(_max)}")
1033
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
+ num_rhs = y_test.explode().nunique()
1035
+ msg = (f"Collected {len(y_test):,} test examples with "
1036
+ f"{num_rhs:,} unique items")
1037
+ else:
1038
+ raise NotImplementedError
1039
+ logger.log(msg)
957
1040
 
958
- else:
1041
+ if not evaluate:
959
1042
  assert indices is not None
960
-
961
1043
  if len(indices) > _MAX_PRED_SIZE[task_type]:
962
1044
  raise ValueError(f"Cannot predict for more than "
963
1045
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -965,26 +1047,12 @@ class KumoRFM:
965
1047
  f"`KumoRFM.batch_mode` to process entities "
966
1048
  f"in batches")
967
1049
 
968
- test_node = self._graph_store.get_node_id(
969
- table_name=query.entity_table,
970
- pkey=pd.Series(indices),
971
- )
972
-
1050
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
973
1051
  if isinstance(anchor_time, pd.Timestamp):
974
- test_time = pd.Series(anchor_time).repeat(
975
- len(test_node)).reset_index(drop=True)
1052
+ test_time = pd.Series([anchor_time]).repeat(
1053
+ len(indices)).reset_index(drop=True)
976
1054
  else:
977
- time = self._graph_store.time_dict[query.entity_table]
978
- time = time[test_node] * 1000**3
979
- test_time = pd.Series(time, dtype='datetime64[ns]')
980
-
981
- train_node, train_time, y_train = query_driver.collect_train(
982
- size=_MAX_CONTEXT_SIZE[run_mode],
983
- anchor_time=context_anchor_time or 'entity',
984
- exclude_node=test_node if (query.query_type == QueryType.STATIC
985
- or anchor_time == 'entity') else None,
986
- max_iterations=max_pq_iterations,
987
- )
1055
+ train_time = test_time = 'entity'
988
1056
 
989
1057
  if logger is not None:
990
1058
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1007,12 +1075,12 @@ class KumoRFM:
1007
1075
  raise NotImplementedError
1008
1076
  logger.log(msg)
1009
1077
 
1010
- entity_table_names: Tuple[str, ...]
1078
+ entity_table_names: tuple[str, ...]
1011
1079
  if task_type.is_link_pred:
1012
1080
  final_aggr = query.get_final_target_aggregation()
1013
1081
  assert final_aggr is not None
1014
1082
  edge_fkey = final_aggr._get_target_column_name()
1015
- for edge_type in self._graph_store.edge_types:
1083
+ for edge_type in self._sampler.edge_types:
1016
1084
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1017
1085
  entity_table_names = (
1018
1086
  query.entity_table,
@@ -1024,20 +1092,24 @@ class KumoRFM:
1024
1092
  # Exclude the entity anchor time from the feature set to prevent
1025
1093
  # running out-of-distribution between in-context and test examples:
1026
1094
  exclude_cols_dict = query.get_exclude_cols_dict()
1027
- if anchor_time == 'entity':
1095
+ if entity_table_names[0] in self._sampler.time_column_dict:
1028
1096
  if entity_table_names[0] not in exclude_cols_dict:
1029
1097
  exclude_cols_dict[entity_table_names[0]] = []
1030
- time_column_dict = self._graph_store.time_column_dict
1031
- time_column = time_column_dict[entity_table_names[0]]
1098
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1032
1099
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1033
1100
 
1034
- subgraph = self._graph_sampler(
1101
+ subgraph = self._sampler.sample_subgraph(
1035
1102
  entity_table_names=entity_table_names,
1036
- node=np.concatenate([train_node, test_node]),
1037
- time=np.concatenate([
1038
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1039
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1040
- ]),
1103
+ entity_pkey=pd.concat(
1104
+ [train_pkey, test_pkey],
1105
+ axis=0,
1106
+ ignore_index=True,
1107
+ ),
1108
+ anchor_time=pd.concat(
1109
+ [train_time, test_time],
1110
+ axis=0,
1111
+ ignore_index=True,
1112
+ ) if isinstance(train_time, pd.Series) else 'entity',
1041
1113
  num_neighbors=num_neighbors,
1042
1114
  exclude_cols_dict=exclude_cols_dict,
1043
1115
  )
@@ -1049,23 +1121,19 @@ class KumoRFM:
1049
1121
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1050
1122
  f"must go beyond this for your use-case.")
1051
1123
 
1052
- step_size: Optional[int] = None
1053
- if query.query_type == QueryType.TEMPORAL:
1054
- step_size = date_offset_to_seconds(end_offset)
1055
-
1056
1124
  return Context(
1057
1125
  task_type=task_type,
1058
1126
  entity_table_names=entity_table_names,
1059
1127
  subgraph=subgraph,
1060
1128
  y_train=y_train,
1061
- y_test=y_test,
1129
+ y_test=y_test if evaluate else None,
1062
1130
  top_k=query.top_k,
1063
- step_size=step_size,
1131
+ step_size=None,
1064
1132
  )
1065
1133
 
1066
1134
  @staticmethod
1067
1135
  def _validate_metrics(
1068
- metrics: List[str],
1136
+ metrics: list[str],
1069
1137
  task_type: TaskType,
1070
1138
  ) -> None:
1071
1139
 
@@ -1122,7 +1190,7 @@ class KumoRFM:
1122
1190
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1123
1191
 
1124
1192
 
1125
- def format_value(value: Union[int, float]) -> str:
1193
+ def format_value(value: int | float) -> str:
1126
1194
  if value == int(value):
1127
1195
  return f'{int(value):,}'
1128
1196
  if abs(value) >= 1000: