kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +178 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +22 -4
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +696 -47
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +384 -207
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +359 -187
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,48 +1,75 @@
1
1
  import copy
2
+ import re
3
+ import warnings
2
4
  from abc import ABC, abstractmethod
5
+ from collections import defaultdict
3
6
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING
7
+ from typing import TYPE_CHECKING, Any, Literal, NamedTuple
5
8
 
6
9
  import numpy as np
7
10
  import pandas as pd
8
- from kumoapi.rfm.context import Subgraph
11
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
12
+ from kumoapi.pquery.AST import Aggregation, ASTNode
13
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
9
14
  from kumoapi.typing import Stype
10
15
 
16
+ from kumoai.utils import ProgressLogger
17
+
11
18
  if TYPE_CHECKING:
12
19
  from kumoai.experimental.rfm import Graph
13
20
 
14
-
15
- @dataclass
16
- class EdgeSpec:
17
- num_neighbors: int | None = None
18
- time_offsets: tuple[
19
- pd.DateOffset | None,
20
- pd.DateOffset,
21
- ] | None = None
22
-
23
- def __post_init__(self) -> None:
24
- if (self.num_neighbors is None) == (self.time_offsets is None):
25
- raise ValueError("Only one of 'num_neighbors' and 'time_offsets' "
26
- "must be provided")
21
+ _coverage_warned = False
27
22
 
28
23
 
29
24
  @dataclass
30
25
  class SamplerOutput:
26
+ anchor_time: np.ndarray
31
27
  df_dict: dict[str, pd.DataFrame]
32
- batch_dict: dict[str, pd.DataFrame]
28
+ inverse_dict: dict[str, np.ndarray]
29
+ batch_dict: dict[str, np.ndarray]
33
30
  num_sampled_nodes_dict: dict[str, list[int]]
34
- edge_index_dict: dict[tuple[str, str, str], np.ndarray] | None = None
35
- num_sampled_edges_dict: dict[tuple[str, str, str], list[int]] | None = None
31
+ row_dict: dict[tuple[str, str, str], np.ndarray]
32
+ col_dict: dict[tuple[str, str, str], np.ndarray]
33
+ num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
34
+
35
+
36
+ class TargetOutput(NamedTuple):
37
+ entity_pkey: pd.Series
38
+ anchor_time: pd.Series
39
+ target: pd.Series
36
40
 
37
41
 
38
42
  class Sampler(ABC):
39
- def __init__(self, graph: 'Graph') -> None:
43
+ r"""A base class to sample relational data (*i.e.*, subgraphs and
44
+ ground-truth targets) from a custom backend.
45
+
46
+ Args:
47
+ graph: The graph.
48
+ verbose: Whether to print verbose output.
49
+ """
50
+ def __init__(
51
+ self,
52
+ graph: 'Graph',
53
+ verbose: bool | ProgressLogger = True,
54
+ ) -> None:
55
+
40
56
  self._edge_types: list[tuple[str, str, str]] = []
41
57
  for edge in graph.edges:
42
58
  edge_type = (edge.src_table, edge.fkey, edge.dst_table)
43
59
  self._edge_types.append(edge_type)
44
60
  self._edge_types.append(Subgraph.rev_edge_type(edge_type))
45
61
 
62
+ # Source Table -> [(Foreign Key, Destination Table)]
63
+ self._foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
64
+ # Destination Table -> [(Source Table, Foreign Key)]
65
+ self._rev_foreign_key_dict: dict[str, list[tuple[str, str]]] = {}
66
+ for table in graph.tables.values():
67
+ self._foreign_key_dict[table.name] = []
68
+ self._rev_foreign_key_dict[table.name] = []
69
+ for src_table, fkey, dst_table in graph.edges:
70
+ self._foreign_key_dict[src_table].append((fkey, dst_table))
71
+ self._rev_foreign_key_dict[dst_table].append((src_table, fkey))
72
+
46
73
  self._primary_key_dict: dict[str, str] = {
47
74
  table.name: table._primary_key
48
75
  for table in graph.tables.values()
@@ -55,80 +82,702 @@ class Sampler(ABC):
55
82
  if table._time_column is not None
56
83
  }
57
84
 
85
+ self._end_time_column_dict: dict[str, str] = {
86
+ table.name: table._end_time_column
87
+ for table in graph.tables.values()
88
+ if table._end_time_column is not None
89
+ }
90
+
58
91
  foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
59
- self._stype_dict: dict[str, dict[str, Stype]] = {}
92
+ self._table_stype_dict: dict[str, dict[str, Stype]] = {}
60
93
  for table in graph.tables.values():
61
- self._stype_dict[table.name] = {}
94
+ self._table_stype_dict[table.name] = {}
62
95
  for column in table.columns:
63
96
  if column == table.primary_key:
64
97
  continue
65
98
  if (table.name, column.name) in foreign_keys:
66
99
  continue
67
- self._stype_dict[table.name][column.name] = column.stype
100
+ self._table_stype_dict[table.name][column.name] = column.stype
101
+
102
+ self._min_time_dict: dict[str, pd.Timestamp] = {}
103
+ self._max_time_dict: dict[str, pd.Timestamp] = {}
104
+
105
+ # Properties ##############################################################
68
106
 
69
107
  @property
70
108
  def edge_types(self) -> list[tuple[str, str, str]]:
109
+ r"""All available edge types in the graph."""
71
110
  return self._edge_types
72
111
 
112
+ @property
113
+ def foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
114
+ r"""The foreign keys for all tables in the graph."""
115
+ return self._foreign_key_dict
116
+
117
+ @property
118
+ def rev_foreign_key_dict(self) -> dict[str, list[tuple[str, str]]]:
119
+ r"""The foreign key back references for all tables in the graph."""
120
+ return self._rev_foreign_key_dict
121
+
73
122
  @property
74
123
  def primary_key_dict(self) -> dict[str, str]:
124
+ r"""All available primary keys in the graph."""
75
125
  return self._primary_key_dict
76
126
 
77
127
  @property
78
128
  def time_column_dict(self) -> dict[str, str]:
129
+ r"""All available time columns in the graph."""
79
130
  return self._time_column_dict
80
131
 
81
132
  @property
82
- def stype_dict(self) -> dict[str, dict[str, Stype]]:
83
- return self._stype_dict
133
+ def end_time_column_dict(self) -> dict[str, str]:
134
+ r"""All available end time columns in the graph."""
135
+ return self._end_time_column_dict
136
+
137
+ @property
138
+ def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
139
+ r"""The registered semantic types for all feature columns in all tables
140
+ in the graph.
141
+ """
142
+ return self._table_stype_dict
143
+
144
+ def get_min_time(
145
+ self,
146
+ table_names: list[str] | None = None,
147
+ ) -> pd.Timestamp:
148
+ r"""Returns the minimal timestamp in the union of a set of tables.
149
+
150
+ Args:
151
+ table_names: The set of tables.
152
+ """
153
+ if table_names is None or len(table_names) == 0:
154
+ table_names = list(self.time_column_dict.keys())
155
+ unknown = list(set(table_names) - set(self._min_time_dict.keys()))
156
+ if len(unknown) > 0:
157
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
158
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
159
+ self._min_time_dict[table_name] = min_time
160
+ self._max_time_dict[table_name] = max_time
161
+ return min([self._min_time_dict[table]
162
+ for table in table_names] + [pd.Timestamp.max])
163
+
164
+ def get_max_time(
165
+ self,
166
+ table_names: list[str] | None = None,
167
+ ) -> pd.Timestamp:
168
+ r"""Returns the maximum timestamp in the union of a set of tables.
169
+
170
+ Args:
171
+ table_names: The set of tables.
172
+ """
173
+ if table_names is None or len(table_names) == 0:
174
+ table_names = list(self.time_column_dict.keys())
175
+ unknown = list(set(table_names) - set(self._max_time_dict.keys()))
176
+ if len(unknown) > 0:
177
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
178
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
179
+ self._min_time_dict[table_name] = min_time
180
+ self._max_time_dict[table_name] = max_time
181
+ return max([self._max_time_dict[table]
182
+ for table in table_names] + [pd.Timestamp.min])
183
+
184
+ # Subgraph Sampling #######################################################
84
185
 
85
186
  def sample_subgraph(
86
187
  self,
87
188
  entity_table_names: tuple[str, ...],
88
189
  entity_pkey: pd.Series,
89
- anchor_time: pd.Series,
190
+ anchor_time: pd.Series | Literal['entity'],
90
191
  num_neighbors: list[int],
91
192
  exclude_cols_dict: dict[str, list[str]] | None = None,
92
193
  ) -> Subgraph:
194
+ r"""Samples distinct subgraphs for each entity primary key.
93
195
 
94
- edge_spec: dict[tuple[str, str, str], list[EdgeSpec]] = {
95
- edge_type: [EdgeSpec(value) for value in num_neighbors]
96
- for edge_type in self.edge_types
97
- }
98
-
99
- stype_dict: dict[str, dict[str, Stype]] = self._stype_dict
196
+ Args:
197
+ entity_table_names: The entity table names.
198
+ entity_pkey: The primary keys to use as seed nodes.
199
+ anchor_time: The anchor time of the subgraphs.
200
+ num_neighbors: The number of neighbors to sample for each hop.
201
+ exclude_cols_dict: The columns to exclude from the subgraph.
202
+ """
203
+ # Exclude all columns that leak target information:
204
+ table_stype_dict: dict[str, dict[str, Stype]] = self.table_stype_dict
100
205
  if exclude_cols_dict is not None:
101
- stype_dict = copy.deepcopy(stype_dict)
206
+ table_stype_dict = copy.deepcopy(table_stype_dict)
102
207
  for table_name, exclude_cols in exclude_cols_dict.items():
103
208
  for column_name in exclude_cols:
104
- del stype_dict[table_name][column_name]
209
+ del table_stype_dict[table_name][column_name]
105
210
 
