kumoai 2.9.0.dev202509081831__cp312-cp312-win_amd64.whl → 2.13.0.dev202511201731__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +71 -102
- kumoai/connector/utils.py +1367 -236
- kumoai/experimental/rfm/__init__.py +153 -10
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph.py +90 -80
- kumoai/experimental/rfm/local_graph_sampler.py +16 -10
- kumoai/experimental/rfm/local_graph_store.py +22 -6
- kumoai/experimental/rfm/local_pquery_driver.py +336 -42
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -222
- kumoai/experimental/rfm/rfm.py +523 -124
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/jobs.py +1 -0
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/progress_logger.py +68 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/METADATA +13 -5
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/RECORD +30 -29
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/WHEEL +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.9.0.dev202509081831.dist-info → kumoai-2.13.0.dev202511201731.dist-info}/top_level.txt +0 -0
|
@@ -1,24 +1,41 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.
|
|
6
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
+
from kumoapi.pquery.AST import (
|
|
8
|
+
Aggregation,
|
|
9
|
+
ASTNode,
|
|
10
|
+
Column,
|
|
11
|
+
Condition,
|
|
12
|
+
Filter,
|
|
13
|
+
Join,
|
|
14
|
+
LogicalOperation,
|
|
15
|
+
)
|
|
16
|
+
from kumoapi.task import TaskType
|
|
17
|
+
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
8
18
|
|
|
9
19
|
import kumoai.kumolib as kumolib
|
|
10
20
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import
|
|
21
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
22
|
|
|
13
23
|
_coverage_warned = False
|
|
14
24
|
|
|
15
25
|
|
|
26
|
+
class SamplingSpec(NamedTuple):
|
|
27
|
+
edge_type: Tuple[str, str, str]
|
|
28
|
+
hop: int
|
|
29
|
+
start_offset: Optional[DateOffset]
|
|
30
|
+
end_offset: Optional[DateOffset]
|
|
31
|
+
|
|
32
|
+
|
|
16
33
|
class LocalPQueryDriver:
|
|
17
34
|
def __init__(
|
|
18
35
|
self,
|
|
19
36
|
graph_store: LocalGraphStore,
|
|
20
|
-
query:
|
|
21
|
-
random_seed: Optional[int],
|
|
37
|
+
query: ValidatedPredictiveQuery,
|
|
38
|
+
random_seed: Optional[int] = None,
|
|
22
39
|
) -> None:
|
|
23
40
|
self._graph_store = graph_store
|
|
24
41
|
self._query = query
|
|
@@ -27,14 +44,13 @@ class LocalPQueryDriver:
|
|
|
27
44
|
|
|
28
45
|
def _get_candidates(
|
|
29
46
|
self,
|
|
30
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
31
47
|
exclude_node: Optional[np.ndarray] = None,
|
|
32
48
|
) -> np.ndarray:
|
|
33
49
|
|
|
34
50
|
if self._query.query_type == QueryType.TEMPORAL:
|
|
35
51
|
assert exclude_node is None
|
|
36
52
|
|
|
37
|
-
table_name = self._query.
|
|
53
|
+
table_name = self._query.entity_table
|
|
38
54
|
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
39
55
|
mask_dict = self._graph_store.mask_dict
|
|
40
56
|
|
|
@@ -61,12 +77,37 @@ class LocalPQueryDriver:
|
|
|
61
77
|
|
|
62
78
|
return candidate
|
|
63
79
|
|
|
80
|
+
def _filter_candidates_by_time(
|
|
81
|
+
self,
|
|
82
|
+
candidate: np.ndarray,
|
|
83
|
+
anchor_time: pd.Timestamp,
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
|
|
86
|
+
entity = self._query.entity_table
|
|
87
|
+
|
|
88
|
+
# Filter out entities that do not exist yet in time:
|
|
89
|
+
time_sec = self._graph_store.time_dict.get(entity)
|
|
90
|
+
if time_sec is not None:
|
|
91
|
+
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
92
|
+
candidate = candidate[mask]
|
|
93
|
+
|
|
94
|
+
# Filter out entities that no longer exist in time:
|
|
95
|
+
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
96
|
+
if end_time_col is not None:
|
|
97
|
+
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
98
|
+
ser = ser.iloc[candidate]
|
|
99
|
+
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
100
|
+
candidate = candidate[mask]
|
|
101
|
+
|
|
102
|
+
return candidate
|
|
103
|
+
|
|
64
104
|
def collect_test(
|
|
65
105
|
self,
|
|
66
106
|
size: int,
|
|
67
107
|
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
68
108
|
batch_size: Optional[int] = None,
|
|
69
109
|
max_iterations: int = 20,
|
|
110
|
+
guarantee_train_examples: bool = True,
|
|
70
111
|
) -> Tuple[np.ndarray, pd.Series, pd.Series]:
|
|
71
112
|
r"""Collects test nodes and their labels used for evaluation.
|
|
72
113
|
|
|
@@ -75,13 +116,15 @@ class LocalPQueryDriver:
|
|
|
75
116
|
anchor_time: The anchor time.
|
|
76
117
|
batch_size: How many nodes to process in a single batch.
|
|
77
118
|
max_iterations: The number of steps to run before aborting.
|
|
119
|
+
guarantee_train_examples: Ensures that test examples do not occupy
|
|
120
|
+
the entire set of entity candidates.
|
|
78
121
|
|
|
79
122
|
Returns:
|
|
80
123
|
A triplet holding the nodes, timestamps and labels.
|
|
81
124
|
"""
|
|
82
125
|
batch_size = size if batch_size is None else batch_size
|
|
83
126
|
|
|
84
|
-
candidate = self._get_candidates(
|
|
127
|
+
candidate = self._get_candidates()
|
|
85
128
|
|
|
86
129
|
nodes: List[np.ndarray] = []
|
|
87
130
|
times: List[pd.Series] = []
|
|
@@ -93,19 +136,12 @@ class LocalPQueryDriver:
|
|
|
93
136
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
94
137
|
|
|
95
138
|
if isinstance(anchor_time, pd.Timestamp):
|
|
96
|
-
|
|
97
|
-
time = self._graph_store.time_dict.get(
|
|
98
|
-
self._query.entity.pkey.table_name)
|
|
99
|
-
if time is not None:
|
|
100
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
101
|
-
|
|
102
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
139
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
103
140
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
104
141
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
105
142
|
else:
|
|
106
143
|
assert anchor_time == 'entity'
|
|
107
|
-
time = self._graph_store.time_dict[
|
|
108
|
-
self._query.entity.pkey.table_name]
|
|
144
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
109
145
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
110
146
|
|
|
111
147
|
y, mask = self(node, time)
|
|
@@ -148,6 +184,16 @@ class LocalPQueryDriver:
|
|
|
148
184
|
f"using the 'max_pq_iterations' option. This "
|
|
149
185
|
f"warning will not be shown again in this run.")
|
|
150
186
|
|
|
187
|
+
if (guarantee_train_examples
|
|
188
|
+
and self._query.query_type == QueryType.STATIC
|
|
189
|
+
and candidate_offset >= len(candidate)):
|
|
190
|
+
# In case all valid entities are used as test examples, we can no
|
|
191
|
+
# longer find any training example. Fallback to a 50/50 split:
|
|
192
|
+
size = len(node) // 2
|
|
193
|
+
node = node[:size]
|
|
194
|
+
time = time.iloc[:size]
|
|
195
|
+
y = y.iloc[:size]
|
|
196
|
+
|
|
151
197
|
return node, time, y
|
|
152
198
|
|
|
153
199
|
def collect_train(
|
|
@@ -172,7 +218,7 @@ class LocalPQueryDriver:
|
|
|
172
218
|
"""
|
|
173
219
|
batch_size = size if batch_size is None else batch_size
|
|
174
220
|
|
|
175
|
-
candidate = self._get_candidates(
|
|
221
|
+
candidate = self._get_candidates(exclude_node)
|
|
176
222
|
|
|
177
223
|
if len(candidate) == 0:
|
|
178
224
|
raise RuntimeError("Failed to generate any context examples "
|
|
@@ -182,29 +228,18 @@ class LocalPQueryDriver:
|
|
|
182
228
|
times: List[pd.Series] = []
|
|
183
229
|
ys: List[pd.Series] = []
|
|
184
230
|
|
|
185
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
186
|
-
anchor_time = anchor_time - (self._query.target.end_offset *
|
|
187
|
-
self._query.num_forecasts)
|
|
188
|
-
|
|
189
231
|
reached_end = False
|
|
190
232
|
num_labels = candidate_offset = 0
|
|
191
233
|
for _ in range(max_iterations):
|
|
192
234
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
193
235
|
|
|
194
236
|
if isinstance(anchor_time, pd.Timestamp):
|
|
195
|
-
|
|
196
|
-
time = self._graph_store.time_dict.get(
|
|
197
|
-
self._query.entity.pkey.table_name)
|
|
198
|
-
if time is not None:
|
|
199
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
200
|
-
|
|
201
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
237
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
202
238
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
203
239
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
204
240
|
else:
|
|
205
241
|
assert anchor_time == 'entity'
|
|
206
|
-
time = self._graph_store.time_dict[
|
|
207
|
-
self._query.entity.pkey.table_name]
|
|
242
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
208
243
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
209
244
|
|
|
210
245
|
y, mask = self(node, time)
|
|
@@ -229,7 +264,8 @@ class LocalPQueryDriver:
|
|
|
229
264
|
reached_end = True
|
|
230
265
|
break
|
|
231
266
|
candidate_offset = 0
|
|
232
|
-
|
|
267
|
+
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
+
anchor_time = anchor_time - (time_frame *
|
|
233
269
|
self._query.num_forecasts)
|
|
234
270
|
if anchor_time < self._graph_store.min_time:
|
|
235
271
|
reached_end = True
|
|
@@ -259,13 +295,171 @@ class LocalPQueryDriver:
|
|
|
259
295
|
|
|
260
296
|
return node, time, y
|
|
261
297
|
|
|
262
|
-
def
|
|
298
|
+
def is_valid(
|
|
299
|
+
self,
|
|
300
|
+
node: np.ndarray,
|
|
301
|
+
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
302
|
+
batch_size: int = 10_000,
|
|
303
|
+
) -> np.ndarray:
|
|
304
|
+
r"""Denotes which nodes are valid for a given anchor time, *e.g.*,
|
|
305
|
+
which nodes fulfill entity filter constraints.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
node: The nodes to check for.
|
|
309
|
+
anchor_time: The anchor time.
|
|
310
|
+
batch_size: How many nodes to process in a single batch.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
The mask.
|
|
314
|
+
"""
|
|
315
|
+
mask: Optional[np.ndarray] = None
|
|
316
|
+
|
|
317
|
+
if isinstance(anchor_time, pd.Timestamp):
|
|
318
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
319
|
+
time = pd.Series(anchor_time).repeat(len(node))
|
|
320
|
+
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
321
|
+
else:
|
|
322
|
+
assert anchor_time == 'entity'
|
|
323
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
324
|
+
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
325
|
+
|
|
326
|
+
if isinstance(self._query.entity_ast, Filter):
|
|
327
|
+
# Mask out via (temporal) entity filter:
|
|
328
|
+
executor = PQueryPandasExecutor()
|
|
329
|
+
masks: List[np.ndarray] = []
|
|
330
|
+
for start in range(0, len(node), batch_size):
|
|
331
|
+
feat_dict, time_dict, batch_dict = self._sample(
|
|
332
|
+
node[start:start + batch_size],
|
|
333
|
+
time.iloc[start:start + batch_size],
|
|
334
|
+
)
|
|
335
|
+
_mask = executor.execute_filter(
|
|
336
|
+
filter=self._query.entity_ast,
|
|
337
|
+
feat_dict=feat_dict,
|
|
338
|
+
time_dict=time_dict,
|
|
339
|
+
batch_dict=batch_dict,
|
|
340
|
+
anchor_time=time.iloc[start:start + batch_size],
|
|
341
|
+
)[1]
|
|
342
|
+
masks.append(_mask)
|
|
343
|
+
|
|
344
|
+
_mask = np.concatenate(masks)
|
|
345
|
+
mask = (mask & _mask) if mask is not None else _mask
|
|
346
|
+
|
|
347
|
+
if mask is None:
|
|
348
|
+
mask = np.ones(len(node), dtype=bool)
|
|
349
|
+
|
|
350
|
+
return mask
|
|
351
|
+
|
|
352
|
+
def _get_sampling_specs(
|
|
353
|
+
self,
|
|
354
|
+
node: ASTNode,
|
|
355
|
+
hop: int,
|
|
356
|
+
seed_table_name: str,
|
|
357
|
+
edge_types: List[Tuple[str, str, str]],
|
|
358
|
+
num_forecasts: int = 1,
|
|
359
|
+
) -> List[SamplingSpec]:
|
|
360
|
+
if isinstance(node, (Aggregation, Column)):
|
|
361
|
+
if isinstance(node, Column):
|
|
362
|
+
table_name = node.fqn.split('.')[0]
|
|
363
|
+
if seed_table_name == table_name:
|
|
364
|
+
return []
|
|
365
|
+
else:
|
|
366
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
+
|
|
368
|
+
target_edge_types = [
|
|
369
|
+
edge_type for edge_type in edge_types if
|
|
370
|
+
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
+
]
|
|
372
|
+
if len(target_edge_types) != 1:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Could not find a unique foreign key from table "
|
|
375
|
+
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
+
|
|
377
|
+
if isinstance(node, Column):
|
|
378
|
+
return [
|
|
379
|
+
SamplingSpec(
|
|
380
|
+
edge_type=target_edge_types[0],
|
|
381
|
+
hop=hop + 1,
|
|
382
|
+
start_offset=None,
|
|
383
|
+
end_offset=None,
|
|
384
|
+
)
|
|
385
|
+
]
|
|
386
|
+
spec = SamplingSpec(
|
|
387
|
+
edge_type=target_edge_types[0],
|
|
388
|
+
hop=hop + 1,
|
|
389
|
+
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
+
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
+
num_forecasts,
|
|
392
|
+
)
|
|
393
|
+
return [spec] + self._get_sampling_specs(
|
|
394
|
+
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
+
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
+
specs = []
|
|
397
|
+
for child in node.children:
|
|
398
|
+
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
+
edge_types, num_forecasts)
|
|
400
|
+
return specs
|
|
401
|
+
|
|
402
|
+
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
+
edge_types = self._graph_store.edge_types
|
|
404
|
+
specs = self._get_sampling_specs(
|
|
405
|
+
self._query.target_ast, hop=0,
|
|
406
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
+
num_forecasts=self._query.num_forecasts)
|
|
408
|
+
specs += self._get_sampling_specs(
|
|
409
|
+
self._query.entity_ast, hop=0,
|
|
410
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
+
if self._query.whatif_ast is not None:
|
|
412
|
+
specs += self._get_sampling_specs(
|
|
413
|
+
self._query.whatif_ast, hop=0,
|
|
414
|
+
seed_table_name=self._query.entity_table,
|
|
415
|
+
edge_types=edge_types)
|
|
416
|
+
# Group specs according to edge type and hop:
|
|
417
|
+
spec_dict: Dict[
|
|
418
|
+
Tuple[Tuple[str, str, str], int],
|
|
419
|
+
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
+
] = {}
|
|
421
|
+
for spec in specs:
|
|
422
|
+
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
+
spec.start_offset,
|
|
425
|
+
spec.end_offset,
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
start_offset, end_offset = spec_dict[(
|
|
429
|
+
spec.edge_type,
|
|
430
|
+
spec.hop,
|
|
431
|
+
)]
|
|
432
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
+
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
+
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return [
|
|
438
|
+
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
+
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
+
]
|
|
441
|
+
|
|
442
|
+
def _sample(
|
|
263
443
|
self,
|
|
264
444
|
node: np.ndarray,
|
|
265
445
|
anchor_time: pd.Series,
|
|
266
|
-
) -> Tuple[
|
|
446
|
+
) -> Tuple[
|
|
447
|
+
Dict[str, pd.DataFrame],
|
|
448
|
+
Dict[str, pd.Series],
|
|
449
|
+
Dict[str, np.ndarray],
|
|
450
|
+
]:
|
|
451
|
+
r"""Samples a subgraph that contains all relevant information to
|
|
452
|
+
evaluate the predictive query.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
node: The nodes to check for.
|
|
456
|
+
anchor_time: The anchor time.
|
|
267
457
|
|
|
268
|
-
|
|
458
|
+
Returns:
|
|
459
|
+
The feature dictionary, the time column dictionary and the batch
|
|
460
|
+
dictionary.
|
|
461
|
+
"""
|
|
462
|
+
specs = self.get_sampling_specs()
|
|
269
463
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
270
464
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
271
465
|
time_offsets: Dict[
|
|
@@ -291,7 +485,7 @@ class LocalPQueryDriver:
|
|
|
291
485
|
|
|
292
486
|
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
293
487
|
node_types = list(
|
|
294
|
-
set([self._query.
|
|
488
|
+
set([self._query.entity_table])
|
|
295
489
|
| set(src for src, _, _ in edge_types)
|
|
296
490
|
| set(dst for _, _, dst in edge_types))
|
|
297
491
|
|
|
@@ -323,26 +517,48 @@ class LocalPQueryDriver:
|
|
|
323
517
|
'__'.join(edge_type): np.array(values)
|
|
324
518
|
for edge_type, values in time_offsets.items()
|
|
325
519
|
},
|
|
326
|
-
self._query.
|
|
520
|
+
self._query.entity_table,
|
|
327
521
|
node,
|
|
328
522
|
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
329
523
|
)
|
|
330
524
|
|
|
331
525
|
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
332
526
|
time_dict: Dict[str, pd.Series] = {}
|
|
333
|
-
column_dict =
|
|
334
|
-
|
|
527
|
+
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
+
for col in self._query.all_query_columns:
|
|
529
|
+
table_name, col_name = col.split('.')
|
|
530
|
+
if table_name not in column_dict:
|
|
531
|
+
column_dict[table_name] = set()
|
|
532
|
+
if col_name != '*':
|
|
533
|
+
column_dict[table_name].add(col_name)
|
|
534
|
+
time_tables = self.find_time_tables()
|
|
335
535
|
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
336
536
|
df = self._graph_store.df_dict[table_name]
|
|
337
537
|
row_id = node_dict[table_name]
|
|
338
538
|
df = df.iloc[row_id].reset_index(drop=True)
|
|
339
539
|
if table_name in column_dict:
|
|
340
|
-
|
|
540
|
+
if len(column_dict[table_name]) == 0:
|
|
541
|
+
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
+
# to ensure we don't lose the information on node count
|
|
543
|
+
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
+
{'ones': [1] * len(df)})
|
|
545
|
+
else:
|
|
546
|
+
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
341
547
|
if table_name in time_tables:
|
|
342
548
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
343
549
|
time_dict[table_name] = df[time_col]
|
|
344
550
|
|
|
345
|
-
|
|
551
|
+
return feat_dict, time_dict, batch_dict
|
|
552
|
+
|
|
553
|
+
def __call__(
|
|
554
|
+
self,
|
|
555
|
+
node: np.ndarray,
|
|
556
|
+
anchor_time: pd.Series,
|
|
557
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
558
|
+
|
|
559
|
+
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
560
|
+
|
|
561
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
346
562
|
query=self._query,
|
|
347
563
|
feat_dict=feat_dict,
|
|
348
564
|
time_dict=time_dict,
|
|
@@ -353,6 +569,62 @@ class LocalPQueryDriver:
|
|
|
353
569
|
|
|
354
570
|
return y, mask
|
|
355
571
|
|
|
572
|
+
def find_time_tables(self) -> List[str]:
|
|
573
|
+
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
+
time_tables = []
|
|
575
|
+
if isinstance(node, Aggregation):
|
|
576
|
+
time_tables.append(
|
|
577
|
+
node._get_target_column_name().split('.')[0])
|
|
578
|
+
for child in node.children:
|
|
579
|
+
time_tables += _find_time_tables(child)
|
|
580
|
+
return time_tables
|
|
581
|
+
|
|
582
|
+
time_tables = _find_time_tables(
|
|
583
|
+
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
+
if self._query.whatif_ast is not None:
|
|
585
|
+
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
+
return list(set(time_tables))
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_task_type(
|
|
590
|
+
query: ValidatedPredictiveQuery,
|
|
591
|
+
edge_types: List[Tuple[str, str, str]],
|
|
592
|
+
) -> TaskType:
|
|
593
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
+
|
|
596
|
+
target = query.target_ast
|
|
597
|
+
if isinstance(target, Join):
|
|
598
|
+
target = target.rhs_target
|
|
599
|
+
if isinstance(target, Aggregation):
|
|
600
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
+
'.')
|
|
603
|
+
target_edge_types = [
|
|
604
|
+
edge_type for edge_type in edge_types
|
|
605
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
+
]
|
|
607
|
+
if len(target_edge_types) != 1:
|
|
608
|
+
raise NotImplementedError(
|
|
609
|
+
f"Multilabel-classification queries based on "
|
|
610
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
+
f"planned to write a link prediction query instead, "
|
|
612
|
+
f"make sure to register '{col_name}' as a "
|
|
613
|
+
f"foreign key.")
|
|
614
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
+
|
|
616
|
+
return TaskType.REGRESSION
|
|
617
|
+
|
|
618
|
+
assert isinstance(target, Column)
|
|
619
|
+
|
|
620
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
+
|
|
623
|
+
if target.stype in {Stype.numerical}:
|
|
624
|
+
return TaskType.REGRESSION
|
|
625
|
+
|
|
626
|
+
raise NotImplementedError("Task type not yet supported")
|
|
627
|
+
|
|
356
628
|
|
|
357
629
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
358
630
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
@@ -393,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
|
393
665
|
total_ns += scaled_value
|
|
394
666
|
|
|
395
667
|
return total_ns
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
+
if any(arg is None for arg in args):
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
+
timestamps = [anchor + arg for arg in args]
|
|
676
|
+
assert len(timestamps) > 0
|
|
677
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
+
return args[argmin]
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
+
if any(arg is None for arg in args):
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
+
timestamps = [anchor + arg for arg in args]
|
|
687
|
+
assert len(timestamps) > 0
|
|
688
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
+
return args[argmax]
|