kumoai 2.12.0.dev202511051731__cp313-cp313-macosx_11_0_arm64.whl → 2.12.0.dev202511101731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/_version.py +1 -1
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +35 -7
- kumoai/experimental/rfm/__init__.py +3 -1
- kumoai/experimental/rfm/local_pquery_driver.py +221 -26
- kumoai/experimental/rfm/pquery/__init__.py +0 -4
- kumoai/experimental/rfm/pquery/pandas_executor.py +34 -8
- kumoai/experimental/rfm/rfm.py +127 -71
- kumoai/utils/progress_logger.py +10 -4
- {kumoai-2.12.0.dev202511051731.dist-info → kumoai-2.12.0.dev202511101731.dist-info}/METADATA +2 -2
- {kumoai-2.12.0.dev202511051731.dist-info → kumoai-2.12.0.dev202511101731.dist-info}/RECORD +14 -16
- kumoai/experimental/rfm/pquery/backend.py +0 -136
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
- {kumoai-2.12.0.dev202511051731.dist-info → kumoai-2.12.0.dev202511101731.dist-info}/WHEEL +0 -0
- {kumoai-2.12.0.dev202511051731.dist-info → kumoai-2.12.0.dev202511101731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.12.0.dev202511051731.dist-info → kumoai-2.12.0.dev202511101731.dist-info}/top_level.txt +0 -0
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.12.0.
|
|
1
|
+
__version__ = '2.12.0.dev202511101731'
|
kumoai/client/endpoints.py
CHANGED
|
@@ -147,3 +147,4 @@ class RFMEndpoints:
|
|
|
147
147
|
explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
|
|
148
148
|
evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
|
|
149
149
|
validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
|
|
150
|
+
parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
|
kumoai/client/rfm.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
1
3
|
from kumoapi.json_serde import to_json_dict
|
|
2
4
|
from kumoapi.rfm import (
|
|
3
5
|
RFMEvaluateResponse,
|
|
4
6
|
RFMExplanationResponse,
|
|
7
|
+
RFMParseQueryRequest,
|
|
8
|
+
RFMParseQueryResponse,
|
|
5
9
|
RFMPredictResponse,
|
|
6
10
|
RFMValidateQueryRequest,
|
|
7
11
|
RFMValidateQueryResponse,
|
|
@@ -26,25 +30,32 @@ class RFMAPI:
|
|
|
26
30
|
Returns:
|
|
27
31
|
RFMPredictResponse containing the predictions
|
|
28
32
|
"""
|
|
29
|
-
# Send binary data to the predict endpoint
|
|
30
33
|
response = self._client._request(
|
|
31
|
-
RFMEndpoints.predict,
|
|
32
|
-
|
|
34
|
+
RFMEndpoints.predict,
|
|
35
|
+
data=request,
|
|
36
|
+
headers={'Content-Type': 'application/x-protobuf'},
|
|
37
|
+
)
|
|
33
38
|
raise_on_error(response)
|
|
34
39
|
return parse_response(RFMPredictResponse, response)
|
|
35
40
|
|
|
36
|
-
def explain(
|
|
41
|
+
def explain(
|
|
42
|
+
self,
|
|
43
|
+
request: bytes,
|
|
44
|
+
skip_summary: bool = False,
|
|
45
|
+
) -> RFMExplanationResponse:
|
|
37
46
|
"""Explain the RFM model on the given context.
|
|
38
47
|
|
|
39
48
|
Args:
|
|
40
49
|
request: The predict request as serialized protobuf.
|
|
50
|
+
skip_summary: Whether to skip generating a human-readable summary
|
|
51
|
+
of the explanation.
|
|
41
52
|
|
|
42
53
|
Returns:
|
|
43
54
|
RFMPredictResponse containing the explanations
|
|
44
55
|
"""
|
|
45
|
-
|
|
56
|
+
params: dict[str, Any] = {'generate_summary': not skip_summary}
|
|
46
57
|
response = self._client._request(
|
|
47
|
-
RFMEndpoints.explain, data=request,
|
|
58
|
+
RFMEndpoints.explain, data=request, params=params,
|
|
48
59
|
headers={'Content-Type': 'application/x-protobuf'})
|
|
49
60
|
raise_on_error(response)
|
|
50
61
|
return parse_response(RFMExplanationResponse, response)
|
|
@@ -58,7 +69,6 @@ class RFMAPI:
|
|
|
58
69
|
Returns:
|
|
59
70
|
RFMEvaluateResponse containing the computed metrics
|
|
60
71
|
"""
|
|
61
|
-
# Send binary data to the evaluate endpoint
|
|
62
72
|
response = self._client._request(
|
|
63
73
|
RFMEndpoints.evaluate, data=request,
|
|
64
74
|
headers={'Content-Type': 'application/x-protobuf'})
|
|
@@ -82,3 +92,21 @@ class RFMAPI:
|
|
|
82
92
|
json=to_json_dict(request))
|
|
83
93
|
raise_on_error(response)
|
|
84
94
|
return parse_response(RFMValidateQueryResponse, response)
|
|
95
|
+
|
|
96
|
+
def parse_query(
|
|
97
|
+
self,
|
|
98
|
+
request: RFMParseQueryRequest,
|
|
99
|
+
) -> RFMParseQueryResponse:
|
|
100
|
+
"""Validate a predictive query against a graph.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
request: The request object containing
|
|
104
|
+
the query and graph definition
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
RFMParseQueryResponse containing the QueryDefinition
|
|
108
|
+
"""
|
|
109
|
+
response = self._client._request(RFMEndpoints.parse_query,
|
|
110
|
+
json=to_json_dict(request))
|
|
111
|
+
raise_on_error(response)
|
|
112
|
+
return parse_response(RFMParseQueryResponse, response)
|
|
@@ -36,7 +36,7 @@ import os
|
|
|
36
36
|
import kumoai
|
|
37
37
|
from .local_table import LocalTable
|
|
38
38
|
from .local_graph import LocalGraph
|
|
39
|
-
from .rfm import KumoRFM
|
|
39
|
+
from .rfm import ExplainConfig, Explanation, KumoRFM
|
|
40
40
|
from .authenticate import authenticate
|
|
41
41
|
|
|
42
42
|
|
|
@@ -60,6 +60,8 @@ __all__ = [
|
|
|
60
60
|
'LocalTable',
|
|
61
61
|
'LocalGraph',
|
|
62
62
|
'KumoRFM',
|
|
63
|
+
'ExplainConfig',
|
|
64
|
+
'Explanation',
|
|
63
65
|
'authenticate',
|
|
64
66
|
'init',
|
|
65
67
|
]
|
|
@@ -1,23 +1,40 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import Dict, List, Literal, Optional, Tuple, Union
|
|
2
|
+
from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
6
|
-
from kumoapi.pquery import QueryType
|
|
7
|
-
from kumoapi.
|
|
6
|
+
from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
|
|
7
|
+
from kumoapi.pquery.AST import (
|
|
8
|
+
Aggregation,
|
|
9
|
+
ASTNode,
|
|
10
|
+
Column,
|
|
11
|
+
Condition,
|
|
12
|
+
Filter,
|
|
13
|
+
Join,
|
|
14
|
+
LogicalOperation,
|
|
15
|
+
)
|
|
16
|
+
from kumoapi.task import TaskType
|
|
17
|
+
from kumoapi.typing import AggregationType, DateOffset, Stype
|
|
8
18
|
|
|
9
19
|
import kumoai.kumolib as kumolib
|
|
10
20
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
11
|
-
from kumoai.experimental.rfm.pquery import
|
|
21
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
22
|
|
|
13
23
|
_coverage_warned = False
|
|
14
24
|
|
|
15
25
|
|
|
26
|
+
class SamplingSpec(NamedTuple):
|
|
27
|
+
edge_type: Tuple[str, str, str]
|
|
28
|
+
hop: int
|
|
29
|
+
start_offset: Optional[DateOffset]
|
|
30
|
+
end_offset: Optional[DateOffset]
|
|
31
|
+
|
|
32
|
+
|
|
16
33
|
class LocalPQueryDriver:
|
|
17
34
|
def __init__(
|
|
18
35
|
self,
|
|
19
36
|
graph_store: LocalGraphStore,
|
|
20
|
-
query:
|
|
37
|
+
query: ValidatedPredictiveQuery,
|
|
21
38
|
random_seed: Optional[int] = None,
|
|
22
39
|
) -> None:
|
|
23
40
|
self._graph_store = graph_store
|
|
@@ -33,7 +50,7 @@ class LocalPQueryDriver:
|
|
|
33
50
|
if self._query.query_type == QueryType.TEMPORAL:
|
|
34
51
|
assert exclude_node is None
|
|
35
52
|
|
|
36
|
-
table_name = self._query.
|
|
53
|
+
table_name = self._query.entity_table
|
|
37
54
|
num_nodes = len(self._graph_store.df_dict[table_name])
|
|
38
55
|
mask_dict = self._graph_store.mask_dict
|
|
39
56
|
|
|
@@ -66,7 +83,7 @@ class LocalPQueryDriver:
|
|
|
66
83
|
anchor_time: pd.Timestamp,
|
|
67
84
|
) -> np.ndarray:
|
|
68
85
|
|
|
69
|
-
entity = self._query.
|
|
86
|
+
entity = self._query.entity_table
|
|
70
87
|
|
|
71
88
|
# Filter out entities that do not exist yet in time:
|
|
72
89
|
time_sec = self._graph_store.time_dict.get(entity)
|
|
@@ -124,8 +141,7 @@ class LocalPQueryDriver:
|
|
|
124
141
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
125
142
|
else:
|
|
126
143
|
assert anchor_time == 'entity'
|
|
127
|
-
time = self._graph_store.time_dict[
|
|
128
|
-
self._query.entity.pkey.table_name]
|
|
144
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
129
145
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
130
146
|
|
|
131
147
|
y, mask = self(node, time)
|
|
@@ -223,8 +239,7 @@ class LocalPQueryDriver:
|
|
|
223
239
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
224
240
|
else:
|
|
225
241
|
assert anchor_time == 'entity'
|
|
226
|
-
time = self._graph_store.time_dict[
|
|
227
|
-
self._query.entity.pkey.table_name]
|
|
242
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
228
243
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
229
244
|
|
|
230
245
|
y, mask = self(node, time)
|
|
@@ -249,7 +264,8 @@ class LocalPQueryDriver:
|
|
|
249
264
|
reached_end = True
|
|
250
265
|
break
|
|
251
266
|
candidate_offset = 0
|
|
252
|
-
|
|
267
|
+
time_frame = self._query.target_timeframe.timeframe
|
|
268
|
+
anchor_time = anchor_time - (time_frame *
|
|
253
269
|
self._query.num_forecasts)
|
|
254
270
|
if anchor_time < self._graph_store.min_time:
|
|
255
271
|
reached_end = True
|
|
@@ -304,26 +320,25 @@ class LocalPQueryDriver:
|
|
|
304
320
|
time = time.astype('datetime64[ns]').reset_index(drop=True)
|
|
305
321
|
else:
|
|
306
322
|
assert anchor_time == 'entity'
|
|
307
|
-
time = self._graph_store.time_dict[
|
|
308
|
-
self._query.entity.pkey.table_name]
|
|
323
|
+
time = self._graph_store.time_dict[self._query.entity_table]
|
|
309
324
|
time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
|
|
310
325
|
|
|
311
|
-
if self._query.
|
|
326
|
+
if isinstance(self._query.entity_ast, Filter):
|
|
312
327
|
# Mask out via (temporal) entity filter:
|
|
313
|
-
|
|
328
|
+
executor = PQueryPandasExecutor()
|
|
314
329
|
masks: List[np.ndarray] = []
|
|
315
330
|
for start in range(0, len(node), batch_size):
|
|
316
331
|
feat_dict, time_dict, batch_dict = self._sample(
|
|
317
332
|
node[start:start + batch_size],
|
|
318
333
|
time.iloc[start:start + batch_size],
|
|
319
334
|
)
|
|
320
|
-
_mask =
|
|
321
|
-
filter=self._query.
|
|
335
|
+
_mask = executor.execute_filter(
|
|
336
|
+
filter=self._query.entity_ast,
|
|
322
337
|
feat_dict=feat_dict,
|
|
323
338
|
time_dict=time_dict,
|
|
324
339
|
batch_dict=batch_dict,
|
|
325
340
|
anchor_time=time.iloc[start:start + batch_size],
|
|
326
|
-
)
|
|
341
|
+
)[1]
|
|
327
342
|
masks.append(_mask)
|
|
328
343
|
|
|
329
344
|
_mask = np.concatenate(masks)
|
|
@@ -334,6 +349,96 @@ class LocalPQueryDriver:
|
|
|
334
349
|
|
|
335
350
|
return mask
|
|
336
351
|
|
|
352
|
+
def _get_sampling_specs(
|
|
353
|
+
self,
|
|
354
|
+
node: ASTNode,
|
|
355
|
+
hop: int,
|
|
356
|
+
seed_table_name: str,
|
|
357
|
+
edge_types: List[Tuple[str, str, str]],
|
|
358
|
+
num_forecasts: int = 1,
|
|
359
|
+
) -> List[SamplingSpec]:
|
|
360
|
+
if isinstance(node, (Aggregation, Column)):
|
|
361
|
+
if isinstance(node, Column):
|
|
362
|
+
table_name = node.fqn.split('.')[0]
|
|
363
|
+
if seed_table_name == table_name:
|
|
364
|
+
return []
|
|
365
|
+
else:
|
|
366
|
+
table_name = node._get_target_column_name().split('.')[0]
|
|
367
|
+
|
|
368
|
+
target_edge_types = [
|
|
369
|
+
edge_type for edge_type in edge_types if
|
|
370
|
+
edge_type[2] == seed_table_name and edge_type[0] == table_name
|
|
371
|
+
]
|
|
372
|
+
if len(target_edge_types) != 1:
|
|
373
|
+
raise ValueError(
|
|
374
|
+
f"Could not find a unique foreign key from table "
|
|
375
|
+
f"'{seed_table_name}' to '{table_name}'")
|
|
376
|
+
|
|
377
|
+
if isinstance(node, Column):
|
|
378
|
+
return [
|
|
379
|
+
SamplingSpec(
|
|
380
|
+
edge_type=target_edge_types[0],
|
|
381
|
+
hop=hop + 1,
|
|
382
|
+
start_offset=None,
|
|
383
|
+
end_offset=None,
|
|
384
|
+
)
|
|
385
|
+
]
|
|
386
|
+
spec = SamplingSpec(
|
|
387
|
+
edge_type=target_edge_types[0],
|
|
388
|
+
hop=hop + 1,
|
|
389
|
+
start_offset=node.aggr_time_range.start_date_offset,
|
|
390
|
+
end_offset=node.aggr_time_range.end_date_offset *
|
|
391
|
+
num_forecasts,
|
|
392
|
+
)
|
|
393
|
+
return [spec] + self._get_sampling_specs(
|
|
394
|
+
node.target, hop=hop + 1, seed_table_name=table_name,
|
|
395
|
+
edge_types=edge_types, num_forecasts=num_forecasts)
|
|
396
|
+
specs = []
|
|
397
|
+
for child in node.children:
|
|
398
|
+
specs += self._get_sampling_specs(child, hop, seed_table_name,
|
|
399
|
+
edge_types, num_forecasts)
|
|
400
|
+
return specs
|
|
401
|
+
|
|
402
|
+
def get_sampling_specs(self) -> List[SamplingSpec]:
|
|
403
|
+
edge_types = self._graph_store.edge_types
|
|
404
|
+
specs = self._get_sampling_specs(
|
|
405
|
+
self._query.target_ast, hop=0,
|
|
406
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types,
|
|
407
|
+
num_forecasts=self._query.num_forecasts)
|
|
408
|
+
specs += self._get_sampling_specs(
|
|
409
|
+
self._query.entity_ast, hop=0,
|
|
410
|
+
seed_table_name=self._query.entity_table, edge_types=edge_types)
|
|
411
|
+
if self._query.whatif_ast is not None:
|
|
412
|
+
specs += self._get_sampling_specs(
|
|
413
|
+
self._query.whatif_ast, hop=0,
|
|
414
|
+
seed_table_name=self._query.entity_table,
|
|
415
|
+
edge_types=edge_types)
|
|
416
|
+
# Group specs according to edge type and hop:
|
|
417
|
+
spec_dict: Dict[
|
|
418
|
+
Tuple[Tuple[str, str, str], int],
|
|
419
|
+
Tuple[Optional[DateOffset], Optional[DateOffset]],
|
|
420
|
+
] = {}
|
|
421
|
+
for spec in specs:
|
|
422
|
+
if (spec.edge_type, spec.hop) not in spec_dict:
|
|
423
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
424
|
+
spec.start_offset,
|
|
425
|
+
spec.end_offset,
|
|
426
|
+
)
|
|
427
|
+
else:
|
|
428
|
+
start_offset, end_offset = spec_dict[(
|
|
429
|
+
spec.edge_type,
|
|
430
|
+
spec.hop,
|
|
431
|
+
)]
|
|
432
|
+
spec_dict[(spec.edge_type, spec.hop)] = (
|
|
433
|
+
min_date_offset(start_offset, spec.start_offset),
|
|
434
|
+
max_date_offset(end_offset, spec.end_offset),
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
return [
|
|
438
|
+
SamplingSpec(edge, hop, start_offset, end_offset)
|
|
439
|
+
for (edge, hop), (start_offset, end_offset) in spec_dict.items()
|
|
440
|
+
]
|
|
441
|
+
|
|
337
442
|
def _sample(
|
|
338
443
|
self,
|
|
339
444
|
node: np.ndarray,
|
|
@@ -354,7 +459,7 @@ class LocalPQueryDriver:
|
|
|
354
459
|
The feature dictionary, the time column dictionary and the batch
|
|
355
460
|
dictionary.
|
|
356
461
|
"""
|
|
357
|
-
specs = self.
|
|
462
|
+
specs = self.get_sampling_specs()
|
|
358
463
|
num_hops = max([spec.hop for spec in specs] + [0])
|
|
359
464
|
num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
|
|
360
465
|
time_offsets: Dict[
|
|
@@ -380,7 +485,7 @@ class LocalPQueryDriver:
|
|
|
380
485
|
|
|
381
486
|
edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
|
|
382
487
|
node_types = list(
|
|
383
|
-
set([self._query.
|
|
488
|
+
set([self._query.entity_table])
|
|
384
489
|
| set(src for src, _, _ in edge_types)
|
|
385
490
|
| set(dst for _, _, dst in edge_types))
|
|
386
491
|
|
|
@@ -412,21 +517,33 @@ class LocalPQueryDriver:
|
|
|
412
517
|
'__'.join(edge_type): np.array(values)
|
|
413
518
|
for edge_type, values in time_offsets.items()
|
|
414
519
|
},
|
|
415
|
-
self._query.
|
|
520
|
+
self._query.entity_table,
|
|
416
521
|
node,
|
|
417
522
|
anchor_time.astype(int).to_numpy() // 1000**3,
|
|
418
523
|
)
|
|
419
524
|
|
|
420
525
|
feat_dict: Dict[str, pd.DataFrame] = {}
|
|
421
526
|
time_dict: Dict[str, pd.Series] = {}
|
|
422
|
-
column_dict =
|
|
423
|
-
|
|
527
|
+
column_dict: Dict[str, Set[str]] = {}
|
|
528
|
+
for col in self._query.all_query_columns:
|
|
529
|
+
table_name, col_name = col.split('.')
|
|
530
|
+
if table_name not in column_dict:
|
|
531
|
+
column_dict[table_name] = set()
|
|
532
|
+
if col_name != '*':
|
|
533
|
+
column_dict[table_name].add(col_name)
|
|
534
|
+
time_tables = self.find_time_tables()
|
|
424
535
|
for table_name in set(list(column_dict.keys()) + time_tables):
|
|
425
536
|
df = self._graph_store.df_dict[table_name]
|
|
426
537
|
row_id = node_dict[table_name]
|
|
427
538
|
df = df.iloc[row_id].reset_index(drop=True)
|
|
428
539
|
if table_name in column_dict:
|
|
429
|
-
|
|
540
|
+
if len(column_dict[table_name]) == 0:
|
|
541
|
+
# We are dealing with COUNT(table.*), insert a dummy col
|
|
542
|
+
# to ensure we don't lose the information on node count
|
|
543
|
+
feat_dict[table_name] = pd.DataFrame(
|
|
544
|
+
{'ones': [1] * len(df)})
|
|
545
|
+
else:
|
|
546
|
+
feat_dict[table_name] = df[list(column_dict[table_name])]
|
|
430
547
|
if table_name in time_tables:
|
|
431
548
|
time_col = self._graph_store.time_column_dict[table_name]
|
|
432
549
|
time_dict[table_name] = df[time_col]
|
|
@@ -441,7 +558,7 @@ class LocalPQueryDriver:
|
|
|
441
558
|
|
|
442
559
|
feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
|
|
443
560
|
|
|
444
|
-
y, mask =
|
|
561
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
445
562
|
query=self._query,
|
|
446
563
|
feat_dict=feat_dict,
|
|
447
564
|
time_dict=time_dict,
|
|
@@ -452,6 +569,62 @@ class LocalPQueryDriver:
|
|
|
452
569
|
|
|
453
570
|
return y, mask
|
|
454
571
|
|
|
572
|
+
def find_time_tables(self) -> List[str]:
|
|
573
|
+
def _find_time_tables(node: ASTNode) -> List[str]:
|
|
574
|
+
time_tables = []
|
|
575
|
+
if isinstance(node, Aggregation):
|
|
576
|
+
time_tables.append(
|
|
577
|
+
node._get_target_column_name().split('.')[0])
|
|
578
|
+
for child in node.children:
|
|
579
|
+
time_tables += _find_time_tables(child)
|
|
580
|
+
return time_tables
|
|
581
|
+
|
|
582
|
+
time_tables = _find_time_tables(
|
|
583
|
+
self._query.target_ast) + _find_time_tables(self._query.entity_ast)
|
|
584
|
+
if self._query.whatif_ast is not None:
|
|
585
|
+
time_tables += _find_time_tables(self._query.whatif_ast)
|
|
586
|
+
return list(set(time_tables))
|
|
587
|
+
|
|
588
|
+
@staticmethod
|
|
589
|
+
def get_task_type(
|
|
590
|
+
query: ValidatedPredictiveQuery,
|
|
591
|
+
edge_types: List[Tuple[str, str, str]],
|
|
592
|
+
) -> TaskType:
|
|
593
|
+
if isinstance(query.target_ast, (Condition, LogicalOperation)):
|
|
594
|
+
return TaskType.BINARY_CLASSIFICATION
|
|
595
|
+
|
|
596
|
+
target = query.target_ast
|
|
597
|
+
if isinstance(target, Join):
|
|
598
|
+
target = target.rhs_target
|
|
599
|
+
if isinstance(target, Aggregation):
|
|
600
|
+
if target.aggr == AggregationType.LIST_DISTINCT:
|
|
601
|
+
table_name, col_name = target._get_target_column_name().split(
|
|
602
|
+
'.')
|
|
603
|
+
target_edge_types = [
|
|
604
|
+
edge_type for edge_type in edge_types
|
|
605
|
+
if edge_type[0] == table_name and edge_type[1] == col_name
|
|
606
|
+
]
|
|
607
|
+
if len(target_edge_types) != 1:
|
|
608
|
+
raise NotImplementedError(
|
|
609
|
+
f"Multilabel-classification queries based on "
|
|
610
|
+
f"'LIST_DISTINCT' are not supported yet. If you "
|
|
611
|
+
f"planned to write a link prediction query instead, "
|
|
612
|
+
f"make sure to register '{col_name}' as a "
|
|
613
|
+
f"foreign key.")
|
|
614
|
+
return TaskType.TEMPORAL_LINK_PREDICTION
|
|
615
|
+
|
|
616
|
+
return TaskType.REGRESSION
|
|
617
|
+
|
|
618
|
+
assert isinstance(target, Column)
|
|
619
|
+
|
|
620
|
+
if target.stype in {Stype.ID, Stype.categorical}:
|
|
621
|
+
return TaskType.MULTICLASS_CLASSIFICATION
|
|
622
|
+
|
|
623
|
+
if target.stype in {Stype.numerical}:
|
|
624
|
+
return TaskType.REGRESSION
|
|
625
|
+
|
|
626
|
+
raise NotImplementedError("Task type not yet supported")
|
|
627
|
+
|
|
455
628
|
|
|
456
629
|
def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
457
630
|
r"""Convert a :class:`pandas.DateOffset` into a maximum number of
|
|
@@ -492,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
|
|
|
492
665
|
total_ns += scaled_value
|
|
493
666
|
|
|
494
667
|
return total_ns
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
|
|
671
|
+
if any(arg is None for arg in args):
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
675
|
+
timestamps = [anchor + arg for arg in args]
|
|
676
|
+
assert len(timestamps) > 0
|
|
677
|
+
argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
678
|
+
return args[argmin]
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def max_date_offset(*args: DateOffset) -> DateOffset:
|
|
682
|
+
if any(arg is None for arg in args):
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
anchor = pd.Timestamp('2000-01-01')
|
|
686
|
+
timestamps = [anchor + arg for arg in args]
|
|
687
|
+
assert len(timestamps) > 0
|
|
688
|
+
argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
|
|
689
|
+
return args[argmax]
|
|
@@ -1,11 +1,7 @@
|
|
|
1
|
-
from .backend import PQueryBackend
|
|
2
|
-
from .pandas_backend import PQueryPandasBackend
|
|
3
1
|
from .executor import PQueryExecutor
|
|
4
2
|
from .pandas_executor import PQueryPandasExecutor
|
|
5
3
|
|
|
6
4
|
__all__ = [
|
|
7
|
-
'PQueryBackend',
|
|
8
|
-
'PQueryPandasBackend',
|
|
9
5
|
'PQueryExecutor',
|
|
10
6
|
'PQueryPandasExecutor',
|
|
11
7
|
]
|
|
@@ -118,7 +118,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
118
118
|
target_feat, target_mask = self.execute_column(
|
|
119
119
|
column=aggr.target,
|
|
120
120
|
feat_dict=feat_dict,
|
|
121
|
-
filter_na=
|
|
121
|
+
filter_na=True,
|
|
122
122
|
)
|
|
123
123
|
else:
|
|
124
124
|
assert isinstance(aggr.target, Filter)
|
|
@@ -128,7 +128,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
128
128
|
time_dict=time_dict,
|
|
129
129
|
batch_dict=batch_dict,
|
|
130
130
|
anchor_time=anchor_time,
|
|
131
|
-
filter_na=
|
|
131
|
+
filter_na=True,
|
|
132
132
|
)
|
|
133
133
|
|
|
134
134
|
outs: List[pd.Series] = []
|
|
@@ -137,19 +137,20 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
137
137
|
anchor_target_time = anchor_time[target_batch]
|
|
138
138
|
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
139
139
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
<= anchor_target_time + aggr.aggr_time_range.end_date_offset)
|
|
140
|
+
time_filter_mask = (target_time <= anchor_target_time +
|
|
141
|
+
aggr.aggr_time_range.end_date_offset)
|
|
143
142
|
if aggr.aggr_time_range.start is not None:
|
|
144
143
|
start_offset = aggr.aggr_time_range.start_date_offset
|
|
145
|
-
|
|
144
|
+
time_filter_mask &= (target_time
|
|
146
145
|
> anchor_target_time + start_offset)
|
|
147
146
|
else:
|
|
148
147
|
assert num_forecasts == 1
|
|
148
|
+
curr_target_mask = target_mask & time_filter_mask
|
|
149
149
|
|
|
150
150
|
out, mask = self.execute_aggregation_type(
|
|
151
151
|
aggr.aggr,
|
|
152
|
-
feat=target_feat[
|
|
152
|
+
feat=target_feat[time_filter_mask[target_mask].reset_index(
|
|
153
|
+
drop=True)],
|
|
153
154
|
batch=target_batch[curr_target_mask],
|
|
154
155
|
batch_size=len(anchor_time),
|
|
155
156
|
filter_na=False if num_forecasts > 1 else filter_na,
|
|
@@ -499,7 +500,32 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
499
500
|
)
|
|
500
501
|
else:
|
|
501
502
|
raise NotImplementedError(
|
|
502
|
-
f'{type(query.
|
|
503
|
+
f'{type(query.target_ast)} compilation missing.')
|
|
504
|
+
if query.whatif_ast is not None:
|
|
505
|
+
if isinstance(query.whatif_ast, Condition):
|
|
506
|
+
mask &= self.execute_condition(
|
|
507
|
+
condition=query.whatif_ast,
|
|
508
|
+
feat_dict=feat_dict,
|
|
509
|
+
time_dict=time_dict,
|
|
510
|
+
batch_dict=batch_dict,
|
|
511
|
+
anchor_time=anchor_time,
|
|
512
|
+
filter_na=True,
|
|
513
|
+
num_forecasts=num_forecasts,
|
|
514
|
+
)[0]
|
|
515
|
+
elif isinstance(query.whatif_ast, LogicalOperation):
|
|
516
|
+
mask &= self.execute_logical_operation(
|
|
517
|
+
logical_operation=query.whatif_ast,
|
|
518
|
+
feat_dict=feat_dict,
|
|
519
|
+
time_dict=time_dict,
|
|
520
|
+
batch_dict=batch_dict,
|
|
521
|
+
anchor_time=anchor_time,
|
|
522
|
+
filter_na=True,
|
|
523
|
+
num_forecasts=num_forecasts,
|
|
524
|
+
)[0]
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
|
|
528
|
+
|
|
503
529
|
out = out[mask[_mask]]
|
|
504
530
|
mask &= _mask
|
|
505
531
|
out = out.reset_index(drop=True)
|