kumoai 2.14.0.dev202601011731__cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +300 -0
- kumoai/_logging.py +29 -0
- kumoai/_singleton.py +25 -0
- kumoai/_version.py +1 -0
- kumoai/artifact_export/__init__.py +9 -0
- kumoai/artifact_export/config.py +209 -0
- kumoai/artifact_export/job.py +108 -0
- kumoai/client/__init__.py +5 -0
- kumoai/client/client.py +223 -0
- kumoai/client/connector.py +110 -0
- kumoai/client/endpoints.py +150 -0
- kumoai/client/graph.py +120 -0
- kumoai/client/jobs.py +471 -0
- kumoai/client/online.py +78 -0
- kumoai/client/pquery.py +207 -0
- kumoai/client/rfm.py +112 -0
- kumoai/client/source_table.py +53 -0
- kumoai/client/table.py +101 -0
- kumoai/client/utils.py +130 -0
- kumoai/codegen/__init__.py +19 -0
- kumoai/codegen/cli.py +100 -0
- kumoai/codegen/context.py +16 -0
- kumoai/codegen/edits.py +473 -0
- kumoai/codegen/exceptions.py +10 -0
- kumoai/codegen/generate.py +222 -0
- kumoai/codegen/handlers/__init__.py +4 -0
- kumoai/codegen/handlers/connector.py +118 -0
- kumoai/codegen/handlers/graph.py +71 -0
- kumoai/codegen/handlers/pquery.py +62 -0
- kumoai/codegen/handlers/table.py +109 -0
- kumoai/codegen/handlers/utils.py +42 -0
- kumoai/codegen/identity.py +114 -0
- kumoai/codegen/loader.py +93 -0
- kumoai/codegen/naming.py +94 -0
- kumoai/codegen/registry.py +121 -0
- kumoai/connector/__init__.py +31 -0
- kumoai/connector/base.py +153 -0
- kumoai/connector/bigquery_connector.py +200 -0
- kumoai/connector/databricks_connector.py +213 -0
- kumoai/connector/file_upload_connector.py +189 -0
- kumoai/connector/glue_connector.py +150 -0
- kumoai/connector/s3_connector.py +278 -0
- kumoai/connector/snowflake_connector.py +252 -0
- kumoai/connector/source_table.py +471 -0
- kumoai/connector/utils.py +1796 -0
- kumoai/databricks.py +14 -0
- kumoai/encoder/__init__.py +4 -0
- kumoai/exceptions.py +26 -0
- kumoai/experimental/__init__.py +0 -0
- kumoai/experimental/rfm/__init__.py +210 -0
- kumoai/experimental/rfm/authenticate.py +432 -0
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/backend/local/graph_store.py +297 -0
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +736 -0
- kumoai/experimental/rfm/graph.py +1237 -0
- kumoai/experimental/rfm/infer/__init__.py +19 -0
- kumoai/experimental/rfm/infer/categorical.py +40 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/id.py +46 -0
- kumoai/experimental/rfm/infer/multicategorical.py +48 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/infer/timestamp.py +41 -0
- kumoai/experimental/rfm/pquery/__init__.py +7 -0
- kumoai/experimental/rfm/pquery/executor.py +102 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +530 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +1184 -0
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/experimental/rfm/task_table.py +231 -0
- kumoai/formatting.py +30 -0
- kumoai/futures.py +99 -0
- kumoai/graph/__init__.py +12 -0
- kumoai/graph/column.py +106 -0
- kumoai/graph/graph.py +948 -0
- kumoai/graph/table.py +838 -0
- kumoai/jobs.py +80 -0
- kumoai/kumolib.cpython-310-x86_64-linux-gnu.so +0 -0
- kumoai/mixin.py +28 -0
- kumoai/pquery/__init__.py +25 -0
- kumoai/pquery/prediction_table.py +287 -0
- kumoai/pquery/predictive_query.py +641 -0
- kumoai/pquery/training_table.py +424 -0
- kumoai/spcs.py +121 -0
- kumoai/testing/__init__.py +8 -0
- kumoai/testing/decorators.py +57 -0
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/__init__.py +42 -0
- kumoai/trainer/baseline_trainer.py +93 -0
- kumoai/trainer/config.py +2 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/trainer/job.py +1192 -0
- kumoai/trainer/online_serving.py +258 -0
- kumoai/trainer/trainer.py +475 -0
- kumoai/trainer/util.py +103 -0
- kumoai/utils/__init__.py +11 -0
- kumoai/utils/datasets.py +83 -0
- kumoai/utils/display.py +51 -0
- kumoai/utils/forecasting.py +209 -0
- kumoai/utils/progress_logger.py +343 -0
- kumoai/utils/sql.py +3 -0
- kumoai-2.14.0.dev202601011731.dist-info/METADATA +71 -0
- kumoai-2.14.0.dev202601011731.dist-info/RECORD +122 -0
- kumoai-2.14.0.dev202601011731.dist-info/WHEEL +6 -0
- kumoai-2.14.0.dev202601011731.dist-info/licenses/LICENSE +9 -0
- kumoai-2.14.0.dev202601011731.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,530 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
4
|
+
from kumoapi.pquery.AST import (
|
|
5
|
+
Aggregation,
|
|
6
|
+
Column,
|
|
7
|
+
Condition,
|
|
8
|
+
Constant,
|
|
9
|
+
Filter,
|
|
10
|
+
Join,
|
|
11
|
+
LogicalOperation,
|
|
12
|
+
)
|
|
13
|
+
from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
|
|
14
|
+
|
|
15
|
+
from kumoai.experimental.rfm.pquery import PQueryExecutor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
19
|
+
np.ndarray]):
|
|
20
|
+
def execute_column(
|
|
21
|
+
self,
|
|
22
|
+
column: Column,
|
|
23
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
24
|
+
filter_na: bool = True,
|
|
25
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
26
|
+
table_name, column_name = column.fqn.split(".")
|
|
27
|
+
if column_name == '*':
|
|
28
|
+
out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
|
|
29
|
+
else:
|
|
30
|
+
out = feat_dict[table_name][column_name]
|
|
31
|
+
out = out.reset_index(drop=True)
|
|
32
|
+
|
|
33
|
+
if pd.api.types.is_float_dtype(out):
|
|
34
|
+
out = out.astype('float32')
|
|
35
|
+
|
|
36
|
+
out.name = None
|
|
37
|
+
out.index.name = None
|
|
38
|
+
|
|
39
|
+
mask = out.notna().to_numpy()
|
|
40
|
+
|
|
41
|
+
if not filter_na:
|
|
42
|
+
return out, mask
|
|
43
|
+
|
|
44
|
+
out = out[mask].reset_index(drop=True)
|
|
45
|
+
|
|
46
|
+
# Cast to primitive dtype:
|
|
47
|
+
if pd.api.types.is_integer_dtype(out):
|
|
48
|
+
out = out.astype('int64')
|
|
49
|
+
elif pd.api.types.is_bool_dtype(out):
|
|
50
|
+
out = out.astype('bool')
|
|
51
|
+
|
|
52
|
+
return out, mask
|
|
53
|
+
|
|
54
|
+
def execute_aggregation_type(
|
|
55
|
+
self,
|
|
56
|
+
op: AggregationType,
|
|
57
|
+
feat: pd.Series,
|
|
58
|
+
batch: np.ndarray,
|
|
59
|
+
batch_size: int,
|
|
60
|
+
filter_na: bool = True,
|
|
61
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
62
|
+
|
|
63
|
+
mask = feat.notna()
|
|
64
|
+
feat, batch = feat[mask], batch[mask]
|
|
65
|
+
|
|
66
|
+
if op == AggregationType.LIST_DISTINCT:
|
|
67
|
+
df = pd.DataFrame(dict(feat=feat, batch=batch))
|
|
68
|
+
df = df.drop_duplicates()
|
|
69
|
+
out = df.groupby('batch')['feat'].agg(list)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
df = pd.DataFrame(dict(feat=feat, batch=batch))
|
|
73
|
+
if op == AggregationType.AVG:
|
|
74
|
+
agg = 'mean'
|
|
75
|
+
elif op == AggregationType.COUNT:
|
|
76
|
+
agg = 'size'
|
|
77
|
+
else:
|
|
78
|
+
agg = op.lower()
|
|
79
|
+
out = df.groupby('batch')['feat'].agg(agg)
|
|
80
|
+
|
|
81
|
+
if not pd.api.types.is_datetime64_any_dtype(out):
|
|
82
|
+
out = out.astype('float32')
|
|
83
|
+
|
|
84
|
+
out.name = None
|
|
85
|
+
out.index.name = None
|
|
86
|
+
|
|
87
|
+
if op in {AggregationType.SUM, AggregationType.COUNT}:
|
|
88
|
+
out = out.reindex(range(batch_size), fill_value=0)
|
|
89
|
+
mask = np.ones(batch_size, dtype=bool)
|
|
90
|
+
return out, mask
|
|
91
|
+
|
|
92
|
+
mask = np.zeros(batch_size, dtype=bool)
|
|
93
|
+
mask[batch] = True
|
|
94
|
+
|
|
95
|
+
if filter_na:
|
|
96
|
+
return out.reset_index(drop=True), mask
|
|
97
|
+
|
|
98
|
+
out = out.reindex(range(batch_size), fill_value=pd.NA)
|
|
99
|
+
|
|
100
|
+
return out, mask
|
|
101
|
+
|
|
102
|
+
def execute_aggregation(
|
|
103
|
+
self,
|
|
104
|
+
aggr: Aggregation,
|
|
105
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
106
|
+
time_dict: dict[str, pd.Series],
|
|
107
|
+
batch_dict: dict[str, np.ndarray],
|
|
108
|
+
anchor_time: pd.Series,
|
|
109
|
+
filter_na: bool = True,
|
|
110
|
+
num_forecasts: int = 1,
|
|
111
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
112
|
+
target_table = aggr._get_target_column_name().split('.')[0]
|
|
113
|
+
target_batch = batch_dict[target_table]
|
|
114
|
+
target_time = time_dict[target_table]
|
|
115
|
+
if isinstance(aggr.target, Column):
|
|
116
|
+
target_feat, target_mask = self.execute_column(
|
|
117
|
+
column=aggr.target,
|
|
118
|
+
feat_dict=feat_dict,
|
|
119
|
+
filter_na=True,
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
assert isinstance(aggr.target, Filter)
|
|
123
|
+
target_feat, target_mask = self.execute_filter(
|
|
124
|
+
filter=aggr.target,
|
|
125
|
+
feat_dict=feat_dict,
|
|
126
|
+
time_dict=time_dict,
|
|
127
|
+
batch_dict=batch_dict,
|
|
128
|
+
anchor_time=anchor_time,
|
|
129
|
+
filter_na=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
outs: list[pd.Series] = []
|
|
133
|
+
masks: list[np.ndarray] = []
|
|
134
|
+
for _ in range(num_forecasts):
|
|
135
|
+
anchor_target_time = anchor_time.iloc[target_batch]
|
|
136
|
+
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
137
|
+
|
|
138
|
+
time_filter_mask = (target_time <= anchor_target_time +
|
|
139
|
+
aggr.aggr_time_range.end_date_offset)
|
|
140
|
+
if aggr.aggr_time_range.start is not None:
|
|
141
|
+
start_offset = aggr.aggr_time_range.start_date_offset
|
|
142
|
+
time_filter_mask &= (target_time
|
|
143
|
+
> anchor_target_time + start_offset)
|
|
144
|
+
else:
|
|
145
|
+
assert num_forecasts == 1
|
|
146
|
+
curr_target_mask = target_mask & time_filter_mask
|
|
147
|
+
|
|
148
|
+
out, mask = self.execute_aggregation_type(
|
|
149
|
+
aggr.aggr,
|
|
150
|
+
feat=target_feat[time_filter_mask[target_mask].reset_index(
|
|
151
|
+
drop=True)],
|
|
152
|
+
batch=target_batch[curr_target_mask],
|
|
153
|
+
batch_size=len(anchor_time),
|
|
154
|
+
filter_na=False if num_forecasts > 1 else filter_na,
|
|
155
|
+
)
|
|
156
|
+
outs.append(out)
|
|
157
|
+
masks.append(mask)
|
|
158
|
+
|
|
159
|
+
if num_forecasts > 1:
|
|
160
|
+
anchor_time = (anchor_time +
|
|
161
|
+
aggr.aggr_time_range.end_date_offset)
|
|
162
|
+
if len(outs) == 1:
|
|
163
|
+
assert len(masks) == 1
|
|
164
|
+
return outs[0], masks[0]
|
|
165
|
+
|
|
166
|
+
out = pd.Series([list(ser) for ser in zip(*outs)])
|
|
167
|
+
mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
|
|
168
|
+
|
|
169
|
+
if filter_na:
|
|
170
|
+
out = out[mask].reset_index(drop=True)
|
|
171
|
+
|
|
172
|
+
return out, mask
|
|
173
|
+
|
|
174
|
+
def execute_rel_op(
|
|
175
|
+
self,
|
|
176
|
+
left: pd.Series,
|
|
177
|
+
op: RelOp,
|
|
178
|
+
right: Constant,
|
|
179
|
+
) -> pd.Series:
|
|
180
|
+
|
|
181
|
+
if right.typed_value() is None:
|
|
182
|
+
if op == RelOp.EQ:
|
|
183
|
+
return left.isna()
|
|
184
|
+
assert op == RelOp.NEQ
|
|
185
|
+
return left.notna()
|
|
186
|
+
|
|
187
|
+
# Promote left to float if right is a float to avoid lossy coercion.
|
|
188
|
+
right_value = right.typed_value()
|
|
189
|
+
if pd.api.types.is_integer_dtype(left) and isinstance(
|
|
190
|
+
right_value, float):
|
|
191
|
+
left = left.astype('float64')
|
|
192
|
+
value = pd.Series([right_value], dtype=left.dtype).iloc[0]
|
|
193
|
+
|
|
194
|
+
if op == RelOp.EQ:
|
|
195
|
+
return (left == value).fillna(False).astype(bool)
|
|
196
|
+
if op == RelOp.NEQ:
|
|
197
|
+
out = (left != value).fillna(False).astype(bool)
|
|
198
|
+
out[left.isna()] = False # N/A != right should always be `False`.
|
|
199
|
+
return out
|
|
200
|
+
if op == RelOp.LEQ:
|
|
201
|
+
return (left <= value).fillna(False).astype(bool)
|
|
202
|
+
if op == RelOp.GEQ:
|
|
203
|
+
return (left >= value).fillna(False).astype(bool)
|
|
204
|
+
if op == RelOp.LT:
|
|
205
|
+
return (left < value).fillna(False).astype(bool)
|
|
206
|
+
if op == RelOp.GT:
|
|
207
|
+
return (left > value).fillna(False).astype(bool)
|
|
208
|
+
|
|
209
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
210
|
+
|
|
211
|
+
def execute_member_op(
|
|
212
|
+
self,
|
|
213
|
+
left: pd.Series,
|
|
214
|
+
op: MemberOp,
|
|
215
|
+
right: Constant,
|
|
216
|
+
) -> pd.Series:
|
|
217
|
+
|
|
218
|
+
if op == MemberOp.IN:
|
|
219
|
+
ser = pd.Series(right.typed_value(), dtype=left.dtype)
|
|
220
|
+
return left.isin(ser).astype(bool)
|
|
221
|
+
|
|
222
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
223
|
+
|
|
224
|
+
def execute_condition(
|
|
225
|
+
self,
|
|
226
|
+
condition: Condition,
|
|
227
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
228
|
+
time_dict: dict[str, pd.Series],
|
|
229
|
+
batch_dict: dict[str, np.ndarray],
|
|
230
|
+
anchor_time: pd.Series,
|
|
231
|
+
filter_na: bool = True,
|
|
232
|
+
num_forecasts: int = 1,
|
|
233
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
234
|
+
if num_forecasts > 1:
|
|
235
|
+
raise NotImplementedError("Forecasting not yet implemented for "
|
|
236
|
+
"non-regression tasks")
|
|
237
|
+
|
|
238
|
+
assert isinstance(condition.value, Constant)
|
|
239
|
+
value_is_na = condition.value.typed_value() is None
|
|
240
|
+
if isinstance(condition.target, Column):
|
|
241
|
+
left, mask = self.execute_column(
|
|
242
|
+
column=condition.target,
|
|
243
|
+
feat_dict=feat_dict,
|
|
244
|
+
filter_na=filter_na if not value_is_na else False,
|
|
245
|
+
)
|
|
246
|
+
elif isinstance(condition.target, Join):
|
|
247
|
+
left, mask = self.execute_join(
|
|
248
|
+
join=condition.target,
|
|
249
|
+
feat_dict=feat_dict,
|
|
250
|
+
time_dict=time_dict,
|
|
251
|
+
batch_dict=batch_dict,
|
|
252
|
+
anchor_time=anchor_time,
|
|
253
|
+
filter_na=filter_na if not value_is_na else False,
|
|
254
|
+
)
|
|
255
|
+
else:
|
|
256
|
+
assert isinstance(condition.target, Aggregation)
|
|
257
|
+
left, mask = self.execute_aggregation(
|
|
258
|
+
aggr=condition.target,
|
|
259
|
+
feat_dict=feat_dict,
|
|
260
|
+
time_dict=time_dict,
|
|
261
|
+
batch_dict=batch_dict,
|
|
262
|
+
anchor_time=anchor_time,
|
|
263
|
+
filter_na=filter_na if not value_is_na else False,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if filter_na and value_is_na:
|
|
267
|
+
mask = np.ones(len(left), dtype=bool)
|
|
268
|
+
|
|
269
|
+
if isinstance(condition.op, RelOp):
|
|
270
|
+
out = self.execute_rel_op(
|
|
271
|
+
left=left,
|
|
272
|
+
op=condition.op,
|
|
273
|
+
right=condition.value,
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
assert isinstance(condition.op, MemberOp)
|
|
277
|
+
out = self.execute_member_op(
|
|
278
|
+
left=left,
|
|
279
|
+
op=condition.op,
|
|
280
|
+
right=condition.value,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return out, mask
|
|
284
|
+
|
|
285
|
+
def execute_bool_op(
|
|
286
|
+
self,
|
|
287
|
+
left: pd.Series,
|
|
288
|
+
op: BoolOp,
|
|
289
|
+
right: pd.Series | None,
|
|
290
|
+
) -> pd.Series:
|
|
291
|
+
|
|
292
|
+
# TODO Implement Kleene-Priest three-value logic.
|
|
293
|
+
if op == BoolOp.AND:
|
|
294
|
+
assert right is not None
|
|
295
|
+
return left & right
|
|
296
|
+
if op == BoolOp.OR:
|
|
297
|
+
assert right is not None
|
|
298
|
+
return left | right
|
|
299
|
+
if op == BoolOp.NOT:
|
|
300
|
+
return ~left
|
|
301
|
+
|
|
302
|
+
raise NotImplementedError(f"Operator '{op}' not implemented")
|
|
303
|
+
|
|
304
|
+
def execute_logical_operation(
|
|
305
|
+
self,
|
|
306
|
+
logical_operation: LogicalOperation,
|
|
307
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
308
|
+
time_dict: dict[str, pd.Series],
|
|
309
|
+
batch_dict: dict[str, np.ndarray],
|
|
310
|
+
anchor_time: pd.Series,
|
|
311
|
+
filter_na: bool = True,
|
|
312
|
+
num_forecasts: int = 1,
|
|
313
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
314
|
+
if num_forecasts > 1:
|
|
315
|
+
raise NotImplementedError("Forecasting not yet implemented for "
|
|
316
|
+
"non-regression tasks")
|
|
317
|
+
|
|
318
|
+
if isinstance(logical_operation.left, Condition):
|
|
319
|
+
left, mask = self.execute_condition(
|
|
320
|
+
condition=logical_operation.left,
|
|
321
|
+
feat_dict=feat_dict,
|
|
322
|
+
time_dict=time_dict,
|
|
323
|
+
batch_dict=batch_dict,
|
|
324
|
+
anchor_time=anchor_time,
|
|
325
|
+
filter_na=False,
|
|
326
|
+
)
|
|
327
|
+
else:
|
|
328
|
+
assert isinstance(logical_operation.left, LogicalOperation)
|
|
329
|
+
left, mask = self.execute_logical_operation(
|
|
330
|
+
logical_operation=logical_operation.left,
|
|
331
|
+
feat_dict=feat_dict,
|
|
332
|
+
time_dict=time_dict,
|
|
333
|
+
batch_dict=batch_dict,
|
|
334
|
+
anchor_time=anchor_time,
|
|
335
|
+
filter_na=False,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
right = right_mask = None
|
|
339
|
+
if isinstance(logical_operation.right, Condition):
|
|
340
|
+
right, right_mask = self.execute_condition(
|
|
341
|
+
condition=logical_operation.right,
|
|
342
|
+
feat_dict=feat_dict,
|
|
343
|
+
time_dict=time_dict,
|
|
344
|
+
batch_dict=batch_dict,
|
|
345
|
+
anchor_time=anchor_time,
|
|
346
|
+
filter_na=False,
|
|
347
|
+
)
|
|
348
|
+
elif isinstance(logical_operation.right, LogicalOperation):
|
|
349
|
+
right, right_mask = self.execute_logical_operation(
|
|
350
|
+
logical_operation=logical_operation.right,
|
|
351
|
+
feat_dict=feat_dict,
|
|
352
|
+
time_dict=time_dict,
|
|
353
|
+
batch_dict=batch_dict,
|
|
354
|
+
anchor_time=anchor_time,
|
|
355
|
+
filter_na=False,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
out = self.execute_bool_op(left, logical_operation.bool_op, right)
|
|
359
|
+
|
|
360
|
+
if right_mask is not None:
|
|
361
|
+
mask &= right_mask
|
|
362
|
+
|
|
363
|
+
if filter_na:
|
|
364
|
+
out = out[mask].reset_index(drop=True)
|
|
365
|
+
|
|
366
|
+
return out, mask
|
|
367
|
+
|
|
368
|
+
def execute_join(
|
|
369
|
+
self,
|
|
370
|
+
join: Join,
|
|
371
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
372
|
+
time_dict: dict[str, pd.Series],
|
|
373
|
+
batch_dict: dict[str, np.ndarray],
|
|
374
|
+
anchor_time: pd.Series,
|
|
375
|
+
filter_na: bool = True,
|
|
376
|
+
num_forecasts: int = 1,
|
|
377
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
378
|
+
if isinstance(join.rhs_target, Aggregation):
|
|
379
|
+
return self.execute_aggregation(
|
|
380
|
+
aggr=join.rhs_target,
|
|
381
|
+
feat_dict=feat_dict,
|
|
382
|
+
time_dict=time_dict,
|
|
383
|
+
batch_dict=batch_dict,
|
|
384
|
+
anchor_time=anchor_time,
|
|
385
|
+
filter_na=True,
|
|
386
|
+
num_forecasts=num_forecasts,
|
|
387
|
+
)
|
|
388
|
+
raise NotImplementedError(
|
|
389
|
+
f'Unexpected {type(join.rhs_target)} nested in Join')
|
|
390
|
+
|
|
391
|
+
def execute_filter(
|
|
392
|
+
self,
|
|
393
|
+
filter: Filter,
|
|
394
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
395
|
+
time_dict: dict[str, pd.Series],
|
|
396
|
+
batch_dict: dict[str, np.ndarray],
|
|
397
|
+
anchor_time: pd.Series,
|
|
398
|
+
filter_na: bool = True,
|
|
399
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
400
|
+
out, mask = self.execute_column(
|
|
401
|
+
column=filter.target,
|
|
402
|
+
feat_dict=feat_dict,
|
|
403
|
+
filter_na=False,
|
|
404
|
+
)
|
|
405
|
+
if isinstance(filter.condition, Condition):
|
|
406
|
+
_mask = self.execute_condition(
|
|
407
|
+
condition=filter.condition,
|
|
408
|
+
feat_dict=feat_dict,
|
|
409
|
+
time_dict=time_dict,
|
|
410
|
+
batch_dict=batch_dict,
|
|
411
|
+
anchor_time=anchor_time,
|
|
412
|
+
filter_na=False,
|
|
413
|
+
)[0].to_numpy()
|
|
414
|
+
else:
|
|
415
|
+
assert isinstance(filter.condition, LogicalOperation)
|
|
416
|
+
_mask = self.execute_logical_operation(
|
|
417
|
+
logical_operation=filter.condition,
|
|
418
|
+
feat_dict=feat_dict,
|
|
419
|
+
time_dict=time_dict,
|
|
420
|
+
batch_dict=batch_dict,
|
|
421
|
+
anchor_time=anchor_time,
|
|
422
|
+
filter_na=False,
|
|
423
|
+
)[0].to_numpy()
|
|
424
|
+
if filter_na:
|
|
425
|
+
return out[_mask & mask].reset_index(drop=True), _mask & mask
|
|
426
|
+
else:
|
|
427
|
+
return out[_mask].reset_index(drop=True), mask & _mask
|
|
428
|
+
|
|
429
|
+
def execute(
|
|
430
|
+
self,
|
|
431
|
+
query: ValidatedPredictiveQuery,
|
|
432
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
433
|
+
time_dict: dict[str, pd.Series],
|
|
434
|
+
batch_dict: dict[str, np.ndarray],
|
|
435
|
+
anchor_time: pd.Series,
|
|
436
|
+
num_forecasts: int = 1,
|
|
437
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
438
|
+
if isinstance(query.entity_ast, Column):
|
|
439
|
+
out, mask = self.execute_column(
|
|
440
|
+
column=query.entity_ast,
|
|
441
|
+
feat_dict=feat_dict,
|
|
442
|
+
filter_na=True,
|
|
443
|
+
)
|
|
444
|
+
else:
|
|
445
|
+
assert isinstance(query.entity_ast, Filter)
|
|
446
|
+
out, mask = self.execute_filter(
|
|
447
|
+
filter=query.entity_ast,
|
|
448
|
+
feat_dict=feat_dict,
|
|
449
|
+
time_dict=time_dict,
|
|
450
|
+
batch_dict=batch_dict,
|
|
451
|
+
anchor_time=anchor_time,
|
|
452
|
+
)
|
|
453
|
+
if isinstance(query.target_ast, Column):
|
|
454
|
+
out, _mask = self.execute_column(
|
|
455
|
+
column=query.target_ast,
|
|
456
|
+
feat_dict=feat_dict,
|
|
457
|
+
filter_na=True,
|
|
458
|
+
)
|
|
459
|
+
elif isinstance(query.target_ast, Condition):
|
|
460
|
+
out, _mask = self.execute_condition(
|
|
461
|
+
condition=query.target_ast,
|
|
462
|
+
feat_dict=feat_dict,
|
|
463
|
+
time_dict=time_dict,
|
|
464
|
+
batch_dict=batch_dict,
|
|
465
|
+
anchor_time=anchor_time,
|
|
466
|
+
filter_na=True,
|
|
467
|
+
num_forecasts=num_forecasts,
|
|
468
|
+
)
|
|
469
|
+
elif isinstance(query.target_ast, Aggregation):
|
|
470
|
+
out, _mask = self.execute_aggregation(
|
|
471
|
+
aggr=query.target_ast,
|
|
472
|
+
feat_dict=feat_dict,
|
|
473
|
+
time_dict=time_dict,
|
|
474
|
+
batch_dict=batch_dict,
|
|
475
|
+
anchor_time=anchor_time,
|
|
476
|
+
filter_na=True,
|
|
477
|
+
num_forecasts=num_forecasts,
|
|
478
|
+
)
|
|
479
|
+
elif isinstance(query.target_ast, Join):
|
|
480
|
+
out, _mask = self.execute_join(
|
|
481
|
+
join=query.target_ast,
|
|
482
|
+
feat_dict=feat_dict,
|
|
483
|
+
time_dict=time_dict,
|
|
484
|
+
batch_dict=batch_dict,
|
|
485
|
+
anchor_time=anchor_time,
|
|
486
|
+
filter_na=True,
|
|
487
|
+
num_forecasts=num_forecasts,
|
|
488
|
+
)
|
|
489
|
+
elif isinstance(query.target_ast, LogicalOperation):
|
|
490
|
+
out, _mask = self.execute_logical_operation(
|
|
491
|
+
logical_operation=query.target_ast,
|
|
492
|
+
feat_dict=feat_dict,
|
|
493
|
+
time_dict=time_dict,
|
|
494
|
+
batch_dict=batch_dict,
|
|
495
|
+
anchor_time=anchor_time,
|
|
496
|
+
filter_na=True,
|
|
497
|
+
num_forecasts=num_forecasts,
|
|
498
|
+
)
|
|
499
|
+
else:
|
|
500
|
+
raise NotImplementedError(
|
|
501
|
+
f'{type(query.target_ast)} compilation missing.')
|
|
502
|
+
if query.whatif_ast is not None:
|
|
503
|
+
if isinstance(query.whatif_ast, Condition):
|
|
504
|
+
mask &= self.execute_condition(
|
|
505
|
+
condition=query.whatif_ast,
|
|
506
|
+
feat_dict=feat_dict,
|
|
507
|
+
time_dict=time_dict,
|
|
508
|
+
batch_dict=batch_dict,
|
|
509
|
+
anchor_time=anchor_time,
|
|
510
|
+
filter_na=True,
|
|
511
|
+
num_forecasts=num_forecasts,
|
|
512
|
+
)[0]
|
|
513
|
+
elif isinstance(query.whatif_ast, LogicalOperation):
|
|
514
|
+
mask &= self.execute_logical_operation(
|
|
515
|
+
logical_operation=query.whatif_ast,
|
|
516
|
+
feat_dict=feat_dict,
|
|
517
|
+
time_dict=time_dict,
|
|
518
|
+
batch_dict=batch_dict,
|
|
519
|
+
anchor_time=anchor_time,
|
|
520
|
+
filter_na=True,
|
|
521
|
+
num_forecasts=num_forecasts,
|
|
522
|
+
)[0]
|
|
523
|
+
else:
|
|
524
|
+
raise ValueError(
|
|
525
|
+
f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
|
|
526
|
+
|
|
527
|
+
out = out[mask[_mask]]
|
|
528
|
+
mask &= _mask
|
|
529
|
+
out = out.reset_index(drop=True)
|
|
530
|
+
return out, mask
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|