kumoai 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +213 -14
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +25 -6
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +438 -38
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512081731.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,30 @@
1
1
  import copy
2
2
  import re
3
+ import warnings
3
4
  from abc import ABC, abstractmethod
4
5
  from collections import defaultdict
5
6
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Literal
7
+ from typing import TYPE_CHECKING, Any, Literal, NamedTuple
7
8
 
8
9
  import numpy as np
9
10
  import pandas as pd
10
- from kumoapi.pquery import ValidatedPredictiveQuery
11
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
11
12
  from kumoapi.pquery.AST import Aggregation, ASTNode
12
13
  from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
13
14
  from kumoapi.typing import Stype
14
15
 
16
+ from kumoai.experimental.rfm.base import SourceColumn
17
+ from kumoai.utils import ProgressLogger
18
+
15
19
  if TYPE_CHECKING:
16
20
  from kumoai.experimental.rfm import Graph
17
21
 
22
+ _coverage_warned = False
23
+
18
24
 
19
25
  @dataclass
20
- class BackwardSamplerOutput:
26
+ class SamplerOutput:
27
+ anchor_time: np.ndarray
21
28
  df_dict: dict[str, pd.DataFrame]
22
29
  inverse_dict: dict[str, np.ndarray]
23
30
  batch_dict: dict[str, np.ndarray]
@@ -27,15 +34,25 @@ class BackwardSamplerOutput:
27
34
  num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
28
35
 
29
36
 
30
- @dataclass
31
- class ForwardSamplerOutput:
37
+ class TargetOutput(NamedTuple):
32
38
  entity_pkey: pd.Series
33
39
  anchor_time: pd.Series
34
40
  target: pd.Series
35
41
 
36
42
 
37
43
  class Sampler(ABC):
38
- def __init__(self, graph: 'Graph') -> None:
44
+ r"""A base class to sample relational data (*i.e.*, subgraphs and
45
+ ground-truth targets) from a custom backend.
46
+
47
+ Args:
48
+ graph: The graph.
49
+ verbose: Whether to print verbose output.
50
+ """
51
+ def __init__(
52
+ self,
53
+ graph: 'Graph',
54
+ verbose: bool | ProgressLogger = True,
55
+ ) -> None:
39
56
  self._edge_types: list[tuple[str, str, str]] = []
40
57
  for edge in graph.edges:
41
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
@@ -71,35 +88,106 @@ class Sampler(ABC):
71
88
  continue
72
89
  self._table_stype_dict[table.name][column.name] = column.stype
73
90
 
91
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
92
+ for table in graph.tables.values():
93
+ self._source_table_dict[table.name] = table._source_column_dict
94
+
95
+ self._min_time_dict: dict[str, pd.Timestamp] = {}
96
+ self._max_time_dict: dict[str, pd.Timestamp] = {}
97
+
98
+ # Properties ##############################################################
99
+
74
100
  @property
75
101
  def edge_types(self) -> list[tuple[str, str, str]]:
102
+ r"""All available edge types in the graph."""
76
103
  return self._edge_types
77
104
 
78
105
  @property
79
106
  def primary_key_dict(self) -> dict[str, str]:
107
+ r"""All available primary keys in the graph."""
80
108
  return self._primary_key_dict
81
109
 
82
110
  @property
83
111
  def time_column_dict(self) -> dict[str, str]:
112
+ r"""All available time columns in the graph."""
84
113
  return self._time_column_dict
85
114
 
86
115
  @property
87
116
  def end_time_column_dict(self) -> dict[str, str]:
117
+ r"""All available end time columns in the graph."""
88
118
  return self._end_time_column_dict
89
119
 
90
120
  @property
91
121
  def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
122
+ r"""The registered semantic types for all columns in all tables in
123
+ the graph.
124
+ """
92
125
  return self._table_stype_dict
93
126
 
