kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +37 -8
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
  12. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  13. kumoai/experimental/rfm/backend/local/table.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
  20. kumoai/experimental/rfm/base/__init__.py +25 -0
  21. kumoai/experimental/rfm/base/column.py +66 -0
  22. kumoai/experimental/rfm/base/sampler.py +773 -0
  23. kumoai/experimental/rfm/base/source.py +19 -0
  24. kumoai/experimental/rfm/base/sql_sampler.py +60 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
  27. kumoai/experimental/rfm/infer/__init__.py +6 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +126 -0
  30. kumoai/experimental/rfm/infer/time_col.py +62 -0
  31. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  32. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  33. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  34. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
  35. kumoai/experimental/rfm/rfm.py +669 -246
  36. kumoai/experimental/rfm/sagemaker.py +138 -0
  37. kumoai/jobs.py +1 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/trainer.py +12 -10
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +239 -4
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
  47. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -176
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -404
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
  52. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,56 @@
1
1
  import json
2
+ import time
2
3
  import warnings
3
- from typing import List, Literal, Optional, Union
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
4
19
 
5
20
  import numpy as np
6
21
  import pandas as pd
7
22
  from kumoapi.model_plan import RunMode
8
- from kumoapi.pquery import QueryType
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.pquery.AST import (
25
+ Aggregation,
26
+ Column,
27
+ Condition,
28
+ Join,
29
+ LogicalOperation,
30
+ )
31
+ from kumoapi.rfm import Context
32
+ from kumoapi.rfm import Explanation as ExplanationConfig
9
33
  from kumoapi.rfm import (
10
- Context,
11
- PQueryDefinition,
12
34
  RFMEvaluateRequest,
35
+ RFMParseQueryRequest,
13
36
  RFMPredictRequest,
14
- RFMValidateQueryRequest,
15
37
  )
16
38
  from kumoapi.task import TaskType
39
+ from kumoapi.typing import AggregationType, Stype
17
40
 
18
- from kumoai import global_state
41
+ from kumoai import in_notebook, in_snowflake_notebook
42
+ from kumoai.client.rfm import RFMAPI
19
43
  from kumoai.exceptions import HTTPException
20
- from kumoai.experimental.rfm import LocalGraph
21
- from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
22
- from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
23
- from kumoai.experimental.rfm.local_pquery_driver import (
24
- LocalPQueryDriver,
25
- date_offset_to_seconds,
26
- )
27
- from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
+ from kumoai.experimental.rfm import Graph
45
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
46
+ from kumoai.mixin import CastMixin
47
+ from kumoai.utils import ProgressLogger
28
48
 
29
49
  _RANDOM_SEED = 42
30
50
 
51
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
52
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
53
+
31
54
  _MAX_CONTEXT_SIZE = {
32
55
  RunMode.DEBUG: 100,
33
56
  RunMode.FAST: 1_000,
@@ -42,7 +65,7 @@ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
42
65
  }
43
66
 
44
67
  _MAX_SIZE = 30 * 1024 * 1024
45
- _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
68
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
46
69
  "reduce either the number of tables in the graph, their "
47
70
  "number of columns (e.g., large text columns), "
48
71
  "neighborhood configuration, or the run mode. If none of "
@@ -51,6 +74,68 @@ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats_msg}\nPlease "
51
74
  "beyond this for your use-case.")
52
75
 
53
76
 
77
+ @dataclass(repr=False)
78
+ class ExplainConfig(CastMixin):
79
+ """Configuration for explainability.
80
+
81
+ Args:
82
+ skip_summary: Whether to skip generating a human-readable summary of
83
+ the explanation.
84
+ """
85
+ skip_summary: bool = False
86
+
87
+
88
+ @dataclass(repr=False)
89
+ class Explanation:
90
+ prediction: pd.DataFrame
91
+ summary: str
92
+ details: ExplanationConfig
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
96
+ pass
97
+
98
+ @overload
99
+ def __getitem__(self, index: Literal[1]) -> str:
100
+ pass
101
+
102
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
103
+ if index == 0:
104
+ return self.prediction
105
+ if index == 1:
106
+ return self.summary
107
+ raise IndexError("Index out of range")
108
+
109
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
110
+ return iter((self.prediction, self.summary))
111
+
112
+ def __repr__(self) -> str:
113
+ return str((self.prediction, self.summary))
114
+
115
+ def print(self) -> None:
116
+ r"""Prints the explanation."""
117
+ if in_snowflake_notebook():
118
+ import streamlit as st
119
+ st.dataframe(self.prediction, hide_index=True)
120
+ st.markdown(self.summary)
121
+ elif in_notebook():
122
+ from IPython.display import Markdown, display
123
+ try:
124
+ if hasattr(self.prediction.style, 'hide'):
125
+ display(self.prediction.hide(axis='index')) # pandas=2
126
+ else:
127
+ display(self.prediction.hide_index()) # pandas <1.3
128
+ except ImportError:
129
+ print(self.prediction.to_string(index=False)) # missing jinja2
130
+ display(Markdown(self.summary))
131
+ else:
132
+ print(self.prediction.to_string(index=False))
133
+ print(self.summary)
134
+
135
+ def _ipython_display_(self) -> None:
136
+ self.print()
137
+
138
+
54
139
  class KumoRFM:
55
140
  r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
56
141
  Foundation Model for In-Context Learning on Relational Data
@@ -59,17 +144,17 @@ class KumoRFM:
59
144
  :class:`KumoRFM` is a foundation model to generate predictions for any
60
145
  relational dataset without training.
61
146
  The model is pre-trained and the class provides an interface to query the
62
- model from a :class:`LocalGraph` object.
147
+ model from a :class:`Graph` object.
63
148
 
64
149
  .. code-block:: python
65
150
 
66
- from kumoai.experimental.rfm import LocalGraph, KumoRFM
151
+ from kumoai.experimental.rfm import Graph, KumoRFM
67
152
 
68
153
  df_users = pd.DataFrame(...)
69
154
  df_items = pd.DataFrame(...)
70
155
  df_orders = pd.DataFrame(...)
71
156
 
