kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (122) hide show
  1. kumoai/__init__.py +300 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +223 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +471 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +207 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1796 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +210 -0
  51. kumoai/experimental/rfm/authenticate.py +432 -0
  52. kumoai/experimental/rfm/backend/__init__.py +0 -0
  53. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  54. kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
  55. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  56. kumoai/experimental/rfm/backend/local/table.py +113 -0
  57. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  58. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  59. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  60. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  61. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  62. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  63. kumoai/experimental/rfm/base/__init__.py +30 -0
  64. kumoai/experimental/rfm/base/column.py +152 -0
  65. kumoai/experimental/rfm/base/expression.py +44 -0
  66. kumoai/experimental/rfm/base/sampler.py +761 -0
  67. kumoai/experimental/rfm/base/source.py +19 -0
  68. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  69. kumoai/experimental/rfm/base/table.py +736 -0
  70. kumoai/experimental/rfm/graph.py +1237 -0
  71. kumoai/experimental/rfm/infer/__init__.py +19 -0
  72. kumoai/experimental/rfm/infer/categorical.py +40 -0
  73. kumoai/experimental/rfm/infer/dtype.py +82 -0
  74. kumoai/experimental/rfm/infer/id.py +46 -0
  75. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  76. kumoai/experimental/rfm/infer/pkey.py +128 -0
  77. kumoai/experimental/rfm/infer/stype.py +35 -0
  78. kumoai/experimental/rfm/infer/time_col.py +61 -0
  79. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  80. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  81. kumoai/experimental/rfm/pquery/executor.py +102 -0
  82. kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
  83. kumoai/experimental/rfm/relbench.py +76 -0
  84. kumoai/experimental/rfm/rfm.py +1184 -0
  85. kumoai/experimental/rfm/sagemaker.py +138 -0
  86. kumoai/experimental/rfm/task_table.py +231 -0
  87. kumoai/formatting.py +30 -0
  88. kumoai/futures.py +99 -0
  89. kumoai/graph/__init__.py +12 -0
  90. kumoai/graph/column.py +106 -0
  91. kumoai/graph/graph.py +948 -0
  92. kumoai/graph/table.py +838 -0
  93. kumoai/jobs.py +80 -0
  94. kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
  95. kumoai/mixin.py +28 -0
  96. kumoai/pquery/__init__.py +25 -0
  97. kumoai/pquery/prediction_table.py +287 -0
  98. kumoai/pquery/predictive_query.py +641 -0
  99. kumoai/pquery/training_table.py +424 -0
  100. kumoai/spcs.py +121 -0
  101. kumoai/testing/__init__.py +8 -0
  102. kumoai/testing/decorators.py +57 -0
  103. kumoai/testing/snow.py +50 -0
  104. kumoai/trainer/__init__.py +42 -0
  105. kumoai/trainer/baseline_trainer.py +93 -0
  106. kumoai/trainer/config.py +2 -0
  107. kumoai/trainer/distilled_trainer.py +175 -0
  108. kumoai/trainer/job.py +1192 -0
  109. kumoai/trainer/online_serving.py +258 -0
  110. kumoai/trainer/trainer.py +475 -0
  111. kumoai/trainer/util.py +103 -0
  112. kumoai/utils/__init__.py +11 -0
  113. kumoai/utils/datasets.py +83 -0
  114. kumoai/utils/display.py +51 -0
  115. kumoai/utils/forecasting.py +209 -0
  116. kumoai/utils/progress_logger.py +343 -0
  117. kumoai/utils/sql.py +3 -0
  118. kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
  119. kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
  120. kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
  121. kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
  122. kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1184 @@
