kumoai 2.11.0.dev202510161830__py3-none-any.whl → 2.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kumoai/__init__.py CHANGED
@@ -200,9 +200,11 @@ def init(
200
200
 
201
201
  logger = logging.getLogger('kumoai')
202
202
  log_level = logging.getLevelName(logger.getEffectiveLevel())
203
+
203
204
  logger.info(
204
- "Successfully initialized the Kumo SDK against deployment %s, with "
205
- "log level %s.", url, log_level)
205
+ f"Successfully initialized the Kumo SDK (version {__version__}) "
206
+ f"against deployment {url}, with "
207
+ f"log level {log_level}.")
206
208
 
207
209
 
208
210
  def set_log_level(level: str) -> None:
kumoai/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = '2.11.0.dev202510161830'
1
+ __version__ = '2.12.1'
@@ -147,3 +147,4 @@ class RFMEndpoints:
147
147
  explain = Endpoint(f"{BASE}/explain", HTTPMethod.POST)
148
148
  evaluate = Endpoint(f"{BASE}/evaluate", HTTPMethod.POST)
149
149
  validate_query = Endpoint(f"{BASE}/validate_query", HTTPMethod.POST)
150
+ parse_query = Endpoint(f"{BASE}/parse_query", HTTPMethod.POST)
kumoai/client/rfm.py CHANGED
@@ -1,7 +1,11 @@
1
+ from typing import Any
2
+
1
3
  from kumoapi.json_serde import to_json_dict
2
4
  from kumoapi.rfm import (
3
5
  RFMEvaluateResponse,
4
6
  RFMExplanationResponse,
7
+ RFMParseQueryRequest,
8
+ RFMParseQueryResponse,
5
9
  RFMPredictResponse,
6
10
  RFMValidateQueryRequest,
7
11
  RFMValidateQueryResponse,
@@ -26,25 +30,32 @@ class RFMAPI:
26
30
  Returns:
27
31
  RFMPredictResponse containing the predictions
28
32
  """
29
- # Send binary data to the predict endpoint
30
33
  response = self._client._request(
31
- RFMEndpoints.predict, data=request,
32
- headers={'Content-Type': 'application/x-protobuf'})
34
+ RFMEndpoints.predict,
35
+ data=request,
36
+ headers={'Content-Type': 'application/x-protobuf'},
37
+ )
33
38
  raise_on_error(response)
34
39
  return parse_response(RFMPredictResponse, response)
35
40
 
36
- def explain(self, request: bytes) -> RFMExplanationResponse:
41
+ def explain(
42
+ self,
43
+ request: bytes,
44
+ skip_summary: bool = False,
45
+ ) -> RFMExplanationResponse:
37
46
  """Explain the RFM model on the given context.
38
47
 
39
48
  Args:
40
49
  request: The predict request as serialized protobuf.
50
+ skip_summary: Whether to skip generating a human-readable summary
51
+ of the explanation.
41
52
 
42
53
  Returns:
43
54
  RFMPredictResponse containing the explanations
44
55
  """
45
- # Send binary data to the explain endpoint
56
+ params: dict[str, Any] = {'generate_summary': not skip_summary}
46
57
  response = self._client._request(
47
- RFMEndpoints.explain, data=request,
58
+ RFMEndpoints.explain, data=request, params=params,
48
59
  headers={'Content-Type': 'application/x-protobuf'})
49
60
  raise_on_error(response)
50
61
  return parse_response(RFMExplanationResponse, response)
@@ -58,7 +69,6 @@ class RFMAPI:
58
69
  Returns:
59
70
  RFMEvaluateResponse containing the computed metrics
60
71
  """
61
- # Send binary data to the evaluate endpoint
62
72
  response = self._client._request(
63
73
  RFMEndpoints.evaluate, data=request,
64
74
  headers={'Content-Type': 'application/x-protobuf'})
@@ -82,3 +92,21 @@ class RFMAPI:
82
92
  json=to_json_dict(request))
