kumoai 2.12.0.dev202511031731__cp313-cp313-macosx_11_0_arm64.whl → 2.12.0.dev202511101731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of kumoai might be problematic. Click here for more details.

kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.12.0.dev202511031731'
1
+ __version__ = '2.12.0.dev202511101731'
@@ -147,3 +147,4 @@ class RFMEndpoints:
147
147
  explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
148
148
  evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
149
149
  validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
150
+ parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
kumoai/client/rfm.py CHANGED
@@ -1,7 +1,11 @@
1
+ from typing import Any
2
+
1
3
  from kumoapi.json_serde import to_json_dict
2
4
  from kumoapi.rfm import (
3
5
  RFMEvaluateResponse,
4
6
  RFMExplanationResponse,
7
+ RFMParseQueryRequest,
8
+ RFMParseQueryResponse,
5
9
  RFMPredictResponse,
6
10
  RFMValidateQueryRequest,
7
11
  RFMValidateQueryResponse,
@@ -26,25 +30,32 @@ class RFMAPI:
26
30
  Returns:
27
31
  RFMPredictResponse containing the predictions
28
32
  """
29
- # Send binary data to the predict endpoint
30
33
  response = self._client._request(
31
- RFMEndpoints.predict, data=request,
32
- headers={'Content-Type': 'application/x-protobuf'})
34
+ RFMEndpoints.predict,
35
+ data=request,
36
+ headers={'Content-Type': 'application/x-protobuf'},
37
+ )
33
38
  raise_on_error(response)
34
39
  return parse_response(RFMPredictResponse, response)
35
40
 
36
- def explain(self, request: bytes) -> RFMExplanationResponse:
41
+ def explain(
42
+ self,
43
+ request: bytes,
44
+ skip_summary: bool = False,
45
+ ) -> RFMExplanationResponse:
37
46
  """Explain the RFM model on the given context.
38
47
 
39
48
  Args:
40
49
  request: The predict request as serialized protobuf.
50
+ skip_summary: Whether to skip generating a human-readable summary
51
+ of the explanation.
41
52
 
42
53
  Returns:
43
54
  RFMPredictResponse containing the explanations
44
55
  """
45
- # Send binary data to the explain endpoint
56
+ params: dict[str, Any] = {'generate_summary': not skip_summary}
46
57
  response = self._client._request(
47
- RFMEndpoints.explain, data=request,
58
+ RFMEndpoints.explain, data=request, params=params,
48
59
  headers={'Content-Type': 'application/x-protobuf'})
49
60
  raise_on_error(response)
50
61
  return parse_response(RFMExplanationResponse, response)
@@ -58,7 +69,6 @@ class RFMAPI:
58
69
  Returns:
59
70
  RFMEvaluateResponse containing the computed metrics
60
71
  """
61
- # Send binary data to the evaluate endpoint
62
72
  response = self._client._request(
63
73
  RFMEndpoints.evaluate, data=request,
64
74
  headers={'Content-Type': 'application/x-protobuf'})
@@ -82,3 +92,21 @@ class RFMAPI:
82
92
  json=to_json_dict(request))
83
93
  raise_on_error(response)
84
94
  return parse_response(RFMValidateQueryResponse, response)
95
+
96
+ def parse_query(
97
+ self,
98
+ request: RFMParseQueryRequest,
99
+ ) -> RFMParseQueryResponse:
100
+ """Validate a predictive query against a graph.
101
+
102
+ Args:
103
+ request: The request object containing
104
+ the query and graph definition
105
+
106
+ Returns:
107
+ RFMParseQueryResponse containing the QueryDefinition
108
+ """
109
+ response = self._client._request(RFMEndpoints.parse_query,
110
+ json=to_json_dict(request))
111
+ raise_on_error(response)
112
+ return parse_response(RFMParseQueryResponse, response)
@@ -36,7 +36,7 @@ import os
36
36
  import kumoai
