kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1130 @@
1
+ import json
2
+ import time
3
+ import warnings
4
+ from collections import defaultdict
5
+ from collections.abc import Generator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterator,
12
+ List,
13
+ Literal,
14
+ Optional,
15
+ Tuple,
16
+ Union,
17
+ overload,
18
+ )
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+ from kumoapi.model_plan import RunMode
23
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
24
+ from kumoapi.rfm import Context
25
+ from kumoapi.rfm import Explanation as ExplanationConfig
26
+ from kumoapi.rfm import (
27
+ RFMEvaluateRequest,
28
+ RFMParseQueryRequest,
29
+ RFMPredictRequest,
30
+ )
31
+ from kumoapi.task import TaskType
32
+
33
+ from kumoai import global_state
34
+ from kumoai.exceptions import HTTPException
35
+ from kumoai.experimental.rfm import LocalGraph
36
+ from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
37
+ from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
38
+ from kumoai.experimental.rfm.local_pquery_driver import (
39
+ LocalPQueryDriver,
40
+ date_offset_to_seconds,
41
+ )
42
+ from kumoai.mixin import CastMixin
43
+ from kumoai.utils import InteractiveProgressLogger, ProgressLogger
44
+
45
+ _RANDOM_SEED = 42
46
+
47
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
48
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
49
+
50
+ _MAX_CONTEXT_SIZE = {
51
+ RunMode.DEBUG: 100,
52
+ RunMode.FAST: 1_000,
53
+ RunMode.NORMAL: 5_000,
54
+ RunMode.BEST: 10_000,
55
+ }
56
+ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
57
+ RunMode.DEBUG: 100,
58
+ RunMode.FAST: 2_000,
59
+ RunMode.NORMAL: 2_000,
60
+ RunMode.BEST: 2_000,
61
+ }
62
+
63
+ _MAX_SIZE = 30 * 1024 * 1024
64
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
65
+ "reduce either the number of tables in the graph, their "
66
+ "number of columns (e.g., large text columns), "
67
+ "neighborhood configuration, or the run mode. If none of "
68
+ "this is possible, please create a feature request at "
69
+ "'https://github.com/kumo-ai/kumo-rfm' if you must go "
70
+ "beyond this for your use-case.")
71
+
72
+
73
+ @dataclass(repr=False)
74
+ class ExplainConfig(CastMixin):
75
+ """Configuration for explainability.
76
+
77
+ Args:
78
+ skip_summary: Whether to skip generating a human-readable summary of
79
+ the explanation.
80
+ """
81
+ skip_summary: bool = False
82
+
83
+
84
+ @dataclass(repr=False)
85
+ class Explanation:
86
+ prediction: pd.DataFrame
87
+ summary: str
88
+ details: ExplanationConfig
89
+
90
+ @overload
91
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
92
+ pass
93
+
94
+ @overload
95
+ def __getitem__(self, index: Literal[1]) -> str:
96
+ pass
97
+
98
+ def __getitem__(self, index: int) -> Union[pd.DataFrame, str]:
99
+ if index == 0:
100
+ return self.prediction
101
+ if index == 1:
102
+ return self.summary
103
+ raise IndexError("Index out of range")
104
+
105
+ def __iter__(self) -> Iterator[Union[pd.DataFrame, str]]:
106
+ return iter((self.prediction, self.summary))
107
+
108
+ def __repr__(self) -> str:
109
+ return str((self.prediction, self.summary))
110
+
111
+ def _ipython_display_(self) -> None:
112
+ from IPython.display import Markdown, display
113
+
114
+ display(self.prediction)
115
+ display(Markdown(self.summary))
116
+
117
+
118
+ class KumoRFM:
119
+ r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
120
+ Foundation Model for In-Context Learning on Relational Data
121
+ <https://kumo.ai/research/kumo_relational_foundation_model.pdf>`_ paper.
122
+
123
+ :class:`KumoRFM` is a foundation model to generate predictions for any
124
+ relational dataset without training.
125
+ The model is pre-trained and the class provides an interface to query the
126
+ model from a :class:`LocalGraph` object.
127
+
128
+ .. code-block:: python
129
+
130
+ from kumoai.experimental.rfm import LocalGraph, KumoRFM
131
+
132
+ df_users = pd.DataFrame(...)
133
+ df_items = pd.DataFrame(...)
134
+ df_orders = pd.DataFrame(...)
135
+
136
+ graph = LocalGraph.from_data({
137
+ 'users': df_users,
138
+ 'items': df_items,
139
+ 'orders': df_orders,
140
+ })
141
+
142
+ rfm = KumoRFM(graph)
143
+
144
+ query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 "
145
+ "FOR users.user_id=0")
146
+ result = rfm.query(query)
147
+
148
+ print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
149
+ # 1 0.85
150
+
151
+ Args:
152
+ graph: The graph.
153
+ preprocess: Whether to pre-process the data in advance during graph
154
+ materialization.
155
+ This is a runtime trade-off between graph materialization and model
156
+ processing speed.
157
+ It can be benefical to preprocess your data once and then run many
158
+ queries on top to achieve maximum model speed.
159
+ However, if activiated, graph materialization can take potentially
160
+ much longer, especially on graphs with many large text columns.
161
+ Best to tune this option manually.
162
+ verbose: Whether to print verbose output.
163
+ """
164
+ def __init__(
165
+ self,
166
+ graph: LocalGraph,
167
+ preprocess: bool = False,
168
+ verbose: Union[bool, ProgressLogger] = True,
169
+ ) -> None:
170
+ graph = graph.validate()
171
+ self._graph_def = graph._to_api_graph_definition()
172
+ self._graph_store = LocalGraphStore(graph, preprocess, verbose)
173
+ self._graph_sampler = LocalGraphSampler(self._graph_store)
174
+
175
+ self._batch_size: Optional[int | Literal['max']] = None
176
+ self.num_retries: int = 0
177
+
178
+ def __repr__(self) -> str:
179
+ return f'{self.__class__.__name__}()'
180
+
181
+ @contextmanager
182
+ def batch_mode(
183
+ self,
184
+ batch_size: Union[int, Literal['max']] = 'max',
185
+ num_retries: int = 1,
186
+ ) -> Generator[None, None, None]:
187
+ """Context manager to predict in batches.
188
+
189
+ .. code-block:: python
190
+
191
+ with model.batch_mode(batch_size='max', num_retries=1):
192
+ df = model.predict(query, indices=...)
193
+
194
+ Args:
195
+ batch_size: The batch size. If set to ``"max"``, will use the
196
+ maximum applicable batch size for the given task.
197
+ num_retries: The maximum number of retries for failed queries due
198
+ to unexpected server issues.
199
+ """
200
+ if batch_size != 'max' and batch_size <= 0:
201
+ raise ValueError(f"'batch_size' must be greater than zero "
202
+ f"(got {batch_size})")
203
+
204
+ if num_retries < 0:
205
+ raise ValueError(f"'num_retries' must be greater than or equal to "
206
+ f"zero (got {num_retries})")
207
+
208
+ self._batch_size = batch_size
209
+ self.num_retries = num_retries
210
+ yield
211
+ self._batch_size = None
212
+ self.num_retries = 0
213
+
214
+ @overload
215
+ def predict(
216
+ self,
217
+ query: str,
218
+ indices: Union[List[str], List[float], List[int], None] = None,
219
+ *,
220
+ explain: Literal[False] = False,
221
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
222
+ context_anchor_time: Union[pd.Timestamp, None] = None,
223
+ run_mode: Union[RunMode, str] = RunMode.FAST,
224
+ num_neighbors: Optional[List[int]] = None,
225
+ num_hops: int = 2,
226
+ max_pq_iterations: int = 20,
227
+ random_seed: Optional[int] = _RANDOM_SEED,
228
+ verbose: Union[bool, ProgressLogger] = True,
229
+ use_prediction_time: bool = False,
230
+ ) -> pd.DataFrame:
231
+ pass
232
+
233
+ @overload
234
+ def predict(
235
+ self,
236
+ query: str,
237
+ indices: Union[List[str], List[float], List[int], None] = None,
238
+ *,
239
+ explain: Union[Literal[True], ExplainConfig, Dict[str, Any]],
240
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
241
+ context_anchor_time: Union[pd.Timestamp, None] = None,
242
+ run_mode: Union[RunMode, str] = RunMode.FAST,
243
+ num_neighbors: Optional[List[int]] = None,
244
+ num_hops: int = 2,
245
+ max_pq_iterations: int = 20,
246
+ random_seed: Optional[int] = _RANDOM_SEED,
247
+ verbose: Union[bool, ProgressLogger] = True,
248
+ use_prediction_time: bool = False,
249
+ ) -> Explanation:
250
+ pass
251
+
252
+ def predict(
253
+ self,
254
+ query: str,
255
+ indices: Union[List[str], List[float], List[int], None] = None,
256
+ *,
257
+ explain: Union[bool, ExplainConfig, Dict[str, Any]] = False,
258
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
259
+ context_anchor_time: Union[pd.Timestamp, None] = None,
260
+ run_mode: Union[RunMode, str] = RunMode.FAST,
261
+ num_neighbors: Optional[List[int]] = None,
262
+ num_hops: int = 2,
263
+ max_pq_iterations: int = 20,
264
+ random_seed: Optional[int] = _RANDOM_SEED,
265
+ verbose: Union[bool, ProgressLogger] = True,
266
+ use_prediction_time: bool = False,
267
+ ) -> Union[pd.DataFrame, Explanation]:
268
+ """Returns predictions for a predictive query.
269
+
270
+ Args:
271
+ query: The predictive query.
272
+ indices: The entity primary keys to predict for. Will override the
273
+ indices given as part of the predictive query. Predictions will
274
+ be generated for all indices, independent of whether they
275
+ fulfill entity filter constraints. To pre-filter entities, use
276
+ :meth:`~KumoRFM.is_valid_entity`.
277
+ explain: Configuration for explainability.
278
+ If set to ``True``, will additionally explain the prediction.
279
+ Passing in an :class:`ExplainConfig` instance provides control
280
+ over which parts of explanation are generated.
281
+ Explainability is currently only supported for single entity
282
+ predictions with ``run_mode="FAST"``.
283
+ anchor_time: The anchor timestamp for the prediction. If set to
284
+ ``None``, will use the maximum timestamp in the data.
285
+ If set to ``"entity"``, will use the timestamp of the entity.
286
+ context_anchor_time: The maximum anchor timestamp for context
287
+ examples. If set to ``None``, ``anchor_time`` will
288
+ determine the anchor time for context examples.
289
+ run_mode: The :class:`RunMode` for the query.
290
+ num_neighbors: The number of neighbors to sample for each hop.
291
+ If specified, the ``num_hops`` option will be ignored.
292
+ num_hops: The number of hops to sample when generating the context.
293
+ max_pq_iterations: The maximum number of iterations to perform to
294
+ collect valid labels. It is advised to increase the number of
295
+ iterations in case the predictive query has strict entity
296
+ filters, in which case, :class:`KumoRFM` needs to sample more
297
+ entities to find valid labels.
298
+ random_seed: A manual seed for generating pseudo-random numbers.
299
+ verbose: Whether to print verbose output.
300
+ use_prediction_time: Whether to use the anchor timestamp as an
301
+ additional feature during prediction. This is typically
302
+ beneficial for time series forecasting tasks.
303
+
304
+ Returns:
305
+ The predictions as a :class:`pandas.DataFrame`.
306
+ If ``explain`` is provided, returns an :class:`Explanation` object
307
+ containing the prediction, summary, and details.
308
+ """
309
+ explain_config: Optional[ExplainConfig] = None
310
+ if explain is True:
311
+ explain_config = ExplainConfig()
312
+ elif explain is not False:
313
+ explain_config = ExplainConfig._cast(explain)
314
+
315
+ query_def = self._parse_query(query)
316
+ query_str = query_def.to_string()
317
+
318
+ if num_hops != 2 and num_neighbors is not None:
319
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
320
+ f"custom 'num_hops={num_hops}' option")
321
+
322
+ if explain_config is not None and run_mode in {
323
+ RunMode.NORMAL, RunMode.BEST
324
+ }:
325
+ warnings.warn(f"Explainability is currently only supported for "
326
+ f"run mode 'FAST' (got '{run_mode}'). Provided run "
327
+ f"mode has been reset. Please lower the run mode to "
328
+ f"suppress this warning.")
329
+
330
+ if indices is None:
331
+ if query_def.rfm_entity_ids is None:
332
+ raise ValueError("Cannot find entities to predict for. Please "
333
+ "pass them via `predict(query, indices=...)`")
334
+ indices = query_def.get_rfm_entity_id_list()
335
+ else:
336
+ query_def = replace(query_def, rfm_entity_ids=None)
337
+
338
+ if len(indices) == 0:
339
+ raise ValueError("At least one entity is required")
340
+
341
+ if explain_config is not None and len(indices) > 1:
342
+ raise ValueError(
343
+ f"Cannot explain predictions for more than a single entity "
344
+ f"(got {len(indices)})")
345
+
346
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
347
+ if explain_config is not None:
348
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
349
+ else:
350
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
351
+
352
+ if not isinstance(verbose, ProgressLogger):
353
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
354
+
355
+ with verbose as logger:
356
+
357
+ batch_size: Optional[int] = None
358
+ if self._batch_size == 'max':
359
+ task_type = LocalPQueryDriver.get_task_type(
360
+ query_def,
361
+ edge_types=self._graph_store.edge_types,
362
+ )
363
+ batch_size = _MAX_PRED_SIZE[task_type]
364
+ else:
365
+ batch_size = self._batch_size
366
+
367
+ if batch_size is not None:
368
+ offsets = range(0, len(indices), batch_size)
369
+ batches = [indices[step:step + batch_size] for step in offsets]
370
+ else:
371
+ batches = [indices]
372
+
373
+ if len(batches) > 1:
374
+ logger.log(f"Splitting {len(indices):,} entities into "
375
+ f"{len(batches):,} batches of size {batch_size:,}")
376
+
377
+ predictions: List[pd.DataFrame] = []
378
+ summary: Optional[str] = None
379
+ details: Optional[Explanation] = None
380
+ for i, batch in enumerate(batches):
381
+ # TODO Re-use the context for subsequent predictions.
382
+ context = self._get_context(
383
+ query=query_def,
384
+ indices=batch,
385
+ anchor_time=anchor_time,
386
+ context_anchor_time=context_anchor_time,
387
+ run_mode=RunMode(run_mode),
388
+ num_neighbors=num_neighbors,
389
+ num_hops=num_hops,
390
+ max_pq_iterations=max_pq_iterations,
391
+ evaluate=False,
392
+ random_seed=random_seed,
393
+ logger=logger if i == 0 else None,
394
+ )
395
+ request = RFMPredictRequest(
396
+ context=context,
397
+ run_mode=RunMode(run_mode),
398
+ query=query_str,
399
+ use_prediction_time=use_prediction_time,
400
+ )
401
+ with warnings.catch_warnings():
402
+ warnings.filterwarnings('ignore', message='gencode')
403
+ request_msg = request.to_protobuf()
404
+ _bytes = request_msg.SerializeToString()
405
+ if i == 0:
406
+ logger.log(f"Generated context of size "
407
+ f"{len(_bytes) / (1024*1024):.2f}MB")
408
+
409
+ if len(_bytes) > _MAX_SIZE:
410
+ stats = Context.get_memory_stats(request_msg.context)
411
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
412
+
413
+ if (isinstance(verbose, InteractiveProgressLogger) and i == 0
414
+ and len(batches) > 1):
415
+ verbose.init_progress(
416
+ total=len(batches),
417
+ description='Predicting',
418
+ )
419
+
420
+ for attempt in range(self.num_retries + 1):
421
+ try:
422
+ if explain_config is not None:
423
+ resp = global_state.client.rfm_api.explain(
424
+ request=_bytes,
425
+ skip_summary=explain_config.skip_summary,
426
+ )
427
+ summary = resp.summary
428
+ details = resp.details
429
+ else:
430
+ resp = global_state.client.rfm_api.predict(_bytes)
431
+ df = pd.DataFrame(**resp.prediction)
432
+
433
+ # Cast 'ENTITY' to correct data type:
434
+ if 'ENTITY' in df:
435
+ entity = query_def.entity_table
436
+ pkey_map = self._graph_store.pkey_map_dict[entity]
437
+ df['ENTITY'] = df['ENTITY'].astype(
438
+ type(pkey_map.index[0]))
439
+
440
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
441
+ if 'ANCHOR_TIMESTAMP' in df:
442
+ ser = df['ANCHOR_TIMESTAMP']
443
+ if not pd.api.types.is_datetime64_any_dtype(ser):
444
+ if isinstance(ser.iloc[0], str):
445
+ unit = None
446
+ else:
447
+ unit = 'ms'
448
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
449
+ ser, errors='coerce', unit=unit)
450
+
451
+ predictions.append(df)
452
+
453
+ if (isinstance(verbose, InteractiveProgressLogger)
454
+ and len(batches) > 1):
455
+ verbose.step()
456
+
457
+ break
458
+ except HTTPException as e:
459
+ if attempt == self.num_retries:
460
+ try:
461
+ msg = json.loads(e.detail)['detail']
462
+ except Exception:
463
+ msg = e.detail
464
+ raise RuntimeError(
465
+ f"An unexpected exception occurred. Please "
466
+ f"create an issue at "
467
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
468
+ ) from None
469
+
470
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
471
+
472
+ if len(predictions) == 1:
473
+ prediction = predictions[0]
474
+ else:
475
+ prediction = pd.concat(predictions, ignore_index=True)
476
+
477
+ if explain_config is not None:
478
+ assert len(predictions) == 1
479
+ assert summary is not None
480
+ assert details is not None
481
+ return Explanation(
482
+ prediction=prediction,
483
+ summary=summary,
484
+ details=details,
485
+ )
486
+
487
+ return prediction
488
+
489
+ def is_valid_entity(
490
+ self,
491
+ query: str,
492
+ indices: Union[List[str], List[float], List[int], None] = None,
493
+ *,
494
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
495
+ ) -> np.ndarray:
496
+ r"""Returns a mask that denotes which entities are valid for the
497
+ given predictive query, *i.e.*, which entities fulfill (temporal)
498
+ entity filter constraints.
499
+
500
+ Args:
501
+ query: The predictive query.
502
+ indices: The entity primary keys to predict for. Will override the
503
+ indices given as part of the predictive query.
504
+ anchor_time: The anchor timestamp for the prediction. If set to
505
+ ``None``, will use the maximum timestamp in the data.
506
+ If set to ``"entity"``, will use the timestamp of the entity.
507
+ """
508
+ query_def = self._parse_query(query)
509
+
510
+ if indices is None:
511
+ if query_def.rfm_entity_ids is None:
512
+ raise ValueError("Cannot find entities to predict for. Please "
513
+ "pass them via "
514
+ "`is_valid_entity(query, indices=...)`")
515
+ indices = query_def.get_rfm_entity_id_list()
516
+
517
+ if len(indices) == 0:
518
+ raise ValueError("At least one entity is required")
519
+
520
+ if anchor_time is None:
521
+ anchor_time = self._graph_store.max_time
522
+
523
+ if isinstance(anchor_time, pd.Timestamp):
524
+ self._validate_time(query_def, anchor_time, None, False)
525
+ else:
526
+ assert anchor_time == 'entity'
527
+ if (query_def.entity_table not in self._graph_store.time_dict):
528
+ raise ValueError(f"Anchor time 'entity' requires the entity "
529
+ f"table '{query_def.entity_table}' "
530
+ f"to have a time column.")
531
+
532
+ node = self._graph_store.get_node_id(
533
+ table_name=query_def.entity_table,
534
+ pkey=pd.Series(indices),
535
+ )
536
+ query_driver = LocalPQueryDriver(self._graph_store, query_def)
537
+ return query_driver.is_valid(node, anchor_time)
538
+
539
+ def evaluate(
540
+ self,
541
+ query: str,
542
+ *,
543
+ metrics: Optional[List[str]] = None,
544
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
545
+ context_anchor_time: Union[pd.Timestamp, None] = None,
546
+ run_mode: Union[RunMode, str] = RunMode.FAST,
547
+ num_neighbors: Optional[List[int]] = None,
548
+ num_hops: int = 2,
549
+ max_pq_iterations: int = 20,
550
+ random_seed: Optional[int] = _RANDOM_SEED,
551
+ verbose: Union[bool, ProgressLogger] = True,
552
+ use_prediction_time: bool = False,
553
+ ) -> pd.DataFrame:
554
+ """Evaluates a predictive query.
555
+
556
+ Args:
557
+ query: The predictive query.
558
+ metrics: The metrics to use.
559
+ anchor_time: The anchor timestamp for the prediction. If set to
560
+ ``None``, will use the maximum timestamp in the data.
561
+ If set to ``"entity"``, will use the timestamp of the entity.
562
+ context_anchor_time: The maximum anchor timestamp for context
563
+ examples. If set to ``None``, ``anchor_time`` will
564
+ determine the anchor time for context examples.
565
+ run_mode: The :class:`RunMode` for the query.
566
+ num_neighbors: The number of neighbors to sample for each hop.
567
+ If specified, the ``num_hops`` option will be ignored.
568
+ num_hops: The number of hops to sample when generating the context.
569
+ max_pq_iterations: The maximum number of iterations to perform to
570
+ collect valid labels. It is advised to increase the number of
571
+ iterations in case the predictive query has strict entity
572
+ filters, in which case, :class:`KumoRFM` needs to sample more
573
+ entities to find valid labels.
574
+ random_seed: A manual seed for generating pseudo-random numbers.
575
+ verbose: Whether to print verbose output.
576
+ use_prediction_time: Whether to use the anchor timestamp as an
577
+ additional feature during prediction. This is typically
578
+ beneficial for time series forecasting tasks.
579
+
580
+ Returns:
581
+ The metrics as a :class:`pandas.DataFrame`
582
+ """
583
+ query_def = self._parse_query(query)
584
+
585
+ if num_hops != 2 and num_neighbors is not None:
586
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
587
+ f"custom 'num_hops={num_hops}' option")
588
+
589
+ if query_def.rfm_entity_ids is not None:
590
+ query_def = replace(
591
+ query_def,
592
+ rfm_entity_ids=None,
593
+ )
594
+
595
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
596
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
597
+
598
+ if not isinstance(verbose, ProgressLogger):
599
+ verbose = InteractiveProgressLogger(msg, verbose=verbose)
600
+
601
+ with verbose as logger:
602
+ context = self._get_context(
603
+ query=query_def,
604
+ indices=None,
605
+ anchor_time=anchor_time,
606
+ context_anchor_time=context_anchor_time,
607
+ run_mode=RunMode(run_mode),
608
+ num_neighbors=num_neighbors,
609
+ num_hops=num_hops,
610
+ max_pq_iterations=max_pq_iterations,
611
+ evaluate=True,
612
+ random_seed=random_seed,
613
+ logger=logger if verbose else None,
614
+ )
615
+ if metrics is not None and len(metrics) > 0:
616
+ self._validate_metrics(metrics, context.task_type)
617
+ metrics = list(dict.fromkeys(metrics))
618
+ request = RFMEvaluateRequest(
619
+ context=context,
620
+ run_mode=RunMode(run_mode),
621
+ metrics=metrics,
622
+ use_prediction_time=use_prediction_time,
623
+ )
624
+ with warnings.catch_warnings():
625
+ warnings.filterwarnings('ignore', message='Protobuf gencode')
626
+ request_msg = request.to_protobuf()
627
+ request_bytes = request_msg.SerializeToString()
628
+ logger.log(f"Generated context of size "
629
+ f"{len(request_bytes) / (1024*1024):.2f}MB")
630
+
631
+ if len(request_bytes) > _MAX_SIZE:
632
+ stats_msg = Context.get_memory_stats(request_msg.context)
633
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
634
+
635
+ try:
636
+ resp = global_state.client.rfm_api.evaluate(request_bytes)
637
+ except HTTPException as e:
638
+ try:
639
+ msg = json.loads(e.detail)['detail']
640
+ except Exception:
641
+ msg = e.detail
642
+ raise RuntimeError(f"An unexpected exception occurred. "
643
+ f"Please create an issue at "
644
+ f"'https://github.com/kumo-ai/kumo-rfm'. "
645
+ f"{msg}") from None
646
+
647
+ return pd.DataFrame.from_dict(
648
+ resp.metrics,
649
+ orient='index',
650
+ columns=['value'],
651
+ ).reset_index(names='metric')
652
+
653
+ def get_train_table(
654
+ self,
655
+ query: str,
656
+ size: int,
657
+ *,
658
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None] = None,
659
+ random_seed: Optional[int] = _RANDOM_SEED,
660
+ max_iterations: int = 20,
661
+ ) -> pd.DataFrame:
662
+ """Returns the labels of a predictive query for a specified anchor
663
+ time.
664
+
665
+ Args:
666
+ query: The predictive query.
667
+ size: The maximum number of entities to generate labels for.
668
+ anchor_time: The anchor timestamp for the query. If set to
669
+ :obj:`None`, will use the maximum timestamp in the data.
670
+ If set to :`"entity"`, will use the timestamp of the entity.
671
+ random_seed: A manual seed for generating pseudo-random numbers.
672
+ max_iterations: The number of steps to run before aborting.
673
+
674
+ Returns:
675
+ The labels as a :class:`pandas.DataFrame`.
676
+ """
677
+ query_def = self._parse_query(query)
678
+
679
+ if anchor_time is None:
680
+ anchor_time = self._graph_store.max_time
681
+ if query_def.target_ast.date_offset_range is not None:
682
+ anchor_time = anchor_time - (
683
+ query_def.target_ast.date_offset_range.end_date_offset *
684
+ query_def.num_forecasts)
685
+
686
+ assert anchor_time is not None
687
+ if isinstance(anchor_time, pd.Timestamp):
688
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
689
+ else:
690
+ assert anchor_time == 'entity'
691
+ if (query_def.entity_table not in self._graph_store.time_dict):
692
+ raise ValueError(f"Anchor time 'entity' requires the entity "
693
+ f"table '{query_def.entity_table}' "
694
+ f"to have a time column")
695
+
696
+ query_driver = LocalPQueryDriver(self._graph_store, query_def,
697
+ random_seed)
698
+
699
+ node, time, y = query_driver.collect_test(
700
+ size=size,
701
+ anchor_time=anchor_time,
702
+ batch_size=min(10_000, size),
703
+ max_iterations=max_iterations,
704
+ guarantee_train_examples=False,
705
+ )
706
+
707
+ entity = self._graph_store.pkey_map_dict[
708
+ query_def.entity_table].index[node]
709
+
710
+ return pd.DataFrame({
711
+ 'ENTITY': entity,
712
+ 'ANCHOR_TIMESTAMP': time,
713
+ 'TARGET': y,
714
+ })
715
+
716
+ # Helpers #################################################################
717
+
718
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
719
+ if isinstance(query, ValidatedPredictiveQuery):
720
+ return query
721
+
722
+ if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
723
+ raise ValueError("'EVALUATE PREDICT ...' queries are not "
724
+ "supported in the SDK. Instead, use either "
725
+ "`predict()` or `evaluate()` methods to perform "
726
+ "predictions or evaluations.")
727
+
728
+ try:
729
+ request = RFMParseQueryRequest(
730
+ query=query,
731
+ graph_definition=self._graph_def,
732
+ )
733
+
734
+ resp = global_state.client.rfm_api.parse_query(request)
735
+ # TODO Expose validation warnings.
736
+
737
+ if len(resp.validation_response.warnings) > 0:
738
+ msg = '\n'.join([
739
+ f'{i+1}. {warning.title}: {warning.message}' for i, warning
740
+ in enumerate(resp.validation_response.warnings)
741
+ ])
742
+ warnings.warn(f"Encountered the following warnings during "
743
+ f"parsing:\n{msg}")
744
+
745
+ return resp.query
746
+ except HTTPException as e:
747
+ try:
748
+ msg = json.loads(e.detail)['detail']
749
+ except Exception:
750
+ msg = e.detail
751
+ raise ValueError(f"Failed to parse query '{query}'. "
752
+ f"{msg}") from None
753
+
754
+ def _validate_time(
755
+ self,
756
+ query: ValidatedPredictiveQuery,
757
+ anchor_time: pd.Timestamp,
758
+ context_anchor_time: Union[pd.Timestamp, None],
759
+ evaluate: bool,
760
+ ) -> None:
761
+
762
+ if self._graph_store.min_time == pd.Timestamp.max:
763
+ return # Graph without timestamps
764
+
765
+ if anchor_time < self._graph_store.min_time:
766
+ raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
767
+ f"the earliest timestamp "
768
+ f"'{self._graph_store.min_time}' in the data.")
769
+
770
+ if (context_anchor_time is not None
771
+ and context_anchor_time < self._graph_store.min_time):
772
+ raise ValueError(f"Context anchor timestamp is too early or "
773
+ f"aggregation time range is too large. To make "
774
+ f"this prediction, we would need data back to "
775
+ f"'{context_anchor_time}', however, your data "
776
+ f"only contains data back to "
777
+ f"'{self._graph_store.min_time}'.")
778
+
779
+ if query.target_ast.date_offset_range is not None:
780
+ end_offset = query.target_ast.date_offset_range.end_date_offset
781
+ else:
782
+ end_offset = pd.DateOffset(0)
783
+ forecast_end_offset = end_offset * query.num_forecasts
784
+ if (context_anchor_time is not None
785
+ and context_anchor_time > anchor_time):
786
+ warnings.warn(f"Context anchor timestamp "
787
+ f"(got '{context_anchor_time}') is set to a later "
788
+ f"date than the prediction anchor timestamp "
789
+ f"(got '{anchor_time}'). Please make sure this is "
790
+ f"intended.")
791
+ elif (query.query_type == QueryType.TEMPORAL
792
+ and context_anchor_time is not None
793
+ and context_anchor_time + forecast_end_offset > anchor_time):
794
+ warnings.warn(f"Aggregation for context examples at timestamp "
795
+ f"'{context_anchor_time}' will leak information "
796
+ f"from the prediction anchor timestamp "
797
+ f"'{anchor_time}'. Please make sure this is "
798
+ f"intended.")
799
+
800
+ elif (context_anchor_time is not None
801
+ and context_anchor_time - forecast_end_offset
802
+ < self._graph_store.min_time):
803
+ _time = context_anchor_time - forecast_end_offset
804
+ warnings.warn(f"Context anchor timestamp is too early or "
805
+ f"aggregation time range is too large. To form "
806
+ f"proper input data, we would need data back to "
807
+ f"'{_time}', however, your data only contains "
808
+ f"data back to '{self._graph_store.min_time}'.")
809
+
810
+ if (not evaluate and anchor_time
811
+ > self._graph_store.max_time + pd.DateOffset(days=1)):
812
+ warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
813
+ f"latest timestamp '{self._graph_store.max_time}' "
814
+ f"in the data. Please make sure this is intended.")
815
+
816
+ max_eval_time = self._graph_store.max_time - forecast_end_offset
817
+ if evaluate and anchor_time > max_eval_time:
818
+ raise ValueError(
819
+ f"Anchor timestamp for evaluation is after the latest "
820
+ f"supported timestamp '{max_eval_time}'.")
821
+
822
+ def _get_context(
823
+ self,
824
+ query: ValidatedPredictiveQuery,
825
+ indices: Union[List[str], List[float], List[int], None],
826
+ anchor_time: Union[pd.Timestamp, Literal['entity'], None],
827
+ context_anchor_time: Union[pd.Timestamp, None],
828
+ run_mode: RunMode,
829
+ num_neighbors: Optional[List[int]],
830
+ num_hops: int,
831
+ max_pq_iterations: int,
832
+ evaluate: bool,
833
+ random_seed: Optional[int] = _RANDOM_SEED,
834
+ logger: Optional[ProgressLogger] = None,
835
+ ) -> Context:
836
+
837
+ if num_neighbors is not None:
838
+ num_hops = len(num_neighbors)
839
+
840
+ if num_hops < 0:
841
+ raise ValueError(f"'num_hops' must be non-negative "
842
+ f"(got {num_hops})")
843
+ if num_hops > 6:
844
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
845
+ f"hops (got {num_hops}). Please reduce the "
846
+ f"number of hops and try again. Please create a "
847
+ f"feature request at "
848
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
849
+ f"must go beyond this for your use-case.")
850
+
851
+ query_driver = LocalPQueryDriver(self._graph_store, query, random_seed)
852
+ task_type = LocalPQueryDriver.get_task_type(
853
+ query,
854
+ edge_types=self._graph_store.edge_types,
855
+ )
856
+
857
+ if logger is not None:
858
+ if task_type == TaskType.BINARY_CLASSIFICATION:
859
+ task_type_repr = 'binary classification'
860
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
861
+ task_type_repr = 'multi-class classification'
862
+ elif task_type == TaskType.REGRESSION:
863
+ task_type_repr = 'regression'
864
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
865
+ task_type_repr = 'link prediction'
866
+ else:
867
+ task_type_repr = str(task_type)
868
+ logger.log(f"Identified {query.query_type} {task_type_repr} task")
869
+
870
+ if task_type.is_link_pred and num_hops < 2:
871
+ raise ValueError(f"Cannot perform link prediction on subgraphs "
872
+ f"with less than 2 hops (got {num_hops}) since "
873
+ f"historical target entities need to be part of "
874
+ f"the context. Please increase the number of "
875
+ f"hops and try again.")
876
+
877
+ if num_neighbors is None:
878
+ if run_mode == RunMode.DEBUG:
879
+ num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
880
+ elif run_mode == RunMode.FAST or task_type.is_link_pred:
881
+ num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
882
+ else:
883
+ num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
884
+
885
+ if query.target_ast.date_offset_range is None:
886
+ end_offset = pd.DateOffset(0)
887
+ else:
888
+ end_offset = query.target_ast.date_offset_range.end_date_offset
889
+ forecast_end_offset = end_offset * query.num_forecasts
890
+ if anchor_time is None:
891
+ anchor_time = self._graph_store.max_time
892
+ if evaluate:
893
+ anchor_time = anchor_time - forecast_end_offset
894
+ if logger is not None:
895
+ assert isinstance(anchor_time, pd.Timestamp)
896
+ if anchor_time == pd.Timestamp.min:
897
+ pass # Static graph
898
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
899
+ and anchor_time.second == 0
900
+ and anchor_time.microsecond == 0):
901
+ logger.log(f"Derived anchor time {anchor_time.date()}")
902
+ else:
903
+ logger.log(f"Derived anchor time {anchor_time}")
904
+
905
+ assert anchor_time is not None
906
+ if isinstance(anchor_time, pd.Timestamp):
907
+ if context_anchor_time is None:
908
+ context_anchor_time = anchor_time - forecast_end_offset
909
+ self._validate_time(query, anchor_time, context_anchor_time,
910
+ evaluate)
911
+ else:
912
+ assert anchor_time == 'entity'
913
+ if query.entity_table not in self._graph_store.time_dict:
914
+ raise ValueError(f"Anchor time 'entity' requires the entity "
915
+ f"table '{query.entity_table}' to "
916
+ f"have a time column")
917
+ if context_anchor_time is not None:
918
+ warnings.warn("Ignoring option 'context_anchor_time' for "
919
+ "`anchor_time='entity'`")
920
+ context_anchor_time = None
921
+
922
+ y_test: Optional[pd.Series] = None
923
+ if evaluate:
924
+ max_test_size = _MAX_TEST_SIZE[run_mode]
925
+ if task_type.is_link_pred:
926
+ max_test_size = max_test_size // 5
927
+
928
+ test_node, test_time, y_test = query_driver.collect_test(
929
+ size=max_test_size,
930
+ anchor_time=anchor_time,
931
+ max_iterations=max_pq_iterations,
932
+ guarantee_train_examples=True,
933
+ )
934
+ if logger is not None:
935
+ if task_type == TaskType.BINARY_CLASSIFICATION:
936
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
937
+ msg = (f"Collected {len(y_test):,} test examples with "
938
+ f"{pos:.2f}% positive cases")
939
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
940
+ msg = (f"Collected {len(y_test):,} test examples "
941
+ f"holding {y_test.nunique()} classes")
942
+ elif task_type == TaskType.REGRESSION:
943
+ _min, _max = float(y_test.min()), float(y_test.max())
944
+ msg = (f"Collected {len(y_test):,} test examples with "
945
+ f"targets between {format_value(_min)} and "
946
+ f"{format_value(_max)}")
947
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
948
+ num_rhs = y_test.explode().nunique()
949
+ msg = (f"Collected {len(y_test):,} test examples with "
950
+ f"{num_rhs:,} unique items")
951
+ else:
952
+ raise NotImplementedError
953
+ logger.log(msg)
954
+
955
+ else:
956
+ assert indices is not None
957
+
958
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
959
+ raise ValueError(f"Cannot predict for more than "
960
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
961
+ f"once (got {len(indices):,}). Use "
962
+ f"`KumoRFM.batch_mode` to process entities "
963
+ f"in batches")
964
+
965
+ test_node = self._graph_store.get_node_id(
966
+ table_name=query.entity_table,
967
+ pkey=pd.Series(indices),
968
+ )
969
+
970
+ if isinstance(anchor_time, pd.Timestamp):
971
+ test_time = pd.Series(anchor_time).repeat(
972
+ len(test_node)).reset_index(drop=True)
973
+ else:
974
+ time = self._graph_store.time_dict[query.entity_table]
975
+ time = time[test_node] * 1000**3
976
+ test_time = pd.Series(time, dtype='datetime64[ns]')
977
+
978
+ train_node, train_time, y_train = query_driver.collect_train(
979
+ size=_MAX_CONTEXT_SIZE[run_mode],
980
+ anchor_time=context_anchor_time or 'entity',
981
+ exclude_node=test_node if (query.query_type == QueryType.STATIC
982
+ or anchor_time == 'entity') else None,
983
+ max_iterations=max_pq_iterations,
984
+ )
985
+
986
+ if logger is not None:
987
+ if task_type == TaskType.BINARY_CLASSIFICATION:
988
+ pos = 100 * int((y_train > 0).sum()) / len(y_train)
989
+ msg = (f"Collected {len(y_train):,} in-context examples with "
990
+ f"{pos:.2f}% positive cases")
991
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
992
+ msg = (f"Collected {len(y_train):,} in-context examples "
993
+ f"holding {y_train.nunique()} classes")
994
+ elif task_type == TaskType.REGRESSION:
995
+ _min, _max = float(y_train.min()), float(y_train.max())
996
+ msg = (f"Collected {len(y_train):,} in-context examples with "
997
+ f"targets between {format_value(_min)} and "
998
+ f"{format_value(_max)}")
999
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1000
+ num_rhs = y_train.explode().nunique()
1001
+ msg = (f"Collected {len(y_train):,} in-context examples with "
1002
+ f"{num_rhs:,} unique items")
1003
+ else:
1004
+ raise NotImplementedError
1005
+ logger.log(msg)
1006
+
1007
+ entity_table_names: Tuple[str, ...]
1008
+ if task_type.is_link_pred:
1009
+ final_aggr = query.get_final_target_aggregation()
1010
+ assert final_aggr is not None
1011
+ edge_fkey = final_aggr._get_target_column_name()
1012
+ for edge_type in self._graph_store.edge_types:
1013
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1014
+ entity_table_names = (
1015
+ query.entity_table,
1016
+ edge_type[2],
1017
+ )
1018
+ else:
1019
+ entity_table_names = (query.entity_table, )
1020
+
1021
+ # Exclude the entity anchor time from the feature set to prevent
1022
+ # running out-of-distribution between in-context and test examples:
1023
+ exclude_cols_dict = query.get_exclude_cols_dict()
1024
+ if anchor_time == 'entity':
1025
+ if entity_table_names[0] not in exclude_cols_dict:
1026
+ exclude_cols_dict[entity_table_names[0]] = []
1027
+ time_column_dict = self._graph_store.time_column_dict
1028
+ time_column = time_column_dict[entity_table_names[0]]
1029
+ exclude_cols_dict[entity_table_names[0]].append(time_column)
1030
+
1031
+ subgraph = self._graph_sampler(
1032
+ entity_table_names=entity_table_names,
1033
+ node=np.concatenate([train_node, test_node]),
1034
+ time=np.concatenate([
1035
+ train_time.astype('datetime64[ns]').astype(int).to_numpy(),
1036
+ test_time.astype('datetime64[ns]').astype(int).to_numpy(),
1037
+ ]),
1038
+ run_mode=run_mode,
1039
+ num_neighbors=num_neighbors,
1040
+ exclude_cols_dict=exclude_cols_dict,
1041
+ )
1042
+
1043
+ if len(subgraph.table_dict) >= 15:
1044
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1045
+ f"tables (got {len(subgraph.table_dict)}). "
1046
+ f"Please create a feature request at "
1047
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1048
+ f"must go beyond this for your use-case.")
1049
+
1050
+ step_size: Optional[int] = None
1051
+ if query.query_type == QueryType.TEMPORAL:
1052
+ step_size = date_offset_to_seconds(end_offset)
1053
+
1054
+ return Context(
1055
+ task_type=task_type,
1056
+ entity_table_names=entity_table_names,
1057
+ subgraph=subgraph,
1058
+ y_train=y_train,
1059
+ y_test=y_test,
1060
+ top_k=query.top_k,
1061
+ step_size=step_size,
1062
+ )
1063
+
1064
+ @staticmethod
1065
+ def _validate_metrics(
1066
+ metrics: List[str],
1067
+ task_type: TaskType,
1068
+ ) -> None:
1069
+
1070
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1071
+ supported_metrics = [
1072
+ 'acc', 'precision', 'recall', 'f1', 'auroc', 'auprc', 'ap'
1073
+ ]
1074
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1075
+ supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
1076
+ elif task_type == TaskType.REGRESSION:
1077
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
1078
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1079
+ supported_metrics = [
1080
+ 'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
1081
+ 'hit_ratio@'
1082
+ ]
1083
+ else:
1084
+ raise NotImplementedError
1085
+
1086
+ for metric in metrics:
1087
+ if '@' in metric:
1088
+ metric_split = metric.split('@')
1089
+ if len(metric_split) != 2:
1090
+ raise ValueError(f"Unsupported metric '{metric}'. "
1091
+ f"Available metrics "
1092
+ f"are {supported_metrics}.")
1093
+
1094
+ name, top_k = f'{metric_split[0]}@', metric_split[1]
1095
+
1096
+ if not top_k.isdigit():
1097
+ raise ValueError(f"Metric '{metric}' does not define a "
1098
+ f"valid 'top_k' value (got '{top_k}').")
1099
+
1100
+ if int(top_k) <= 0:
1101
+ raise ValueError(f"Metric '{metric}' needs to define a "
1102
+ f"positive 'top_k' value (got '{top_k}')")
1103
+
1104
+ if int(top_k) > 100:
1105
+ raise ValueError(f"Metric '{metric}' defines a 'top_k' "
1106
+ f"value greater than 100 "
1107
+ f"(got '{top_k}'). Please create a "
1108
+ f"feature request at "
1109
+ f"'https://github.com/kumo-ai/kumo-rfm' "
1110
+ f"if you must go beyond this for your "
1111
+ f"use-case.")
1112
+
1113
+ metric = name
1114
+
1115
+ if metric not in supported_metrics:
1116
+ raise ValueError(f"Unsupported metric '{metric}'. Available "
1117
+ f"metrics are {supported_metrics}. If you "
1118
+ f"feel a metric is missing, please create a "
1119
+ f"feature request at "
1120
+ f"'https://github.com/kumo-ai/kumo-rfm'.")
1121
+
1122
+
1123
+ def format_value(value: Union[int, float]) -> str:
1124
+ if value == int(value):
1125
+ return f'{int(value):,}'
1126
+ if abs(value) >= 1000:
1127
+ return f'{value:,.0f}'
1128
+ if abs(value) >= 10:
1129
+ return f'{value:.1f}'
1130
+ return f'{value:.2f}'