83
93
  raise_on_error(response)
84
94
  return parse_response(RFMValidateQueryResponse, response)
95
+
96
+ def parse_query(
97
+ self,
98
+ request: RFMParseQueryRequest,
99
+ ) -> RFMParseQueryResponse:
100
+ """Validate a predictive query against a graph.
101
+
102
+ Args:
103
+ request: The request object containing
104
+ the query and graph definition
105
+
106
+ Returns:
107
+ RFMParseQueryResponse containing the QueryDefinition
108
+ """
109
+ response = self._client._request(RFMEndpoints.parse_query,
110
+ json=to_json_dict(request))
111
+ raise_on_error(response)
112
+ return parse_response(RFMParseQueryResponse, response)
@@ -36,7 +36,7 @@ import os
36
36
  import kumoai
37
37
  from .local_table import LocalTable
38
38
  from .local_graph import LocalGraph
39
- from .rfm import KumoRFM
39
+ from .rfm import ExplainConfig, Explanation, KumoRFM
40
40
  from .authenticate import authenticate
41
41
 
42
42
 
@@ -60,6 +60,8 @@ __all__ = [
60
60
  'LocalTable',
61
61
  'LocalGraph',
62
62
  'KumoRFM',
63
+ 'ExplainConfig',
64
+ 'Explanation',
63
65
  'authenticate',
64
66
  'init',
65
67
  ]
@@ -1,23 +1,40 @@
1
1
  import warnings
2
- from typing import Dict, List, Literal, Optional, Tuple, Union
2
+ from typing import Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Union
3
3
 
4
4
  import numpy as np
5
5
  import pandas as pd
6
- from kumoapi.pquery import QueryType
7
- from kumoapi.rfm import PQueryDefinition
6
+ from kumoapi.pquery import QueryType, ValidatedPredictiveQuery
7
+ from kumoapi.pquery.AST import (
8
+ Aggregation,
9
+ ASTNode,
10
+ Column,
11
+ Condition,
12
+ Filter,
13
+ Join,
14
+ LogicalOperation,
15
+ )
16
+ from kumoapi.task import TaskType
17
+ from kumoapi.typing import AggregationType, DateOffset, Stype
8
18
 
9
19
  import kumoai.kumolib as kumolib
10
20
  from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
11
- from kumoai.experimental.rfm.pquery import PQueryPandasBackend
21
+ from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
12
22
 
13
23
  _coverage_warned = False
14
24
 
15
25
 
26
+ class SamplingSpec(NamedTuple):
27
+ edge_type: Tuple[str, str, str]
28
+ hop: int
29
+ start_offset: Optional[DateOffset]
30
+ end_offset: Optional[DateOffset]
31
+
32
+
16
33
  class LocalPQueryDriver:
17
34
  def __init__(
18
35
  self,
19
36
  graph_store: LocalGraphStore,
20
- query: PQueryDefinition,
37
+ query: ValidatedPredictiveQuery,
21
38
  random_seed: Optional[int] = None,
22
39
  ) -> None:
23
40
  self._graph_store = graph_store
@@ -33,7 +50,7 @@ class LocalPQueryDriver:
33
50
  if self._query.query_type == QueryType.TEMPORAL:
34
51
  assert exclude_node is None
35
52
 
36
- table_name = self._query.entity.pkey.table_name
53
+ table_name = self._query.entity_table
37
54
  num_nodes = len(self._graph_store.df_dict[table_name])
38
55
  mask_dict = self._graph_store.mask_dict
39
56
 
@@ -66,7 +83,7 @@ class LocalPQueryDriver:
66
83
  anchor_time: pd.Timestamp,
67
84
  ) -> np.ndarray:
68
85
 
69
- entity = self._query.entity.pkey.table_name
86
+ entity = self._query.entity_table
70
87
 
71
88
  # Filter out entities that do not exist yet in time:
72
89
  time_sec = self._graph_store.time_dict.get(entity)
