kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

Files changed (53) hide show
  1. kumoai/__init__.py +22 -11
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +17 -16
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +37 -8
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +164 -46
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
  12. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  13. kumoai/experimental/rfm/backend/local/table.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
  20. kumoai/experimental/rfm/base/__init__.py +25 -0
  21. kumoai/experimental/rfm/base/column.py +66 -0
  22. kumoai/experimental/rfm/base/sampler.py +773 -0
  23. kumoai/experimental/rfm/base/source.py +19 -0
  24. kumoai/experimental/rfm/base/sql_sampler.py +60 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
  27. kumoai/experimental/rfm/infer/__init__.py +6 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +126 -0
  30. kumoai/experimental/rfm/infer/time_col.py +62 -0
  31. kumoai/experimental/rfm/infer/timestamp.py +7 -4
  32. kumoai/experimental/rfm/pquery/__init__.py +4 -4
  33. kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
  34. kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
  35. kumoai/experimental/rfm/rfm.py +669 -246
  36. kumoai/experimental/rfm/sagemaker.py +138 -0
  37. kumoai/jobs.py +1 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/trainer.py +12 -10
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +239 -4
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
  47. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -176
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -404
  50. kumoai/experimental/rfm/utils.py +0 -344
  51. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
  52. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
  53. {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
@@ -1,45 +1,69 @@
1
- from typing import Dict, List, Optional, Tuple, Union
1
+ from typing import Dict, List, Tuple
2
2
 
3
3
  import numpy as np
4
4
  import pandas as pd
5
- from kumoapi.rfm import PQueryDefinition
6
- from kumoapi.rfm.pquery import (
5
+ from kumoapi.pquery import ValidatedPredictiveQuery
6
+ from kumoapi.pquery.AST import (
7
7
  Aggregation,
8
- AggregationType,
9
- BoolOp,
10
8
  Column,
11
9
  Condition,
10
+ Constant,
12
11
  Filter,
13
- Float,
14
- FloatList,
15
- Int,
16
- IntList,
12
+ Join,
17
13
  LogicalOperation,
18
- MemberOp,
19
- RelOp,
20
- Str,
21
- StrList,
22
14
  )
15
+ from kumoapi.typing import AggregationType, BoolOp, MemberOp, RelOp
23
16
 
24
- from kumoai.experimental.rfm.pquery import PQueryBackend
17
+ from kumoai.experimental.rfm.pquery import PQueryExecutor
25
18
 
26
19
 
27
- class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
28
- def eval_aggregation_type(
20
+ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
21
+ np.ndarray]):
22
+ def execute_column(
23
+ self,
24
+ column: Column,
25
+ feat_dict: Dict[str, pd.DataFrame],
26
+ filter_na: bool = True,
27
+ ) -> Tuple[pd.Series, np.ndarray]:
28
+ table_name, column_name = column.fqn.split(".")
29
+ if column_name == '*':
30
+ out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
31
+ else:
32
+ out = feat_dict[table_name][column_name]
33
+ out = out.reset_index(drop=True)
34
+
35
+ if pd.api.types.is_float_dtype(out):
36
+ out = out.astype('float32')
37
+
38
+ out.name = None
39
+ out.index.name = None
40
+
41
+ mask = out.notna().to_numpy()
42
+
43
+ if not filter_na:
44
+ return out, mask
45
+
46
+ out = out[mask].reset_index(drop=True)
47
+
48
+ # Cast to primitive dtype:
49
+ if pd.api.types.is_integer_dtype(out):
50
+ out = out.astype('int64')
51
+ elif pd.api.types.is_bool_dtype(out):
52
+ out = out.astype('bool')
53
+
54
+ return out, mask
55
+
56
+ def execute_aggregation_type(
29
57
  self,
30
58
  op: AggregationType,
31
- feat: Optional[pd.Series],
59
+ feat: pd.Series,
32
60
  batch: np.ndarray,
33
61
  batch_size: int,
34
62
  filter_na: bool = True,
35
63
  ) -> Tuple[pd.Series, np.ndarray]:
36
64
 
37
- if op != AggregationType.COUNT:
38
- assert feat is not None
39
-
40
- if feat is not None:
41
- mask = feat.notna()
42
- feat, batch = feat[mask], batch[mask]
65
+ mask = feat.notna()
66
+ feat, batch = feat[mask], batch[mask]
43
67
 
44
68
  if op == AggregationType.LIST_DISTINCT:
45
69
  df = pd.DataFrame(dict(feat=feat, batch=batch))
@@ -77,102 +101,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
77
101
 
78
102
  return out, mask
79
103
 
80
- def eval_rel_op(
81
- self,
82
- left: pd.Series,
83
- op: RelOp,
84
- right: Union[Int, Float, Str, None],
85
- ) -> pd.Series:
86
-
87
- if right is None:
88
- if op == RelOp.EQ:
89
- return left.isna()
90
- assert op == RelOp.NEQ
91
- return left.notna()
92
-
93
- value = pd.Series([right.value], dtype=left.dtype).iloc[0]
94
-
95
- if op == RelOp.EQ:
96
- return (left == value).fillna(False).astype(bool)
97
- if op == RelOp.NEQ:
98
- out = (left != value).fillna(False).astype(bool)
99
- out[left.isna()] = False # N/A != right should always be `False`.
100
- return out
101
- if op == RelOp.LEQ:
102
- return (left <= value).fillna(False).astype(bool)
103
- if op == RelOp.GEQ:
104
- return (left >= value).fillna(False).astype(bool)
105
- if op == RelOp.LT:
106
- return (left < value).fillna(False).astype(bool)
107
- if op == RelOp.GT:
108
- return (left > value).fillna(False).astype(bool)
109
-
110
- raise NotImplementedError(f"Operator '{op}' not implemented")
111
-
112
- def eval_member_op(
113
- self,
114
- left: pd.Series,
115
- op: MemberOp,
116
- right: Union[IntList, FloatList, StrList],
117
- ) -> pd.Series:
118
-
119
- if op == MemberOp.IN:
120
- ser = pd.Series(right.value, dtype=left.dtype)
121
- return left.isin(ser).astype(bool)
122
-
123
- raise NotImplementedError(f"Operator '{op}' not implemented")
124
-
125
- def eval_bool_op(
126
- self,
127
- left: pd.Series,
128
- op: BoolOp,
129
- right: Optional[pd.Series],
130
- ) -> pd.Series:
131
-
132
- # TODO Implement Kleene-Priest three-value logic.
133
- if op == BoolOp.AND:
134
- assert right is not None
135
- return left & right
136
- if op == BoolOp.OR:
137
- assert right is not None
138
- return left | right
139
- if op == BoolOp.NOT:
140
- return ~left
141
-
142
- raise NotImplementedError(f"Operator '{op}' not implemented")
143
-
144
- def eval_column(
145
- self,
146
- column: Column,
147
- feat_dict: Dict[str, pd.DataFrame],
148
- filter_na: bool = True,
149
- ) -> Tuple[pd.Series, np.ndarray]:
150
-
151
- out = feat_dict[column.table_name][column.column_name]
152
- out = out.reset_index(drop=True)
153
-
154
- if pd.api.types.is_float_dtype(out):
155
- out = out.astype('float32')
156
-
157
- out.name = None
158
- out.index.name = None
159
-
160
- mask = out.notna().to_numpy()
161
-
162
- if not filter_na:
163
- return out, mask
164
-
165
- out = out[mask].reset_index(drop=True)
166
-
167
- # Cast to primitive dtype:
168
- if pd.api.types.is_integer_dtype(out):
169
- out = out.astype('int64')
170
- elif pd.api.types.is_bool_dtype(out):
171
- out = out.astype('bool')
172
-
173
- return out, mask
174
-
175
- def eval_aggregation(
104
+ def execute_aggregation(
176
105
  self,
177
106
  aggr: Aggregation,
178
107
  feat_dict: Dict[str, pd.DataFrame],
@@ -182,49 +111,47 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
182
111
  filter_na: bool = True,
183
112
  num_forecasts: int = 1,
184
113
  ) -> Tuple[pd.Series, np.ndarray]:
185
-
186
- target_table = aggr.column.table_name
114
+ target_table = aggr._get_target_column_name().split('.')[0]
187
115
  target_batch = batch_dict[target_table]
188
116
  target_time = time_dict[target_table]
117
+ if isinstance(aggr.target, Column):
118
+ target_feat, target_mask = self.execute_column(
119
+ column=aggr.target,
120
+ feat_dict=feat_dict,
121
+ filter_na=True,
122
+ )
123
+ else:
124
+ assert isinstance(aggr.target, Filter)
125
+ target_feat, target_mask = self.execute_filter(
126
+ filter=aggr.target,
127
+ feat_dict=feat_dict,
128
+ time_dict=time_dict,
129
+ batch_dict=batch_dict,
130
+ anchor_time=anchor_time,
131
+ filter_na=True,
132
+ )
189
133
 
190
134
  outs: List[pd.Series] = []
191
135
  masks: List[np.ndarray] = []
192
136
  for _ in range(num_forecasts):
193
- anchor_target_time = anchor_time[target_batch]
137
+ anchor_target_time = anchor_time.iloc[target_batch]
194
138
  anchor_target_time = anchor_target_time.reset_index(drop=True)
195
139
 
196
- target_mask = target_time <= anchor_target_time + aggr.end_offset
197
-
198
- if aggr.start is not None:
199
- start_offset = aggr.start * aggr.time_unit.to_offset()
200
- target_mask &= target_time > anchor_target_time + start_offset
140
+ time_filter_mask = (target_time <= anchor_target_time +
141
+ aggr.aggr_time_range.end_date_offset)
142
+ if aggr.aggr_time_range.start is not None:
143
+ start_offset = aggr.aggr_time_range.start_date_offset
144
+ time_filter_mask &= (target_time
145
+ > anchor_target_time + start_offset)
201
146
  else:
202
147
  assert num_forecasts == 1
148
+ curr_target_mask = target_mask & time_filter_mask
203
149
 
204
- if aggr.filter is not None:
205
- target_mask &= self.eval_filter(
206
- filter=aggr.filter,
207
- feat_dict=feat_dict,
208
- time_dict=time_dict,
209
- batch_dict=batch_dict,
210
- anchor_time=anchor_time,
211
- )
212
-
213
- if (aggr.type == AggregationType.COUNT
214
- and aggr.column.column_name == '*'):
215
- target_feat = None
216
- else:
217
- target_feat, _ = self.eval_column(
218
- aggr.column,
219
- feat_dict,
220
- filter_na=False,
221
- )
222
- target_feat = target_feat[target_mask]
223
-
224
- out, mask = self.eval_aggregation_type(
225
- aggr.type,
226
- feat=target_feat,
227
- batch=target_batch[target_mask],
150
+ out, mask = self.execute_aggregation_type(
151
+ aggr.aggr,
152
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
153
+ drop=True)],
154
+ batch=target_batch[curr_target_mask],
228
155
  batch_size=len(anchor_time),
229
156
  filter_na=False if num_forecasts > 1 else filter_na,
230
157
  )
@@ -232,8 +159,8 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
232
159
  masks.append(mask)
233
160
 
234
161
  if num_forecasts > 1:
235
- anchor_time = anchor_time + aggr.end_offset
236
-
162
+ anchor_time = (anchor_time +
163
+ aggr.aggr_time_range.end_date_offset)
237
164
  if len(outs) == 1:
238
165
  assert len(masks) == 1
239
166
  return outs[0], masks[0]
@@ -246,7 +173,57 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
246
173
 
247
174
  return out, mask
248
175
 
249
- def eval_condition(
176
+ def execute_rel_op(
177
+ self,
178
+ left: pd.Series,
179
+ op: RelOp,
180
+ right: Constant,
181
+ ) -> pd.Series:
182
+
183
+ if right.typed_value() is None:
184
+ if op == RelOp.EQ:
185
+ return left.isna()
186
+ assert op == RelOp.NEQ
187
+ return left.notna()
188
+
189
+ # Promote left to float if right is a float to avoid lossy coercion.
190
+ right_value = right.typed_value()
191
+ if pd.api.types.is_integer_dtype(left) and isinstance(
192
+ right_value, float):
193
+ left = left.astype('float64')
194
+ value = pd.Series([right_value], dtype=left.dtype).iloc[0]
195
+
196
+ if op == RelOp.EQ:
197
+ return (left == value).fillna(False).astype(bool)
198
+ if op == RelOp.NEQ:
199
+ out = (left != value).fillna(False).astype(bool)
200
+ out[left.isna()] = False # N/A != right should always be `False`.
201
+ return out
202
+ if op == RelOp.LEQ:
203
+ return (left <= value).fillna(False).astype(bool)
204
+ if op == RelOp.GEQ:
205
+ return (left >= value).fillna(False).astype(bool)
206
+ if op == RelOp.LT:
207
+ return (left < value).fillna(False).astype(bool)
208
+ if op == RelOp.GT:
209
+ return (left > value).fillna(False).astype(bool)
210
+
211
+ raise NotImplementedError(f"Operator '{op}' not implemented")
212
+
213
+ def execute_member_op(
214
+ self,
215
+ left: pd.Series,
216
+ op: MemberOp,
217
+ right: Constant,
218
+ ) -> pd.Series:
219
+
220
+ if op == MemberOp.IN:
221
+ ser = pd.Series(right.typed_value(), dtype=left.dtype)
222
+ return left.isin(ser).astype(bool)
223
+
224
+ raise NotImplementedError(f"Operator '{op}' not implemented")
225
+
226
+ def execute_condition(
250
227
  self,
251
228
  condition: Condition,
252
229
  feat_dict: Dict[str, pd.DataFrame],
@@ -256,48 +233,77 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
256
233
  filter_na: bool = True,
257
234
  num_forecasts: int = 1,
258
235
  ) -> Tuple[pd.Series, np.ndarray]:
259
-
260
236
  if num_forecasts > 1:
261
237
  raise NotImplementedError("Forecasting not yet implemented for "
262
238
  "non-regression tasks")
263
239
 
264
- if isinstance(condition.left, Column):
265
- left, mask = self.eval_column(
266
- column=condition.left,
240
+ assert isinstance(condition.value, Constant)
241
+ value_is_na = condition.value.typed_value() is None
242
+ if isinstance(condition.target, Column):
243
+ left, mask = self.execute_column(
244
+ column=condition.target,
267
245
  feat_dict=feat_dict,
268
- filter_na=filter_na if condition.right is not None else False,
246
+ filter_na=filter_na if not value_is_na else False,
247
+ )
248
+ elif isinstance(condition.target, Join):
249
+ left, mask = self.execute_join(
250
+ join=condition.target,
251
+ feat_dict=feat_dict,
252
+ time_dict=time_dict,
253
+ batch_dict=batch_dict,
254
+ anchor_time=anchor_time,
255
+ filter_na=filter_na if not value_is_na else False,
269
256
  )
270
257
  else:
271
- assert isinstance(condition.left, Aggregation)
272
- left, mask = self.eval_aggregation(
273
- aggr=condition.left,
258
+ assert isinstance(condition.target, Aggregation)
259
+ left, mask = self.execute_aggregation(
260
+ aggr=condition.target,
274
261
  feat_dict=feat_dict,
275
262
  time_dict=time_dict,
276
263
  batch_dict=batch_dict,
277
264
  anchor_time=anchor_time,
278
- filter_na=filter_na if condition.right is not None else False,
265
+ filter_na=filter_na if not value_is_na else False,
279
266
  )
280
267
 
281
- if filter_na and condition.right is None:
268
+ if filter_na and value_is_na:
282
269
  mask = np.ones(len(left), dtype=bool)
283
270
 
284
271
  if isinstance(condition.op, RelOp):
285
- out = self.eval_rel_op(
272
+ out = self.execute_rel_op(
286
273
  left=left,
287
274
  op=condition.op,
288
- right=condition.right,
275
+ right=condition.value,
289
276
  )
290
277
  else:
291
278
  assert isinstance(condition.op, MemberOp)
292
- out = self.eval_member_op(
279
+ out = self.execute_member_op(
293
280
  left=left,
294
281
  op=condition.op,
295
- right=condition.right,
282
+ right=condition.value,
296
283
  )
297
284
 
298
285
  return out, mask
299
286
 
300
- def eval_logical_operation(
287
+ def execute_bool_op(
288
+ self,
289
+ left: pd.Series,
290
+ op: BoolOp,
291
+ right: pd.Series | None,
292
+ ) -> pd.Series:
293
+
294
+ # TODO Implement Kleene-Priest three-value logic.
295
+ if op == BoolOp.AND:
296
+ assert right is not None
297
+ return left & right
298
+ if op == BoolOp.OR:
299
+ assert right is not None
300
+ return left | right
301
+ if op == BoolOp.NOT:
302
+ return ~left
303
+
304
+ raise NotImplementedError(f"Operator '{op}' not implemented")
305
+
306
+ def execute_logical_operation(
301
307
  self,
302
308
  logical_operation: LogicalOperation,
303
309
  feat_dict: Dict[str, pd.DataFrame],
@@ -307,13 +313,12 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
307
313
  filter_na: bool = True,
308
314
  num_forecasts: int = 1,
309
315
  ) -> Tuple[pd.Series, np.ndarray]:
310
-
311
316
  if num_forecasts > 1:
312
317
  raise NotImplementedError("Forecasting not yet implemented for "
313
318
  "non-regression tasks")
314
319
 
315
320
  if isinstance(logical_operation.left, Condition):
316
- left, mask = self.eval_condition(
321
+ left, mask = self.execute_condition(
317
322
  condition=logical_operation.left,
318
323
  feat_dict=feat_dict,
319
324
  time_dict=time_dict,
@@ -323,7 +328,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
323
328
  )
324
329
  else:
325
330
  assert isinstance(logical_operation.left, LogicalOperation)
326
- left, mask = self.eval_logical_operation(
331
+ left, mask = self.execute_logical_operation(
327
332
  logical_operation=logical_operation.left,
328
333
  feat_dict=feat_dict,
329
334
  time_dict=time_dict,
@@ -334,7 +339,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
334
339
 
335
340
  right = right_mask = None
336
341
  if isinstance(logical_operation.right, Condition):
337
- right, right_mask = self.eval_condition(
342
+ right, right_mask = self.execute_condition(
338
343
  condition=logical_operation.right,
339
344
  feat_dict=feat_dict,
340
345
  time_dict=time_dict,
@@ -343,7 +348,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
343
348
  filter_na=False,
344
349
  )
345
350
  elif isinstance(logical_operation.right, LogicalOperation):
346
- right, right_mask = self.eval_logical_operation(
351
+ right, right_mask = self.execute_logical_operation(
347
352
  logical_operation=logical_operation.right,
348
353
  feat_dict=feat_dict,
349
354
  time_dict=time_dict,
@@ -352,7 +357,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
352
357
  filter_na=False,
353
358
  )
354
359
 
355
- out = self.eval_bool_op(left, logical_operation.op, right)
360
+ out = self.execute_bool_op(left, logical_operation.bool_op, right)
356
361
 
357
362
  if right_mask is not None:
358
363
  mask &= right_mask
@@ -362,16 +367,45 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
362
367
 
363
368
  return out, mask
364
369
 
365
- def eval_filter(
370
+ def execute_join(
371
+ self,
372
+ join: Join,
373
+ feat_dict: Dict[str, pd.DataFrame],
374
+ time_dict: Dict[str, pd.Series],
375
+ batch_dict: Dict[str, np.ndarray],
376
+ anchor_time: pd.Series,
377
+ filter_na: bool = True,
378
+ num_forecasts: int = 1,
379
+ ) -> Tuple[pd.Series, np.ndarray]:
380
+ if isinstance(join.rhs_target, Aggregation):
381
+ return self.execute_aggregation(
382
+ aggr=join.rhs_target,
383
+ feat_dict=feat_dict,
384
+ time_dict=time_dict,
385
+ batch_dict=batch_dict,
386
+ anchor_time=anchor_time,
387
+ filter_na=True,
388
+ num_forecasts=num_forecasts,
389
+ )
390
+ raise NotImplementedError(
391
+ f'Unexpected {type(join.rhs_target)} nested in Join')
392
+
393
+ def execute_filter(
366
394
  self,
367
395
  filter: Filter,
368
396
  feat_dict: Dict[str, pd.DataFrame],
369
397
  time_dict: Dict[str, pd.Series],
370
398
  batch_dict: Dict[str, np.ndarray],
371
399
  anchor_time: pd.Series,
372
- ) -> np.ndarray:
400
+ filter_na: bool = True,
401
+ ) -> Tuple[pd.Series, np.ndarray]:
402
+ out, mask = self.execute_column(
403
+ column=filter.target,
404
+ feat_dict=feat_dict,
405
+ filter_na=False,
406
+ )
373
407
  if isinstance(filter.condition, Condition):
374
- return self.eval_condition(
408
+ _mask = self.execute_condition(
375
409
  condition=filter.condition,
376
410
  feat_dict=feat_dict,
377
411
  time_dict=time_dict,
@@ -381,7 +415,7 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
381
415
  )[0].to_numpy()
382
416
  else:
383
417
  assert isinstance(filter.condition, LogicalOperation)
384
- return self.eval_logical_operation(
418
+ _mask = self.execute_logical_operation(
385
419
  logical_operation=filter.condition,
386
420
  feat_dict=feat_dict,
387
421
  time_dict=time_dict,
@@ -389,58 +423,44 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
389
423
  anchor_time=anchor_time,
390
424
  filter_na=False,
391
425
  )[0].to_numpy()
426
+ if filter_na:
427
+ return out[_mask & mask].reset_index(drop=True), _mask & mask
428
+ else:
429
+ return out[_mask].reset_index(drop=True), mask & _mask
392
430
 
393
- def eval_pquery(
431
+ def execute(
394
432
  self,
395
- query: PQueryDefinition,
433
+ query: ValidatedPredictiveQuery,
396
434
  feat_dict: Dict[str, pd.DataFrame],
397
435
  time_dict: Dict[str, pd.Series],
398
436
  batch_dict: Dict[str, np.ndarray],
399
437
  anchor_time: pd.Series,
400
438
  num_forecasts: int = 1,
401
439
  ) -> Tuple[pd.Series, np.ndarray]:
402
-
403
- mask = np.ones(len(anchor_time), dtype=bool)
404
-
405
- if query.entity.filter is not None:
406
- mask &= self.eval_filter(
407
- filter=query.entity.filter,
440
+ if isinstance(query.entity_ast, Column):
441
+ out, mask = self.execute_column(
442
+ column=query.entity_ast,
443
+ feat_dict=feat_dict,
444
+ filter_na=True,
445
+ )
446
+ else:
447
+ assert isinstance(query.entity_ast, Filter)
448
+ out, mask = self.execute_filter(
449
+ filter=query.entity_ast,
408
450
  feat_dict=feat_dict,
409
451
  time_dict=time_dict,
410
452
  batch_dict=batch_dict,
411
453
  anchor_time=anchor_time,
412
454
  )
413
-
414
- if getattr(query, 'assuming', None) is not None:
415
- if isinstance(query.assuming, Condition):
416
- mask &= self.eval_condition(
417
- condition=query.assuming,
418
- feat_dict=feat_dict,
419
- time_dict=time_dict,
420
- batch_dict=batch_dict,
421
- anchor_time=anchor_time,
422
- filter_na=False,
423
- )[0].to_numpy()
424
- else:
425
- assert isinstance(query.assuming, LogicalOperation)
426
- mask &= self.eval_logical_operation(
427
- logical_operation=query.assuming,
428
- feat_dict=feat_dict,
429
- time_dict=time_dict,
430
- batch_dict=batch_dict,
431
- anchor_time=anchor_time,
432
- filter_na=False,
433
- )[0].to_numpy()
434
-
435
- if isinstance(query.target, Column):
436
- out, _mask = self.eval_column(
437
- column=query.target,
455
+ if isinstance(query.target_ast, Column):
456
+ out, _mask = self.execute_column(
457
+ column=query.target_ast,
438
458
  feat_dict=feat_dict,
439
459
  filter_na=True,
440
460
  )
441
- elif isinstance(query.target, Aggregation):
442
- out, _mask = self.eval_aggregation(
443
- aggr=query.target,
461
+ elif isinstance(query.target_ast, Condition):
462
+ out, _mask = self.execute_condition(
463
+ condition=query.target_ast,
444
464
  feat_dict=feat_dict,
445
465
  time_dict=time_dict,
446
466
  batch_dict=batch_dict,
@@ -448,9 +468,9 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
448
468
  filter_na=True,
449
469
  num_forecasts=num_forecasts,
450
470
  )
451
- elif isinstance(query.target, Condition):
452
- out, _mask = self.eval_condition(
453
- condition=query.target,
471
+ elif isinstance(query.target_ast, Aggregation):
472
+ out, _mask = self.execute_aggregation(
473
+ aggr=query.target_ast,
454
474
  feat_dict=feat_dict,
455
475
  time_dict=time_dict,
456
476
  batch_dict=batch_dict,
@@ -458,10 +478,9 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
458
478
  filter_na=True,
459
479
  num_forecasts=num_forecasts,
460
480
  )
461
- else:
462
- assert isinstance(query.target, LogicalOperation)
463
- out, _mask = self.eval_logical_operation(
464
- logical_operation=query.target,
481
+ elif isinstance(query.target_ast, Join):
482
+ out, _mask = self.execute_join(
483
+ join=query.target_ast,
465
484
  feat_dict=feat_dict,
466
485
  time_dict=time_dict,
467
486
  batch_dict=batch_dict,
@@ -469,10 +488,45 @@ class PQueryPandasBackend(PQueryBackend[pd.DataFrame, pd.Series, np.ndarray]):
469
488
  filter_na=True,
470
489
  num_forecasts=num_forecasts,
471
490
  )
491
+ elif isinstance(query.target_ast, LogicalOperation):
492
+ out, _mask = self.execute_logical_operation(
493
+ logical_operation=query.target_ast,
494
+ feat_dict=feat_dict,
495
+ time_dict=time_dict,
496
+ batch_dict=batch_dict,
497
+ anchor_time=anchor_time,
498
+ filter_na=True,
499
+ num_forecasts=num_forecasts,
500
+ )
501
+ else:
502
+ raise NotImplementedError(
503
+ f'{type(query.target_ast)} compilation missing.')
504
+ if query.whatif_ast is not None:
505
+ if isinstance(query.whatif_ast, Condition):
506
+ mask &= self.execute_condition(
507
+ condition=query.whatif_ast,
508
+ feat_dict=feat_dict,
509
+ time_dict=time_dict,
510
+ batch_dict=batch_dict,
511
+ anchor_time=anchor_time,
512
+ filter_na=True,
513
+ num_forecasts=num_forecasts,
514
+ )[0]
515
+ elif isinstance(query.whatif_ast, LogicalOperation):
516
+ mask &= self.execute_logical_operation(
517
+ logical_operation=query.whatif_ast,
518
+ feat_dict=feat_dict,
519
+ time_dict=time_dict,
520
+ batch_dict=batch_dict,
521
+ anchor_time=anchor_time,
522
+ filter_na=True,
523
+ num_forecasts=num_forecasts,
524
+ )[0]
525
+ else:
526
+ raise ValueError(
527
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
472
528
 
473
529
  out = out[mask[_mask]]
474
530
  mask &= _mask
475
-
476
531
  out = out.reset_index(drop=True)
477
-
478
532
  return out, mask