127
+ @property
128
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
129
+ r"""Source column information for all tables in the graph."""
130
+ return self._source_table_dict
131
+
132
+ def get_min_time(
133
+ self,
134
+ table_names: list[str] | None = None,
135
+ ) -> pd.Timestamp:
136
+ r"""Returns the minimal timestamp in the union of a set of tables.
137
+
138
+ Args:
139
+ table_names: The set of tables.
140
+ """
141
+ if table_names is None or len(table_names) == 0:
142
+ table_names = list(self.time_column_dict.keys())
143
+ unknown = list(set(table_names) - set(self._min_time_dict.keys()))
144
+ if len(unknown) > 0:
145
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
146
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
147
+ self._min_time_dict[table_name] = min_time
148
+ self._max_time_dict[table_name] = max_time
149
+ return min([self._min_time_dict[table]
150
+ for table in table_names] + [pd.Timestamp.max])
151
+
152
+ def get_max_time(
153
+ self,
154
+ table_names: list[str] | None = None,
155
+ ) -> pd.Timestamp:
156
+ r"""Returns the maximum timestamp in the union of a set of tables.
157
+
158
+ Args:
159
+ table_names: The set of tables.
160
+ """
161
+ if table_names is None or len(table_names) == 0:
162
+ table_names = list(self.time_column_dict.keys())
163
+ unknown = list(set(table_names) - set(self._max_time_dict.keys()))
164
+ if len(unknown) > 0:
165
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
166
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
167
+ self._min_time_dict[table_name] = min_time
168
+ self._max_time_dict[table_name] = max_time
169
+ return max([self._max_time_dict[table]
170
+ for table in table_names] + [pd.Timestamp.min])
171
+
172
+ # Subgraph Sampling #######################################################
173
+
94
174
  def sample_subgraph(
95
175
  self,
96
176
  entity_table_names: tuple[str, ...],
97
177
  entity_pkey: pd.Series,
98
- anchor_time: pd.Series,
178
+ anchor_time: pd.Series | Literal['entity'],
99
179
  num_neighbors: list[int],
100
180
  exclude_cols_dict: dict[str, list[str]] | None = None,
101
181
  ) -> Subgraph:
102
-
182
+ r"""Samples distinct subgraphs for each entity primary key.
183
+
184
+ Args:
185
+ entity_table_names: The entity table names.
186
+ entity_pkey: The primary keys to use as seed nodes.
187
+ anchor_time: The anchor time of the subgraphs.
188
+ num_neighbors: The number of neighbors to sample for each hop.
189
+ exclude_cols_dict: The columns to exclude from the subgraph.
190
+ """
103
191
  # Exclude all columns that leak target information:
104
192
  table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
105
193
  if exclude_cols_dict is not None:
@@ -117,10 +205,11 @@ class Sampler(ABC):
117
205
  for table_name in entity_table_names:
118
206
  columns_dict[table_name].add(self.primary_key_dict[table_name])
119
207
 
120
- if anchor_time.dtype != 'datetime64[ns]':
208
+ if (isinstance(anchor_time, pd.Series)
209
+ and anchor_time.dtype != 'datetime64[ns]'):
121
210
  anchor_time = anchor_time.astype('datetime64[ns]')
122
211
 