1
+ import json
2
+ import time
3
+ import warnings
4
+ from collections import defaultdict
5
+ from collections.abc import Generator, Iterator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, replace
8
+ from typing import Any, Literal, overload
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ from kumoapi.model_plan import RunMode
13
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
14
+ from kumoapi.pquery.AST import (
15
+ Aggregation,
16
+ Column,
17
+ Condition,
18
+ Join,
19
+ LogicalOperation,
20
+ )
21
+ from kumoapi.rfm import Context
22
+ from kumoapi.rfm import Explanation as ExplanationConfig
23
+ from kumoapi.rfm import (
24
+ RFMEvaluateRequest,
25
+ RFMParseQueryRequest,
26
+ RFMPredictRequest,
27
+ )
28
+ from kumoapi.task import TaskType
29
+ from kumoapi.typing import AggregationType, Stype
30
+
31
+ from kumoai.client.rfm import RFMAPI
32
+ from kumoai.exceptions import HTTPException
33
+ from kumoai.experimental.rfm import Graph
34
+ from kumoai.experimental.rfm.base import DataBackend, Sampler
35
+ from kumoai.mixin import CastMixin
36
+ from kumoai.utils import ProgressLogger, display
37
+
38
+ _RANDOM_SEED = 42
39
+
40
+ _MAX_PRED_SIZE: dict[TaskType, int] = defaultdict(lambda: 1_000)
41
+ _MAX_PRED_SIZE[TaskType.TEMPORAL_LINK_PREDICTION] = 200
42
+
43
+ _MAX_CONTEXT_SIZE = {
44
+ RunMode.DEBUG: 100,
45
+ RunMode.FAST: 1_000,
46
+ RunMode.NORMAL: 5_000,
47
+ RunMode.BEST: 10_000,
48
+ }
49
+ _MAX_TEST_SIZE = { # Share test set size across run modes for fair comparison:
50
+ RunMode.DEBUG: 100,
51
+ RunMode.FAST: 2_000,
52
+ RunMode.NORMAL: 2_000,
53
+ RunMode.BEST: 2_000,
54
+ }
55
+
56
+ _MAX_SIZE = 30 * 1024 * 1024
57
+ _SIZE_LIMIT_MSG = ("Context size exceeds the 30MB limit. {stats}\nPlease "
58
+ "reduce either the number of tables in the graph, their "
59
+ "number of columns (e.g., large text columns), "
60
+ "neighborhood configuration, or the run mode. If none of "
61
+ "this is possible, please create a feature request at "
62
+ "'https://github.com/kumo-ai/kumo-rfm' if you must go "
63
+ "beyond this for your use-case.")
64
+
65
+
66
+ @dataclass(repr=False)
67
+ class ExplainConfig(CastMixin):
68
+ """Configuration for explainability.
69
+
70
+ Args:
71
+ skip_summary: Whether to skip generating a human-readable summary of
72
+ the explanation.
73
+ """
74
+ skip_summary: bool = False
75
+
76
+
77
+ @dataclass(repr=False)
78
+ class Explanation:
79
+ prediction: pd.DataFrame
80
+ summary: str
81
+ details: ExplanationConfig
82
+
83
+ @overload
84
+ def __getitem__(self, index: Literal[0]) -> pd.DataFrame:
85
+ pass
86
+
87
+ @overload
88
+ def __getitem__(self, index: Literal[1]) -> str:
89
+ pass
90
+
91
+ def __getitem__(self, index: int) -> pd.DataFrame | str:
92
+ if index == 0:
93
+ return self.prediction
94
+ if index == 1:
95
+ return self.summary
96
+ raise IndexError("Index out of range")
97
+
98
+ def __iter__(self) -> Iterator[pd.DataFrame | str]:
99
+ return iter((self.prediction, self.summary))
100
+
101
+ def __repr__(self) -> str:
102
+ return str((self.prediction, self.summary))
103
+
104
+ def print(self) -> None:
105
+ r"""Prints the explanation."""
106
+ display.dataframe(self.prediction)
107
+ display.message(self.summary)
108
+
109
+ def _ipython_display_(self) -> None:
110
+ self.print()
111
+
112
+
113
+ class KumoRFM:
114
+ r"""The Kumo Relational Foundation model (RFM) from the `KumoRFM: A
115
+ Foundation Model for In-Context Learning on Relational Data
116
+ <https://kumo.ai/research/kumo_relational_foundation_model.pdf>`_ paper.
117
+
118
+ :class:`KumoRFM` is a foundation model to generate predictions for any
119
+ relational dataset without training.
120
+ The model is pre-trained and the class provides an interface to query the
121
+ model from a :class:`Graph` object.
122
+
123
+ .. code-block:: python
124
+
125
+ from kumoai.experimental.rfm import Graph, KumoRFM
126
+
127
+ df_users = pd.DataFrame(...)
128
+ df_items = pd.DataFrame(...)
129
+ df_orders = pd.DataFrame(...)
130
+
131
+ graph = Graph.from_data({
132
+ 'users': df_users,
133
+ 'items': df_items,
134
+ 'orders': df_orders,
135
+ })
136
+
137
+ rfm = KumoRFM(graph)
138
+
139
+ query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
140
+ "FOR users.user_id=1")
141
+ result = rfm.predict(query)
142
+
143
+ print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0
144
+ # 1 0.85
145
+
146
+ Args:
147
+ graph: The graph.
148
+ verbose: Whether to print verbose output.
149
+ optimize: If set to ``True``, will optimize the underlying data backend
150
+ for optimal querying. For example, for transactional database
151
+ backends, will create any missing indices. Requires write-access to
152
+ the data backend.
153
+ """
154
+ def __init__(
155
+ self,
156
+ graph: Graph,
157
+ verbose: bool | ProgressLogger = True,
158
+ optimize: bool = False,
159
+ ) -> None:
160
+ graph = graph.validate()
161
+ self._graph_def = graph._to_api_graph_definition()
162
+
163
+ if graph.backend == DataBackend.LOCAL:
164
+ from kumoai.experimental.rfm.backend.local import LocalSampler
165
+ self._sampler: Sampler = LocalSampler(graph, verbose)
166
+ elif graph.backend == DataBackend.SQLITE:
167
+ from kumoai.experimental.rfm.backend.sqlite import SQLiteSampler
168
+ self._sampler = SQLiteSampler(graph, verbose, optimize)
169
+ elif graph.backend == DataBackend.SNOWFLAKE:
170
+ from kumoai.experimental.rfm.backend.snow import SnowSampler
171
+ self._sampler = SnowSampler(graph, verbose)
172
+ else:
173
+ raise NotImplementedError
174
+
175
+ self._client: RFMAPI | None = None
176
+
177
+ self._batch_size: int | Literal['max'] | None = None
178
+ self.num_retries: int = 0
179
+
180
+ @property
181
+ def _api_client(self) -> RFMAPI:
182
+ if self._client is not None:
183
+ return self._client
184
+
185
+ from kumoai.experimental.rfm import global_state
186
+ self._client = RFMAPI(global_state.client)
187
+ return self._client
188
+
189
+ def __repr__(self) -> str:
190
+ return f'{self.__class__.__name__}()'
191
+
192
+ @contextmanager
193
+ def batch_mode(
194
+ self,
195
+ batch_size: int | Literal['max'] = 'max',
196
+ num_retries: int = 1,
197
+ ) -> Generator[None, None, None]:
198
+ """Context manager to predict in batches.
199
+
200
+ .. code-block:: python
201
+
202
+ with model.batch_mode(batch_size='max', num_retries=1):
203
+ df = model.predict(query, indices=...)
204
+
205
+ Args:
206
+ batch_size: The batch size. If set to ``"max"``, will use the
207
+ maximum applicable batch size for the given task.
208
+ num_retries: The maximum number of retries for failed queries due
209
+ to unexpected server issues.
210
+ """
211
+ if batch_size != 'max' and batch_size <= 0:
212
+ raise ValueError(f"'batch_size' must be greater than zero "
213
+ f"(got {batch_size})")
214
+
215
+ if num_retries < 0:
216
+ raise ValueError(f"'num_retries' must be greater than or equal to "
217
+ f"zero (got {num_retries})")
218
+
219
+ self._batch_size = batch_size
220
+ self.num_retries = num_retries
221
+ yield
222
+ self._batch_size = None
223
+ self.num_retries = 0
224
+
225
+ @overload
226
+ def predict(
227
+ self,
228
+ query: str,
229
+ indices: list[str] | list[float] | list[int] | None = None,
230
+ *,
231
+ explain: Literal[False] = False,
232
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
233
+ context_anchor_time: pd.Timestamp | None = None,
234
+ run_mode: RunMode | str = RunMode.FAST,
235
+ num_neighbors: list[int] | None = None,
236
+ num_hops: int = 2,
237
+ max_pq_iterations: int = 10,
238
+ random_seed: int | None = _RANDOM_SEED,
239
+ verbose: bool | ProgressLogger = True,
240
+ use_prediction_time: bool = False,
241
+ ) -> pd.DataFrame:
242
+ pass
243
+
244
+ @overload
245
+ def predict(
246
+ self,
247
+ query: str,
248
+ indices: list[str] | list[float] | list[int] | None = None,
249
+ *,
250
+ explain: Literal[True] | ExplainConfig | dict[str, Any],
251
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
252
+ context_anchor_time: pd.Timestamp | None = None,
253
+ run_mode: RunMode | str = RunMode.FAST,
254
+ num_neighbors: list[int] | None = None,
255
+ num_hops: int = 2,
256
+ max_pq_iterations: int = 10,
257
+ random_seed: int | None = _RANDOM_SEED,
258
+ verbose: bool | ProgressLogger = True,
259
+ use_prediction_time: bool = False,
260
+ ) -> Explanation:
261
+ pass
262
+
263
+ def predict(
264
+ self,
265
+ query: str,
266
+ indices: list[str] | list[float] | list[int] | None = None,
267
+ *,
268
+ explain: bool | ExplainConfig | dict[str, Any] = False,
269
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
270
+ context_anchor_time: pd.Timestamp | None = None,
271
+ run_mode: RunMode | str = RunMode.FAST,
272
+ num_neighbors: list[int] | None = None,
273
+ num_hops: int = 2,
274
+ max_pq_iterations: int = 10,
275
+ random_seed: int | None = _RANDOM_SEED,
276
+ verbose: bool | ProgressLogger = True,
277
+ use_prediction_time: bool = False,
278
+ ) -> pd.DataFrame | Explanation:
279
+ """Returns predictions for a predictive query.
280
+
281
+ Args:
282
+ query: The predictive query.
283
+ indices: The entity primary keys to predict for. Will override the
284
+ indices given as part of the predictive query. Predictions will
285
+ be generated for all indices, independent of whether they
286
+ fulfill entity filter constraints. To pre-filter entities, use
287
+ :meth:`~KumoRFM.is_valid_entity`.
288
+ explain: Configuration for explainability.
289
+ If set to ``True``, will additionally explain the prediction.
290
+ Passing in an :class:`ExplainConfig` instance provides control
291
+ over which parts of explanation are generated.
292
+ Explainability is currently only supported for single entity
293
+ predictions with ``run_mode="FAST"``.
294
+ anchor_time: The anchor timestamp for the prediction. If set to
295
+ ``None``, will use the maximum timestamp in the data.
296
+ If set to ``"entity"``, will use the timestamp of the entity.
297
+ context_anchor_time: The maximum anchor timestamp for context
298
+ examples. If set to ``None``, ``anchor_time`` will
299
+ determine the anchor time for context examples.
300
+ run_mode: The :class:`RunMode` for the query.
301
+ num_neighbors: The number of neighbors to sample for each hop.
302
+ If specified, the ``num_hops`` option will be ignored.
303
+ num_hops: The number of hops to sample when generating the context.
304
+ max_pq_iterations: The maximum number of iterations to perform to
305
+ collect valid labels. It is advised to increase the number of
306
+ iterations in case the predictive query has strict entity
307
+ filters, in which case, :class:`KumoRFM` needs to sample more
308
+ entities to find valid labels.
309
+ random_seed: A manual seed for generating pseudo-random numbers.
310
+ verbose: Whether to print verbose output.
311
+ use_prediction_time: Whether to use the anchor timestamp as an
312
+ additional feature during prediction. This is typically
313
+ beneficial for time series forecasting tasks.
314
+
315
+ Returns:
316
+ The predictions as a :class:`pandas.DataFrame`.
317
+ If ``explain`` is provided, returns an :class:`Explanation` object
318
+ containing the prediction, summary, and details.
319
+ """
320
+ explain_config: ExplainConfig | None = None
321
+ if explain is True:
322
+ explain_config = ExplainConfig()
323
+ elif explain is not False:
324
+ explain_config = ExplainConfig._cast(explain)
325
+
326
+ query_def = self._parse_query(query)
327
+ query_str = query_def.to_string()
328
+
329
+ if num_hops != 2 and num_neighbors is not None:
330
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
331
+ f"custom 'num_hops={num_hops}' option")
332
+
333
+ if explain_config is not None and run_mode in {
334
+ RunMode.NORMAL, RunMode.BEST
335
+ }:
336
+ warnings.warn(f"Explainability is currently only supported for "
337
+ f"run mode 'FAST' (got '{run_mode}'). Provided run "
338
+ f"mode has been reset. Please lower the run mode to "
339
+ f"suppress this warning.")
340
+
341
+ if indices is None:
342
+ if query_def.rfm_entity_ids is None:
343
+ raise ValueError("Cannot find entities to predict for. Please "
344
+ "pass them via `predict(query, indices=...)`")
345
+ indices = query_def.get_rfm_entity_id_list()
346
+ else:
347
+ query_def = replace(query_def, rfm_entity_ids=None)
348
+
349
+ if len(indices) == 0:
350
+ raise ValueError("At least one entity is required")
351
+
352
+ if explain_config is not None and len(indices) > 1:
353
+ raise ValueError(
354
+ f"Cannot explain predictions for more than a single entity "
355
+ f"(got {len(indices)})")
356
+
357
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
358
+ if explain_config is not None:
359
+ msg = f'[bold]EXPLAIN[/bold] {query_repr}'
360
+ else:
361
+ msg = f'[bold]PREDICT[/bold] {query_repr}'
362
+
363
+ if not isinstance(verbose, ProgressLogger):
364
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
365
+
366
+ with verbose as logger:
367
+
368
+ batch_size: int | None = None
369
+ if self._batch_size == 'max':
370
+ task_type = self._get_task_type(
371
+ query=query_def,
372
+ edge_types=self._sampler.edge_types,
373
+ )
374
+ batch_size = _MAX_PRED_SIZE[task_type]
375
+ else:
376
+ batch_size = self._batch_size
377
+
378
+ if batch_size is not None:
379
+ offsets = range(0, len(indices), batch_size)
380
+ batches = [indices[step:step + batch_size] for step in offsets]
381
+ else:
382
+ batches = [indices]
383
+
384
+ if len(batches) > 1:
385
+ logger.log(f"Splitting {len(indices):,} entities into "
386
+ f"{len(batches):,} batches of size {batch_size:,}")
387
+
388
+ predictions: list[pd.DataFrame] = []
389
+ summary: str | None = None
390
+ details: Explanation | None = None
391
+ for i, batch in enumerate(batches):
392
+ # TODO Re-use the context for subsequent predictions.
393
+ context = self._get_context(
394
+ query=query_def,
395
+ indices=batch,
396
+ anchor_time=anchor_time,
397
+ context_anchor_time=context_anchor_time,
398
+ run_mode=RunMode(run_mode),
399
+ num_neighbors=num_neighbors,
400
+ num_hops=num_hops,
401
+ max_pq_iterations=max_pq_iterations,
402
+ evaluate=False,
403
+ random_seed=random_seed,
404
+ logger=logger if i == 0 else None,
405
+ )
406
+ request = RFMPredictRequest(
407
+ context=context,
408
+ run_mode=RunMode(run_mode),
409
+ query=query_str,
410
+ use_prediction_time=use_prediction_time,
411
+ )
412
+ with warnings.catch_warnings():
413
+ warnings.filterwarnings('ignore', message='gencode')
414
+ request_msg = request.to_protobuf()
415
+ _bytes = request_msg.SerializeToString()
416
+ if i == 0:
417
+ logger.log(f"Generated context of size "
418
+ f"{len(_bytes) / (1024*1024):.2f}MB")
419
+
420
+ if len(_bytes) > _MAX_SIZE:
421
+ stats = Context.get_memory_stats(request_msg.context)
422
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats))
423
+
424
+ if i == 0 and len(batches) > 1:
425
+ verbose.init_progress(
426
+ total=len(batches),
427
+ description='Predicting',
428
+ )
429
+
430
+ for attempt in range(self.num_retries + 1):
431
+ try:
432
+ if explain_config is not None:
433
+ resp = self._api_client.explain(
434
+ request=_bytes,
435
+ skip_summary=explain_config.skip_summary,
436
+ )
437
+ summary = resp.summary
438
+ details = resp.details
439
+ else:
440
+ resp = self._api_client.predict(_bytes)
441
+ df = pd.DataFrame(**resp.prediction)
442
+
443
+ # Cast 'ENTITY' to correct data type:
444
+ if 'ENTITY' in df:
445
+ table_dict = context.subgraph.table_dict
446
+ table = table_dict[query_def.entity_table]
447
+ ser = table.df[table.primary_key]
448
+ df['ENTITY'] = df['ENTITY'].astype(ser.dtype)
449
+
450
+ # Cast 'ANCHOR_TIMESTAMP' to correct data type:
451
+ if 'ANCHOR_TIMESTAMP' in df:
452
+ ser = df['ANCHOR_TIMESTAMP']
453
+ if not pd.api.types.is_datetime64_any_dtype(ser):
454
+ if isinstance(ser.iloc[0], str):
455
+ unit = None
456
+ else:
457
+ unit = 'ms'
458
+ df['ANCHOR_TIMESTAMP'] = pd.to_datetime(
459
+ ser, errors='coerce', unit=unit)
460
+
461
+ predictions.append(df)
462
+
463
+ if len(batches) > 1:
464
+ verbose.step()
465
+
466
+ break
467
+ except HTTPException as e:
468
+ if attempt == self.num_retries:
469
+ try:
470
+ msg = json.loads(e.detail)['detail']
471
+ except Exception:
472
+ msg = e.detail
473
+ raise RuntimeError(
474
+ f"An unexpected exception occurred. Please "
475
+ f"create an issue at "
476
+ f"'https://github.com/kumo-ai/kumo-rfm'. {msg}"
477
+ ) from None
478
+
479
+ time.sleep(2**attempt) # 1s, 2s, 4s, 8s, ...
480
+
481
+ if len(predictions) == 1:
482
+ prediction = predictions[0]
483
+ else:
484
+ prediction = pd.concat(predictions, ignore_index=True)
485
+
486
+ if explain_config is not None:
487
+ assert len(predictions) == 1
488
+ assert summary is not None
489
+ assert details is not None
490
+ return Explanation(
491
+ prediction=prediction,
492
+ summary=summary,
493
+ details=details,
494
+ )
495
+
496
+ return prediction
497
+
498
+ def is_valid_entity(
499
+ self,
500
+ query: str,
501
+ indices: list[str] | list[float] | list[int] | None = None,
502
+ *,
503
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
504
+ ) -> np.ndarray:
505
+ r"""Returns a mask that denotes which entities are valid for the
506
+ given predictive query, *i.e.*, which entities fulfill (temporal)
507
+ entity filter constraints.
508
+
509
+ Args:
510
+ query: The predictive query.
511
+ indices: The entity primary keys to predict for. Will override the
512
+ indices given as part of the predictive query.
513
+ anchor_time: The anchor timestamp for the prediction. If set to
514
+ ``None``, will use the maximum timestamp in the data.
515
+ If set to ``"entity"``, will use the timestamp of the entity.
516
+ """
517
+ query_def = self._parse_query(query)
518
+
519
+ if indices is None:
520
+ if query_def.rfm_entity_ids is None:
521
+ raise ValueError("Cannot find entities to predict for. Please "
522
+ "pass them via "
523
+ "`is_valid_entity(query, indices=...)`")
524
+ indices = query_def.get_rfm_entity_id_list()
525
+
526
+ if len(indices) == 0:
527
+ raise ValueError("At least one entity is required")
528
+
529
+ if anchor_time is None:
530
+ anchor_time = self._get_default_anchor_time(query_def)
531
+
532
+ if isinstance(anchor_time, pd.Timestamp):
533
+ self._validate_time(query_def, anchor_time, None, False)
534
+ else:
535
+ assert anchor_time == 'entity'
536
+ if query_def.entity_table not in self._sampler.time_column_dict:
537
+ raise ValueError(f"Anchor time 'entity' requires the entity "
538
+ f"table '{query_def.entity_table}' "
539
+ f"to have a time column.")
540
+
541
+ raise NotImplementedError
542
+
543
+ def evaluate(
544
+ self,
545
+ query: str,
546
+ *,
547
+ metrics: list[str] | None = None,
548
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
549
+ context_anchor_time: pd.Timestamp | None = None,
550
+ run_mode: RunMode | str = RunMode.FAST,
551
+ num_neighbors: list[int] | None = None,
552
+ num_hops: int = 2,
553
+ max_pq_iterations: int = 10,
554
+ random_seed: int | None = _RANDOM_SEED,
555
+ verbose: bool | ProgressLogger = True,
556
+ use_prediction_time: bool = False,
557
+ ) -> pd.DataFrame:
558
+ """Evaluates a predictive query.
559
+
560
+ Args:
561
+ query: The predictive query.
562
+ metrics: The metrics to use.
563
+ anchor_time: The anchor timestamp for the prediction. If set to
564
+ ``None``, will use the maximum timestamp in the data.
565
+ If set to ``"entity"``, will use the timestamp of the entity.
566
+ context_anchor_time: The maximum anchor timestamp for context
567
+ examples. If set to ``None``, ``anchor_time`` will
568
+ determine the anchor time for context examples.
569
+ run_mode: The :class:`RunMode` for the query.
570
+ num_neighbors: The number of neighbors to sample for each hop.
571
+ If specified, the ``num_hops`` option will be ignored.
572
+ num_hops: The number of hops to sample when generating the context.
573
+ max_pq_iterations: The maximum number of iterations to perform to
574
+ collect valid labels. It is advised to increase the number of
575
+ iterations in case the predictive query has strict entity
576
+ filters, in which case, :class:`KumoRFM` needs to sample more
577
+ entities to find valid labels.
578
+ random_seed: A manual seed for generating pseudo-random numbers.
579
+ verbose: Whether to print verbose output.
580
+ use_prediction_time: Whether to use the anchor timestamp as an
581
+ additional feature during prediction. This is typically
582
+ beneficial for time series forecasting tasks.
583
+
584
+ Returns:
585
+ The metrics as a :class:`pandas.DataFrame`
586
+ """
587
+ query_def = self._parse_query(query)
588
+
589
+ if num_hops != 2 and num_neighbors is not None:
590
+ warnings.warn(f"Received custom 'num_neighbors' option; ignoring "
591
+ f"custom 'num_hops={num_hops}' option")
592
+
593
+ if query_def.rfm_entity_ids is not None:
594
+ query_def = replace(
595
+ query_def,
596
+ rfm_entity_ids=None,
597
+ )
598
+
599
+ query_repr = query_def.to_string(rich=True, exclude_predict=True)
600
+ msg = f'[bold]EVALUATE[/bold] {query_repr}'
601
+
602
+ if not isinstance(verbose, ProgressLogger):
603
+ verbose = ProgressLogger.default(msg=msg, verbose=verbose)
604
+
605
+ with verbose as logger:
606
+ context = self._get_context(
607
+ query=query_def,
608
+ indices=None,
609
+ anchor_time=anchor_time,
610
+ context_anchor_time=context_anchor_time,
611
+ run_mode=RunMode(run_mode),
612
+ num_neighbors=num_neighbors,
613
+ num_hops=num_hops,
614
+ max_pq_iterations=max_pq_iterations,
615
+ evaluate=True,
616
+ random_seed=random_seed,
617
+ logger=logger if verbose else None,
618
+ )
619
+ if metrics is not None and len(metrics) > 0:
620
+ self._validate_metrics(metrics, context.task_type)
621
+ metrics = list(dict.fromkeys(metrics))
622
+ request = RFMEvaluateRequest(
623
+ context=context,
624
+ run_mode=RunMode(run_mode),
625
+ metrics=metrics,
626
+ use_prediction_time=use_prediction_time,
627
+ )
628
+ with warnings.catch_warnings():
629
+ warnings.filterwarnings('ignore', message='Protobuf gencode')
630
+ request_msg = request.to_protobuf()
631
+ request_bytes = request_msg.SerializeToString()
632
+ logger.log(f"Generated context of size "
633
+ f"{len(request_bytes) / (1024*1024):.2f}MB")
634
+
635
+ if len(request_bytes) > _MAX_SIZE:
636
+ stats_msg = Context.get_memory_stats(request_msg.context)
637
+ raise ValueError(_SIZE_LIMIT_MSG.format(stats=stats_msg))
638
+
639
+ try:
640
+ resp = self._api_client.evaluate(request_bytes)
641
+ except HTTPException as e:
642
+ try:
643
+ msg = json.loads(e.detail)['detail']
644
+ except Exception:
645
+ msg = e.detail
646
+ raise RuntimeError(f"An unexpected exception occurred. "
647
+ f"Please create an issue at "
648
+ f"'https://github.com/kumo-ai/kumo-rfm'. "
649
+ f"{msg}") from None
650
+
651
+ return pd.DataFrame.from_dict(
652
+ resp.metrics,
653
+ orient='index',
654
+ columns=['value'],
655
+ ).reset_index(names='metric')
656
+
657
+ def get_train_table(
658
+ self,
659
+ query: str,
660
+ size: int,
661
+ *,
662
+ anchor_time: pd.Timestamp | Literal['entity'] | None = None,
663
+ random_seed: int | None = _RANDOM_SEED,
664
+ max_iterations: int = 10,
665
+ ) -> pd.DataFrame:
666
+ """Returns the labels of a predictive query for a specified anchor
667
+ time.
668
+
669
+ Args:
670
+ query: The predictive query.
671
+ size: The maximum number of entities to generate labels for.
672
+ anchor_time: The anchor timestamp for the query. If set to
673
+ :obj:`None`, will use the maximum timestamp in the data.
674
+ If set to :`"entity"`, will use the timestamp of the entity.
675
+ random_seed: A manual seed for generating pseudo-random numbers.
676
+ max_iterations: The number of steps to run before aborting.
677
+
678
+ Returns:
679
+ The labels as a :class:`pandas.DataFrame`.
680
+ """
681
+ query_def = self._parse_query(query)
682
+
683
+ if anchor_time is None:
684
+ anchor_time = self._get_default_anchor_time(query_def)
685
+ if query_def.target_ast.date_offset_range is not None:
686
+ offset = query_def.target_ast.date_offset_range.end_date_offset
687
+ offset *= query_def.num_forecasts
688
+ anchor_time -= offset
689
+
690
+ assert anchor_time is not None
691
+ if isinstance(anchor_time, pd.Timestamp):
692
+ self._validate_time(query_def, anchor_time, None, evaluate=True)
693
+ else:
694
+ assert anchor_time == 'entity'
695
+ if query_def.entity_table not in self._sampler.time_column_dict:
696
+ raise ValueError(f"Anchor time 'entity' requires the entity "
697
+ f"table '{query_def.entity_table}' "
698
+ f"to have a time column")
699
+
700
+ train, test = self._sampler.sample_target(
701
+ query=query_def,
702
+ num_train_examples=0,
703
+ train_anchor_time=anchor_time,
704
+ num_train_trials=0,
705
+ num_test_examples=size,
706
+ test_anchor_time=anchor_time,
707
+ num_test_trials=max_iterations * size,
708
+ random_seed=random_seed,
709
+ )
710
+
711
+ return pd.DataFrame({
712
+ 'ENTITY': test.entity_pkey,
713
+ 'ANCHOR_TIMESTAMP': test.anchor_time,
714
+ 'TARGET': test.target,
715
+ })
716
+
717
+ # Helpers #################################################################
718
+
719
+ def _parse_query(self, query: str) -> ValidatedPredictiveQuery:
720
+ if isinstance(query, ValidatedPredictiveQuery):
721
+ return query
722
+
723
+ if isinstance(query, str) and query.strip()[:9].lower() == 'evaluate ':
724
+ raise ValueError("'EVALUATE PREDICT ...' queries are not "
725
+ "supported in the SDK. Instead, use either "
726
+ "`predict()` or `evaluate()` methods to perform "
727
+ "predictions or evaluations.")
728
+
729
+ try:
730
+ request = RFMParseQueryRequest(
731
+ query=query,
732
+ graph_definition=self._graph_def,
733
+ )
734
+
735
+ resp = self._api_client.parse_query(request)
736
+
737
+ if len(resp.validation_response.warnings) > 0:
738
+ msg = '\n'.join([
739
+ f'{i+1}. {warning.title}: {warning.message}' for i, warning
740
+ in enumerate(resp.validation_response.warnings)
741
+ ])
742
+ warnings.warn(f"Encountered the following warnings during "
743
+ f"parsing:\n{msg}")
744
+
745
+ return resp.query
746
+ except HTTPException as e:
747
+ try:
748
+ msg = json.loads(e.detail)['detail']
749
+ except Exception:
750
+ msg = e.detail
751
+ raise ValueError(f"Failed to parse query '{query}'. "
752
+ f"{msg}") from None
753
+
754
+ @staticmethod
755
+ def _get_task_type(
756
+ query: ValidatedPredictiveQuery,
757
+ edge_types: list[tuple[str, str, str]],
758
+ ) -> TaskType:
759
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
760
+ return TaskType.BINARY_CLASSIFICATION
761
+
762
+ target = query.target_ast
763
+ if isinstance(target, Join):
764
+ target = target.rhs_target
765
+ if isinstance(target, Aggregation):
766
+ if target.aggr == AggregationType.LIST_DISTINCT:
767
+ table_name, col_name = target._get_target_column_name().split(
768
+ '.')
769
+ target_edge_types = [
770
+ edge_type for edge_type in edge_types
771
+ if edge_type[0] == table_name and edge_type[1] == col_name
772
+ ]
773
+ if len(target_edge_types) != 1:
774
+ raise NotImplementedError(
775
+ f"Multilabel-classification queries based on "
776
+ f"'LIST_DISTINCT' are not supported yet. If you "
777
+ f"planned to write a link prediction query instead, "
778
+ f"make sure to register '{col_name}' as a "
779
+ f"foreign key.")
780
+ return TaskType.TEMPORAL_LINK_PREDICTION
781
+
782
+ return TaskType.REGRESSION
783
+
784
+ assert isinstance(target, Column)
785
+
786
+ if target.stype in {Stype.ID, Stype.categorical}:
787
+ return TaskType.MULTICLASS_CLASSIFICATION
788
+
789
+ if target.stype in {Stype.numerical}:
790
+ return TaskType.REGRESSION
791
+
792
+ raise NotImplementedError("Task type not yet supported")
793
+
794
+ def _get_default_anchor_time(
795
+ self,
796
+ query: ValidatedPredictiveQuery,
797
+ ) -> pd.Timestamp:
798
+ if query.query_type == QueryType.TEMPORAL:
799
+ aggr_table_names = [
800
+ aggr._get_target_column_name().split('.')[0]
801
+ for aggr in query.get_all_target_aggregations()
802
+ ]
803
+ return self._sampler.get_max_time(aggr_table_names)
804
+
805
+ assert query.query_type == QueryType.STATIC
806
+ return self._sampler.get_max_time()
807
+
808
+ def _validate_time(
809
+ self,
810
+ query: ValidatedPredictiveQuery,
811
+ anchor_time: pd.Timestamp,
812
+ context_anchor_time: pd.Timestamp | None,
813
+ evaluate: bool,
814
+ ) -> None:
815
+
816
+ if len(self._sampler.time_column_dict) == 0:
817
+ return # Graph without timestamps
818
+
819
+ min_time = self._sampler.get_min_time()
820
+ max_time = self._sampler.get_max_time()
821
+
822
+ if anchor_time < min_time:
823
+ raise ValueError(f"Anchor timestamp '{anchor_time}' is before "
824
+ f"the earliest timestamp '{min_time}' in the "
825
+ f"data.")
826
+
827
+ if context_anchor_time is not None and context_anchor_time < min_time:
828
+ raise ValueError(f"Context anchor timestamp is too early or "
829
+ f"aggregation time range is too large. To make "
830
+ f"this prediction, we would need data back to "
831
+ f"'{context_anchor_time}', however, your data "
832
+ f"only contains data back to '{min_time}'.")
833
+
834
+ if query.target_ast.date_offset_range is not None:
835
+ end_offset = query.target_ast.date_offset_range.end_date_offset
836
+ else:
837
+ end_offset = pd.DateOffset(0)
838
+ end_offset = end_offset * query.num_forecasts
839
+
840
+ if (context_anchor_time is not None
841
+ and context_anchor_time > anchor_time):
842
+ warnings.warn(f"Context anchor timestamp "
843
+ f"(got '{context_anchor_time}') is set to a later "
844
+ f"date than the prediction anchor timestamp "
845
+ f"(got '{anchor_time}'). Please make sure this is "
846
+ f"intended.")
847
+ elif (query.query_type == QueryType.TEMPORAL
848
+ and context_anchor_time is not None
849
+ and context_anchor_time + end_offset > anchor_time):
850
+ warnings.warn(f"Aggregation for context examples at timestamp "
851
+ f"'{context_anchor_time}' will leak information "
852
+ f"from the prediction anchor timestamp "
853
+ f"'{anchor_time}'. Please make sure this is "
854
+ f"intended.")
855
+
856
+ elif (context_anchor_time is not None
857
+ and context_anchor_time - end_offset < min_time):
858
+ _time = context_anchor_time - end_offset
859
+ warnings.warn(f"Context anchor timestamp is too early or "
860
+ f"aggregation time range is too large. To form "
861
+ f"proper input data, we would need data back to "
862
+ f"'{_time}', however, your data only contains "
863
+ f"data back to '{min_time}'.")
864
+
865
+ if not evaluate and anchor_time > max_time + pd.DateOffset(days=1):
866
+ warnings.warn(f"Anchor timestamp '{anchor_time}' is after the "
867
+ f"latest timestamp '{max_time}' in the data. Please "
868
+ f"make sure this is intended.")
869
+
870
+ if evaluate and anchor_time > max_time - end_offset:
871
+ raise ValueError(
872
+ f"Anchor timestamp for evaluation is after the latest "
873
+ f"supported timestamp '{max_time - end_offset}'.")
874
+
875
+ def _get_context(
876
+ self,
877
+ query: ValidatedPredictiveQuery,
878
+ indices: list[str] | list[float] | list[int] | None,
879
+ anchor_time: pd.Timestamp | Literal['entity'] | None,
880
+ context_anchor_time: pd.Timestamp | None,
881
+ run_mode: RunMode,
882
+ num_neighbors: list[int] | None,
883
+ num_hops: int,
884
+ max_pq_iterations: int,
885
+ evaluate: bool,
886
+ random_seed: int | None = _RANDOM_SEED,
887
+ logger: ProgressLogger | None = None,
888
+ ) -> Context:
889
+
890
+ if num_neighbors is not None:
891
+ num_hops = len(num_neighbors)
892
+
893
+ if num_hops < 0:
894
+ raise ValueError(f"'num_hops' must be non-negative "
895
+ f"(got {num_hops})")
896
+ if num_hops > 6:
897
+ raise ValueError(f"Cannot predict on subgraphs with more than 6 "
898
+ f"hops (got {num_hops}). Please reduce the "
899
+ f"number of hops and try again. Please create a "
900
+ f"feature request at "
901
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
902
+ f"must go beyond this for your use-case.")
903
+
904
+ task_type = self._get_task_type(
905
+ query=query,
906
+ edge_types=self._sampler.edge_types,
907
+ )
908
+
909
+ if logger is not None:
910
+ if task_type == TaskType.BINARY_CLASSIFICATION:
911
+ task_type_repr = 'binary classification'
912
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
913
+ task_type_repr = 'multi-class classification'
914
+ elif task_type == TaskType.REGRESSION:
915
+ task_type_repr = 'regression'
916
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
917
+ task_type_repr = 'link prediction'
918
+ else:
919
+ task_type_repr = str(task_type)
920
+ logger.log(f"Identified {query.query_type} {task_type_repr} task")
921
+
922
+ if task_type.is_link_pred and num_hops < 2:
923
+ raise ValueError(f"Cannot perform link prediction on subgraphs "
924
+ f"with less than 2 hops (got {num_hops}) since "
925
+ f"historical target entities need to be part of "
926
+ f"the context. Please increase the number of "
927
+ f"hops and try again.")
928
+
929
+ if num_neighbors is None:
930
+ if run_mode == RunMode.DEBUG:
931
+ num_neighbors = [16, 16, 4, 4, 1, 1][:num_hops]
932
+ elif run_mode == RunMode.FAST or task_type.is_link_pred:
933
+ num_neighbors = [32, 32, 8, 8, 4, 4][:num_hops]
934
+ else:
935
+ num_neighbors = [64, 64, 8, 8, 4, 4][:num_hops]
936
+
937
+ if query.target_ast.date_offset_range is None:
938
+ step_offset = pd.DateOffset(0)
939
+ else:
940
+ step_offset = query.target_ast.date_offset_range.end_date_offset
941
+ end_offset = step_offset * query.num_forecasts
942
+
943
+ if anchor_time is None:
944
+ anchor_time = self._get_default_anchor_time(query)
945
+
946
+ if evaluate:
947
+ anchor_time = anchor_time - end_offset
948
+
949
+ if logger is not None:
950
+ assert isinstance(anchor_time, pd.Timestamp)
951
+ if anchor_time == pd.Timestamp.min:
952
+ pass # Static graph
953
+ elif (anchor_time.hour == 0 and anchor_time.minute == 0
954
+ and anchor_time.second == 0
955
+ and anchor_time.microsecond == 0):
956
+ logger.log(f"Derived anchor time {anchor_time.date()}")
957
+ else:
958
+ logger.log(f"Derived anchor time {anchor_time}")
959
+
960
+ assert anchor_time is not None
961
+ if isinstance(anchor_time, pd.Timestamp):
962
+ if context_anchor_time == 'entity':
963
+ raise ValueError("Anchor time 'entity' needs to be shared "
964
+ "for context and prediction examples")
965
+ if context_anchor_time is None:
966
+ context_anchor_time = anchor_time - end_offset
967
+ self._validate_time(query, anchor_time, context_anchor_time,
968
+ evaluate)
969
+ else:
970
+ assert anchor_time == 'entity'
971
+ if query.query_type != QueryType.STATIC:
972
+ raise ValueError("Anchor time 'entity' is only valid for "
973
+ "static predictive queries")
974
+ if query.entity_table not in self._sampler.time_column_dict:
975
+ raise ValueError(f"Anchor time 'entity' requires the entity "
976
+ f"table '{query.entity_table}' to "
977
+ f"have a time column")
978
+ if isinstance(context_anchor_time, pd.Timestamp):
979
+ raise ValueError("Anchor time 'entity' needs to be shared "
980
+ "for context and prediction examples")
981
+ context_anchor_time = 'entity'
982
+
983
+ num_train_examples = _MAX_CONTEXT_SIZE[run_mode]
984
+ if evaluate:
985
+ num_test_examples = _MAX_TEST_SIZE[run_mode]
986
+ if task_type.is_link_pred:
987
+ num_test_examples = num_test_examples // 5
988
+ else:
989
+ num_test_examples = 0
990
+
991
+ train, test = self._sampler.sample_target(
992
+ query=query,
993
+ num_train_examples=num_train_examples,
994
+ train_anchor_time=context_anchor_time,
995
+ num_train_trials=max_pq_iterations * num_train_examples,
996
+ num_test_examples=num_test_examples,
997
+ test_anchor_time=anchor_time,
998
+ num_test_trials=max_pq_iterations * num_test_examples,
999
+ random_seed=random_seed,
1000
+ )
1001
+ train_pkey, train_time, y_train = train
1002
+ test_pkey, test_time, y_test = test
1003
+
1004
+ if evaluate and logger is not None:
1005
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1006
+ pos = 100 * int((y_test > 0).sum()) / len(y_test)
1007
+ msg = (f"Collected {len(y_test):,} test examples with "
1008
+ f"{pos:.2f}% positive cases")
1009
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1010
+ msg = (f"Collected {len(y_test):,} test examples holding "
1011
+ f"{y_test.nunique()} classes")
1012
+ elif task_type == TaskType.REGRESSION:
1013
+ _min, _max = float(y_test.min()), float(y_test.max())
1014
+ msg = (f"Collected {len(y_test):,} test examples with targets "
1015
+ f"between {format_value(_min)} and "
1016
+ f"{format_value(_max)}")
1017
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1018
+ num_rhs = y_test.explode().nunique()
1019
+ msg = (f"Collected {len(y_test):,} test examples with "
1020
+ f"{num_rhs:,} unique items")
1021
+ else:
1022
+ raise NotImplementedError
1023
+ logger.log(msg)
1024
+
1025
+ if not evaluate:
1026
+ assert indices is not None
1027
+ if len(indices) > _MAX_PRED_SIZE[task_type]:
1028
+ raise ValueError(f"Cannot predict for more than "
1029
+ f"{_MAX_PRED_SIZE[task_type]:,} entities at "
1030
+ f"once (got {len(indices):,}). Use "
1031
+ f"`KumoRFM.batch_mode` to process entities "
1032
+ f"in batches")
1033
+
1034
+ test_pkey = pd.Series(indices, dtype=train_pkey.dtype)
1035
+ if isinstance(anchor_time, pd.Timestamp):
1036
+ test_time = pd.Series([anchor_time]).repeat(
1037
+ len(indices)).reset_index(drop=True)
1038
+ else:
1039
+ train_time = test_time = 'entity'
1040
+
1041
+ if logger is not None:
1042
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1043
+ pos = 100 * int((y_train > 0).sum()) / len(y_train)
1044
+ msg = (f"Collected {len(y_train):,} in-context examples with "
1045
+ f"{pos:.2f}% positive cases")
1046
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1047
+ msg = (f"Collected {len(y_train):,} in-context examples "
1048
+ f"holding {y_train.nunique()} classes")
1049
+ elif task_type == TaskType.REGRESSION:
1050
+ _min, _max = float(y_train.min()), float(y_train.max())
1051
+ msg = (f"Collected {len(y_train):,} in-context examples with "
1052
+ f"targets between {format_value(_min)} and "
1053
+ f"{format_value(_max)}")
1054
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1055
+ num_rhs = y_train.explode().nunique()
1056
+ msg = (f"Collected {len(y_train):,} in-context examples with "
1057
+ f"{num_rhs:,} unique items")
1058
+ else:
1059
+ raise NotImplementedError
1060
+ logger.log(msg)
1061
+
1062
+ entity_table_names: tuple[str, ...]
1063
+ if task_type.is_link_pred:
1064
+ final_aggr = query.get_final_target_aggregation()
1065
+ assert final_aggr is not None
1066
+ edge_fkey = final_aggr._get_target_column_name()
1067
+ for edge_type in self._sampler.edge_types:
1068
+ if edge_fkey == f'{edge_type[0]}.{edge_type[1]}':
1069
+ entity_table_names = (
1070
+ query.entity_table,
1071
+ edge_type[2],
1072
+ )
1073
+ else:
1074
+ entity_table_names = (query.entity_table, )
1075
+
1076
+ # Exclude the entity anchor time from the feature set to prevent
1077
+ # running out-of-distribution between in-context and test examples:
1078
+ exclude_cols_dict = query.get_exclude_cols_dict()
1079
+ if entity_table_names[0] in self._sampler.time_column_dict:
1080
+ if entity_table_names[0] not in exclude_cols_dict:
1081
+ exclude_cols_dict[entity_table_names[0]] = []
1082
+ time_column = self._sampler.time_column_dict[entity_table_names[0]]
1083
+ exclude_cols_dict[entity_table_names[0]].append(time_column)
1084
+
1085
+ subgraph = self._sampler.sample_subgraph(
1086
+ entity_table_names=entity_table_names,
1087
+ entity_pkey=pd.concat(
1088
+ [train_pkey, test_pkey],
1089
+ axis=0,
1090
+ ignore_index=True,
1091
+ ),
1092
+ anchor_time=pd.concat(
1093
+ [train_time, test_time],
1094
+ axis=0,
1095
+ ignore_index=True,
1096
+ ) if isinstance(train_time, pd.Series) else 'entity',
1097
+ num_neighbors=num_neighbors,
1098
+ exclude_cols_dict=exclude_cols_dict,
1099
+ )
1100
+
1101
+ if len(subgraph.table_dict) >= 15:
1102
+ raise ValueError(f"Cannot query from a graph with more than 15 "
1103
+ f"tables (got {len(subgraph.table_dict)}). "
1104
+ f"Please create a feature request at "
1105
+ f"'https://github.com/kumo-ai/kumo-rfm' if you "
1106
+ f"must go beyond this for your use-case.")
1107
+
1108
+ return Context(
1109
+ task_type=task_type,
1110
+ entity_table_names=entity_table_names,
1111
+ subgraph=subgraph,
1112
+ y_train=y_train,
1113
+ y_test=y_test if evaluate else None,
1114
+ top_k=query.top_k,
1115
+ step_size=None,
1116
+ )
1117
+
1118
+ @staticmethod
1119
+ def _validate_metrics(
1120
+ metrics: list[str],
1121
+ task_type: TaskType,
1122
+ ) -> None:
1123
+
1124
+ if task_type == TaskType.BINARY_CLASSIFICATION:
1125
+ supported_metrics = [
1126
+ 'acc', 'precision', 'recall', 'f1', 'auroc', 'auprc', 'ap'
1127
+ ]
1128
+ elif task_type == TaskType.MULTICLASS_CLASSIFICATION:
1129
+ supported_metrics = ['acc', 'precision', 'recall', 'f1', 'mrr']
1130
+ elif task_type == TaskType.REGRESSION:
1131
+ supported_metrics = ['mae', 'mape', 'mse', 'rmse', 'smape', 'r2']
1132
+ elif task_type == TaskType.TEMPORAL_LINK_PREDICTION:
1133
+ supported_metrics = [
1134
+ 'map@', 'ndcg@', 'mrr@', 'precision@', 'recall@', 'f1@',
1135
+ 'hit_ratio@'
1136
+ ]
1137
+ else:
1138
+ raise NotImplementedError
1139
+
1140
+ for metric in metrics:
1141
+ if '@' in metric:
1142
+ metric_split = metric.split('@')
1143
+ if len(metric_split) != 2:
1144
+ raise ValueError(f"Unsupported metric '{metric}'. "
1145
+ f"Available metrics "
1146
+ f"are {supported_metrics}.")
1147
+
1148
+ name, top_k = f'{metric_split[0]}@', metric_split[1]
1149
+
1150
+ if not top_k.isdigit():
1151
+ raise ValueError(f"Metric '{metric}' does not define a "
1152
+ f"valid 'top_k' value (got '{top_k}').")
1153
+
1154
+ if int(top_k) <= 0:
1155
+ raise ValueError(f"Metric '{metric}' needs to define a "
1156
+ f"positive 'top_k' value (got '{top_k}')")
1157
+
1158
+ if int(top_k) > 100:
1159
+ raise ValueError(f"Metric '{metric}' defines a 'top_k' "
1160
+ f"value greater than 100 "
1161
+ f"(got '{top_k}'). Please create a "
1162
+ f"feature request at "
1163
+ f"'https://github.com/kumo-ai/kumo-rfm' "
1164
+ f"if you must go beyond this for your "
1165
+ f"use-case.")
1166
+
1167
+ metric = name
1168
+
1169
+ if metric not in supported_metrics:
1170
+ raise ValueError(f"Unsupported metric '{metric}'. Available "
1171
+ f"metrics are {supported_metrics}. If you "
1172
+ f"feel a metric is missing, please create a "
1173
+ f"feature request at "
1174
+ f"'https://github.com/kumo-ai/kumo-rfm'.")
1175
+
1176
+
1177
+ def format_value(value: int | float) -> str:
1178
+ if value == int(value):
1179
+ return f'{int(value):,}'
1180
+ if abs(value) >= 1000:
1181
+ return f'{value:,.0f}'
1182
+ if abs(value) >= 10:
1183
+ return f'{value:.1f}'
1184
+ return f'{value:.2f}'