kumoai 2.13.0.dev202511131731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +191 -50
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/sampler.py +761 -0
  24. kumoai/experimental/rfm/base/source.py +19 -0
  25. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  26. kumoai/experimental/rfm/base/table.py +753 -0
  27. kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
  28. kumoai/experimental/rfm/infer/__init__.py +8 -0
  29. kumoai/experimental/rfm/infer/dtype.py +81 -0
  30. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  31. kumoai/experimental/rfm/infer/pkey.py +128 -0
  32. kumoai/experimental/rfm/infer/stype.py +35 -0
  33. kumoai/experimental/rfm/infer/time_col.py +61 -0
  34. kumoai/experimental/rfm/pquery/executor.py +27 -27
  35. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  36. kumoai/experimental/rfm/rfm.py +322 -252
  37. kumoai/experimental/rfm/sagemaker.py +138 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/distilled_trainer.py +175 -0
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +178 -12
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +13 -2
  47. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +50 -29
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  50. kumoai/experimental/rfm/local_table.py +0 -545
  51. kumoai/experimental/rfm/utils.py +0 -344
  52. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
@@ -2,25 +2,22 @@ import json
2
2
  import time
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from collections.abc import Generator
5
+ from collections.abc import Generator, Iterator
6
6
  from contextlib import contextmanager
7
7
  from dataclasses import dataclass, replace
8
- from typing import (
9
- Any,
10
- Dict,
11
- Iterator,
12
- List,
13
- Literal,
14
- Optional,
15
- Tuple,
16
- Union,
17
- overload,
18
- )
8
+ from typing import Any, Literal, overload
19
9
 
20
10
  import numpy as np
21
11
  import pandas as pd
22
12
  from kumoapi.model_plan import RunMode
23
13
  from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
24
21
  from kumoapi.rfm import Context
25
22
  from kumoapi.rfm import Explanation as ExplanationConfig
