kumoai 2.13.0.dev202511131731__cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (98) hide show
  1. kumoai/__init__.py +294 -0
  2. kumoai/_logging.py +29 -0
  3. kumoai/_singleton.py +25 -0
  4. kumoai/_version.py +1 -0
  5. kumoai/artifact_export/__init__.py +9 -0
  6. kumoai/artifact_export/config.py +209 -0
  7. kumoai/artifact_export/job.py +108 -0
  8. kumoai/client/__init__.py +5 -0
  9. kumoai/client/client.py +221 -0
  10. kumoai/client/connector.py +110 -0
  11. kumoai/client/endpoints.py +150 -0
  12. kumoai/client/graph.py +120 -0
  13. kumoai/client/jobs.py +447 -0
  14. kumoai/client/online.py +78 -0
  15. kumoai/client/pquery.py +203 -0
  16. kumoai/client/rfm.py +112 -0
  17. kumoai/client/source_table.py +53 -0
  18. kumoai/client/table.py +101 -0
  19. kumoai/client/utils.py +130 -0
  20. kumoai/codegen/__init__.py +19 -0
  21. kumoai/codegen/cli.py +100 -0
  22. kumoai/codegen/context.py +16 -0
  23. kumoai/codegen/edits.py +473 -0
  24. kumoai/codegen/exceptions.py +10 -0
  25. kumoai/codegen/generate.py +222 -0
  26. kumoai/codegen/handlers/__init__.py +4 -0
  27. kumoai/codegen/handlers/connector.py +118 -0
  28. kumoai/codegen/handlers/graph.py +71 -0
  29. kumoai/codegen/handlers/pquery.py +62 -0
  30. kumoai/codegen/handlers/table.py +109 -0
  31. kumoai/codegen/handlers/utils.py +42 -0
  32. kumoai/codegen/identity.py +114 -0
  33. kumoai/codegen/loader.py +93 -0
  34. kumoai/codegen/naming.py +94 -0
  35. kumoai/codegen/registry.py +121 -0
  36. kumoai/connector/__init__.py +31 -0
  37. kumoai/connector/base.py +153 -0
  38. kumoai/connector/bigquery_connector.py +200 -0
  39. kumoai/connector/databricks_connector.py +213 -0
  40. kumoai/connector/file_upload_connector.py +189 -0
  41. kumoai/connector/glue_connector.py +150 -0
  42. kumoai/connector/s3_connector.py +278 -0
  43. kumoai/connector/snowflake_connector.py +252 -0
  44. kumoai/connector/source_table.py +471 -0
  45. kumoai/connector/utils.py +1775 -0
  46. kumoai/databricks.py +14 -0
  47. kumoai/encoder/__init__.py +4 -0
  48. kumoai/exceptions.py +26 -0
  49. kumoai/experimental/__init__.py +0 -0
  50. kumoai/experimental/rfm/__init__.py +67 -0
  51. kumoai/experimental/rfm/authenticate.py +433 -0
  52. kumoai/experimental/rfm/infer/__init__.py +11 -0
  53. kumoai/experimental/rfm/infer/categorical.py +40 -0
  54. kumoai/experimental/rfm/infer/id.py +46 -0
  55. kumoai/experimental/rfm/infer/multicategorical.py +48 -0
  56. kumoai/experimental/rfm/infer/timestamp.py +41 -0
  57. kumoai/experimental/rfm/local_graph.py +810 -0
  58. kumoai/experimental/rfm/local_graph_sampler.py +184 -0
  59. kumoai/experimental/rfm/local_graph_store.py +359 -0
  60. kumoai/experimental/rfm/local_pquery_driver.py +689 -0
  61. kumoai/experimental/rfm/local_table.py +545 -0
  62. kumoai/experimental/rfm/pquery/__init__.py +7 -0
  63. kumoai/experimental/rfm/pquery/executor.py +102 -0
  64. kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
  65. kumoai/experimental/rfm/rfm.py +1130 -0
  66. kumoai/experimental/rfm/utils.py +344 -0
  67. kumoai/formatting.py +30 -0
  68. kumoai/futures.py +99 -0
  69. kumoai/graph/__init__.py +12 -0
  70. kumoai/graph/column.py +106 -0
  71. kumoai/graph/graph.py +948 -0
  72. kumoai/graph/table.py +838 -0
  73. kumoai/jobs.py +80 -0
  74. kumoai/kumolib.cpython-313-x86_64-linux-gnu.so +0 -0
  75. kumoai/mixin.py +28 -0
  76. kumoai/pquery/__init__.py +25 -0
  77. kumoai/pquery/prediction_table.py +287 -0
  78. kumoai/pquery/predictive_query.py +637 -0
  79. kumoai/pquery/training_table.py +424 -0
  80. kumoai/spcs.py +123 -0
  81. kumoai/testing/__init__.py +8 -0
  82. kumoai/testing/decorators.py +57 -0
  83. kumoai/trainer/__init__.py +42 -0
  84. kumoai/trainer/baseline_trainer.py +93 -0
  85. kumoai/trainer/config.py +2 -0
  86. kumoai/trainer/job.py +1192 -0
  87. kumoai/trainer/online_serving.py +258 -0
  88. kumoai/trainer/trainer.py +475 -0
  89. kumoai/trainer/util.py +103 -0
  90. kumoai/utils/__init__.py +10 -0
  91. kumoai/utils/datasets.py +83 -0
  92. kumoai/utils/forecasting.py +209 -0
  93. kumoai/utils/progress_logger.py +177 -0
  94. kumoai-2.13.0.dev202511131731.dist-info/METADATA +60 -0
  95. kumoai-2.13.0.dev202511131731.dist-info/RECORD +98 -0
  96. kumoai-2.13.0.dev202511131731.dist-info/WHEEL +6 -0
  97. kumoai-2.13.0.dev202511131731.dist-info/licenses/LICENSE +9 -0
  98. kumoai-2.13.0.dev202511131731.dist-info/top_level.txt +1 -0
