kumoai 2.9.0.dev202509061830__cp311-cp311-macosx_11_0_arm64.whl → 2.12.0.dev202511031731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. kumoai/__init__.py +4 -2
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +10 -5
  4. kumoai/client/rfm.py +3 -2
  5. kumoai/connector/file_upload_connector.py +71 -102
  6. kumoai/connector/utils.py +1367 -236
  7. kumoai/experimental/rfm/__init__.py +2 -2
  8. kumoai/experimental/rfm/authenticate.py +8 -5
  9. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  10. kumoai/experimental/rfm/local_graph.py +90 -80
  11. kumoai/experimental/rfm/local_graph_sampler.py +16 -8
  12. kumoai/experimental/rfm/local_graph_store.py +22 -6
  13. kumoai/experimental/rfm/local_pquery_driver.py +129 -28
  14. kumoai/experimental/rfm/local_table.py +100 -22
  15. kumoai/experimental/rfm/pquery/__init__.py +4 -0
  16. kumoai/experimental/rfm/pquery/backend.py +4 -0
  17. kumoai/experimental/rfm/pquery/executor.py +102 -0
  18. kumoai/experimental/rfm/pquery/pandas_backend.py +71 -30
  19. kumoai/experimental/rfm/pquery/pandas_executor.py +506 -0
  20. kumoai/experimental/rfm/rfm.py +442 -94
  21. kumoai/jobs.py +1 -0
  22. kumoai/trainer/trainer.py +19 -10
  23. kumoai/utils/progress_logger.py +62 -0
  24. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/METADATA +4 -5
  25. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/RECORD +28 -26
  26. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/WHEEL +0 -0
  27. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/licenses/LICENSE +0 -0
  28. {kumoai-2.9.0.dev202509061830.dist-info → kumoai-2.12.0.dev202511031731.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Dict, Optional, Tuple, Union
1
+ from typing import Dict, List, Optional, Tuple, Union
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
@@ -180,46 +180,71 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
180
180
  batch_dict: Dict[str, np.ndarray],
181
181
  anchor_time: pd.Series,
182
182
  filter_na: bool = True,
183
+ num_forecasts: int = 1,
183
184
  ) -> Tuple[pd.Series, np.ndarray]:
184
185
 
185
186
  target_table = aggr.column.table_name
186
187
  target_batch = batch_dict[target_table]
187
188
  target_time = time_dict[target_table]
188
- anchor_target_time = anchor_time[target_batch].reset_index(drop=True)
189
189
 
190
- target_mask = target_time <= anchor_target_time + aggr.end_offset
190
+ outs: List[pd.Series] = []
191
+ masks: List[np.ndarray] = []
192
+ for _ in range(num_forecasts):
193
+ anchor_target_time = anchor_time[target_batch]
194
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
191
195
 
192
- if aggr.start is not None:
193
- start_offset = aggr.start * aggr.time_unit.to_offset()
194
- target_mask &= target_time > anchor_target_time + start_offset
196
+ target_mask = target_time <= anchor_target_time + aggr.end_offset
195
197
 
196
- if aggr.filter is not None:
197
- target_mask &= self.eval_filter(
198
- filter=aggr.filter,
199
- feat_dict=feat_dict,
200
- time_dict=time_dict,
201
- batch_dict=batch_dict,
202
- anchor_time=anchor_time,
203
- )
198
+ if aggr.start is not None:
199
+ start_offset = aggr.start * aggr.time_unit.to_offset()
200
+ target_mask &= target_time > anchor_target_time + start_offset
201
+ else:
202
+ assert num_forecasts == 1
204
203
 
