kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/graph_store.py +52 -104
  10. kumoai/experimental/rfm/backend/local/sampler.py +125 -55
  11. kumoai/experimental/rfm/backend/local/table.py +35 -31
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +174 -49
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  18. kumoai/experimental/rfm/base/__init__.py +21 -5
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +422 -35
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +144 -0
  24. kumoai/experimental/rfm/base/table.py +386 -195
  25. kumoai/experimental/rfm/graph.py +350 -178
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +7 -4
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +630 -408
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +290 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/testing/snow.py +50 -0
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +3 -2
  42. kumoai/utils/display.py +51 -0
  43. kumoai/utils/progress_logger.py +190 -12
  44. kumoai/utils/sql.py +3 -0
  45. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +3 -2
  46. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +49 -40
  47. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  48. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  49. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  50. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  51. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,29 @@
1
1
  import copy
2
2
  import re
3
+ import warnings
3
4
  from abc import ABC, abstractmethod
4
5
  from collections import defaultdict
5
6
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Literal
7
+ from typing import TYPE_CHECKING, Any, Literal, NamedTuple
7
8
 
8
9
  import numpy as np
9
10
  import pandas as pd
10
- from kumoapi.pquery import ValidatedPredictiveQuery
11
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
11
12
  from kumoapi.pquery.AST import Aggregation, ASTNode
12
13
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
13
14
  from kumoapi.typing import Stype
14
15
 
16
+ from kumoai.utils import ProgressLogger
17
+
15
18
  if TYPE_CHECKING:
16
19
  from kumoai.experimental.rfm import Graph
17
20
 
21
+ _coverage_warned = False
22
+
18
23
 
19
24
  @dataclass
20
25
  class SamplerOutput:
26
+ anchor_time: np.ndarray
21
27
  df_dict: dict[str, pd.DataFrame]
22
28
  inverse_dict: dict[str, np.ndarray]
23
29
  batch_dict: dict[str, np.ndarray]
@@ -27,16 +33,26 @@ class SamplerOutput:
27
33
  num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
28
34
 
29
35
 
30
- @dataclass
31
- class TargetOutput:
36
+ class TargetOutput(NamedTuple):
32
37
  entity_pkey: pd.Series
33
38
  anchor_time: pd.Series
34
39
  target: pd.Series
35
- num_trials: int
36
40
 
37
41
 
38
42
  class Sampler(ABC):
39
- def __init__(self, graph: 'Graph') -> None:
43
+ r"""A base class to sample relational data (*i.e.*, subgraphs and
44
+ ground-truth targets) from a custom backend.
45
+
46
+ Args:
47
+ graph: The graph.
48
+ verbose: Whether to print verbose output.
49
+ """
50
+ def __init__(
51
+ self,
52
+ graph: 'Graph',
53
+ verbose: bool | ProgressLogger = True,
54
+ ) -> None:
55
+
40
56
  self._edge_types: list[tuple[str, str, str]] = []
41
57
  for edge in graph.edges:
42
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
@@ -72,37 +88,99 @@ class Sampler(ABC):
72
88
  continue
73
89
  self._table_stype_dict[table.name][column.name] = column.stype
74
90
 
91
+ self._min_time_dict: dict[str, pd.Timestamp] = {}
92
+ self._max_time_dict: dict[str, pd.Timestamp] = {}
93
+
94
+ # Properties ##############################################################
95
+
75
96
  @property
76
97
  def edge_types(self) -> list[tuple[str, str, str]]:
98
+ r"""All available edge types in the graph."""
77
99
  return self._edge_types
78
100
 
79
101
  @property
80
102
  def primary_key_dict(self) -> dict[str, str]:
103
+ r"""All available primary keys in the graph."""
81
104
  return self._primary_key_dict
82
105
 
83
106
  @property
84
107
  def time_column_dict(self) -> dict[str, str]:
108
+ r"""All available time columns in the graph."""
85
109
  return self._time_column_dict
86
110
 
87
111
  @property
88
112
  def end_time_column_dict(self) -> dict[str, str]:
113
+ r"""All available end time columns in the graph."""
89
114
  return self._end_time_column_dict
90
115
 
91
116
  @property
92
117
  def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
118
+ r"""The registered semantic types for all feature columns in all tables
119
+ in the graph.
120
+ """
93
121
  return self._table_stype_dict
94
122
 
123
+ def get_min_time(
124
+ self,
125
+ table_names: list[str] | None = None,
126
+ ) -> pd.Timestamp:
127
+ r"""Returns the minimal timestamp in the union of a set of tables.
128
+
129
+ Args:
130
+ table_names: The set of tables.
131
+ """
132
+ if table_names is None or len(table_names) == 0:
133
+ table_names = list(self.time_column_dict.keys())
134
+ unknown = list(set(table_names) - set(self._min_time_dict.keys()))
135
+ if len(unknown) > 0:
136
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
137
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
138
+ self._min_time_dict[table_name] = min_time
139
+ self._max_time_dict[table_name] = max_time
140
+ return min([self._min_time_dict[table]
141
+ for table in table_names] + [pd.Timestamp.max])
142
+
143
+ def get_max_time(
144
+ self,
145
+ table_names: list[str] | None = None,
146
+ ) -> pd.Timestamp:
147
+ r"""Returns the maximum timestamp in the union of a set of tables.
148
+
149
+ Args:
150
+ table_names: The set of tables.
151
+ """
152
+ if table_names is None or len(table_names) == 0:
153
+ table_names = list(self.time_column_dict.keys())
154
+ unknown = list(set(table_names) - set(self._max_time_dict.keys()))
155
+ if len(unknown) > 0:
156
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
157
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
158
+ self._min_time_dict[table_name] = min_time
159
+ self._max_time_dict[table_name] = max_time
160
+ return max([self._max_time_dict[table]
161
+ for table in table_names] + [pd.Timestamp.min])
162
+
163
+ # Subgraph Sampling #######################################################
164
+
95
165
  def sample_subgraph(
96
166
  self,
97
167
  entity_table_names: tuple[str, ...],
98
168
  entity_pkey: pd.Series,
99
- anchor_time: pd.Series,
169
+ anchor_time: pd.Series | Literal['entity'],
100
170
  num_neighbors: list[int],
101
171
  exclude_cols_dict: dict[str, list[str]] | None = None,
102
172
  ) -> Subgraph:
103
-
173
+ r"""Samples distinct subgraphs for each entity primary key.
174
+
175
+ Args:
176
+ entity_table_names: The entity table names.
177
+ entity_pkey: The primary keys to use as seed nodes.
178
+ anchor_time: The anchor time of the subgraphs.
179
+ num_neighbors: The number of neighbors to sample for each hop.
180
+ exclude_cols_dict: The columns to exclude from the subgraph.
181
+ """
104
182
  # Exclude all columns that leak target information:
105
- table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
183
+ table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
106
184
  if exclude_cols_dict is not None:
107
185
  table_stype_dict = copy.deepcopy(table_stype_dict)
108
186
  for table_name, exclude_cols in exclude_cols_dict.items():
@@ -118,7 +196,8 @@ class Sampler(ABC):
118
196
  for table_name in entity_table_names:
119
197
  columns_dict[table_name].add(self.primary_key_dict[table_name])
120
198
 
121
- if anchor_time.dtype != 'datetime64[ns]':
199
+ if (isinstance(anchor_time, pd.Series)
200
+ and anchor_time.dtype != 'datetime64[ns]'):
122
201
  anchor_time = anchor_time.astype('datetime64[ns]')
123
202
 
124
203
  out = self._sample_subgraph(
@@ -129,8 +208,9 @@ class Sampler(ABC):
129
208
  num_neighbors=num_neighbors,
130
209
  )