@@ -124,8 +141,7 @@ class LocalPQueryDriver:
124
141
  time = time.astype('datetime64[ns]').reset_index(drop=True)
125
142
  else:
126
143
  assert anchor_time == 'entity'
127
- time = self._graph_store.time_dict[
128
- self._query.entity.pkey.table_name]
144
+ time = self._graph_store.time_dict[self._query.entity_table]
129
145
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
130
146
 
131
147
  y, mask = self(node, time)
@@ -223,8 +239,7 @@ class LocalPQueryDriver:
223
239
  time = time.astype('datetime64[ns]').reset_index(drop=True)
224
240
  else:
225
241
  assert anchor_time == 'entity'
226
- time = self._graph_store.time_dict[
227
- self._query.entity.pkey.table_name]
242
+ time = self._graph_store.time_dict[self._query.entity_table]
228
243
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
229
244
 
230
245
  y, mask = self(node, time)
@@ -249,7 +264,8 @@ class LocalPQueryDriver:
249
264
  reached_end = True
250
265
  break
251
266
  candidate_offset = 0
252
- anchor_time = anchor_time - (self._query.target.end_offset *
267
+ time_frame = self._query.target_timeframe.timeframe
268
+ anchor_time = anchor_time - (time_frame *
253
269
  self._query.num_forecasts)
254
270
  if anchor_time < self._graph_store.min_time:
255
271
  reached_end = True
@@ -304,26 +320,25 @@ class LocalPQueryDriver:
304
320
  time = time.astype('datetime64[ns]').reset_index(drop=True)
305
321
  else:
306
322
  assert anchor_time == 'entity'
307
- time = self._graph_store.time_dict[
308
- self._query.entity.pkey.table_name]
323
+ time = self._graph_store.time_dict[self._query.entity_table]
309
324
  time = pd.Series(time[node] * 1000**3, dtype='datetime64[ns]')
310
325
 
311
- if self._query.entity.filter is not None:
326
+ if isinstance(self._query.entity_ast, Filter):
312
327
  # Mask out via (temporal) entity filter:
313
- backend = PQueryPandasBackend()
328
+ executor = PQueryPandasExecutor()
314
329
  masks: List[np.ndarray] = []
315
330
  for start in range(0, len(node), batch_size):
316
331
  feat_dict, time_dict, batch_dict = self._sample(
317
332
  node[start:start + batch_size],
318
333
  time.iloc[start:start + batch_size],
319
334
  )
320
- _mask = backend.eval_filter(
321
- filter=self._query.entity.filter,
335
+ _mask = executor.execute_filter(
336
+ filter=self._query.entity_ast,
322
337
  feat_dict=feat_dict,
323
338
  time_dict=time_dict,
324
339
  batch_dict=batch_dict,
325
340
  anchor_time=time.iloc[start:start + batch_size],
326
- )
341
+ )[1]
327
342
  masks.append(_mask)
328
343
 
329
344
  _mask = np.concatenate(masks)
@@ -334,6 +349,96 @@ class LocalPQueryDriver:
334
349
 
335
350
  return mask
336
351
 
352
+ def _get_sampling_specs(
353
+ self,
354
+ node: ASTNode,
355
+ hop: int,
356
+ seed_table_name: str,
357
+ edge_types: List[Tuple[str, str, str]],
358
+ num_forecasts: int = 1,
359
+ ) -> List[SamplingSpec]:
360
+ if isinstance(node, (Aggregation, Column)):
361
+ if isinstance(node, Column):
362
+ table_name = node.fqn.split('.')[0]
363
+ if seed_table_name == table_name:
364
+ return []
365
+ else:
366
+ table_name = node._get_target_column_name().split('.')[0]
367
+
368
+ target_edge_types = [
369
+ edge_type for edge_type in edge_types if
370
+ edge_type[2] == seed_table_name and edge_type[0] == table_name
371
+ ]
372
+ if len(target_edge_types) != 1:
373
+ raise ValueError(
374
+ f"Could not find a unique foreign key from table "
375
+ f"'{seed_table_name}' to '{table_name}'")
376
+
377
+ if isinstance(node, Column):
378
+ return [
379
+ SamplingSpec(
380
+ edge_type=target_edge_types[0],
381
+ hop=hop + 1,
382
+ start_offset=None,
383
+ end_offset=None,
384
+ )
385
+ ]
386
+ spec = SamplingSpec(
387
+ edge_type=target_edge_types[0],
388
+ hop=hop + 1,
389
+ start_offset=node.aggr_time_range.start_date_offset,
390
+ end_offset=node.aggr_time_range.end_date_offset *
391
+ num_forecasts,
392
+ )
393
+ return [spec] + self._get_sampling_specs(
394
+ node.target, hop=hop + 1, seed_table_name=table_name,
395
+ edge_types=edge_types, num_forecasts=num_forecasts)
396
+ specs = []
397
+ for child in node.children:
398
+ specs += self._get_sampling_specs(child, hop, seed_table_name,
399
+ edge_types, num_forecasts)
400
+ return specs
401
+
402
+ def get_sampling_specs(self) -> List[SamplingSpec]:
403
+ edge_types = self._graph_store.edge_types
404
+ specs = self._get_sampling_specs(
405
+ self._query.target_ast, hop=0,
406
+ seed_table_name=self._query.entity_table, edge_types=edge_types,
407
+ num_forecasts=self._query.num_forecasts)
408
+ specs += self._get_sampling_specs(
409
+ self._query.entity_ast, hop=0,
410
+ seed_table_name=self._query.entity_table, edge_types=edge_types)
411
+ if self._query.whatif_ast is not None:
412
+ specs += self._get_sampling_specs(
413
+ self._query.whatif_ast, hop=0,
414
+ seed_table_name=self._query.entity_table,
415
+ edge_types=edge_types)
416
+ # Group specs according to edge type and hop:
417
+ spec_dict: Dict[
418
+ Tuple[Tuple[str, str, str], int],
419
+ Tuple[Optional[DateOffset], Optional[DateOffset]],
420
+ ] = {}
421
+ for spec in specs:
422
+ if (spec.edge_type, spec.hop) not in spec_dict:
423
+ spec_dict[(spec.edge_type, spec.hop)] = (
424
+ spec.start_offset,
425
+ spec.end_offset,
426
+ )
427
+ else:
428
+ start_offset, end_offset = spec_dict[(
429
+ spec.edge_type,
430
+ spec.hop,
431
+ )]
432
+ spec_dict[(spec.edge_type, spec.hop)] = (
433
+ min_date_offset(start_offset, spec.start_offset),
434
+ max_date_offset(end_offset, spec.end_offset),
435
+ )
436
+
437
+ return [
438
+ SamplingSpec(edge, hop, start_offset, end_offset)
439
+ for (edge, hop), (start_offset, end_offset) in spec_dict.items()
440
+ ]
441
+
337
442
  def _sample(
338
443
  self,
339
444
  node: np.ndarray,
@@ -354,7 +459,7 @@ class LocalPQueryDriver:
354
459
  The feature dictionary, the time column dictionary and the batch
355
460
  dictionary.
356
461
  """
357
- specs = self._query.get_sampling_specs(self._graph_store.edge_types)
462
+ specs = self.get_sampling_specs()
358
463
  num_hops = max([spec.hop for spec in specs] + [0])
359
464
  num_neighbors: Dict[Tuple[str, str, str], list[int]] = {}
360
465
  time_offsets: Dict[
@@ -380,7 +485,7 @@ class LocalPQueryDriver:
380
485
 
381
486
  edge_types = list(num_neighbors.keys()) + list(time_offsets.keys())
382
487
  node_types = list(
383
- set([self._query.entity.pkey.table_name])
488
+ set([self._query.entity_table])
384
489
  | set(src for src, _, _ in edge_types)
385
490
  | set(dst for _, _, dst in edge_types))
386
491
 
@@ -412,21 +517,33 @@ class LocalPQueryDriver:
412
517
  '__'.join(edge_type): np.array(values)
413
518
  for edge_type, values in time_offsets.items()
414
519
  },
415
- self._query.entity.pkey.table_name,
520
+ self._query.entity_table,
416
521
  node,
417
522
  anchor_time.astype(int).to_numpy() // 1000**3,
418
523
  )
419
524
 
420
525
  feat_dict: Dict[str, pd.DataFrame] = {}
421
526
  time_dict: Dict[str, pd.Series] = {}
422
- column_dict = self._query.column_dict
423
- time_tables = self._query.time_tables
527
+ column_dict: Dict[str, Set[str]] = {}
528
+ for col in self._query.all_query_columns:
529
+ table_name, col_name = col.split('.')
530
+ if table_name not in column_dict:
531
+ column_dict[table_name] = set()
532
+ if col_name != '*':
533
+ column_dict[table_name].add(col_name)
534
+ time_tables = self.find_time_tables()
424
535
  for table_name in set(list(column_dict.keys()) + time_tables):
425
536
  df = self._graph_store.df_dict[table_name]
426
537
  row_id = node_dict[table_name]
427
538
  df = df.iloc[row_id].reset_index(drop=True)
428
539
  if table_name in column_dict:
429
- feat_dict[table_name] = df[list(column_dict[table_name])]
540
+ if len(column_dict[table_name]) == 0:
541
+ # We are dealing with COUNT(table.*), insert a dummy col
542
+ # to ensure we don't lose the information on node count
543
+ feat_dict[table_name] = pd.DataFrame(
544
+ {'ones': [1] * len(df)})
545
+ else:
546
+ feat_dict[table_name] = df[list(column_dict[table_name])]
430
547
  if table_name in time_tables:
431
548
  time_col = self._graph_store.time_column_dict[table_name]
432
549
  time_dict[table_name] = df[time_col]
@@ -441,7 +558,7 @@ class LocalPQueryDriver:
441
558
 
442
559
  feat_dict, time_dict, batch_dict = self._sample(node, anchor_time)
443
560
 
444
- y, mask = PQueryPandasBackend().eval_pquery(
561
+ y, mask = PQueryPandasExecutor().execute(
445
562
  query=self._query,
446
563
  feat_dict=feat_dict,
447
564
  time_dict=time_dict,
@@ -452,6 +569,62 @@ class LocalPQueryDriver:
452
569
 
453
570
  return y, mask
454
571
 
572
+ def find_time_tables(self) -> List[str]:
573
+ def _find_time_tables(node: ASTNode) -> List[str]:
574
+ time_tables = []
575
+ if isinstance(node, Aggregation):
576
+ time_tables.append(
577
+ node._get_target_column_name().split('.')[0])
578
+ for child in node.children:
579
+ time_tables += _find_time_tables(child)
580
+ return time_tables
581
+
582
+ time_tables = _find_time_tables(
583
+ self._query.target_ast) + _find_time_tables(self._query.entity_ast)
584
+ if self._query.whatif_ast is not None:
585
+ time_tables += _find_time_tables(self._query.whatif_ast)
586
+ return list(set(time_tables))
587
+
588
+ @staticmethod
589
+ def get_task_type(
590
+ query: ValidatedPredictiveQuery,
591
+ edge_types: List[Tuple[str, str, str]],
592
+ ) -> TaskType:
593
+ if isinstance(query.target_ast, (Condition, LogicalOperation)):
594
+ return TaskType.BINARY_CLASSIFICATION
595
+
596
+ target = query.target_ast
597
+ if isinstance(target, Join):
598
+ target = target.rhs_target
599
+ if isinstance(target, Aggregation):
600
+ if target.aggr == AggregationType.LIST_DISTINCT:
601
+ table_name, col_name = target._get_target_column_name().split(
602
+ '.')
603
+ target_edge_types = [
604
+ edge_type for edge_type in edge_types
605
+ if edge_type[0] == table_name and edge_type[1] == col_name
606
+ ]
607
+ if len(target_edge_types) != 1:
608
+ raise NotImplementedError(
609
+ f"Multilabel-classification queries based on "
610
+ f"'LIST_DISTINCT' are not supported yet. If you "
611
+ f"planned to write a link prediction query instead, "
612
+ f"make sure to register '{col_name}' as a "
613
+ f"foreign key.")
614
+ return TaskType.TEMPORAL_LINK_PREDICTION
615
+
616
+ return TaskType.REGRESSION
617
+
618
+ assert isinstance(target, Column)
619
+
620
+ if target.stype in {Stype.ID, Stype.categorical}:
621
+ return TaskType.MULTICLASS_CLASSIFICATION
622
+
623
+ if target.stype in {Stype.numerical}:
624
+ return TaskType.REGRESSION
625
+
626
+ raise NotImplementedError("Task type not yet supported")
627
+
455
628
 
456
629
  def date_offset_to_seconds(offset: pd.DateOffset) -> int:
457
630
  r"""Convert a :class:`pandas.DateOffset` into a maximum number of
@@ -492,3 +665,25 @@ def date_offset_to_seconds(offset: pd.DateOffset) -> int:
492
665
  total_ns += scaled_value
493
666
 
494
667
  return total_ns
668
+
669
+
670
+ def min_date_offset(*args: Optional[DateOffset]) -> Optional[DateOffset]:
671
+ if any(arg is None for arg in args):
672
+ return None
673
+
674
+ anchor = pd.Timestamp('2000-01-01')
675
+ timestamps = [anchor + arg for arg in args]
676
+ assert len(timestamps) > 0
677
+ argmin = min(range(len(timestamps)), key=lambda i: timestamps[i])
678
+ return args[argmin]
679
+
680
+
681
+ def max_date_offset(*args: DateOffset) -> DateOffset:
682
+ if any(arg is None for arg in args):
683
+ return None
684
+
685
+ anchor = pd.Timestamp('2000-01-01')
686
+ timestamps = [anchor + arg for arg in args]
687
+ assert len(timestamps) > 0
688
+ argmax = max(range(len(timestamps)), key=lambda i: timestamps[i])
689
+ return args[argmax]
@@ -1,7 +1,7 @@
1
- from .backend import PQueryBackend
2
- from .pandas_backend import PQueryPandasBackend
1
+ from .executor import PQueryExecutor
2
+ from .pandas_executor import PQueryPandasExecutor
3
3
 
4
4
  __all__ = [
5
- 'PQueryBackend',
6
- 'PQueryPandasBackend',
5
+ 'PQueryExecutor',
6
+ 'PQueryPandasExecutor',
7
7
  ]
@@ -1,23 +1,14 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Generic, Optional, Tuple, TypeVar, Union
2
+ from typing import Dict, Generic, Tuple, TypeVar
3
3
 
4
- from kumoapi.rfm import PQueryDefinition
5
- from kumoapi.rfm.pquery import (
4
+ from kumoapi.pquery import ValidatedPredictiveQuery
5
+ from kumoapi.pquery.AST import (
6
6
  Aggregation,
7
- AggregationType,
8
- BoolOp,
9
7
  Column,
10
8
  Condition,
11
9
  Filter,
12
- Float,
13
- FloatList,
14
- Int,
15
- IntList,
10
+ Join,
16
11
  LogicalOperation,
17
- MemberOp,
18
- RelOp,
19
- Str,
20
- StrList,
21
12
  )
22
13
 
23
14
  TableData = TypeVar('TableData')
@@ -25,58 +16,33 @@ ColumnData = TypeVar('ColumnData')
25
16
  IndexData = TypeVar('IndexData')
26
17
 
27
18
 
28
- class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
19
+ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
29
20
  @abstractmethod
30
- def eval_aggregation_type(
21
+ def execute_column(
31
22
  self,
32
- op: AggregationType,
33
- feat: Optional[ColumnData],
34
- batch: IndexData,
35
- batch_size: int,
23
+ column: Column,
24
+ feat_dict: Dict[str, TableData],
36
25
  filter_na: bool = True,
37
26
  ) -> Tuple[ColumnData, IndexData]:
38
27
  pass
39
28
 
40
29
  @abstractmethod
41
- def eval_rel_op(
42
- self,
43
- left: ColumnData,
44
- op: RelOp,
45
- right: Union[Int, Float, Str, None],
46
- ) -> ColumnData:
47
- pass
48
-
49
- @abstractmethod
50
- def eval_member_op(
51
- self,
52
- left: ColumnData,
53
- op: MemberOp,
54
- right: Union[IntList, FloatList, StrList],
55
- ) -> ColumnData:
56
- pass
57
-
58
- @abstractmethod
59
- def eval_bool_op(
60
- self,
61
- left: ColumnData,
62
- op: BoolOp,
63
- right: Optional[ColumnData],
64
- ) -> ColumnData:
65
- pass
66
-
67
- @abstractmethod
68
- def eval_column(
30
+ def execute_aggregation(
69
31
  self,
70
- column: Column,
32
+ aggr: Aggregation,
71
33
  feat_dict: Dict[str, TableData],
34
+ time_dict: Dict[str, ColumnData],
35
+ batch_dict: Dict[str, IndexData],
36
+ anchor_time: ColumnData,
72
37
  filter_na: bool = True,
38
+ num_forecasts: int = 1,
73
39
  ) -> Tuple[ColumnData, IndexData]:
74
40
  pass
75
41
 
76
42
  @abstractmethod
77
- def eval_aggregation(
43
+ def execute_condition(
78
44
  self,
79
- aggr: Aggregation,
45
+ condition: Condition,
80
46
  feat_dict: Dict[str, TableData],
81
47
  time_dict: Dict[str, ColumnData],
82
48
  batch_dict: Dict[str, IndexData],
@@ -87,9 +53,9 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
87
53
  pass
88
54
 
89
55
  @abstractmethod
90
- def eval_condition(
56
+ def execute_logical_operation(
91
57
  self,
92
- condition: Condition,
58
+ logical_operation: LogicalOperation,
93
59
  feat_dict: Dict[str, TableData],
94
60
  time_dict: Dict[str, ColumnData],
95
61
  batch_dict: Dict[str, IndexData],
@@ -100,9 +66,9 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
100
66
  pass
101
67
 
102
68
  @abstractmethod
103
- def eval_logical_operation(
69
+ def execute_join(
104
70
  self,
105
- logical_operation: LogicalOperation,
71
+ join: Join,
106
72
  feat_dict: Dict[str, TableData],
107
73
  time_dict: Dict[str, ColumnData],
108
74
  batch_dict: Dict[str, IndexData],
@@ -113,20 +79,20 @@ class PQueryBackend(Generic[TableData, ColumnData, IndexData], ABC):
113
79
  pass
114
80
 
115
81
  @abstractmethod
116
- def eval_filter(
82
+ def execute_filter(
117
83
  self,
118
84
  filter: Filter,
119
85
  feat_dict: Dict[str, TableData],
120
86
  time_dict: Dict[str, ColumnData],
121
87
  batch_dict: Dict[str, IndexData],
122
88
  anchor_time: ColumnData,
123
- ) -> IndexData:
89
+ ) -> Tuple[ColumnData, IndexData]:
124
90
  pass
125
91
 
126
92
  @abstractmethod
127
- def eval_pquery(
93
+ def execute(
128
94
  self,
129
- query: PQueryDefinition,
95
+ query: ValidatedPredictiveQuery,
130
96
  feat_dict: Dict[str, TableData],
131
97
  time_dict: Dict[str, ColumnData],
132
98
  batch_dict: Dict[str, IndexData],