@@ -0,0 +1,532 @@
1
+ from typing import Dict, List, Tuple
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+ from kumoapi.pquery.AST import (
7
+ Aggregation,
8
+ Column,
9
+ Condition,
10
+ Constant,
11
+ Filter,
12
+ Join,
13
+ LogicalOperation,
14
+ )
15
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
16
+
17
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
18
+
19
+
20
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
21
+ np.ndarray]):
22
+ def execute_column(
23
+ self,
24
+ column: Column,
25
+ feat_dict: Dict[str, pd.DataFrame],
26
+ filter_na: bool = True,
27
+ ) -> Tuple[pd.Series, np.ndarray]:
28
+ table_name, column_name = column.fqn.split(".")
29
+ if column_name == '*':
30
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
31
+ else:
32
+ out = feat_dict[table_name][column_name]
33
+ out = out.reset_index(drop=True)
34
+
35
+ if pd.api.types.is_float_dtype(out):
36
+ out = out.astype('float32')
37
+
38
+ out.name = None
39
+ out.index.name = None
40
+
41
+ mask = out.notna().to_numpy()
42
+
43
+ if not filter_na:
44
+ return out, mask
45
+
46
+ out = out[mask].reset_index(drop=True)
47
+
48
+ # Cast to primitive dtype:
49
+ if pd.api.types.is_integer_dtype(out):
50
+ out = out.astype('int64')
51
+ elif pd.api.types.is_bool_dtype(out):
52
+ out = out.astype('bool')
53
+
54
+ return out, mask
55
+
56
+ def execute_aggregation_type(
57
+ self,
58
+ op: AggregationType,
59
+ feat: pd.Series,
60
+ batch: np.ndarray,
61
+ batch_size: int,
62
+ filter_na: bool = True,
63
+ ) -> Tuple[pd.Series, np.ndarray]:
64
+
65
+ mask = feat.notna()
66
+ feat, batch = feat[mask], batch[mask]
67
+
68
+ if op == AggregationType.LIST_DISTINCT:
69
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
70
+ df = df.drop_duplicates()
71
+ out = df.groupby('batch')['feat'].agg(list)
72
+
73
+ else:
74
+ df = pd.DataFrame(dict(feat=feat, batch=batch))
75
+ if op == AggregationType.AVG:
76
+ agg = 'mean'
77
+ elif op == AggregationType.COUNT:
78
+ agg = 'size'
79
+ else:
80
+ agg = op.lower()
81
+ out = df.groupby('batch')['feat'].agg(agg)
82
+
83
+ if not pd.api.types.is_datetime64_any_dtype(out):
84
+ out = out.astype('float32')
85
+
86
+ out.name = None
87
+ out.index.name = None
88
+
89
+ if op in {AggregationType.SUM, AggregationType.COUNT}:
90
+ out = out.reindex(range(batch_size), fill_value=0)
91
+ mask = np.ones(batch_size, dtype=bool)
92
+ return out, mask
93
+
94
+ mask = np.zeros(batch_size, dtype=bool)
95
+ mask[batch] = True
96
+
97
+ if filter_na:
98
+ return out.reset_index(drop=True), mask
99
+
100
+ out = out.reindex(range(batch_size), fill_value=pd.NA)
101
+
102
+ return out, mask
103
+
104
+ def execute_aggregation(
105
+ self,
106
+ aggr: Aggregation,
107
+ feat_dict: Dict[str, pd.DataFrame],
108
+ time_dict: Dict[str, pd.Series],
109
+ batch_dict: Dict[str, np.ndarray],
110
+ anchor_time: pd.Series,
111
+ filter_na: bool = True,
112
+ num_forecasts: int = 1,
113
+ ) -> Tuple[pd.Series, np.ndarray]:
114
+ target_table = aggr._get_target_column_name().split('.')[0]
115
+ target_batch = batch_dict[target_table]
116
+ target_time = time_dict[target_table]
117
+ if isinstance(aggr.target, Column):
118
+ target_feat, target_mask = self.execute_column(
119
+ column=aggr.target,
120
+ feat_dict=feat_dict,
121
+ filter_na=True,
122
+ )
123
+ else:
124
+ assert isinstance(aggr.target, Filter)
125
+ target_feat, target_mask = self.execute_filter(
126
+ filter=aggr.target,
127
+ feat_dict=feat_dict,
128
+ time_dict=time_dict,
129
+ batch_dict=batch_dict,
130
+ anchor_time=anchor_time,
131
+ filter_na=True,
132
+ )
133
+
134
+ outs: List[pd.Series] = []
135
+ masks: List[np.ndarray] = []
136
+ for _ in range(num_forecasts):
137
+ anchor_target_time = anchor_time[target_batch]
138
+ anchor_target_time = anchor_target_time.reset_index(drop=True)
139
+
140
+ time_filter_mask = (target_time <= anchor_target_time +
141
+ aggr.aggr_time_range.end_date_offset)
142
+ if aggr.aggr_time_range.start is not None:
143
+ start_offset = aggr.aggr_time_range.start_date_offset
144
+ time_filter_mask &= (target_time
145
+ > anchor_target_time + start_offset)
146
+ else:
147
+ assert num_forecasts == 1
148
+ curr_target_mask = target_mask & time_filter_mask
149
+
150
+ out, mask = self.execute_aggregation_type(
151
+ aggr.aggr,
152
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
153
+ drop=True)],
154
+ batch=target_batch[curr_target_mask],
155
+ batch_size=len(anchor_time),
156
+ filter_na=False if num_forecasts > 1 else filter_na,
157
+ )
158
+ outs.append(out)
159
+ masks.append(mask)
160
+
161
+ if num_forecasts > 1:
162
+ anchor_time = (anchor_time +
163
+ aggr.aggr_time_range.end_date_offset)
164
+ if len(outs) == 1:
165
+ assert len(masks) == 1
166
+ return outs[0], masks[0]
167
+
168
+ out = pd.Series([list(ser) for ser in zip(*outs)])
169
+ mask = np.stack(masks, axis=-1).any(axis=-1) # type: ignore
170
+
171
+ if filter_na:
172
+ out = out[mask].reset_index(drop=True)
173
+
174
+ return out, mask
175
+
176
+ def execute_rel_op(
177
+ self,
178
+ left: pd.Series,
179
+ op: RelOp,
180
+ right: Constant,
181
+ ) -> pd.Series:
182
+
183
+ if right.typed_value() is None:
184
+ if op == RelOp.EQ:
185
+ return left.isna()
186
+ assert op == RelOp.NEQ
187
+ return left.notna()
188
+
189
+ # Promote left to float if right is a float to avoid lossy coercion.
190
+ right_value = right.typed_value()
191
+ if pd.api.types.is_integer_dtype(left) and isinstance(
192
+ right_value, float):
193
+ left = left.astype('float64')
194
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
195
+
196
+ if op == RelOp.EQ:
197
+ return (left == value).fillna(False).astype(bool)
198
+ if op == RelOp.NEQ:
199
+ out = (left != value).fillna(False).astype(bool)
200
+ out[left.isna()] = False # N/A != right should always be `False`.
201
+ return out
202
+ if op == RelOp.LEQ:
203
+ return (left <= value).fillna(False).astype(bool)
204
+ if op == RelOp.GEQ:
205
+ return (left >= value).fillna(False).astype(bool)
206
+ if op == RelOp.LT:
207
+ return (left < value).fillna(False).astype(bool)
208
+ if op == RelOp.GT:
209
+ return (left > value).fillna(False).astype(bool)
210
+
211
+ raise NotImplementedError(f"Operator '{op}' not implemented")
212
+
213
+ def execute_member_op(
214
+ self,
215
+ left: pd.Series,
216
+ op: MemberOp,
217
+ right: Constant,
218
+ ) -> pd.Series:
219
+
220
+ if op == MemberOp.IN:
221
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
222
+ return left.isin(ser).astype(bool)
223
+
224
+ raise NotImplementedError(f"Operator '{op}' not implemented")
225
+
226
+ def execute_condition(
227
+ self,
228
+ condition: Condition,
229
+ feat_dict: Dict[str, pd.DataFrame],
230
+ time_dict: Dict[str, pd.Series],
231
+ batch_dict: Dict[str, np.ndarray],
232
+ anchor_time: pd.Series,
233
+ filter_na: bool = True,
234
+ num_forecasts: int = 1,
235
+ ) -> Tuple[pd.Series, np.ndarray]:
236
+ if num_forecasts > 1:
237
+ raise NotImplementedError("Forecasting not yet implemented for "
238
+ "non-regression tasks")
239
+
240
+ assert isinstance(condition.value, Constant)
241
+ value_is_na = condition.value.typed_value() is None
242
+ if isinstance(condition.target, Column):
243
+ left, mask = self.execute_column(
244
+ column=condition.target,
245
+ feat_dict=feat_dict,
246
+ filter_na=filter_na if not value_is_na else False,
247
+ )
248
+ elif isinstance(condition.target, Join):
249
+ left, mask = self.execute_join(
250
+ join=condition.target,
251
+ feat_dict=feat_dict,
252
+ time_dict=time_dict,
253
+ batch_dict=batch_dict,
254
+ anchor_time=anchor_time,
255
+ filter_na=filter_na if not value_is_na else False,
256
+ )
257
+ else:
258
+ assert isinstance(condition.target, Aggregation)
259
+ left, mask = self.execute_aggregation(
260
+ aggr=condition.target,
261
+ feat_dict=feat_dict,
262
+ time_dict=time_dict,
263
+ batch_dict=batch_dict,
264
+ anchor_time=anchor_time,
265
+ filter_na=filter_na if not value_is_na else False,
266
+ )
267
+
268
+ if filter_na and value_is_na:
269
+ mask = np.ones(len(left), dtype=bool)
270
+
271
+ if isinstance(condition.op, RelOp):
272
+ out = self.execute_rel_op(
273
+ left=left,
274
+ op=condition.op,
275
+ right=condition.value,
276
+ )
277
+ else:
278
+ assert isinstance(condition.op, MemberOp)
279
+ out = self.execute_member_op(
280
+ left=left,
281
+ op=condition.op,
282
+ right=condition.value,
283
+ )
284
+
285
+ return out, mask
286
+
287
+ def execute_bool_op(
288
+ self,
289
+ left: pd.Series,
290
+ op: BoolOp,
291
+ right: pd.Series | None,
292
+ ) -> pd.Series:
293
+
294
+ # TODO Implement Kleene-Priest three-value logic.
295
+ if op == BoolOp.AND:
296
+ assert right is not None
297
+ return left & right
298
+ if op == BoolOp.OR:
299
+ assert right is not None
300
+ return left | right
301
+ if op == BoolOp.NOT:
302
+ return ~left
303
+
304
+ raise NotImplementedError(f"Operator '{op}' not implemented")
305
+
306
+ def execute_logical_operation(
307
+ self,
308
+ logical_operation: LogicalOperation,
309
+ feat_dict: Dict[str, pd.DataFrame],
310
+ time_dict: Dict[str, pd.Series],
311
+ batch_dict: Dict[str, np.ndarray],
312
+ anchor_time: pd.Series,
313
+ filter_na: bool = True,
314
+ num_forecasts: int = 1,
315
+ ) -> Tuple[pd.Series, np.ndarray]:
316
+ if num_forecasts > 1:
317
+ raise NotImplementedError("Forecasting not yet implemented for "
318
+ "non-regression tasks")
319
+
320
+ if isinstance(logical_operation.left, Condition):
321
+ left, mask = self.execute_condition(
322
+ condition=logical_operation.left,
323
+ feat_dict=feat_dict,
324
+ time_dict=time_dict,
325
+ batch_dict=batch_dict,
326
+ anchor_time=anchor_time,
327
+ filter_na=False,
328
+ )
329
+ else:
330
+ assert isinstance(logical_operation.left, LogicalOperation)
331
+ left, mask = self.execute_logical_operation(
332
+ logical_operation=logical_operation.left,
333
+ feat_dict=feat_dict,
334
+ time_dict=time_dict,
335
+ batch_dict=batch_dict,
336
+ anchor_time=anchor_time,
337
+ filter_na=False,
338
+ )
339
+
340
+ right = right_mask = None
341
+ if isinstance(logical_operation.right, Condition):
342
+ right, right_mask = self.execute_condition(
343
+ condition=logical_operation.right,
344
+ feat_dict=feat_dict,
345
+ time_dict=time_dict,
346
+ batch_dict=batch_dict,
347
+ anchor_time=anchor_time,
348
+ filter_na=False,
349
+ )
350
+ elif isinstance(logical_operation.right, LogicalOperation):
351
+ right, right_mask = self.execute_logical_operation(
352
+ logical_operation=logical_operation.right,
353
+ feat_dict=feat_dict,
354
+ time_dict=time_dict,
355
+ batch_dict=batch_dict,
356
+ anchor_time=anchor_time,
357
+ filter_na=False,
358
+ )
359
+
360
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
361
+
362
+ if right_mask is not None:
363
+ mask &= right_mask
364
+
365
+ if filter_na:
366
+ out = out[mask].reset_index(drop=True)
367
+
368
+ return out, mask
369
+
370
+ def execute_join(
371
+ self,
372
+ join: Join,
373
+ feat_dict: Dict[str, pd.DataFrame],
374
+ time_dict: Dict[str, pd.Series],
375
+ batch_dict: Dict[str, np.ndarray],
376
+ anchor_time: pd.Series,
377
+ filter_na: bool = True,
378
+ num_forecasts: int = 1,
379
+ ) -> Tuple[pd.Series, np.ndarray]:
380
+ if isinstance(join.rhs_target, Aggregation):
381
+ return self.execute_aggregation(
382
+ aggr=join.rhs_target,
383
+ feat_dict=feat_dict,
384
+ time_dict=time_dict,
385
+ batch_dict=batch_dict,
386
+ anchor_time=anchor_time,
387
+ filter_na=True,
388
+ num_forecasts=num_forecasts,
389
+ )
390
+ raise NotImplementedError(
391
+ f'Unexpected {type(join.rhs_target)} nested in Join')
392
+
393
+ def execute_filter(
394
+ self,
395
+ filter: Filter,
396
+ feat_dict: Dict[str, pd.DataFrame],
397
+ time_dict: Dict[str, pd.Series],
398
+ batch_dict: Dict[str, np.ndarray],
399
+ anchor_time: pd.Series,
400
+ filter_na: bool = True,
401
+ ) -> Tuple[pd.Series, np.ndarray]:
402
+ out, mask = self.execute_column(
403
+ column=filter.target,
404
+ feat_dict=feat_dict,
405
+ filter_na=False,
406
+ )
407
+ if isinstance(filter.condition, Condition):
408
+ _mask = self.execute_condition(
409
+ condition=filter.condition,
410
+ feat_dict=feat_dict,
411
+ time_dict=time_dict,
412
+ batch_dict=batch_dict,
413
+ anchor_time=anchor_time,
414
+ filter_na=False,
415
+ )[0].to_numpy()
416
+ else:
417
+ assert isinstance(filter.condition, LogicalOperation)
418
+ _mask = self.execute_logical_operation(
419
+ logical_operation=filter.condition,
420
+ feat_dict=feat_dict,
421
+ time_dict=time_dict,
422
+ batch_dict=batch_dict,
423
+ anchor_time=anchor_time,
424
+ filter_na=False,
425
+ )[0].to_numpy()
426
+ if filter_na:
427
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
428
+ else:
429
+ return out[_mask].reset_index(drop=True), mask & _mask
430
+
431
+ def execute(
432
+ self,
433
+ query: ValidatedPredictiveQuery,
434
+ feat_dict: Dict[str, pd.DataFrame],
435
+ time_dict: Dict[str, pd.Series],
436
+ batch_dict: Dict[str, np.ndarray],
437
+ anchor_time: pd.Series,
438
+ num_forecasts: int = 1,
439
+ ) -> Tuple[pd.Series, np.ndarray]:
440
+ if isinstance(query.entity_ast, Column):
441
+ out, mask = self.execute_column(
442
+ column=query.entity_ast,
443
+ feat_dict=feat_dict,
444
+ filter_na=True,
445
+ )
446
+ else:
447
+ assert isinstance(query.entity_ast, Filter)
448
+ out, mask = self.execute_filter(
449
+ filter=query.entity_ast,
450
+ feat_dict=feat_dict,
451
+ time_dict=time_dict,
452
+ batch_dict=batch_dict,
453
+ anchor_time=anchor_time,
454
+ )
455
+ if isinstance(query.target_ast, Column):
456
+ out, _mask = self.execute_column(
457
+ column=query.target_ast,
458
+ feat_dict=feat_dict,
459
+ filter_na=True,
460
+ )
461
+ elif isinstance(query.target_ast, Condition):
462
+ out, _mask = self.execute_condition(
463
+ condition=query.target_ast,
464
+ feat_dict=feat_dict,
465
+ time_dict=time_dict,
466
+ batch_dict=batch_dict,
467
+ anchor_time=anchor_time,
468
+ filter_na=True,
469
+ num_forecasts=num_forecasts,
470
+ )
471
+ elif isinstance(query.target_ast, Aggregation):
472
+ out, _mask = self.execute_aggregation(
473
+ aggr=query.target_ast,
474
+ feat_dict=feat_dict,
475
+ time_dict=time_dict,
476
+ batch_dict=batch_dict,
477
+ anchor_time=anchor_time,
478
+ filter_na=True,
479
+ num_forecasts=num_forecasts,
480
+ )
481
+ elif isinstance(query.target_ast, Join):
482
+ out, _mask = self.execute_join(
483
+ join=query.target_ast,
484
+ feat_dict=feat_dict,
485
+ time_dict=time_dict,
486
+ batch_dict=batch_dict,
487
+ anchor_time=anchor_time,
488
+ filter_na=True,
489
+ num_forecasts=num_forecasts,
490
+ )
491
+ elif isinstance(query.target_ast, LogicalOperation):
492
+ out, _mask = self.execute_logical_operation(
493
+ logical_operation=query.target_ast,
494
+ feat_dict=feat_dict,
495
+ time_dict=time_dict,
496
+ batch_dict=batch_dict,
497
+ anchor_time=anchor_time,
498
+ filter_na=True,
499
+ num_forecasts=num_forecasts,
500
+ )
501
+ else:
502
+ raise NotImplementedError(
503
+ f'{type(query.target_ast)} compilation missing.')
504
+ if query.whatif_ast is not None:
505
+ if isinstance(query.whatif_ast, Condition):
506
+ mask &= self.execute_condition(
507
+ condition=query.whatif_ast,
508
+ feat_dict=feat_dict,
509
+ time_dict=time_dict,
510
+ batch_dict=batch_dict,
511
+ anchor_time=anchor_time,
512
+ filter_na=True,
513
+ num_forecasts=num_forecasts,
514
+ )[0]
515
+ elif isinstance(query.whatif_ast, LogicalOperation):
516
+ mask &= self.execute_logical_operation(
517
+ logical_operation=query.whatif_ast,
518
+ feat_dict=feat_dict,
519
+ time_dict=time_dict,
520
+ batch_dict=batch_dict,
521
+ anchor_time=anchor_time,
522
+ filter_na=True,
523
+ num_forecasts=num_forecasts,
524
+ )[0]
525
+ else:
526
+ raise ValueError(
527
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
528
+
529
+ out = out[mask[_mask]]
530
+ mask &= _mask
531
+ out = out.reset_index(drop=True)
532
+ return out, mask