kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (52) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/rfm.py +37 -8
  6. kumoai/connector/file_upload_connector.py +94 -85
  7. kumoai/connector/utils.py +1399 -210
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/authenticate.py +8 -5
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  12. kumoai/experimental/rfm/backend/local/table.py +109 -0
  13. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  16. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  17. kumoai/experimental/rfm/base/__init__.py +10 -0
  18. kumoai/experimental/rfm/base/column.py +66 -0
  19. kumoai/experimental/rfm/base/source.py +18 -0
  20. kumoai/experimental/rfm/base/table.py +545 -0
  21. kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
  22. kumoai/experimental/rfm/infer/__init__.py +6 -0
  23. kumoai/experimental/rfm/infer/dtype.py +79 -0
  24. kumoai/experimental/rfm/infer/pkey.py +126 -0
  25. kumoai/experimental/rfm/infer/time_col.py +62 -0
  26. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  27. kumoai/experimental/rfm/local_graph_sampler.py +58 -11
  28. kumoai/experimental/rfm/local_graph_store.py +45 -37
  29. kumoai/experimental/rfm/local_pquery_driver.py +342 -46
  30. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  31. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
  32. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  33. kumoai/experimental/rfm/rfm.py +559 -148
  34. kumoai/experimental/rfm/sagemaker.py +138 -0
  35. kumoai/jobs.py +27 -1
  36. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  37. kumoai/pquery/prediction_table.py +5 -3
  38. kumoai/pquery/training_table.py +5 -3
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/trainer/job.py +9 -30
  42. kumoai/trainer/trainer.py +19 -10
  43. kumoai/utils/__init__.py +2 -1
  44. kumoai/utils/progress_logger.py +96 -16
  45. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
  46. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
  47. kumoai/experimental/rfm/local_table.py +0 -448
  48. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
  49. kumoai/experimental/rfm/utils.py +0 -347
  50. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
  51. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
  52. {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,14 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Generic, Optional, Tuple, TypeVar, Union
2
+ from typing import Dict, Generic, Tuple, TypeVar
3
3
 
4
- from kumoapi.rfm import PQueryDefinition
5
- from kumoapi.rfm.pquery import (
4
+ from kumoapi.pquery import ValidatedPredictiveQuery
5
+ from kumoapi.pquery.AST import (
6
6
  Aggregation,
7
- AggregationType,
8
- BoolOp,
9
7
  Column,
10
8
  Condition,
11
9
  Filter,
12
- Float,
13
- FloatList,
14
- Int,
15
- IntList,
10
+ Join,
16
11
  LogicalOperation,
17
- MemberOp,
18
- RelOp,
19
- Str,
20
- StrList,
21
12
  )
22
13
 
23
14
  TableData = TypeVar('TableData')
@@ -25,108 +16,87 @@ ColumnData = TypeVar('ColumnData')
25
16
  IndexData = TypeVar('IndexData')
26
17
 
27
18
 
28
- class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
19
+ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
29
20
  @abstractmethod
30
- def eval_aggregation_type(
21
+ def execute_column(
31
22
  self,
32
- op: AggregationType,
33
- feat: Optional[ColumnData],
34
- batch: IndexData,
35
- batch_size: int,
23
+ column: Column,
24
+ feat_dict: Dict[str, TableData],
36
25
  filter_na: bool = True,
37
26
  ) -> Tuple[ColumnData, IndexData]:
38
27
  pass
39
28
 
40
29
  @abstractmethod
41
- def eval_rel_op(
30
+ def execute_aggregation(
42
31
  self,
43
- left: ColumnData,
44
- op: RelOp,
45
- right: Union[Int, Float, Str, None],
46
- ) -> ColumnData:
47
- pass
48
-
49
- @abstractmethod
50
- def eval_member_op(
51
- self,
52
- left: ColumnData,
53
- op: MemberOp,
54
- right: Union[IntList, FloatList, StrList],
55
- ) -> ColumnData:
56
- pass
57
-
58
- @abstractmethod
59
- def eval_bool_op(
60
- self,
61
- left: ColumnData,
62
- op: BoolOp,
63
- right: Optional[ColumnData],
64
- ) -> ColumnData:
65
- pass
66
-
67
- @abstractmethod
68
- def eval_column(
69
- self,
70
- column: Column,
32
+ aggr: Aggregation,
71
33
  feat_dict: Dict[str, TableData],
34
+ time_dict: Dict[str, ColumnData],
35
+ batch_dict: Dict[str, IndexData],
36
+ anchor_time: ColumnData,
72
37
  filter_na: bool = True,
38
+ num_forecasts: int = 1,
73
39
  ) -> Tuple[ColumnData, IndexData]:
74
40
  pass
75
41
 
76
42
  @abstractmethod
77
- def eval_aggregation(
43
+ def execute_condition(
78
44
  self,
79
- aggr: Aggregation,
45
+ condition: Condition,
80
46
  feat_dict: Dict[str, TableData],
81
47
  time_dict: Dict[str, ColumnData],
82
48
  batch_dict: Dict[str, IndexData],
83
49
  anchor_time: ColumnData,
84
50
  filter_na: bool = True,
51
+ num_forecasts: int = 1,
85
52
  ) -> Tuple[ColumnData, IndexData]:
86
53
  pass
87
54
 
88
55
  @abstractmethod
89
- def eval_condition(
56
+ def execute_logical_operation(
90
57
  self,
91
- condition: Condition,
58
+ logical_operation: LogicalOperation,
92
59
  feat_dict: Dict[str, TableData],
93
60
  time_dict: Dict[str, ColumnData],
94
61
  batch_dict: Dict[str, IndexData],
95
62
  anchor_time: ColumnData,
96
63
  filter_na: bool = True,
64
+ num_forecasts: int = 1,
97
65
  ) -> Tuple[ColumnData, IndexData]:
98
66
  pass
99
67
 
100
68
  @abstractmethod
101
- def eval_logical_operation(
69
+ def execute_join(
102
70
  self,
103
- logical_operation: LogicalOperation,
71
+ join: Join,
104
72
  feat_dict: Dict[str, TableData],
105
73
  time_dict: Dict[str, ColumnData],
106
74
  batch_dict: Dict[str, IndexData],
107
75
  anchor_time: ColumnData,
108
76
  filter_na: bool = True,
77
+ num_forecasts: int = 1,
109
78
  ) -> Tuple[ColumnData, IndexData]:
110
79
  pass
111
80
 
112
81
  @abstractmethod
113
- def eval_filter(
82
+ def execute_filter(
114
83
  self,
115
84
  filter: Filter,
116
85
  feat_dict: Dict[str, TableData],
117
86
  time_dict: Dict[str, ColumnData],
118
87
  batch_dict: Dict[str, IndexData],
119
88
  anchor_time: ColumnData,
120
- ) -> IndexData:
89
+ ) -> Tuple[ColumnData, IndexData]:
121
90
  pass
122
91
 
123
92
  @abstractmethod
124
- def eval_pquery(
93
+ def execute(
125
94
  self,
126
- query: PQueryDefinition,
95
+ query: ValidatedPredictiveQuery,
127
96
  feat_dict: Dict[str, TableData],
128
97
  time_dict: Dict[str, ColumnData],
129
98
  batch_dict: Dict[str, IndexData],
130
99
  anchor_time: ColumnData,
100
+ num_forecasts: int = 1,
131
101
  ) -> Tuple[ColumnData, IndexData]:
132
102
  pass
@@ -0,0 +1,532 @@
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+ from kumoapi.pquery.AST import (
7
+ Aggregation,
8
+ Column,
9
+ Condition,
10
+ Constant,
11
+ Filter,
12
+ Join,
13
+ LogicalOperation,
14
+ )
15
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
16
+
17
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
18
+
19
+
20
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
21
+ np.ndarray]):
22
+ def execute_column(
23
+ self,
24
+ column: Column,
25
+ feat_dict: Dict[str, pd.DataFrame],
26
+ filter_na: bool = True,
27
+ ) -> Tuple[pd.Series, np.ndarray]:
28
+ table_name, column_name = column.fqn.split(".")
29
+ if column_name == '*':
30
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
31
+ else:
32
+ out = feat_dict[table_name][column_name]
33
+ out = out.reset_index(drop=True)
34
+
35
+ if pd.api.types.is_float_dtype(out):
36
+ out = out.astype('float32')
37
+
38
+ out.name = None
39
+ out.index.name = None
40
+
41
+ mask = out.notna().to_numpy()
42
+
43
+ if not filter_na:
44
+ return out, mask
45
+
46
+ out = out[mask].reset_index(drop=True)
47
+
48
+ # Cast to primitive dtype:
49
+ if pd.api.types.is_integer_dtype(out):
50
+ out = out.astype('int64')
51
+ elif pd.api.types.is_bool_dtype(out):
52
+ out = out.astype('bool')
53
+
54
+ return out, mask
55
+
56
+ def execute_aggregation_type(
57
+ self,
58
+ op: AggregationType,
59
+ feat: pd.Series,
60
+ batch: np.ndarray,
61
+ batch_size: int,
62
+ filter_na: bool = True,
63
+ ) -> Tuple[pd.Series, np.ndarray]:
64
+
65
+ mask = feat.notna()
66
+ feat, batch = feat[mask], batch[mask]
67
+
68
+ if op == AggregationType.LIST_DISTINCT:
69
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
70
+ df = df.drop_duplicates()
71
+ out = df.groupby('batch')['feat'].agg(list)
72
+
73
+ else:
74
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
75
+ if op == AggregationType.AVG:
76
+ agg = 'mean'
77
+ elif op == AggregationType.COUNT:
78
+ agg = 'size'
79
+ else:
80
+ agg = op.lower()
81
+ out = df.groupby('batch')['feat'].agg(agg)
82
+
83
+ if not pd.api.types.is_datetime64_any_dtype(out):
84
+ out = out.astype('float32')
85
+
86
+ out.name = None
87
+ out.index.name = None
88
+
89
+ if op in {AggregationType.SUM, AggregationType.COUNT}:
90
+ out = out.reindex(range(batch_size), fill_value=0)
91
+ mask = np.ones(batch_size, dtype=bool)
92
+ return out, mask
93
+
94
+ mask = np.zeros(batch_size, dtype=bool)
95
+ mask[batch] = True
96
+
97
+ if filter_na:
98
+ return out.reset_index(drop=True), mask
99
+
100
+ out = out.reindex(range(batch_size), fill_value=pd.NA)
101
+
102
+ return out, mask
103
+
104
+ def execute_aggregation(
105
+ self,
106
+ aggr: Aggregation,
107
+ feat_dict: Dict[str, pd.DataFrame],
108
+ time_dict: Dict[str, pd.Series],
109
+ batch_dict: Dict[str, np.ndarray],
110
+ anchor_time: pd.Series,
111
+ filter_na: bool = True,
112
+ num_forecasts: int = 1,
113
+ ) -> Tuple[pd.Series, np.ndarray]:
114
+ target_table = aggr._get_target_column_name().split('.')[0]
115
+ target_batch = batch_dict[target_table]
116
+ target_time = time_dict[target_table]
117
+ if isinstance(aggr.target, Column):
118
+ target_feat, target_mask = self.execute_column(
119
+ column=aggr.target,
120
+ feat_dict=feat_dict,
121
+ filter_na=True,
122
+ )
123
+ else:
124
+ assert isinstance(aggr.target, Filter)
125
+ target_feat, target_mask = self.execute_filter(
126
+ filter=aggr.target,
127
+ feat_dict=feat_dict,
128
+ time_dict=time_dict,
129
+ batch_dict=batch_dict,
130
+ anchor_time=anchor_time,
131
+ filter_na=True,
132
+ )
133
+
134
+ outs: List[pd.Series] = []
135
+ masks: List[np.ndarray] = []
136
+ for _ in range(num_forecasts):
137
+ anchor_target_time = anchor_time[target_batch]
138
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
139
+
140
+ time_filter_mask = (target_time <= anchor_target_time +
141
+ aggr.aggr_time_range.end_date_offset)
142
+ if aggr.aggr_time_range.start is not None:
143
+ start_offset = aggr.aggr_time_range.start_date_offset
144
+ time_filter_mask &= (target_time
145
+ > anchor_target_time + start_offset)
146
+ else:
147
+ assert num_forecasts == 1
148
+ curr_target_mask = target_mask & time_filter_mask
149
+
150
+ out, mask = self.execute_aggregation_type(
151
+ aggr.aggr,
152
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
153
+ drop=True)],
154
+ batch=target_batch[curr_target_mask],
155
+ batch_size=len(anchor_time),
156
+ filter_na=False if num_forecasts > 1 else filter_na,
157
+ )
158
+ outs.append(out)
159
+ masks.append(mask)
160
+
161
+ if num_forecasts > 1:
162
+ anchor_time = (anchor_time +
163
+ aggr.aggr_time_range.end_date_offset)
164
+ if len(outs) == 1:
165
+ assert len(masks) == 1
166
+ return outs[0], masks[0]
167
+
168
+ out = pd.Series([list(ser) for ser in zip(*outs)])
169
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
170
+
171
+ if filter_na:
172
+ out = out[mask].reset_index(drop=True)
173
+
174
+ return out, mask
175
+
176
+ def execute_rel_op(
177
+ self,
178
+ left: pd.Series,
179
+ op: RelOp,
180
+ right: Constant,
181
+ ) -> pd.Series:
182
+
183
+ if right.typed_value() is None:
184
+ if op == RelOp.EQ:
185
+ return left.isna()
186
+ assert op == RelOp.NEQ
187
+ return left.notna()
188
+
189
+ # Promote left to float if right is a float to avoid lossy coercion.
190
+ right_value = right.typed_value()
191
+ if pd.api.types.is_integer_dtype(left) and isinstance(
192
+ right_value, float):
193
+ left = left.astype('float64')
194
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
195
+
196
+ if op == RelOp.EQ:
197
+ return (left == value).fillna(False).astype(bool)
198
+ if op == RelOp.NEQ:
199
+ out = (left != value).fillna(False).astype(bool)
200
+ out[left.isna()] = False # N/A != right should always be `False`.
201
+ return out
202
+ if op == RelOp.LEQ:
203
+ return (left <= value).fillna(False).astype(bool)
204
+ if op == RelOp.GEQ:
205
+ return (left >= value).fillna(False).astype(bool)
206
+ if op == RelOp.LT:
207
+ return (left < value).fillna(False).astype(bool)
208
+ if op == RelOp.GT:
209
+ return (left > value).fillna(False).astype(bool)
210
+
211
+ raise NotImplementedError(f"Operator '{op}' not implemented")
212
+
213
+ def execute_member_op(
214
+ self,
215
+ left: pd.Series,
216
+ op: MemberOp,
217
+ right: Constant,
218
+ ) -> pd.Series:
219
+
220
+ if op == MemberOp.IN:
221
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
222
+ return left.isin(ser).astype(bool)
223
+
224
+ raise NotImplementedError(f"Operator '{op}' not implemented")
225
+
226
+ def execute_condition(
227
+ self,
228
+ condition: Condition,
229
+ feat_dict: Dict[str, pd.DataFrame],
230
+ time_dict: Dict[str, pd.Series],
231
+ batch_dict: Dict[str, np.ndarray],
232
+ anchor_time: pd.Series,
233
+ filter_na: bool = True,
234
+ num_forecasts: int = 1,
235
+ ) -> Tuple[pd.Series, np.ndarray]:
236
+ if num_forecasts > 1:
237
+ raise NotImplementedError("Forecasting not yet implemented for "
238
+ "non-regression tasks")
239
+
240
+ assert isinstance(condition.value, Constant)
241
+ value_is_na = condition.value.typed_value() is None
242
+ if isinstance(condition.target, Column):
243
+ left, mask = self.execute_column(
244
+ column=condition.target,
245
+ feat_dict=feat_dict,
246
+ filter_na=filter_na if not value_is_na else False,
247
+ )
248
+ elif isinstance(condition.target, Join):
249
+ left, mask = self.execute_join(
250
+ join=condition.target,
251
+ feat_dict=feat_dict,
252
+ time_dict=time_dict,
253
+ batch_dict=batch_dict,
254
+ anchor_time=anchor_time,
255
+ filter_na=filter_na if not value_is_na else False,
256
+ )
257
+ else:
258
+ assert isinstance(condition.target, Aggregation)
259
+ left, mask = self.execute_aggregation(
260
+ aggr=condition.target,
261
+ feat_dict=feat_dict,
262
+ time_dict=time_dict,
263
+ batch_dict=batch_dict,
264
+ anchor_time=anchor_time,
265
+ filter_na=filter_na if not value_is_na else False,
266
+ )
267
+
268
+ if filter_na and value_is_na:
269
+ mask = np.ones(len(left), dtype=bool)
270
+
271
+ if isinstance(condition.op, RelOp):
272
+ out = self.execute_rel_op(
273
+ left=left,
274
+ op=condition.op,
275
+ right=condition.value,
276
+ )
277
+ else:
278
+ assert isinstance(condition.op, MemberOp)
279
+ out = self.execute_member_op(
280
+ left=left,
281
+ op=condition.op,
282
+ right=condition.value,
283
+ )
284
+
285
+ return out, mask
286
+
287
+ def execute_bool_op(
288
+ self,
289
+ left: pd.Series,
290
+ op: BoolOp,
291
+ right: pd.Series | None,
292
+ ) -> pd.Series:
293
+
294
+ # TODO Implement Kleene-Priest three-value logic.
295
+ if op == BoolOp.AND:
296
+ assert right is not None
297
+ return left & right
298
+ if op == BoolOp.OR:
299
+ assert right is not None
300
+ return left | right
301
+ if op == BoolOp.NOT:
302
+ return ~left
303
+
304
+ raise NotImplementedError(f"Operator '{op}' not implemented")
305
+
306
+ def execute_logical_operation(
307
+ self,
308
+ logical_operation: LogicalOperation,
309
+ feat_dict: Dict[str, pd.DataFrame],
310
+ time_dict: Dict[str, pd.Series],
311
+ batch_dict: Dict[str, np.ndarray],
312
+ anchor_time: pd.Series,
313
+ filter_na: bool = True,
314
+ num_forecasts: int = 1,
315
+ ) -> Tuple[pd.Series, np.ndarray]:
316
+ if num_forecasts > 1:
317
+ raise NotImplementedError("Forecasting not yet implemented for "
318
+ "non-regression tasks")
319
+
320
+ if isinstance(logical_operation.left, Condition):
321
+ left, mask = self.execute_condition(
322
+ condition=logical_operation.left,
323
+ feat_dict=feat_dict,
324
+ time_dict=time_dict,
325
+ batch_dict=batch_dict,
326
+ anchor_time=anchor_time,
327
+ filter_na=False,
328
+ )
329
+ else:
330
+ assert isinstance(logical_operation.left, LogicalOperation)
331
+ left, mask = self.execute_logical_operation(
332
+ logical_operation=logical_operation.left,
333
+ feat_dict=feat_dict,
334
+ time_dict=time_dict,
335
+ batch_dict=batch_dict,
336
+ anchor_time=anchor_time,
337
+ filter_na=False,
338
+ )
339
+
340
+ right = right_mask = None
341
+ if isinstance(logical_operation.right, Condition):
342
+ right, right_mask = self.execute_condition(
343
+ condition=logical_operation.right,
344
+ feat_dict=feat_dict,
345
+ time_dict=time_dict,
346
+ batch_dict=batch_dict,
347
+ anchor_time=anchor_time,
348
+ filter_na=False,
349
+ )
350
+ elif isinstance(logical_operation.right, LogicalOperation):
351
+ right, right_mask = self.execute_logical_operation(
352
+ logical_operation=logical_operation.right,
353
+ feat_dict=feat_dict,
354
+ time_dict=time_dict,
355
+ batch_dict=batch_dict,
356
+ anchor_time=anchor_time,
357
+ filter_na=False,
358
+ )
359
+
360
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
361
+
362
+ if right_mask is not None:
363
+ mask &= right_mask
364
+
365
+ if filter_na:
366
+ out = out[mask].reset_index(drop=True)
367
+
368
+ return out, mask
369
+
370
+ def execute_join(
371
+ self,
372
+ join: Join,
373
+ feat_dict: Dict[str, pd.DataFrame],
374
+ time_dict: Dict[str, pd.Series],
375
+ batch_dict: Dict[str, np.ndarray],
376
+ anchor_time: pd.Series,
377
+ filter_na: bool = True,
378
+ num_forecasts: int = 1,
379
+ ) -> Tuple[pd.Series, np.ndarray]:
380
+ if isinstance(join.rhs_target, Aggregation):
381
+ return self.execute_aggregation(
382
+ aggr=join.rhs_target,
383
+ feat_dict=feat_dict,
384
+ time_dict=time_dict,
385
+ batch_dict=batch_dict,
386
+ anchor_time=anchor_time,
387
+ filter_na=True,
388
+ num_forecasts=num_forecasts,
389
+ )
390
+ raise NotImplementedError(
391
+ f'Unexpected {type(join.rhs_target)} nested in Join')
392
+
393
+ def execute_filter(
394
+ self,
395
+ filter: Filter,
396
+ feat_dict: Dict[str, pd.DataFrame],
397
+ time_dict: Dict[str, pd.Series],
398
+ batch_dict: Dict[str, np.ndarray],
399
+ anchor_time: pd.Series,
400
+ filter_na: bool = True,
401
+ ) -> Tuple[pd.Series, np.ndarray]:
402
+ out, mask = self.execute_column(
403
+ column=filter.target,
404
+ feat_dict=feat_dict,
405
+ filter_na=False,
406
+ )
407
+ if isinstance(filter.condition, Condition):
408
+ _mask = self.execute_condition(
409
+ condition=filter.condition,
410
+ feat_dict=feat_dict,
411
+ time_dict=time_dict,
412
+ batch_dict=batch_dict,
413
+ anchor_time=anchor_time,
414
+ filter_na=False,
415
+ )[0].to_numpy()
416
+ else:
417
+ assert isinstance(filter.condition, LogicalOperation)
418
+ _mask = self.execute_logical_operation(
419
+ logical_operation=filter.condition,
420
+ feat_dict=feat_dict,
421
+ time_dict=time_dict,
422
+ batch_dict=batch_dict,
423
+ anchor_time=anchor_time,
424
+ filter_na=False,
425
+ )[0].to_numpy()
426
+ if filter_na:
427
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
428
+ else:
429
+ return out[_mask].reset_index(drop=True), mask & _mask
430
+
431
+ def execute(
432
+ self,
433
+ query: ValidatedPredictiveQuery,
434
+ feat_dict: Dict[str, pd.DataFrame],
435
+ time_dict: Dict[str, pd.Series],
436
+ batch_dict: Dict[str, np.ndarray],
437
+ anchor_time: pd.Series,
438
+ num_forecasts: int = 1,
439
+ ) -> Tuple[pd.Series, np.ndarray]:
440
+ if isinstance(query.entity_ast, Column):
441
+ out, mask = self.execute_column(
442
+ column=query.entity_ast,
443
+ feat_dict=feat_dict,
444
+ filter_na=True,
445
+ )
446
+ else:
447
+ assert isinstance(query.entity_ast, Filter)
448
+ out, mask = self.execute_filter(
449
+ filter=query.entity_ast,
450
+ feat_dict=feat_dict,
451
+ time_dict=time_dict,
452
+ batch_dict=batch_dict,
453
+ anchor_time=anchor_time,
454
+ )
455
+ if isinstance(query.target_ast, Column):
456
+ out, _mask = self.execute_column(
457
+ column=query.target_ast,
458
+ feat_dict=feat_dict,
459
+ filter_na=True,
460
+ )
461
+ elif isinstance(query.target_ast, Condition):
462
+ out, _mask = self.execute_condition(
463
+ condition=query.target_ast,
464
+ feat_dict=feat_dict,
465
+ time_dict=time_dict,
466
+ batch_dict=batch_dict,
467
+ anchor_time=anchor_time,
468
+ filter_na=True,
469
+ num_forecasts=num_forecasts,
470
+ )
471
+ elif isinstance(query.target_ast, Aggregation):
472
+ out, _mask = self.execute_aggregation(
473
+ aggr=query.target_ast,
474
+ feat_dict=feat_dict,
475
+ time_dict=time_dict,
476
+ batch_dict=batch_dict,
477
+ anchor_time=anchor_time,
478
+ filter_na=True,
479
+ num_forecasts=num_forecasts,
480
+ )
481
+ elif isinstance(query.target_ast, Join):
482
+ out, _mask = self.execute_join(
483
+ join=query.target_ast,
484
+ feat_dict=feat_dict,
485
+ time_dict=time_dict,
486
+ batch_dict=batch_dict,
487
+ anchor_time=anchor_time,
488
+ filter_na=True,
489
+ num_forecasts=num_forecasts,
490
+ )
491
+ elif isinstance(query.target_ast, LogicalOperation):
492
+ out, _mask = self.execute_logical_operation(
493
+ logical_operation=query.target_ast,
494
+ feat_dict=feat_dict,
495
+ time_dict=time_dict,
496
+ batch_dict=batch_dict,
497
+ anchor_time=anchor_time,
498
+ filter_na=True,
499
+ num_forecasts=num_forecasts,
500
+ )
501
+ else:
502
+ raise NotImplementedError(
503
+ f'{type(query.target_ast)} compilation missing.')
504
+ if query.whatif_ast is not None:
505
+ if isinstance(query.whatif_ast, Condition):
506
+ mask &= self.execute_condition(
507
+ condition=query.whatif_ast,
508
+ feat_dict=feat_dict,
509
+ time_dict=time_dict,
510
+ batch_dict=batch_dict,
511
+ anchor_time=anchor_time,
512
+ filter_na=True,
513
+ num_forecasts=num_forecasts,
514
+ )[0]
515
+ elif isinstance(query.whatif_ast, LogicalOperation):
516
+ mask &= self.execute_logical_operation(
517
+ logical_operation=query.whatif_ast,
518
+ feat_dict=feat_dict,
519
+ time_dict=time_dict,
520
+ batch_dict=batch_dict,
521
+ anchor_time=anchor_time,
522
+ filter_na=True,
523
+ num_forecasts=num_forecasts,
524
+ )[0]
525
+ else:
526
+ raise ValueError(
527
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
528
+
529
+ out = out[mask[_mask]]
530
+ mask &= _mask
531
+ out = out.reset_index(drop=True)
532
+ return out, mask