123
- out = self._sample_backward(
212
+ out = self._sample_subgraph(
124
213
  entity_table_name=entity_table_names[0],
125
214
  entity_pkey=entity_pkey,
126
215
  anchor_time=anchor_time,
@@ -128,8 +217,9 @@ class Sampler(ABC):
128
217
  num_neighbors=num_neighbors,
129
218
  )
130
219
 
220
+ # Parse `SubgraphOutput` into `Subgraph` structure:
131
221
  subgraph = Subgraph(
132
- anchor_time=anchor_time.astype(int).to_numpy(),
222
+ anchor_time=out.anchor_time,
133
223
  table_dict={},
134
224
  link_dict={},
135
225
  )
@@ -149,7 +239,7 @@ class Sampler(ABC):
149
239
  ser = df[end_time_column]
150
240
  if ser.dtype != 'datetime64[ns]':
151
241
  ser = ser.astype('datetime64[ns]')
152
- mask = ser > anchor_time.iloc[batch]
242
+ mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
153
243
  ser.iloc[mask] = pd.NaT
154
244
  df[end_time_column] = ser
155
245
 
@@ -212,24 +302,31 @@ class Sampler(ABC):
212
302
 
213
303
  return subgraph
214
304
 
215
- def sample_forward(
305
+ # Predictive Query ########################################################
306
+
307
+ def _get_query_columns_dict(
216
308
  self,
217
309
  query: ValidatedPredictiveQuery,
218
- num_examples: int,
219
- anchor_time: pd.Timestamp | Literal['entity'],
220
- random_seed: int | None = None,
221
- ) -> ForwardSamplerOutput:
222
-
310
+ ) -> dict[str, set[str]]:
223
311
  columns_dict: dict[str, set[str]] = defaultdict(set)
224
312
  for fqn in query.all_query_columns + [query.entity_column]:
225
313
  table_name, column_name = fqn.split('.')
314
+ if column_name == '*':
315
+ continue
226
316
  columns_dict[table_name].add(column_name)
317
+ if column_name := self.time_column_dict.get(query.entity_table):
318
+ columns_dict[table_name].add(column_name)
319
+ if column_name := self.end_time_column_dict.get(query.entity_table):
320
+ columns_dict[table_name].add(column_name)
321
+ return columns_dict
227
322
 
228
- if time_column := self.time_column_dict[query.entity_table]:
229
- columns_dict[table_name].add(time_column)
230
- if end_time_column := self.end_time_column_dict[query.entity_table]:
231
- columns_dict[table_name].add(end_time_column)
232
-
323
+ def _get_query_time_offset_dict(
324
+ self,
325
+ query: ValidatedPredictiveQuery,
326
+ ) -> dict[
327
+ tuple[str, str, str],
328
+ tuple[pd.DateOffset | None, pd.DateOffset],
329
+ ]:
233
330
  time_offset_dict: dict[
234
331
  tuple[str, str, str],
235
332
  tuple[pd.DateOffset | None, pd.DateOffset],
@@ -238,7 +335,6 @@ class Sampler(ABC):
238
335
  def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
239
336
  if isinstance(node, Aggregation):
240
337
  table_name = node._get_target_column_name().split('.')[0]
241
- columns_dict[table_name].add(self.time_column_dict[table_name])
242
338
 
243
339
  edge_types = [
244
340
  edge_type for edge_type in self.edge_types
@@ -272,42 +368,342 @@ class Sampler(ABC):
272
368
  if query.whatif_ast is not None:
273
369
  _add_time_offset(query.whatif_ast)
274
370
 
275
- return self._sample_forward(
371
+ return time_offset_dict
372
+
373
+ def sample_target(
374
+ self,
375
+ query: ValidatedPredictiveQuery,
376
+ num_train_examples: int,
377
+ train_anchor_time: pd.Timestamp | Literal['entity'],
378
+ num_train_trials: int,
379
+ num_test_examples: int,
380
+ test_anchor_time: pd.Timestamp | Literal['entity'],
381
+ num_test_trials: int,
382
+ random_seed: int | None = None,
383
+ ) -> tuple[TargetOutput, TargetOutput]:
384
+ r"""Samples ground-truth targets given a predictive query, split into
385
+ training and test set.
386
+
387
+ Args:
388
+ query: The predictive query.
389
+ num_train_examples: How many training examples to produce.
390
+ train_anchor_time: The anchor timestamp for the training set.
391
+ If set to ``"entity"``, will use the timestamp of the entity.
392
+ num_train_trials: The number of training examples to try before
393
+ aborting.
394
+ num_test_examples: How many test examples to produce.
395
+ test_anchor_time: The anchor timestamp for the test set.
396
+ If set to ``"entity"``, will use the timestamp of the entity.
397
+ num_test_trials: The number of test examples to try before
398
+ aborting.
399
+ random_seed: A manual seed for generating pseudo-random numbers.
400
+ """
401
+ rng = np.random.default_rng(random_seed)
402
+
403
+ if num_train_examples == 0 or num_train_trials == 0:
404
+ num_train_examples = num_train_trials = 0
405
+ if num_test_examples == 0 or num_test_trials == 0:
406
+ num_test_examples = num_test_trials = 0
407
+
408
+ # 1. Collect information on what to query #############################
409
+ columns_dict = self._get_query_columns_dict(query)
410
+ time_offset_dict = self._get_query_time_offset_dict(query)
411
+ for table_name, _, _ in time_offset_dict.keys():
412
+ columns_dict[table_name].add(self.time_column_dict[table_name])
413
+
414
+ # 2. Sample random rows from entity table #############################
415
+ shared_train_test = query.query_type == QueryType.STATIC
416
+ shared_train_test &= train_anchor_time == test_anchor_time
417
+ if shared_train_test:
418
+ num_entity_rows = num_train_trials + num_test_trials
419
+ else:
420
+ num_entity_rows = max(num_train_trials, num_test_trials)
421
+ assert num_entity_rows > 0
422
+
423
+ entity_df = self._sample_entity_table(
424
+ table_name=query.entity_table,
425
+ columns=columns_dict[query.entity_table],
426
+ num_rows=num_entity_rows,
427
+ random_seed=random_seed,
428
+ )
429
+
430
+ if len(entity_df) == 0:
431
+ raise ValueError("Failed to find any rows in the entity table "
432
+ "'{query.entity_table}'.")
433
+
434
+ entity_pkey = entity_df[self.primary_key_dict[query.entity_table]]
435
+ entity_time: pd.Series | None = None
436
+ if column_name := self.time_column_dict.get(query.entity_table):
437
+ entity_time = entity_df[column_name]
438
+ entity_end_time: pd.Series | None = None
439
+ if column_name := self.end_time_column_dict.get(query.entity_table):
440
+ entity_end_time = entity_df[column_name]
441
+
442
+ def get_valid_entity_index(
443
+ time: pd.Timestamp | Literal['entity'],
444
+ max_size: int | None = None,
445
+ ) -> np.ndarray:
446
+
447
+ if time == 'entity':
448
+ index: np.ndarray = np.arange(len(entity_pkey))
449
+ elif entity_time is None and entity_end_time is None:
450
+ index = np.arange(len(entity_pkey))
451
+ else:
452
+ mask: np.ndarray | None = None
453
+ if entity_time is not None:
454
+ mask = (entity_time <= time).to_numpy()
455
+ if entity_end_time is not None:
456
+ _mask = (entity_end_time > time).to_numpy()
457
+ _mask |= entity_end_time.isna().to_numpy()
458
+ mask = _mask if mask is None else mask & _mask
459
+ assert mask is not None
460
+ index = mask.nonzero()[0]
461
+
462
+ rng.shuffle(index)
463
+
464
+ if max_size is not None:
465
+ index = index[:max_size]
466
+
467
+ return index
468
+
469
+ # 3. Build training and test candidates ###############################
470
+ train_index = test_index = np.array([], dtype=np.int64)
471
+ train_time = test_time = pd.Series([], dtype='datetime64[ns]')
472
+
473
+ if shared_train_test:
474
+ train_index = get_valid_entity_index(train_anchor_time)
475
+ if train_anchor_time == 'entity': # Sort by timestamp:
476
+ assert entity_time is not None
477
+ train_time = entity_time.iloc[train_index]
478
+ train_time = train_time.reset_index(drop=True)
479
+ train_time = train_time.sort_values(ascending=False)
480
+ perm = train_time.index.to_numpy()
481
+ train_index = train_index[perm]
482
+ train_time = train_time.reset_index(drop=True)
483
+ else:
484
+ train_time = to_ser(train_anchor_time, size=len(train_index))
485
+ else:
486
+ if num_test_examples > 0:
487
+ test_index = get_valid_entity_index( #
488
+ test_anchor_time, max_size=num_test_trials)
489
+ assert test_anchor_time != 'entity'
490
+ test_time = to_ser(test_anchor_time, len(test_index))
491
+
492
+ if query.query_type == QueryType.STATIC and num_train_examples > 0:
493
+ train_index = get_valid_entity_index( #
494
+ train_anchor_time, max_size=num_train_trials)
495
+ assert train_anchor_time != 'entity'
496
+ train_time = to_ser(train_anchor_time, len(train_index))
497
+ elif query.query_type == QueryType.TEMPORAL and num_train_examples:
498
+ aggr_table_names = [
499
+ aggr._get_target_column_name().split('.')[0]
500
+ for aggr in query.get_all_target_aggregations()
501
+ ]
502
+ offset = query.target_timeframe.timeframe * query.num_forecasts
503
+
504
+ train_indices: list[np.ndarray] = []
505
+ train_times: list[pd.Series] = []
506
+ while True:
507
+ train_index = get_valid_entity_index( #
508
+ train_anchor_time, max_size=num_train_trials)
509
+ assert train_anchor_time != 'entity'
510
+ train_time = to_ser(train_anchor_time, len(train_index))
511
+ train_indices.append(train_index)
512
+ train_times.append(train_time)
513
+ if sum(len(x) for x in train_indices) >= num_train_trials:
514
+ break
515
+ train_anchor_time -= offset
516
+ if train_anchor_time < self.get_min_time(aggr_table_names):
517
+ break
518
+ train_index = np.concatenate(train_indices, axis=0)
519
+ train_index = train_index[:num_train_trials]
520
+ train_time = pd.concat(train_times, axis=0, ignore_index=True)
521
+ train_time = train_time.iloc[:num_train_trials]
522
+
523
+ # 4. Sample training and test labels ##################################
524
+ train_y, train_mask, test_y, test_mask = self._sample_target(
276
525
  query=query,
277
- num_examples=num_examples,
278
- anchor_time=anchor_time,
526
+ entity_df=entity_df,
527
+ train_index=train_index,
528
+ train_time=train_time,
529
+ num_train_examples=(num_train_examples + num_test_examples
530
+ if shared_train_test else num_train_examples),
531
+ test_index=test_index,
532
+ test_time=test_time,
533
+ num_test_examples=0 if shared_train_test else num_test_examples,
279
534
  columns_dict=columns_dict,
280
535
  time_offset_dict=time_offset_dict,
281
- random_seed=random_seed,
536
+ )
537
+
538
+ # 5. Post-processing ##################################################
539
+ if shared_train_test:
540
+ num_examples = num_train_examples + num_test_examples
541
+ train_index = train_index[train_mask][:num_examples]
542
+ train_time = train_time.iloc[train_mask].iloc[:num_examples]
543
+ train_y = train_y.iloc[:num_examples]
544
+
545
+ _num_test = num_test_examples
546
+ _num_train = min(num_train_examples, 1000)
547
+ if (num_test_examples > 0 and num_train_examples > 0
548
+ and len(train_y) < num_examples
549
+ and len(train_y) < _num_test + _num_train):
550
+ # Not enough labels to satisfy requested split without losing
551
+ # large number of training examples:
552
+ _num_test = len(train_y) - _num_train
553
+ if _num_test < _num_train: # Fallback to 50/50 split:
554
+ _num_test = len(train_y) // 2
555
+
556
+ test_index = train_index[:_num_test]
557
+ test_pkey = entity_pkey.iloc[test_index]
558
+ test_time = train_time.iloc[:_num_test]
559
+ test_y = train_y.iloc[:_num_test]
560
+
561
+ train_index = train_index[_num_test:]
562
+ train_pkey = entity_pkey.iloc[train_index]
563
+ train_time = train_time.iloc[_num_test:]
564
+ train_y = train_y.iloc[_num_test:]
565
+ else:
566
+ train_index = train_index[train_mask][:num_train_examples]
567
+ train_pkey = entity_pkey.iloc[train_index]
568
+ train_time = train_time.iloc[train_mask].iloc[:num_train_examples]
569
+ train_y = train_y.iloc[:num_train_examples]
570
+
571
+ test_index = test_index[test_mask][:num_test_examples]
572
+ test_pkey = entity_pkey.iloc[test_index]
573
+ test_time = test_time.iloc[test_mask].iloc[:num_test_examples]
574
+ test_y = test_y.iloc[:num_test_examples]
575
+
576
+ train_pkey = train_pkey.reset_index(drop=True)
577
+ train_time = train_time.reset_index(drop=True)
578
+ train_y = train_y.reset_index(drop=True)
579
+ test_pkey = test_pkey.reset_index(drop=True)
580
+ test_time = test_time.reset_index(drop=True)
581
+ test_y = test_y.reset_index(drop=True)
582
+
583
+ if num_train_examples > 0 and len(train_y) == 0:
584
+ raise RuntimeError("Failed to collect any context examples. Is "
585
+ "your predictive query too restrictive?")
586
+
587
+ if num_test_examples > 0 and len(test_y) == 0:
588
+ raise RuntimeError("Failed to collect any test examples for "
589
+ "evaluation. Is your predictive query too "
590
+ "restrictive?")
591
+
592
+ global _coverage_warned
593
+ if (not num_train_examples > 0 #
594
+ and not _coverage_warned #
595
+ and len(entity_df) >= num_entity_rows
596
+ and len(train_y) < num_train_examples // 2):
597
+ _coverage_warned = True
598
+ warnings.warn(f"Failed to collect {num_train_examples:,} context "
599
+ f"examples within {num_train_trials:,} candidates. "
600
+ f"To improve coverage, consider increasing the "
601
+ f"number of PQ iterations using the "
602
+ f"'max_pq_iterations' option. This warning will not "
603
+ f"be shown again in this run.")
604
+
605
+ if (not num_test_examples > 0 #
606
+ and not _coverage_warned #
607
+ and len(entity_df) >= num_entity_rows
608
+ and len(test_y) < num_test_examples // 2):
609
+ _coverage_warned = True
610
+ warnings.warn(f"Failed to collect {num_test_examples:,} test "
611
+ f"examples within {num_test_trials:,} candidates. "
612
+ f"To improve coverage, consider increasing the "
613
+ f"number of PQ iterations using the "
614
+ f"'max_pq_iterations' option. This warning will not "
615
+ f"be shown again in this run.")
616
+
617
+ return (
618
+ TargetOutput(train_pkey, train_time, train_y),
619
+ TargetOutput(test_pkey, test_time, test_y),
282
620
  )
283
621
 
284
622
  # Abstract Methods ########################################################
285
623
 
286
624
  @abstractmethod
287
- def _sample_backward(
625
+ def _get_min_max_time_dict(
626
+ self,
627
+ table_names: list[str],
628
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
629
+ r"""Returns the minimum and maximum timestamps for a set of tables.
630
+
631
+ Args:
632
+ table_names: The tables.
633
+ """
634
+
635
+ @abstractmethod
636
+ def _sample_subgraph(
288
637
  self,
289
638
  entity_table_name: str,
290
639
  entity_pkey: pd.Series,
291
- anchor_time: pd.Series,
640
+ anchor_time: pd.Series | Literal['entity'],
292
641
  columns_dict: dict[str, set[str]],
293
642
  num_neighbors: list[int],
294
- ) -> BackwardSamplerOutput:
295
- pass
643
+ ) -> SamplerOutput:
644
+ r"""Samples distinct subgraphs for each entity primary key.
645
+
646
+ Args:
647
+ entity_table_name: The entity table name.
648
+ entity_pkey: The primary keys to use as seed nodes.
649
+ anchor_time: The anchor time of the subgraphs.
650
+ columns_dict: The columns to return for each table.
651
+ num_neighbors: The number of neighbors to sample for each hop.
652
+ """
296
653
 
297
654
  @abstractmethod
298
- def _sample_forward(
655
+ def _sample_entity_table(
656
+ self,
657
+ table_name: str,
658
+ columns: set[str],
659
+ num_rows: int,
660
+ random_seed: int | None = None,
661
+ ) -> pd.DataFrame:
662
+ r"""Returns a random sample of rows from the entity table.
663
+
664
+ Args:
665
+ table_name: The table.
666
+ columns: The columns to return.
667
+ num_rows: Maximum number of rows to return. Can be smaller in case
668
+ the entity table contains less rows.
669
+ random_seed: A manual seed for generating pseudo-random numbers.
670
+ """
671
+
672
+ @abstractmethod
673
+ def _sample_target(
299
674
  self,
300
675
  query: ValidatedPredictiveQuery,
301
- num_examples: int,
302
- anchor_time: pd.Timestamp | Literal['entity'],
676
+ entity_df: pd.DataFrame,
677
+ train_index: np.ndarray,
678
+ train_time: pd.Series,
679
+ num_train_examples: int,
680
+ test_index: np.ndarray,
681
+ test_time: pd.Series,
682
+ num_test_examples: int,
303
683
  columns_dict: dict[str, set[str]],
304
684
  time_offset_dict: dict[
305
685
  tuple[str, str, str],
306
686
  tuple[pd.DateOffset | None, pd.DateOffset],
307
687
  ],
308
- random_seed: int | None = None,
309
- ) -> ForwardSamplerOutput:
310
- pass
688
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
689
+ r"""Samples ground-truth targets given a predictive query from a set of
690
+ training and test candidates.
691
+
692
+ Args:
693
+ query: The predictive query.
694
+ entity_df: The entity data frame, containing the union of all train
695
+ and test candidates.
696
+ train_index: The indices of training candidates.
697
+ train_time: The anchor time of training candidates.
698
+ num_train_examples: How many training examples to produce.
699
+ test_index: The indices of test candidates.
700
+ test_time: The anchor time of test candidates.
701
+ num_test_examples: How many test examples to produce.
702
+ columns_dict: The columns that are being used to compute
703
+ ground-truth targets.
704
+ time_offset_dict: The date offsets to query for each edge type,
705
+ relative to the anchor time.
706
+ """
311
707
 
312
708
 
313
709
  # Helper Functions ############################################################
@@ -371,3 +767,7 @@ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
371
767
  assert len(timestamps) > 0
372
768
  argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
373
769
  return args[argmax]
770
+
771
+
772
+ def to_ser(value: Any, size: int) -> pd.Series:
773
+ return pd.Series([value]).repeat(size).reset_index(drop=True)
@@ -9,6 +9,7 @@ class SourceColumn:
9
9
  dtype: Dtype
10
10
  is_primary_key: bool
11
11
  is_unique_key: bool
12
+ is_nullable: bool
12
13
 
13
14
 
14
15
  @dataclass