131
210
 
211
+ # Parse `SubgraphOutput` into `Subgraph` structure:
132
212
  subgraph = Subgraph(
133
- anchor_time=anchor_time.astype(int).to_numpy(),
213
+ anchor_time=out.anchor_time,
134
214
  table_dict={},
135
215
  link_dict={},
136
216
  )
@@ -148,11 +228,8 @@ class Sampler(ABC):
148
228
  # Set end time to NaT for all values greater than anchor time:
149
229
  assert table_name not in out.inverse_dict
150
230
  ser = df[end_time_column]
151
- if ser.dtype != 'datetime64[ns]':
152
- ser = ser.astype('datetime64[ns]')
153
- mask = ser > anchor_time.iloc[batch]
154
- ser.iloc[mask] = pd.NaT
155
- df[end_time_column] = ser
231
+ mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
232
+ df.loc[mask, end_time_column] = pd.NaT
156
233
 
157
234
  stype_dict = table_stype_dict[table_name]
158
235
  for column_name, stype in stype_dict.items():
@@ -213,24 +290,31 @@ class Sampler(ABC):
213
290
 
214
291
  return subgraph
215
292
 
216
- def sample_target(
293
+ # Predictive Query ########################################################
294
+
295
+ def _get_query_columns_dict(
217
296
  self,
218
297
  query: ValidatedPredictiveQuery,
219
- num_examples: int,
220
- anchor_time: pd.Timestamp | Literal['entity'],
221
- random_seed: int | None = None,
222
- ) -> TargetOutput:
223
-
298
+ ) -> dict[str, set[str]]:
224
299
  columns_dict: dict[str, set[str]] = defaultdict(set)
225
300
  for fqn in query.all_query_columns + [query.entity_column]:
226
301
  table_name, column_name = fqn.split('.')
302
+ if column_name == '*':
303
+ continue
227
304
  columns_dict[table_name].add(column_name)
228
-
229
305
  if column_name := self.time_column_dict.get(query.entity_table):
230
306
  columns_dict[table_name].add(column_name)
231
307
  if column_name := self.end_time_column_dict.get(query.entity_table):
232
308
  columns_dict[table_name].add(column_name)
309
+ return columns_dict
233
310
 