37
37
  from .local_table import LocalTable
38
38
  from .local_graph import LocalGraph
39
- from .rfm import KumoRFM
39
+ from .rfm import ExplainConfig, Explanation, KumoRFM
40
40
  from .authenticate import authenticate
41
41
 
42
42
 
@@ -60,6 +60,8 @@ __all__ = [
60
60
  'LocalTable',
61
61
  'LocalGraph',
62
62
  'KumoRFM',
63
+ 'ExplainConfig',
64
+ 'Explanation',
63
65
  'authenticate',
64
66
  'init',
65
67
  ]
@@ -1,23 +1,40 @@
1
1
  import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
6
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
7
+ from kumoapi.pquery.AST import (
8
+ Aggregation,
9
+ ASTNode,
10
+ Column,
11
+ Condition,
12
+ Filter,
13
+ Join,
14
+ LogicalOperation,
15
+ )
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.typing import AggregationType, DateOffset, Stype
8
18
 
9
19
  import kumoai.kumolib as kumolib
10
20
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
21
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
22
 
13
23
  _coverage_warned = False
14
24
 
15
25
 
26
+ class SamplingSpec(NamedTuple):
27
+ edge_type: Tuple[str, str, str]
28
+ hop: int
29
+ start_offset: Optional[DateOffset]
30
+ end_offset: Optional[DateOffset]
31
+
32
+
16
33
  class LocalPQueryDriver:
17
34
  def __init__(
18
35
  self,
19
36
  graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
37
+ query: ValidatedPredictiveQuery,
21
38
  random_seed: Optional[int] = None,
22
39
  ) -> None:
23
40
  self._graph_store = graph_store
@@ -33,7 +50,7 @@ class LocalPQueryDriver:
33
50
  if self._query.query_type == QueryType.TEMPORAL:
34
51
  assert exclude_node is None
35
52
 
36
- table_name = self._query.entity.pkey.table_name
53
+ table_name = self._query.entity_table
37
54
  num_nodes = len(self._graph_store.df_dict[table_name])
38
55
  mask_dict = self._graph_store.mask_dict
39
56
 
@@ -66,7 +83,7 @@ class LocalPQueryDriver:
66
83
  anchor_time: pd.Timestamp,
67
84
  ) -> np.ndarray:
68
85
 
69
- entity = self._query.entity.pkey.table_name
86
+ entity = self._query.entity_table
70
87
 
71
88
  # Filter out entities that do not exist yet in time:
72
89
  time_sec = self._graph_store.time_dict.get(entity)
@@ -124,8 +141,7 @@ class LocalPQueryDriver:
124
141
  time = time.astype('datetime64[ns]').reset_index(drop=True)
125
142
  else:
126
143
  assert anchor_time == 'entity'
127
- time = self._graph_store.time_dict[
128
- self._query.entity.pkey.table_name]
144
+ time = self._graph_store.time_dict[self._query.entity_table]
129
145
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
130
146
 
131
147
  y, mask = self(node, time)
@@ -223,8 +239,7 @@ class LocalPQueryDriver:
223
239
  time = time.astype('datetime64[ns]').reset_index(drop=True)
224
240
  else:
225
241
  assert anchor_time == 'entity'
226
- time = self._graph_store.time_dict[
227
- self._query.entity.pkey.table_name]
242
+ time = self._graph_store.time_dict[self._query.entity_table]
228
243
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
229
244
 
230
245
  y, mask = self(node, time)
@@ -249,7 +264,8 @@ class LocalPQueryDriver:
249
264
  reached_end = True
250
265
  break
251
266
  candidate_offset = 0
