kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/pquery.py +6 -2
  5. kumoai/connector/utils.py +23 -2
  6. kumoai/experimental/rfm/__init__.py +162 -46
  7. kumoai/experimental/rfm/backend/__init__.py +0 -0
  8. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  10. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  11. kumoai/experimental/rfm/backend/local/table.py +119 -0
  12. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  18. kumoai/experimental/rfm/base/__init__.py +23 -0
  19. kumoai/experimental/rfm/base/column.py +66 -0
  20. kumoai/experimental/rfm/base/sampler.py +773 -0
  21. kumoai/experimental/rfm/base/source.py +19 -0
  22. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  23. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  24. kumoai/experimental/rfm/infer/__init__.py +6 -0
  25. kumoai/experimental/rfm/infer/dtype.py +79 -0
  26. kumoai/experimental/rfm/infer/pkey.py +126 -0
  27. kumoai/experimental/rfm/infer/time_col.py +62 -0
  28. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  29. kumoai/experimental/rfm/rfm.py +233 -174
  30. kumoai/experimental/rfm/sagemaker.py +138 -0
  31. kumoai/spcs.py +1 -3
  32. kumoai/testing/decorators.py +1 -1
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +2 -0
  35. kumoai/utils/sql.py +3 -0
  36. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
  38. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  39. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,773 @@