205
- if (aggr.type == AggregationType.COUNT
206
- and aggr.column.column_name == '*'):
207
- target_feat = None
208
- else:
209
- target_feat, _ = self.eval_column(
210
- aggr.column,
211
- feat_dict,
212
- filter_na=False,
204
+ if aggr.filter is not None:
205
+ target_mask &= self.eval_filter(
206
+ filter=aggr.filter,
207
+ feat_dict=feat_dict,
208
+ time_dict=time_dict,
209
+ batch_dict=batch_dict,
210
+ anchor_time=anchor_time,
211
+ )
212
+
213
+ if (aggr.type == AggregationType.COUNT
214
+ and aggr.column.column_name == '*'):
215
+ target_feat = None
216
+ else:
217
+ target_feat, _ = self.eval_column(
218
+ aggr.column,
219
+ feat_dict,
220
+ filter_na=False,
221
+ )
222
+ target_feat = target_feat[target_mask]
223
+
224
+ out, mask = self.eval_aggregation_type(
225
+ aggr.type,
226
+ feat=target_feat,
227
+ batch=target_batch[target_mask],
228
+ batch_size=len(anchor_time),
229
+ filter_na=False if num_forecasts > 1 else filter_na,
213
230
  )
214
- target_feat = target_feat[target_mask]
231
+ outs.append(out)
232
+ masks.append(mask)
233
+
234
+ if num_forecasts > 1:
235
+ anchor_time = anchor_time + aggr.end_offset
236
+
237
+ if len(outs) == 1:
238
+ assert len(masks) == 1
239
+ return outs[0], masks[0]
215
240
 
216
- return self.eval_aggregation_type(
217
- aggr.type,
218
- feat=target_feat,
219
- batch=target_batch[target_mask],
220
- batch_size=len(anchor_time),
221
- filter_na=filter_na,
222
- )
241
+ out = pd.Series([list(ser) for ser in zip(*outs)])
242
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
243
+
244
+ if filter_na:
245
+ out = out[mask].reset_index(drop=True)
246
+
247
+ return out, mask
223
248
 
224
249
  def eval_condition(
225
250
  self,
@@ -229,8 +254,13 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
229
254
  batch_dict: Dict[str, np.ndarray],
230
255
  anchor_time: pd.Series,
231
256
  filter_na: bool = True,
257
+ num_forecasts: int = 1,
232
258
  ) -> Tuple[pd.Series, np.ndarray]:
233
259
 
260
+ if num_forecasts > 1:
261
+ raise NotImplementedError("Forecasting not yet implemented for "
262
+ "non-regression tasks")
263
+
234
264
  if isinstance(condition.left, Column):
235
265
  left, mask = self.eval_column(
236
266
  column=condition.left,
@@ -275,8 +305,13 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
275
305
  batch_dict: Dict[str, np.ndarray],
276
306
  anchor_time: pd.Series,
277
307
  filter_na: bool = True,
308
+ num_forecasts: int = 1,
278
309
  ) -> Tuple[pd.Series, np.ndarray]:
279
310
 
311
+ if num_forecasts > 1:
312
+ raise NotImplementedError("Forecasting not yet implemented for "
313
+ "non-regression tasks")
314
+
280
315
  if isinstance(logical_operation.left, Condition):
281
316
  left, mask = self.eval_condition(
282
317
  condition=logical_operation.left,
@@ -362,6 +397,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
362
397
  time_dict: Dict[str, pd.Series],
363
398
  batch_dict: Dict[str, np.ndarray],
364
399
  anchor_time: pd.Series,
400
+ num_forecasts: int = 1,
365
401
  ) -> Tuple[pd.Series, np.ndarray]:
366
402
 
367
403
  mask = np.ones(len(anchor_time), dtype=bool)
@@ -410,6 +446,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
410
446
  batch_dict=batch_dict,
411
447
  anchor_time=anchor_time,
412
448
  filter_na=True,
449
+ num_forecasts=num_forecasts,
413
450
  )
414
451
  elif isinstance(query.target, Condition):
415
452
  out, _mask = self.eval_condition(
@@ -419,6 +456,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
419
456
  batch_dict=batch_dict,
420
457
  anchor_time=anchor_time,
421
458
  filter_na=True,
459
+ num_forecasts=num_forecasts,
422
460
  )
423
461
  else:
424
462
  assert isinstance(query.target, LogicalOperation)
@@ -429,9 +467,12 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
429
467
  batch_dict=batch_dict,
430
468
  anchor_time=anchor_time,
431
469
  filter_na=True,
470
+ num_forecasts=num_forecasts,
432
471
  )
433
472
 
434
473
  out = out[mask[_mask]]
435
474
  mask &= _mask
436
475
 
476
+ out = out.reset_index(drop=True)
477
+
437
478
  return out, mask
@@ -0,0 +1,506 @@
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+ from kumoapi.pquery.AST import (
7
+ Aggregation,
8
+ Column,
9
+ Condition,
10
+ Constant,
11
+ Filter,
12
+ Join,
13
+ LogicalOperation,
14
+ )
15
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
16
+
17
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
18
+
19
+
20
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
21
+ np.ndarray]):
22
+ def execute_column(
23
+ self,
24
+ column: Column,
25
+ feat_dict: Dict[str, pd.DataFrame],
26
+ filter_na: bool = True,
27
+ ) -> Tuple[pd.Series, np.ndarray]:
28
+ table_name, column_name = column.fqn.split(".")
29
+ if column_name == '*':
30
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
31
+ else:
32
+ out = feat_dict[table_name][column_name]
33
+ out = out.reset_index(drop=True)
34
+
35
+ if pd.api.types.is_float_dtype(out):
36
+ out = out.astype('float32')
37
+
38
+ out.name = None
39
+ out.index.name = None
40
+
41
+ mask = out.notna().to_numpy()
42
+
43
+ if not filter_na:
44
+ return out, mask
45
+
46
+ out = out[mask].reset_index(drop=True)
47
+
48
+ # Cast to primitive dtype:
49
+ if pd.api.types.is_integer_dtype(out):
50
+ out = out.astype('int64')
51
+ elif pd.api.types.is_bool_dtype(out):
52
+ out = out.astype('bool')
53
+
54
+ return out, mask
55
+
56
+ def execute_aggregation_type(
57
+ self,
58
+ op: AggregationType,
59
+ feat: pd.Series,
60
+ batch: np.ndarray,
61
+ batch_size: int,
62
+ filter_na: bool = True,
63
+ ) -> Tuple[pd.Series, np.ndarray]:
64
+
65
+ mask = feat.notna()
66
+ feat, batch = feat[mask], batch[mask]
67
+
68
+ if op == AggregationType.LIST_DISTINCT:
69
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
70
+ df = df.drop_duplicates()
71
+ out = df.groupby('batch')['feat'].agg(list)
72
+
73
+ else:
74
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
75
+ if op == AggregationType.AVG:
76
+ agg = 'mean'
77
+ elif op == AggregationType.COUNT:
78
+ agg = 'size'
79
+ else:
80
+ agg = op.lower()
81
+ out = df.groupby('batch')['feat'].agg(agg)
82
+
83
+ if not pd.api.types.is_datetime64_any_dtype(out):
84
+ out = out.astype('float32')
85
+
86
+ out.name = None
87
+ out.index.name = None
88
+
89
+ if op in {AggregationType.SUM, AggregationType.COUNT}:
90
+ out = out.reindex(range(batch_size), fill_value=0)
91
+ mask = np.ones(batch_size, dtype=bool)
92
+ return out, mask
93
+
94
+ mask = np.zeros(batch_size, dtype=bool)
95
+ mask[batch] = True
96
+
97
+ if filter_na:
98
+ return out.reset_index(drop=True), mask
99
+
100
+ out = out.reindex(range(batch_size), fill_value=pd.NA)
101
+
102
+ return out, mask
103
+
104
+ def execute_aggregation(
105
+ self,
106
+ aggr: Aggregation,
107
+ feat_dict: Dict[str, pd.DataFrame],
108
+ time_dict: Dict[str, pd.Series],
109
+ batch_dict: Dict[str, np.ndarray],
110
+ anchor_time: pd.Series,
111
+ filter_na: bool = True,
112
+ num_forecasts: int = 1,
113
+ ) -> Tuple[pd.Series, np.ndarray]:
114
+ target_table = aggr._get_target_column_name().split('.')[0]
115
+ target_batch = batch_dict[target_table]
116
+ target_time = time_dict[target_table]
117
+ if isinstance(aggr.target, Column):
118
+ target_feat, target_mask = self.execute_column(
119
+ column=aggr.target,
120
+ feat_dict=feat_dict,
121
+ filter_na=False,
122
+ )
123
+ else:
124
+ assert isinstance(aggr.target, Filter)
125
+ target_feat, target_mask = self.execute_filter(
126
+ filter=aggr.target,
127
+ feat_dict=feat_dict,
128
+ time_dict=time_dict,
129
+ batch_dict=batch_dict,
130
+ anchor_time=anchor_time,
131
+ filter_na=False,
132
+ )
133
+
134
+ outs: List[pd.Series] = []
135
+ masks: List[np.ndarray] = []
136
+ for _ in range(num_forecasts):
137
+ anchor_target_time = anchor_time[target_batch]
138
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
139
+
140
+ curr_target_mask = target_mask & (
141
+ target_time
142
+ <= anchor_target_time + aggr.aggr_time_range.end_date_offset)
143
+ if aggr.aggr_time_range.start is not None:
144
+ start_offset = aggr.aggr_time_range.start_date_offset
145
+ curr_target_mask &= (target_time
146
+ > anchor_target_time + start_offset)
147
+ else:
148
+ assert num_forecasts == 1
149
+
150
+ out, mask = self.execute_aggregation_type(
151
+ aggr.aggr,
152
+ feat=target_feat[curr_target_mask],
153
+ batch=target_batch[curr_target_mask],
154
+ batch_size=len(anchor_time),
155
+ filter_na=False if num_forecasts > 1 else filter_na,
156
+ )
157
+ outs.append(out)
158
+ masks.append(mask)
159
+
160
+ if num_forecasts > 1:
161
+ anchor_time = (anchor_time +
162
+ aggr.aggr_time_range.end_date_offset)
163
+ if len(outs) == 1:
164
+ assert len(masks) == 1
165
+ return outs[0], masks[0]
166
+
167
+ out = pd.Series([list(ser) for ser in zip(*outs)])
168
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
169
+
170
+ if filter_na:
171
+ out = out[mask].reset_index(drop=True)
172
+
173
+ return out, mask
174
+
175
+ def execute_rel_op(
176
+ self,
177
+ left: pd.Series,
178
+ op: RelOp,
179
+ right: Constant,
180
+ ) -> pd.Series:
181
+
182
+ if right.typed_value() is None:
183
+ if op == RelOp.EQ:
184
+ return left.isna()
185
+ assert op == RelOp.NEQ
186
+ return left.notna()
187
+
188
+ # Promote left to float if right is a float to avoid lossy coercion.
189
+ right_value = right.typed_value()
190
+ if pd.api.types.is_integer_dtype(left) and isinstance(
191
+ right_value, float):
192
+ left = left.astype('float64')
193
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
194
+
195
+ if op == RelOp.EQ:
196
+ return (left == value).fillna(False).astype(bool)
197
+ if op == RelOp.NEQ:
198
+ out = (left != value).fillna(False).astype(bool)
199
+ out[left.isna()] = False # N/A != right should always be `False`.
200
+ return out
201
+ if op == RelOp.LEQ:
202
+ return (left <= value).fillna(False).astype(bool)
203
+ if op == RelOp.GEQ:
204
+ return (left >= value).fillna(False).astype(bool)
205
+ if op == RelOp.LT:
206
+ return (left < value).fillna(False).astype(bool)
207
+ if op == RelOp.GT:
208
+ return (left > value).fillna(False).astype(bool)
209
+
210
+ raise NotImplementedError(f"Operator '{op}' not implemented")
211
+
212
+ def execute_member_op(
213
+ self,
214
+ left: pd.Series,
215
+ op: MemberOp,
216
+ right: Constant,
217
+ ) -> pd.Series:
218
+
219
+ if op == MemberOp.IN:
220
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
221
+ return left.isin(ser).astype(bool)
222
+
223
+ raise NotImplementedError(f"Operator '{op}' not implemented")
224
+
225
+ def execute_condition(
226
+ self,
227
+ condition: Condition,
228
+ feat_dict: Dict[str, pd.DataFrame],
229
+ time_dict: Dict[str, pd.Series],
230
+ batch_dict: Dict[str, np.ndarray],
231
+ anchor_time: pd.Series,
232
+ filter_na: bool = True,
233
+ num_forecasts: int = 1,
234
+ ) -> Tuple[pd.Series, np.ndarray]:
235
+ if num_forecasts > 1:
236
+ raise NotImplementedError("Forecasting not yet implemented for "
237
+ "non-regression tasks")
238
+
239
+ assert isinstance(condition.value, Constant)
240
+ value_is_na = condition.value.typed_value() is None
241
+ if isinstance(condition.target, Column):
242
+ left, mask = self.execute_column(
243
+ column=condition.target,
244
+ feat_dict=feat_dict,
245
+ filter_na=filter_na if not value_is_na else False,
246
+ )
247
+ elif isinstance(condition.target, Join):
248
+ left, mask = self.execute_join(
249
+ join=condition.target,
250
+ feat_dict=feat_dict,
251
+ time_dict=time_dict,
252
+ batch_dict=batch_dict,
253
+ anchor_time=anchor_time,
254
+ filter_na=filter_na if not value_is_na else False,
255
+ )
256
+ else:
257
+ assert isinstance(condition.target, Aggregation)
258
+ left, mask = self.execute_aggregation(
259
+ aggr=condition.target,
260
+ feat_dict=feat_dict,
261
+ time_dict=time_dict,
262
+ batch_dict=batch_dict,
263
+ anchor_time=anchor_time,
264
+ filter_na=filter_na if not value_is_na else False,
265
+ )
266
+
267
+ if filter_na and value_is_na:
268
+ mask = np.ones(len(left), dtype=bool)
269
+
270
+ if isinstance(condition.op, RelOp):
271
+ out = self.execute_rel_op(
272
+ left=left,
273
+ op=condition.op,
274
+ right=condition.value,
275
+ )
276
+ else:
277
+ assert isinstance(condition.op, MemberOp)
278
+ out = self.execute_member_op(
279
+ left=left,
280
+ op=condition.op,
281
+ right=condition.value,
282
+ )
283
+
284
+ return out, mask
285
+
286
+ def execute_bool_op(
287
+ self,
288
+ left: pd.Series,
289
+ op: BoolOp,
290
+ right: pd.Series | None,
291
+ ) -> pd.Series:
292
+
293
+ # TODO Implement Kleene-Priest three-value logic.
294
+ if op == BoolOp.AND:
295
+ assert right is not None
296
+ return left & right
297
+ if op == BoolOp.OR:
298
+ assert right is not None
299
+ return left | right
300
+ if op == BoolOp.NOT:
301
+ return ~left
302
+
303
+ raise NotImplementedError(f"Operator '{op}' not implemented")
304
+
305
+ def execute_logical_operation(
306
+ self,
307
+ logical_operation: LogicalOperation,
308
+ feat_dict: Dict[str, pd.DataFrame],
309
+ time_dict: Dict[str, pd.Series],
310
+ batch_dict: Dict[str, np.ndarray],
311
+ anchor_time: pd.Series,
312
+ filter_na: bool = True,
313
+ num_forecasts: int = 1,
314
+ ) -> Tuple[pd.Series, np.ndarray]:
315
+ if num_forecasts > 1:
316
+ raise NotImplementedError("Forecasting not yet implemented for "
317
+ "non-regression tasks")
318
+
319
+ if isinstance(logical_operation.left, Condition):
320
+ left, mask = self.execute_condition(
321
+ condition=logical_operation.left,
322
+ feat_dict=feat_dict,
323
+ time_dict=time_dict,
324
+ batch_dict=batch_dict,
325
+ anchor_time=anchor_time,
326
+ filter_na=False,
327
+ )
328
+ else:
329
+ assert isinstance(logical_operation.left, LogicalOperation)
330
+ left, mask = self.execute_logical_operation(
331
+ logical_operation=logical_operation.left,
332
+ feat_dict=feat_dict,
333
+ time_dict=time_dict,
334
+ batch_dict=batch_dict,
335
+ anchor_time=anchor_time,
336
+ filter_na=False,
337
+ )
338
+
339
+ right = right_mask = None
340
+ if isinstance(logical_operation.right, Condition):
341
+ right, right_mask = self.execute_condition(
342
+ condition=logical_operation.right,
343
+ feat_dict=feat_dict,
344
+ time_dict=time_dict,
345
+ batch_dict=batch_dict,
346
+ anchor_time=anchor_time,
347
+ filter_na=False,
348
+ )
349
+ elif isinstance(logical_operation.right, LogicalOperation):
350
+ right, right_mask = self.execute_logical_operation(
351
+ logical_operation=logical_operation.right,
352
+ feat_dict=feat_dict,
353
+ time_dict=time_dict,
354
+ batch_dict=batch_dict,
355
+ anchor_time=anchor_time,
356
+ filter_na=False,
357
+ )
358
+
359
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
360
+
361
+ if right_mask is not None:
362
+ mask &= right_mask
363
+
364
+ if filter_na:
365
+ out = out[mask].reset_index(drop=True)
366
+
367
+ return out, mask
368
+
369
+ def execute_join(
370
+ self,
371
+ join: Join,
372
+ feat_dict: Dict[str, pd.DataFrame],
373
+ time_dict: Dict[str, pd.Series],
374
+ batch_dict: Dict[str, np.ndarray],
375
+ anchor_time: pd.Series,
376
+ filter_na: bool = True,
377
+ num_forecasts: int = 1,
378
+ ) -> Tuple[pd.Series, np.ndarray]:
379
+ if isinstance(join.rhs_target, Aggregation):
380
+ return self.execute_aggregation(
381
+ aggr=join.rhs_target,
382
+ feat_dict=feat_dict,
383
+ time_dict=time_dict,
384
+ batch_dict=batch_dict,
385
+ anchor_time=anchor_time,
386
+ filter_na=True,
387
+ num_forecasts=num_forecasts,
388
+ )
389
+ raise NotImplementedError(
390
+ f'Unexpected {type(join.rhs_target)} nested in Join')
391
+
392
+ def execute_filter(
393
+ self,
394
+ filter: Filter,
395
+ feat_dict: Dict[str, pd.DataFrame],
396
+ time_dict: Dict[str, pd.Series],
397
+ batch_dict: Dict[str, np.ndarray],
398
+ anchor_time: pd.Series,
399
+ filter_na: bool = True,
400
+ ) -> Tuple[pd.Series, np.ndarray]:
401
+ out, mask = self.execute_column(
402
+ column=filter.target,
403
+ feat_dict=feat_dict,
404
+ filter_na=False,
405
+ )
406
+ if isinstance(filter.condition, Condition):
407
+ _mask = self.execute_condition(
408
+ condition=filter.condition,
409
+ feat_dict=feat_dict,
410
+ time_dict=time_dict,
411
+ batch_dict=batch_dict,
412
+ anchor_time=anchor_time,
413
+ filter_na=False,
414
+ )[0].to_numpy()
415
+ else:
416
+ assert isinstance(filter.condition, LogicalOperation)
417
+ _mask = self.execute_logical_operation(
418
+ logical_operation=filter.condition,
419
+ feat_dict=feat_dict,
420
+ time_dict=time_dict,
421
+ batch_dict=batch_dict,
422
+ anchor_time=anchor_time,
423
+ filter_na=False,
424
+ )[0].to_numpy()
425
+ if filter_na:
426
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
427
+ else:
428
+ return out[_mask].reset_index(drop=True), mask & _mask
429
+
430
+ def execute(
431
+ self,
432
+ query: ValidatedPredictiveQuery,
433
+ feat_dict: Dict[str, pd.DataFrame],
434
+ time_dict: Dict[str, pd.Series],
435
+ batch_dict: Dict[str, np.ndarray],
436
+ anchor_time: pd.Series,
437
+ num_forecasts: int = 1,
438
+ ) -> Tuple[pd.Series, np.ndarray]:
439
+ if isinstance(query.entity_ast, Column):
440
+ out, mask = self.execute_column(
441
+ column=query.entity_ast,
442
+ feat_dict=feat_dict,
443
+ filter_na=True,
444
+ )
445
+ else:
446
+ assert isinstance(query.entity_ast, Filter)
447
+ out, mask = self.execute_filter(
448
+ filter=query.entity_ast,
449
+ feat_dict=feat_dict,
450
+ time_dict=time_dict,
451
+ batch_dict=batch_dict,
452
+ anchor_time=anchor_time,
453
+ )
454
+ if isinstance(query.target_ast, Column):
455
+ out, _mask = self.execute_column(
456
+ column=query.target_ast,
457
+ feat_dict=feat_dict,
458
+ filter_na=True,
459
+ )
460
+ elif isinstance(query.target_ast, Condition):
461
+ out, _mask = self.execute_condition(
462
+ condition=query.target_ast,
463
+ feat_dict=feat_dict,
464
+ time_dict=time_dict,
465
+ batch_dict=batch_dict,
466
+ anchor_time=anchor_time,
467
+ filter_na=True,
468
+ num_forecasts=num_forecasts,
469
+ )
470
+ elif isinstance(query.target_ast, Aggregation):
471
+ out, _mask = self.execute_aggregation(
472
+ aggr=query.target_ast,
473
+ feat_dict=feat_dict,
474
+ time_dict=time_dict,
475
+ batch_dict=batch_dict,
476
+ anchor_time=anchor_time,
477
+ filter_na=True,
478
+ num_forecasts=num_forecasts,
479
+ )
480
+ elif isinstance(query.target_ast, Join):
481
+ out, _mask = self.execute_join(
482
+ join=query.target_ast,
483
+ feat_dict=feat_dict,
484
+ time_dict=time_dict,
485
+ batch_dict=batch_dict,
486
+ anchor_time=anchor_time,
487
+ filter_na=True,
488
+ num_forecasts=num_forecasts,
489
+ )
490
+ elif isinstance(query.target_ast, LogicalOperation):
491
+ out, _mask = self.execute_logical_operation(
492
+ logical_operation=query.target_ast,
493
+ feat_dict=feat_dict,
494
+ time_dict=time_dict,
495
+ batch_dict=batch_dict,
496
+ anchor_time=anchor_time,
497
+ filter_na=True,
498
+ num_forecasts=num_forecasts,
499
+ )
500
+ else:
501
+ raise NotImplementedError(
502
+ f'{type(query.target)} compilation missing.')
503
+ out = out[mask[_mask]]
504
+ mask &= _mask
505
+ out = out.reset_index(drop=True)
506
+ return out, mask