106
- column_spec: dict[str, list[str]] = {
107
- table_name: list(stypes.keys())
108
- for table_name, stypes in stype_dict.items()
211
+ # Collect all columns being used as features:
212
+ columns_dict: dict[str, set[str]] = {
213
+ table_name: set(stype_dict.keys())
214
+ for table_name, stype_dict in table_stype_dict.items()
109
215
  }
216
+ # Make sure to store primary key information for entity tables:
110
217
  for table_name in entity_table_names:
111
- column_spec[table_name].append(self.primary_key_dict[table_name])
218
+ columns_dict[table_name].add(self.primary_key_dict[table_name])
112
219
 
113
- return self._sample(
220
+ if (isinstance(anchor_time, pd.Series)
221
+ and anchor_time.dtype != 'datetime64[ns]'):
222
+ anchor_time = anchor_time.astype('datetime64[ns]')
223
+
224
+ out = self._sample_subgraph(
114
225
  entity_table_name=entity_table_names[0],
115
226
  entity_pkey=entity_pkey,
116
227
  anchor_time=anchor_time,
117
- column_spec=column_spec,
118
- edge_spec=edge_spec,
119
- return_edges=True,
228
+ columns_dict=columns_dict,
229
+ num_neighbors=num_neighbors,
230
+ )
231
+
232
+ # Parse `SubgraphOutput` into `Subgraph` structure:
233
+ subgraph = Subgraph(
234
+ anchor_time=out.anchor_time,
235
+ table_dict={},
236
+ link_dict={},
237
+ )
238
+
239
+ for table_name, batch in out.batch_dict.items():
240
+ if len(batch) == 0:
241
+ continue
242
+
243
+ primary_key: str | None = None
244
+ if table_name in entity_table_names:
245
+ primary_key = self.primary_key_dict[table_name]
246
+
247
+ df = out.df_dict[table_name].reset_index(drop=True)
248
+ if end_time_column := self.end_time_column_dict.get(table_name):
249
+ # Set end time to NaT for all values greater than anchor time:
250
+ assert table_name not in out.inverse_dict
251
+ ser = df[end_time_column]
252
+ mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
253
+ df.loc[mask, end_time_column] = pd.NaT
254
+
255
+ stype_dict = table_stype_dict[table_name]
256
+ for column_name, stype in stype_dict.items():
257
+ if stype == Stype.text:
258
+ df[column_name] = _normalize_text(df[column_name])
259
+
260
+ subgraph.table_dict[table_name] = Table(
261
+ df=df,
262
+ row=out.inverse_dict.get(table_name),
263
+ batch=batch,
264
+ num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
265
+ stype_dict=stype_dict,
266
+ primary_key=primary_key,
267
+ )
268
+
269
+ for edge_type in out.row_dict.keys():
270
+ row: np.ndarray | None = out.row_dict[edge_type]
271
+ col: np.ndarray | None = out.col_dict[edge_type]
272
+
273
+ if row is None or col is None or len(row) == 0:
274
+ continue
275
+
276
+ # Do not store reverse edge type if it is an exact replica:
277
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
278
+ if (rev_edge_type in subgraph.link_dict
279
+ and np.array_equal(row, out.col_dict[rev_edge_type])
280
+ and np.array_equal(col, out.row_dict[rev_edge_type])):
281
+ subgraph.link_dict[edge_type] = Link(
282
+ layout=EdgeLayout.REV,
283
+ row=None,
284
+ col=None,
285
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
286
+ )
287
+ continue
288
+
289
+ # Do not store non-informative edges:
290
+ layout = EdgeLayout.COO
291
+ if np.array_equal(row, np.arange(len(row))):
292
+ row = None
293
+ if np.array_equal(col, np.arange(len(col))):
294
+ col = None
295
+
296
+ # Store in compressed representation if more efficient:
297
+ num_cols = subgraph.table_dict[edge_type[2]].num_rows
298
+ if (col is not None and len(col) > num_cols + 1
299
+ and ((col[1:] - col[:-1]) >= 0).all()):
300
+ layout = EdgeLayout.CSC
301
+ colcount = np.bincount(col, minlength=num_cols)
302
+ col = np.empty(num_cols + 1, dtype=col.dtype)
303
+ col[0] = 0
304
+ np.cumsum(colcount, out=col[1:])
305
+
306
+ subgraph.link_dict[edge_type] = Link(
307
+ layout=layout,
308
+ row=row,
309
+ col=col,
310
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
311
+ )
312
+
313
+ return subgraph
314
+
315
+ # Predictive Query ########################################################
316
+
317
+ def _get_query_columns_dict(
318
+ self,
319
+ query: ValidatedPredictiveQuery,
320
+ ) -> dict[str, set[str]]:
321
+ columns_dict: dict[str, set[str]] = defaultdict(set)
322
+ for fqn in query.all_query_columns + [query.entity_column]:
323
+ table_name, column_name = fqn.split('.')
324
+ if column_name == '*':
325
+ continue
326
+ columns_dict[table_name].add(column_name)
327
+ if column_name := self.time_column_dict.get(query.entity_table):
328
+ columns_dict[table_name].add(column_name)
329
+ if column_name := self.end_time_column_dict.get(query.entity_table):
330
+ columns_dict[table_name].add(column_name)
331
+ return columns_dict
332
+
333
+ def _get_query_time_offset_dict(
334
+ self,
335
+ query: ValidatedPredictiveQuery,
336
+ ) -> dict[
337
+ tuple[str, str, str],
338
+ tuple[pd.DateOffset | None, pd.DateOffset],
339
+ ]:
340
+ time_offset_dict: dict[
341
+ tuple[str, str, str],
342
+ tuple[pd.DateOffset | None, pd.DateOffset],
343
+ ] = {}
344
+
345
+ def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
346
+ if isinstance(node, Aggregation):
347
+ table_name = node._get_target_column_name().split('.')[0]
348
+
349
+ edge_types = [
350
+ edge_type for edge_type in self.edge_types
351
+ if edge_type[0] == table_name
352
+ and edge_type[2] == query.entity_table
353
+ ]
354
+ if len(edge_types) != 1:
355
+ raise ValueError(f"Could not find a unique foreign key "
356
+ f"from table '{table_name}' to "
357
+ f"'{query.entity_table}'")
358
+ if edge_types[0] not in time_offset_dict:
359
+ start = node.aggr_time_range.start_date_offset
360
+ end = node.aggr_time_range.end_date_offset * num_forecasts
361
+ else:
362
+ start, end = time_offset_dict[edge_types[0]]
363
+ start = min_date_offset(
364
+ start,
365
+ node.aggr_time_range.start_date_offset,
366
+ )
367
+ end = max_date_offset(
368
+ end,
369
+ node.aggr_time_range.end_date_offset * num_forecasts,
370
+ )
371
+ time_offset_dict[edge_types[0]] = (start, end)
372
+
373
+ for child in node.children:
374
+ _add_time_offset(child, num_forecasts)
375
+
376
+ _add_time_offset(query.target_ast, query.num_forecasts)
377
+ _add_time_offset(query.entity_ast)
378
+ if query.whatif_ast is not None:
379
+ _add_time_offset(query.whatif_ast)
380
+
381
+ return time_offset_dict
382
+
383
+ def sample_target(
384
+ self,
385
+ query: ValidatedPredictiveQuery,
386
+ num_train_examples: int,
387
+ train_anchor_time: pd.Timestamp | Literal['entity'],
388
+ num_train_trials: int,
389
+ num_test_examples: int,
390
+ test_anchor_time: pd.Timestamp | Literal['entity'],
391
+ num_test_trials: int,
392
+ random_seed: int | None = None,
393
+ ) -> tuple[TargetOutput, TargetOutput]:
394
+ r"""Samples ground-truth targets given a predictive query, split into
395
+ training and test set.
396
+
397
+ Args:
398
+ query: The predictive query.
399
+ num_train_examples: How many training examples to produce.
400
+ train_anchor_time: The anchor timestamp for the training set.
401
+ If set to ``"entity"``, will use the timestamp of the entity.
402
+ num_train_trials: The number of training examples to try before
403
+ aborting.
404
+ num_test_examples: How many test examples to produce.
405
+ test_anchor_time: The anchor timestamp for the test set.
406
+ If set to ``"entity"``, will use the timestamp of the entity.
407
+ num_test_trials: The number of test examples to try before
408
+ aborting.
409
+ random_seed: A manual seed for generating pseudo-random numbers.
410
+ """
411
+ rng = np.random.default_rng(random_seed)
412
+
413
+ if num_train_examples == 0 or num_train_trials == 0:
414
+ num_train_examples = num_train_trials = 0
415
+ if num_test_examples == 0 or num_test_trials == 0:
416
+ num_test_examples = num_test_trials = 0
417
+
418
+ # 1. Collect information on what to query #############################
419
+ columns_dict = self._get_query_columns_dict(query)
420
+ time_offset_dict = self._get_query_time_offset_dict(query)
421
+ for table_name, _, _ in time_offset_dict.keys():
422
+ columns_dict[table_name].add(self.time_column_dict[table_name])
423
+
424
+ # 2. Sample random rows from entity table #############################
425
+ shared_train_test = query.query_type == QueryType.STATIC
426
+ shared_train_test &= train_anchor_time == test_anchor_time
427
+ if shared_train_test:
428
+ num_entity_rows = num_train_trials + num_test_trials
429
+ else:
430
+ num_entity_rows = max(num_train_trials, num_test_trials)
431
+ assert num_entity_rows > 0
432
+
433
+ entity_df = self._sample_entity_table(
434
+ table_name=query.entity_table,
435
+ columns=columns_dict[query.entity_table],
436
+ num_rows=num_entity_rows,
437
+ random_seed=random_seed,
438
+ )
439
+
440
+ if len(entity_df) == 0:
441
+ raise ValueError("Failed to find any rows in the entity table "
442
+ "'{query.entity_table}'.")
443
+
444
+ entity_pkey = entity_df[self.primary_key_dict[query.entity_table]]
445
+ entity_time: pd.Series | None = None
446
+ if column_name := self.time_column_dict.get(query.entity_table):
447
+ entity_time = entity_df[column_name]
448
+ entity_end_time: pd.Series | None = None
449
+ if column_name := self.end_time_column_dict.get(query.entity_table):
450
+ entity_end_time = entity_df[column_name]
451
+
452
+ def get_valid_entity_index(
453
+ time: pd.Timestamp | Literal['entity'],
454
+ max_size: int | None = None,
455
+ ) -> np.ndarray:
456
+
457
+ if time == 'entity':
458
+ index: np.ndarray = np.arange(len(entity_pkey))
459
+ elif entity_time is None and entity_end_time is None:
460
+ index = np.arange(len(entity_pkey))
461
+ else:
462
+ mask: np.ndarray | None = None
463
+ if entity_time is not None:
464
+ mask = (entity_time <= time).to_numpy()
465
+ if entity_end_time is not None:
466
+ _mask = (entity_end_time > time).to_numpy()
467
+ _mask |= entity_end_time.isna().to_numpy()
468
+ mask = _mask if mask is None else mask & _mask
469
+ assert mask is not None
470
+ index = mask.nonzero()[0]
471
+
472
+ rng.shuffle(index)
473
+
474
+ if max_size is not None:
475
+ index = index[:max_size]
476
+
477
+ return index
478
+
479
+ # 3. Build training and test candidates ###############################
480
+ train_index = test_index = np.array([], dtype=np.int64)
481
+ train_time = test_time = pd.Series([], dtype='datetime64[ns]')
482
+
483
+ if shared_train_test:
484
+ train_index = get_valid_entity_index(train_anchor_time)
485
+ if train_anchor_time == 'entity': # Sort by timestamp:
486
+ assert entity_time is not None
487
+ train_time = entity_time.iloc[train_index]
488
+ train_time = train_time.reset_index(drop=True)
489
+ train_time = train_time.sort_values(ascending=False)
490
+ perm = train_time.index.to_numpy()
491
+ train_index = train_index[perm]
492
+ train_time = train_time.reset_index(drop=True)
493
+ else:
494
+ train_time = to_ser(train_anchor_time, size=len(train_index))
495
+ else:
496
+ if num_test_examples > 0:
497
+ test_index = get_valid_entity_index( #
498
+ test_anchor_time, max_size=num_test_trials)
499
+ assert test_anchor_time != 'entity'
500
+ test_time = to_ser(test_anchor_time, len(test_index))
501
+
502
+ if query.query_type == QueryType.STATIC and num_train_examples > 0:
503
+ train_index = get_valid_entity_index( #
504
+ train_anchor_time, max_size=num_train_trials)
505
+ assert train_anchor_time != 'entity'
506
+ train_time = to_ser(train_anchor_time, len(train_index))
507
+ elif query.query_type == QueryType.TEMPORAL and num_train_examples:
508
+ aggr_table_names = [
509
+ aggr._get_target_column_name().split('.')[0]
510
+ for aggr in query.get_all_target_aggregations()
511
+ ]
512
+ offset = query.target_timeframe.timeframe * query.num_forecasts
513
+
514
+ train_indices: list[np.ndarray] = []
515
+ train_times: list[pd.Series] = []
516
+ while True:
517
+ train_index = get_valid_entity_index( #
518
+ train_anchor_time, max_size=num_train_trials)
519
+ assert train_anchor_time != 'entity'
520
+ train_time = to_ser(train_anchor_time, len(train_index))
521
+ train_indices.append(train_index)
522
+ train_times.append(train_time)
523
+ if sum(len(x) for x in train_indices) >= num_train_trials:
524
+ break
525
+ train_anchor_time -= offset
526
+ if train_anchor_time < self.get_min_time(aggr_table_names):
527
+ break
528
+ train_index = np.concatenate(train_indices, axis=0)
529
+ train_index = train_index[:num_train_trials]
530
+ train_time = pd.concat(train_times, axis=0, ignore_index=True)
531
+ train_time = train_time.iloc[:num_train_trials]
532
+
533
+ # 4. Sample training and test labels ##################################
534
+ train_y, train_mask, test_y, test_mask = self._sample_target(
535
+ query=query,
536
+ entity_df=entity_df,
537
+ train_index=train_index,
538
+ train_time=train_time,
539
+ num_train_examples=(num_train_examples + num_test_examples
540
+ if shared_train_test else num_train_examples),
541
+ test_index=test_index,
542
+ test_time=test_time,
543
+ num_test_examples=0 if shared_train_test else num_test_examples,
544
+ columns_dict=columns_dict,
545
+ time_offset_dict=time_offset_dict,
546
+ )
547
+
548
+ # 5. Post-processing ##################################################
549
+ if shared_train_test:
550
+ num_examples = num_train_examples + num_test_examples
551
+ train_index = train_index[train_mask][:num_examples]
552
+ train_time = train_time.iloc[train_mask].iloc[:num_examples]
553
+ train_y = train_y.iloc[:num_examples]
554
+
555
+ _num_test = num_test_examples
556
+ _num_train = min(num_train_examples, 1000)
557
+ if (num_test_examples > 0 and num_train_examples > 0
558
+ and len(train_y) < num_examples
559
+ and len(train_y) < _num_test + _num_train):
560
+ # Not enough labels to satisfy requested split without losing
561
+ # large number of training examples:
562
+ _num_test = len(train_y) - _num_train
563
+ if _num_test < _num_train: # Fallback to 50/50 split:
564
+ _num_test = len(train_y) // 2
565
+
566
+ test_index = train_index[:_num_test]
567
+ test_pkey = entity_pkey.iloc[test_index]
568
+ test_time = train_time.iloc[:_num_test]
569
+ test_y = train_y.iloc[:_num_test]
570
+
571
+ train_index = train_index[_num_test:]
572
+ train_pkey = entity_pkey.iloc[train_index]
573
+ train_time = train_time.iloc[_num_test:]
574
+ train_y = train_y.iloc[_num_test:]
575
+ else:
576
+ train_index = train_index[train_mask][:num_train_examples]
577
+ train_pkey = entity_pkey.iloc[train_index]
578
+ train_time = train_time.iloc[train_mask].iloc[:num_train_examples]
579
+ train_y = train_y.iloc[:num_train_examples]
580
+
581
+ test_index = test_index[test_mask][:num_test_examples]
582
+ test_pkey = entity_pkey.iloc[test_index]
583
+ test_time = test_time.iloc[test_mask].iloc[:num_test_examples]
584
+ test_y = test_y.iloc[:num_test_examples]
585
+
586
+ train_pkey = train_pkey.reset_index(drop=True)
587
+ train_time = train_time.reset_index(drop=True)
588
+ train_y = train_y.reset_index(drop=True)
589
+ test_pkey = test_pkey.reset_index(drop=True)
590
+ test_time = test_time.reset_index(drop=True)
591
+ test_y = test_y.reset_index(drop=True)
592
+
593
+ if num_train_examples > 0 and len(train_y) == 0:
594
+ raise RuntimeError("Failed to collect any context examples. Is "
595
+ "your predictive query too restrictive?")
596
+
597
+ if num_test_examples > 0 and len(test_y) == 0:
598
+ raise RuntimeError("Failed to collect any test examples for "
599
+ "evaluation. Is your predictive query too "
600
+ "restrictive?")
601
+
602
+ global _coverage_warned
603
+ if (not num_train_examples > 0 #
604
+ and not _coverage_warned #
605
+ and len(entity_df) >= num_entity_rows
606
+ and len(train_y) < num_train_examples // 2):
607
+ _coverage_warned = True
608
+ warnings.warn(f"Failed to collect {num_train_examples:,} context "
609
+ f"examples within {num_train_trials:,} candidates. "
610
+ f"To improve coverage, consider increasing the "
611
+ f"number of PQ iterations using the "
612
+ f"'max_pq_iterations' option. This warning will not "
613
+ f"be shown again in this run.")
614
+
615
+ if (not num_test_examples > 0 #
616
+ and not _coverage_warned #
617
+ and len(entity_df) >= num_entity_rows
618
+ and len(test_y) < num_test_examples // 2):
619
+ _coverage_warned = True
620
+ warnings.warn(f"Failed to collect {num_test_examples:,} test "
621
+ f"examples within {num_test_trials:,} candidates. "
622
+ f"To improve coverage, consider increasing the "
623
+ f"number of PQ iterations using the "
624
+ f"'max_pq_iterations' option. This warning will not "
625
+ f"be shown again in this run.")
626
+
627
+ return (
628
+ TargetOutput(train_pkey, train_time, train_y),
629
+ TargetOutput(test_pkey, test_time, test_y),
120
630
  )
121
631
 
122
632
  # Abstract Methods ########################################################
123
633
 
124
634
  @abstractmethod
125
- def _sample(
635
+ def _get_min_max_time_dict(
636
+ self,
637
+ table_names: list[str],
638
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
639
+ r"""Returns the minimum and maximum timestamps for a set of tables.
640
+
641
+ Args:
642
+ table_names: The tables.
643
+ """
644
+
645
+ @abstractmethod
646
+ def _sample_subgraph(
126
647
  self,
127
648
  entity_table_name: str,
128
649
  entity_pkey: pd.Series,
129
- anchor_time: pd.Series,
130
- column_spec: dict[str, list[str]],
131
- edge_spec: dict[tuple[str, str, str], list[EdgeSpec]],
132
- return_edges: bool = False,
650
+ anchor_time: pd.Series | Literal['entity'],
651
+ columns_dict: dict[str, set[str]],
652
+ num_neighbors: list[int],
133
653
  ) -> SamplerOutput:
134
- pass
654
+ r"""Samples distinct subgraphs for each entity primary key.
655
+
656
+ Args:
657
+ entity_table_name: The entity table name.
658
+ entity_pkey: The primary keys to use as seed nodes.
659
+ anchor_time: The anchor time of the subgraphs.
660
+ columns_dict: The columns to return for each table.
661
+ num_neighbors: The number of neighbors to sample for each hop.
662
+ """
663
+
664
+ @abstractmethod
665
+ def _sample_entity_table(
666
+ self,
667
+ table_name: str,
668
+ columns: set[str],
669
+ num_rows: int,
670
+ random_seed: int | None = None,
671
+ ) -> pd.DataFrame:
672
+ r"""Returns a random sample of rows from the entity table.
673
+
674
+ Args:
675
+ table_name: The table.
676
+ columns: The columns to return.
677
+ num_rows: Maximum number of rows to return. Can be smaller in case
678
+ the entity table contains less rows.
679
+ random_seed: A manual seed for generating pseudo-random numbers.
680
+ """
681
+
682
+ @abstractmethod
683
+ def _sample_target(
684
+ self,
685
+ query: ValidatedPredictiveQuery,
686
+ entity_df: pd.DataFrame,
687
+ train_index: np.ndarray,
688
+ train_time: pd.Series,
689
+ num_train_examples: int,
690
+ test_index: np.ndarray,
691
+ test_time: pd.Series,
692
+ num_test_examples: int,
693
+ columns_dict: dict[str, set[str]],
694
+ time_offset_dict: dict[
695
+ tuple[str, str, str],
696
+ tuple[pd.DateOffset | None, pd.DateOffset],
697
+ ],
698
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
699
+ r"""Samples ground-truth targets given a predictive query from a set of
700
+ training and test candidates.
701
+
702
+ Args:
703
+ query: The predictive query.
704
+ entity_df: The entity data frame, containing the union of all train
705
+ and test candidates.
706
+ train_index: The indices of training candidates.
707
+ train_time: The anchor time of training candidates.
708
+ num_train_examples: How many training examples to produce.
709
+ test_index: The indices of test candidates.
710
+ test_time: The anchor time of test candidates.
711
+ num_test_examples: How many test examples to produce.
712
+ columns_dict: The columns that are being used to compute
713
+ ground-truth targets.
714
+ time_offset_dict: The date offsets to query for each edge type,
715
+ relative to the anchor time.
716
+ """
717
+
718
+
719
+ # Helper Functions ############################################################
720
+
721
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
722
+ MULTISPACE = re.compile(r"\s+")
723
+
724
+
725
+ def _normalize_text(
726
+ ser: pd.Series,
727
+ max_words: int | None = 50,
728
+ ) -> pd.Series:
729
+ r"""Normalizes text into a list of lower-case words.
730
+
731
+ Args:
732
+ ser: The :class:`pandas.Series` to normalize.
733
+ max_words: The maximum number of words to return.
734
+ This will auto-shrink any large text column to avoid blowing up
735
+ context size.
736
+ """
737
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
738
+ return ser
739
+
740
+ def normalize_fn(line: str) -> list[str]:
741
+ line = PUNCTUATION.sub(" ", line)
742
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
743
+ line = MULTISPACE.sub(" ", line)
744
+ words = line.split()
745
+ if max_words is not None:
746
+ words = words[:max_words]
747
+ return words
748
+
749
+ ser = ser.fillna('').astype(str)
750
+
751
+ if max_words is not None:
752
+ # We estimate the number of words as 5 characters + 1 space in an
753
+ # English text on average. We need this pre-filter here, as word
754
+ # splitting on a giant text can be very expensive:
755
+ ser = ser.str[:6 * max_words]
756
+
757
+ ser = ser.str.lower()
758
+ ser = ser.map(normalize_fn)
759
+
760
+ return ser
761
+
762
+
763
+ def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
764
+ if any(arg is None for arg in args):
765
+ return None
766
+
767
+ anchor = pd.Timestamp('2000-01-01')
768
+ timestamps = [anchor + arg for arg in args]
769
+ assert len(timestamps) > 0
770
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
771
+ return args[argmin]
772
+
773
+
774
+ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
775
+ anchor = pd.Timestamp('2000-01-01')
776
+ timestamps = [anchor + arg for arg in args]
777
+ assert len(timestamps) > 0
778
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
779
+ return args[argmax]
780
+
781
+
782
+ def to_ser(value: Any, size: int) -> pd.Series:
783
+ return pd.Series([value]).repeat(size).reset_index(drop=True)