kumoai 2.10.0.dev202510061830__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +10 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +153 -10
- kumoai/experimental/rfm/infer/timestamp.py +5 -4
- kumoai/experimental/rfm/local_graph.py +90 -74
- kumoai/experimental/rfm/local_graph_sampler.py +16 -10
- kumoai/experimental/rfm/local_graph_store.py +13 -1
- kumoai/experimental/rfm/local_pquery_driver.py +249 -49
- kumoai/experimental/rfm/local_table.py +100 -22
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +174 -91
- kumoai/experimental/rfm/sagemaker.py +130 -0
- kumoai/jobs.py +1 -0
- kumoai/spcs.py +1 -3
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/METADATA +13 -5
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/RECORD +26 -25
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202510061830.dist-info → kumoai-2.13.0.dev202511261731.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
5
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
6
|
from kumoapi.typing import Stype
|
|
8
7
|
|
|
@@ -33,7 +32,6 @@ class LocalGraphSampler:
|
|
|
33
32
|
entity_table_names: Tuple[str, ...],
|
|
34
33
|
node: np.ndarray,
|
|
35
34
|
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
35
|
num_neighbors: List[int],
|
|
38
36
|
exclude_cols_dict: Dict[str, List[str]],
|
|
39
37
|
) -> Subgraph:
|
|
@@ -92,15 +90,23 @@ class LocalGraphSampler:
|
|
|
92
90
|
)
|
|
93
91
|
continue
|
|
94
92
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
df = df.iloc[
|
|
99
|
-
|
|
93
|
+
row: Optional[np.ndarray] = None
|
|
94
|
+
if table_name in self._graph_store.end_time_column_dict:
|
|
95
|
+
# Set end time to NaT for all values greater than anchor time:
|
|
96
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
97
|
+
col_name = self._graph_store.end_time_column_dict[table_name]
|
|
98
|
+
ser = df[col_name]
|
|
99
|
+
value = ser.astype('datetime64[ns]').astype(int).to_numpy()
|
|
100
|
+
mask = value > time[batch]
|
|
101
|
+
df.loc[mask, col_name] = pd.NaT
|
|
100
102
|
else:
|
|
101
|
-
df
|
|
102
|
-
|
|
103
|
-
|
|
103
|
+
# Only store unique rows in `df` above a certain threshold:
|
|
104
|
+
unique_node, inverse = np.unique(node, return_inverse=True)
|
|
105
|
+
if len(node) > 1.05 * len(unique_node):
|
|
106
|
+
df = df.iloc[unique_node].reset_index(drop=True)
|
|
107
|
+
row = inverse
|
|
108
|
+
else:
|
|
109
|
+
df = df.iloc[node].reset_index(drop=True)
|
|
104
110
|
|
|
105
111
|
# Filter data frame to minimal set of columns:
|
|
106
112
|
df = df[columns]
|
|
@@ -45,6 +45,7 @@ class LocalGraphStore:
|
|
|
45
45
|
|
|
46
46
|
(
|
|
47
47
|
self.time_column_dict,
|
|
48
|
+
self.end_time_column_dict,
|
|
48
49
|
self.time_dict,
|
|
49
50
|
self.min_time,
|
|
50
51
|
self.max_time,
|
|
@@ -219,16 +220,21 @@ class LocalGraphStore:
|
|
|
219
220
|
self,
|
|
220
221
|
graph: LocalGraph,
|
|
221
222
|
) -> Tuple[
|
|
223
|
+
Dict[str, str],
|
|
222
224
|
Dict[str, str],
|
|
223
225
|
Dict[str, np.ndarray],
|
|
224
226
|
pd.Timestamp,
|
|
225
227
|
pd.Timestamp,
|
|
226
228
|
]:
|
|
227
229
|
time_column_dict: Dict[str, str] = {}
|
|
230
|
+
end_time_column_dict: Dict[str, str] = {}
|
|
228
231
|
time_dict: Dict[str, np.ndarray] = {}
|
|
229
232
|
min_time = pd.Timestamp.max
|
|
230
233
|
max_time = pd.Timestamp.min
|
|
231
234
|
for table in graph.tables.values():
|
|
235
|
+
if table._end_time_column is not None:
|
|
236
|
+
end_time_column_dict[table.name] = table._end_time_column
|
|
237
|
+
|
|
232
238
|
if table._time_column is None:
|
|
233
239
|
continue
|
|
234
240
|
|
|
@@ -243,7 +249,13 @@ class LocalGraphStore:
|
|
|
243
249
|
min_time = min(min_time, time.min())
|
|
244
250
|
max_time = max(max_time, time.max())
|
|
245
251
|
|
|
246
|
-
return
|
|
252
|
+
return (
|
|
253
|
+
time_column_dict,
|
|
254
|
+
end_time_column_dict,
|
|
255
|
+
time_dict,
|
|
256
|
+
min_time,
|
|
257
|
+
max_time,
|
|
258
|
+
)
|
|
247
259
|
|
|
248
260
|
def get_csc(
|
|
249
261
|
self,
|
|
@@ -1,23 +1,40 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.
|
|
6
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
+
from kumoapi.pquery.AST import (
|
|
8
|
+
Aggregation,
|
|
9
|
+
ASTNode,
|
|
10
|
+
Column,
|
|
11
|
+
Condition,
|
|
12
|
+
Filter,
|
|
13
|
+
Join,
|
|
14
|
+
LogicalOperation,
|
|
15
|
+
)
|
|
16
|
+
from kumoapi.task import TaskType
|
|
17
|
+
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
8
18
|
|
|
9
19
|
import kumoai.kumolib as kumolib
|
|
10
20
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import
|
|
21
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
22
|
|
|
13
23
|
_coverage_warned = False
|
|
14
24
|
|
|
15
25
|
|
|
26
|
+
class SamplingSpec(NamedTuple):
|
|
27
|
+
edge_type: Tuple[str, str, str]
|
|
28
|
+
hop: int
|
|
29
|
+
start_offset: Optional[DateOffset]
|
|
30
|
+
end_offset: Optional[DateOffset]
|
|
31
|
+
|
|
32
|
+
|
|
16
33
|
class LocalPQueryDriver:
|
|
17
34
|
def __init__(
|
|
18
35
|
self,
|
|
19
36
|
graph_store: LocalGraphStore,
|
|
20
|
-
query:
|
|
37
|
+
query: ValidatedPredictiveQuery,
|
|
21
38
|
random_seed: Optional[int] = None,
|
|
22
39
|
) -> None:
|
|
23
40
|
self._graph_store = graph_store
|
|
@@ -27,14 +44,13 @@ class LocalPQueryDriver:
|
|
|
27
44
|
|
|
28
45
|
def _get_candidates(
|
|
29
46
|
self,
|
|
30
|
-
anchor_time: Union[pd.Timestamp, Literal['entity']],
|
|
31
47
|
exclude_node: Optional[np.ndarray] = None,
|
|
32
48
|
) -> np.ndarray:
|
|
33
49
|
|
|
34
50
|
if self._query.query_type == QueryType.TEMPORAL:
|
|
35
51
|
assert exclude_node is None
|
|
36
52
|
|
|
37
|
-
table_name = self._query.
|
|
53
|
+
table_name = self._query.entity_table
|
|
38
54
|
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
39
55
|
mask_dict = self._graph_store.mask_dict
|
|
40
56
|
|
|
@@ -61,6 +77,30 @@ class LocalPQueryDriver:
|
|
|
61
77
|
|
|
62
78
|
return candidate
|
|
63
79
|
|
|
80
|
+
def _filter_candidates_by_time(
|
|
81
|
+
self,
|
|
82
|
+
candidate: np.ndarray,
|
|
83
|
+
anchor_time: pd.Timestamp,
|
|
84
|
+
) -> np.ndarray:
|
|
85
|
+
|
|
86
|
+
entity = self._query.entity_table
|
|
87
|
+
|
|
88
|
+
# Filter out entities that do not exist yet in time:
|
|
89
|
+
time_sec = self._graph_store.time_dict.get(entity)
|
|
90
|
+
if time_sec is not None:
|
|
91
|
+
mask = time_sec[candidate] <= (anchor_time.value // (1000**3))
|
|
92
|
+
candidate = candidate[mask]
|
|
93
|
+
|
|
94
|
+
# Filter out entities that no longer exist in time:
|
|
95
|
+
end_time_col = self._graph_store.end_time_column_dict.get(entity)
|
|
96
|
+
if end_time_col is not None:
|
|
97
|
+
ser = self._graph_store.df_dict[entity][end_time_col]
|
|
98
|
+
ser = ser.iloc[candidate]
|
|
99
|
+
mask = (anchor_time < ser) | ser.isna().to_numpy()
|
|
100
|
+
candidate = candidate[mask]
|
|
101
|
+
|
|
102
|
+
return candidate
|
|
103
|
+
|
|
64
104
|
def collect_test(
|
|
65
105
|
self,
|
|
66
106
|
size: int,
|
|
@@ -84,7 +124,7 @@ class LocalPQueryDriver:
|
|
|
84
124
|
"""
|
|
85
125
|
batch_size = size if batch_size is None else batch_size
|
|
86
126
|
|
|
87
|
-
candidate = self._get_candidates(
|
|
127
|
+
candidate = self._get_candidates()
|
|
88
128
|
|
|
89
129
|
nodes: List[np.ndarray] = []
|
|
90
130
|
times: List[pd.Series] = []
|
|
@@ -96,19 +136,12 @@ class LocalPQueryDriver:
|
|
|
96
136
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
97
137
|
|
|
98
138
|
if isinstance(anchor_time, pd.Timestamp):
|
|
99
|
-
|
|
100
|
-
time = self._graph_store.time_dict.get(
|
|
101
|
-
self._query.entity.pkey.table_name)
|
|
102
|
-
if time is not None:
|
|
103
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
104
|
-
|
|
105
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
139
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
106
140
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
107
141
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
108
142
|
else:
|
|
109
143
|
assert anchor_time == 'entity'
|
|
110
|
-
time = self._graph_store.time_dict[
|
|
111
|
-
self._query.entity.pkey.table_name]
|
|
144
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
112
145
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
113
146
|
|
|
114
147
|
y, mask = self(node, time)
|
|
@@ -185,7 +218,7 @@ class LocalPQueryDriver:
|
|
|
185
218
|
"""
|
|
186
219
|
batch_size = size if batch_size is None else batch_size
|
|
187
220
|
|
|
188
|
-
candidate = self._get_candidates(
|
|
221
|
+
candidate = self._get_candidates(exclude_node)
|
|
189
222
|
|
|
190
223
|
if len(candidate) == 0:
|
|
191
224
|
raise RuntimeError("Failed to generate any context examples "
|
|
@@ -201,19 +234,12 @@ class LocalPQueryDriver:
|
|
|
201
234
|
node = candidate[candidate_offset:candidate_offset + batch_size]
|
|
202
235
|
|
|
203
236
|
if isinstance(anchor_time, pd.Timestamp):
|
|
204
|
-
|
|
205
|
-
time = self._graph_store.time_dict.get(
|
|
206
|
-
self._query.entity.pkey.table_name)
|
|
207
|
-
if time is not None:
|
|
208
|
-
node = node[time[node] <= (anchor_time.value // (1000**3))]
|
|
209
|
-
|
|
210
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
237
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
211
238
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
212
239
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
213
240
|
else:
|
|
214
241
|
assert anchor_time == 'entity'
|
|
215
|
-
time = self._graph_store.time_dict[
|
|
216
|
-
self._query.entity.pkey.table_name]
|
|
242
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
217
243
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
218
244
|
|
|
219
245
|
y, mask = self(node, time)
|
|
@@ -238,7 +264,8 @@ class LocalPQueryDriver:
|
|
|
238
264
|
reached_end = True
|
|
239
265
|
break
|
|
240
266
|
candidate_offset = 0
|
|
241
|
-
|
|
267
|
+
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
+
anchor_time = anchor_time - (time_frame *
|
|
242
269
|
self._query.num_forecasts)
|
|
243
270
|
if anchor_time < self._graph_store.min_time:
|
|
244
271
|
reached_end = True
|
|
@@ -288,37 +315,30 @@ class LocalPQueryDriver:
|
|
|
288
315
|
mask: Optional[np.ndarray] = None
|
|
289
316
|
|
|
290
317
|
if isinstance(anchor_time, pd.Timestamp):
|
|
291
|
-
|
|
292
|
-
time = self._graph_store.time_dict.get(
|
|
293
|
-
self._query.entity.pkey.table_name)
|
|
294
|
-
if time is not None:
|
|
295
|
-
mask = time[node] <= (anchor_time.value // (1000**3))
|
|
296
|
-
|
|
297
|
-
if isinstance(anchor_time, pd.Timestamp):
|
|
318
|
+
node = self._filter_candidates_by_time(node, anchor_time)
|
|
298
319
|
time = pd.Series(anchor_time).repeat(len(node))
|
|
299
320
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
300
321
|
else:
|
|
301
322
|
assert anchor_time == 'entity'
|
|
302
|
-
time = self._graph_store.time_dict[
|
|
303
|
-
self._query.entity.pkey.table_name]
|
|
323
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
304
324
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
305
325
|
|
|
306
|
-
if self._query.
|
|
326
|
+
if isinstance(self._query.entity_ast, Filter):
|
|
307
327
|
# Mask out via (temporal) entity filter:
|
|
308
|
-
|
|
328
|
+
executor = PQueryPandasExecutor()
|
|
309
329
|
masks: List[np.ndarray] = []
|
|
310
330
|
for start in range(0, len(node), batch_size):
|
|
311
331
|
feat_dict, time_dict, batch_dict = self._sample(
|
|
312
332
|
node[start:start + batch_size],
|
|
313
333
|
time.iloc[start:start + batch_size],
|
|
314
334
|
)
|
|
315
|
-
_mask =
|
|
316
|
-
filter=self._query.
|
|
335
|
+
_mask = executor.execute_filter(
|
|
336
|
+
filter=self._query.entity_ast,
|
|
317
337
|
feat_dict=feat_dict,
|
|
318
338
|
time_dict=time_dict,
|
|
319
339
|
batch_dict=batch_dict,
|
|
320
340
|
anchor_time=time.iloc[start:start + batch_size],
|
|
321
|
-
)
|
|
341
|
+
)[1]
|
|
322
342
|
masks.append(_mask)
|
|
323
343
|
|
|
324
344
|
_mask = np.concatenate(masks)
|
|
@@ -329,6 +349,96 @@ class LocalPQueryDriver:
|
|
|
329
349
|
|
|
330
350
|
return mask
|
|
331
351
|
|
|
352
|
+
def _get_sampling_specs(
|
|
353
|
+
self,
|
|
354
|
+
node: ASTNode,
|
|
355
|
+
hop: int,
|
|
356
|
+
seed_table_name: str,
|
|
357
|
+
edge_types: List[Tuple[str, str, str]],
|
|
358
|
+
num_forecasts: int = 1,
|
|
359
|
+
) -> List[SamplingSpec]:
|
|
360
|
+
if isinstance(node, (Aggregation, Column)):
|
|
361
|
+
if isinstance(node, Column):
|
|
362
|
+
table_name = node.fqn.split('.')[0]
|
|
363
|
+
if seed_table_name == table_name:
|
|
364
|
+
return []
|
|
365
|
+
else:
|
|
366
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
+
|
|
368
|
+
target_edge_types = [
|
|
369
|
+
edge_type for edge_type in edge_types if
|
|
370
|
+
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
+
]
|
|
372
|
+
if len(target_edge_types) != 1:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Could not find a unique foreign key from table "
|
|
375
|
+
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
+
|
|
377
|
+
if isinstance(node, Column):
|
|
378
|
+
return [
|
|
379
|
+
SamplingSpec(
|
|
380
|
+
edge_type=target_edge_types[0],
|
|
381
|
+
hop=hop + 1,
|
|
382
|
+
start_offset=None,
|
|
383
|
+
end_offset=None,
|
|
384
|
+
)
|
|
385
|
+
]
|
|
386
|
+
spec = SamplingSpec(
|
|
387
|
+
edge_type=target_edge_types[0],
|
|
388
|
+
hop=hop + 1,
|
|
389
|
+
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
+
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
+
num_forecasts,
|
|
392
|
+
)
|
|
393
|
+
return [spec] + self._get_sampling_specs(
|
|
394
|
+
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
+
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
+
specs = []
|
|
397
|
+
for child in node.children:
|
|
398
|
+
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
+
edge_types, num_forecasts)
|
|
400
|
+
return specs
|
|
401
|
+
|
|
402
|
+
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
+
edge_types = self._graph_store.edge_types
|
|
404
|
+
specs = self._get_sampling_specs(
|
|
405
|
+
self._query.target_ast, hop=0,
|
|
406
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
+
num_forecasts=self._query.num_forecasts)
|
|
408
|
+
specs += self._get_sampling_specs(
|
|
409
|
+
self._query.entity_ast, hop=0,
|
|
410
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
+
if self._query.whatif_ast is not None:
|
|
412
|
+
specs += self._get_sampling_specs(
|
|
413
|
+
self._query.whatif_ast, hop=0,
|
|
414
|
+
seed_table_name=self._query.entity_table,
|
|
415
|
+
edge_types=edge_types)
|
|
416
|
+
# Group specs according to edge type and hop:
|
|
417
|
+
spec_dict: Dict[
|
|
418
|
+
Tuple[Tuple[str, str, str], int],
|
|
419
|
+
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
+
] = {}
|
|
421
|
+
for spec in specs:
|
|
422
|
+
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
+
spec.start_offset,
|
|
425
|
+
spec.end_offset,
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
start_offset, end_offset = spec_dict[(
|
|
429
|
+
spec.edge_type,
|
|
430
|
+
spec.hop,
|
|
431
|
+
)]
|
|
432
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
+
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
+
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return [
|
|
438
|
+
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
+
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
+
]
|
|
441
|
+
|
|
332
442
|
def _sample(
|
|
333
443
|
self,
|
|
334
444
|
node: np.ndarray,
|
|
@@ -349,7 +459,7 @@ class LocalPQueryDriver:
|
|
|
349
459
|
The feature dictionary, the time column dictionary and the batch
|
|
350
460
|
dictionary.
|
|
351
461
|
"""
|
|
352
|
-
specs = self.
|
|
462
|
+
specs = self.get_sampling_specs()
|
|
353
463
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
354
464
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
355
465
|
time_offsets: Dict[
|
|
@@ -375,7 +485,7 @@ class LocalPQueryDriver:
|
|
|
375
485
|
|
|
376
486
|
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
377
487
|
node_types = list(
|
|
378
|
-
set([self._query.
|
|
488
|
+
set([self._query.entity_table])
|
|
379
489
|
| set(src for src, _, _ in edge_types)
|
|
380
490
|
| set(dst for _, _, dst in edge_types))
|
|
381
491
|
|
|
@@ -407,21 +517,33 @@ class LocalPQueryDriver:
|
|
|
407
517
|
'__'.join(edge_type): np.array(values)
|
|
408
518
|
for edge_type, values in time_offsets.items()
|
|
409
519
|
},
|
|
410
|
-
self._query.
|
|
520
|
+
self._query.entity_table,
|
|
411
521
|
node,
|
|
412
522
|
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
413
523
|
)
|
|
414
524
|
|
|
415
525
|
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
416
526
|
time_dict: Dict[str, pd.Series] = {}
|
|
417
|
-
column_dict =
|
|
418
|
-
|
|
527
|
+
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
+
for col in self._query.all_query_columns:
|
|
529
|
+
table_name, col_name = col.split('.')
|
|
530
|
+
if table_name not in column_dict:
|
|
531
|
+
column_dict[table_name] = set()
|
|
532
|
+
if col_name != '*':
|
|
533
|
+
column_dict[table_name].add(col_name)
|
|
534
|
+
time_tables = self.find_time_tables()
|
|
419
535
|
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
420
536
|
df = self._graph_store.df_dict[table_name]
|
|
421
537
|
row_id = node_dict[table_name]
|
|
422
538
|
df = df.iloc[row_id].reset_index(drop=True)
|
|
423
539
|
if table_name in column_dict:
|
|
424
|
-
|
|
540
|
+
if len(column_dict[table_name]) == 0:
|
|
541
|
+
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
+
# to ensure we don't lose the information on node count
|
|
543
|
+
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
+
{'ones': [1] * len(df)})
|
|
545
|
+
else:
|
|
546
|
+
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
425
547
|
if table_name in time_tables:
|
|
426
548
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
427
549
|
time_dict[table_name] = df[time_col]
|
|
@@ -436,7 +558,7 @@ class LocalPQueryDriver:
|
|
|
436
558
|
|
|
437
559
|
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
438
560
|
|
|
439
|
-
y, mask =
|
|
561
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
440
562
|
query=self._query,
|
|
441
563
|
feat_dict=feat_dict,
|
|
442
564
|
time_dict=time_dict,
|
|
@@ -447,6 +569,62 @@ class LocalPQueryDriver:
|
|
|
447
569
|
|
|
448
570
|
return y, mask
|
|
449
571
|
|
|
572
|
+
def find_time_tables(self) -> List[str]:
|
|
573
|
+
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
+
time_tables = []
|
|
575
|
+
if isinstance(node, Aggregation):
|
|
576
|
+
time_tables.append(
|
|
577
|
+
node._get_target_column_name().split('.')[0])
|
|
578
|
+
for child in node.children:
|
|
579
|
+
time_tables += _find_time_tables(child)
|
|
580
|
+
return time_tables
|
|
581
|
+
|
|
582
|
+
time_tables = _find_time_tables(
|
|
583
|
+
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
+
if self._query.whatif_ast is not None:
|
|
585
|
+
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
+
return list(set(time_tables))
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_task_type(
|
|
590
|
+
query: ValidatedPredictiveQuery,
|
|
591
|
+
edge_types: List[Tuple[str, str, str]],
|
|
592
|
+
) -> TaskType:
|
|
593
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
+
|
|
596
|
+
target = query.target_ast
|
|
597
|
+
if isinstance(target, Join):
|
|
598
|
+
target = target.rhs_target
|
|
599
|
+
if isinstance(target, Aggregation):
|
|
600
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
+
'.')
|
|
603
|
+
target_edge_types = [
|
|
604
|
+
edge_type for edge_type in edge_types
|
|
605
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
+
]
|
|
607
|
+
if len(target_edge_types) != 1:
|
|
608
|
+
raise NotImplementedError(
|
|
609
|
+
f"Multilabel-classification queries based on "
|
|
610
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
+
f"planned to write a link prediction query instead, "
|
|
612
|
+
f"make sure to register '{col_name}' as a "
|
|
613
|
+
f"foreign key.")
|
|
614
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
+
|
|
616
|
+
return TaskType.REGRESSION
|
|
617
|
+
|
|
618
|
+
assert isinstance(target, Column)
|
|
619
|
+
|
|
620
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
+
|
|
623
|
+
if target.stype in {Stype.numerical}:
|
|
624
|
+
return TaskType.REGRESSION
|
|
625
|
+
|
|
626
|
+
raise NotImplementedError("Task type not yet supported")
|
|
627
|
+
|
|
450
628
|
|
|
451
629
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
452
630
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
@@ -487,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
|
487
665
|
total_ns += scaled_value
|
|
488
666
|
|
|
489
667
|
return total_ns
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
+
if any(arg is None for arg in args):
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
+
timestamps = [anchor + arg for arg in args]
|
|
676
|
+
assert len(timestamps) > 0
|
|
677
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
+
return args[argmin]
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
+
if any(arg is None for arg in args):
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
+
timestamps = [anchor + arg for arg in args]
|
|
687
|
+
assert len(timestamps) > 0
|
|
688
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
+
return args[argmax]
|