kumoai 2.11.0.dev202510161830__py3-none-any.whl → 2.12.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +4 -2
- kumoai/_version.py +1 -1
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +3 -1
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +277 -223
- kumoai/experimental/rfm/rfm.py +140 -72
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/trainer/trainer.py +9 -10
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.12.1.dist-info}/METADATA +2 -2
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.12.1.dist-info}/RECORD +18 -18
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.12.1.dist-info}/WHEEL +0 -0
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.12.1.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.11.0.dev202510161830.dist-info → kumoai-2.12.1.dist-info}/top_level.txt +0 -0
kumoai/__init__.py
CHANGED
|
@@ -200,9 +200,11 @@ def init(
|
|
|
200
200
|
|
|
201
201
|
logger = logging.getLogger('kumoai')
|
|
202
202
|
log_level = logging.getLevelName(logger.getEffectiveLevel())
|
|
203
|
+
|
|
203
204
|
logger.info(
|
|
204
|
-
"Successfully initialized the Kumo SDK
|
|
205
|
-
"
|
|
205
|
+
f"Successfully initialized the Kumo SDK (version {__version__}) "
|
|
206
|
+
f"against deployment {url}, with "
|
|
207
|
+
f"log level {log_level}.")
|
|
206
208
|
|
|
207
209
|
|
|
208
210
|
def set_log_level(level: str) -> None:
|
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.
|
|
1
|
+
__version__ = '2.12.1'
|
kumoai/client/endpoints.py
CHANGED
|
@@ -147,3 +147,4 @@ class RFMEndpoints:
|
|
|
147
147
|
explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
|
|
148
148
|
evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
|
|
149
149
|
validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
|
|
150
|
+
parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
|
kumoai/client/rfm.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
from kumoapi.json_serde import to_json_dict
|
|
2
4
|
from kumoapi.rfm import (
|
|
3
5
|
RFMEvaluateResponse,
|
|
4
6
|
RFMExplanationResponse,
|
|
7
|
+
RFMParseQueryRequest,
|
|
8
|
+
RFMParseQueryResponse,
|
|
5
9
|
RFMPredictResponse,
|
|
6
10
|
RFMValidateQueryRequest,
|
|
7
11
|
RFMValidateQueryResponse,
|
|
@@ -26,25 +30,32 @@ class RFMAPI:
|
|
|
26
30
|
Returns:
|
|
27
31
|
RFMPredictResponse containing the predictions
|
|
28
32
|
"""
|
|
29
|
-
# Send binary data to the predict endpoint
|
|
30
33
|
response = self._client._request(
|
|
31
|
-
RFMEndpoints.predict,
|
|
32
|
-
|
|
34
|
+
RFMEndpoints.predict,
|
|
35
|
+
data=request,
|
|
36
|
+
headers={'Content-Type': 'application/x-protobuf'},
|
|
37
|
+
)
|
|
33
38
|
raise_on_error(response)
|
|
34
39
|
return parse_response(RFMPredictResponse, response)
|
|
35
40
|
|
|
36
|
-
def explain(
|
|
41
|
+
def explain(
|
|
42
|
+
self,
|
|
43
|
+
request: bytes,
|
|
44
|
+
skip_summary: bool = False,
|
|
45
|
+
) -> RFMExplanationResponse:
|
|
37
46
|
"""Explain the RFM model on the given context.
|
|
38
47
|
|
|
39
48
|
Args:
|
|
40
49
|
request: The predict request as serialized protobuf.
|
|
50
|
+
skip_summary: Whether to skip generating a human-readable summary
|
|
51
|
+
of the explanation.
|
|
41
52
|
|
|
42
53
|
Returns:
|
|
43
54
|
RFMPredictResponse containing the explanations
|
|
44
55
|
"""
|
|
45
|
-
|
|
56
|
+
params: dict[str, Any] = {'generate_summary': not skip_summary}
|
|
46
57
|
response = self._client._request(
|
|
47
|
-
RFMEndpoints.explain, data=request,
|
|
58
|
+
RFMEndpoints.explain, data=request, params=params,
|
|
48
59
|
headers={'Content-Type': 'application/x-protobuf'})
|
|
49
60
|
raise_on_error(response)
|
|
50
61
|
return parse_response(RFMExplanationResponse, response)
|
|
@@ -58,7 +69,6 @@ class RFMAPI:
|
|
|
58
69
|
Returns:
|
|
59
70
|
RFMEvaluateResponse containing the computed metrics
|
|
60
71
|
"""
|
|
61
|
-
# Send binary data to the evaluate endpoint
|
|
62
72
|
response = self._client._request(
|
|
63
73
|
RFMEndpoints.evaluate, data=request,
|
|
64
74
|
headers={'Content-Type': 'application/x-protobuf'})
|
|
@@ -82,3 +92,21 @@ class RFMAPI:
|
|
|
82
92
|
json=to_json_dict(request))
|
|
83
93
|
raise_on_error(response)
|
|
84
94
|
return parse_response(RFMValidateQueryResponse, response)
|
|
95
|
+
|
|
96
|
+
def parse_query(
|
|
97
|
+
self,
|
|
98
|
+
request: RFMParseQueryRequest,
|
|
99
|
+
) -> RFMParseQueryResponse:
|
|
100
|
+
"""Validate a predictive query against a graph.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
request: The request object containing
|
|
104
|
+
the query and graph definition
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
RFMParseQueryResponse containing the QueryDefinition
|
|
108
|
+
"""
|
|
109
|
+
response = self._client._request(RFMEndpoints.parse_query,
|
|
110
|
+
json=to_json_dict(request))
|
|
111
|
+
raise_on_error(response)
|
|
112
|
+
return parse_response(RFMParseQueryResponse, response)
|
|
@@ -36,7 +36,7 @@ import os
|
|
|
36
36
|
import kumoai
|
|
37
37
|
from .local_table import LocalTable
|
|
38
38
|
from .local_graph import LocalGraph
|
|
39
|
-
from .rfm import KumoRFM
|
|
39
|
+
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
40
40
|
from .authenticate import authenticate
|
|
41
41
|
|
|
42
42
|
|
|
@@ -60,6 +60,8 @@ __all__ = [
|
|
|
60
60
|
'LocalTable',
|
|
61
61
|
'LocalGraph',
|
|
62
62
|
'KumoRFM',
|
|
63
|
+
'ExplainConfig',
|
|
64
|
+
'Explanation',
|
|
63
65
|
'authenticate',
|
|
64
66
|
'init',
|
|
65
67
|
]
|
|
@@ -1,23 +1,40 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.
|
|
6
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
+
from kumoapi.pquery.AST import (
|
|
8
|
+
Aggregation,
|
|
9
|
+
ASTNode,
|
|
10
|
+
Column,
|
|
11
|
+
Condition,
|
|
12
|
+
Filter,
|
|
13
|
+
Join,
|
|
14
|
+
LogicalOperation,
|
|
15
|
+
)
|
|
16
|
+
from kumoapi.task import TaskType
|
|
17
|
+
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
8
18
|
|
|
9
19
|
import kumoai.kumolib as kumolib
|
|
10
20
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import
|
|
21
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
22
|
|
|
13
23
|
_coverage_warned = False
|
|
14
24
|
|
|
15
25
|
|
|
26
|
+
class SamplingSpec(NamedTuple):
|
|
27
|
+
edge_type: Tuple[str, str, str]
|
|
28
|
+
hop: int
|
|
29
|
+
start_offset: Optional[DateOffset]
|
|
30
|
+
end_offset: Optional[DateOffset]
|
|
31
|
+
|
|
32
|
+
|
|
16
33
|
class LocalPQueryDriver:
|
|
17
34
|
def __init__(
|
|
18
35
|
self,
|
|
19
36
|
graph_store: LocalGraphStore,
|
|
20
|
-
query:
|
|
37
|
+
query: ValidatedPredictiveQuery,
|
|
21
38
|
random_seed: Optional[int] = None,
|
|
22
39
|
) -> None:
|
|
23
40
|
self._graph_store = graph_store
|
|
@@ -33,7 +50,7 @@ class LocalPQueryDriver:
|
|
|
33
50
|
if self._query.query_type == QueryType.TEMPORAL:
|
|
34
51
|
assert exclude_node is None
|
|
35
52
|
|
|
36
|
-
table_name = self._query.
|
|
53
|
+
table_name = self._query.entity_table
|
|
37
54
|
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
38
55
|
mask_dict = self._graph_store.mask_dict
|
|
39
56
|
|
|
@@ -66,7 +83,7 @@ class LocalPQueryDriver:
|
|
|
66
83
|
anchor_time: pd.Timestamp,
|
|
67
84
|
) -> np.ndarray:
|
|
68
85
|
|
|
69
|
-
entity = self._query.
|
|
86
|
+
entity = self._query.entity_table
|
|
70
87
|
|
|
71
88
|
# Filter out entities that do not exist yet in time:
|
|
72
89
|
time_sec = self._graph_store.time_dict.get(entity)
|
|
@@ -124,8 +141,7 @@ class LocalPQueryDriver:
|
|
|
124
141
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
125
142
|
else:
|
|
126
143
|
assert anchor_time == 'entity'
|
|
127
|
-
time = self._graph_store.time_dict[
|
|
128
|
-
self._query.entity.pkey.table_name]
|
|
144
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
129
145
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
130
146
|
|
|
131
147
|
y, mask = self(node, time)
|
|
@@ -223,8 +239,7 @@ class LocalPQueryDriver:
|
|
|
223
239
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
224
240
|
else:
|
|
225
241
|
assert anchor_time == 'entity'
|
|
226
|
-
time = self._graph_store.time_dict[
|
|
227
|
-
self._query.entity.pkey.table_name]
|
|
242
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
228
243
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
229
244
|
|
|
230
245
|
y, mask = self(node, time)
|
|
@@ -249,7 +264,8 @@ class LocalPQueryDriver:
|
|
|
249
264
|
reached_end = True
|
|
250
265
|
break
|
|
251
266
|
candidate_offset = 0
|
|
252
|
-
|
|
267
|
+
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
+
anchor_time = anchor_time - (time_frame *
|
|
253
269
|
self._query.num_forecasts)
|
|
254
270
|
if anchor_time < self._graph_store.min_time:
|
|
255
271
|
reached_end = True
|
|
@@ -304,26 +320,25 @@ class LocalPQueryDriver:
|
|
|
304
320
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
305
321
|
else:
|
|
306
322
|
assert anchor_time == 'entity'
|
|
307
|
-
time = self._graph_store.time_dict[
|
|
308
|
-
self._query.entity.pkey.table_name]
|
|
323
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
309
324
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
310
325
|
|
|
311
|
-
if self._query.
|
|
326
|
+
if isinstance(self._query.entity_ast, Filter):
|
|
312
327
|
# Mask out via (temporal) entity filter:
|
|
313
|
-
|
|
328
|
+
executor = PQueryPandasExecutor()
|
|
314
329
|
masks: List[np.ndarray] = []
|
|
315
330
|
for start in range(0, len(node), batch_size):
|
|
316
331
|
feat_dict, time_dict, batch_dict = self._sample(
|
|
317
332
|
node[start:start + batch_size],
|
|
318
333
|
time.iloc[start:start + batch_size],
|
|
319
334
|
)
|
|
320
|
-
_mask =
|
|
321
|
-
filter=self._query.
|
|
335
|
+
_mask = executor.execute_filter(
|
|
336
|
+
filter=self._query.entity_ast,
|
|
322
337
|
feat_dict=feat_dict,
|
|
323
338
|
time_dict=time_dict,
|
|
324
339
|
batch_dict=batch_dict,
|
|
325
340
|
anchor_time=time.iloc[start:start + batch_size],
|
|
326
|
-
)
|
|
341
|
+
)[1]
|
|
327
342
|
masks.append(_mask)
|
|
328
343
|
|
|
329
344
|
_mask = np.concatenate(masks)
|
|
@@ -334,6 +349,96 @@ class LocalPQueryDriver:
|
|
|
334
349
|
|
|
335
350
|
return mask
|
|
336
351
|
|
|
352
|
+
def _get_sampling_specs(
|
|
353
|
+
self,
|
|
354
|
+
node: ASTNode,
|
|
355
|
+
hop: int,
|
|
356
|
+
seed_table_name: str,
|
|
357
|
+
edge_types: List[Tuple[str, str, str]],
|
|
358
|
+
num_forecasts: int = 1,
|
|
359
|
+
) -> List[SamplingSpec]:
|
|
360
|
+
if isinstance(node, (Aggregation, Column)):
|
|
361
|
+
if isinstance(node, Column):
|
|
362
|
+
table_name = node.fqn.split('.')[0]
|
|
363
|
+
if seed_table_name == table_name:
|
|
364
|
+
return []
|
|
365
|
+
else:
|
|
366
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
+
|
|
368
|
+
target_edge_types = [
|
|
369
|
+
edge_type for edge_type in edge_types if
|
|
370
|
+
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
+
]
|
|
372
|
+
if len(target_edge_types) != 1:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Could not find a unique foreign key from table "
|
|
375
|
+
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
+
|
|
377
|
+
if isinstance(node, Column):
|
|
378
|
+
return [
|
|
379
|
+
SamplingSpec(
|
|
380
|
+
edge_type=target_edge_types[0],
|
|
381
|
+
hop=hop + 1,
|
|
382
|
+
start_offset=None,
|
|
383
|
+
end_offset=None,
|
|
384
|
+
)
|
|
385
|
+
]
|
|
386
|
+
spec = SamplingSpec(
|
|
387
|
+
edge_type=target_edge_types[0],
|
|
388
|
+
hop=hop + 1,
|
|
389
|
+
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
+
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
+
num_forecasts,
|
|
392
|
+
)
|
|
393
|
+
return [spec] + self._get_sampling_specs(
|
|
394
|
+
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
+
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
+
specs = []
|
|
397
|
+
for child in node.children:
|
|
398
|
+
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
+
edge_types, num_forecasts)
|
|
400
|
+
return specs
|
|
401
|
+
|
|
402
|
+
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
+
edge_types = self._graph_store.edge_types
|
|
404
|
+
specs = self._get_sampling_specs(
|
|
405
|
+
self._query.target_ast, hop=0,
|
|
406
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
+
num_forecasts=self._query.num_forecasts)
|
|
408
|
+
specs += self._get_sampling_specs(
|
|
409
|
+
self._query.entity_ast, hop=0,
|
|
410
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
+
if self._query.whatif_ast is not None:
|
|
412
|
+
specs += self._get_sampling_specs(
|
|
413
|
+
self._query.whatif_ast, hop=0,
|
|
414
|
+
seed_table_name=self._query.entity_table,
|
|
415
|
+
edge_types=edge_types)
|
|
416
|
+
# Group specs according to edge type and hop:
|
|
417
|
+
spec_dict: Dict[
|
|
418
|
+
Tuple[Tuple[str, str, str], int],
|
|
419
|
+
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
+
] = {}
|
|
421
|
+
for spec in specs:
|
|
422
|
+
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
+
spec.start_offset,
|
|
425
|
+
spec.end_offset,
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
start_offset, end_offset = spec_dict[(
|
|
429
|
+
spec.edge_type,
|
|
430
|
+
spec.hop,
|
|
431
|
+
)]
|
|
432
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
+
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
+
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return [
|
|
438
|
+
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
+
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
+
]
|
|
441
|
+
|
|
337
442
|
def _sample(
|
|
338
443
|
self,
|
|
339
444
|
node: np.ndarray,
|
|
@@ -354,7 +459,7 @@ class LocalPQueryDriver:
|
|
|
354
459
|
The feature dictionary, the time column dictionary and the batch
|
|
355
460
|
dictionary.
|
|
356
461
|
"""
|
|
357
|
-
specs = self.
|
|
462
|
+
specs = self.get_sampling_specs()
|
|
358
463
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
359
464
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
360
465
|
time_offsets: Dict[
|
|
@@ -380,7 +485,7 @@ class LocalPQueryDriver:
|
|
|
380
485
|
|
|
381
486
|
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
382
487
|
node_types = list(
|
|
383
|
-
set([self._query.
|
|
488
|
+
set([self._query.entity_table])
|
|
384
489
|
| set(src for src, _, _ in edge_types)
|
|
385
490
|
| set(dst for _, _, dst in edge_types))
|
|
386
491
|
|
|
@@ -412,21 +517,33 @@ class LocalPQueryDriver:
|
|
|
412
517
|
'__'.join(edge_type): np.array(values)
|
|
413
518
|
for edge_type, values in time_offsets.items()
|
|
414
519
|
},
|
|
415
|
-
self._query.
|
|
520
|
+
self._query.entity_table,
|
|
416
521
|
node,
|
|
417
522
|
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
418
523
|
)
|
|
419
524
|
|
|
420
525
|
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
421
526
|
time_dict: Dict[str, pd.Series] = {}
|
|
422
|
-
column_dict =
|
|
423
|
-
|
|
527
|
+
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
+
for col in self._query.all_query_columns:
|
|
529
|
+
table_name, col_name = col.split('.')
|
|
530
|
+
if table_name not in column_dict:
|
|
531
|
+
column_dict[table_name] = set()
|
|
532
|
+
if col_name != '*':
|
|
533
|
+
column_dict[table_name].add(col_name)
|
|
534
|
+
time_tables = self.find_time_tables()
|
|
424
535
|
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
425
536
|
df = self._graph_store.df_dict[table_name]
|
|
426
537
|
row_id = node_dict[table_name]
|
|
427
538
|
df = df.iloc[row_id].reset_index(drop=True)
|
|
428
539
|
if table_name in column_dict:
|
|
429
|
-
|
|
540
|
+
if len(column_dict[table_name]) == 0:
|
|
541
|
+
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
+
# to ensure we don't lose the information on node count
|
|
543
|
+
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
+
{'ones': [1] * len(df)})
|
|
545
|
+
else:
|
|
546
|
+
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
430
547
|
if table_name in time_tables:
|
|
431
548
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
432
549
|
time_dict[table_name] = df[time_col]
|
|
@@ -441,7 +558,7 @@ class LocalPQueryDriver:
|
|
|
441
558
|
|
|
442
559
|
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
443
560
|
|
|
444
|
-
y, mask =
|
|
561
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
445
562
|
query=self._query,
|
|
446
563
|
feat_dict=feat_dict,
|
|
447
564
|
time_dict=time_dict,
|
|
@@ -452,6 +569,62 @@ class LocalPQueryDriver:
|
|
|
452
569
|
|
|
453
570
|
return y, mask
|
|
454
571
|
|
|
572
|
+
def find_time_tables(self) -> List[str]:
|
|
573
|
+
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
+
time_tables = []
|
|
575
|
+
if isinstance(node, Aggregation):
|
|
576
|
+
time_tables.append(
|
|
577
|
+
node._get_target_column_name().split('.')[0])
|
|
578
|
+
for child in node.children:
|
|
579
|
+
time_tables += _find_time_tables(child)
|
|
580
|
+
return time_tables
|
|
581
|
+
|
|
582
|
+
time_tables = _find_time_tables(
|
|
583
|
+
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
+
if self._query.whatif_ast is not None:
|
|
585
|
+
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
+
return list(set(time_tables))
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_task_type(
|
|
590
|
+
query: ValidatedPredictiveQuery,
|
|
591
|
+
edge_types: List[Tuple[str, str, str]],
|
|
592
|
+
) -> TaskType:
|
|
593
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
+
|
|
596
|
+
target = query.target_ast
|
|
597
|
+
if isinstance(target, Join):
|
|
598
|
+
target = target.rhs_target
|
|
599
|
+
if isinstance(target, Aggregation):
|
|
600
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
+
'.')
|
|
603
|
+
target_edge_types = [
|
|
604
|
+
edge_type for edge_type in edge_types
|
|
605
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
+
]
|
|
607
|
+
if len(target_edge_types) != 1:
|
|
608
|
+
raise NotImplementedError(
|
|
609
|
+
f"Multilabel-classification queries based on "
|
|
610
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
+
f"planned to write a link prediction query instead, "
|
|
612
|
+
f"make sure to register '{col_name}' as a "
|
|
613
|
+
f"foreign key.")
|
|
614
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
+
|
|
616
|
+
return TaskType.REGRESSION
|
|
617
|
+
|
|
618
|
+
assert isinstance(target, Column)
|
|
619
|
+
|
|
620
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
+
|
|
623
|
+
if target.stype in {Stype.numerical}:
|
|
624
|
+
return TaskType.REGRESSION
|
|
625
|
+
|
|
626
|
+
raise NotImplementedError("Task type not yet supported")
|
|
627
|
+
|
|
455
628
|
|
|
456
629
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
457
630
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
@@ -492,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
|
492
665
|
total_ns += scaled_value
|
|
493
666
|
|
|
494
667
|
return total_ns
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
+
if any(arg is None for arg in args):
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
+
timestamps = [anchor + arg for arg in args]
|
|
676
|
+
assert len(timestamps) > 0
|
|
677
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
+
return args[argmin]
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
+
if any(arg is None for arg in args):
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
+
timestamps = [anchor + arg for arg in args]
|
|
687
|
+
assert len(timestamps) > 0
|
|
688
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
+
return args[argmax]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .
|
|
1
|
+
from .executor import PQueryExecutor
|
|
2
|
+
from .pandas_executor import PQueryPandasExecutor
|
|
3
3
|
|
|
4
4
|
__all__ = [
|
|
5
|
-
'
|
|
6
|
-
'
|
|
5
|
+
'PQueryExecutor',
|
|
6
|
+
'PQueryPandasExecutor',
|
|
7
7
|
]
|
|
@@ -1,23 +1,14 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Generic,
|
|
2
|
+
from typing import Dict, Generic, Tuple, TypeVar
|
|
3
3
|
|
|
4
|
-
from kumoapi.
|
|
5
|
-
from kumoapi.
|
|
4
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
5
|
+
from kumoapi.pquery.AST import (
|
|
6
6
|
Aggregation,
|
|
7
|
-
AggregationType,
|
|
8
|
-
BoolOp,
|
|
9
7
|
Column,
|
|
10
8
|
Condition,
|
|
11
9
|
Filter,
|
|
12
|
-
|
|
13
|
-
FloatList,
|
|
14
|
-
Int,
|
|
15
|
-
IntList,
|
|
10
|
+
Join,
|
|
16
11
|
LogicalOperation,
|
|
17
|
-
MemberOp,
|
|
18
|
-
RelOp,
|
|
19
|
-
Str,
|
|
20
|
-
StrList,
|
|
21
12
|
)
|
|
22
13
|
|
|
23
14
|
TableData = TypeVar('TableData')
|
|
@@ -25,58 +16,33 @@ ColumnData = TypeVar('ColumnData')
|
|
|
25
16
|
IndexData = TypeVar('IndexData')
|
|
26
17
|
|
|
27
18
|
|
|
28
|
-
class
|
|
19
|
+
class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
|
|
29
20
|
@abstractmethod
|
|
30
|
-
def
|
|
21
|
+
def execute_column(
|
|
31
22
|
self,
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
batch: IndexData,
|
|
35
|
-
batch_size: int,
|
|
23
|
+
column: Column,
|
|
24
|
+
feat_dict: Dict[str, TableData],
|
|
36
25
|
filter_na: bool = True,
|
|
37
26
|
) -> Tuple[ColumnData, IndexData]:
|
|
38
27
|
pass
|
|
39
28
|
|
|
40
29
|
@abstractmethod
|
|
41
|
-
def
|
|
42
|
-
self,
|
|
43
|
-
left: ColumnData,
|
|
44
|
-
op: RelOp,
|
|
45
|
-
right: Union[Int, Float, Str, None],
|
|
46
|
-
) -> ColumnData:
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def eval_member_op(
|
|
51
|
-
self,
|
|
52
|
-
left: ColumnData,
|
|
53
|
-
op: MemberOp,
|
|
54
|
-
right: Union[IntList, FloatList, StrList],
|
|
55
|
-
) -> ColumnData:
|
|
56
|
-
pass
|
|
57
|
-
|
|
58
|
-
@abstractmethod
|
|
59
|
-
def eval_bool_op(
|
|
60
|
-
self,
|
|
61
|
-
left: ColumnData,
|
|
62
|
-
op: BoolOp,
|
|
63
|
-
right: Optional[ColumnData],
|
|
64
|
-
) -> ColumnData:
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def eval_column(
|
|
30
|
+
def execute_aggregation(
|
|
69
31
|
self,
|
|
70
|
-
|
|
32
|
+
aggr: Aggregation,
|
|
71
33
|
feat_dict: Dict[str, TableData],
|
|
34
|
+
time_dict: Dict[str, ColumnData],
|
|
35
|
+
batch_dict: Dict[str, IndexData],
|
|
36
|
+
anchor_time: ColumnData,
|
|
72
37
|
filter_na: bool = True,
|
|
38
|
+
num_forecasts: int = 1,
|
|
73
39
|
) -> Tuple[ColumnData, IndexData]:
|
|
74
40
|
pass
|
|
75
41
|
|
|
76
42
|
@abstractmethod
|
|
77
|
-
def
|
|
43
|
+
def execute_condition(
|
|
78
44
|
self,
|
|
79
|
-
|
|
45
|
+
condition: Condition,
|
|
80
46
|
feat_dict: Dict[str, TableData],
|
|
81
47
|
time_dict: Dict[str, ColumnData],
|
|
82
48
|
batch_dict: Dict[str, IndexData],
|
|
@@ -87,9 +53,9 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
87
53
|
pass
|
|
88
54
|
|
|
89
55
|
@abstractmethod
|
|
90
|
-
def
|
|
56
|
+
def execute_logical_operation(
|
|
91
57
|
self,
|
|
92
|
-
|
|
58
|
+
logical_operation: LogicalOperation,
|
|
93
59
|
feat_dict: Dict[str, TableData],
|
|
94
60
|
time_dict: Dict[str, ColumnData],
|
|
95
61
|
batch_dict: Dict[str, IndexData],
|
|
@@ -100,9 +66,9 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
100
66
|
pass
|
|
101
67
|
|
|
102
68
|
@abstractmethod
|
|
103
|
-
def
|
|
69
|
+
def execute_join(
|
|
104
70
|
self,
|
|
105
|
-
|
|
71
|
+
join: Join,
|
|
106
72
|
feat_dict: Dict[str, TableData],
|
|
107
73
|
time_dict: Dict[str, ColumnData],
|
|
108
74
|
batch_dict: Dict[str, IndexData],
|
|
@@ -113,20 +79,20 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
113
79
|
pass
|
|
114
80
|
|
|
115
81
|
@abstractmethod
|
|
116
|
-
def
|
|
82
|
+
def execute_filter(
|
|
117
83
|
self,
|
|
118
84
|
filter: Filter,
|
|
119
85
|
feat_dict: Dict[str, TableData],
|
|
120
86
|
time_dict: Dict[str, ColumnData],
|
|
121
87
|
batch_dict: Dict[str, IndexData],
|
|
122
88
|
anchor_time: ColumnData,
|
|
123
|
-
) -> IndexData:
|
|
89
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
124
90
|
pass
|
|
125
91
|
|
|
126
92
|
@abstractmethod
|
|
127
|
-
def
|
|
93
|
+
def execute(
|
|
128
94
|
self,
|
|
129
|
-
query:
|
|
95
|
+
query: ValidatedPredictiveQuery,
|
|
130
96
|
feat_dict: Dict[str, TableData],
|
|
131
97
|
time_dict: Dict[str, ColumnData],
|
|
132
98
|
batch_dict: Dict[str, IndexData],
|