311
+ def _get_query_time_offset_dict(
312
+ self,
313
+ query: ValidatedPredictiveQuery,
314
+ ) -> dict[
315
+ tuple[str, str, str],
316
+ tuple[pd.DateOffset | None, pd.DateOffset],
317
+ ]:
234
318
  time_offset_dict: dict[
235
319
  tuple[str, str, str],
236
320
  tuple[pd.DateOffset | None, pd.DateOffset],
@@ -239,7 +323,6 @@ class Sampler(ABC):
239
323
  def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
240
324
  if isinstance(node, Aggregation):
241
325
  table_name = node._get_target_column_name().split('.')[0]
242
- columns_dict[table_name].add(self.time_column_dict[table_name])
243
326
 
244
327
  edge_types = [
245
328
  edge_type for edge_type in self.edge_types
@@ -273,42 +356,342 @@ class Sampler(ABC):
273
356
  if query.whatif_ast is not None:
274
357
  _add_time_offset(query.whatif_ast)
275
358
 
276
- return self._sample_target(
359
+ return time_offset_dict
360
+
361
+ def sample_target(
362
+ self,
363
+ query: ValidatedPredictiveQuery,
364
+ num_train_examples: int,
365
+ train_anchor_time: pd.Timestamp | Literal['entity'],
366
+ num_train_trials: int,
367
+ num_test_examples: int,
368
+ test_anchor_time: pd.Timestamp | Literal['entity'],
369
+ num_test_trials: int,
370
+ random_seed: int | None = None,
371
+ ) -> tuple[TargetOutput, TargetOutput]:
372
+ r"""Samples ground-truth targets given a predictive query, split into
373
+ training and test set.
374
+
375
+ Args:
376
+ query: The predictive query.
377
+ num_train_examples: How many training examples to produce.
378
+ train_anchor_time: The anchor timestamp for the training set.
379
+ If set to ``"entity"``, will use the timestamp of the entity.
380
+ num_train_trials: The number of training examples to try before
381
+ aborting.
382
+ num_test_examples: How many test examples to produce.
383
+ test_anchor_time: The anchor timestamp for the test set.
384
+ If set to ``"entity"``, will use the timestamp of the entity.
385
+ num_test_trials: The number of test examples to try before
386
+ aborting.
387
+ random_seed: A manual seed for generating pseudo-random numbers.
388
+ """
389
+ rng = np.random.default_rng(random_seed)
390
+
391
+ if num_train_examples == 0 or num_train_trials == 0:
392
+ num_train_examples = num_train_trials = 0
393
+ if num_test_examples == 0 or num_test_trials == 0:
394
+ num_test_examples = num_test_trials = 0
395
+
396
+ # 1. Collect information on what to query #############################
397
+ columns_dict = self._get_query_columns_dict(query)
398
+ time_offset_dict = self._get_query_time_offset_dict(query)
399
+ for table_name, _, _ in time_offset_dict.keys():
400
+ columns_dict[table_name].add(self.time_column_dict[table_name])
401
+
402
+ # 2. Sample random rows from entity table #############################
403
+ shared_train_test = query.query_type == QueryType.STATIC
404
+ shared_train_test &= train_anchor_time == test_anchor_time
405
+ if shared_train_test:
406
+ num_entity_rows = num_train_trials + num_test_trials
407
+ else:
408
+ num_entity_rows = max(num_train_trials, num_test_trials)
409
+ assert num_entity_rows > 0
410
+
411
+ entity_df = self._sample_entity_table(
412
+ table_name=query.entity_table,
413
+ columns=columns_dict[query.entity_table],
414
+ num_rows=num_entity_rows,
415
+ random_seed=random_seed,
416
+ )
417
+
418
+ if len(entity_df) == 0:
419
+ raise ValueError("Failed to find any rows in the entity table "
420
+ "'{query.entity_table}'.")
421
+
422
+ entity_pkey = entity_df[self.primary_key_dict[query.entity_table]]
423
+ entity_time: pd.Series | None = None
424
+ if column_name := self.time_column_dict.get(query.entity_table):
425
+ entity_time = entity_df[column_name]
426
+ entity_end_time: pd.Series | None = None
427
+ if column_name := self.end_time_column_dict.get(query.entity_table):
428
+ entity_end_time = entity_df[column_name]
429
+
430
+ def get_valid_entity_index(
431
+ time: pd.Timestamp | Literal['entity'],
432
+ max_size: int | None = None,
433
+ ) -> np.ndarray:
434
+
435
+ if time == 'entity':
436
+ index: np.ndarray = np.arange(len(entity_pkey))
437
+ elif entity_time is None and entity_end_time is None:
438
+ index = np.arange(len(entity_pkey))
439
+ else:
440
+ mask: np.ndarray | None = None
441
+ if entity_time is not None:
442
+ mask = (entity_time <= time).to_numpy()
443
+ if entity_end_time is not None:
444
+ _mask = (entity_end_time > time).to_numpy()
445
+ _mask |= entity_end_time.isna().to_numpy()
446
+ mask = _mask if mask is None else mask & _mask
447
+ assert mask is not None
448
+ index = mask.nonzero()[0]
449
+
450
+ rng.shuffle(index)
451
+
452
+ if max_size is not None:
453
+ index = index[:max_size]
454
+
455
+ return index
456
+
457
+ # 3. Build training and test candidates ###############################
458
+ train_index = test_index = np.array([], dtype=np.int64)
459
+ train_time = test_time = pd.Series([], dtype='datetime64[ns]')
460
+
461
+ if shared_train_test:
462
+ train_index = get_valid_entity_index(train_anchor_time)
463
+ if train_anchor_time == 'entity': # Sort by timestamp:
464
+ assert entity_time is not None
465
+ train_time = entity_time.iloc[train_index]
466
+ train_time = train_time.reset_index(drop=True)
467
+ train_time = train_time.sort_values(ascending=False)
468
+ perm = train_time.index.to_numpy()
469
+ train_index = train_index[perm]
470
+ train_time = train_time.reset_index(drop=True)
471
+ else:
472
+ train_time = to_ser(train_anchor_time, size=len(train_index))
473
+ else:
474
+ if num_test_examples > 0:
475
+ test_index = get_valid_entity_index( #
476
+ test_anchor_time, max_size=num_test_trials)
477
+ assert test_anchor_time != 'entity'
478
+ test_time = to_ser(test_anchor_time, len(test_index))
479
+
480
+ if query.query_type == QueryType.STATIC and num_train_examples > 0:
481
+ train_index = get_valid_entity_index( #
482
+ train_anchor_time, max_size=num_train_trials)
483
+ assert train_anchor_time != 'entity'
484
+ train_time = to_ser(train_anchor_time, len(train_index))
485
+ elif query.query_type == QueryType.TEMPORAL and num_train_examples:
486
+ aggr_table_names = [
487
+ aggr._get_target_column_name().split('.')[0]
488
+ for aggr in query.get_all_target_aggregations()
489
+ ]
490
+ offset = query.target_timeframe.timeframe * query.num_forecasts
491
+
492
+ train_indices: list[np.ndarray] = []
493
+ train_times: list[pd.Series] = []
494
+ while True:
495
+ train_index = get_valid_entity_index( #
496
+ train_anchor_time, max_size=num_train_trials)
497
+ assert train_anchor_time != 'entity'
498
+ train_time = to_ser(train_anchor_time, len(train_index))
499
+ train_indices.append(train_index)
500
+ train_times.append(train_time)
501
+ if sum(len(x) for x in train_indices) >= num_train_trials:
502
+ break
503
+ train_anchor_time -= offset
504
+ if train_anchor_time < self.get_min_time(aggr_table_names):
505
+ break
506
+ train_index = np.concatenate(train_indices, axis=0)
507
+ train_index = train_index[:num_train_trials]
508
+ train_time = pd.concat(train_times, axis=0, ignore_index=True)
509
+ train_time = train_time.iloc[:num_train_trials]
510
+
511
+ # 4. Sample training and test labels ##################################
512
+ train_y, train_mask, test_y, test_mask = self._sample_target(
277
513
  query=query,
278
- num_examples=num_examples,
279
- anchor_time=anchor_time,
514
+ entity_df=entity_df,
515
+ train_index=train_index,
516
+ train_time=train_time,
517
+ num_train_examples=(num_train_examples + num_test_examples
518
+ if shared_train_test else num_train_examples),
519
+ test_index=test_index,
520
+ test_time=test_time,
521
+ num_test_examples=0 if shared_train_test else num_test_examples,
280
522
  columns_dict=columns_dict,
281
523
  time_offset_dict=time_offset_dict,
282
- random_seed=random_seed,
524
+ )
525
+
526
+ # 5. Post-processing ##################################################
527
+ if shared_train_test:
528
+ num_examples = num_train_examples + num_test_examples
529
+ train_index = train_index[train_mask][:num_examples]
530
+ train_time = train_time.iloc[train_mask].iloc[:num_examples]
531
+ train_y = train_y.iloc[:num_examples]
532
+
533
+ _num_test = num_test_examples
534
+ _num_train = min(num_train_examples, 1000)
535
+ if (num_test_examples > 0 and num_train_examples > 0
536
+ and len(train_y) < num_examples
537
+ and len(train_y) < _num_test + _num_train):
538
+ # Not enough labels to satisfy requested split without losing
539
+ # large number of training examples:
540
+ _num_test = len(train_y) - _num_train
541
+ if _num_test < _num_train: # Fallback to 50/50 split:
542
+ _num_test = len(train_y) // 2
543
+
544
+ test_index = train_index[:_num_test]
545
+ test_pkey = entity_pkey.iloc[test_index]
546
+ test_time = train_time.iloc[:_num_test]
547
+ test_y = train_y.iloc[:_num_test]
548
+
549
+ train_index = train_index[_num_test:]
550
+ train_pkey = entity_pkey.iloc[train_index]
551
+ train_time = train_time.iloc[_num_test:]
552
+ train_y = train_y.iloc[_num_test:]
553
+ else:
554
+ train_index = train_index[train_mask][:num_train_examples]
555
+ train_pkey = entity_pkey.iloc[train_index]
556
+ train_time = train_time.iloc[train_mask].iloc[:num_train_examples]
557
+ train_y = train_y.iloc[:num_train_examples]
558
+
559
+ test_index = test_index[test_mask][:num_test_examples]
560
+ test_pkey = entity_pkey.iloc[test_index]
561
+ test_time = test_time.iloc[test_mask].iloc[:num_test_examples]
562
+ test_y = test_y.iloc[:num_test_examples]
563
+
564
+ train_pkey = train_pkey.reset_index(drop=True)
565
+ train_time = train_time.reset_index(drop=True)
566
+ train_y = train_y.reset_index(drop=True)
567
+ test_pkey = test_pkey.reset_index(drop=True)
568
+ test_time = test_time.reset_index(drop=True)
569
+ test_y = test_y.reset_index(drop=True)
570
+
571
+ if num_train_examples > 0 and len(train_y) == 0:
572
+ raise RuntimeError("Failed to collect any context examples. Is "
573
+ "your predictive query too restrictive?")
574
+
575
+ if num_test_examples > 0 and len(test_y) == 0:
576
+ raise RuntimeError("Failed to collect any test examples for "
577
+ "evaluation. Is your predictive query too "
578
+ "restrictive?")
579
+
580
+ global _coverage_warned
581
+ if (not num_train_examples > 0 #
582
+ and not _coverage_warned #
583
+ and len(entity_df) >= num_entity_rows
584
+ and len(train_y) < num_train_examples // 2):
585
+ _coverage_warned = True
586
+ warnings.warn(f"Failed to collect {num_train_examples:,} context "
587
+ f"examples within {num_train_trials:,} candidates. "
588
+ f"To improve coverage, consider increasing the "
589
+ f"number of PQ iterations using the "
590
+ f"'max_pq_iterations' option. This warning will not "
591
+ f"be shown again in this run.")
592
+
593
+ if (not num_test_examples > 0 #
594
+ and not _coverage_warned #
595
+ and len(entity_df) >= num_entity_rows
596
+ and len(test_y) < num_test_examples // 2):
597
+ _coverage_warned = True
598
+ warnings.warn(f"Failed to collect {num_test_examples:,} test "
599
+ f"examples within {num_test_trials:,} candidates. "
600
+ f"To improve coverage, consider increasing the "
601
+ f"number of PQ iterations using the "
602
+ f"'max_pq_iterations' option. This warning will not "
603
+ f"be shown again in this run.")
604
+
605
+ return (
606
+ TargetOutput(train_pkey, train_time, train_y),
607
+ TargetOutput(test_pkey, test_time, test_y),
283
608
  )
284
609
 
285
610
  # Abstract Methods ########################################################
286
611
 
612
+ @abstractmethod
613
+ def _get_min_max_time_dict(
614
+ self,
615
+ table_names: list[str],
616
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
617
+ r"""Returns the minimum and maximum timestamps for a set of tables.
618
+
619
+ Args:
620
+ table_names: The tables.
621
+ """
622
+
287
623
  @abstractmethod
288
624
  def _sample_subgraph(
289
625
  self,
290
626
  entity_table_name: str,
291
627
  entity_pkey: pd.Series,
292
- anchor_time: pd.Series,
628
+ anchor_time: pd.Series | Literal['entity'],
293
629
  columns_dict: dict[str, set[str]],
294
630
  num_neighbors: list[int],
295
631
  ) -> SamplerOutput:
296
- pass
632
+ r"""Samples distinct subgraphs for each entity primary key.
633
+
634
+ Args:
635
+ entity_table_name: The entity table name.
636
+ entity_pkey: The primary keys to use as seed nodes.
637
+ anchor_time: The anchor time of the subgraphs.
638
+ columns_dict: The columns to return for each table.
639
+ num_neighbors: The number of neighbors to sample for each hop.
640
+ """
641
+
642
+ @abstractmethod
643
+ def _sample_entity_table(
644
+ self,
645
+ table_name: str,
646
+ columns: set[str],
647
+ num_rows: int,
648
+ random_seed: int | None = None,
649
+ ) -> pd.DataFrame:
650
+ r"""Returns a random sample of rows from the entity table.
651
+
652
+ Args:
653
+ table_name: The table.
654
+ columns: The columns to return.
655
+ num_rows: Maximum number of rows to return. Can be smaller in case
656
+ the entity table contains less rows.
657
+ random_seed: A manual seed for generating pseudo-random numbers.
658
+ """
297
659
 
298
660
  @abstractmethod
299
661
  def _sample_target(
300
662
  self,
301
663
  query: ValidatedPredictiveQuery,
302
- num_examples: int,
303
- anchor_time: pd.Timestamp | Literal['entity'],
664
+ entity_df: pd.DataFrame,
665
+ train_index: np.ndarray,
666
+ train_time: pd.Series,
667
+ num_train_examples: int,
668
+ test_index: np.ndarray,
669
+ test_time: pd.Series,
670
+ num_test_examples: int,
304
671
  columns_dict: dict[str, set[str]],
305
672
  time_offset_dict: dict[
306
673
  tuple[str, str, str],
307
674
  tuple[pd.DateOffset | None, pd.DateOffset],
308
675
  ],
309
- random_seed: int | None = None,
310
- ) -> TargetOutput:
311
- pass
676
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
677
+ r"""Samples ground-truth targets given a predictive query from a set of
678
+ training and test candidates.
679
+
680
+ Args:
681
+ query: The predictive query.
682
+ entity_df: The entity data frame, containing the union of all train
683
+ and test candidates.
684
+ train_index: The indices of training candidates.
685
+ train_time: The anchor time of training candidates.
686
+ num_train_examples: How many training examples to produce.
687
+ test_index: The indices of test candidates.
688
+ test_time: The anchor time of test candidates.
689
+ num_test_examples: How many test examples to produce.
690
+ columns_dict: The columns that are being used to compute
691
+ ground-truth targets.
692
+ time_offset_dict: The date offsets to query for each edge type,
693
+ relative to the anchor time.
694
+ """
312
695
 
313
696
 
314
697
  # Helper Functions ############################################################
@@ -372,3 +755,7 @@ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
372
755
  assert len(timestamps) > 0
373
756
  argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
374
757
  return args[argmax]
758
+
759
+
760
+ def to_ser(value: Any, size: int) -> pd.Series:
761
+ return pd.Series([value]).repeat(size).reset_index(drop=True)
@@ -6,9 +6,10 @@ from kumoapi.typing import Dtype
6
6
  @dataclass
7
7
  class SourceColumn:
8
8
  name: str
9
- dtype: Dtype
9
+ dtype: Dtype | None
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
+ is_nullable: bool
12
13
 
13
14
 
14
15
  @dataclass