252
- anchor_time = anchor_time - (self._query.target.end_offset *
267
+ time_frame = self._query.target_timeframe.timeframe
268
+ anchor_time = anchor_time - (time_frame *
253
269
  self._query.num_forecasts)
254
270
  if anchor_time < self._graph_store.min_time:
255
271
  reached_end = True
@@ -304,26 +320,25 @@ class LocalPQueryDriver:
304
320
  time = time.astype('datetime64[ns]').reset_index(drop=True)
305
321
  else:
306
322
  assert anchor_time == 'entity'
307
- time = self._graph_store.time_dict[
308
- self._query.entity.pkey.table_name]
323
+ time = self._graph_store.time_dict[self._query.entity_table]
309
324
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
310
325
 
311
- if self._query.entity.filter is not None:
326
+ if isinstance(self._query.entity_ast, Filter):
312
327
  # Mask out via (temporal) entity filter:
313
- backend = PQueryPandasBackend()
328
+ executor = PQueryPandasExecutor()
314
329
  masks: List[np.ndarray] = []
315
330
  for start in range(0, len(node), batch_size):
316
331
  feat_dict, time_dict, batch_dict = self._sample(
317
332
  node[start:start + batch_size],
318
333
  time.iloc[start:start + batch_size],
319
334
  )
320
- _mask = backend.eval_filter(
321
- filter=self._query.entity.filter,
335
+ _mask = executor.execute_filter(
336
+ filter=self._query.entity_ast,
322
337
  feat_dict=feat_dict,
323
338
  time_dict=time_dict,
324
339
  batch_dict=batch_dict,
325
340
  anchor_time=time.iloc[start:start + batch_size],
326
- )
341
+ )[1]
327
342
  masks.append(_mask)
328
343
 
329
344
  _mask = np.concatenate(masks)
@@ -334,6 +349,96 @@ class LocalPQueryDriver:
334
349
 
335
350
  return mask
336
351
 
352
+ def _get_sampling_specs(
353
+ self,
354
+ node: ASTNode,
355
+ hop: int,
356
+ seed_table_name: str,
357
+ edge_types: List[Tuple[str, str, str]],
358
+ num_forecasts: int = 1,
359
+ ) -> List[SamplingSpec]:
360
+ if isinstance(node, (Aggregation, Column)):
361
+ if isinstance(node, Column):
362
+ table_name = node.fqn.split('.')[0]
363
+ if seed_table_name == table_name:
364
+ return []
365
+ else:
366
+ table_name = node._get_target_column_name().split('.')[0]
367
+
368
+ target_edge_types = [
369
+ edge_type for edge_type in edge_types if
370
+ edge_type[2] == seed_table_name and edge_type[0] == table_name
371
+ ]
372
+ if len(target_edge_types) != 1:
373
+ raise ValueError(
374
+ f"Could not find a unique foreign key from table "
375
+ f"'{seed_table_name}' to '{table_name}'")
376
+
377
+ if isinstance(node, Column):
378
+ return [
379
+ SamplingSpec(
380
+ edge_type=target_edge_types[0],
381
+ hop=hop + 1,
382
+ start_offset=None,
383
+ end_offset=None,
384
+ )
385
+ ]
386
+ spec = SamplingSpec(
387
+ edge_type=target_edge_types[0],
388
+ hop=hop + 1,
389
+ start_offset=node.aggr_time_range.start_date_offset,
390
+ end_offset=node.aggr_time_range.end_date_offset *
391
+ num_forecasts,
392
+ )
393
+ return [spec] + self._get_sampling_specs(
394
+ node.target, hop=hop + 1, seed_table_name=table_name,
395
+ edge_types=edge_types, num_forecasts=num_forecasts)
396
+ specs = []
397
+ for child in node.children:
398
+ specs += self._get_sampling_specs(child, hop, seed_table_name,
399
+ edge_types, num_forecasts)
400
+ return specs
401
+
402
+ def get_sampling_specs(self) -> List[SamplingSpec]:
403
+ edge_types = self._graph_store.edge_types
404
+ specs = self._get_sampling_specs(
405
+ self._query.target_ast, hop=0,
406
+ seed_table_name=self._query.entity_table, edge_types=edge_types,
407
+ num_forecasts=self._query.num_forecasts)
408
+ specs += self._get_sampling_specs(
409
+ self._query.entity_ast, hop=0,
410
+ seed_table_name=self._query.entity_table, edge_types=edge_types)
411
+ if self._query.whatif_ast is not None:
412
+ specs += self._get_sampling_specs(
413
+ self._query.whatif_ast, hop=0,
414
+ seed_table_name=self._query.entity_table,
415
+ edge_types=edge_types)
416
+ # Group specs according to edge type and hop:
417
+ spec_dict: Dict[
418
+ Tuple[Tuple[str, str, str], int],
419
+ Tuple[Optional[DateOffset], Optional[DateOffset]],
420
+ ] = {}
421
+ for spec in specs:
422
+ if (spec.edge_type, spec.hop) not in spec_dict:
423
+ spec_dict[(spec.edge_type, spec.hop)] = (
424
+ spec.start_offset,
425
+ spec.end_offset,
426
+ )
427
+ else:
428
+ start_offset, end_offset = spec_dict[(
429
+ spec.edge_type,
430
+ spec.hop,
431
+ )]
432
+ spec_dict[(spec.edge_type, spec.hop)] = (
433
+ min_date_offset(start_offset, spec.start_offset),
434
+ max_date_offset(end_offset, spec.end_offset),
435
+ )
436
+
437
+ return [
438
+ SamplingSpec(edge, hop, start_offset, end_offset)
439
+ for (edge, hop), (start_offset, end_offset) in spec_dict.items()
440
+ ]
441
+
337
442
  def _sample(
338
443
  self,
339
444
  node: np.ndarray,
@@ -354,7 +459,7 @@ class LocalPQueryDriver:
354
459
  The feature dictionary, the time column dictionary and the batch
355
460
  dictionary.
356
461
  """
357
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
462
+ specs = self.get_sampling_specs()
358
463
  num_hops = max([spec.hop for spec in specs] + [0])
359
464
  num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
360
465
  time_offsets: Dict[
@@ -380,7 +485,7 @@ class LocalPQueryDriver:
380
485
 
381
486
  edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
382
487
  node_types = list(
383
- set([self._query.entity.pkey.table_name])
488
+ set([self._query.entity_table])
384
489
  | set(src for src, _, _ in edge_types)
385
490
  | set(dst for _, _, dst in edge_types))
386
491
 
@@ -412,21 +517,33 @@ class LocalPQueryDriver:
412
517
  '__'.join(edge_type): np.array(values)
413
518
  for edge_type, values in time_offsets.items()
414
519
  },
415
- self._query.entity.pkey.table_name,
520
+ self._query.entity_table,
416
521
  node,
417
522
  anchor_time.astype(int).to_numpy() // 1000**3,
418
523
  )
419
524
 
420
525
  feat_dict: Dict[str, pd.DataFrame] = {}
421
526
  time_dict: Dict[str, pd.Series] = {}
422
- column_dict = self._query.column_dict
423
- time_tables = self._query.time_tables
527
+ column_dict: Dict[str, Set[str]] = {}
528
+ for col in self._query.all_query_columns:
529
+ table_name, col_name = col.split('.')
530
+ if table_name not in column_dict:
531
+ column_dict[table_name] = set()
532
+ if col_name != '*':
533
+ column_dict[table_name].add(col_name)
534
+ time_tables = self.find_time_tables()
424
535
  for table_name in set(list(column_dict.keys()) + time_tables):
425
536
  df = self._graph_store.df_dict[table_name]
426
537
  row_id = node_dict[table_name]
427
538
  df = df.iloc[row_id].reset_index(drop=True)
428
539
  if table_name in column_dict:
429
- feat_dict[table_name] = df[list(column_dict[table_name])]
540
+ if len(column_dict[table_name]) == 0:
541
+ # We are dealing with COUNT(table.*), insert a dummy col
542
+ # to ensure we don't lose the information on node count
543
+ feat_dict[table_name] = pd.DataFrame(
544
+ {'ones': [1] * len(df)})
545
+ else:
546
+ feat_dict[table_name] = df[list(column_dict[table_name])]
430
547
  if table_name in time_tables:
431
548
  time_col = self._graph_store.time_column_dict[table_name]
432
549
  time_dict[table_name] = df[time_col]
@@ -441,7 +558,7 @@ class LocalPQueryDriver:
441
558
 
442
559
  feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
443
560
 
444
- y, mask = PQueryPandasBackend().eval_pquery(
561
+ y, mask = PQueryPandasExecutor().execute(
445
562
  query=self._query,
446
563
  feat_dict=feat_dict,
447
564
  time_dict=time_dict,
@@ -452,6 +569,62 @@ class LocalPQueryDriver:
452
569
 
453
570
  return y, mask
454
571
 
572
+ def find_time_tables(self) -> List[str]:
573
+ def _find_time_tables(node: ASTNode) -> List[str]:
574
+ time_tables = []
575
+ if isinstance(node, Aggregation):
576
+ time_tables.append(
577
+ node._get_target_column_name().split('.')[0])
578
+ for child in node.children:
579
+ time_tables += _find_time_tables(child)
580
+ return time_tables
581
+
582
+ time_tables = _find_time_tables(
583
+ self._query.target_ast) + _find_time_tables(self._query.entity_ast)
584
+ if self._query.whatif_ast is not None:
585
+ time_tables += _find_time_tables(self._query.whatif_ast)
586
+ return list(set(time_tables))
587
+
588
+ @staticmethod
589
+ def get_task_type(
590
+ query: ValidatedPredictiveQuery,
591
+ edge_types: List[Tuple[str, str, str]],
592
+ ) -> TaskType:
593
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
594
+ return TaskType.BINARY_CLASSIFICATION
595
+
596
+ target = query.target_ast
597
+ if isinstance(target, Join):
598
+ target = target.rhs_target
599
+ if isinstance(target, Aggregation):
600
+ if target.aggr == AggregationType.LIST_DISTINCT:
601
+ table_name, col_name = target._get_target_column_name().split(
602
+ '.')
603
+ target_edge_types = [
604
+ edge_type for edge_type in edge_types
605
+ if edge_type[0] == table_name and edge_type[1] == col_name
606
+ ]
607
+ if len(target_edge_types) != 1:
608
+ raise NotImplementedError(
609
+ f"Multilabel-classification queries based on "
610
+ f"'LIST_DISTINCT' are not supported yet. If you "
611
+ f"planned to write a link prediction query instead, "
612
+ f"make sure to register '{col_name}' as a "
613
+ f"foreign key.")
614
+ return TaskType.TEMPORAL_LINK_PREDICTION
615
+
616
+ return TaskType.REGRESSION
617
+
618
+ assert isinstance(target, Column)
619
+
620
+ if target.stype in {Stype.ID, Stype.categorical}:
621
+ return TaskType.MULTICLASS_CLASSIFICATION
622
+
623
+ if target.stype in {Stype.numerical}:
624
+ return TaskType.REGRESSION
625
+
626
+ raise NotImplementedError("Task type not yet supported")
627
+
455
628
 
456
629
  def date_offset_to_seconds(offset: pd.DateOffset) -> int:
457
630
  r"""Convert a :class:`pandas.DateOffset` into a maximum number of
@@ -492,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
492
665
  total_ns += scaled_value
493
666
 
494
667
  return total_ns
668
+
669
+
670
+ def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
671
+ if any(arg is None for arg in args):
672
+ return None
673
+
674
+ anchor = pd.Timestamp('2000-01-01')
675
+ timestamps = [anchor + arg for arg in args]
676
+ assert len(timestamps) > 0
677
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
678
+ return args[argmin]
679
+
680
+
681
+ def max_date_offset(*args: DateOffset) -> DateOffset:
682
+ if any(arg is None for arg in args):
683
+ return None
684
+
685
+ anchor = pd.Timestamp('2000-01-01')
686
+ timestamps = [anchor + arg for arg in args]
687
+ assert len(timestamps) > 0
688
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
689
+ return args[argmax]
@@ -1,11 +1,7 @@
1
- from .backend import PQueryBackend
2
- from .pandas_backend import PQueryPandasBackend
3
1
  from .executor import PQueryExecutor
4
2
  from .pandas_executor import PQueryPandasExecutor
5
3
 
6
4
  __all__ = [
7
- 'PQueryBackend',
8
- 'PQueryPandasBackend',
9
5
  'PQueryExecutor',
10
6
  'PQueryPandasExecutor',
11
7
  ]
@@ -118,7 +118,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
118
118
  target_feat, target_mask = self.execute_column(
119
119
  column=aggr.target,
120
120
  feat_dict=feat_dict,
121
- filter_na=False,
121
+ filter_na=True,
122
122
  )
123
123
  else:
124
124
  assert isinstance(aggr.target, Filter)
@@ -128,7 +128,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
128
128
  time_dict=time_dict,
129
129
  batch_dict=batch_dict,
130
130
  anchor_time=anchor_time,
131
- filter_na=False,
131
+ filter_na=True,
132
132
  )
133
133
 
134
134
  outs: List[pd.Series] = []
@@ -137,19 +137,20 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
137
137
  anchor_target_time = anchor_time[target_batch]
138
138
  anchor_target_time = anchor_target_time.reset_index(drop=True)
139
139
 
140
- curr_target_mask = target_mask & (
141
- target_time
142
- <= anchor_target_time + aggr.aggr_time_range.end_date_offset)
140
+ time_filter_mask = (target_time <= anchor_target_time +
141
+ aggr.aggr_time_range.end_date_offset)
143
142
  if aggr.aggr_time_range.start is not None:
144
143
  start_offset = aggr.aggr_time_range.start_date_offset
145
- curr_target_mask &= (target_time
144
+ time_filter_mask &= (target_time
146
145
  > anchor_target_time + start_offset)
147
146
  else:
148
147
  assert num_forecasts == 1
148
+ curr_target_mask = target_mask & time_filter_mask
149
149
 
150
150
  out, mask = self.execute_aggregation_type(
151
151
  aggr.aggr,
152
- feat=target_feat[curr_target_mask],
152
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
153
+ drop=True)],
153
154
  batch=target_batch[curr_target_mask],
154
155
  batch_size=len(anchor_time),
155
156
  filter_na=False if num_forecasts > 1 else filter_na,
@@ -499,7 +500,32 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
499
500
  )
500
501
  else:
501
502
  raise NotImplementedError(
502
- f'{type(query.target)} compilation missing.')
503
+ f'{type(query.target_ast)} compilation missing.')
504
+ if query.whatif_ast is not None:
505
+ if isinstance(query.whatif_ast, Condition):
506
+ mask &= self.execute_condition(
507
+ condition=query.whatif_ast,
508
+ feat_dict=feat_dict,
509
+ time_dict=time_dict,
510
+ batch_dict=batch_dict,
511
+ anchor_time=anchor_time,
512
+ filter_na=True,
513
+ num_forecasts=num_forecasts,
514
+ )[0]
515
+ elif isinstance(query.whatif_ast, LogicalOperation):
516
+ mask &= self.execute_logical_operation(
517
+ logical_operation=query.whatif_ast,
518
+ feat_dict=feat_dict,
519
+ time_dict=time_dict,
520
+ batch_dict=batch_dict,
521
+ anchor_time=anchor_time,
522
+ filter_na=True,
523
+ num_forecasts=num_forecasts,
524
+ )[0]
525
+ else:
526
+ raise ValueError(
527
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
528
+
503
529
  out = out[mask[_mask]]
504
530
  mask &= _mask
505
531
  out = out.reset_index(drop=True)