26
23
  from kumoapi.rfm import (
@@ -29,18 +26,15 @@ from kumoapi.rfm import (
29
26
  RFMPredictRequest,
30
27
  )
31
28
  from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
32
30
 
33
- from kumoai import global_state
31
+ from kumoai import in_notebook, in_snowflake_notebook
32
+ from kumoai.client.rfm import RFMAPI
34
33
  from kumoai.exceptions import HTTPException
35
- from kumoai.experimental.rfm import LocalGraph
36
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
- from kumoai.experimental.rfm.local_pquery_driver import (
39
- LocalPQueryDriver,
40
- date_offset_to_seconds,
41
- )
34
+ from kumoai.experimental.rfm import Graph
35
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
42
36
  from kumoai.mixin import CastMixin
43
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
37
+ from kumoai.utils import ProgressLogger
44
38
 
45
39
  _RANDOM_SEED = 42
46
40
 
@@ -95,24 +89,41 @@ class Explanation:
95
89
  def __getitem__(self, index: Literal[1]) -> str:
96
90
  pass
97
91
 
98
- def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
92
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
99
93
  if index == 0:
100
94
  return self.prediction
101
95
  if index == 1:
102
96
  return self.summary
103
97
  raise IndexError("Index out of range")
104
98
 
105
- def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
99
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
106
100
  return iter((self.prediction, self.summary))
107
101
 
108
102
  def __repr__(self) -> str:
109
103
  return str((self.prediction, self.summary))
110
104
 
111
- def _ipython_display_(self) -> None:
112
- from IPython.display import Markdown, display
105
+ def print(self) -> None:
106
+ r"""Prints the explanation."""
107
+ if in_snowflake_notebook():
108
+ import streamlit as st
109
+ st.dataframe(self.prediction, hide_index=True)
110
+ st.markdown(self.summary)
111
+ elif in_notebook():
112
+ from IPython.display import Markdown, display
113
+ try:
114
+ if hasattr(self.prediction.style, 'hide'):
115
+ display(self.prediction.hide(axis='index')) # pandas=2
116
+ else:
117
+ display(self.prediction.hide_index()) # pandas <1.3
118
+ except ImportError:
119
+ print(self.prediction.to_string(index=False)) # missing jinja2
120
+ display(Markdown(self.summary))
121
+ else:
122
+ print(self.prediction.to_string(index=False))
123
+ print(self.summary)
113
124
 
114
- display(self.prediction)
115
- display(Markdown(self.summary))
125
+ def _ipython_display_(self) -> None:
126
+ self.print()
116
127
 
117
128
 
118
129
  class KumoRFM:
@@ -123,17 +134,17 @@ class KumoRFM:
123
134
  :class:`KumoRFM` is a foundation model to generate predictions for any
124
135
  relational dataset without training.
125
136
  The model is pre-trained and the class provides an interface to query the
126
- model from a :class:`LocalGraph` object.
137
+ model from a :class:`Graph` object.
127
138
 
128
139
  .. code-block:: python
129
140
 
130
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
141
+ from kumoai.experimental.rfm import Graph, KumoRFM
131
142
 
132
143
  df_users = pd.DataFrame(...)
133
144
  df_items = pd.DataFrame(...)
134
145
  df_orders = pd.DataFrame(...)
135
146
 
136
- graph = LocalGraph.from_data({
147
+ graph = Graph.from_data({
137
148
  'users': df_users,
138
149
  'items': df_items,
139
150
  'orders': df_orders,
@@ -141,47 +152,63 @@ class KumoRFM:
141
152
 
142
153
  rfm = KumoRFM(graph)
143
154
 
144
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
145
- "FOR users.user_id=0")
146
- result = rfm.query(query)
155
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
156
+ "FOR users.user_id=1")
157
+ result = rfm.predict(query)
147
158
 
148
159
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
149
160
  # 1 0.85
150
161
 
151
162
  Args:
152
163
  graph: The graph.
153
- preprocess: Whether to pre-process the data in advance during graph
154
- materialization.
155
- This is a runtime trade-off between graph materialization and model
156
- processing speed.
157
- It can be benefical to preprocess your data once and then run many
158
- queries on top to achieve maximum model speed.
159
- However, if activiated, graph materialization can take potentially
160
- much longer, especially on graphs with many large text columns.
161
- Best to tune this option manually.
162
164
  verbose: Whether to print verbose output.
165
+ optimize: If set to ``True``, will optimize the underlying data backend
166
+ for optimal querying. For example, for transactional database
167
+ backends, will create any missing indices. Requires write-access to
168
+ the data backend.
163
169
  """
164
170
  def __init__(
165
171
  self,
166
- graph: LocalGraph,
167
- preprocess: bool = False,
168
- verbose: Union[bool, ProgressLogger] = True,
172
+ graph: Graph,
173
+ verbose: bool | ProgressLogger = True,
174
+ optimize: bool = False,
169
175
  ) -> None:
170
176
  graph = graph.validate()
171
177
  self._graph_def = graph._to_api_graph_definition()
172
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
- self._graph_sampler = LocalGraphSampler(self._graph_store)
174
178
 
175
- self._batch_size: Optional[int | Literal['max']] = None
179
+ if graph.backend == DataBackend.LOCAL:
180
+ from kumoai.experimental.rfm.backend.local import LocalSampler
181
+ self._sampler: Sampler = LocalSampler(graph, verbose)
182
+ elif graph.backend == DataBackend.SQLITE:
183
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
184
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
185
+ elif graph.backend == DataBackend.SNOWFLAKE:
186
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
187
+ self._sampler = SnowSampler(graph, verbose)
188
+ else:
189
+ raise NotImplementedError
190
+
191
+ self._client: RFMAPI | None = None
192
+
193
+ self._batch_size: int | Literal['max'] | None = None
176
194
  self.num_retries: int = 0
177
195
 
196
+ @property
197
+ def _api_client(self) -> RFMAPI:
198
+ if self._client is not None:
199
+ return self._client
200
+
201
+ from kumoai.experimental.rfm import global_state
202
+ self._client = RFMAPI(global_state.client)
203
+ return self._client
204
+
178
205
  def __repr__(self) -> str:
179
206
  return f'{self.__class__.__name__}()'
180
207
 
181
208
  @contextmanager
182
209
  def batch_mode(
183
210
  self,
184
- batch_size: Union[int, Literal['max']] = 'max',
211
+ batch_size: int | Literal['max'] = 'max',
185
212
  num_retries: int = 1,
186
213
  ) -> Generator[None, None, None]:
187
214
  """Context manager to predict in batches.
@@ -215,17 +242,17 @@ class KumoRFM:
215
242
  def predict(
216
243
  self,
217
244
  query: str,
218
- indices: Union[List[str], List[float], List[int], None] = None,
245
+ indices: list[str] | list[float] | list[int] | None = None,
219
246
  *,
220
247
  explain: Literal[False] = False,
221
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
222
- context_anchor_time: Union[pd.Timestamp, None] = None,
223
- run_mode: Union[RunMode, str] = RunMode.FAST,
224
- num_neighbors: Optional[List[int]] = None,
248
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
249
+ context_anchor_time: pd.Timestamp | None = None,
250
+ run_mode: RunMode | str = RunMode.FAST,
251
+ num_neighbors: list[int] | None = None,
225
252
  num_hops: int = 2,
226
- max_pq_iterations: int = 20,
227
- random_seed: Optional[int] = _RANDOM_SEED,
228
- verbose: Union[bool, ProgressLogger] = True,
253
+ max_pq_iterations: int = 10,
254
+ random_seed: int | None = _RANDOM_SEED,
255
+ verbose: bool | ProgressLogger = True,
229
256
  use_prediction_time: bool = False,
230
257
  ) -> pd.DataFrame:
231
258
  pass
@@ -234,17 +261,17 @@ class KumoRFM:
234
261
  def predict(
235
262
  self,
236
263
  query: str,
237
- indices: Union[List[str], List[float], List[int], None] = None,
264
+ indices: list[str] | list[float] | list[int] | None = None,
238
265
  *,
239
- explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
240
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
- context_anchor_time: Union[pd.Timestamp, None] = None,
242
- run_mode: Union[RunMode, str] = RunMode.FAST,
243
- num_neighbors: Optional[List[int]] = None,
266
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
267
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
268
+ context_anchor_time: pd.Timestamp | None = None,
269
+ run_mode: RunMode | str = RunMode.FAST,
270
+ num_neighbors: list[int] | None = None,
244
271
  num_hops: int = 2,
245
- max_pq_iterations: int = 20,
246
- random_seed: Optional[int] = _RANDOM_SEED,
247
- verbose: Union[bool, ProgressLogger] = True,
272
+ max_pq_iterations: int = 10,
273
+ random_seed: int | None = _RANDOM_SEED,
274
+ verbose: bool | ProgressLogger = True,
248
275
  use_prediction_time: bool = False,
249
276
  ) -> Explanation:
250
277
  pass
@@ -252,19 +279,19 @@ class KumoRFM:
252
279
  def predict(
253
280
  self,
254
281
  query: str,
255
- indices: Union[List[str], List[float], List[int], None] = None,
282
+ indices: list[str] | list[float] | list[int] | None = None,
256
283
  *,
257
- explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
258
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
259
- context_anchor_time: Union[pd.Timestamp, None] = None,
260
- run_mode: Union[RunMode, str] = RunMode.FAST,
261
- num_neighbors: Optional[List[int]] = None,
284
+ explain: bool | ExplainConfig | dict[str, Any] = False,
285
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
286
+ context_anchor_time: pd.Timestamp | None = None,
287
+ run_mode: RunMode | str = RunMode.FAST,
288
+ num_neighbors: list[int] | None = None,
262
289
  num_hops: int = 2,
263
- max_pq_iterations: int = 20,
264
- random_seed: Optional[int] = _RANDOM_SEED,
265
- verbose: Union[bool, ProgressLogger] = True,
290
+ max_pq_iterations: int = 10,
291
+ random_seed: int | None = _RANDOM_SEED,
292
+ verbose: bool | ProgressLogger = True,
266
293
  use_prediction_time: bool = False,
267
- ) -> Union[pd.DataFrame, Explanation]:
294
+ ) -> pd.DataFrame | Explanation:
268
295
  """Returns predictions for a predictive query.
269
296
 
270
297
  Args:
@@ -306,7 +333,7 @@ class KumoRFM:
306
333
  If ``explain`` is provided, returns an :class:`Explanation` object
307
334
  containing the prediction, summary, and details.
308
335
  """
309
- explain_config: Optional[ExplainConfig] = None
336
+ explain_config: ExplainConfig | None = None
310
337
  if explain is True:
311
338
  explain_config = ExplainConfig()
312
339
  elif explain is not False:
@@ -350,15 +377,15 @@ class KumoRFM:
350
377
  msg = f'[bold]PREDICT[/bold] {query_repr}'
351
378
 
352
379
  if not isinstance(verbose, ProgressLogger):
353
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
380
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
354
381
 
355
382
  with verbose as logger:
356
383
 
357
- batch_size: Optional[int] = None
384
+ batch_size: int | None = None
358
385
  if self._batch_size == 'max':
359
- task_type = LocalPQueryDriver.get_task_type(
360
- query_def,
361
- edge_types=self._graph_store.edge_types,
386
+ task_type = self._get_task_type(
387
+ query=query_def,
388
+ edge_types=self._sampler.edge_types,
362
389
  )
363
390
  batch_size = _MAX_PRED_SIZE[task_type]
364
391
  else:
@@ -374,9 +401,9 @@ class KumoRFM:
374
401
  logger.log(f"Splitting {len(indices):,} entities into "
375
402
  f"{len(batches):,} batches of size {batch_size:,}")
376
403
 
377
- predictions: List[pd.DataFrame] = []
378
- summary: Optional[str] = None
379
- details: Optional[Explanation] = None
404
+ predictions: list[pd.DataFrame] = []
405
+ summary: str | None = None
406
+ details: Explanation | None = None
380
407
  for i, batch in enumerate(batches):
381
408
  # TODO Re-use the context for subsequent predictions.
382
409
  context = self._get_context(
@@ -410,8 +437,7 @@ class KumoRFM:
410
437
  stats = Context.get_memory_stats(request_msg.context)
411
438
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
412
439
 
413
- if (isinstance(verbose, InteractiveProgressLogger) and i == 0
414
- and len(batches) > 1):
440
+ if i == 0 and len(batches) > 1:
415
441
  verbose.init_progress(
416
442
  total=len(batches),
417
443
  description='Predicting',
@@ -420,22 +446,22 @@ class KumoRFM:
420
446
  for attempt in range(self.num_retries + 1):
421
447
  try:
422
448
  if explain_config is not None:
423
- resp = global_state.client.rfm_api.explain(
449
+ resp = self._api_client.explain(
424
450
  request=_bytes,
425
451
  skip_summary=explain_config.skip_summary,
426
452
  )
427
453
  summary = resp.summary
428
454
  details = resp.details
429
455
  else:
430
- resp = global_state.client.rfm_api.predict(_bytes)
456
+ resp = self._api_client.predict(_bytes)
431
457
  df = pd.DataFrame(**resp.prediction)
432
458
 
433
459
  # Cast 'ENTITY' to correct data type:
434
460
  if 'ENTITY' in df:
435
- entity = query_def.entity_table
436
- pkey_map = self._graph_store.pkey_map_dict[entity]
437
- df['ENTITY'] = df['ENTITY'].astype(
438
- type(pkey_map.index[0]))
461
+ table_dict = context.subgraph.table_dict
462
+ table = table_dict[query_def.entity_table]
463
+ ser = table.df[table.primary_key]
464
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
439
465
 
440
466
  # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
467
  if 'ANCHOR_TIMESTAMP' in df:
@@ -450,8 +476,7 @@ class KumoRFM:
450
476
 
451
477
  predictions.append(df)
452
478
 
453
- if (isinstance(verbose, InteractiveProgressLogger)
454
- and len(batches) > 1):
479
+ if len(batches) > 1:
455
480
  verbose.step()
456
481
 
457
482
  break
@@ -489,9 +514,9 @@ class KumoRFM:
489
514
  def is_valid_entity(
490
515
  self,
491
516
  query: str,
492
- indices: Union[List[str], List[float], List[int], None] = None,
517
+ indices: list[str] | list[float] | list[int] | None = None,
493
518
  *,
494
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
519
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
495
520
  ) -> np.ndarray:
496
521
  r"""Returns a mask that denotes which entities are valid for the
497
522
  given predictive query, *i.e.*, which entities fulfill (temporal)
@@ -518,37 +543,32 @@ class KumoRFM:
518
543
  raise ValueError("At least one entity is required")
519
544
 
520
545
  if anchor_time is None:
521
- anchor_time = self._graph_store.max_time
546
+ anchor_time = self._get_default_anchor_time(query_def)
522
547
 
523
548
  if isinstance(anchor_time, pd.Timestamp):
524
549
  self._validate_time(query_def, anchor_time, None, False)
525
550
  else:
526
551
  assert anchor_time == 'entity'
527
- if (query_def.entity_table not in self._graph_store.time_dict):
552
+ if query_def.entity_table not in self._sampler.time_column_dict:
528
553
  raise ValueError(f"Anchor time 'entity' requires the entity "
529
554
  f"table '{query_def.entity_table}' "
530
555
  f"to have a time column.")
531
556
 
532
- node = self._graph_store.get_node_id(
533
- table_name=query_def.entity_table,
534
- pkey=pd.Series(indices),
535
- )
536
- query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
- return query_driver.is_valid(node, anchor_time)
557
+ raise NotImplementedError
538
558
 
539
559
  def evaluate(
540
560
  self,
541
561
  query: str,
542
562
  *,
543
- metrics: Optional[List[str]] = None,
544
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
545
- context_anchor_time: Union[pd.Timestamp, None] = None,
546
- run_mode: Union[RunMode, str] = RunMode.FAST,
547
- num_neighbors: Optional[List[int]] = None,
563
+ metrics: list[str] | None = None,
564
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
565
+ context_anchor_time: pd.Timestamp | None = None,
566
+ run_mode: RunMode | str = RunMode.FAST,
567
+ num_neighbors: list[int] | None = None,
548
568
  num_hops: int = 2,
549
- max_pq_iterations: int = 20,
550
- random_seed: Optional[int] = _RANDOM_SEED,
551
- verbose: Union[bool, ProgressLogger] = True,
569
+ max_pq_iterations: int = 10,
570
+ random_seed: int | None = _RANDOM_SEED,
571
+ verbose: bool | ProgressLogger = True,
552
572
  use_prediction_time: bool = False,
553
573
  ) -> pd.DataFrame:
554
574
  """Evaluates a predictive query.
@@ -596,7 +616,7 @@ class KumoRFM:
596
616
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
597
617
 
598
618
  if not isinstance(verbose, ProgressLogger):
599
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
619
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
600
620
 
601
621
  with verbose as logger:
602
622
  context = self._get_context(
@@ -633,7 +653,7 @@ class KumoRFM:
633
653
  raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
634
654
 
635
655
  try:
636
- resp = global_state.client.rfm_api.evaluate(request_bytes)
656
+ resp = self._api_client.evaluate(request_bytes)
637
657
  except HTTPException as e:
638
658
  try:
639
659
  msg = json.loads(e.detail)['detail']
@@ -655,9 +675,9 @@ class KumoRFM:
655
675
  query: str,
656
676
  size: int,
657
677
  *,
658
- anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
659
- random_seed: Optional[int] = _RANDOM_SEED,
660
- max_iterations: int = 20,
678
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
679
+ random_seed: int | None = _RANDOM_SEED,
680
+ max_iterations: int = 10,
661
681
  ) -> pd.DataFrame:
662
682
  """Returns the labels of a predictive query for a specified anchor
663
683
  time.
@@ -677,40 +697,37 @@ class KumoRFM:
677
697
  query_def = self._parse_query(query)
678
698
 
679
699
  if anchor_time is None:
680
- anchor_time = self._graph_store.max_time
700
+ anchor_time = self._get_default_anchor_time(query_def)
681
701
  if query_def.target_ast.date_offset_range is not None:
682
- anchor_time = anchor_time - (
683
- query_def.target_ast.date_offset_range.end_date_offset *
684
- query_def.num_forecasts)
702
+ offset = query_def.target_ast.date_offset_range.end_date_offset
703
+ offset *= query_def.num_forecasts
704
+ anchor_time -= offset
685
705
 
686
706
  assert anchor_time is not None
687
707
  if isinstance(anchor_time, pd.Timestamp):
688
708
  self._validate_time(query_def, anchor_time, None, evaluate=True)
689
709
  else:
690
710
  assert anchor_time == 'entity'
691
- if (query_def.entity_table not in self._graph_store.time_dict):
711
+ if query_def.entity_table not in self._sampler.time_column_dict:
692
712
  raise ValueError(f"Anchor time 'entity' requires the entity "
693
713
  f"table '{query_def.entity_table}' "
694
714
  f"to have a time column")
695
715
 
696
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
697
- random_seed)
698
-
699
- node, time, y = query_driver.collect_test(
700
- size=size,
701
- anchor_time=anchor_time,
702
- batch_size=min(10_000, size),
703
- max_iterations=max_iterations,
704
- guarantee_train_examples=False,
716
+ train, test = self._sampler.sample_target(
717
+ query=query_def,
718
+ num_train_examples=0,
719
+ train_anchor_time=anchor_time,
720
+ num_train_trials=0,
721
+ num_test_examples=size,
722
+ test_anchor_time=anchor_time,
723
+ num_test_trials=max_iterations * size,
724
+ random_seed=random_seed,
705
725
  )
706
726
 
707
- entity = self._graph_store.pkey_map_dict[
708
- query_def.entity_table].index[node]
709
-
710
727
  return pd.DataFrame({
711
- 'ENTITY': entity,
712
- 'ANCHOR_TIMESTAMP': time,
713
- 'TARGET': y,
728
+ 'ENTITY': test.entity_pkey,
729
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
730
+ 'TARGET': test.target,
714
731
  })
715
732
 
716
733
  # Helpers #################################################################
@@ -731,8 +748,7 @@ class KumoRFM:
731
748
  graph_definition=self._graph_def,
732
749
  )
733
750
 
734
- resp = global_state.client.rfm_api.parse_query(request)
735
- # TODO Expose validation warnings.
751
+ resp = self._api_client.parse_query(request)
736
752
 
737
753
  if len(resp.validation_response.warnings) > 0:
738
754
  msg = '\n'.join([
@@ -751,36 +767,92 @@ class KumoRFM:
751
767
  raise ValueError(f"Failed to parse query '{query}'. "
752
768
  f"{msg}") from None
753
769
 
770
+ @staticmethod
771
+ def _get_task_type(
772
+ query: ValidatedPredictiveQuery,
773
+ edge_types: list[tuple[str, str, str]],
774
+ ) -> TaskType:
775
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
776
+ return TaskType.BINARY_CLASSIFICATION
777
+
778
+ target = query.target_ast
779
+ if isinstance(target, Join):
780
+ target = target.rhs_target
781
+ if isinstance(target, Aggregation):
782
+ if target.aggr == AggregationType.LIST_DISTINCT:
783
+ table_name, col_name = target._get_target_column_name().split(
784
+ '.')
785
+ target_edge_types = [
786
+ edge_type for edge_type in edge_types
787
+ if edge_type[0] == table_name and edge_type[1] == col_name
788
+ ]
789
+ if len(target_edge_types) != 1:
790
+ raise NotImplementedError(
791
+ f"Multilabel-classification queries based on "
792
+ f"'LIST_DISTINCT' are not supported yet. If you "
793
+ f"planned to write a link prediction query instead, "
794
+ f"make sure to register '{col_name}' as a "
795
+ f"foreign key.")
796
+ return TaskType.TEMPORAL_LINK_PREDICTION
797
+
798
+ return TaskType.REGRESSION
799
+
800
+ assert isinstance(target, Column)
801
+
802
+ if target.stype in {Stype.ID, Stype.categorical}:
803
+ return TaskType.MULTICLASS_CLASSIFICATION
804
+
805
+ if target.stype in {Stype.numerical}:
806
+ return TaskType.REGRESSION
807
+
808
+ raise NotImplementedError("Task type not yet supported")
809
+
810
+ def _get_default_anchor_time(
811
+ self,
812
+ query: ValidatedPredictiveQuery,
813
+ ) -> pd.Timestamp:
814
+ if query.query_type == QueryType.TEMPORAL:
815
+ aggr_table_names = [
816
+ aggr._get_target_column_name().split('.')[0]
817
+ for aggr in query.get_all_target_aggregations()
818
+ ]
819
+ return self._sampler.get_max_time(aggr_table_names)
820
+
821
+ assert query.query_type == QueryType.STATIC
822
+ return self._sampler.get_max_time()
823
+
754
824
  def _validate_time(
755
825
  self,
756
826
  query: ValidatedPredictiveQuery,
757
827
  anchor_time: pd.Timestamp,
758
- context_anchor_time: Union[pd.Timestamp, None],
828
+ context_anchor_time: pd.Timestamp | None,
759
829
  evaluate: bool,
760
830
  ) -> None:
761
831
 
762
- if self._graph_store.min_time == pd.Timestamp.max:
832
+ if len(self._sampler.time_column_dict) == 0:
763
833
  return # Graph without timestamps
764
834
 
765
- if anchor_time < self._graph_store.min_time:
835
+ min_time = self._sampler.get_min_time()
836
+ max_time = self._sampler.get_max_time()
837
+
838
+ if anchor_time < min_time:
766
839
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
767
- f"the earliest timestamp "
768
- f"'{self._graph_store.min_time}' in the data.")
840
+ f"the earliest timestamp '{min_time}' in the "
841
+ f"data.")
769
842
 
770
- if (context_anchor_time is not None
771
- and context_anchor_time < self._graph_store.min_time):
843
+ if context_anchor_time is not None and context_anchor_time < min_time:
772
844
  raise ValueError(f"Context anchor timestamp is too early or "
773
845
  f"aggregation time range is too large. To make "
774
846
  f"this prediction, we would need data back to "
775
847
  f"'{context_anchor_time}', however, your data "
776
- f"only contains data back to "
777
- f"'{self._graph_store.min_time}'.")
848
+ f"only contains data back to '{min_time}'.")
778
849
 
779
850
  if query.target_ast.date_offset_range is not None:
780
851
  end_offset = query.target_ast.date_offset_range.end_date_offset
781
852
  else:
782
853
  end_offset = pd.DateOffset(0)
783
- forecast_end_offset = end_offset * query.num_forecasts
854
+ end_offset = end_offset * query.num_forecasts
855
+
784
856
  if (context_anchor_time is not None
785
857
  and context_anchor_time > anchor_time):
786
858
  warnings.warn(f"Context anchor timestamp "
@@ -790,7 +862,7 @@ class KumoRFM:
790
862
  f"intended.")
791
863
  elif (query.query_type == QueryType.TEMPORAL
792
864
  and context_anchor_time is not None
793
- and context_anchor_time + forecast_end_offset > anchor_time):
865
+ and context_anchor_time + end_offset > anchor_time):
794
866
  warnings.warn(f"Aggregation for context examples at timestamp "
795
867
  f"'{context_anchor_time}' will leak information "
796
868
  f"from the prediction anchor timestamp "
@@ -798,40 +870,37 @@ class KumoRFM:
798
870
  f"intended.")
799
871
 
800
872
  elif (context_anchor_time is not None
801
- and context_anchor_time - forecast_end_offset
802
- < self._graph_store.min_time):
803
- _time = context_anchor_time - forecast_end_offset
873
+ and context_anchor_time - end_offset < min_time):
874
+ _time = context_anchor_time - end_offset
804
875
  warnings.warn(f"Context anchor timestamp is too early or "
805
876
  f"aggregation time range is too large. To form "
806
877
  f"proper input data, we would need data back to "
807
878
  f"'{_time}', however, your data only contains "
808
- f"data back to '{self._graph_store.min_time}'.")
879
+ f"data back to '{min_time}'.")
809
880
 
810
- if (not evaluate and anchor_time
811
- > self._graph_store.max_time + pd.DateOffset(days=1)):
881
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
812
882
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
813
- f"latest timestamp '{self._graph_store.max_time}' "
814
- f"in the data. Please make sure this is intended.")
883
+ f"latest timestamp '{max_time}' in the data. Please "
884
+ f"make sure this is intended.")
815
885
 
816
- max_eval_time = self._graph_store.max_time - forecast_end_offset
817
- if evaluate and anchor_time > max_eval_time:
886
+ if evaluate and anchor_time > max_time - end_offset:
818
887
  raise ValueError(
819
888
  f"Anchor timestamp for evaluation is after the latest "
820
- f"supported timestamp '{max_eval_time}'.")
889
+ f"supported timestamp '{max_time - end_offset}'.")
821
890
 
822
891
  def _get_context(
823
892
  self,
824
893
  query: ValidatedPredictiveQuery,
825
- indices: Union[List[str], List[float], List[int], None],
826
- anchor_time: Union[pd.Timestamp, Literal['entity'], None],
827
- context_anchor_time: Union[pd.Timestamp, None],
894
+ indices: list[str] | list[float] | list[int] | None,
895
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
896
+ context_anchor_time: pd.Timestamp | None,
828
897
  run_mode: RunMode,
829
- num_neighbors: Optional[List[int]],
898
+ num_neighbors: list[int] | None,
830
899
  num_hops: int,
831
900
  max_pq_iterations: int,
832
901
  evaluate: bool,
833
- random_seed: Optional[int] = _RANDOM_SEED,
834
- logger: Optional[ProgressLogger] = None,
902
+ random_seed: int | None = _RANDOM_SEED,
903
+ logger: ProgressLogger | None = None,
835
904
  ) -> Context:
836
905
 
837
906
  if num_neighbors is not None:
@@ -848,10 +917,9 @@ class KumoRFM:
848
917
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
849
918
  f"must go beyond this for your use-case.")
850
919
 
851
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
852
- task_type = LocalPQueryDriver.get_task_type(
853
- query,
854
- edge_types=self._graph_store.edge_types,
920
+ task_type = self._get_task_type(
921
+ query=query,
922
+ edge_types=self._sampler.edge_types,
855
923
  )
856
924
 
857
925
  if logger is not None:
@@ -883,14 +951,17 @@ class KumoRFM:
883
951
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
884
952
 
885
953
  if query.target_ast.date_offset_range is None:
886
- end_offset = pd.DateOffset(0)
954
+ step_offset = pd.DateOffset(0)
887
955
  else:
888
- end_offset = query.target_ast.date_offset_range.end_date_offset
889
- forecast_end_offset = end_offset * query.num_forecasts
956
+ step_offset = query.target_ast.date_offset_range.end_date_offset
957
+ end_offset = step_offset * query.num_forecasts
958
+
890
959
  if anchor_time is None:
891
- anchor_time = self._graph_store.max_time
960
+ anchor_time = self._get_default_anchor_time(query)
961
+
892
962
  if evaluate:
893
- anchor_time = anchor_time - forecast_end_offset
963
+ anchor_time = anchor_time - end_offset
964
+
894
965
  if logger is not None:
895
966
  assert isinstance(anchor_time, pd.Timestamp)
896
967
  if anchor_time == pd.Timestamp.min:
@@ -904,57 +975,71 @@ class KumoRFM:
904
975
 
905
976
  assert anchor_time is not None
906
977
  if isinstance(anchor_time, pd.Timestamp):
978
+ if context_anchor_time == 'entity':
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
907
981
  if context_anchor_time is None:
908
- context_anchor_time = anchor_time - forecast_end_offset
982
+ context_anchor_time = anchor_time - end_offset
909
983
  self._validate_time(query, anchor_time, context_anchor_time,
910
984
  evaluate)
911
985
  else:
912
986
  assert anchor_time == 'entity'
913
- if query.entity_table not in self._graph_store.time_dict:
987
+ if query.query_type != QueryType.STATIC:
988
+ raise ValueError("Anchor time 'entity' is only valid for "
989
+ "static predictive queries")
990
+ if query.entity_table not in self._sampler.time_column_dict:
914
991
  raise ValueError(f"Anchor time 'entity' requires the entity "
915
992
  f"table '{query.entity_table}' to "
916
993
  f"have a time column")
917
- if context_anchor_time is not None:
918
- warnings.warn("Ignoring option 'context_anchor_time' for "
919
- "`anchor_time='entity'`")
920
- context_anchor_time = None
994
+ if isinstance(context_anchor_time, pd.Timestamp):
995
+ raise ValueError("Anchor time 'entity' needs to be shared "
996
+ "for context and prediction examples")
997
+ context_anchor_time = 'entity'
921
998
 
922
- y_test: Optional[pd.Series] = None
999
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
923
1000
  if evaluate:
924
- max_test_size = _MAX_TEST_SIZE[run_mode]
1001
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
925
1002
  if task_type.is_link_pred:
926
- max_test_size = max_test_size // 5
1003
+ num_test_examples = num_test_examples // 5
1004
+ else:
1005
+ num_test_examples = 0
1006
+
1007
+ train, test = self._sampler.sample_target(
1008
+ query=query,
1009
+ num_train_examples=num_train_examples,
1010
+ train_anchor_time=context_anchor_time,
1011
+ num_train_trials=max_pq_iterations * num_train_examples,
1012
+ num_test_examples=num_test_examples,
1013
+ test_anchor_time=anchor_time,
1014
+ num_test_trials=max_pq_iterations * num_test_examples,
1015
+ random_seed=random_seed,
1016
+ )
1017
+ train_pkey, train_time, y_train = train
1018
+ test_pkey, test_time, y_test = test
927
1019
 
928
- test_node, test_time, y_test = query_driver.collect_test(
929
- size=max_test_size,
930
- anchor_time=anchor_time,
931
- max_iterations=max_pq_iterations,
932
- guarantee_train_examples=True,
933
- )
934
- if logger is not None:
935
- if task_type == TaskType.BINARY_CLASSIFICATION:
936
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
937
- msg = (f"Collected {len(y_test):,} test examples with "
938
- f"{pos:.2f}% positive cases")
939
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
940
- msg = (f"Collected {len(y_test):,} test examples "
941
- f"holding {y_test.nunique()} classes")
942
- elif task_type == TaskType.REGRESSION:
943
- _min, _max = float(y_test.min()), float(y_test.max())
944
- msg = (f"Collected {len(y_test):,} test examples with "
945
- f"targets between {format_value(_min)} and "
946
- f"{format_value(_max)}")
947
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
948
- num_rhs = y_test.explode().nunique()
949
- msg = (f"Collected {len(y_test):,} test examples with "
950
- f"{num_rhs:,} unique items")
951
- else:
952
- raise NotImplementedError
953
- logger.log(msg)
1020
+ if evaluate and logger is not None:
1021
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1022
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1023
+ msg = (f"Collected {len(y_test):,} test examples with "
1024
+ f"{pos:.2f}% positive cases")
1025
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1026
+ msg = (f"Collected {len(y_test):,} test examples holding "
1027
+ f"{y_test.nunique()} classes")
1028
+ elif task_type == TaskType.REGRESSION:
1029
+ _min, _max = float(y_test.min()), float(y_test.max())
1030
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1031
+ f"between {format_value(_min)} and "
1032
+ f"{format_value(_max)}")
1033
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1034
+ num_rhs = y_test.explode().nunique()
1035
+ msg = (f"Collected {len(y_test):,} test examples with "
1036
+ f"{num_rhs:,} unique items")
1037
+ else:
1038
+ raise NotImplementedError
1039
+ logger.log(msg)
954
1040
 
955
- else:
1041
+ if not evaluate:
956
1042
  assert indices is not None
957
-
958
1043
  if len(indices) > _MAX_PRED_SIZE[task_type]:
959
1044
  raise ValueError(f"Cannot predict for more than "
960
1045
  f"{_MAX_PRED_SIZE[task_type]:,} entities at "
@@ -962,26 +1047,12 @@ class KumoRFM:
962
1047
  f"`KumoRFM.batch_mode` to process entities "
963
1048
  f"in batches")
964
1049
 
965
- test_node = self._graph_store.get_node_id(
966
- table_name=query.entity_table,
967
- pkey=pd.Series(indices),
968
- )
969
-
1050
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
970
1051
  if isinstance(anchor_time, pd.Timestamp):
971
- test_time = pd.Series(anchor_time).repeat(
972
- len(test_node)).reset_index(drop=True)
1052
+ test_time = pd.Series([anchor_time]).repeat(
1053
+ len(indices)).reset_index(drop=True)
973
1054
  else:
974
- time = self._graph_store.time_dict[query.entity_table]
975
- time = time[test_node] * 1000**3
976
- test_time = pd.Series(time, dtype='datetime64[ns]')
977
-
978
- train_node, train_time, y_train = query_driver.collect_train(
979
- size=_MAX_CONTEXT_SIZE[run_mode],
980
- anchor_time=context_anchor_time or 'entity',
981
- exclude_node=test_node if (query.query_type == QueryType.STATIC
982
- or anchor_time == 'entity') else None,
983
- max_iterations=max_pq_iterations,
984
- )
1055
+ train_time = test_time = 'entity'
985
1056
 
986
1057
  if logger is not None:
987
1058
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -1004,12 +1075,12 @@ class KumoRFM:
1004
1075
  raise NotImplementedError
1005
1076
  logger.log(msg)
1006
1077
 
1007
- entity_table_names: Tuple[str, ...]
1078
+ entity_table_names: tuple[str, ...]
1008
1079
  if task_type.is_link_pred:
1009
1080
  final_aggr = query.get_final_target_aggregation()
1010
1081
  assert final_aggr is not None
1011
1082
  edge_fkey = final_aggr._get_target_column_name()
1012
- for edge_type in self._graph_store.edge_types:
1083
+ for edge_type in self._sampler.edge_types:
1013
1084
  if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
1085
  entity_table_names = (
1015
1086
  query.entity_table,
@@ -1021,21 +1092,24 @@ class KumoRFM:
1021
1092
  # Exclude the entity anchor time from the feature set to prevent
1022
1093
  # running out-of-distribution between in-context and test examples:
1023
1094
  exclude_cols_dict = query.get_exclude_cols_dict()
1024
- if anchor_time == 'entity':
1095
+ if entity_table_names[0] in self._sampler.time_column_dict:
1025
1096
  if entity_table_names[0] not in exclude_cols_dict:
1026
1097
  exclude_cols_dict[entity_table_names[0]] = []
1027
- time_column_dict = self._graph_store.time_column_dict
1028
- time_column = time_column_dict[entity_table_names[0]]
1098
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1029
1099
  exclude_cols_dict[entity_table_names[0]].append(time_column)
1030
1100
 
1031
- subgraph = self._graph_sampler(
1101
+ subgraph = self._sampler.sample_subgraph(
1032
1102
  entity_table_names=entity_table_names,
1033
- node=np.concatenate([train_node, test_node]),
1034
- time=np.concatenate([
1035
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1036
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1037
- ]),
1038
- run_mode=run_mode,
1103
+ entity_pkey=pd.concat(
1104
+ [train_pkey, test_pkey],
1105
+ axis=0,
1106
+ ignore_index=True,
1107
+ ),
1108
+ anchor_time=pd.concat(
1109
+ [train_time, test_time],
1110
+ axis=0,
1111
+ ignore_index=True,
1112
+ ) if isinstance(train_time, pd.Series) else 'entity',
1039
1113
  num_neighbors=num_neighbors,
1040
1114
  exclude_cols_dict=exclude_cols_dict,
1041
1115
  )
@@ -1047,23 +1121,19 @@ class KumoRFM:
1047
1121
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
1048
1122
  f"must go beyond this for your use-case.")
1049
1123
 
1050
- step_size: Optional[int] = None
1051
- if query.query_type == QueryType.TEMPORAL:
1052
- step_size = date_offset_to_seconds(end_offset)
1053
-
1054
1124
  return Context(
1055
1125
  task_type=task_type,
1056
1126
  entity_table_names=entity_table_names,
1057
1127
  subgraph=subgraph,
1058
1128
  y_train=y_train,
1059
- y_test=y_test,
1129
+ y_test=y_test if evaluate else None,
1060
1130
  top_k=query.top_k,
1061
- step_size=step_size,
1131
+ step_size=None,
1062
1132
  )
1063
1133
 
1064
1134
  @staticmethod
1065
1135
  def _validate_metrics(
1066
- metrics: List[str],
1136
+ metrics: list[str],
1067
1137
  task_type: TaskType,
1068
1138
  ) -> None:
1069
1139
 
@@ -1120,7 +1190,7 @@ class KumoRFM:
1120
1190
  f"'https://github.com/kumo-ai/kumo-rfm'.")
1121
1191
 
1122
1192
 
1123
- def format_value(value: Union[int, float]) -> str:
1193
+ def format_value(value: int | float) -> str:
1124
1194
  if value == int(value):
1125
1195
  return f'{int(value):,}'
1126
1196
  if abs(value) >= 1000: