kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +44 -9
- kumoai/experimental/rfm/__init__.py +70 -68
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
- kumoai/experimental/rfm/backend/snow/table.py +245 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +783 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +722 -0
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +84 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +63 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +783 -481
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +55 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
|
@@ -1,689 +0,0 @@
|
|
|
1
|
-
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
|
-
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
-
from kumoapi.pquery.AST import (
|
|
8
|
-
Aggregation,
|
|
9
|
-
ASTNode,
|
|
10
|
-
Column,
|
|
11
|
-
Condition,
|
|
12
|
-
Filter,
|
|
13
|
-
Join,
|
|
14
|
-
LogicalOperation,
|
|
15
|
-
)
|
|
16
|
-
from kumoapi.task import TaskType
|
|
17
|
-
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
18
|
-
|
|
19
|
-
import kumoai.kumolib as kumolib
|
|
20
|
-
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
21
|
-
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
22
|
-
|
|
23
|
-
_coverage_warned = False
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class SamplingSpec(NamedTuple):
|
|
27
|
-
edge_type: Tuple[str, str, str]
|
|
28
|
-
hop: int
|
|
29
|
-
start_offset: Optional[DateOffset]
|
|
30
|
-
end_offset: Optional[DateOffset]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class LocalPQueryDriver:
|
|
34
|
-
def __init__(
|
|
35
|
-
self,
|
|
36
|
-
graph_store: LocalGraphStore,
|
|
37
|
-
query: ValidatedPredictiveQuery,
|
|
38
|
-
random_seed: Optional[int] = None,
|
|
39
|
-
) -> None:
|
|
40
|
-
self._graph_store = graph_store
|
|
41
|
-
self._query = query
|
|
42
|
-
self._random_seed = random_seed
|
|
43
|
-
self._rng = np.random.default_rng(random_seed)
|
|
44
|
-
|
|
45
|
-
def _get_candidates(
|
|
46
|
-
self,
|
|
47
|
-
exclude_node: Optional[np.ndarray] = None,
|
|
48
|
-
) -> np.ndarray:
|
|
49
|
-
|
|
50
|
-
if self._query.query_type == QueryType.TEMPORAL:
|
|
51
|
-
assert exclude_node is None
|
|
52
|
-
|
|
53
|
-
table_name = self._query.entity_table
|
|
54
|
-
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
55
|
-
mask_dict = self._graph_store.mask_dict
|
|
56
|
-
|
|
57
|
-
candidate: np.ndarray
|
|
58
|
-
|
|
59
|
-
# Case 1: All nodes are valid and nothing to exclude:
|
|
60
|
-
if exclude_node is None and table_name not in mask_dict:
|
|
61
|
-
candidate = np.arange(num_nodes)
|
|
62
|
-
|
|
63
|
-
# Case 2: Not all nodes are valid - lookup valid nodes:
|
|
64
|
-
if exclude_node is None:
|
|
65
|
-
pkey_map = self._graph_store.pkey_map_dict[table_name]
|
|
66
|
-
candidate = pkey_map['arange'].to_numpy().copy()
|
|
67
|
-
|
|
68
|
-
# Case 3: Exclude nodes - use a mask to exclude them:
|
|
69
|
-
else:
|
|
70
|
-
mask = np.full((num_nodes, ), fill_value=True, dtype=bool)
|
|
71
|
-
mask[exclude_node] = False
|
|
72
|
-
if table_name in mask_dict:
|
|
73
|
-
mask &= mask_dict[table_name]
|
|
74
|
-
candidate = mask.nonzero()[0]
|
|
75
|
-
|
|
76
|
-
self._rng.shuffle(candidate)
|
|
77
|
-
|
|
78
|
-
return candidate
|
|
79
|
-
|
|
80
|
-
def _filter_candidates_by_time(
|
|
81
|
-
self,
|
|
82
|
-
candidate: np.ndarray,
|
|
83
|
-
anchor_time: pd.Timestamp,
|
|
84
|
-
) -> np.ndarray:
|
|
85
|
-
|
|
86
|
-
entity = self._query.entity_table
|
|
87
|
-
|
|
88
|
-
# Filter out entities that do not exist yet in time:
|
|
89
|
-
time_sec = self._graph_store.time_dict.get(entity)
|
|
90
|
-
if time_sec is not None:
|
|
91
|
-
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
92
|
-
candidate = candidate[mask]
|
|
93
|
-
|
|
94
|
-
# Filter out entities that no longer exist in time:
|
|
95
|
-
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
96
|
-
if end_time_col is not None:
|
|
97
|
-
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
98
|
-
ser = ser.iloc[candidate]
|
|
99
|
-
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
100
|
-
candidate = candidate[mask]
|
|
101
|
-
|
|
102
|
-
return candidate
|
|
103
|
-
|
|
104
|
-
def collect_test(
|
|
105
|
-
self,
|
|
106
|
-
size: int,
|
|
107
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
108
|
-
batch_size: Optional[int] = None,
|
|
109
|
-
max_iterations: int = 20,
|
|
110
|
-
guarantee_train_examples: bool = True,
|
|
111
|
-
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
112
|
-
r"""Collects test nodes and their labels used for evaluation.
|
|
113
|
-
|
|
114
|
-
Args:
|
|
115
|
-
size: The number of test nodes to collect.
|
|
116
|
-
anchor_time: The anchor time.
|
|
117
|
-
batch_size: How many nodes to process in a single batch.
|
|
118
|
-
max_iterations: The number of steps to run before aborting.
|
|
119
|
-
guarantee_train_examples: Ensures that test examples do not occupy
|
|
120
|
-
the entire set of entity candidates.
|
|
121
|
-
|
|
122
|
-
Returns:
|
|
123
|
-
A triplet holding the nodes, timestamps and labels.
|
|
124
|
-
"""
|
|
125
|
-
batch_size = size if batch_size is None else batch_size
|
|
126
|
-
|
|
127
|
-
candidate = self._get_candidates()
|
|
128
|
-
|
|
129
|
-
nodes: List[np.ndarray] = []
|
|
130
|
-
times: List[pd.Series] = []
|
|
131
|
-
ys: List[pd.Series] = []
|
|
132
|
-
|
|
133
|
-
reached_end = False
|
|
134
|
-
num_labels = candidate_offset = 0
|
|
135
|
-
for _ in range(max_iterations):
|
|
136
|
-
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
137
|
-
|
|
138
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
139
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
140
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
141
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
142
|
-
else:
|
|
143
|
-
assert anchor_time == 'entity'
|
|
144
|
-
time = self._graph_store.time_dict[self._query.entity_table]
|
|
145
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
146
|
-
|
|
147
|
-
y, mask = self(node, time)
|
|
148
|
-
|
|
149
|
-
nodes.append(node[mask])
|
|
150
|
-
times.append(time[mask].reset_index(drop=True))
|
|
151
|
-
ys.append(y)
|
|
152
|
-
|
|
153
|
-
num_labels += len(y)
|
|
154
|
-
|
|
155
|
-
if num_labels > size:
|
|
156
|
-
reached_end = True
|
|
157
|
-
break # Sufficient number of labels collected. Abort.
|
|
158
|
-
|
|
159
|
-
candidate_offset += batch_size
|
|
160
|
-
if candidate_offset >= len(candidate):
|
|
161
|
-
reached_end = True
|
|
162
|
-
break
|
|
163
|
-
|
|
164
|
-
if len(nodes) > 1:
|
|
165
|
-
node = np.concatenate(nodes, axis=0)[:size]
|
|
166
|
-
time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
|
|
167
|
-
y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
|
|
168
|
-
else:
|
|
169
|
-
node = nodes[0][:size]
|
|
170
|
-
time = times[0].iloc[:size]
|
|
171
|
-
y = ys[0].iloc[:size]
|
|
172
|
-
|
|
173
|
-
if len(node) == 0:
|
|
174
|
-
raise RuntimeError("Failed to collect any test examples for "
|
|
175
|
-
"evaluation. Is your predictive query too "
|
|
176
|
-
"restrictive?")
|
|
177
|
-
|
|
178
|
-
global _coverage_warned
|
|
179
|
-
if not _coverage_warned and not reached_end and len(node) < size // 2:
|
|
180
|
-
_coverage_warned = True
|
|
181
|
-
warnings.warn(f"Failed to collect {size:,} test examples within "
|
|
182
|
-
f"{max_iterations} iterations. To improve coverage, "
|
|
183
|
-
f"consider increasing the number of PQ iterations "
|
|
184
|
-
f"using the 'max_pq_iterations' option. This "
|
|
185
|
-
f"warning will not be shown again in this run.")
|
|
186
|
-
|
|
187
|
-
if (guarantee_train_examples
|
|
188
|
-
and self._query.query_type == QueryType.STATIC
|
|
189
|
-
and candidate_offset >= len(candidate)):
|
|
190
|
-
# In case all valid entities are used as test examples, we can no
|
|
191
|
-
# longer find any training example. Fallback to a 50/50 split:
|
|
192
|
-
size = len(node) // 2
|
|
193
|
-
node = node[:size]
|
|
194
|
-
time = time.iloc[:size]
|
|
195
|
-
y = y.iloc[:size]
|
|
196
|
-
|
|
197
|
-
return node, time, y
|
|
198
|
-
|
|
199
|
-
def collect_train(
|
|
200
|
-
self,
|
|
201
|
-
size: int,
|
|
202
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
203
|
-
exclude_node: Optional[np.ndarray] = None,
|
|
204
|
-
batch_size: Optional[int] = None,
|
|
205
|
-
max_iterations: int = 20,
|
|
206
|
-
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
207
|
-
r"""Collects training nodes and their labels.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
size: The number of test nodes to collect.
|
|
211
|
-
anchor_time: The anchor time.
|
|
212
|
-
exclude_node: The nodes to exclude for use as in-context examples.
|
|
213
|
-
batch_size: How many nodes to process in a single batch.
|
|
214
|
-
max_iterations: The number of steps to run before aborting.
|
|
215
|
-
|
|
216
|
-
Returns:
|
|
217
|
-
A triplet holding the nodes, timestamps and labels.
|
|
218
|
-
"""
|
|
219
|
-
batch_size = size if batch_size is None else batch_size
|
|
220
|
-
|
|
221
|
-
candidate = self._get_candidates(exclude_node)
|
|
222
|
-
|
|
223
|
-
if len(candidate) == 0:
|
|
224
|
-
raise RuntimeError("Failed to generate any context examples "
|
|
225
|
-
"since not enough entities exist")
|
|
226
|
-
|
|
227
|
-
nodes: List[np.ndarray] = []
|
|
228
|
-
times: List[pd.Series] = []
|
|
229
|
-
ys: List[pd.Series] = []
|
|
230
|
-
|
|
231
|
-
reached_end = False
|
|
232
|
-
num_labels = candidate_offset = 0
|
|
233
|
-
for _ in range(max_iterations):
|
|
234
|
-
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
235
|
-
|
|
236
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
237
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
238
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
239
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
240
|
-
else:
|
|
241
|
-
assert anchor_time == 'entity'
|
|
242
|
-
time = self._graph_store.time_dict[self._query.entity_table]
|
|
243
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
244
|
-
|
|
245
|
-
y, mask = self(node, time)
|
|
246
|
-
|
|
247
|
-
nodes.append(node[mask])
|
|
248
|
-
times.append(time[mask].reset_index(drop=True))
|
|
249
|
-
ys.append(y)
|
|
250
|
-
|
|
251
|
-
num_labels += len(y)
|
|
252
|
-
|
|
253
|
-
if num_labels > size:
|
|
254
|
-
reached_end = True
|
|
255
|
-
break # Sufficient number of labels collected. Abort.
|
|
256
|
-
|
|
257
|
-
candidate_offset += batch_size
|
|
258
|
-
if candidate_offset >= len(candidate):
|
|
259
|
-
# Restart with an earlier anchor time (if applicable).
|
|
260
|
-
if self._query.query_type == QueryType.STATIC:
|
|
261
|
-
reached_end = True
|
|
262
|
-
break # Cannot jump back in time for static PQs. Abort.
|
|
263
|
-
if anchor_time == 'entity':
|
|
264
|
-
reached_end = True
|
|
265
|
-
break
|
|
266
|
-
candidate_offset = 0
|
|
267
|
-
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
-
anchor_time = anchor_time - (time_frame *
|
|
269
|
-
self._query.num_forecasts)
|
|
270
|
-
if anchor_time < self._graph_store.min_time:
|
|
271
|
-
reached_end = True
|
|
272
|
-
break # No earlier anchor time left. Abort.
|
|
273
|
-
|
|
274
|
-
if len(nodes) > 1:
|
|
275
|
-
node = np.concatenate(nodes, axis=0)[:size]
|
|
276
|
-
time = pd.concat(times, axis=0).reset_index(drop=True).iloc[:size]
|
|
277
|
-
y = pd.concat(ys, axis=0).reset_index(drop=True).iloc[:size]
|
|
278
|
-
else:
|
|
279
|
-
node = nodes[0][:size]
|
|
280
|
-
time = times[0].iloc[:size]
|
|
281
|
-
y = ys[0].iloc[:size]
|
|
282
|
-
|
|
283
|
-
if len(node) == 0:
|
|
284
|
-
raise ValueError("Failed to collect any context examples. Is your "
|
|
285
|
-
"predictive query too restrictive?")
|
|
286
|
-
|
|
287
|
-
global _coverage_warned
|
|
288
|
-
if not _coverage_warned and not reached_end and len(node) < size // 2:
|
|
289
|
-
_coverage_warned = True
|
|
290
|
-
warnings.warn(f"Failed to collect {size:,} context examples "
|
|
291
|
-
f"within {max_iterations} iterations. To improve "
|
|
292
|
-
f"coverage, consider increasing the number of PQ "
|
|
293
|
-
f"iterations using the 'max_pq_iterations' option. "
|
|
294
|
-
f"This warning will not be shown again in this run.")
|
|
295
|
-
|
|
296
|
-
return node, time, y
|
|
297
|
-
|
|
298
|
-
def is_valid(
|
|
299
|
-
self,
|
|
300
|
-
node: np.ndarray,
|
|
301
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
302
|
-
batch_size: int = 10_000,
|
|
303
|
-
) -> np.ndarray:
|
|
304
|
-
r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
|
|
305
|
-
which nodes fulfill entity filter constraints.
|
|
306
|
-
|
|
307
|
-
Args:
|
|
308
|
-
node: The nodes to check for.
|
|
309
|
-
anchor_time: The anchor time.
|
|
310
|
-
batch_size: How many nodes to process in a single batch.
|
|
311
|
-
|
|
312
|
-
Returns:
|
|
313
|
-
The mask.
|
|
314
|
-
"""
|
|
315
|
-
mask: Optional[np.ndarray] = None
|
|
316
|
-
|
|
317
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
318
|
-
node = self._filter_candidates_by_time(node, anchor_time)
|
|
319
|
-
time = pd.Series(anchor_time).repeat(len(node))
|
|
320
|
-
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
321
|
-
else:
|
|
322
|
-
assert anchor_time == 'entity'
|
|
323
|
-
time = self._graph_store.time_dict[self._query.entity_table]
|
|
324
|
-
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
325
|
-
|
|
326
|
-
if isinstance(self._query.entity_ast, Filter):
|
|
327
|
-
# Mask out via (temporal) entity filter:
|
|
328
|
-
executor = PQueryPandasExecutor()
|
|
329
|
-
masks: List[np.ndarray] = []
|
|
330
|
-
for start in range(0, len(node), batch_size):
|
|
331
|
-
feat_dict, time_dict, batch_dict = self._sample(
|
|
332
|
-
node[start:start + batch_size],
|
|
333
|
-
time.iloc[start:start + batch_size],
|
|
334
|
-
)
|
|
335
|
-
_mask = executor.execute_filter(
|
|
336
|
-
filter=self._query.entity_ast,
|
|
337
|
-
feat_dict=feat_dict,
|
|
338
|
-
time_dict=time_dict,
|
|
339
|
-
batch_dict=batch_dict,
|
|
340
|
-
anchor_time=time.iloc[start:start + batch_size],
|
|
341
|
-
)[1]
|
|
342
|
-
masks.append(_mask)
|
|
343
|
-
|
|
344
|
-
_mask = np.concatenate(masks)
|
|
345
|
-
mask = (mask & _mask) if mask is not None else _mask
|
|
346
|
-
|
|
347
|
-
if mask is None:
|
|
348
|
-
mask = np.ones(len(node), dtype=bool)
|
|
349
|
-
|
|
350
|
-
return mask
|
|
351
|
-
|
|
352
|
-
def _get_sampling_specs(
|
|
353
|
-
self,
|
|
354
|
-
node: ASTNode,
|
|
355
|
-
hop: int,
|
|
356
|
-
seed_table_name: str,
|
|
357
|
-
edge_types: List[Tuple[str, str, str]],
|
|
358
|
-
num_forecasts: int = 1,
|
|
359
|
-
) -> List[SamplingSpec]:
|
|
360
|
-
if isinstance(node, (Aggregation, Column)):
|
|
361
|
-
if isinstance(node, Column):
|
|
362
|
-
table_name = node.fqn.split('.')[0]
|
|
363
|
-
if seed_table_name == table_name:
|
|
364
|
-
return []
|
|
365
|
-
else:
|
|
366
|
-
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
-
|
|
368
|
-
target_edge_types = [
|
|
369
|
-
edge_type for edge_type in edge_types if
|
|
370
|
-
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
-
]
|
|
372
|
-
if len(target_edge_types) != 1:
|
|
373
|
-
raise ValueError(
|
|
374
|
-
f"Could not find a unique foreign key from table "
|
|
375
|
-
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
-
|
|
377
|
-
if isinstance(node, Column):
|
|
378
|
-
return [
|
|
379
|
-
SamplingSpec(
|
|
380
|
-
edge_type=target_edge_types[0],
|
|
381
|
-
hop=hop + 1,
|
|
382
|
-
start_offset=None,
|
|
383
|
-
end_offset=None,
|
|
384
|
-
)
|
|
385
|
-
]
|
|
386
|
-
spec = SamplingSpec(
|
|
387
|
-
edge_type=target_edge_types[0],
|
|
388
|
-
hop=hop + 1,
|
|
389
|
-
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
-
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
-
num_forecasts,
|
|
392
|
-
)
|
|
393
|
-
return [spec] + self._get_sampling_specs(
|
|
394
|
-
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
-
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
-
specs = []
|
|
397
|
-
for child in node.children:
|
|
398
|
-
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
-
edge_types, num_forecasts)
|
|
400
|
-
return specs
|
|
401
|
-
|
|
402
|
-
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
-
edge_types = self._graph_store.edge_types
|
|
404
|
-
specs = self._get_sampling_specs(
|
|
405
|
-
self._query.target_ast, hop=0,
|
|
406
|
-
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
-
num_forecasts=self._query.num_forecasts)
|
|
408
|
-
specs += self._get_sampling_specs(
|
|
409
|
-
self._query.entity_ast, hop=0,
|
|
410
|
-
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
-
if self._query.whatif_ast is not None:
|
|
412
|
-
specs += self._get_sampling_specs(
|
|
413
|
-
self._query.whatif_ast, hop=0,
|
|
414
|
-
seed_table_name=self._query.entity_table,
|
|
415
|
-
edge_types=edge_types)
|
|
416
|
-
# Group specs according to edge type and hop:
|
|
417
|
-
spec_dict: Dict[
|
|
418
|
-
Tuple[Tuple[str, str, str], int],
|
|
419
|
-
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
-
] = {}
|
|
421
|
-
for spec in specs:
|
|
422
|
-
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
-
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
-
spec.start_offset,
|
|
425
|
-
spec.end_offset,
|
|
426
|
-
)
|
|
427
|
-
else:
|
|
428
|
-
start_offset, end_offset = spec_dict[(
|
|
429
|
-
spec.edge_type,
|
|
430
|
-
spec.hop,
|
|
431
|
-
)]
|
|
432
|
-
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
-
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
-
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
return [
|
|
438
|
-
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
-
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
-
]
|
|
441
|
-
|
|
442
|
-
def _sample(
|
|
443
|
-
self,
|
|
444
|
-
node: np.ndarray,
|
|
445
|
-
anchor_time: pd.Series,
|
|
446
|
-
) -> Tuple[
|
|
447
|
-
Dict[str, pd.DataFrame],
|
|
448
|
-
Dict[str, pd.Series],
|
|
449
|
-
Dict[str, np.ndarray],
|
|
450
|
-
]:
|
|
451
|
-
r"""Samples a subgraph that contains all relevant information to
|
|
452
|
-
evaluate the predictive query.
|
|
453
|
-
|
|
454
|
-
Args:
|
|
455
|
-
node: The nodes to check for.
|
|
456
|
-
anchor_time: The anchor time.
|
|
457
|
-
|
|
458
|
-
Returns:
|
|
459
|
-
The feature dictionary, the time column dictionary and the batch
|
|
460
|
-
dictionary.
|
|
461
|
-
"""
|
|
462
|
-
specs = self.get_sampling_specs()
|
|
463
|
-
num_hops = max([spec.hop for spec in specs] + [0])
|
|
464
|
-
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
465
|
-
time_offsets: Dict[
|
|
466
|
-
Tuple[str, str, str],
|
|
467
|
-
List[List[Optional[int]]],
|
|
468
|
-
] = {}
|
|
469
|
-
for spec in specs:
|
|
470
|
-
if spec.end_offset is not None:
|
|
471
|
-
if spec.edge_type not in time_offsets:
|
|
472
|
-
time_offsets[spec.edge_type] = [[0, 0]
|
|
473
|
-
for _ in range(num_hops)]
|
|
474
|
-
offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
|
|
475
|
-
time_offsets[spec.edge_type][spec.hop - 1][1] = offset
|
|
476
|
-
if spec.start_offset is not None:
|
|
477
|
-
offset = date_offset_to_seconds(spec.start_offset)
|
|
478
|
-
else:
|
|
479
|
-
offset = None
|
|
480
|
-
time_offsets[spec.edge_type][spec.hop - 1][0] = offset
|
|
481
|
-
else:
|
|
482
|
-
if spec.edge_type not in num_neighbors:
|
|
483
|
-
num_neighbors[spec.edge_type] = [0] * num_hops
|
|
484
|
-
num_neighbors[spec.edge_type][spec.hop - 1] = -1
|
|
485
|
-
|
|
486
|
-
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
487
|
-
node_types = list(
|
|
488
|
-
set([self._query.entity_table])
|
|
489
|
-
| set(src for src, _, _ in edge_types)
|
|
490
|
-
| set(dst for _, _, dst in edge_types))
|
|
491
|
-
|
|
492
|
-
sampler = kumolib.NeighborSampler(
|
|
493
|
-
node_types,
|
|
494
|
-
edge_types,
|
|
495
|
-
{
|
|
496
|
-
'__'.join(edge_type): self._graph_store.colptr_dict[edge_type]
|
|
497
|
-
for edge_type in edge_types
|
|
498
|
-
},
|
|
499
|
-
{
|
|
500
|
-
'__'.join(edge_type): self._graph_store.row_dict[edge_type]
|
|
501
|
-
for edge_type in edge_types
|
|
502
|
-
},
|
|
503
|
-
{
|
|
504
|
-
node_type: time
|
|
505
|
-
for node_type, time in self._graph_store.time_dict.items()
|
|
506
|
-
if node_type in node_types
|
|
507
|
-
},
|
|
508
|
-
)
|
|
509
|
-
|
|
510
|
-
anchor_time = anchor_time.astype('datetime64[ns]')
|
|
511
|
-
_, _, node_dict, batch_dict, _, _ = sampler.sample(
|
|
512
|
-
{
|
|
513
|
-
'__'.join(edge_type): np.array(values)
|
|
514
|
-
for edge_type, values in num_neighbors.items()
|
|
515
|
-
},
|
|
516
|
-
{
|
|
517
|
-
'__'.join(edge_type): np.array(values)
|
|
518
|
-
for edge_type, values in time_offsets.items()
|
|
519
|
-
},
|
|
520
|
-
self._query.entity_table,
|
|
521
|
-
node,
|
|
522
|
-
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
523
|
-
)
|
|
524
|
-
|
|
525
|
-
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
526
|
-
time_dict: Dict[str, pd.Series] = {}
|
|
527
|
-
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
-
for col in self._query.all_query_columns:
|
|
529
|
-
table_name, col_name = col.split('.')
|
|
530
|
-
if table_name not in column_dict:
|
|
531
|
-
column_dict[table_name] = set()
|
|
532
|
-
if col_name != '*':
|
|
533
|
-
column_dict[table_name].add(col_name)
|
|
534
|
-
time_tables = self.find_time_tables()
|
|
535
|
-
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
536
|
-
df = self._graph_store.df_dict[table_name]
|
|
537
|
-
row_id = node_dict[table_name]
|
|
538
|
-
df = df.iloc[row_id].reset_index(drop=True)
|
|
539
|
-
if table_name in column_dict:
|
|
540
|
-
if len(column_dict[table_name]) == 0:
|
|
541
|
-
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
-
# to ensure we don't lose the information on node count
|
|
543
|
-
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
-
{'ones': [1] * len(df)})
|
|
545
|
-
else:
|
|
546
|
-
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
547
|
-
if table_name in time_tables:
|
|
548
|
-
time_col = self._graph_store.time_column_dict[table_name]
|
|
549
|
-
time_dict[table_name] = df[time_col]
|
|
550
|
-
|
|
551
|
-
return feat_dict, time_dict, batch_dict
|
|
552
|
-
|
|
553
|
-
def __call__(
|
|
554
|
-
self,
|
|
555
|
-
node: np.ndarray,
|
|
556
|
-
anchor_time: pd.Series,
|
|
557
|
-
) -> Tuple[pd.Series, np.ndarray]:
|
|
558
|
-
|
|
559
|
-
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
560
|
-
|
|
561
|
-
y, mask = PQueryPandasExecutor().execute(
|
|
562
|
-
query=self._query,
|
|
563
|
-
feat_dict=feat_dict,
|
|
564
|
-
time_dict=time_dict,
|
|
565
|
-
batch_dict=batch_dict,
|
|
566
|
-
anchor_time=anchor_time,
|
|
567
|
-
num_forecasts=self._query.num_forecasts,
|
|
568
|
-
)
|
|
569
|
-
|
|
570
|
-
return y, mask
|
|
571
|
-
|
|
572
|
-
def find_time_tables(self) -> List[str]:
|
|
573
|
-
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
-
time_tables = []
|
|
575
|
-
if isinstance(node, Aggregation):
|
|
576
|
-
time_tables.append(
|
|
577
|
-
node._get_target_column_name().split('.')[0])
|
|
578
|
-
for child in node.children:
|
|
579
|
-
time_tables += _find_time_tables(child)
|
|
580
|
-
return time_tables
|
|
581
|
-
|
|
582
|
-
time_tables = _find_time_tables(
|
|
583
|
-
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
-
if self._query.whatif_ast is not None:
|
|
585
|
-
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
-
return list(set(time_tables))
|
|
587
|
-
|
|
588
|
-
@staticmethod
|
|
589
|
-
def get_task_type(
|
|
590
|
-
query: ValidatedPredictiveQuery,
|
|
591
|
-
edge_types: List[Tuple[str, str, str]],
|
|
592
|
-
) -> TaskType:
|
|
593
|
-
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
-
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
-
|
|
596
|
-
target = query.target_ast
|
|
597
|
-
if isinstance(target, Join):
|
|
598
|
-
target = target.rhs_target
|
|
599
|
-
if isinstance(target, Aggregation):
|
|
600
|
-
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
-
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
-
'.')
|
|
603
|
-
target_edge_types = [
|
|
604
|
-
edge_type for edge_type in edge_types
|
|
605
|
-
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
-
]
|
|
607
|
-
if len(target_edge_types) != 1:
|
|
608
|
-
raise NotImplementedError(
|
|
609
|
-
f"Multilabel-classification queries based on "
|
|
610
|
-
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
-
f"planned to write a link prediction query instead, "
|
|
612
|
-
f"make sure to register '{col_name}' as a "
|
|
613
|
-
f"foreign key.")
|
|
614
|
-
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
-
|
|
616
|
-
return TaskType.REGRESSION
|
|
617
|
-
|
|
618
|
-
assert isinstance(target, Column)
|
|
619
|
-
|
|
620
|
-
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
-
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
-
|
|
623
|
-
if target.stype in {Stype.numerical}:
|
|
624
|
-
return TaskType.REGRESSION
|
|
625
|
-
|
|
626
|
-
raise NotImplementedError("Task type not yet supported")
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
630
|
-
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
631
|
-
nanoseconds.
|
|
632
|
-
|
|
633
|
-
.. note::
|
|
634
|
-
We are conservative and take months and years as their maximum value.
|
|
635
|
-
Additional values are then dropped in label computation where we know
|
|
636
|
-
the actual dates.
|
|
637
|
-
"""
|
|
638
|
-
# Max durations for months and years in nanoseconds:
|
|
639
|
-
MAX_DAYS_IN_MONTH = 31
|
|
640
|
-
MAX_DAYS_IN_YEAR = 366
|
|
641
|
-
|
|
642
|
-
# Conversion factors:
|
|
643
|
-
SECONDS_IN_MINUTE = 60
|
|
644
|
-
SECONDS_IN_HOUR = 60 * SECONDS_IN_MINUTE
|
|
645
|
-
SECONDS_IN_DAY = 24 * SECONDS_IN_HOUR
|
|
646
|
-
|
|
647
|
-
total_ns = 0
|
|
648
|
-
multiplier = getattr(offset, 'n', 1) # The multiplier (if present).
|
|
649
|
-
|
|
650
|
-
for attr, value in offset.__dict__.items():
|
|
651
|
-
if value is None or value == 0:
|
|
652
|
-
continue
|
|
653
|
-
scaled_value = value * multiplier
|
|
654
|
-
if attr == 'years':
|
|
655
|
-
total_ns += scaled_value * MAX_DAYS_IN_YEAR * SECONDS_IN_DAY
|
|
656
|
-
elif attr == 'months':
|
|
657
|
-
total_ns += scaled_value * MAX_DAYS_IN_MONTH * SECONDS_IN_DAY
|
|
658
|
-
elif attr == 'days':
|
|
659
|
-
total_ns += scaled_value * SECONDS_IN_DAY
|
|
660
|
-
elif attr == 'hours':
|
|
661
|
-
total_ns += scaled_value * SECONDS_IN_HOUR
|
|
662
|
-
elif attr == 'minutes':
|
|
663
|
-
total_ns += scaled_value * SECONDS_IN_MINUTE
|
|
664
|
-
elif attr == 'seconds':
|
|
665
|
-
total_ns += scaled_value
|
|
666
|
-
|
|
667
|
-
return total_ns
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
-
if any(arg is None for arg in args):
|
|
672
|
-
return None
|
|
673
|
-
|
|
674
|
-
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
-
timestamps = [anchor + arg for arg in args]
|
|
676
|
-
assert len(timestamps) > 0
|
|
677
|
-
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
-
return args[argmin]
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
-
if any(arg is None for arg in args):
|
|
683
|
-
return None
|
|
684
|
-
|
|
685
|
-
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
-
timestamps = [anchor + arg for arg in args]
|
|
687
|
-
assert len(timestamps) > 0
|
|
688
|
-
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
-
return args[argmax]
|