kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
|
@@ -1,24 +1,41 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.
|
|
6
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
+
from kumoapi.pquery.AST import (
|
|
8
|
+
Aggregation,
|
|
9
|
+
ASTNode,
|
|
10
|
+
Column,
|
|
11
|
+
Condition,
|
|
12
|
+
Filter,
|
|
13
|
+
Join,
|
|
14
|
+
LogicalOperation,
|
|
15
|
+
)
|
|
16
|
+
from kumoapi.task import TaskType
|
|
17
|
+
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
8
18
|
|
|
9
19
|
import kumoai.kumolib as kumolib
|
|
10
20
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import
|
|
21
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
22
|
|
|
13
23
|
_coverage_warned = False
|
|
14
24
|
|
|
15
25
|
|
|
26
|
+
class SamplingSpec(NamedTuple):
|
|
27
|
+
edge_type: Tuple[str, str, str]
|
|
28
|
+
hop: int
|
|
29
|
+
start_offset: Optional[DateOffset]
|
|
30
|
+
end_offset: Optional[DateOffset]
|
|
31
|
+
|
|
32
|
+
|
|
16
33
|
class LocalPQueryDriver:
|
|
17
34
|
def __init__(
|
|
18
35
|
self,
|
|
19
36
|
graph_store: LocalGraphStore,
|
|
20
|
-
query:
|
|
21
|
-
random_seed: Optional[int],
|
|
37
|
+
query: ValidatedPredictiveQuery,
|
|
38
|
+
random_seed: Optional[int] = None,
|
|
22
39
|
) -> None:
|
|
23
40
|
self._graph_store = graph_store
|
|
24
41
|
self._query = query
|
|
@@ -27,14 +44,13 @@ class LocalPQueryDriver:
|
|
|
27
44
|
|
|
28
45
|
def _get_candidates(
|
|
29
46
|
self,
|
|
30
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
31
47
|
exclude_node: Optional[np.ndarray] = None,
|
|
32
48
|
) -> np.ndarray:
|
|
33
49
|
|
|
34
50
|
if self._query.query_type == QueryType.TEMPORAL:
|
|
35
51
|
assert exclude_node is None
|
|
36
52
|
|
|
37
|
-
table_name = self._query.
|
|
53
|
+
table_name = self._query.entity_table
|
|
38
54
|
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
39
55
|
mask_dict = self._graph_store.mask_dict
|
|
40
56
|
|
|
@@ -61,12 +77,37 @@ class LocalPQueryDriver:
|
|
|
61
77
|
|
|
62
78
|
return candidate
|
|
63
79
|
|
|
80
|
+
def _filter_candidates_by_time(
|
|
81
|
+
self,
|
|
82
|
+
candidate: np.ndarray,
|
|
83
|
+
anchor_time: pd.Timestamp,
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
|
|
86
|
+
entity = self._query.entity_table
|
|
87
|
+
|
|
88
|
+
# Filter out entities that do not exist yet in time:
|
|
89
|
+
time_sec = self._graph_store.time_dict.get(entity)
|
|
90
|
+
if time_sec is not None:
|
|
91
|
+
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
92
|
+
candidate = candidate[mask]
|
|
93
|
+
|
|
94
|
+
# Filter out entities that no longer exist in time:
|
|
95
|
+
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
96
|
+
if end_time_col is not None:
|
|
97
|
+
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
98
|
+
ser = ser.iloc[candidate]
|
|
99
|
+
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
100
|
+
candidate = candidate[mask]
|
|
101
|
+
|
|
102
|
+
return candidate
|
|
103
|
+
|
|
64
104
|
def collect_test(
|
|
65
105
|
self,
|
|
66
106
|
size: int,
|
|
67
107
|
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
68
108
|
batch_size: Optional[int] = None,
|
|
69
109
|
max_iterations: int = 20,
|
|
110
|
+
guarantee_train_examples: bool = True,
|
|
70
111
|
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
71
112
|
r"""Collects test nodes and their labels used for evaluation.
|
|
72
113
|
|
|
@@ -75,13 +116,15 @@ class LocalPQueryDriver:
|
|
|
75
116
|
anchor_time: The anchor time.
|
|
76
117
|
batch_size: How many nodes to process in a single batch.
|
|
77
118
|
max_iterations: The number of steps to run before aborting.
|
|
119
|
+
guarantee_train_examples: Ensures that test examples do not occupy
|
|
120
|
+
the entire set of entity candidates.
|
|
78
121
|
|
|
79
122
|
Returns:
|
|
80
123
|
A triplet holding the nodes, timestamps and labels.
|
|
81
124
|
"""
|
|
82
125
|
batch_size = size if batch_size is None else batch_size
|
|
83
126
|
|
|
84
|
-
candidate = self._get_candidates(
|
|
127
|
+
candidate = self._get_candidates()
|
|
85
128
|
|
|
86
129
|
nodes: List[np.ndarray] = []
|
|
87
130
|
times: List[pd.Series] = []
|
|
@@ -93,19 +136,12 @@ class LocalPQueryDriver:
|
|
|
93
136
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
94
137
|
|
|
95
138
|
if isinstance(anchor_time, pd.Timestamp):
|
|
96
|
-
|
|
97
|
-
time = self._graph_store.time_dict.get(
|
|
98
|
-
self._query.entity.pkey.table_name)
|
|
99
|
-
if time is not None:
|
|
100
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
101
|
-
|
|
102
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
139
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
103
140
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
104
141
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
105
142
|
else:
|
|
106
143
|
assert anchor_time == 'entity'
|
|
107
|
-
time = self._graph_store.time_dict[
|
|
108
|
-
self._query.entity.pkey.table_name]
|
|
144
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
109
145
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
110
146
|
|
|
111
147
|
y, mask = self(node, time)
|
|
@@ -148,6 +184,16 @@ class LocalPQueryDriver:
|
|
|
148
184
|
f"using the 'max_pq_iterations' option. This "
|
|
149
185
|
f"warning will not be shown again in this run.")
|
|
150
186
|
|
|
187
|
+
if (guarantee_train_examples
|
|
188
|
+
and self._query.query_type == QueryType.STATIC
|
|
189
|
+
and candidate_offset >= len(candidate)):
|
|
190
|
+
# In case all valid entities are used as test examples, we can no
|
|
191
|
+
# longer find any training example. Fallback to a 50/50 split:
|
|
192
|
+
size = len(node) // 2
|
|
193
|
+
node = node[:size]
|
|
194
|
+
time = time.iloc[:size]
|
|
195
|
+
y = y.iloc[:size]
|
|
196
|
+
|
|
151
197
|
return node, time, y
|
|
152
198
|
|
|
153
199
|
def collect_train(
|
|
@@ -172,7 +218,7 @@ class LocalPQueryDriver:
|
|
|
172
218
|
"""
|
|
173
219
|
batch_size = size if batch_size is None else batch_size
|
|
174
220
|
|
|
175
|
-
candidate = self._get_candidates(
|
|
221
|
+
candidate = self._get_candidates(exclude_node)
|
|
176
222
|
|
|
177
223
|
if len(candidate) == 0:
|
|
178
224
|
raise RuntimeError("Failed to generate any context examples "
|
|
@@ -182,28 +228,18 @@ class LocalPQueryDriver:
|
|
|
182
228
|
times: List[pd.Series] = []
|
|
183
229
|
ys: List[pd.Series] = []
|
|
184
230
|
|
|
185
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
186
|
-
anchor_time = anchor_time - self._query.target.end_offset
|
|
187
|
-
|
|
188
231
|
reached_end = False
|
|
189
232
|
num_labels = candidate_offset = 0
|
|
190
233
|
for _ in range(max_iterations):
|
|
191
234
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
192
235
|
|
|
193
236
|
if isinstance(anchor_time, pd.Timestamp):
|
|
194
|
-
|
|
195
|
-
time = self._graph_store.time_dict.get(
|
|
196
|
-
self._query.entity.pkey.table_name)
|
|
197
|
-
if time is not None:
|
|
198
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
199
|
-
|
|
200
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
237
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
201
238
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
202
239
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
203
240
|
else:
|
|
204
241
|
assert anchor_time == 'entity'
|
|
205
|
-
time = self._graph_store.time_dict[
|
|
206
|
-
self._query.entity.pkey.table_name]
|
|
242
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
207
243
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
208
244
|
|
|
209
245
|
y, mask = self(node, time)
|
|
@@ -228,7 +264,9 @@ class LocalPQueryDriver:
|
|
|
228
264
|
reached_end = True
|
|
229
265
|
break
|
|
230
266
|
candidate_offset = 0
|
|
231
|
-
|
|
267
|
+
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
+
anchor_time = anchor_time - (time_frame *
|
|
269
|
+
self._query.num_forecasts)
|
|
232
270
|
if anchor_time < self._graph_store.min_time:
|
|
233
271
|
reached_end = True
|
|
234
272
|
break # No earlier anchor time left. Abort.
|
|
@@ -257,13 +295,171 @@ class LocalPQueryDriver:
|
|
|
257
295
|
|
|
258
296
|
return node, time, y
|
|
259
297
|
|
|
260
|
-
def
|
|
298
|
+
def is_valid(
|
|
299
|
+
self,
|
|
300
|
+
node: np.ndarray,
|
|
301
|
+
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
302
|
+
batch_size: int = 10_000,
|
|
303
|
+
) -> np.ndarray:
|
|
304
|
+
r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
|
|
305
|
+
which nodes fulfill entity filter constraints.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
node: The nodes to check for.
|
|
309
|
+
anchor_time: The anchor time.
|
|
310
|
+
batch_size: How many nodes to process in a single batch.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
The mask.
|
|
314
|
+
"""
|
|
315
|
+
mask: Optional[np.ndarray] = None
|
|
316
|
+
|
|
317
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
318
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
319
|
+
time = pd.Series(anchor_time).repeat(len(node))
|
|
320
|
+
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
321
|
+
else:
|
|
322
|
+
assert anchor_time == 'entity'
|
|
323
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
324
|
+
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
325
|
+
|
|
326
|
+
if isinstance(self._query.entity_ast, Filter):
|
|
327
|
+
# Mask out via (temporal) entity filter:
|
|
328
|
+
executor = PQueryPandasExecutor()
|
|
329
|
+
masks: List[np.ndarray] = []
|
|
330
|
+
for start in range(0, len(node), batch_size):
|
|
331
|
+
feat_dict, time_dict, batch_dict = self._sample(
|
|
332
|
+
node[start:start + batch_size],
|
|
333
|
+
time.iloc[start:start + batch_size],
|
|
334
|
+
)
|
|
335
|
+
_mask = executor.execute_filter(
|
|
336
|
+
filter=self._query.entity_ast,
|
|
337
|
+
feat_dict=feat_dict,
|
|
338
|
+
time_dict=time_dict,
|
|
339
|
+
batch_dict=batch_dict,
|
|
340
|
+
anchor_time=time.iloc[start:start + batch_size],
|
|
341
|
+
)[1]
|
|
342
|
+
masks.append(_mask)
|
|
343
|
+
|
|
344
|
+
_mask = np.concatenate(masks)
|
|
345
|
+
mask = (mask & _mask) if mask is not None else _mask
|
|
346
|
+
|
|
347
|
+
if mask is None:
|
|
348
|
+
mask = np.ones(len(node), dtype=bool)
|
|
349
|
+
|
|
350
|
+
return mask
|
|
351
|
+
|
|
352
|
+
def _get_sampling_specs(
|
|
353
|
+
self,
|
|
354
|
+
node: ASTNode,
|
|
355
|
+
hop: int,
|
|
356
|
+
seed_table_name: str,
|
|
357
|
+
edge_types: List[Tuple[str, str, str]],
|
|
358
|
+
num_forecasts: int = 1,
|
|
359
|
+
) -> List[SamplingSpec]:
|
|
360
|
+
if isinstance(node, (Aggregation, Column)):
|
|
361
|
+
if isinstance(node, Column):
|
|
362
|
+
table_name = node.fqn.split('.')[0]
|
|
363
|
+
if seed_table_name == table_name:
|
|
364
|
+
return []
|
|
365
|
+
else:
|
|
366
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
+
|
|
368
|
+
target_edge_types = [
|
|
369
|
+
edge_type for edge_type in edge_types if
|
|
370
|
+
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
+
]
|
|
372
|
+
if len(target_edge_types) != 1:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Could not find a unique foreign key from table "
|
|
375
|
+
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
+
|
|
377
|
+
if isinstance(node, Column):
|
|
378
|
+
return [
|
|
379
|
+
SamplingSpec(
|
|
380
|
+
edge_type=target_edge_types[0],
|
|
381
|
+
hop=hop + 1,
|
|
382
|
+
start_offset=None,
|
|
383
|
+
end_offset=None,
|
|
384
|
+
)
|
|
385
|
+
]
|
|
386
|
+
spec = SamplingSpec(
|
|
387
|
+
edge_type=target_edge_types[0],
|
|
388
|
+
hop=hop + 1,
|
|
389
|
+
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
+
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
+
num_forecasts,
|
|
392
|
+
)
|
|
393
|
+
return [spec] + self._get_sampling_specs(
|
|
394
|
+
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
+
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
+
specs = []
|
|
397
|
+
for child in node.children:
|
|
398
|
+
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
+
edge_types, num_forecasts)
|
|
400
|
+
return specs
|
|
401
|
+
|
|
402
|
+
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
+
edge_types = self._graph_store.edge_types
|
|
404
|
+
specs = self._get_sampling_specs(
|
|
405
|
+
self._query.target_ast, hop=0,
|
|
406
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
+
num_forecasts=self._query.num_forecasts)
|
|
408
|
+
specs += self._get_sampling_specs(
|
|
409
|
+
self._query.entity_ast, hop=0,
|
|
410
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
+
if self._query.whatif_ast is not None:
|
|
412
|
+
specs += self._get_sampling_specs(
|
|
413
|
+
self._query.whatif_ast, hop=0,
|
|
414
|
+
seed_table_name=self._query.entity_table,
|
|
415
|
+
edge_types=edge_types)
|
|
416
|
+
# Group specs according to edge type and hop:
|
|
417
|
+
spec_dict: Dict[
|
|
418
|
+
Tuple[Tuple[str, str, str], int],
|
|
419
|
+
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
+
] = {}
|
|
421
|
+
for spec in specs:
|
|
422
|
+
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
+
spec.start_offset,
|
|
425
|
+
spec.end_offset,
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
start_offset, end_offset = spec_dict[(
|
|
429
|
+
spec.edge_type,
|
|
430
|
+
spec.hop,
|
|
431
|
+
)]
|
|
432
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
+
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
+
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return [
|
|
438
|
+
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
+
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
def _sample(
|
|
261
443
|
self,
|
|
262
444
|
node: np.ndarray,
|
|
263
445
|
anchor_time: pd.Series,
|
|
264
|
-
) -> Tuple[
|
|
446
|
+
) -> Tuple[
|
|
447
|
+
Dict[str, pd.DataFrame],
|
|
448
|
+
Dict[str, pd.Series],
|
|
449
|
+
Dict[str, np.ndarray],
|
|
450
|
+
]:
|
|
451
|
+
r"""Samples a subgraph that contains all relevant information to
|
|
452
|
+
evaluate the predictive query.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
node: The nodes to check for.
|
|
456
|
+
anchor_time: The anchor time.
|
|
265
457
|
|
|
266
|
-
|
|
458
|
+
Returns:
|
|
459
|
+
The feature dictionary, the time column dictionary and the batch
|
|
460
|
+
dictionary.
|
|
461
|
+
"""
|
|
462
|
+
specs = self.get_sampling_specs()
|
|
267
463
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
268
464
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
269
465
|
time_offsets: Dict[
|
|
@@ -275,11 +471,10 @@ class LocalPQueryDriver:
|
|
|
275
471
|
if spec.edge_type not in time_offsets:
|
|
276
472
|
time_offsets[spec.edge_type] = [[0, 0]
|
|
277
473
|
for _ in range(num_hops)]
|
|
278
|
-
offset: Optional[int] =
|
|
279
|
-
spec.end_offset)
|
|
474
|
+
offset: Optional[int] = date_offset_to_seconds(spec.end_offset)
|
|
280
475
|
time_offsets[spec.edge_type][spec.hop - 1][1] = offset
|
|
281
476
|
if spec.start_offset is not None:
|
|
282
|
-
offset =
|
|
477
|
+
offset = date_offset_to_seconds(spec.start_offset)
|
|
283
478
|
else:
|
|
284
479
|
offset = None
|
|
285
480
|
time_offsets[spec.edge_type][spec.hop - 1][0] = offset
|
|
@@ -290,7 +485,7 @@ class LocalPQueryDriver:
|
|
|
290
485
|
|
|
291
486
|
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
292
487
|
node_types = list(
|
|
293
|
-
set([self._query.
|
|
488
|
+
set([self._query.entity_table])
|
|
294
489
|
| set(src for src, _, _ in edge_types)
|
|
295
490
|
| set(dst for _, _, dst in edge_types))
|
|
296
491
|
|
|
@@ -322,37 +517,116 @@ class LocalPQueryDriver:
|
|
|
322
517
|
'__'.join(edge_type): np.array(values)
|
|
323
518
|
for edge_type, values in time_offsets.items()
|
|
324
519
|
},
|
|
325
|
-
self._query.
|
|
520
|
+
self._query.entity_table,
|
|
326
521
|
node,
|
|
327
522
|
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
328
523
|
)
|
|
329
524
|
|
|
330
525
|
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
331
526
|
time_dict: Dict[str, pd.Series] = {}
|
|
332
|
-
column_dict =
|
|
333
|
-
|
|
527
|
+
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
+
for col in self._query.all_query_columns:
|
|
529
|
+
table_name, col_name = col.split('.')
|
|
530
|
+
if table_name not in column_dict:
|
|
531
|
+
column_dict[table_name] = set()
|
|
532
|
+
if col_name != '*':
|
|
533
|
+
column_dict[table_name].add(col_name)
|
|
534
|
+
time_tables = self.find_time_tables()
|
|
334
535
|
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
335
536
|
df = self._graph_store.df_dict[table_name]
|
|
336
537
|
row_id = node_dict[table_name]
|
|
337
538
|
df = df.iloc[row_id].reset_index(drop=True)
|
|
338
539
|
if table_name in column_dict:
|
|
339
|
-
|
|
540
|
+
if len(column_dict[table_name]) == 0:
|
|
541
|
+
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
+
# to ensure we don't lose the information on node count
|
|
543
|
+
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
+
{'ones': [1] * len(df)})
|
|
545
|
+
else:
|
|
546
|
+
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
340
547
|
if table_name in time_tables:
|
|
341
548
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
342
549
|
time_dict[table_name] = df[time_col]
|
|
343
550
|
|
|
344
|
-
|
|
551
|
+
return feat_dict, time_dict, batch_dict
|
|
552
|
+
|
|
553
|
+
def __call__(
|
|
554
|
+
self,
|
|
555
|
+
node: np.ndarray,
|
|
556
|
+
anchor_time: pd.Series,
|
|
557
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
558
|
+
|
|
559
|
+
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
560
|
+
|
|
561
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
345
562
|
query=self._query,
|
|
346
563
|
feat_dict=feat_dict,
|
|
347
564
|
time_dict=time_dict,
|
|
348
565
|
batch_dict=batch_dict,
|
|
349
566
|
anchor_time=anchor_time,
|
|
567
|
+
num_forecasts=self._query.num_forecasts,
|
|
350
568
|
)
|
|
351
569
|
|
|
352
570
|
return y, mask
|
|
353
571
|
|
|
354
|
-
|
|
355
|
-
def
|
|
572
|
+
def find_time_tables(self) -> List[str]:
|
|
573
|
+
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
+
time_tables = []
|
|
575
|
+
if isinstance(node, Aggregation):
|
|
576
|
+
time_tables.append(
|
|
577
|
+
node._get_target_column_name().split('.')[0])
|
|
578
|
+
for child in node.children:
|
|
579
|
+
time_tables += _find_time_tables(child)
|
|
580
|
+
return time_tables
|
|
581
|
+
|
|
582
|
+
time_tables = _find_time_tables(
|
|
583
|
+
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
+
if self._query.whatif_ast is not None:
|
|
585
|
+
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
+
return list(set(time_tables))
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_task_type(
|
|
590
|
+
query: ValidatedPredictiveQuery,
|
|
591
|
+
edge_types: List[Tuple[str, str, str]],
|
|
592
|
+
) -> TaskType:
|
|
593
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
+
|
|
596
|
+
target = query.target_ast
|
|
597
|
+
if isinstance(target, Join):
|
|
598
|
+
target = target.rhs_target
|
|
599
|
+
if isinstance(target, Aggregation):
|
|
600
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
+
'.')
|
|
603
|
+
target_edge_types = [
|
|
604
|
+
edge_type for edge_type in edge_types
|
|
605
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
+
]
|
|
607
|
+
if len(target_edge_types) != 1:
|
|
608
|
+
raise NotImplementedError(
|
|
609
|
+
f"Multilabel-classification queries based on "
|
|
610
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
+
f"planned to write a link prediction query instead, "
|
|
612
|
+
f"make sure to register '{col_name}' as a "
|
|
613
|
+
f"foreign key.")
|
|
614
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
+
|
|
616
|
+
return TaskType.REGRESSION
|
|
617
|
+
|
|
618
|
+
assert isinstance(target, Column)
|
|
619
|
+
|
|
620
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
+
|
|
623
|
+
if target.stype in {Stype.numerical}:
|
|
624
|
+
return TaskType.REGRESSION
|
|
625
|
+
|
|
626
|
+
raise NotImplementedError("Task type not yet supported")
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
356
630
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
357
631
|
nanoseconds.
|
|
358
632
|
|
|
@@ -391,3 +665,25 @@ def _date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
|
391
665
|
total_ns += scaled_value
|
|
392
666
|
|
|
393
667
|
return total_ns
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
+
if any(arg is None for arg in args):
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
+
timestamps = [anchor + arg for arg in args]
|
|
676
|
+
assert len(timestamps) > 0
|
|
677
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
+
return args[argmin]
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
+
if any(arg is None for arg in args):
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
+
timestamps = [anchor + arg for arg in args]
|
|
687
|
+
assert len(timestamps) > 0
|
|
688
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
+
return args[argmax]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
1
|
+
from .executor import PQueryExecutor
|
|
2
|
+
from .pandas_executor import PQueryPandasExecutor
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
-
'
|
|
6
|
-
'
|
|
5
|
+
'PQueryExecutor',
|
|
6
|
+
'PQueryPandasExecutor',
|
|
7
7
|
]
|