kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
|
@@ -1,23 +1,14 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import Dict, Generic,
|
|
2
|
+
from typing import Dict, Generic, Tuple, TypeVar
|
|
3
3
|
|
|
4
|
-
from kumoapi.
|
|
5
|
-
from kumoapi.
|
|
4
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
5
|
+
from kumoapi.pquery.AST import (
|
|
6
6
|
Aggregation,
|
|
7
|
-
AggregationType,
|
|
8
|
-
BoolOp,
|
|
9
7
|
Column,
|
|
10
8
|
Condition,
|
|
11
9
|
Filter,
|
|
12
|
-
|
|
13
|
-
FloatList,
|
|
14
|
-
Int,
|
|
15
|
-
IntList,
|
|
10
|
+
Join,
|
|
16
11
|
LogicalOperation,
|
|
17
|
-
MemberOp,
|
|
18
|
-
RelOp,
|
|
19
|
-
Str,
|
|
20
|
-
StrList,
|
|
21
12
|
)
|
|
22
13
|
|
|
23
14
|
TableData = TypeVar('TableData')
|
|
@@ -25,108 +16,87 @@ ColumnData = TypeVar('ColumnData')
|
|
|
25
16
|
IndexData = TypeVar('IndexData')
|
|
26
17
|
|
|
27
18
|
|
|
28
|
-
class
|
|
19
|
+
class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
|
|
29
20
|
@abstractmethod
|
|
30
|
-
def
|
|
21
|
+
def execute_column(
|
|
31
22
|
self,
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
batch: IndexData,
|
|
35
|
-
batch_size: int,
|
|
23
|
+
column: Column,
|
|
24
|
+
feat_dict: Dict[str, TableData],
|
|
36
25
|
filter_na: bool = True,
|
|
37
26
|
) -> Tuple[ColumnData, IndexData]:
|
|
38
27
|
pass
|
|
39
28
|
|
|
40
29
|
@abstractmethod
|
|
41
|
-
def
|
|
30
|
+
def execute_aggregation(
|
|
42
31
|
self,
|
|
43
|
-
|
|
44
|
-
op: RelOp,
|
|
45
|
-
right: Union[Int, Float, Str, None],
|
|
46
|
-
) -> ColumnData:
|
|
47
|
-
pass
|
|
48
|
-
|
|
49
|
-
@abstractmethod
|
|
50
|
-
def eval_member_op(
|
|
51
|
-
self,
|
|
52
|
-
left: ColumnData,
|
|
53
|
-
op: MemberOp,
|
|
54
|
-
right: Union[IntList, FloatList, StrList],
|
|
55
|
-
) -> ColumnData:
|
|
56
|
-
pass
|
|
57
|
-
|
|
58
|
-
@abstractmethod
|
|
59
|
-
def eval_bool_op(
|
|
60
|
-
self,
|
|
61
|
-
left: ColumnData,
|
|
62
|
-
op: BoolOp,
|
|
63
|
-
right: Optional[ColumnData],
|
|
64
|
-
) -> ColumnData:
|
|
65
|
-
pass
|
|
66
|
-
|
|
67
|
-
@abstractmethod
|
|
68
|
-
def eval_column(
|
|
69
|
-
self,
|
|
70
|
-
column: Column,
|
|
32
|
+
aggr: Aggregation,
|
|
71
33
|
feat_dict: Dict[str, TableData],
|
|
34
|
+
time_dict: Dict[str, ColumnData],
|
|
35
|
+
batch_dict: Dict[str, IndexData],
|
|
36
|
+
anchor_time: ColumnData,
|
|
72
37
|
filter_na: bool = True,
|
|
38
|
+
num_forecasts: int = 1,
|
|
73
39
|
) -> Tuple[ColumnData, IndexData]:
|
|
74
40
|
pass
|
|
75
41
|
|
|
76
42
|
@abstractmethod
|
|
77
|
-
def
|
|
43
|
+
def execute_condition(
|
|
78
44
|
self,
|
|
79
|
-
|
|
45
|
+
condition: Condition,
|
|
80
46
|
feat_dict: Dict[str, TableData],
|
|
81
47
|
time_dict: Dict[str, ColumnData],
|
|
82
48
|
batch_dict: Dict[str, IndexData],
|
|
83
49
|
anchor_time: ColumnData,
|
|
84
50
|
filter_na: bool = True,
|
|
51
|
+
num_forecasts: int = 1,
|
|
85
52
|
) -> Tuple[ColumnData, IndexData]:
|
|
86
53
|
pass
|
|
87
54
|
|
|
88
55
|
@abstractmethod
|
|
89
|
-
def
|
|
56
|
+
def execute_logical_operation(
|
|
90
57
|
self,
|
|
91
|
-
|
|
58
|
+
logical_operation: LogicalOperation,
|
|
92
59
|
feat_dict: Dict[str, TableData],
|
|
93
60
|
time_dict: Dict[str, ColumnData],
|
|
94
61
|
batch_dict: Dict[str, IndexData],
|
|
95
62
|
anchor_time: ColumnData,
|
|
96
63
|
filter_na: bool = True,
|
|
64
|
+
num_forecasts: int = 1,
|
|
97
65
|
) -> Tuple[ColumnData, IndexData]:
|
|
98
66
|
pass
|
|
99
67
|
|
|
100
68
|
@abstractmethod
|
|
101
|
-
def
|
|
69
|
+
def execute_join(
|
|
102
70
|
self,
|
|
103
|
-
|
|
71
|
+
join: Join,
|
|
104
72
|
feat_dict: Dict[str, TableData],
|
|
105
73
|
time_dict: Dict[str, ColumnData],
|
|
106
74
|
batch_dict: Dict[str, IndexData],
|
|
107
75
|
anchor_time: ColumnData,
|
|
108
76
|
filter_na: bool = True,
|
|
77
|
+
num_forecasts: int = 1,
|
|
109
78
|
) -> Tuple[ColumnData, IndexData]:
|
|
110
79
|
pass
|
|
111
80
|
|
|
112
81
|
@abstractmethod
|
|
113
|
-
def
|
|
82
|
+
def execute_filter(
|
|
114
83
|
self,
|
|
115
84
|
filter: Filter,
|
|
116
85
|
feat_dict: Dict[str, TableData],
|
|
117
86
|
time_dict: Dict[str, ColumnData],
|
|
118
87
|
batch_dict: Dict[str, IndexData],
|
|
119
88
|
anchor_time: ColumnData,
|
|
120
|
-
) -> IndexData:
|
|
89
|
+
) -> Tuple[ColumnData, IndexData]:
|
|
121
90
|
pass
|
|
122
91
|
|
|
123
92
|
@abstractmethod
|
|
124
|
-
def
|
|
93
|
+
def execute(
|
|
125
94
|
self,
|
|
126
|
-
query:
|
|
95
|
+
query: ValidatedPredictiveQuery,
|
|
127
96
|
feat_dict: Dict[str, TableData],
|
|
128
97
|
time_dict: Dict[str, ColumnData],
|
|
129
98
|
batch_dict: Dict[str, IndexData],
|
|
130
99
|
anchor_time: ColumnData,
|
|
100
|
+
num_forecasts: int = 1,
|
|
131
101
|
) -> Tuple[ColumnData, IndexData]:
|
|
132
102
|
pass
|
|
@@ -0,0 +1,532 @@
|
|
|
1
|
+
from typing import Dict, List, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
6
|
+
from kumoapi.pquery.AST import (
|
|
7
|
+
Aggregation,
|
|
8
|
+
Column,
|
|
9
|
+
Condition,
|
|
10
|
+
Constant,
|
|
11
|
+
Filter,
|
|
12
|
+
Join,
|
|
13
|
+
LogicalOperation,
|
|
14
|
+
)
|
|
15
|
+
from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
|
|
16
|
+
|
|
17
|
+
from kumoai.experimental.rfm.pquery import PQueryExecutor
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
21
|
+
np.ndarray]):
|
|
22
|
+
def execute_column(
|
|
23
|
+
self,
|
|
24
|
+
column: Column,
|
|
25
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
26
|
+
filter_na: bool = True,
|
|
27
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
28
|
+
table_name, column_name = column.fqn.split(".")
|
|
29
|
+
if column_name == '*':
|
|
30
|
+
out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
|
|
31
|
+
else:
|
|
32
|
+
out = feat_dict[table_name][column_name]
|
|
33
|
+
out = out.reset_index(drop=True)
|
|
34
|
+
|
|
35
|
+
if pd.api.types.is_float_dtype(out):
|
|
36
|
+
out = out.astype('float32')
|
|
37
|
+
|
|
38
|
+
out.name = None
|
|
39
|
+
out.index.name = None
|
|
40
|
+
|
|
41
|
+
mask = out.notna().to_numpy()
|
|
42
|
+
|
|
43
|
+
if not filter_na:
|
|
44
|
+
return out, mask
|
|
45
|
+
|
|
46
|
+
out = out[mask].reset_index(drop=True)
|
|
47
|
+
|
|
48
|
+
# Cast to primitive dtype:
|
|
49
|
+
if pd.api.types.is_integer_dtype(out):
|
|
50
|
+
out = out.astype('int64')
|
|
51
|
+
elif pd.api.types.is_bool_dtype(out):
|
|
52
|
+
out = out.astype('bool')
|
|
53
|
+
|
|
54
|
+
return out, mask
|
|
55
|
+
|
|
56
|
+
def execute_aggregation_type(
|
|
57
|
+
self,
|
|
58
|
+
op: AggregationType,
|
|
59
|
+
feat: pd.Series,
|
|
60
|
+
batch: np.ndarray,
|
|
61
|
+
batch_size: int,
|
|
62
|
+
filter_na: bool = True,
|
|
63
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
64
|
+
|
|
65
|
+
mask = feat.notna()
|
|
66
|
+
feat, batch = feat[mask], batch[mask]
|
|
67
|
+
|
|
68
|
+
if op == AggregationType.LIST_DISTINCT:
|
|
69
|
+
df = pd.DataFrame(dict(feat=feat, batch=batch))
|
|
70
|
+
df = df.drop_duplicates()
|
|
71
|
+
out = df.groupby('batch')['feat'].agg(list)
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
df = pd.DataFrame(dict(feat=feat, batch=batch))
|
|
75
|
+
if op == AggregationType.AVG:
|
|
76
|
+
agg = 'mean'
|
|
77
|
+
elif op == AggregationType.COUNT:
|
|
78
|
+
agg = 'size'
|
|
79
|
+
else:
|
|
80
|
+
agg = op.lower()
|
|
81
|
+
out = df.groupby('batch')['feat'].agg(agg)
|
|
82
|
+
|
|
83
|
+
if not pd.api.types.is_datetime64_any_dtype(out):
|
|
84
|
+
out = out.astype('float32')
|
|
85
|
+
|
|
86
|
+
out.name = None
|
|
87
|
+
out.index.name = None
|
|
88
|
+
|
|
89
|
+
if op in {AggregationType.SUM, AggregationType.COUNT}:
|
|
90
|
+
out = out.reindex(range(batch_size), fill_value=0)
|
|
91
|
+
mask = np.ones(batch_size, dtype=bool)
|
|
92
|
+
return out, mask
|
|
93
|
+
|
|
94
|
+
mask = np.zeros(batch_size, dtype=bool)
|
|
95
|
+
mask[batch] = True
|
|
96
|
+
|
|
97
|
+
if filter_na:
|
|
98
|
+
return out.reset_index(drop=True), mask
|
|
99
|
+
|
|
100
|
+
out = out.reindex(range(batch_size), fill_value=pd.NA)
|
|
101
|
+
|
|
102
|
+
return out, mask
|
|
103
|
+
|
|
104
|
+
def execute_aggregation(
|
|
105
|
+
self,
|
|
106
|
+
aggr: Aggregation,
|
|
107
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
108
|
+
time_dict: Dict[str, pd.Series],
|
|
109
|
+
batch_dict: Dict[str, np.ndarray],
|
|
110
|
+
anchor_time: pd.Series,
|
|
111
|
+
filter_na: bool = True,
|
|
112
|
+
num_forecasts: int = 1,
|
|
113
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
114
|
+
target_table = aggr._get_target_column_name().split('.')[0]
|
|
115
|
+
target_batch = batch_dict[target_table]
|
|
116
|
+
target_time = time_dict[target_table]
|
|
117
|
+
if isinstance(aggr.target, Column):
|
|
118
|
+
target_feat, target_mask = self.execute_column(
|
|
119
|
+
column=aggr.target,
|
|
120
|
+
feat_dict=feat_dict,
|
|
121
|
+
filter_na=True,
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
assert isinstance(aggr.target, Filter)
|
|
125
|
+
target_feat, target_mask = self.execute_filter(
|
|
126
|
+
filter=aggr.target,
|
|
127
|
+
feat_dict=feat_dict,
|
|
128
|
+
time_dict=time_dict,
|
|
129
|
+
batch_dict=batch_dict,
|
|
130
|
+
anchor_time=anchor_time,
|
|
131
|
+
filter_na=True,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
outs: List[pd.Series] = []
|
|
135
|
+
masks: List[np.ndarray] = []
|
|
136
|
+
for _ in range(num_forecasts):
|
|
137
|
+
anchor_target_time = anchor_time[target_batch]
|
|
138
|
+
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
139
|
+
|
|
140
|
+
time_filter_mask = (target_time <= anchor_target_time +
|
|
141
|
+
aggr.aggr_time_range.end_date_offset)
|
|
142
|
+
if aggr.aggr_time_range.start is not None:
|
|
143
|
+
start_offset = aggr.aggr_time_range.start_date_offset
|
|
144
|
+
time_filter_mask &= (target_time
|
|
145
|
+
> anchor_target_time + start_offset)
|
|
146
|
+
else:
|
|
147
|
+
assert num_forecasts == 1
|
|
148
|
+
curr_target_mask = target_mask & time_filter_mask
|
|
149
|
+
|
|
150
|
+
out, mask = self.execute_aggregation_type(
|
|
151
|
+
aggr.aggr,
|
|
152
|
+
feat=target_feat[time_filter_mask[target_mask].reset_index(
|
|
153
|
+
drop=True)],
|
|
154
|
+
batch=target_batch[curr_target_mask],
|
|
155
|
+
batch_size=len(anchor_time),
|
|
156
|
+
filter_na=False if num_forecasts > 1 else filter_na,
|
|
157
|
+
)
|
|
158
|
+
outs.append(out)
|
|
159
|
+
masks.append(mask)
|
|
160
|
+
|
|
161
|
+
if num_forecasts > 1:
|
|
162
|
+
anchor_time = (anchor_time +
|
|
163
|
+
aggr.aggr_time_range.end_date_offset)
|
|
164
|
+
if len(outs) == 1:
|
|
165
|
+
assert len(masks) == 1
|
|
166
|
+
return outs[0], masks[0]
|
|
167
|
+
|
|
168
|
+
out = pd.Series([list(ser) for ser in zip(*outs)])
|
|
169
|
+
mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
|
|
170
|
+
|
|
171
|
+
if filter_na:
|
|
172
|
+
out = out[mask].reset_index(drop=True)
|
|
173
|
+
|
|
174
|
+
return out, mask
|
|
175
|
+
|
|
176
|
+
def execute_rel_op(
|
|
177
|
+
self,
|
|
178
|
+
left: pd.Series,
|
|
179
|
+
op: RelOp,
|
|
180
|
+
right: Constant,
|
|
181
|
+
) -> pd.Series:
|
|
182
|
+
|
|
183
|
+
if right.typed_value() is None:
|
|
184
|
+
if op == RelOp.EQ:
|
|
185
|
+
return left.isna()
|
|
186
|
+
assert op == RelOp.NEQ
|
|
187
|
+
return left.notna()
|
|
188
|
+
|
|
189
|
+
# Promote left to float if right is a float to avoid lossy coercion.
|
|
190
|
+
right_value = right.typed_value()
|
|
191
|
+
if pd.api.types.is_integer_dtype(left) and isinstance(
|
|
192
|
+
right_value, float):
|
|
193
|
+
left = left.astype('float64')
|
|
194
|
+
value = pd.Series([right_value], dtype=left.dtype).iloc[0]
|
|
195
|
+
|
|
196
|
+
if op == RelOp.EQ:
|
|
197
|
+
return (left == value).fillna(False).astype(bool)
|
|
198
|
+
if op == RelOp.NEQ:
|
|
199
|
+
out = (left != value).fillna(False).astype(bool)
|
|
200
|
+
out[left.isna()] = False # N/A != right should always be `False`.
|
|
201
|
+
return out
|
|
202
|
+
if op == RelOp.LEQ:
|
|
203
|
+
return (left <= value).fillna(False).astype(bool)
|
|
204
|
+
if op == RelOp.GEQ:
|
|
205
|
+
return (left >= value).fillna(False).astype(bool)
|
|
206
|
+
if op == RelOp.LT:
|
|
207
|
+
return (left < value).fillna(False).astype(bool)
|
|
208
|
+
if op == RelOp.GT:
|
|
209
|
+
return (left > value).fillna(False).astype(bool)
|
|
210
|
+
|
|
211
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
212
|
+
|
|
213
|
+
def execute_member_op(
|
|
214
|
+
self,
|
|
215
|
+
left: pd.Series,
|
|
216
|
+
op: MemberOp,
|
|
217
|
+
right: Constant,
|
|
218
|
+
) -> pd.Series:
|
|
219
|
+
|
|
220
|
+
if op == MemberOp.IN:
|
|
221
|
+
ser = pd.Series(right.typed_value(), dtype=left.dtype)
|
|
222
|
+
return left.isin(ser).astype(bool)
|
|
223
|
+
|
|
224
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
225
|
+
|
|
226
|
+
def execute_condition(
|
|
227
|
+
self,
|
|
228
|
+
condition: Condition,
|
|
229
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
230
|
+
time_dict: Dict[str, pd.Series],
|
|
231
|
+
batch_dict: Dict[str, np.ndarray],
|
|
232
|
+
anchor_time: pd.Series,
|
|
233
|
+
filter_na: bool = True,
|
|
234
|
+
num_forecasts: int = 1,
|
|
235
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
236
|
+
if num_forecasts > 1:
|
|
237
|
+
raise NotImplementedError("Forecasting not yet implemented for "
|
|
238
|
+
"non-regression tasks")
|
|
239
|
+
|
|
240
|
+
assert isinstance(condition.value, Constant)
|
|
241
|
+
value_is_na = condition.value.typed_value() is None
|
|
242
|
+
if isinstance(condition.target, Column):
|
|
243
|
+
left, mask = self.execute_column(
|
|
244
|
+
column=condition.target,
|
|
245
|
+
feat_dict=feat_dict,
|
|
246
|
+
filter_na=filter_na if not value_is_na else False,
|
|
247
|
+
)
|
|
248
|
+
elif isinstance(condition.target, Join):
|
|
249
|
+
left, mask = self.execute_join(
|
|
250
|
+
join=condition.target,
|
|
251
|
+
feat_dict=feat_dict,
|
|
252
|
+
time_dict=time_dict,
|
|
253
|
+
batch_dict=batch_dict,
|
|
254
|
+
anchor_time=anchor_time,
|
|
255
|
+
filter_na=filter_na if not value_is_na else False,
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
assert isinstance(condition.target, Aggregation)
|
|
259
|
+
left, mask = self.execute_aggregation(
|
|
260
|
+
aggr=condition.target,
|
|
261
|
+
feat_dict=feat_dict,
|
|
262
|
+
time_dict=time_dict,
|
|
263
|
+
batch_dict=batch_dict,
|
|
264
|
+
anchor_time=anchor_time,
|
|
265
|
+
filter_na=filter_na if not value_is_na else False,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if filter_na and value_is_na:
|
|
269
|
+
mask = np.ones(len(left), dtype=bool)
|
|
270
|
+
|
|
271
|
+
if isinstance(condition.op, RelOp):
|
|
272
|
+
out = self.execute_rel_op(
|
|
273
|
+
left=left,
|
|
274
|
+
op=condition.op,
|
|
275
|
+
right=condition.value,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
assert isinstance(condition.op, MemberOp)
|
|
279
|
+
out = self.execute_member_op(
|
|
280
|
+
left=left,
|
|
281
|
+
op=condition.op,
|
|
282
|
+
right=condition.value,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
return out, mask
|
|
286
|
+
|
|
287
|
+
def execute_bool_op(
|
|
288
|
+
self,
|
|
289
|
+
left: pd.Series,
|
|
290
|
+
op: BoolOp,
|
|
291
|
+
right: pd.Series | None,
|
|
292
|
+
) -> pd.Series:
|
|
293
|
+
|
|
294
|
+
# TODO Implement Kleene-Priest three-value logic.
|
|
295
|
+
if op == BoolOp.AND:
|
|
296
|
+
assert right is not None
|
|
297
|
+
return left & right
|
|
298
|
+
if op == BoolOp.OR:
|
|
299
|
+
assert right is not None
|
|
300
|
+
return left | right
|
|
301
|
+
if op == BoolOp.NOT:
|
|
302
|
+
return ~left
|
|
303
|
+
|
|
304
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
305
|
+
|
|
306
|
+
def execute_logical_operation(
|
|
307
|
+
self,
|
|
308
|
+
logical_operation: LogicalOperation,
|
|
309
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
310
|
+
time_dict: Dict[str, pd.Series],
|
|
311
|
+
batch_dict: Dict[str, np.ndarray],
|
|
312
|
+
anchor_time: pd.Series,
|
|
313
|
+
filter_na: bool = True,
|
|
314
|
+
num_forecasts: int = 1,
|
|
315
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
316
|
+
if num_forecasts > 1:
|
|
317
|
+
raise NotImplementedError("Forecasting not yet implemented for "
|
|
318
|
+
"non-regression tasks")
|
|
319
|
+
|
|
320
|
+
if isinstance(logical_operation.left, Condition):
|
|
321
|
+
left, mask = self.execute_condition(
|
|
322
|
+
condition=logical_operation.left,
|
|
323
|
+
feat_dict=feat_dict,
|
|
324
|
+
time_dict=time_dict,
|
|
325
|
+
batch_dict=batch_dict,
|
|
326
|
+
anchor_time=anchor_time,
|
|
327
|
+
filter_na=False,
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
assert isinstance(logical_operation.left, LogicalOperation)
|
|
331
|
+
left, mask = self.execute_logical_operation(
|
|
332
|
+
logical_operation=logical_operation.left,
|
|
333
|
+
feat_dict=feat_dict,
|
|
334
|
+
time_dict=time_dict,
|
|
335
|
+
batch_dict=batch_dict,
|
|
336
|
+
anchor_time=anchor_time,
|
|
337
|
+
filter_na=False,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
right = right_mask = None
|
|
341
|
+
if isinstance(logical_operation.right, Condition):
|
|
342
|
+
right, right_mask = self.execute_condition(
|
|
343
|
+
condition=logical_operation.right,
|
|
344
|
+
feat_dict=feat_dict,
|
|
345
|
+
time_dict=time_dict,
|
|
346
|
+
batch_dict=batch_dict,
|
|
347
|
+
anchor_time=anchor_time,
|
|
348
|
+
filter_na=False,
|
|
349
|
+
)
|
|
350
|
+
elif isinstance(logical_operation.right, LogicalOperation):
|
|
351
|
+
right, right_mask = self.execute_logical_operation(
|
|
352
|
+
logical_operation=logical_operation.right,
|
|
353
|
+
feat_dict=feat_dict,
|
|
354
|
+
time_dict=time_dict,
|
|
355
|
+
batch_dict=batch_dict,
|
|
356
|
+
anchor_time=anchor_time,
|
|
357
|
+
filter_na=False,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
out = self.execute_bool_op(left, logical_operation.bool_op, right)
|
|
361
|
+
|
|
362
|
+
if right_mask is not None:
|
|
363
|
+
mask &= right_mask
|
|
364
|
+
|
|
365
|
+
if filter_na:
|
|
366
|
+
out = out[mask].reset_index(drop=True)
|
|
367
|
+
|
|
368
|
+
return out, mask
|
|
369
|
+
|
|
370
|
+
def execute_join(
|
|
371
|
+
self,
|
|
372
|
+
join: Join,
|
|
373
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
374
|
+
time_dict: Dict[str, pd.Series],
|
|
375
|
+
batch_dict: Dict[str, np.ndarray],
|
|
376
|
+
anchor_time: pd.Series,
|
|
377
|
+
filter_na: bool = True,
|
|
378
|
+
num_forecasts: int = 1,
|
|
379
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
380
|
+
if isinstance(join.rhs_target, Aggregation):
|
|
381
|
+
return self.execute_aggregation(
|
|
382
|
+
aggr=join.rhs_target,
|
|
383
|
+
feat_dict=feat_dict,
|
|
384
|
+
time_dict=time_dict,
|
|
385
|
+
batch_dict=batch_dict,
|
|
386
|
+
anchor_time=anchor_time,
|
|
387
|
+
filter_na=True,
|
|
388
|
+
num_forecasts=num_forecasts,
|
|
389
|
+
)
|
|
390
|
+
raise NotImplementedError(
|
|
391
|
+
f'Unexpected {type(join.rhs_target)} nested in Join')
|
|
392
|
+
|
|
393
|
+
def execute_filter(
|
|
394
|
+
self,
|
|
395
|
+
filter: Filter,
|
|
396
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
397
|
+
time_dict: Dict[str, pd.Series],
|
|
398
|
+
batch_dict: Dict[str, np.ndarray],
|
|
399
|
+
anchor_time: pd.Series,
|
|
400
|
+
filter_na: bool = True,
|
|
401
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
402
|
+
out, mask = self.execute_column(
|
|
403
|
+
column=filter.target,
|
|
404
|
+
feat_dict=feat_dict,
|
|
405
|
+
filter_na=False,
|
|
406
|
+
)
|
|
407
|
+
if isinstance(filter.condition, Condition):
|
|
408
|
+
_mask = self.execute_condition(
|
|
409
|
+
condition=filter.condition,
|
|
410
|
+
feat_dict=feat_dict,
|
|
411
|
+
time_dict=time_dict,
|
|
412
|
+
batch_dict=batch_dict,
|
|
413
|
+
anchor_time=anchor_time,
|
|
414
|
+
filter_na=False,
|
|
415
|
+
)[0].to_numpy()
|
|
416
|
+
else:
|
|
417
|
+
assert isinstance(filter.condition, LogicalOperation)
|
|
418
|
+
_mask = self.execute_logical_operation(
|
|
419
|
+
logical_operation=filter.condition,
|
|
420
|
+
feat_dict=feat_dict,
|
|
421
|
+
time_dict=time_dict,
|
|
422
|
+
batch_dict=batch_dict,
|
|
423
|
+
anchor_time=anchor_time,
|
|
424
|
+
filter_na=False,
|
|
425
|
+
)[0].to_numpy()
|
|
426
|
+
if filter_na:
|
|
427
|
+
return out[_mask & mask].reset_index(drop=True), _mask & mask
|
|
428
|
+
else:
|
|
429
|
+
return out[_mask].reset_index(drop=True), mask & _mask
|
|
430
|
+
|
|
431
|
+
def execute(
|
|
432
|
+
self,
|
|
433
|
+
query: ValidatedPredictiveQuery,
|
|
434
|
+
feat_dict: Dict[str, pd.DataFrame],
|
|
435
|
+
time_dict: Dict[str, pd.Series],
|
|
436
|
+
batch_dict: Dict[str, np.ndarray],
|
|
437
|
+
anchor_time: pd.Series,
|
|
438
|
+
num_forecasts: int = 1,
|
|
439
|
+
) -> Tuple[pd.Series, np.ndarray]:
|
|
440
|
+
if isinstance(query.entity_ast, Column):
|
|
441
|
+
out, mask = self.execute_column(
|
|
442
|
+
column=query.entity_ast,
|
|
443
|
+
feat_dict=feat_dict,
|
|
444
|
+
filter_na=True,
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
assert isinstance(query.entity_ast, Filter)
|
|
448
|
+
out, mask = self.execute_filter(
|
|
449
|
+
filter=query.entity_ast,
|
|
450
|
+
feat_dict=feat_dict,
|
|
451
|
+
time_dict=time_dict,
|
|
452
|
+
batch_dict=batch_dict,
|
|
453
|
+
anchor_time=anchor_time,
|
|
454
|
+
)
|
|
455
|
+
if isinstance(query.target_ast, Column):
|
|
456
|
+
out, _mask = self.execute_column(
|
|
457
|
+
column=query.target_ast,
|
|
458
|
+
feat_dict=feat_dict,
|
|
459
|
+
filter_na=True,
|
|
460
|
+
)
|
|
461
|
+
elif isinstance(query.target_ast, Condition):
|
|
462
|
+
out, _mask = self.execute_condition(
|
|
463
|
+
condition=query.target_ast,
|
|
464
|
+
feat_dict=feat_dict,
|
|
465
|
+
time_dict=time_dict,
|
|
466
|
+
batch_dict=batch_dict,
|
|
467
|
+
anchor_time=anchor_time,
|
|
468
|
+
filter_na=True,
|
|
469
|
+
num_forecasts=num_forecasts,
|
|
470
|
+
)
|
|
471
|
+
elif isinstance(query.target_ast, Aggregation):
|
|
472
|
+
out, _mask = self.execute_aggregation(
|
|
473
|
+
aggr=query.target_ast,
|
|
474
|
+
feat_dict=feat_dict,
|
|
475
|
+
time_dict=time_dict,
|
|
476
|
+
batch_dict=batch_dict,
|
|
477
|
+
anchor_time=anchor_time,
|
|
478
|
+
filter_na=True,
|
|
479
|
+
num_forecasts=num_forecasts,
|
|
480
|
+
)
|
|
481
|
+
elif isinstance(query.target_ast, Join):
|
|
482
|
+
out, _mask = self.execute_join(
|
|
483
|
+
join=query.target_ast,
|
|
484
|
+
feat_dict=feat_dict,
|
|
485
|
+
time_dict=time_dict,
|
|
486
|
+
batch_dict=batch_dict,
|
|
487
|
+
anchor_time=anchor_time,
|
|
488
|
+
filter_na=True,
|
|
489
|
+
num_forecasts=num_forecasts,
|
|
490
|
+
)
|
|
491
|
+
elif isinstance(query.target_ast, LogicalOperation):
|
|
492
|
+
out, _mask = self.execute_logical_operation(
|
|
493
|
+
logical_operation=query.target_ast,
|
|
494
|
+
feat_dict=feat_dict,
|
|
495
|
+
time_dict=time_dict,
|
|
496
|
+
batch_dict=batch_dict,
|
|
497
|
+
anchor_time=anchor_time,
|
|
498
|
+
filter_na=True,
|
|
499
|
+
num_forecasts=num_forecasts,
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
raise NotImplementedError(
|
|
503
|
+
f'{type(query.target_ast)} compilation missing.')
|
|
504
|
+
if query.whatif_ast is not None:
|
|
505
|
+
if isinstance(query.whatif_ast, Condition):
|
|
506
|
+
mask &= self.execute_condition(
|
|
507
|
+
condition=query.whatif_ast,
|
|
508
|
+
feat_dict=feat_dict,
|
|
509
|
+
time_dict=time_dict,
|
|
510
|
+
batch_dict=batch_dict,
|
|
511
|
+
anchor_time=anchor_time,
|
|
512
|
+
filter_na=True,
|
|
513
|
+
num_forecasts=num_forecasts,
|
|
514
|
+
)[0]
|
|
515
|
+
elif isinstance(query.whatif_ast, LogicalOperation):
|
|
516
|
+
mask &= self.execute_logical_operation(
|
|
517
|
+
logical_operation=query.whatif_ast,
|
|
518
|
+
feat_dict=feat_dict,
|
|
519
|
+
time_dict=time_dict,
|
|
520
|
+
batch_dict=batch_dict,
|
|
521
|
+
anchor_time=anchor_time,
|
|
522
|
+
filter_na=True,
|
|
523
|
+
num_forecasts=num_forecasts,
|
|
524
|
+
)[0]
|
|
525
|
+
else:
|
|
526
|
+
raise ValueError(
|
|
527
|
+
f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
|
|
528
|
+
|
|
529
|
+
out = out[mask[_mask]]
|
|
530
|
+
mask &= _mask
|
|
531
|
+
out = out.reset_index(drop=True)
|
|
532
|
+
return out, mask
|