72
- graph = LocalGraph.from_data({
157
+ graph = Graph.from_data({
73
158
  'users': df_users,
74
159
  'items': df_items,
75
160
  'orders': df_orders,
@@ -77,62 +162,166 @@ class KumoRFM:
77
162
 
78
163
  rfm = KumoRFM(graph)
79
164
 
80
- query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
81
- "FOR users.user_id=0")
82
- result = rfm.query(query)
165
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
166
+ "FOR users.user_id=1")
167
+ result = rfm.predict(query)
83
168
 
84
169
  print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
85
170
  # 1 0.85
86
171
 
87
172
  Args:
88
173
  graph: The graph.
89
- preprocess: Whether to pre-process the data in advance during graph
90
- materialization.
91
- This is a runtime trade-off between graph materialization and model
92
- processing speed.
93
- It can be benefical to preprocess your data once and then run many
94
- queries on top to achieve maximum model speed.
95
- However, if activiated, graph materialization can take potentially
96
- much longer, especially on graphs with many large text columns.
97
- Best to tune this option manually.
98
174
  verbose: Whether to print verbose output.
175
+ optimize: If set to ``True``, will optimize the underlying data backend
176
+ for optimal querying. For example, for transactional database
177
+ backends, will create any missing indices. Requires write-access to
178
+ the data backend.
99
179
  """
100
180
  def __init__(
101
181
  self,
102
- graph: LocalGraph,
103
- preprocess: bool = False,
182
+ graph: Graph,
104
183
  verbose: Union[bool, ProgressLogger] = True,
184
+ optimize: bool = False,
105
185
  ) -> None:
106
186
  graph = graph.validate()
107
187
  self._graph_def = graph._to_api_graph_definition()
108
- self._graph_store = LocalGraphStore(graph, preprocess, verbose)
109
- self._graph_sampler = LocalGraphSampler(self._graph_store)
188
+
189
+ if graph.backend == DataBackend.LOCAL:
190
+ from kumoai.experimental.rfm.backend.local import LocalSampler
191
+ self._sampler: Sampler = LocalSampler(graph, verbose)
192
+ elif graph.backend == DataBackend.SQLITE:
193
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
194
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
195
+ elif graph.backend == DataBackend.SNOWFLAKE:
196
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
197
+ self._sampler = SnowSampler(graph, verbose)
198
+ else:
199
+ raise NotImplementedError
200
+
201
+ self._client: Optional[RFMAPI] = None
202
+
203
+ self._batch_size: Optional[int | Literal['max']] = None
204
+ self.num_retries: int = 0
205
+
206
+ @property
207
+ def _api_client(self) -> RFMAPI:
208
+ if self._client is not None:
209
+ return self._client
210
+
211
+ from kumoai.experimental.rfm import global_state
212
+ self._client = RFMAPI(global_state.client)
213
+ return self._client
110
214
 
111
215
  def __repr__(self) -> str:
112
216
  return f'{self.__class__.__name__}()'
113
217
 
218
+ @contextmanager
219
+ def batch_mode(
220
+ self,
221
+ batch_size: Union[int, Literal['max']] = 'max',
222
+ num_retries: int = 1,
223
+ ) -> Generator[None, None, None]:
224
+ """Context manager to predict in batches.
225
+
226
+ .. code-block:: python
227
+
228
+ with model.batch_mode(batch_size='max', num_retries=1):
229
+ df = model.predict(query, indices=...)
230
+
231
+ Args:
232
+ batch_size: The batch size. If set to ``"max"``, will use the
233
+ maximum applicable batch size for the given task.
234
+ num_retries: The maximum number of retries for failed queries due
235
+ to unexpected server issues.
236
+ """
237
+ if batch_size != 'max' and batch_size <= 0:
238
+ raise ValueError(f"'batch_size' must be greater than zero "
239
+ f"(got {batch_size})")
240
+
241
+ if num_retries < 0:
242
+ raise ValueError(f"'num_retries' must be greater than or equal to "
243
+ f"zero (got {num_retries})")
244
+
245
+ self._batch_size = batch_size
246
+ self.num_retries = num_retries
247
+ yield
248
+ self._batch_size = None
249
+ self.num_retries = 0
250
+
251
+ @overload
114
252
  def predict(
115
253
  self,
116
254
  query: str,
255
+ indices: Union[List[str], List[float], List[int], None] = None,
117
256
  *,
257
+ explain: Literal[False] = False,
118
258
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
119
259
  context_anchor_time: Union[pd.Timestamp, None] = None,
120
260
  run_mode: Union[RunMode, str] = RunMode.FAST,
121
261
  num_neighbors: Optional[List[int]] = None,
122
262
  num_hops: int = 2,
123
- max_pq_iterations: int = 20,
263
+ max_pq_iterations: int = 10,
124
264
  random_seed: Optional[int] = _RANDOM_SEED,
125
265
  verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
126
267
  ) -> pd.DataFrame:
268
+ pass
269
+
270
+ @overload
271
+ def predict(
272
+ self,
273
+ query: str,
274
+ indices: Union[List[str], List[float], List[int], None] = None,
275
+ *,
276
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
277
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
278
+ context_anchor_time: Union[pd.Timestamp, None] = None,
279
+ run_mode: Union[RunMode, str] = RunMode.FAST,
280
+ num_neighbors: Optional[List[int]] = None,
281
+ num_hops: int = 2,
282
+ max_pq_iterations: int = 10,
283
+ random_seed: Optional[int] = _RANDOM_SEED,
284
+ verbose: Union[bool, ProgressLogger] = True,
285
+ use_prediction_time: bool = False,
286
+ ) -> Explanation:
287
+ pass
288
+
289
+ def predict(
290
+ self,
291
+ query: str,
292
+ indices: Union[List[str], List[float], List[int], None] = None,
293
+ *,
294
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
295
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
296
+ context_anchor_time: Union[pd.Timestamp, None] = None,
297
+ run_mode: Union[RunMode, str] = RunMode.FAST,
298
+ num_neighbors: Optional[List[int]] = None,
299
+ num_hops: int = 2,
300
+ max_pq_iterations: int = 10,
301
+ random_seed: Optional[int] = _RANDOM_SEED,
302
+ verbose: Union[bool, ProgressLogger] = True,
303
+ use_prediction_time: bool = False,
304
+ ) -> Union[pd.DataFrame, Explanation]:
127
305
  """Returns predictions for a predictive query.
128
306
 
129
307
  Args:
130
308
  query: The predictive query.
309
+ indices: The entity primary keys to predict for. Will override the
310
+ indices given as part of the predictive query. Predictions will
311
+ be generated for all indices, independent of whether they
312
+ fulfill entity filter constraints. To pre-filter entities, use
313
+ :meth:`~KumoRFM.is_valid_entity`.
314
+ explain: Configuration for explainability.
315
+ If set to ``True``, will additionally explain the prediction.
316
+ Passing in an :class:`ExplainConfig` instance provides control
317
+ over which parts of explanation are generated.
318
+ Explainability is currently only supported for single entity
319
+ predictions with ``run_mode="FAST"``.
131
320
  anchor_time: The anchor timestamp for the prediction. If set to
132
- :obj:`None`, will use the maximum timestamp in the data.
133
- If set to :`"entity"`, will use the timestamp of the entity.
321
+ ``None``, will use the maximum timestamp in the data.
322
+ If set to ``"entity"``, will use the timestamp of the entity.
134
323
  context_anchor_time: The maximum anchor timestamp for context
135
- examples. If set to :obj:`None`, :obj:`anchor_time` will
324
+ examples. If set to ``None``, ``anchor_time`` will
136
325
  determine the anchor time for context examples.
137
326
  run_mode: The :class:`RunMode` for the query.
138
327
  num_neighbors: The number of neighbors to sample for each hop.
@@ -145,83 +334,237 @@ class KumoRFM:
145
334
  entities to find valid labels.
146
335
  random_seed: A manual seed for generating pseudo-random numbers.
147
336
  verbose: Whether to print verbose output.
337
+ use_prediction_time: Whether to use the anchor timestamp as an
338
+ additional feature during prediction. This is typically
339
+ beneficial for time series forecasting tasks.
148
340
 
149
341
  Returns:
150
- The predictions as a :class:`pandas.DataFrame`
342
+ The predictions as a :class:`pandas.DataFrame`.
343
+ If ``explain`` is provided, returns an :class:`Explanation` object
344
+ containing the prediction, summary, and details.
151
345
  """
152
- explain = False
346
+ explain_config: Optional[ExplainConfig] = None
347
+ if explain is True:
348
+ explain_config = ExplainConfig()
349
+ elif explain is not False:
350
+ explain_config = ExplainConfig._cast(explain)
351
+
153
352
  query_def = self._parse_query(query)
353
+ query_str = query_def.to_string()
154
354
 
155
355
  if num_hops != 2 and num_neighbors is not None:
156
356
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
157
357
  f"custom 'num_hops={num_hops}' option")
158
358
 
159
- if explain and run_mode in {RunMode.NORMAL, RunMode.BEST}:
359
+ if explain_config is not None and run_mode in {
360
+ RunMode.NORMAL, RunMode.BEST
361
+ }:
160
362
  warnings.warn(f"Explainability is currently only supported for "
161
363
  f"run mode 'FAST' (got '{run_mode}'). Provided run "
162
364
  f"mode has been reset. Please lower the run mode to "
163
365
  f"suppress this warning.")
164
366
 
165
- if explain:
166
- assert query_def.entity.ids is not None
167
- if len(query_def.entity.ids.value) > 1:
168
- raise ValueError(
169
- f"Cannot explain predictions for more than a single "
170
- f"entity (got {len(query_def.entity.ids.value)})")
367
+ if indices is None:
368
+ if query_def.rfm_entity_ids is None:
369
+ raise ValueError("Cannot find entities to predict for. Please "
370
+ "pass them via `predict(query, indices=...)`")
371
+ indices = query_def.get_rfm_entity_id_list()
372
+ else:
373
+ query_def = replace(query_def, rfm_entity_ids=None)
374
+
375
+ if len(indices) == 0:
376
+ raise ValueError("At least one entity is required")
377
+
378
+ if explain_config is not None and len(indices) > 1:
379
+ raise ValueError(
380
+ f"Cannot explain predictions for more than a single entity "
381
+ f"(got {len(indices)})")
171
382
 
172
383
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
173
- if explain:
384
+ if explain_config is not None:
174
385
  msg = f'[bold]EXPLAIN[/bold] {query_repr}'
175
386
  else:
176
387
  msg = f'[bold]PREDICT[/bold] {query_repr}'
177
388
 
178
389
  if not isinstance(verbose, ProgressLogger):
179
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
390
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
180
391
 
181
392
  with verbose as logger:
182
- context = self._get_context(
183
- query_def,
184
- anchor_time=anchor_time,
185
- context_anchor_time=context_anchor_time,
186
- run_mode=RunMode(run_mode),
187
- num_neighbors=num_neighbors,
188
- num_hops=num_hops,
189
- max_pq_iterations=max_pq_iterations,
190
- evaluate=False,
191
- random_seed=random_seed,
192
- logger=logger,
193
- )
194
- request = RFMPredictRequest(
195
- context=context,
196
- run_mode=RunMode(run_mode),
393
+
394
+ batch_size: Optional[int] = None
395
+ if self._batch_size == 'max':
396
+ task_type = self._get_task_type(
397
+ query=query_def,
398
+ edge_types=self._sampler.edge_types,
399
+ )
400
+ batch_size = _MAX_PRED_SIZE[task_type]
401
+ else:
402
+ batch_size = self._batch_size
403
+
404
+ if batch_size is not None:
405
+ offsets = range(0, len(indices), batch_size)
406
+ batches = [indices[step:step + batch_size] for step in offsets]
407
+ else:
408
+ batches = [indices]
409
+
410
+ if len(batches) > 1:
411
+ logger.log(f"Splitting {len(indices):,} entities into "
412
+ f"{len(batches):,} batches of size {batch_size:,}")
413
+
414
+ predictions: List[pd.DataFrame] = []
415
+ summary: Optional[str] = None
416
+ details: Optional[Explanation] = None
417
+ for i, batch in enumerate(batches):
418
+ # TODO Re-use the context for subsequent predictions.
419
+ context = self._get_context(
420
+ query=query_def,
421
+ indices=batch,
422
+ anchor_time=anchor_time,
423
+ context_anchor_time=context_anchor_time,
424
+ run_mode=RunMode(run_mode),
425
+ num_neighbors=num_neighbors,
426
+ num_hops=num_hops,
427
+ max_pq_iterations=max_pq_iterations,
428
+ evaluate=False,
429
+ random_seed=random_seed,
430
+ logger=logger if i == 0 else None,
431
+ )
432
+ request = RFMPredictRequest(
433
+ context=context,
434
+ run_mode=RunMode(run_mode),
435
+ query=query_str,
436
+ use_prediction_time=use_prediction_time,
437
+ )
438
+ with warnings.catch_warnings():
439
+ warnings.filterwarnings('ignore', message='gencode')
440
+ request_msg = request.to_protobuf()
441
+ _bytes = request_msg.SerializeToString()
442
+ if i == 0:
443
+ logger.log(f"Generated context of size "
444
+ f"{len(_bytes) / (1024*1024):.2f}MB")
445
+
446
+ if len(_bytes) > _MAX_SIZE:
447
+ stats = Context.get_memory_stats(request_msg.context)
448
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
449
+
450
+ if i == 0 and len(batches) > 1:
451
+ verbose.init_progress(
452
+ total=len(batches),
453
+ description='Predicting',
454
+ )
455
+
456
+ for attempt in range(self.num_retries + 1):
457
+ try:
458
+ if explain_config is not None:
459
+ resp = self._api_client.explain(
460
+ request=_bytes,
461
+ skip_summary=explain_config.skip_summary,
462
+ )
463
+ summary = resp.summary
464
+ details = resp.details
465
+ else:
466
+ resp = self._api_client.predict(_bytes)
467
+ df = pd.DataFrame(**resp.prediction)
468
+
469
+ # Cast 'ENTITY' to correct data type:
470
+ if 'ENTITY' in df:
471
+ table_dict = context.subgraph.table_dict
472
+ table = table_dict[query_def.entity_table]
473
+ ser = table.df[table.primary_key]
474
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
475
+
476
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
477
+ if 'ANCHOR_TIMESTAMP' in df:
478
+ ser = df['ANCHOR_TIMESTAMP']
479
+ if not pd.api.types.is_datetime64_any_dtype(ser):
480
+ if isinstance(ser.iloc[0], str):
481
+ unit = None
482
+ else:
483
+ unit = 'ms'
484
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
485
+ ser, errors='coerce', unit=unit)
486
+
487
+ predictions.append(df)
488
+
489
+ if len(batches) > 1:
490
+ verbose.step()
491
+
492
+ break
493
+ except HTTPException as e:
494
+ if attempt == self.num_retries:
495
+ try:
496
+ msg = json.loads(e.detail)['detail']
497
+ except Exception:
498
+ msg = e.detail
499
+ raise RuntimeError(
500
+ f"An unexpected exception occurred. Please "
501
+ f"create an issue at "
502
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
503
+ ) from None
504
+
505
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
506
+
507
+ if len(predictions) == 1:
508
+ prediction = predictions[0]
509
+ else:
510
+ prediction = pd.concat(predictions, ignore_index=True)
511
+
512
+ if explain_config is not None:
513
+ assert len(predictions) == 1
514
+ assert summary is not None
515
+ assert details is not None
516
+ return Explanation(
517
+ prediction=prediction,
518
+ summary=summary,
519
+ details=details,
197
520
  )
198
- with warnings.catch_warnings():
199
- warnings.filterwarnings('ignore', message='Protobuf gencode')
200
- request_msg = request.to_protobuf()
201
- request_bytes = request_msg.SerializeToString()
202
- logger.log(f"Generated context of size "
203
- f"{len(request_bytes) / (1024*1024):.2f}MB")
204
521
 
205
- if len(request_bytes) > _MAX_SIZE:
206
- stats_msg = Context.get_memory_stats(request_msg.context)
207
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
522
+ return prediction
208
523
 
209
- try:
210
- if explain:
211
- resp = global_state.client.rfm_api.explain(request_bytes)
212
- else:
213
- resp = global_state.client.rfm_api.predict(request_bytes)
214
- except HTTPException as e:
215
- try:
216
- msg = json.loads(e.detail)['detail']
217
- except Exception:
218
- msg = e.detail
219
- raise RuntimeError(f"An unexpected exception occurred. "
220
- f"Please create an issue at "
221
- f"'https://github.com/kumo-ai/kumo-rfm'. "
222
- f"{msg}") from None
524
+ def is_valid_entity(
525
+ self,
526
+ query: str,
527
+ indices: Union[List[str], List[float], List[int], None] = None,
528
+ *,
529
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
530
+ ) -> np.ndarray:
531
+ r"""Returns a mask that denotes which entities are valid for the
532
+ given predictive query, *i.e.*, which entities fulfill (temporal)
533
+ entity filter constraints.
223
534
 
224
- return pd.DataFrame(**resp.prediction)
535
+ Args:
536
+ query: The predictive query.
537
+ indices: The entity primary keys to predict for. Will override the
538
+ indices given as part of the predictive query.
539
+ anchor_time: The anchor timestamp for the prediction. If set to
540
+ ``None``, will use the maximum timestamp in the data.
541
+ If set to ``"entity"``, will use the timestamp of the entity.
542
+ """
543
+ query_def = self._parse_query(query)
544
+
545
+ if indices is None:
546
+ if query_def.rfm_entity_ids is None:
547
+ raise ValueError("Cannot find entities to predict for. Please "
548
+ "pass them via "
549
+ "`is_valid_entity(query, indices=...)`")
550
+ indices = query_def.get_rfm_entity_id_list()
551
+
552
+ if len(indices) == 0:
553
+ raise ValueError("At least one entity is required")
554
+
555
+ if anchor_time is None:
556
+ anchor_time = self._get_default_anchor_time(query_def)
557
+
558
+ if isinstance(anchor_time, pd.Timestamp):
559
+ self._validate_time(query_def, anchor_time, None, False)
560
+ else:
561
+ assert anchor_time == 'entity'
562
+ if query_def.entity_table not in self._sampler.time_column_dict:
563
+ raise ValueError(f"Anchor time 'entity' requires the entity "
564
+ f"table '{query_def.entity_table}' "
565
+ f"to have a time column.")
566
+
567
+ raise NotImplementedError
225
568
 
226
569
  def evaluate(
227
570
  self,
@@ -233,9 +576,10 @@ class KumoRFM:
233
576
  run_mode: Union[RunMode, str] = RunMode.FAST,
234
577
  num_neighbors: Optional[List[int]] = None,
235
578
  num_hops: int = 2,
236
- max_pq_iterations: int = 20,
579
+ max_pq_iterations: int = 10,
237
580
  random_seed: Optional[int] = _RANDOM_SEED,
238
581
  verbose: Union[bool, ProgressLogger] = True,
582
+ use_prediction_time: bool = False,
239
583
  ) -> pd.DataFrame:
240
584
  """Evaluates a predictive query.
241
585
 
@@ -243,10 +587,10 @@ class KumoRFM:
243
587
  query: The predictive query.
244
588
  metrics: The metrics to use.
245
589
  anchor_time: The anchor timestamp for the prediction. If set to
246
- :obj:`None`, will use the maximum timestamp in the data.
247
- If set to :`"entity"`, will use the timestamp of the entity.
590
+ ``None``, will use the maximum timestamp in the data.
591
+ If set to ``"entity"``, will use the timestamp of the entity.
248
592
  context_anchor_time: The maximum anchor timestamp for context
249
- examples. If set to :obj:`None`, :obj:`anchor_time` will
593
+ examples. If set to ``None``, ``anchor_time`` will
250
594
  determine the anchor time for context examples.
251
595
  run_mode: The :class:`RunMode` for the query.
252
596
  num_neighbors: The number of neighbors to sample for each hop.
@@ -259,6 +603,9 @@ class KumoRFM:
259
603
  entities to find valid labels.
260
604
  random_seed: A manual seed for generating pseudo-random numbers.
261
605
  verbose: Whether to print verbose output.
606
+ use_prediction_time: Whether to use the anchor timestamp as an
607
+ additional feature during prediction. This is typically
608
+ beneficial for time series forecasting tasks.
262
609
 
263
610
  Returns:
264
611
  The metrics as a :class:`pandas.DataFrame`
@@ -269,15 +616,22 @@ class KumoRFM:
269
616
  warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
270
617
  f"custom 'num_hops={num_hops}' option")
271
618
 
619
+ if query_def.rfm_entity_ids is not None:
620
+ query_def = replace(
621
+ query_def,
622
+ rfm_entity_ids=None,
623
+ )
624
+
272
625
  query_repr = query_def.to_string(rich=True, exclude_predict=True)
273
626
  msg = f'[bold]EVALUATE[/bold] {query_repr}'
274
627
 
275
628
  if not isinstance(verbose, ProgressLogger):
276
- verbose = InteractiveProgressLogger(msg, verbose=verbose)
629
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
277
630
 
278
631
  with verbose as logger:
279
632
  context = self._get_context(
280
- query_def,
633
+ query=query_def,
634
+ indices=None,
281
635
  anchor_time=anchor_time,
282
636
  context_anchor_time=context_anchor_time,
283
637
  run_mode=RunMode(run_mode),
@@ -295,6 +649,7 @@ class KumoRFM:
295
649
  context=context,
296
650
  run_mode=RunMode(run_mode),
297
651
  metrics=metrics,
652
+ use_prediction_time=use_prediction_time,
298
653
  )
299
654
  with warnings.catch_warnings():
300
655
  warnings.filterwarnings('ignore', message='Protobuf gencode')
@@ -305,10 +660,10 @@ class KumoRFM:
305
660
 
306
661
  if len(request_bytes) > _MAX_SIZE:
307
662
  stats_msg = Context.get_memory_stats(request_msg.context)
308
- raise ValueError(_SIZE_LIMIT_MSG.format(stats_msg=stats_msg))
663
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
309
664
 
310
665
  try:
311
- resp = global_state.client.rfm_api.evaluate(request_bytes)
666
+ resp = self._api_client.evaluate(request_bytes)
312
667
  except HTTPException as e:
313
668
  try:
314
669
  msg = json.loads(e.detail)['detail']
@@ -332,7 +687,7 @@ class KumoRFM:
332
687
  *,
333
688
  anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
334
689
  random_seed: Optional[int] = _RANDOM_SEED,
335
- max_iterations: int = 20,
690
+ max_iterations: int = 10,
336
691
  ) -> pd.DataFrame:
337
692
  """Returns the labels of a predictive query for a specified anchor
338
693
  time.
@@ -352,45 +707,43 @@ class KumoRFM:
352
707
  query_def = self._parse_query(query)
353
708
 
354
709
  if anchor_time is None:
355
- anchor_time = self._graph_store.max_time
356
- anchor_time = anchor_time - (query_def.target.end_offset *
357
- query_def.num_forecasts)
710
+ anchor_time = self._get_default_anchor_time(query_def)
711
+ if query_def.target_ast.date_offset_range is not None:
712
+ offset = query_def.target_ast.date_offset_range.end_date_offset
713
+ offset *= query_def.num_forecasts
714
+ anchor_time -= offset
358
715
 
359
716
  assert anchor_time is not None
360
717
  if isinstance(anchor_time, pd.Timestamp):
361
718
  self._validate_time(query_def, anchor_time, None, evaluate=True)
362
719
  else:
363
720
  assert anchor_time == 'entity'
364
- if (query_def.entity.pkey.table_name
365
- not in self._graph_store.time_dict):
721
+ if query_def.entity_table not in self._sampler.time_column_dict:
366
722
  raise ValueError(f"Anchor time 'entity' requires the entity "
367
- f"table '{query_def.entity.pkey.table_name}' "
723
+ f"table '{query_def.entity_table}' "
368
724
  f"to have a time column")
369
725
 
370
- query_driver = LocalPQueryDriver(self._graph_store, query_def,
371
- random_seed)
372
-
373
- node, time, y = query_driver.collect_test(
374
- size=size,
375
- anchor_time=anchor_time,
376
- batch_size=min(10_000, size),
377
- max_iterations=max_iterations,
378
- guarantee_train_examples=False,
726
+ train, test = self._sampler.sample_target(
727
+ query=query,
728
+ num_train_examples=0,
729
+ train_anchor_time=anchor_time,
730
+ num_train_trials=0,
731
+ num_test_examples=size,
732
+ test_anchor_time=anchor_time,
733
+ num_test_trials=max_iterations * size,
734
+ random_seed=random_seed,
379
735
  )
380
736
 
381
- entity = self._graph_store.pkey_map_dict[
382
- query_def.entity.pkey.table_name].index[node]
383
-
384
737
  return pd.DataFrame({
385
- 'ENTITY': entity,
386
- 'ANCHOR_TIMESTAMP': time,
387
- 'TARGET': y,
738
+ 'ENTITY': test.entity_pkey,
739
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
740
+ 'TARGET': test.target,
388
741
  })
389
742
 
390
743
  # Helpers #################################################################
391
744
 
392
- def _parse_query(self, query: str) -> PQueryDefinition:
393
- if isinstance(query, PQueryDefinition):
745
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
746
+ if isinstance(query, ValidatedPredictiveQuery):
394
747
  return query
395
748
 
396
749
  if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
@@ -400,13 +753,12 @@ class KumoRFM:
400
753
  "predictions or evaluations.")
401
754
 
402
755
  try:
403
- request = RFMValidateQueryRequest(
756
+ request = RFMParseQueryRequest(
404
757
  query=query,
405
758
  graph_definition=self._graph_def,
406
759
  )
407
760
 
408
- resp = global_state.client.rfm_api.validate_query(request)
409
- # TODO Expose validation warnings.
761
+ resp = self._api_client.parse_query(request)
410
762
 
411
763
  if len(resp.validation_response.warnings) > 0:
412
764
  msg = '\n'.join([
@@ -416,7 +768,7 @@ class KumoRFM:
416
768
  warnings.warn(f"Encountered the following warnings during "
417
769
  f"parsing:\n{msg}")
418
770
 
419
- return resp.query_definition
771
+ return resp.query
420
772
  except HTTPException as e:
421
773
  try:
422
774
  msg = json.loads(e.detail)['detail']
@@ -425,30 +777,91 @@ class KumoRFM:
425
777
  raise ValueError(f"Failed to parse query '{query}'. "
426
778
  f"{msg}") from None
427
779
 
780
+ @staticmethod
781
+ def _get_task_type(
782
+ query: ValidatedPredictiveQuery,
783
+ edge_types: List[Tuple[str, str, str]],
784
+ ) -> TaskType:
785
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
786
+ return TaskType.BINARY_CLASSIFICATION
787
+
788
+ target = query.target_ast
789
+ if isinstance(target, Join):
790
+ target = target.rhs_target
791
+ if isinstance(target, Aggregation):
792
+ if target.aggr == AggregationType.LIST_DISTINCT:
793
+ table_name, col_name = target._get_target_column_name().split(
794
+ '.')
795
+ target_edge_types = [
796
+ edge_type for edge_type in edge_types
797
+ if edge_type[0] == table_name and edge_type[1] == col_name
798
+ ]
799
+ if len(target_edge_types) != 1:
800
+ raise NotImplementedError(
801
+ f"Multilabel-classification queries based on "
802
+ f"'LIST_DISTINCT' are not supported yet. If you "
803
+ f"planned to write a link prediction query instead, "
804
+ f"make sure to register '{col_name}' as a "
805
+ f"foreign key.")
806
+ return TaskType.TEMPORAL_LINK_PREDICTION
807
+
808
+ return TaskType.REGRESSION
809
+
810
+ assert isinstance(target, Column)
811
+
812
+ if target.stype in {Stype.ID, Stype.categorical}:
813
+ return TaskType.MULTICLASS_CLASSIFICATION
814
+
815
+ if target.stype in {Stype.numerical}:
816
+ return TaskType.REGRESSION
817
+
818
+ raise NotImplementedError("Task type not yet supported")
819
+
820
+ def _get_default_anchor_time(
821
+ self,
822
+ query: ValidatedPredictiveQuery,
823
+ ) -> pd.Timestamp:
824
+ if query.query_type == QueryType.TEMPORAL:
825
+ aggr_table_names = [
826
+ aggr._get_target_column_name().split('.')[0]
827
+ for aggr in query.get_all_target_aggregations()
828
+ ]
829
+ return self._sampler.get_max_time(aggr_table_names)
830
+
831
+ assert query.query_type == QueryType.STATIC
832
+ return self._sampler.get_max_time()
833
+
428
834
  def _validate_time(
429
835
  self,
430
- query: PQueryDefinition,
836
+ query: ValidatedPredictiveQuery,
431
837
  anchor_time: pd.Timestamp,
432
838
  context_anchor_time: Union[pd.Timestamp, None],
433
839
  evaluate: bool,
434
840
  ) -> None:
435
841
 
436
- if self._graph_store.min_time == pd.Timestamp.max:
842
+ if len(self._sampler.time_column_dict) == 0:
437
843
  return # Graph without timestamps
438
844
 
439
- if anchor_time < self._graph_store.min_time:
845
+ min_time = self._sampler.get_min_time()
846
+ max_time = self._sampler.get_max_time()
847
+
848
+ if anchor_time < min_time:
440
849
  raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
441
- f"the earliest timestamp "
442
- f"'{self._graph_store.min_time}' in the data.")
850
+ f"the earliest timestamp '{min_time}' in the "
851
+ f"data.")
443
852
 
444
- if (context_anchor_time is not None
445
- and context_anchor_time < self._graph_store.min_time):
853
+ if context_anchor_time is not None and context_anchor_time < min_time:
446
854
  raise ValueError(f"Context anchor timestamp is too early or "
447
855
  f"aggregation time range is too large. To make "
448
856
  f"this prediction, we would need data back to "
449
857
  f"'{context_anchor_time}', however, your data "
450
- f"only contains data back to "
451
- f"'{self._graph_store.min_time}'.")
858
+ f"only contains data back to '{min_time}'.")
859
+
860
+ if query.target_ast.date_offset_range is not None:
861
+ end_offset = query.target_ast.date_offset_range.end_date_offset
862
+ else:
863
+ end_offset = pd.DateOffset(0)
864
+ end_offset = end_offset * query.num_forecasts
452
865
 
453
866
  if (context_anchor_time is not None
454
867
  and context_anchor_time > anchor_time):
@@ -458,41 +871,37 @@ class KumoRFM:
458
871
  f"(got '{anchor_time}'). Please make sure this is "
459
872
  f"intended.")
460
873
  elif (query.query_type == QueryType.TEMPORAL
461
- and context_anchor_time is not None and context_anchor_time +
462
- query.target.end_offset * query.num_forecasts > anchor_time):
874
+ and context_anchor_time is not None
875
+ and context_anchor_time + end_offset > anchor_time):
463
876
  warnings.warn(f"Aggregation for context examples at timestamp "
464
877
  f"'{context_anchor_time}' will leak information "
465
878
  f"from the prediction anchor timestamp "
466
879
  f"'{anchor_time}'. Please make sure this is "
467
880
  f"intended.")
468
881
 
469
- elif (context_anchor_time is not None and context_anchor_time -
470
- query.target.end_offset * query.num_forecasts
471
- < self._graph_store.min_time):
472
- _time = context_anchor_time - (query.target.end_offset *
473
- query.num_forecasts)
882
+ elif (context_anchor_time is not None
883
+ and context_anchor_time - end_offset < min_time):
884
+ _time = context_anchor_time - end_offset
474
885
  warnings.warn(f"Context anchor timestamp is too early or "
475
886
  f"aggregation time range is too large. To form "
476
887
  f"proper input data, we would need data back to "
477
888
  f"'{_time}', however, your data only contains "
478
- f"data back to '{self._graph_store.min_time}'.")
889
+ f"data back to '{min_time}'.")
479
890
 
480
- if (not evaluate and anchor_time
481
- > self._graph_store.max_time + pd.DateOffset(days=1)):
891
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
482
892
  warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
483
- f"latest timestamp '{self._graph_store.max_time}' "
484
- f"in the data. Please make sure this is intended.")
893
+ f"latest timestamp '{max_time}' in the data. Please "
894
+ f"make sure this is intended.")
485
895
 
486
- max_eval_time = (self._graph_store.max_time -
487
- query.target.end_offset * query.num_forecasts)
488
- if evaluate and anchor_time > max_eval_time:
896
+ if evaluate and anchor_time > max_time - end_offset:
489
897
  raise ValueError(
490
898
  f"Anchor timestamp for evaluation is after the latest "
491
- f"supported timestamp '{max_eval_time}'.")
899
+ f"supported timestamp '{max_time - end_offset}'.")
492
900
 
493
901
  def _get_context(
494
902
  self,
495
- query: PQueryDefinition,
903
+ query: ValidatedPredictiveQuery,
904
+ indices: Union[List[str], List[float], List[int], None],
496
905
  anchor_time: Union[pd.Timestamp, Literal['entity'], None],
497
906
  context_anchor_time: Union[pd.Timestamp, None],
498
907
  run_mode: RunMode,
@@ -518,10 +927,9 @@ class KumoRFM:
518
927
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
519
928
  f"must go beyond this for your use-case.")
520
929
 
521
- query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
522
- task_type = query.get_task_type(
523
- stypes=self._graph_store.stype_dict,
524
- edge_types=self._graph_store.edge_types,
930
+ task_type = self._get_task_type(
931
+ query=query,
932
+ edge_types=self._sampler.edge_types,
525
933
  )
526
934
 
527
935
  if logger is not None:
@@ -552,104 +960,109 @@ class KumoRFM:
552
960
  else:
553
961
  num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
554
962
 
963
+ if query.target_ast.date_offset_range is None:
964
+ step_offset = pd.DateOffset(0)
965
+ else:
966
+ step_offset = query.target_ast.date_offset_range.end_date_offset
967
+ end_offset = step_offset * query.num_forecasts
968
+
555
969
  if anchor_time is None:
556
- anchor_time = self._graph_store.max_time
970
+ anchor_time = self._get_default_anchor_time(query)
971
+
557
972
  if evaluate:
558
- anchor_time = anchor_time - (query.target.end_offset *
559
- query.num_forecasts)
973
+ anchor_time = anchor_time - end_offset
974
+
560
975
  if logger is not None:
561
976
  assert isinstance(anchor_time, pd.Timestamp)
562
- if (anchor_time.hour == 0 and anchor_time.minute == 0
563
- and anchor_time.second == 0
564
- and anchor_time.microsecond == 0):
977
+ if anchor_time == pd.Timestamp.min:
978
+ pass # Static graph
979
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
980
+ and anchor_time.second == 0
981
+ and anchor_time.microsecond == 0):
565
982
  logger.log(f"Derived anchor time {anchor_time.date()}")
566
983
  else:
567
984
  logger.log(f"Derived anchor time {anchor_time}")
568
985
 
569
986
  assert anchor_time is not None
570
987
  if isinstance(anchor_time, pd.Timestamp):
988
+ if context_anchor_time == 'entity':
989
+ raise ValueError("Anchor time 'entity' needs to be shared "
990
+ "for context and prediction examples")
571
991
  if context_anchor_time is None:
572
- context_anchor_time = anchor_time - (query.target.end_offset *
573
- query.num_forecasts)
992
+ context_anchor_time = anchor_time - end_offset
574
993
  self._validate_time(query, anchor_time, context_anchor_time,
575
994
  evaluate)
576
995
  else:
577
996
  assert anchor_time == 'entity'
578
- if query.entity.pkey.table_name not in self._graph_store.time_dict:
997
+ if query.query_type != QueryType.STATIC:
998
+ raise ValueError("Anchor time 'entity' is only valid for "
999
+ "static predictive queries")
1000
+ if query.entity_table not in self._sampler.time_column_dict:
579
1001
  raise ValueError(f"Anchor time 'entity' requires the entity "
580
- f"table '{query.entity.pkey.table_name}' to "
1002
+ f"table '{query.entity_table}' to "
581
1003
  f"have a time column")
582
- if context_anchor_time is not None:
583
- warnings.warn("Ignoring option 'context_anchor_time' for "
584
- "`anchor_time='entity'`")
585
- context_anchor_time = None
1004
+ if isinstance(context_anchor_time, pd.Timestamp):
1005
+ raise ValueError("Anchor time 'entity' needs to be shared "
1006
+ "for context and prediction examples")
1007
+ context_anchor_time = 'entity'
586
1008
 
587
- y_test: Optional[pd.Series] = None
1009
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
588
1010
  if evaluate:
589
- max_test_size = _MAX_TEST_SIZE[run_mode]
1011
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
590
1012
  if task_type.is_link_pred:
591
- max_test_size = max_test_size // 5
592
-
593
- test_node, test_time, y_test = query_driver.collect_test(
594
- size=max_test_size,
595
- anchor_time=anchor_time,
596
- max_iterations=max_pq_iterations,
597
- guarantee_train_examples=True,
598
- )
599
- if logger is not None:
600
- if task_type == TaskType.BINARY_CLASSIFICATION:
601
- pos = 100 * int((y_test > 0).sum()) / len(y_test)
602
- msg = (f"Collected {len(y_test):,} test examples with "
603
- f"{pos:.2f}% positive cases")
604
- elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
605
- msg = (f"Collected {len(y_test):,} test examples "
606
- f"holding {y_test.nunique()} classes")
607
- elif task_type == TaskType.REGRESSION:
608
- _min, _max = float(y_test.min()), float(y_test.max())
609
- msg = (f"Collected {len(y_test):,} test examples with "
610
- f"targets between {format_value(_min)} and "
611
- f"{format_value(_max)}")
612
- elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
613
- num_rhs = y_test.explode().nunique()
614
- msg = (f"Collected {len(y_test):,} test examples with "
615
- f"{num_rhs:,} unique items")
616
- else:
617
- raise NotImplementedError
618
- logger.log(msg)
619
-
1013
+ num_test_examples = num_test_examples // 5
620
1014
  else:
621
- assert query.entity.ids is not None
1015
+ num_test_examples = 0
1016
+
1017
+ train, test = self._sampler.sample_target(
1018
+ query=query,
1019
+ num_train_examples=num_train_examples,
1020
+ train_anchor_time=context_anchor_time,
1021
+ num_train_trials=max_pq_iterations * num_train_examples,
1022
+ num_test_examples=num_test_examples,
1023
+ test_anchor_time=anchor_time,
1024
+ num_test_trials=max_pq_iterations * num_test_examples,
1025
+ random_seed=random_seed,
1026
+ )
1027
+ train_pkey, train_time, y_train = train
1028
+ test_pkey, test_time, y_test = test
622
1029
 
623
- max_num_test = 200 if task_type.is_link_pred else 1000
624
- if len(query.entity.ids.value) > max_num_test:
1030
+ if evaluate and logger is not None:
1031
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1032
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1033
+ msg = (f"Collected {len(y_test):,} test examples with "
1034
+ f"{pos:.2f}% positive cases")
1035
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1036
+ msg = (f"Collected {len(y_test):,} test examples holding "
1037
+ f"{y_test.nunique()} classes")
1038
+ elif task_type == TaskType.REGRESSION:
1039
+ _min, _max = float(y_test.min()), float(y_test.max())
1040
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1041
+ f"between {format_value(_min)} and "
1042
+ f"{format_value(_max)}")
1043
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1044
+ num_rhs = y_test.explode().nunique()
1045
+ msg = (f"Collected {len(y_test):,} test examples with "
1046
+ f"{num_rhs:,} unique items")
1047
+ else:
1048
+ raise NotImplementedError
1049
+ logger.log(msg)
1050
+
1051
+ if not evaluate:
1052
+ assert indices is not None
1053
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
625
1054
  raise ValueError(f"Cannot predict for more than "
626
- f"{max_num_test:,} entities at once "
627
- f"(got {len(query.entity.ids.value):,})")
628
-
629
- test_node = self._graph_store.get_node_id(
630
- table_name=query.entity.pkey.table_name,
631
- pkey=pd.Series(
632
- query.entity.ids.value,
633
- dtype=query.entity.ids.dtype,
634
- ),
635
- )
1055
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1056
+ f"once (got {len(indices):,}). Use "
1057
+ f"`KumoRFM.batch_mode` to process entities "
1058
+ f"in batches")
636
1059
 
1060
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
637
1061
  if isinstance(anchor_time, pd.Timestamp):
638
- test_time = pd.Series(anchor_time).repeat(
639
- len(test_node)).reset_index(drop=True)
1062
+ test_time = pd.Series([anchor_time]).repeat(
1063
+ len(indices)).reset_index(drop=True)
640
1064
  else:
641
- time = self._graph_store.time_dict[
642
- query.entity.pkey.table_name]
643
- time = time[test_node] * 1000**3
644
- test_time = pd.Series(time, dtype='datetime64[ns]')
645
-
646
- train_node, train_time, y_train = query_driver.collect_train(
647
- size=_MAX_CONTEXT_SIZE[run_mode],
648
- anchor_time=context_anchor_time or 'entity',
649
- exclude_node=test_node if (query.query_type == QueryType.STATIC
650
- or anchor_time == 'entity') else None,
651
- max_iterations=max_pq_iterations,
652
- )
1065
+ train_time = test_time = 'entity'
653
1066
 
654
1067
  if logger is not None:
655
1068
  if task_type == TaskType.BINARY_CLASSIFICATION:
@@ -672,27 +1085,41 @@ class KumoRFM:
672
1085
  raise NotImplementedError
673
1086
  logger.log(msg)
674
1087
 
675
- entity_table_names = query.get_entity_table_names(
676
- self._graph_store.edge_types)
1088
+ entity_table_names: Tuple[str, ...]
1089
+ if task_type.is_link_pred:
1090
+ final_aggr = query.get_final_target_aggregation()
1091
+ assert final_aggr is not None
1092
+ edge_fkey = final_aggr._get_target_column_name()
1093
+ for edge_type in self._sampler.edge_types:
1094
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1095
+ entity_table_names = (
1096
+ query.entity_table,
1097
+ edge_type[2],
1098
+ )
1099
+ else:
1100
+ entity_table_names = (query.entity_table, )
677
1101
 
678
1102
  # Exclude the entity anchor time from the feature set to prevent
679
1103
  # running out-of-distribution between in-context and test examples:
680
- exclude_cols_dict = query.exclude_cols_dict
681
- if anchor_time == 'entity':
1104
+ exclude_cols_dict = query.get_exclude_cols_dict()
1105
+ if entity_table_names[0] in self._sampler.time_column_dict:
682
1106
  if entity_table_names[0] not in exclude_cols_dict:
683
1107
  exclude_cols_dict[entity_table_names[0]] = []
684
- time_column_dict = self._graph_store.time_column_dict
685
- time_column = time_column_dict[entity_table_names[0]]
1108
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
686
1109
  exclude_cols_dict[entity_table_names[0]].append(time_column)
687
1110
 
688
- subgraph = self._graph_sampler(
1111
+ subgraph = self._sampler.sample_subgraph(
689
1112
  entity_table_names=entity_table_names,
690
- node=np.concatenate([train_node, test_node]),
691
- time=np.concatenate([
692
- train_time.astype('datetime64[ns]').astype(int).to_numpy(),
693
- test_time.astype('datetime64[ns]').astype(int).to_numpy(),
694
- ]),
695
- run_mode=run_mode,
1113
+ entity_pkey=pd.concat(
1114
+ [train_pkey, test_pkey],
1115
+ axis=0,
1116
+ ignore_index=True,
1117
+ ),
1118
+ anchor_time=pd.concat(
1119
+ [train_time, test_time],
1120
+ axis=0,
1121
+ ignore_index=True,
1122
+ ) if isinstance(train_time, pd.Series) else 'entity',
696
1123
  num_neighbors=num_neighbors,
697
1124
  exclude_cols_dict=exclude_cols_dict,
698
1125
  )
@@ -704,18 +1131,14 @@ class KumoRFM:
704
1131
  f"'https://github.com/kumo-ai/kumo-rfm' if you "
705
1132
  f"must go beyond this for your use-case.")
706
1133
 
707
- step_size: Optional[int] = None
708
- if query.query_type == QueryType.TEMPORAL:
709
- step_size = date_offset_to_seconds(query.target.end_offset)
710
-
711
1134
  return Context(
712
1135
  task_type=task_type,
713
1136
  entity_table_names=entity_table_names,
714
1137
  subgraph=subgraph,
715
1138
  y_train=y_train,
716
- y_test=y_test,
1139
+ y_test=y_test if evaluate else None,
717
1140
  top_k=query.top_k,
718
- step_size=step_size,
1141
+ step_size=None,
719
1142
  )
720
1143
 
721
1144
  @staticmethod
@@ -731,7 +1154,7 @@ class KumoRFM:
731
1154
  elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
732
1155
  supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
733
1156
  elif task_type == TaskType.REGRESSION:
734
- supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape']
1157
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
735
1158
  elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
736
1159
  supported_metrics = [
737
1160
  'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',