1
+ import copy
2
+ import re
3
+ import warnings
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass
7
+ from typing import TYPE_CHECKING, Any, Literal, NamedTuple
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
12
+ from kumoapi.pquery.AST import Aggregation, ASTNode
13
+ from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
14
+ from kumoapi.typing import Stype
15
+
16
+ from kumoai.experimental.rfm.base import SourceColumn
17
+ from kumoai.utils import ProgressLogger
18
+
19
+ if TYPE_CHECKING:
20
+ from kumoai.experimental.rfm import Graph
21
+
22
+ _coverage_warned = False
23
+
24
+
25
+ @dataclass
26
+ class SamplerOutput:
27
+ anchor_time: np.ndarray
28
+ df_dict: dict[str, pd.DataFrame]
29
+ inverse_dict: dict[str, np.ndarray]
30
+ batch_dict: dict[str, np.ndarray]
31
+ num_sampled_nodes_dict: dict[str, list[int]]
32
+ row_dict: dict[tuple[str, str, str], np.ndarray]
33
+ col_dict: dict[tuple[str, str, str], np.ndarray]
34
+ num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
35
+
36
+
37
+ class TargetOutput(NamedTuple):
38
+ entity_pkey: pd.Series
39
+ anchor_time: pd.Series
40
+ target: pd.Series
41
+
42
+
43
+ class Sampler(ABC):
44
+ r"""A base class to sample relational data (*i.e.*, subgraphs and
45
+ ground-truth targets) from a custom backend.
46
+
47
+ Args:
48
+ graph: The graph.
49
+ verbose: Whether to print verbose output.
50
+ """
51
+ def __init__(
52
+ self,
53
+ graph: 'Graph',
54
+ verbose: bool | ProgressLogger = True,
55
+ ) -> None:
56
+ self._edge_types: list[tuple[str, str, str]] = []
57
+ for edge in graph.edges:
58
+ edge_type = (edge.src_table, edge.fkey, edge.dst_table)
59
+ self._edge_types.append(edge_type)
60
+ self._edge_types.append(Subgraph.rev_edge_type(edge_type))
61
+
62
+ self._primary_key_dict: dict[str, str] = {
63
+ table.name: table._primary_key
64
+ for table in graph.tables.values()
65
+ if table._primary_key is not None
66
+ }
67
+
68
+ self._time_column_dict: dict[str, str] = {
69
+ table.name: table._time_column
70
+ for table in graph.tables.values()
71
+ if table._time_column is not None
72
+ }
73
+
74
+ self._end_time_column_dict: dict[str, str] = {
75
+ table.name: table._end_time_column
76
+ for table in graph.tables.values()
77
+ if table._end_time_column is not None
78
+ }
79
+
80
+ foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
81
+ self._table_stype_dict: dict[str, dict[str, Stype]] = {}
82
+ for table in graph.tables.values():
83
+ self._table_stype_dict[table.name] = {}
84
+ for column in table.columns:
85
+ if column == table.primary_key:
86
+ continue
87
+ if (table.name, column.name) in foreign_keys:
88
+ continue
89
+ self._table_stype_dict[table.name][column.name] = column.stype
90
+
91
+ self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
92
+ for table in graph.tables.values():
93
+ self._source_table_dict[table.name] = table._source_column_dict
94
+
95
+ self._min_time_dict: dict[str, pd.Timestamp] = {}
96
+ self._max_time_dict: dict[str, pd.Timestamp] = {}
97
+
98
+ # Properties ##############################################################
99
+
100
+ @property
101
+ def edge_types(self) -> list[tuple[str, str, str]]:
102
+ r"""All available edge types in the graph."""
103
+ return self._edge_types
104
+
105
+ @property
106
+ def primary_key_dict(self) -> dict[str, str]:
107
+ r"""All available primary keys in the graph."""
108
+ return self._primary_key_dict
109
+
110
+ @property
111
+ def time_column_dict(self) -> dict[str, str]:
112
+ r"""All available time columns in the graph."""
113
+ return self._time_column_dict
114
+
115
+ @property
116
+ def end_time_column_dict(self) -> dict[str, str]:
117
+ r"""All available end time columns in the graph."""
118
+ return self._end_time_column_dict
119
+
120
+ @property
121
+ def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
122
+ r"""The registered semantic types for all columns in all tables in
123
+ the graph.
124
+ """
125
+ return self._table_stype_dict
126
+
127
+ @property
128
+ def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
129
+ r"""Source column information for all tables in the graph."""
130
+ return self._source_table_dict
131
+
132
+ def get_min_time(
133
+ self,
134
+ table_names: list[str] | None = None,
135
+ ) -> pd.Timestamp:
136
+ r"""Returns the minimal timestamp in the union of a set of tables.
137
+
138
+ Args:
139
+ table_names: The set of tables.
140
+ """
141
+ if table_names is None or len(table_names) == 0:
142
+ table_names = list(self.time_column_dict.keys())
143
+ unknown = list(set(table_names) - set(self._min_time_dict.keys()))
144
+ if len(unknown) > 0:
145
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
146
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
147
+ self._min_time_dict[table_name] = min_time
148
+ self._max_time_dict[table_name] = max_time
149
+ return min([self._min_time_dict[table]
150
+ for table in table_names] + [pd.Timestamp.max])
151
+
152
+ def get_max_time(
153
+ self,
154
+ table_names: list[str] | None = None,
155
+ ) -> pd.Timestamp:
156
+ r"""Returns the maximum timestamp in the union of a set of tables.
157
+
158
+ Args:
159
+ table_names: The set of tables.
160
+ """
161
+ if table_names is None or len(table_names) == 0:
162
+ table_names = list(self.time_column_dict.keys())
163
+ unknown = list(set(table_names) - set(self._max_time_dict.keys()))
164
+ if len(unknown) > 0:
165
+ min_max_time_dict = self._get_min_max_time_dict(unknown)
166
+ for table_name, (min_time, max_time) in min_max_time_dict.items():
167
+ self._min_time_dict[table_name] = min_time
168
+ self._max_time_dict[table_name] = max_time
169
+ return max([self._max_time_dict[table]
170
+ for table in table_names] + [pd.Timestamp.min])
171
+
172
+ # Subgraph Sampling #######################################################
173
+
174
+ def sample_subgraph(
175
+ self,
176
+ entity_table_names: tuple[str, ...],
177
+ entity_pkey: pd.Series,
178
+ anchor_time: pd.Series | Literal['entity'],
179
+ num_neighbors: list[int],
180
+ exclude_cols_dict: dict[str, list[str]] | None = None,
181
+ ) -> Subgraph:
182
+ r"""Samples distinct subgraphs for each entity primary key.
183
+
184
+ Args:
185
+ entity_table_names: The entity table names.
186
+ entity_pkey: The primary keys to use as seed nodes.
187
+ anchor_time: The anchor time of the subgraphs.
188
+ num_neighbors: The number of neighbors to sample for each hop.
189
+ exclude_cols_dict: The columns to exclude from the subgraph.
190
+ """
191
+ # Exclude all columns that leak target information:
192
+ table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
193
+ if exclude_cols_dict is not None:
194
+ table_stype_dict = copy.deepcopy(table_stype_dict)
195
+ for table_name, exclude_cols in exclude_cols_dict.items():
196
+ for column_name in exclude_cols:
197
+ del table_stype_dict[table_name][column_name]
198
+
199
+ # Collect all columns being used as features:
200
+ columns_dict: dict[str, set[str]] = {
201
+ table_name: set(stype_dict.keys())
202
+ for table_name, stype_dict in table_stype_dict.items()
203
+ }
204
+ # Make sure to store primary key information for entity tables:
205
+ for table_name in entity_table_names:
206
+ columns_dict[table_name].add(self.primary_key_dict[table_name])
207
+
208
+ if (isinstance(anchor_time, pd.Series)
209
+ and anchor_time.dtype != 'datetime64[ns]'):
210
+ anchor_time = anchor_time.astype('datetime64[ns]')
211
+
212
+ out = self._sample_subgraph(
213
+ entity_table_name=entity_table_names[0],
214
+ entity_pkey=entity_pkey,
215
+ anchor_time=anchor_time,
216
+ columns_dict=columns_dict,
217
+ num_neighbors=num_neighbors,
218
+ )
219
+
220
+ # Parse `SubgraphOutput` into `Subgraph` structure:
221
+ subgraph = Subgraph(
222
+ anchor_time=out.anchor_time,
223
+ table_dict={},
224
+ link_dict={},
225
+ )
226
+
227
+ for table_name, batch in out.batch_dict.items():
228
+ if len(batch) == 0:
229
+ continue
230
+
231
+ primary_key: str | None = None
232
+ if table_name in entity_table_names:
233
+ primary_key = self.primary_key_dict[table_name]
234
+
235
+ df = out.df_dict[table_name].reset_index(drop=True)
236
+ if end_time_column := self.end_time_column_dict.get(table_name):
237
+ # Set end time to NaT for all values greater than anchor time:
238
+ assert table_name not in out.inverse_dict
239
+ ser = df[end_time_column]
240
+ if ser.dtype != 'datetime64[ns]':
241
+ ser = ser.astype('datetime64[ns]')
242
+ mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
243
+ ser.iloc[mask] = pd.NaT
244
+ df[end_time_column] = ser
245
+
246
+ stype_dict = table_stype_dict[table_name]
247
+ for column_name, stype in stype_dict.items():
248
+ if stype == Stype.text:
249
+ df[column_name] = _normalize_text(df[column_name])
250
+
251
+ subgraph.table_dict[table_name] = Table(
252
+ df=df,
253
+ row=out.inverse_dict.get(table_name),
254
+ batch=batch,
255
+ num_sampled_nodes=out.num_sampled_nodes_dict[table_name],
256
+ stype_dict=stype_dict,
257
+ primary_key=primary_key,
258
+ )
259
+
260
+ for edge_type in out.row_dict.keys():
261
+ row: np.ndarray | None = out.row_dict[edge_type]
262
+ col: np.ndarray | None = out.col_dict[edge_type]
263
+
264
+ if row is None or col is None or len(row) == 0:
265
+ continue
266
+
267
+ # Do not store reverse edge type if it is an exact replica:
268
+ rev_edge_type = Subgraph.rev_edge_type(edge_type)
269
+ if (rev_edge_type in subgraph.link_dict
270
+ and np.array_equal(row, out.col_dict[rev_edge_type])
271
+ and np.array_equal(col, out.row_dict[rev_edge_type])):
272
+ subgraph.link_dict[edge_type] = Link(
273
+ layout=EdgeLayout.REV,
274
+ row=None,
275
+ col=None,
276
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
277
+ )
278
+ continue
279
+
280
+ # Do not store non-informative edges:
281
+ layout = EdgeLayout.COO
282
+ if np.array_equal(row, np.arange(len(row))):
283
+ row = None
284
+ if np.array_equal(col, np.arange(len(col))):
285
+ col = None
286
+
287
+ # Store in compressed representation if more efficient:
288
+ num_cols = subgraph.table_dict[edge_type[2]].num_rows
289
+ if col is not None and len(col) > num_cols + 1:
290
+ layout = EdgeLayout.CSC
291
+ colcount = np.bincount(col, minlength=num_cols)
292
+ col = np.empty(num_cols + 1, dtype=col.dtype)
293
+ col[0] = 0
294
+ np.cumsum(colcount, out=col[1:])
295
+
296
+ subgraph.link_dict[edge_type] = Link(
297
+ layout=layout,
298
+ row=row,
299
+ col=col,
300
+ num_sampled_edges=out.num_sampled_edges_dict[edge_type],
301
+ )
302
+
303
+ return subgraph
304
+
305
+ # Predictive Query ########################################################
306
+
307
+ def _get_query_columns_dict(
308
+ self,
309
+ query: ValidatedPredictiveQuery,
310
+ ) -> dict[str, set[str]]:
311
+ columns_dict: dict[str, set[str]] = defaultdict(set)
312
+ for fqn in query.all_query_columns + [query.entity_column]:
313
+ table_name, column_name = fqn.split('.')
314
+ if column_name == '*':
315
+ continue
316
+ columns_dict[table_name].add(column_name)
317
+ if column_name := self.time_column_dict.get(query.entity_table):
318
+ columns_dict[table_name].add(column_name)
319
+ if column_name := self.end_time_column_dict.get(query.entity_table):
320
+ columns_dict[table_name].add(column_name)
321
+ return columns_dict
322
+
323
+ def _get_query_time_offset_dict(
324
+ self,
325
+ query: ValidatedPredictiveQuery,
326
+ ) -> dict[
327
+ tuple[str, str, str],
328
+ tuple[pd.DateOffset | None, pd.DateOffset],
329
+ ]:
330
+ time_offset_dict: dict[
331
+ tuple[str, str, str],
332
+ tuple[pd.DateOffset | None, pd.DateOffset],
333
+ ] = {}
334
+
335
+ def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
336
+ if isinstance(node, Aggregation):
337
+ table_name = node._get_target_column_name().split('.')[0]
338
+
339
+ edge_types = [
340
+ edge_type for edge_type in self.edge_types
341
+ if edge_type[0] == table_name
342
+ and edge_type[2] == query.entity_table
343
+ ]
344
+ if len(edge_types) != 1:
345
+ raise ValueError(f"Could not find a unique foreign key "
346
+ f"from table '{table_name}' to "
347
+ f"'{query.entity_table}'")
348
+ if edge_types[0] not in time_offset_dict:
349
+ start = node.aggr_time_range.start_date_offset
350
+ end = node.aggr_time_range.end_date_offset * num_forecasts
351
+ else:
352
+ start, end = time_offset_dict[edge_types[0]]
353
+ start = min_date_offset(
354
+ start,
355
+ node.aggr_time_range.start_date_offset,
356
+ )
357
+ end = max_date_offset(
358
+ end,
359
+ node.aggr_time_range.end_date_offset * num_forecasts,
360
+ )
361
+ time_offset_dict[edge_types[0]] = (start, end)
362
+
363
+ for child in node.children:
364
+ _add_time_offset(child, num_forecasts)
365
+
366
+ _add_time_offset(query.target_ast, query.num_forecasts)
367
+ _add_time_offset(query.entity_ast)
368
+ if query.whatif_ast is not None:
369
+ _add_time_offset(query.whatif_ast)
370
+
371
+ return time_offset_dict
372
+
373
+ def sample_target(
374
+ self,
375
+ query: ValidatedPredictiveQuery,
376
+ num_train_examples: int,
377
+ train_anchor_time: pd.Timestamp | Literal['entity'],
378
+ num_train_trials: int,
379
+ num_test_examples: int,
380
+ test_anchor_time: pd.Timestamp | Literal['entity'],
381
+ num_test_trials: int,
382
+ random_seed: int | None = None,
383
+ ) -> tuple[TargetOutput, TargetOutput]:
384
+ r"""Samples ground-truth targets given a predictive query, split into
385
+ training and test set.
386
+
387
+ Args:
388
+ query: The predictive query.
389
+ num_train_examples: How many training examples to produce.
390
+ train_anchor_time: The anchor timestamp for the training set.
391
+ If set to ``"entity"``, will use the timestamp of the entity.
392
+ num_train_trials: The number of training examples to try before
393
+ aborting.
394
+ num_test_examples: How many test examples to produce.
395
+ test_anchor_time: The anchor timestamp for the test set.
396
+ If set to ``"entity"``, will use the timestamp of the entity.
397
+ num_test_trials: The number of test examples to try before
398
+ aborting.
399
+ random_seed: A manual seed for generating pseudo-random numbers.
400
+ """
401
+ rng = np.random.default_rng(random_seed)
402
+
403
+ if num_train_examples == 0 or num_train_trials == 0:
404
+ num_train_examples = num_train_trials = 0
405
+ if num_test_examples == 0 or num_test_trials == 0:
406
+ num_test_examples = num_test_trials = 0
407
+
408
+ # 1. Collect information on what to query #############################
409
+ columns_dict = self._get_query_columns_dict(query)
410
+ time_offset_dict = self._get_query_time_offset_dict(query)
411
+ for table_name, _, _ in time_offset_dict.keys():
412
+ columns_dict[table_name].add(self.time_column_dict[table_name])
413
+
414
+ # 2. Sample random rows from entity table #############################
415
+ shared_train_test = query.query_type == QueryType.STATIC
416
+ shared_train_test &= train_anchor_time == test_anchor_time
417
+ if shared_train_test:
418
+ num_entity_rows = num_train_trials + num_test_trials
419
+ else:
420
+ num_entity_rows = max(num_train_trials, num_test_trials)
421
+ assert num_entity_rows > 0
422
+
423
+ entity_df = self._sample_entity_table(
424
+ table_name=query.entity_table,
425
+ columns=columns_dict[query.entity_table],
426
+ num_rows=num_entity_rows,
427
+ random_seed=random_seed,
428
+ )
429
+
430
+ if len(entity_df) == 0:
431
+ raise ValueError("Failed to find any rows in the entity table "
432
+ "'{query.entity_table}'.")
433
+
434
+ entity_pkey = entity_df[self.primary_key_dict[query.entity_table]]
435
+ entity_time: pd.Series | None = None
436
+ if column_name := self.time_column_dict.get(query.entity_table):
437
+ entity_time = entity_df[column_name]
438
+ entity_end_time: pd.Series | None = None
439
+ if column_name := self.end_time_column_dict.get(query.entity_table):
440
+ entity_end_time = entity_df[column_name]
441
+
442
+ def get_valid_entity_index(
443
+ time: pd.Timestamp | Literal['entity'],
444
+ max_size: int | None = None,
445
+ ) -> np.ndarray:
446
+
447
+ if time == 'entity':
448
+ index: np.ndarray = np.arange(len(entity_pkey))
449
+ elif entity_time is None and entity_end_time is None:
450
+ index = np.arange(len(entity_pkey))
451
+ else:
452
+ mask: np.ndarray | None = None
453
+ if entity_time is not None:
454
+ mask = (entity_time <= time).to_numpy()
455
+ if entity_end_time is not None:
456
+ _mask = (entity_end_time > time).to_numpy()
457
+ _mask |= entity_end_time.isna().to_numpy()
458
+ mask = _mask if mask is None else mask & _mask
459
+ assert mask is not None
460
+ index = mask.nonzero()[0]
461
+
462
+ rng.shuffle(index)
463
+
464
+ if max_size is not None:
465
+ index = index[:max_size]
466
+
467
+ return index
468
+
469
+ # 3. Build training and test candidates ###############################
470
+ train_index = test_index = np.array([], dtype=np.int64)
471
+ train_time = test_time = pd.Series([], dtype='datetime64[ns]')
472
+
473
+ if shared_train_test:
474
+ train_index = get_valid_entity_index(train_anchor_time)
475
+ if train_anchor_time == 'entity': # Sort by timestamp:
476
+ assert entity_time is not None
477
+ train_time = entity_time.iloc[train_index]
478
+ train_time = train_time.reset_index(drop=True)
479
+ train_time = train_time.sort_values(ascending=False)
480
+ perm = train_time.index.to_numpy()
481
+ train_index = train_index[perm]
482
+ train_time = train_time.reset_index(drop=True)
483
+ else:
484
+ train_time = to_ser(train_anchor_time, size=len(train_index))
485
+ else:
486
+ if num_test_examples > 0:
487
+ test_index = get_valid_entity_index( #
488
+ test_anchor_time, max_size=num_test_trials)
489
+ assert test_anchor_time != 'entity'
490
+ test_time = to_ser(test_anchor_time, len(test_index))
491
+
492
+ if query.query_type == QueryType.STATIC and num_train_examples > 0:
493
+ train_index = get_valid_entity_index( #
494
+ train_anchor_time, max_size=num_train_trials)
495
+ assert train_anchor_time != 'entity'
496
+ train_time = to_ser(train_anchor_time, len(train_index))
497
+ elif query.query_type == QueryType.TEMPORAL and num_train_examples:
498
+ aggr_table_names = [
499
+ aggr._get_target_column_name().split('.')[0]
500
+ for aggr in query.get_all_target_aggregations()
501
+ ]
502
+ offset = query.target_timeframe.timeframe * query.num_forecasts
503
+
504
+ train_indices: list[np.ndarray] = []
505
+ train_times: list[pd.Series] = []
506
+ while True:
507
+ train_index = get_valid_entity_index( #
508
+ train_anchor_time, max_size=num_train_trials)
509
+ assert train_anchor_time != 'entity'
510
+ train_time = to_ser(train_anchor_time, len(train_index))
511
+ train_indices.append(train_index)
512
+ train_times.append(train_time)
513
+ if sum(len(x) for x in train_indices) >= num_train_trials:
514
+ break
515
+ train_anchor_time -= offset
516
+ if train_anchor_time < self.get_min_time(aggr_table_names):
517
+ break
518
+ train_index = np.concatenate(train_indices, axis=0)
519
+ train_index = train_index[:num_train_trials]
520
+ train_time = pd.concat(train_times, axis=0, ignore_index=True)
521
+ train_time = train_time.iloc[:num_train_trials]
522
+
523
+ # 4. Sample training and test labels ##################################
524
+ train_y, train_mask, test_y, test_mask = self._sample_target(
525
+ query=query,
526
+ entity_df=entity_df,
527
+ train_index=train_index,
528
+ train_time=train_time,
529
+ num_train_examples=(num_train_examples + num_test_examples
530
+ if shared_train_test else num_train_examples),
531
+ test_index=test_index,
532
+ test_time=test_time,
533
+ num_test_examples=0 if shared_train_test else num_test_examples,
534
+ columns_dict=columns_dict,
535
+ time_offset_dict=time_offset_dict,
536
+ )
537
+
538
+ # 5. Post-processing ##################################################
539
+ if shared_train_test:
540
+ num_examples = num_train_examples + num_test_examples
541
+ train_index = train_index[train_mask][:num_examples]
542
+ train_time = train_time.iloc[train_mask].iloc[:num_examples]
543
+ train_y = train_y.iloc[:num_examples]
544
+
545
+ _num_test = num_test_examples
546
+ _num_train = min(num_train_examples, 1000)
547
+ if (num_test_examples > 0 and num_train_examples > 0
548
+ and len(train_y) < num_examples
549
+ and len(train_y) < _num_test + _num_train):
550
+ # Not enough labels to satisfy requested split without losing
551
+ # large number of training examples:
552
+ _num_test = len(train_y) - _num_train
553
+ if _num_test < _num_train: # Fallback to 50/50 split:
554
+ _num_test = len(train_y) // 2
555
+
556
+ test_index = train_index[:_num_test]
557
+ test_pkey = entity_pkey.iloc[test_index]
558
+ test_time = train_time.iloc[:_num_test]
559
+ test_y = train_y.iloc[:_num_test]
560
+
561
+ train_index = train_index[_num_test:]
562
+ train_pkey = entity_pkey.iloc[train_index]
563
+ train_time = train_time.iloc[_num_test:]
564
+ train_y = train_y.iloc[_num_test:]
565
+ else:
566
+ train_index = train_index[train_mask][:num_train_examples]
567
+ train_pkey = entity_pkey.iloc[train_index]
568
+ train_time = train_time.iloc[train_mask].iloc[:num_train_examples]
569
+ train_y = train_y.iloc[:num_train_examples]
570
+
571
+ test_index = test_index[test_mask][:num_test_examples]
572
+ test_pkey = entity_pkey.iloc[test_index]
573
+ test_time = test_time.iloc[test_mask].iloc[:num_test_examples]
574
+ test_y = test_y.iloc[:num_test_examples]
575
+
576
+ train_pkey = train_pkey.reset_index(drop=True)
577
+ train_time = train_time.reset_index(drop=True)
578
+ train_y = train_y.reset_index(drop=True)
579
+ test_pkey = test_pkey.reset_index(drop=True)
580
+ test_time = test_time.reset_index(drop=True)
581
+ test_y = test_y.reset_index(drop=True)
582
+
583
+ if num_train_examples > 0 and len(train_y) == 0:
584
+ raise RuntimeError("Failed to collect any context examples. Is "
585
+ "your predictive query too restrictive?")
586
+
587
+ if num_test_examples > 0 and len(test_y) == 0:
588
+ raise RuntimeError("Failed to collect any test examples for "
589
+ "evaluation. Is your predictive query too "
590
+ "restrictive?")
591
+
592
+ global _coverage_warned
593
+ if (not num_train_examples > 0 #
594
+ and not _coverage_warned #
595
+ and len(entity_df) >= num_entity_rows
596
+ and len(train_y) < num_train_examples // 2):
597
+ _coverage_warned = True
598
+ warnings.warn(f"Failed to collect {num_train_examples:,} context "
599
+ f"examples within {num_train_trials:,} candidates. "
600
+ f"To improve coverage, consider increasing the "
601
+ f"number of PQ iterations using the "
602
+ f"'max_pq_iterations' option. This warning will not "
603
+ f"be shown again in this run.")
604
+
605
+ if (not num_test_examples > 0 #
606
+ and not _coverage_warned #
607
+ and len(entity_df) >= num_entity_rows
608
+ and len(test_y) < num_test_examples // 2):
609
+ _coverage_warned = True
610
+ warnings.warn(f"Failed to collect {num_test_examples:,} test "
611
+ f"examples within {num_test_trials:,} candidates. "
612
+ f"To improve coverage, consider increasing the "
613
+ f"number of PQ iterations using the "
614
+ f"'max_pq_iterations' option. This warning will not "
615
+ f"be shown again in this run.")
616
+
617
+ return (
618
+ TargetOutput(train_pkey, train_time, train_y),
619
+ TargetOutput(test_pkey, test_time, test_y),
620
+ )
621
+
622
+ # Abstract Methods ########################################################
623
+
624
+ @abstractmethod
625
+ def _get_min_max_time_dict(
626
+ self,
627
+ table_names: list[str],
628
+ ) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
629
+ r"""Returns the minimum and maximum timestamps for a set of tables.
630
+
631
+ Args:
632
+ table_names: The tables.
633
+ """
634
+
635
+ @abstractmethod
636
+ def _sample_subgraph(
637
+ self,
638
+ entity_table_name: str,
639
+ entity_pkey: pd.Series,
640
+ anchor_time: pd.Series | Literal['entity'],
641
+ columns_dict: dict[str, set[str]],
642
+ num_neighbors: list[int],
643
+ ) -> SamplerOutput:
644
+ r"""Samples distinct subgraphs for each entity primary key.
645
+
646
+ Args:
647
+ entity_table_name: The entity table name.
648
+ entity_pkey: The primary keys to use as seed nodes.
649
+ anchor_time: The anchor time of the subgraphs.
650
+ columns_dict: The columns to return for each table.
651
+ num_neighbors: The number of neighbors to sample for each hop.
652
+ """
653
+
654
+ @abstractmethod
655
+ def _sample_entity_table(
656
+ self,
657
+ table_name: str,
658
+ columns: set[str],
659
+ num_rows: int,
660
+ random_seed: int | None = None,
661
+ ) -> pd.DataFrame:
662
+ r"""Returns a random sample of rows from the entity table.
663
+
664
+ Args:
665
+ table_name: The table.
666
+ columns: The columns to return.
667
+ num_rows: Maximum number of rows to return. Can be smaller in case
668
+ the entity table contains less rows.
669
+ random_seed: A manual seed for generating pseudo-random numbers.
670
+ """
671
+
672
+ @abstractmethod
673
+ def _sample_target(
674
+ self,
675
+ query: ValidatedPredictiveQuery,
676
+ entity_df: pd.DataFrame,
677
+ train_index: np.ndarray,
678
+ train_time: pd.Series,
679
+ num_train_examples: int,
680
+ test_index: np.ndarray,
681
+ test_time: pd.Series,
682
+ num_test_examples: int,
683
+ columns_dict: dict[str, set[str]],
684
+ time_offset_dict: dict[
685
+ tuple[str, str, str],
686
+ tuple[pd.DateOffset | None, pd.DateOffset],
687
+ ],
688
+ ) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
689
+ r"""Samples ground-truth targets given a predictive query from a set of
690
+ training and test candidates.
691
+
692
+ Args:
693
+ query: The predictive query.
694
+ entity_df: The entity data frame, containing the union of all train
695
+ and test candidates.
696
+ train_index: The indices of training candidates.
697
+ train_time: The anchor time of training candidates.
698
+ num_train_examples: How many training examples to produce.
699
+ test_index: The indices of test candidates.
700
+ test_time: The anchor time of test candidates.
701
+ num_test_examples: How many test examples to produce.
702
+ columns_dict: The columns that are being used to compute
703
+ ground-truth targets.
704
+ time_offset_dict: The date offsets to query for each edge type,
705
+ relative to the anchor time.
706
+ """
707
+
708
+
709
+ # Helper Functions ############################################################
710
+
711
+ PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
712
+ MULTISPACE = re.compile(r"\s+")
713
+
714
+
715
+ def _normalize_text(
716
+ ser: pd.Series,
717
+ max_words: int | None = 50,
718
+ ) -> pd.Series:
719
+ r"""Normalizes text into a list of lower-case words.
720
+
721
+ Args:
722
+ ser: The :class:`pandas.Series` to normalize.
723
+ max_words: The maximum number of words to return.
724
+ This will auto-shrink any large text column to avoid blowing up
725
+ context size.
726
+ """
727
+ if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
728
+ return ser
729
+
730
+ def normalize_fn(line: str) -> list[str]:
731
+ line = PUNCTUATION.sub(" ", line)
732
+ line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
733
+ line = MULTISPACE.sub(" ", line)
734
+ words = line.split()
735
+ if max_words is not None:
736
+ words = words[:max_words]
737
+ return words
738
+
739
+ ser = ser.fillna('').astype(str)
740
+
741
+ if max_words is not None:
742
+ # We estimate the number of words as 5 characters + 1 space in an
743
+ # English text on average. We need this pre-filter here, as word
744
+ # splitting on a giant text can be very expensive:
745
+ ser = ser.str[:6 * max_words]
746
+
747
+ ser = ser.str.lower()
748
+ ser = ser.map(normalize_fn)
749
+
750
+ return ser
751
+
752
+
753
+ def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
754
+ if any(arg is None for arg in args):
755
+ return None
756
+
757
+ anchor = pd.Timestamp('2000-01-01')
758
+ timestamps = [anchor + arg for arg in args]
759
+ assert len(timestamps) > 0
760
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
761
+ return args[argmin]
762
+
763
+
764
+ def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
765
+ anchor = pd.Timestamp('2000-01-01')
766
+ timestamps = [anchor + arg for arg in args]
767
+ assert len(timestamps) > 0
768
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
769
+ return args[argmax]
770
+
771
+
772
+ def to_ser(value: Any, size: int) -> pd.Series:
773
+ return pd.Series([value]).repeat(size).reset_index(drop=True)