kumoai 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/backend/local/graph_store.py +19 -62
- kumoai/experimental/rfm/backend/local/sampler.py +229 -45
- kumoai/experimental/rfm/backend/local/table.py +12 -2
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +264 -0
- kumoai/experimental/rfm/backend/snow/table.py +35 -17
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +354 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +36 -11
- kumoai/experimental/rfm/base/__init__.py +16 -5
- kumoai/experimental/rfm/base/sampler.py +538 -52
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +56 -0
- kumoai/experimental/rfm/base/table.py +12 -1
- kumoai/experimental/rfm/graph.py +26 -9
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +214 -151
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/METADATA +2 -2
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/RECORD +28 -25
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512061731.dist-info → kumoai-2.14.0.dev202512151351.dist-info}/top_level.txt +0 -0
|
@@ -1,45 +1,58 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
import re
|
|
3
|
+
import warnings
|
|
3
4
|
from abc import ABC, abstractmethod
|
|
5
|
+
from collections import defaultdict
|
|
4
6
|
from dataclasses import dataclass
|
|
5
|
-
from typing import TYPE_CHECKING
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Literal, NamedTuple
|
|
6
8
|
|
|
7
9
|
import numpy as np
|
|
8
10
|
import pandas as pd
|
|
11
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
12
|
+
from kumoapi.pquery.AST import Aggregation, ASTNode
|
|
9
13
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
10
14
|
from kumoapi.typing import Stype
|
|
11
15
|
|
|
16
|
+
from kumoai.experimental.rfm.base import SourceColumn
|
|
17
|
+
from kumoai.utils import ProgressLogger
|
|
18
|
+
|
|
12
19
|
if TYPE_CHECKING:
|
|
13
20
|
from kumoai.experimental.rfm import Graph
|
|
14
21
|
|
|
15
|
-
|
|
16
|
-
@dataclass
|
|
17
|
-
class EdgeSpec:
|
|
18
|
-
num_neighbors: int | None = None
|
|
19
|
-
time_offsets: tuple[
|
|
20
|
-
pd.DateOffset | None,
|
|
21
|
-
pd.DateOffset,
|
|
22
|
-
] | None = None
|
|
23
|
-
|
|
24
|
-
def __post_init__(self) -> None:
|
|
25
|
-
if (self.num_neighbors is None) == (self.time_offsets is None):
|
|
26
|
-
raise ValueError("Only one of 'num_neighbors' and 'time_offsets' "
|
|
27
|
-
"must be provided")
|
|
22
|
+
_coverage_warned = False
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
@dataclass
|
|
31
26
|
class SamplerOutput:
|
|
27
|
+
anchor_time: np.ndarray
|
|
32
28
|
df_dict: dict[str, pd.DataFrame]
|
|
33
29
|
inverse_dict: dict[str, np.ndarray]
|
|
34
30
|
batch_dict: dict[str, np.ndarray]
|
|
35
31
|
num_sampled_nodes_dict: dict[str, list[int]]
|
|
36
|
-
row_dict: dict[tuple[str, str, str], np.ndarray]
|
|
37
|
-
col_dict: dict[tuple[str, str, str], np.ndarray]
|
|
38
|
-
num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
|
|
32
|
+
row_dict: dict[tuple[str, str, str], np.ndarray]
|
|
33
|
+
col_dict: dict[tuple[str, str, str], np.ndarray]
|
|
34
|
+
num_sampled_edges_dict: dict[tuple[str, str, str], list[int]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TargetOutput(NamedTuple):
|
|
38
|
+
entity_pkey: pd.Series
|
|
39
|
+
anchor_time: pd.Series
|
|
40
|
+
target: pd.Series
|
|
39
41
|
|
|
40
42
|
|
|
41
43
|
class Sampler(ABC):
|
|
42
|
-
|
|
44
|
+
r"""A base class to sample relational data (*i.e.*, subgraphs and
|
|
45
|
+
ground-truth targets) from a custom backend.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
graph: The graph.
|
|
49
|
+
verbose: Whether to print verbose output.
|
|
50
|
+
"""
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
graph: 'Graph',
|
|
54
|
+
verbose: bool | ProgressLogger = True,
|
|
55
|
+
) -> None:
|
|
43
56
|
self._edge_types: list[tuple[str, str, str]] = []
|
|
44
57
|
for edge in graph.edges:
|
|
45
58
|
edge_type = (edge.src_table, edge.fkey, edge.dst_table)
|
|
@@ -75,40 +88,106 @@ class Sampler(ABC):
|
|
|
75
88
|
continue
|
|
76
89
|
self._table_stype_dict[table.name][column.name] = column.stype
|
|
77
90
|
|
|
91
|
+
self._source_table_dict: dict[str, dict[str, SourceColumn]] = {}
|
|
92
|
+
for table in graph.tables.values():
|
|
93
|
+
self._source_table_dict[table.name] = table._source_column_dict
|
|
94
|
+
|
|
95
|
+
self._min_time_dict: dict[str, pd.Timestamp] = {}
|
|
96
|
+
self._max_time_dict: dict[str, pd.Timestamp] = {}
|
|
97
|
+
|
|
98
|
+
# Properties ##############################################################
|
|
99
|
+
|
|
78
100
|
@property
|
|
79
101
|
def edge_types(self) -> list[tuple[str, str, str]]:
|
|
102
|
+
r"""All available edge types in the graph."""
|
|
80
103
|
return self._edge_types
|
|
81
104
|
|
|
82
105
|
@property
|
|
83
106
|
def primary_key_dict(self) -> dict[str, str]:
|
|
107
|
+
r"""All available primary keys in the graph."""
|
|
84
108
|
return self._primary_key_dict
|
|
85
109
|
|
|
86
110
|
@property
|
|
87
111
|
def time_column_dict(self) -> dict[str, str]:
|
|
112
|
+
r"""All available time columns in the graph."""
|
|
88
113
|
return self._time_column_dict
|
|
89
114
|
|
|
90
115
|
@property
|
|
91
116
|
def end_time_column_dict(self) -> dict[str, str]:
|
|
117
|
+
r"""All available end time columns in the graph."""
|
|
92
118
|
return self._end_time_column_dict
|
|
93
119
|
|
|
94
120
|
@property
|
|
95
121
|
def table_stype_dict(self) -> dict[str, dict[str, Stype]]:
|
|
122
|
+
r"""The registered semantic types for all columns in all tables in
|
|
123
|
+
the graph.
|
|
124
|
+
"""
|
|
96
125
|
return self._table_stype_dict
|
|
97
126
|
|
|
127
|
+
@property
|
|
128
|
+
def source_table_dict(self) -> dict[str, dict[str, SourceColumn]]:
|
|
129
|
+
r"""Source column information for all tables in the graph."""
|
|
130
|
+
return self._source_table_dict
|
|
131
|
+
|
|
132
|
+
def get_min_time(
|
|
133
|
+
self,
|
|
134
|
+
table_names: list[str] | None = None,
|
|
135
|
+
) -> pd.Timestamp:
|
|
136
|
+
r"""Returns the minimal timestamp in the union of a set of tables.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
table_names: The set of tables.
|
|
140
|
+
"""
|
|
141
|
+
if table_names is None or len(table_names) == 0:
|
|
142
|
+
table_names = list(self.time_column_dict.keys())
|
|
143
|
+
unknown = list(set(table_names) - set(self._min_time_dict.keys()))
|
|
144
|
+
if len(unknown) > 0:
|
|
145
|
+
min_max_time_dict = self._get_min_max_time_dict(unknown)
|
|
146
|
+
for table_name, (min_time, max_time) in min_max_time_dict.items():
|
|
147
|
+
self._min_time_dict[table_name] = min_time
|
|
148
|
+
self._max_time_dict[table_name] = max_time
|
|
149
|
+
return min([self._min_time_dict[table]
|
|
150
|
+
for table in table_names] + [pd.Timestamp.max])
|
|
151
|
+
|
|
152
|
+
def get_max_time(
|
|
153
|
+
self,
|
|
154
|
+
table_names: list[str] | None = None,
|
|
155
|
+
) -> pd.Timestamp:
|
|
156
|
+
r"""Returns the maximum timestamp in the union of a set of tables.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
table_names: The set of tables.
|
|
160
|
+
"""
|
|
161
|
+
if table_names is None or len(table_names) == 0:
|
|
162
|
+
table_names = list(self.time_column_dict.keys())
|
|
163
|
+
unknown = list(set(table_names) - set(self._max_time_dict.keys()))
|
|
164
|
+
if len(unknown) > 0:
|
|
165
|
+
min_max_time_dict = self._get_min_max_time_dict(unknown)
|
|
166
|
+
for table_name, (min_time, max_time) in min_max_time_dict.items():
|
|
167
|
+
self._min_time_dict[table_name] = min_time
|
|
168
|
+
self._max_time_dict[table_name] = max_time
|
|
169
|
+
return max([self._max_time_dict[table]
|
|
170
|
+
for table in table_names] + [pd.Timestamp.min])
|
|
171
|
+
|
|
172
|
+
# Subgraph Sampling #######################################################
|
|
173
|
+
|
|
98
174
|
def sample_subgraph(
|
|
99
175
|
self,
|
|
100
176
|
entity_table_names: tuple[str, ...],
|
|
101
177
|
entity_pkey: pd.Series,
|
|
102
|
-
anchor_time: pd.Series,
|
|
178
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
103
179
|
num_neighbors: list[int],
|
|
104
180
|
exclude_cols_dict: dict[str, list[str]] | None = None,
|
|
105
181
|
) -> Subgraph:
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
182
|
+
r"""Samples distinct subgraphs for each entity primary key.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
entity_table_names: The entity table names.
|
|
186
|
+
entity_pkey: The primary keys to use as seed nodes.
|
|
187
|
+
anchor_time: The anchor time of the subgraphs.
|
|
188
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
189
|
+
exclude_cols_dict: The columns to exclude from the subgraph.
|
|
190
|
+
"""
|
|
112
191
|
# Exclude all columns that leak target information:
|
|
113
192
|
table_stype_dict: dict[str, dict[str, Stype]] = self._table_stype_dict
|
|
114
193
|
if exclude_cols_dict is not None:
|
|
@@ -118,30 +197,29 @@ class Sampler(ABC):
|
|
|
118
197
|
del table_stype_dict[table_name][column_name]
|
|
119
198
|
|
|
120
199
|
# Collect all columns being used as features:
|
|
121
|
-
|
|
122
|
-
table_name:
|
|
200
|
+
columns_dict: dict[str, set[str]] = {
|
|
201
|
+
table_name: set(stype_dict.keys())
|
|
123
202
|
for table_name, stype_dict in table_stype_dict.items()
|
|
124
203
|
}
|
|
125
204
|
# Make sure to store primary key information for entity tables:
|
|
126
205
|
for table_name in entity_table_names:
|
|
127
|
-
|
|
128
|
-
[self.primary_key_dict[table_name]] +
|
|
129
|
-
column_spec_dict[table_name])
|
|
206
|
+
columns_dict[table_name].add(self.primary_key_dict[table_name])
|
|
130
207
|
|
|
131
|
-
if anchor_time.
|
|
208
|
+
if (isinstance(anchor_time, pd.Series)
|
|
209
|
+
and anchor_time.dtype != 'datetime64[ns]'):
|
|
132
210
|
anchor_time = anchor_time.astype('datetime64[ns]')
|
|
133
|
-
|
|
211
|
+
|
|
212
|
+
out = self._sample_subgraph(
|
|
134
213
|
entity_table_name=entity_table_names[0],
|
|
135
214
|
entity_pkey=entity_pkey,
|
|
136
215
|
anchor_time=anchor_time,
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
drop_duplicates=True,
|
|
140
|
-
return_edges=True,
|
|
216
|
+
columns_dict=columns_dict,
|
|
217
|
+
num_neighbors=num_neighbors,
|
|
141
218
|
)
|
|
142
219
|
|
|
220
|
+
# Parse `SubgraphOutput` into `Subgraph` structure:
|
|
143
221
|
subgraph = Subgraph(
|
|
144
|
-
anchor_time=anchor_time
|
|
222
|
+
anchor_time=out.anchor_time,
|
|
145
223
|
table_dict={},
|
|
146
224
|
link_dict={},
|
|
147
225
|
)
|
|
@@ -150,18 +228,18 @@ class Sampler(ABC):
|
|
|
150
228
|
if len(batch) == 0:
|
|
151
229
|
continue
|
|
152
230
|
|
|
153
|
-
primary_key = None
|
|
231
|
+
primary_key: str | None = None
|
|
154
232
|
if table_name in entity_table_names:
|
|
155
|
-
primary_key = self.primary_key_dict
|
|
233
|
+
primary_key = self.primary_key_dict[table_name]
|
|
156
234
|
|
|
157
235
|
df = out.df_dict[table_name].reset_index(drop=True)
|
|
158
|
-
if
|
|
236
|
+
if end_time_column := self.end_time_column_dict.get(table_name):
|
|
159
237
|
# Set end time to NaT for all values greater than anchor time:
|
|
160
|
-
|
|
238
|
+
assert table_name not in out.inverse_dict
|
|
161
239
|
ser = df[end_time_column]
|
|
162
240
|
if ser.dtype != 'datetime64[ns]':
|
|
163
241
|
ser = ser.astype('datetime64[ns]')
|
|
164
|
-
mask = ser > anchor_time
|
|
242
|
+
mask = ser.astype(int).to_numpy() > out.anchor_time[batch]
|
|
165
243
|
ser.iloc[mask] = pd.NaT
|
|
166
244
|
df[end_time_column] = ser
|
|
167
245
|
|
|
@@ -179,9 +257,6 @@ class Sampler(ABC):
|
|
|
179
257
|
primary_key=primary_key,
|
|
180
258
|
)
|
|
181
259
|
|
|
182
|
-
assert out.row_dict is not None
|
|
183
|
-
assert out.col_dict is not None
|
|
184
|
-
assert out.num_sampled_edges_dict is not None
|
|
185
260
|
for edge_type in out.row_dict.keys():
|
|
186
261
|
row: np.ndarray | None = out.row_dict[edge_type]
|
|
187
262
|
col: np.ndarray | None = out.col_dict[edge_type]
|
|
@@ -227,20 +302,408 @@ class Sampler(ABC):
|
|
|
227
302
|
|
|
228
303
|
return subgraph
|
|
229
304
|
|
|
305
|
+
# Predictive Query ########################################################
|
|
306
|
+
|
|
307
|
+
def _get_query_columns_dict(
|
|
308
|
+
self,
|
|
309
|
+
query: ValidatedPredictiveQuery,
|
|
310
|
+
) -> dict[str, set[str]]:
|
|
311
|
+
columns_dict: dict[str, set[str]] = defaultdict(set)
|
|
312
|
+
for fqn in query.all_query_columns + [query.entity_column]:
|
|
313
|
+
table_name, column_name = fqn.split('.')
|
|
314
|
+
if column_name == '*':
|
|
315
|
+
continue
|
|
316
|
+
columns_dict[table_name].add(column_name)
|
|
317
|
+
if column_name := self.time_column_dict.get(query.entity_table):
|
|
318
|
+
columns_dict[table_name].add(column_name)
|
|
319
|
+
if column_name := self.end_time_column_dict.get(query.entity_table):
|
|
320
|
+
columns_dict[table_name].add(column_name)
|
|
321
|
+
return columns_dict
|
|
322
|
+
|
|
323
|
+
def _get_query_time_offset_dict(
|
|
324
|
+
self,
|
|
325
|
+
query: ValidatedPredictiveQuery,
|
|
326
|
+
) -> dict[
|
|
327
|
+
tuple[str, str, str],
|
|
328
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
329
|
+
]:
|
|
330
|
+
time_offset_dict: dict[
|
|
331
|
+
tuple[str, str, str],
|
|
332
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
333
|
+
] = {}
|
|
334
|
+
|
|
335
|
+
def _add_time_offset(node: ASTNode, num_forecasts: int = 1) -> None:
|
|
336
|
+
if isinstance(node, Aggregation):
|
|
337
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
338
|
+
|
|
339
|
+
edge_types = [
|
|
340
|
+
edge_type for edge_type in self.edge_types
|
|
341
|
+
if edge_type[0] == table_name
|
|
342
|
+
and edge_type[2] == query.entity_table
|
|
343
|
+
]
|
|
344
|
+
if len(edge_types) != 1:
|
|
345
|
+
raise ValueError(f"Could not find a unique foreign key "
|
|
346
|
+
f"from table '{table_name}' to "
|
|
347
|
+
f"'{query.entity_table}'")
|
|
348
|
+
if edge_types[0] not in time_offset_dict:
|
|
349
|
+
start = node.aggr_time_range.start_date_offset
|
|
350
|
+
end = node.aggr_time_range.end_date_offset * num_forecasts
|
|
351
|
+
else:
|
|
352
|
+
start, end = time_offset_dict[edge_types[0]]
|
|
353
|
+
start = min_date_offset(
|
|
354
|
+
start,
|
|
355
|
+
node.aggr_time_range.start_date_offset,
|
|
356
|
+
)
|
|
357
|
+
end = max_date_offset(
|
|
358
|
+
end,
|
|
359
|
+
node.aggr_time_range.end_date_offset * num_forecasts,
|
|
360
|
+
)
|
|
361
|
+
time_offset_dict[edge_types[0]] = (start, end)
|
|
362
|
+
|
|
363
|
+
for child in node.children:
|
|
364
|
+
_add_time_offset(child, num_forecasts)
|
|
365
|
+
|
|
366
|
+
_add_time_offset(query.target_ast, query.num_forecasts)
|
|
367
|
+
_add_time_offset(query.entity_ast)
|
|
368
|
+
if query.whatif_ast is not None:
|
|
369
|
+
_add_time_offset(query.whatif_ast)
|
|
370
|
+
|
|
371
|
+
return time_offset_dict
|
|
372
|
+
|
|
373
|
+
def sample_target(
|
|
374
|
+
self,
|
|
375
|
+
query: ValidatedPredictiveQuery,
|
|
376
|
+
num_train_examples: int,
|
|
377
|
+
train_anchor_time: pd.Timestamp | Literal['entity'],
|
|
378
|
+
num_train_trials: int,
|
|
379
|
+
num_test_examples: int,
|
|
380
|
+
test_anchor_time: pd.Timestamp | Literal['entity'],
|
|
381
|
+
num_test_trials: int,
|
|
382
|
+
random_seed: int | None = None,
|
|
383
|
+
) -> tuple[TargetOutput, TargetOutput]:
|
|
384
|
+
r"""Samples ground-truth targets given a predictive query, split into
|
|
385
|
+
training and test set.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
query: The predictive query.
|
|
389
|
+
num_train_examples: How many training examples to produce.
|
|
390
|
+
train_anchor_time: The anchor timestamp for the training set.
|
|
391
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
392
|
+
num_train_trials: The number of training examples to try before
|
|
393
|
+
aborting.
|
|
394
|
+
num_test_examples: How many test examples to produce.
|
|
395
|
+
test_anchor_time: The anchor timestamp for the test set.
|
|
396
|
+
If set to ``"entity"``, will use the timestamp of the entity.
|
|
397
|
+
num_test_trials: The number of test examples to try before
|
|
398
|
+
aborting.
|
|
399
|
+
random_seed: A manual seed for generating pseudo-random numbers.
|
|
400
|
+
"""
|
|
401
|
+
rng = np.random.default_rng(random_seed)
|
|
402
|
+
|
|
403
|
+
if num_train_examples == 0 or num_train_trials == 0:
|
|
404
|
+
num_train_examples = num_train_trials = 0
|
|
405
|
+
if num_test_examples == 0 or num_test_trials == 0:
|
|
406
|
+
num_test_examples = num_test_trials = 0
|
|
407
|
+
|
|
408
|
+
# 1. Collect information on what to query #############################
|
|
409
|
+
columns_dict = self._get_query_columns_dict(query)
|
|
410
|
+
time_offset_dict = self._get_query_time_offset_dict(query)
|
|
411
|
+
for table_name, _, _ in time_offset_dict.keys():
|
|
412
|
+
columns_dict[table_name].add(self.time_column_dict[table_name])
|
|
413
|
+
|
|
414
|
+
# 2. Sample random rows from entity table #############################
|
|
415
|
+
shared_train_test = query.query_type == QueryType.STATIC
|
|
416
|
+
shared_train_test &= train_anchor_time == test_anchor_time
|
|
417
|
+
if shared_train_test:
|
|
418
|
+
num_entity_rows = num_train_trials + num_test_trials
|
|
419
|
+
else:
|
|
420
|
+
num_entity_rows = max(num_train_trials, num_test_trials)
|
|
421
|
+
assert num_entity_rows > 0
|
|
422
|
+
|
|
423
|
+
entity_df = self._sample_entity_table(
|
|
424
|
+
table_name=query.entity_table,
|
|
425
|
+
columns=columns_dict[query.entity_table],
|
|
426
|
+
num_rows=num_entity_rows,
|
|
427
|
+
random_seed=random_seed,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
if len(entity_df) == 0:
|
|
431
|
+
raise ValueError("Failed to find any rows in the entity table "
|
|
432
|
+
"'{query.entity_table}'.")
|
|
433
|
+
|
|
434
|
+
entity_pkey = entity_df[self.primary_key_dict[query.entity_table]]
|
|
435
|
+
entity_time: pd.Series | None = None
|
|
436
|
+
if column_name := self.time_column_dict.get(query.entity_table):
|
|
437
|
+
entity_time = entity_df[column_name]
|
|
438
|
+
entity_end_time: pd.Series | None = None
|
|
439
|
+
if column_name := self.end_time_column_dict.get(query.entity_table):
|
|
440
|
+
entity_end_time = entity_df[column_name]
|
|
441
|
+
|
|
442
|
+
def get_valid_entity_index(
|
|
443
|
+
time: pd.Timestamp | Literal['entity'],
|
|
444
|
+
max_size: int | None = None,
|
|
445
|
+
) -> np.ndarray:
|
|
446
|
+
|
|
447
|
+
if time == 'entity':
|
|
448
|
+
index: np.ndarray = np.arange(len(entity_pkey))
|
|
449
|
+
elif entity_time is None and entity_end_time is None:
|
|
450
|
+
index = np.arange(len(entity_pkey))
|
|
451
|
+
else:
|
|
452
|
+
mask: np.ndarray | None = None
|
|
453
|
+
if entity_time is not None:
|
|
454
|
+
mask = (entity_time <= time).to_numpy()
|
|
455
|
+
if entity_end_time is not None:
|
|
456
|
+
_mask = (entity_end_time > time).to_numpy()
|
|
457
|
+
_mask |= entity_end_time.isna().to_numpy()
|
|
458
|
+
mask = _mask if mask is None else mask & _mask
|
|
459
|
+
assert mask is not None
|
|
460
|
+
index = mask.nonzero()[0]
|
|
461
|
+
|
|
462
|
+
rng.shuffle(index)
|
|
463
|
+
|
|
464
|
+
if max_size is not None:
|
|
465
|
+
index = index[:max_size]
|
|
466
|
+
|
|
467
|
+
return index
|
|
468
|
+
|
|
469
|
+
# 3. Build training and test candidates ###############################
|
|
470
|
+
train_index = test_index = np.array([], dtype=np.int64)
|
|
471
|
+
train_time = test_time = pd.Series([], dtype='datetime64[ns]')
|
|
472
|
+
|
|
473
|
+
if shared_train_test:
|
|
474
|
+
train_index = get_valid_entity_index(train_anchor_time)
|
|
475
|
+
if train_anchor_time == 'entity': # Sort by timestamp:
|
|
476
|
+
assert entity_time is not None
|
|
477
|
+
train_time = entity_time.iloc[train_index]
|
|
478
|
+
train_time = train_time.reset_index(drop=True)
|
|
479
|
+
train_time = train_time.sort_values(ascending=False)
|
|
480
|
+
perm = train_time.index.to_numpy()
|
|
481
|
+
train_index = train_index[perm]
|
|
482
|
+
train_time = train_time.reset_index(drop=True)
|
|
483
|
+
else:
|
|
484
|
+
train_time = to_ser(train_anchor_time, size=len(train_index))
|
|
485
|
+
else:
|
|
486
|
+
if num_test_examples > 0:
|
|
487
|
+
test_index = get_valid_entity_index( #
|
|
488
|
+
test_anchor_time, max_size=num_test_trials)
|
|
489
|
+
assert test_anchor_time != 'entity'
|
|
490
|
+
test_time = to_ser(test_anchor_time, len(test_index))
|
|
491
|
+
|
|
492
|
+
if query.query_type == QueryType.STATIC and num_train_examples > 0:
|
|
493
|
+
train_index = get_valid_entity_index( #
|
|
494
|
+
train_anchor_time, max_size=num_train_trials)
|
|
495
|
+
assert train_anchor_time != 'entity'
|
|
496
|
+
train_time = to_ser(train_anchor_time, len(train_index))
|
|
497
|
+
elif query.query_type == QueryType.TEMPORAL and num_train_examples:
|
|
498
|
+
aggr_table_names = [
|
|
499
|
+
aggr._get_target_column_name().split('.')[0]
|
|
500
|
+
for aggr in query.get_all_target_aggregations()
|
|
501
|
+
]
|
|
502
|
+
offset = query.target_timeframe.timeframe * query.num_forecasts
|
|
503
|
+
|
|
504
|
+
train_indices: list[np.ndarray] = []
|
|
505
|
+
train_times: list[pd.Series] = []
|
|
506
|
+
while True:
|
|
507
|
+
train_index = get_valid_entity_index( #
|
|
508
|
+
train_anchor_time, max_size=num_train_trials)
|
|
509
|
+
assert train_anchor_time != 'entity'
|
|
510
|
+
train_time = to_ser(train_anchor_time, len(train_index))
|
|
511
|
+
train_indices.append(train_index)
|
|
512
|
+
train_times.append(train_time)
|
|
513
|
+
if sum(len(x) for x in train_indices) >= num_train_trials:
|
|
514
|
+
break
|
|
515
|
+
train_anchor_time -= offset
|
|
516
|
+
if train_anchor_time < self.get_min_time(aggr_table_names):
|
|
517
|
+
break
|
|
518
|
+
train_index = np.concatenate(train_indices, axis=0)
|
|
519
|
+
train_index = train_index[:num_train_trials]
|
|
520
|
+
train_time = pd.concat(train_times, axis=0, ignore_index=True)
|
|
521
|
+
train_time = train_time.iloc[:num_train_trials]
|
|
522
|
+
|
|
523
|
+
# 4. Sample training and test labels ##################################
|
|
524
|
+
train_y, train_mask, test_y, test_mask = self._sample_target(
|
|
525
|
+
query=query,
|
|
526
|
+
entity_df=entity_df,
|
|
527
|
+
train_index=train_index,
|
|
528
|
+
train_time=train_time,
|
|
529
|
+
num_train_examples=(num_train_examples + num_test_examples
|
|
530
|
+
if shared_train_test else num_train_examples),
|
|
531
|
+
test_index=test_index,
|
|
532
|
+
test_time=test_time,
|
|
533
|
+
num_test_examples=0 if shared_train_test else num_test_examples,
|
|
534
|
+
columns_dict=columns_dict,
|
|
535
|
+
time_offset_dict=time_offset_dict,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# 5. Post-processing ##################################################
|
|
539
|
+
if shared_train_test:
|
|
540
|
+
num_examples = num_train_examples + num_test_examples
|
|
541
|
+
train_index = train_index[train_mask][:num_examples]
|
|
542
|
+
train_time = train_time.iloc[train_mask].iloc[:num_examples]
|
|
543
|
+
train_y = train_y.iloc[:num_examples]
|
|
544
|
+
|
|
545
|
+
_num_test = num_test_examples
|
|
546
|
+
_num_train = min(num_train_examples, 1000)
|
|
547
|
+
if (num_test_examples > 0 and num_train_examples > 0
|
|
548
|
+
and len(train_y) < num_examples
|
|
549
|
+
and len(train_y) < _num_test + _num_train):
|
|
550
|
+
# Not enough labels to satisfy requested split without losing
|
|
551
|
+
# large number of training examples:
|
|
552
|
+
_num_test = len(train_y) - _num_train
|
|
553
|
+
if _num_test < _num_train: # Fallback to 50/50 split:
|
|
554
|
+
_num_test = len(train_y) // 2
|
|
555
|
+
|
|
556
|
+
test_index = train_index[:_num_test]
|
|
557
|
+
test_pkey = entity_pkey.iloc[test_index]
|
|
558
|
+
test_time = train_time.iloc[:_num_test]
|
|
559
|
+
test_y = train_y.iloc[:_num_test]
|
|
560
|
+
|
|
561
|
+
train_index = train_index[_num_test:]
|
|
562
|
+
train_pkey = entity_pkey.iloc[train_index]
|
|
563
|
+
train_time = train_time.iloc[_num_test:]
|
|
564
|
+
train_y = train_y.iloc[_num_test:]
|
|
565
|
+
else:
|
|
566
|
+
train_index = train_index[train_mask][:num_train_examples]
|
|
567
|
+
train_pkey = entity_pkey.iloc[train_index]
|
|
568
|
+
train_time = train_time.iloc[train_mask].iloc[:num_train_examples]
|
|
569
|
+
train_y = train_y.iloc[:num_train_examples]
|
|
570
|
+
|
|
571
|
+
test_index = test_index[test_mask][:num_test_examples]
|
|
572
|
+
test_pkey = entity_pkey.iloc[test_index]
|
|
573
|
+
test_time = test_time.iloc[test_mask].iloc[:num_test_examples]
|
|
574
|
+
test_y = test_y.iloc[:num_test_examples]
|
|
575
|
+
|
|
576
|
+
train_pkey = train_pkey.reset_index(drop=True)
|
|
577
|
+
train_time = train_time.reset_index(drop=True)
|
|
578
|
+
train_y = train_y.reset_index(drop=True)
|
|
579
|
+
test_pkey = test_pkey.reset_index(drop=True)
|
|
580
|
+
test_time = test_time.reset_index(drop=True)
|
|
581
|
+
test_y = test_y.reset_index(drop=True)
|
|
582
|
+
|
|
583
|
+
if num_train_examples > 0 and len(train_y) == 0:
|
|
584
|
+
raise RuntimeError("Failed to collect any context examples. Is "
|
|
585
|
+
"your predictive query too restrictive?")
|
|
586
|
+
|
|
587
|
+
if num_test_examples > 0 and len(test_y) == 0:
|
|
588
|
+
raise RuntimeError("Failed to collect any test examples for "
|
|
589
|
+
"evaluation. Is your predictive query too "
|
|
590
|
+
"restrictive?")
|
|
591
|
+
|
|
592
|
+
global _coverage_warned
|
|
593
|
+
if (not num_train_examples > 0 #
|
|
594
|
+
and not _coverage_warned #
|
|
595
|
+
and len(entity_df) >= num_entity_rows
|
|
596
|
+
and len(train_y) < num_train_examples // 2):
|
|
597
|
+
_coverage_warned = True
|
|
598
|
+
warnings.warn(f"Failed to collect {num_train_examples:,} context "
|
|
599
|
+
f"examples within {num_train_trials:,} candidates. "
|
|
600
|
+
f"To improve coverage, consider increasing the "
|
|
601
|
+
f"number of PQ iterations using the "
|
|
602
|
+
f"'max_pq_iterations' option. This warning will not "
|
|
603
|
+
f"be shown again in this run.")
|
|
604
|
+
|
|
605
|
+
if (not num_test_examples > 0 #
|
|
606
|
+
and not _coverage_warned #
|
|
607
|
+
and len(entity_df) >= num_entity_rows
|
|
608
|
+
and len(test_y) < num_test_examples // 2):
|
|
609
|
+
_coverage_warned = True
|
|
610
|
+
warnings.warn(f"Failed to collect {num_test_examples:,} test "
|
|
611
|
+
f"examples within {num_test_trials:,} candidates. "
|
|
612
|
+
f"To improve coverage, consider increasing the "
|
|
613
|
+
f"number of PQ iterations using the "
|
|
614
|
+
f"'max_pq_iterations' option. This warning will not "
|
|
615
|
+
f"be shown again in this run.")
|
|
616
|
+
|
|
617
|
+
return (
|
|
618
|
+
TargetOutput(train_pkey, train_time, train_y),
|
|
619
|
+
TargetOutput(test_pkey, test_time, test_y),
|
|
620
|
+
)
|
|
621
|
+
|
|
230
622
|
# Abstract Methods ########################################################
|
|
231
623
|
|
|
232
624
|
@abstractmethod
|
|
233
|
-
def
|
|
625
|
+
def _get_min_max_time_dict(
|
|
626
|
+
self,
|
|
627
|
+
table_names: list[str],
|
|
628
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
629
|
+
r"""Returns the minimum and maximum timestamps for a set of tables.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
table_names: The tables.
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
@abstractmethod
|
|
636
|
+
def _sample_subgraph(
|
|
234
637
|
self,
|
|
235
638
|
entity_table_name: str,
|
|
236
639
|
entity_pkey: pd.Series,
|
|
237
|
-
anchor_time: pd.Series,
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
drop_duplicates: bool = False,
|
|
241
|
-
return_edges: bool = False,
|
|
640
|
+
anchor_time: pd.Series | Literal['entity'],
|
|
641
|
+
columns_dict: dict[str, set[str]],
|
|
642
|
+
num_neighbors: list[int],
|
|
242
643
|
) -> SamplerOutput:
|
|
243
|
-
|
|
644
|
+
r"""Samples distinct subgraphs for each entity primary key.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
entity_table_name: The entity table name.
|
|
648
|
+
entity_pkey: The primary keys to use as seed nodes.
|
|
649
|
+
anchor_time: The anchor time of the subgraphs.
|
|
650
|
+
columns_dict: The columns to return for each table.
|
|
651
|
+
num_neighbors: The number of neighbors to sample for each hop.
|
|
652
|
+
"""
|
|
653
|
+
|
|
654
|
+
@abstractmethod
|
|
655
|
+
def _sample_entity_table(
|
|
656
|
+
self,
|
|
657
|
+
table_name: str,
|
|
658
|
+
columns: set[str],
|
|
659
|
+
num_rows: int,
|
|
660
|
+
random_seed: int | None = None,
|
|
661
|
+
) -> pd.DataFrame:
|
|
662
|
+
r"""Returns a random sample of rows from the entity table.
|
|
663
|
+
|
|
664
|
+
Args:
|
|
665
|
+
table_name: The table.
|
|
666
|
+
columns: The columns to return.
|
|
667
|
+
num_rows: Maximum number of rows to return. Can be smaller in case
|
|
668
|
+
the entity table contains less rows.
|
|
669
|
+
random_seed: A manual seed for generating pseudo-random numbers.
|
|
670
|
+
"""
|
|
671
|
+
|
|
672
|
+
@abstractmethod
|
|
673
|
+
def _sample_target(
|
|
674
|
+
self,
|
|
675
|
+
query: ValidatedPredictiveQuery,
|
|
676
|
+
entity_df: pd.DataFrame,
|
|
677
|
+
train_index: np.ndarray,
|
|
678
|
+
train_time: pd.Series,
|
|
679
|
+
num_train_examples: int,
|
|
680
|
+
test_index: np.ndarray,
|
|
681
|
+
test_time: pd.Series,
|
|
682
|
+
num_test_examples: int,
|
|
683
|
+
columns_dict: dict[str, set[str]],
|
|
684
|
+
time_offset_dict: dict[
|
|
685
|
+
tuple[str, str, str],
|
|
686
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
687
|
+
],
|
|
688
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
689
|
+
r"""Samples ground-truth targets given a predictive query from a set of
|
|
690
|
+
training and test candidates.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
query: The predictive query.
|
|
694
|
+
entity_df: The entity data frame, containing the union of all train
|
|
695
|
+
and test candidates.
|
|
696
|
+
train_index: The indices of training candidates.
|
|
697
|
+
train_time: The anchor time of training candidates.
|
|
698
|
+
num_train_examples: How many training examples to produce.
|
|
699
|
+
test_index: The indices of test candidates.
|
|
700
|
+
test_time: The anchor time of test candidates.
|
|
701
|
+
num_test_examples: How many test examples to produce.
|
|
702
|
+
columns_dict: The columns that are being used to compute
|
|
703
|
+
ground-truth targets.
|
|
704
|
+
time_offset_dict: The date offsets to query for each edge type,
|
|
705
|
+
relative to the anchor time.
|
|
706
|
+
"""
|
|
244
707
|
|
|
245
708
|
|
|
246
709
|
# Helper Functions ############################################################
|
|
@@ -285,3 +748,26 @@ def _normalize_text(
|
|
|
285
748
|
ser = ser.map(normalize_fn)
|
|
286
749
|
|
|
287
750
|
return ser
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
def min_date_offset(*args: pd.DateOffset | None) -> pd.DateOffset | None:
|
|
754
|
+
if any(arg is None for arg in args):
|
|
755
|
+
return None
|
|
756
|
+
|
|
757
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
758
|
+
timestamps = [anchor + arg for arg in args]
|
|
759
|
+
assert len(timestamps) > 0
|
|
760
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
761
|
+
return args[argmin]
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def max_date_offset(*args: pd.DateOffset) -> pd.DateOffset:
|
|
765
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
766
|
+
timestamps = [anchor + arg for arg in args]
|
|
767
|
+
assert len(timestamps) > 0
|
|
768
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
769
|
+
return args[argmax]
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
def to_ser(value: Any, size: int) -> pd.Series:
|
|
773
|
+
return pd.Series([value]).repeat(size).reset_index(drop=True)
|