kumoai 2.12.0.dev202510231830__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/endpoints.py +1 -0
  5. kumoai/client/jobs.py +24 -0
  6. kumoai/client/pquery.py +6 -2
  7. kumoai/client/rfm.py +35 -7
  8. kumoai/connector/utils.py +23 -2
  9. kumoai/experimental/rfm/__init__.py +191 -48
  10. kumoai/experimental/rfm/authenticate.py +3 -4
  11. kumoai/experimental/rfm/backend/__init__.py +0 -0
  12. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  13. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  14. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  15. kumoai/experimental/rfm/backend/local/table.py +113 -0
  16. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  17. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  18. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  19. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  20. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  21. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  22. kumoai/experimental/rfm/base/__init__.py +30 -0
  23. kumoai/experimental/rfm/base/column.py +152 -0
  24. kumoai/experimental/rfm/base/expression.py +44 -0
  25. kumoai/experimental/rfm/base/sampler.py +761 -0
  26. kumoai/experimental/rfm/base/source.py +19 -0
  27. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  28. kumoai/experimental/rfm/base/table.py +735 -0
  29. kumoai/experimental/rfm/graph.py +1237 -0
  30. kumoai/experimental/rfm/infer/__init__.py +8 -0
  31. kumoai/experimental/rfm/infer/dtype.py +82 -0
  32. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  33. kumoai/experimental/rfm/infer/pkey.py +128 -0
  34. kumoai/experimental/rfm/infer/stype.py +35 -0
  35. kumoai/experimental/rfm/infer/time_col.py +61 -0
  36. kumoai/experimental/rfm/pquery/__init__.py +0 -4
  37. kumoai/experimental/rfm/pquery/executor.py +27 -27
  38. kumoai/experimental/rfm/pquery/pandas_executor.py +64 -40
  39. kumoai/experimental/rfm/relbench.py +76 -0
  40. kumoai/experimental/rfm/rfm.py +386 -276
  41. kumoai/experimental/rfm/sagemaker.py +138 -0
  42. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  43. kumoai/pquery/predictive_query.py +10 -6
  44. kumoai/spcs.py +1 -3
  45. kumoai/testing/decorators.py +1 -1
  46. kumoai/testing/snow.py +50 -0
  47. kumoai/trainer/distilled_trainer.py +175 -0
  48. kumoai/trainer/trainer.py +9 -10
  49. kumoai/utils/__init__.py +3 -2
  50. kumoai/utils/display.py +51 -0
  51. kumoai/utils/progress_logger.py +188 -16
  52. kumoai/utils/sql.py +3 -0
  53. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  54. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +57 -36
  55. kumoai/experimental/rfm/local_graph.py +0 -810
  56. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  57. kumoai/experimental/rfm/local_pquery_driver.py +0 -494
  58. kumoai/experimental/rfm/local_table.py +0 -545
  59. kumoai/experimental/rfm/pquery/backend.py +0 -136
  60. kumoai/experimental/rfm/pquery/pandas_backend.py +0 -478
  61. kumoai/experimental/rfm/utils.py +0 -344
  62. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  63. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  64. {kumoai-2.12.0.dev202510231830.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,19 @@
1
+ from .dtype import infer_dtype
1
2
  from .id import contains_id
2
3
  from .timestamp import contains_timestamp
3
4
  from .categorical import contains_categorical
4
5
  from .multicategorical import contains_multicategorical
6
+ from .stype import infer_stype
7
+ from .pkey import infer_primary_key
8
+ from .time_col import infer_time_column
5
9
 
6
10
  __all__ = [
11
+ 'infer_dtype',
7
12
  'contains_id',
8
13
  'contains_timestamp',
9
14
  'contains_categorical',
10
15
  'contains_multicategorical',
16
+ 'infer_stype',
17
+ 'infer_primary_key',
18
+ 'infer_time_column',
11
19
  ]
@@ -0,0 +1,82 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pyarrow as pa
4
+ from kumoapi.typing import Dtype
5
+
6
+ PANDAS_TO_DTYPE: dict[str, Dtype] = {
7
+ 'bool': Dtype.bool,
8
+ 'boolean': Dtype.bool,
9
+ 'int8': Dtype.int,
10
+ 'int16': Dtype.int,
11
+ 'int32': Dtype.int,
12
+ 'int64': Dtype.int,
13
+ 'float': Dtype.float,
14
+ 'double': Dtype.float,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ 'binary[python]': Dtype.binary,
24
+ 'binary[pyarrow]': Dtype.binary,
25
+ }
26
+
27
+
28
+ def infer_dtype(ser: pd.Series) -> Dtype:
29
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
30
+
31
+ Args:
32
+ ser: A :class:`pandas.Series` to analyze.
33
+
34
+ Returns:
35
+ The data type.
36
+ """
37
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
38
+ return Dtype.date
39
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
40
+ return Dtype.timedelta
41
+ if isinstance(ser.dtype, pd.CategoricalDtype):
42
+ return Dtype.string
43
+
44
+ if (pd.api.types.is_object_dtype(ser.dtype)
45
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
46
+ index = ser.iloc[:1000].first_valid_index()
47
+ if index is not None and pd.api.types.is_list_like(ser[index]):
48
+ pos = ser.index.get_loc(index)
49
+ assert isinstance(pos, int)
50
+ ser = ser.iloc[pos:pos + 1000].dropna()
51
+ arr = pa.array(ser.tolist())
52
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
53
+
54
+ if isinstance(ser.dtype, pd.ArrowDtype):
55
+ if (pa.types.is_list(ser.dtype.pyarrow_dtype)
56
+ or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
57
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
58
+ if pa.types.is_integer(elem_dtype):
59
+ return Dtype.intlist
60
+ if pa.types.is_floating(elem_dtype):
61
+ return Dtype.floatlist
62
+ if pa.types.is_decimal(elem_dtype):
63
+ return Dtype.floatlist
64
+ if pa.types.is_string(elem_dtype):
65
+ return Dtype.stringlist
66
+ if pa.types.is_null(elem_dtype):
67
+ return Dtype.floatlist
68
+
69
+ if isinstance(ser.dtype, np.dtype):
70
+ dtype_str = str(ser.dtype).lower()
71
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
72
+ dtype_str = ser.dtype.name.lower()
73
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
74
+ elif isinstance(ser.dtype, pa.DataType):
75
+ dtype_str = str(ser.dtype).lower()
76
+ else:
77
+ dtype_str = 'object'
78
+
79
+ if dtype_str not in PANDAS_TO_DTYPE:
80
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
81
+
82
+ return PANDAS_TO_DTYPE[dtype_str]
@@ -40,7 +40,7 @@ def contains_multicategorical(
40
40
  sep = max(candidates, key=candidates.get) # type: ignore
41
41
  ser = ser.str.split(sep)
42
42
 
43
- num_unique_multi = ser.explode().nunique()
43
+ num_unique_multi = ser.astype('object').explode().nunique()
44
44
 
45
45
  if dtype.is_list():
46
46
  return num_unique_multi <= MAX_CAT
@@ -0,0 +1,128 @@
1
+ import re
2
+ import warnings
3
+
4
+ import pandas as pd
5
+
6
+
7
+ def infer_primary_key(
8
+ table_name: str,
9
+ df: pd.DataFrame,
10
+ candidates: list[str],
11
+ ) -> str | None:
12
+ r"""Auto-detect potential primary key column.
13
+
14
+ Args:
15
+ table_name: The table name.
16
+ df: The pandas DataFrame to analyze.
17
+ candidates: A list of potential candidates.
18
+
19
+ Returns:
20
+ The name of the detected primary key, or ``None`` if not found.
21
+ """
22
+ if len(candidates) == 0:
23
+ return None
24
+
25
+ # A list of (potentially modified) table names that are eligible to match
26
+ # with a primary key, i.e.:
27
+ # - UserInfo -> User
28
+ # - snakecase <-> camelcase
29
+ # - camelcase <-> snakecase
30
+ # - plural <-> singular (users -> user, eligibilities -> eligibility)
31
+ # - verb -> noun (qualifying -> qualify)
32
+ _table_names = {table_name}
33
+ if table_name.lower().endswith('_info'):
34
+ _table_names.add(table_name[:-5])
35
+ elif table_name.lower().endswith('info'):
36
+ _table_names.add(table_name[:-4])
37
+
38
+ table_names = set()
39
+ for _table_name in _table_names:
40
+ table_names.add(_table_name.lower())
41
+ snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
42
+ snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
43
+ table_names.add(snakecase.lower())
44
+ camelcase = _table_name.replace('_', '')
45
+ table_names.add(camelcase.lower())
46
+ if _table_name.lower().endswith('s'):
47
+ table_names.add(_table_name.lower()[:-1])
48
+ table_names.add(snakecase.lower()[:-1])
49
+ table_names.add(camelcase.lower()[:-1])
50
+ else:
51
+ table_names.add(_table_name.lower() + 's')
52
+ table_names.add(snakecase.lower() + 's')
53
+ table_names.add(camelcase.lower() + 's')
54
+ if _table_name.lower().endswith('ies'):
55
+ table_names.add(_table_name.lower()[:-3] + 'y')
56
+ table_names.add(snakecase.lower()[:-3] + 'y')
57
+ table_names.add(camelcase.lower()[:-3] + 'y')
58
+ elif _table_name.lower().endswith('y'):
59
+ table_names.add(_table_name.lower()[:-1] + 'ies')
60
+ table_names.add(snakecase.lower()[:-1] + 'ies')
61
+ table_names.add(camelcase.lower()[:-1] + 'ies')
62
+ if _table_name.lower().endswith('ing'):
63
+ table_names.add(_table_name.lower()[:-3])
64
+ table_names.add(snakecase.lower()[:-3])
65
+ table_names.add(camelcase.lower()[:-3])
66
+
67
+ scores: list[tuple[str, int]] = []
68
+ for col_name in candidates:
69
+ col_name_lower = col_name.lower()
70
+
71
+ score = 0
72
+
73
+ if col_name_lower == 'id':
74
+ score += 4
75
+
76
+ for table_name_lower in table_names:
77
+
78
+ if col_name_lower == table_name_lower:
79
+ score += 4 # USER -> USER
80
+ break
81
+
82
+ for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
83
+ if not col_name_lower.endswith(suffix):
84
+ continue
85
+
86
+ if col_name_lower == f'{table_name_lower}_{suffix}':
87
+ score += 5 # USER -> USER_ID
88
+ break
89
+
90
+ if col_name_lower == f'{table_name_lower}{suffix}':
91
+ score += 5 # User -> UserId
92
+ break
93
+
94
+ if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
95
+ score += 2
96
+
97
+ if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
98
+ score += 2
99
+
100
+ # `rel-bench` hard-coding :(
101
+ if table_name == 'studies' and col_name == 'nct_id':
102
+ score += 1
103
+
104
+ ser = df[col_name].iloc[:1_000_000]
105
+ score += 3 * (ser.nunique() / len(ser))
106
+
107
+ scores.append((col_name, score))
108
+
109
+ scores = [x for x in scores if x[-1] >= 4]
110
+ scores.sort(key=lambda x: x[-1], reverse=True)
111
+
112
+ if len(scores) == 0:
113
+ return None
114
+
115
+ if len(scores) == 1:
116
+ return scores[0][0]
117
+
118
+ # In case of multiple candidates, only return one if its score is unique:
119
+ if scores[0][1] != scores[1][1]:
120
+ return scores[0][0]
121
+
122
+ max_score = max(scores, key=lambda x: x[1])
123
+ candidates = [col_name for col_name, score in scores if score == max_score]
124
+ warnings.warn(f"Found multiple potential primary keys in table "
125
+ f"'{table_name}': {candidates}. Please specify the primary "
126
+ f"key for this table manually.")
127
+
128
+ return None
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from kumoapi.typing import Dtype, Stype
3
+
4
+ from kumoai.experimental.rfm.infer import (
5
+ contains_categorical,
6
+ contains_id,
7
+ contains_multicategorical,
8
+ contains_timestamp,
9
+ )
10
+
11
+
12
+ def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
13
+ """Infers the :class:`Stype` from a :class:`pandas.Series`.
14
+
15
+ Args:
16
+ ser: A :class:`pandas.Series` to analyze.
17
+ column_name: The column name.
18
+ dtype: The data type.
19
+
20
+ Returns:
21
+ The semantic type.
22
+ """
23
+ if contains_id(ser, column_name, dtype):
24
+ return Stype.ID
25
+
26
+ if contains_timestamp(ser, column_name, dtype):
27
+ return Stype.timestamp
28
+
29
+ if contains_multicategorical(ser, column_name, dtype):
30
+ return Stype.multicategorical
31
+
32
+ if contains_categorical(ser, column_name, dtype):
33
+ return Stype.categorical
34
+
35
+ return dtype.default_stype
@@ -0,0 +1,61 @@
1
+ import re
2
+ import warnings
3
+
4
+ import pandas as pd
5
+
6
+
7
+ def infer_time_column(
8
+ df: pd.DataFrame,
9
+ candidates: list[str],
10
+ ) -> str | None:
11
+ r"""Auto-detect potential time column.
12
+
13
+ Args:
14
+ df: The pandas DataFrame to analyze.
15
+ candidates: A list of potential candidates.
16
+
17
+ Returns:
18
+ The name of the detected time column, or ``None`` if not found.
19
+ """
20
+ candidates = [ # Exclude all candidates with `*last*` in column names:
21
+ col_name for col_name in candidates
22
+ if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
23
+ ]
24
+
25
+ if len(candidates) == 0:
26
+ return None
27
+
28
+ if len(candidates) == 1:
29
+ return candidates[0]
30
+
31
+ # If there exists a dedicated `create*` column, use it as time column:
32
+ create_candidates = [
33
+ candidate for candidate in candidates
34
+ if candidate.lower().startswith('create')
35
+ ]
36
+ if len(create_candidates) == 1:
37
+ return create_candidates[0]
38
+ if len(create_candidates) > 1:
39
+ candidates = create_candidates
40
+
41
+ # Find the most optimal time column. Usually, it is the one pointing to
42
+ # the oldest timestamps:
43
+ with warnings.catch_warnings():
44
+ warnings.filterwarnings('ignore', message='Could not infer format')
45
+ min_timestamp_dict = {
46
+ key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
47
+ for key in candidates
48
+ }
49
+ min_timestamp_dict = {
50
+ key: value.min().tz_localize(None)
51
+ for key, value in min_timestamp_dict.items()
52
+ }
53
+ min_timestamp_dict = {
54
+ key: value
55
+ for key, value in min_timestamp_dict.items() if not pd.isna(value)
56
+ }
57
+
58
+ if len(min_timestamp_dict) == 0:
59
+ return None
60
+
61
+ return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
@@ -1,11 +1,7 @@
1
- from .backend import PQueryBackend
2
- from .pandas_backend import PQueryPandasBackend
3
1
  from .executor import PQueryExecutor
4
2
  from .pandas_executor import PQueryPandasExecutor
5
3
 
6
4
  __all__ = [
7
- 'PQueryBackend',
8
- 'PQueryPandasBackend',
9
5
  'PQueryExecutor',
10
6
  'PQueryPandasExecutor',
11
7
  ]
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Dict, Generic, Tuple, TypeVar
2
+ from typing import Generic, TypeVar
3
3
 
4
4
  from kumoapi.pquery import ValidatedPredictiveQuery
5
5
  from kumoapi.pquery.AST import (
@@ -21,82 +21,82 @@ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
21
21
  def execute_column(
22
22
  self,
23
23
  column: Column,
24
- feat_dict: Dict[str, TableData],
24
+ feat_dict: dict[str, TableData],
25
25
  filter_na: bool = True,
26
- ) -> Tuple[ColumnData, IndexData]:
26
+ ) -> tuple[ColumnData, IndexData]:
27
27
  pass
28
28
 
29
29
  @abstractmethod
30
30
  def execute_aggregation(
31
31
  self,
32
32
  aggr: Aggregation,
33
- feat_dict: Dict[str, TableData],
34
- time_dict: Dict[str, ColumnData],
35
- batch_dict: Dict[str, IndexData],
33
+ feat_dict: dict[str, TableData],
34
+ time_dict: dict[str, ColumnData],
35
+ batch_dict: dict[str, IndexData],
36
36
  anchor_time: ColumnData,
37
37
  filter_na: bool = True,
38
38
  num_forecasts: int = 1,
39
- ) -> Tuple[ColumnData, IndexData]:
39
+ ) -> tuple[ColumnData, IndexData]:
40
40
  pass
41
41
 
42
42
  @abstractmethod
43
43
  def execute_condition(
44
44
  self,
45
45
  condition: Condition,
46
- feat_dict: Dict[str, TableData],
47
- time_dict: Dict[str, ColumnData],
48
- batch_dict: Dict[str, IndexData],
46
+ feat_dict: dict[str, TableData],
47
+ time_dict: dict[str, ColumnData],
48
+ batch_dict: dict[str, IndexData],
49
49
  anchor_time: ColumnData,
50
50
  filter_na: bool = True,
51
51
  num_forecasts: int = 1,
52
- ) -> Tuple[ColumnData, IndexData]:
52
+ ) -> tuple[ColumnData, IndexData]:
53
53
  pass
54
54
 
55
55
  @abstractmethod
56
56
  def execute_logical_operation(
57
57
  self,
58
58
  logical_operation: LogicalOperation,
59
- feat_dict: Dict[str, TableData],
60
- time_dict: Dict[str, ColumnData],
61
- batch_dict: Dict[str, IndexData],
59
+ feat_dict: dict[str, TableData],
60
+ time_dict: dict[str, ColumnData],
61
+ batch_dict: dict[str, IndexData],
62
62
  anchor_time: ColumnData,
63
63
  filter_na: bool = True,
64
64
  num_forecasts: int = 1,
65
- ) -> Tuple[ColumnData, IndexData]:
65
+ ) -> tuple[ColumnData, IndexData]:
66
66
  pass
67
67
 
68
68
  @abstractmethod
69
69
  def execute_join(
70
70
  self,
71
71
  join: Join,
72
- feat_dict: Dict[str, TableData],
73
- time_dict: Dict[str, ColumnData],
74
- batch_dict: Dict[str, IndexData],
72
+ feat_dict: dict[str, TableData],
73
+ time_dict: dict[str, ColumnData],
74
+ batch_dict: dict[str, IndexData],
75
75
  anchor_time: ColumnData,
76
76
  filter_na: bool = True,
77
77
  num_forecasts: int = 1,
78
- ) -> Tuple[ColumnData, IndexData]:
78
+ ) -> tuple[ColumnData, IndexData]:
79
79
  pass
80
80
 
81
81
  @abstractmethod
82
82
  def execute_filter(
83
83
  self,
84
84
  filter: Filter,
85
- feat_dict: Dict[str, TableData],
86
- time_dict: Dict[str, ColumnData],
87
- batch_dict: Dict[str, IndexData],
85
+ feat_dict: dict[str, TableData],
86
+ time_dict: dict[str, ColumnData],
87
+ batch_dict: dict[str, IndexData],
88
88
  anchor_time: ColumnData,
89
- ) -> Tuple[ColumnData, IndexData]:
89
+ ) -> tuple[ColumnData, IndexData]:
90
90
  pass
91
91
 
92
92
  @abstractmethod
93
93
  def execute(
94
94
  self,
95
95
  query: ValidatedPredictiveQuery,
96
- feat_dict: Dict[str, TableData],
97
- time_dict: Dict[str, ColumnData],
98
- batch_dict: Dict[str, IndexData],
96
+ feat_dict: dict[str, TableData],
97
+ time_dict: dict[str, ColumnData],
98
+ batch_dict: dict[str, IndexData],
99
99
  anchor_time: ColumnData,
100
100
  num_forecasts: int = 1,
101
- ) -> Tuple[ColumnData, IndexData]:
101
+ ) -> tuple[ColumnData, IndexData]:
102
102
  pass
@@ -1,5 +1,3 @@
1
- from typing import Dict, List, Tuple
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  from kumoapi.pquery import ValidatedPredictiveQuery
@@ -22,9 +20,9 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
22
20
  def execute_column(
23
21
  self,
24
22
  column: Column,
25
- feat_dict: Dict[str, pd.DataFrame],
23
+ feat_dict: dict[str, pd.DataFrame],
26
24
  filter_na: bool = True,
27
- ) -> Tuple[pd.Series, np.ndarray]:
25
+ ) -> tuple[pd.Series, np.ndarray]:
28
26
  table_name, column_name = column.fqn.split(".")
29
27
  if column_name == '*':
30
28
  out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
@@ -60,7 +58,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
60
58
  batch: np.ndarray,
61
59
  batch_size: int,
62
60
  filter_na: bool = True,
63
- ) -> Tuple[pd.Series, np.ndarray]:
61
+ ) -> tuple[pd.Series, np.ndarray]:
64
62
 
65
63
  mask = feat.notna()
66
64
  feat, batch = feat[mask], batch[mask]
@@ -104,13 +102,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
104
102
  def execute_aggregation(
105
103
  self,
106
104
  aggr: Aggregation,
107
- feat_dict: Dict[str, pd.DataFrame],
108
- time_dict: Dict[str, pd.Series],
109
- batch_dict: Dict[str, np.ndarray],
105
+ feat_dict: dict[str, pd.DataFrame],
106
+ time_dict: dict[str, pd.Series],
107
+ batch_dict: dict[str, np.ndarray],
110
108
  anchor_time: pd.Series,
111
109
  filter_na: bool = True,
112
110
  num_forecasts: int = 1,
113
- ) -> Tuple[pd.Series, np.ndarray]:
111
+ ) -> tuple[pd.Series, np.ndarray]:
114
112
  target_table = aggr._get_target_column_name().split('.')[0]
115
113
  target_batch = batch_dict[target_table]
116
114
  target_time = time_dict[target_table]
@@ -118,7 +116,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
118
116
  target_feat, target_mask = self.execute_column(
119
117
  column=aggr.target,
120
118
  feat_dict=feat_dict,
121
- filter_na=False,
119
+ filter_na=True,
122
120
  )
123
121
  else:
124
122
  assert isinstance(aggr.target, Filter)
@@ -128,28 +126,29 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
128
126
  time_dict=time_dict,
129
127
  batch_dict=batch_dict,
130
128
  anchor_time=anchor_time,
131
- filter_na=False,
129
+ filter_na=True,
132
130
  )
133
131
 
134
- outs: List[pd.Series] = []
135
- masks: List[np.ndarray] = []
132
+ outs: list[pd.Series] = []
133
+ masks: list[np.ndarray] = []
136
134
  for _ in range(num_forecasts):
137
- anchor_target_time = anchor_time[target_batch]
135
+ anchor_target_time = anchor_time.iloc[target_batch]
138
136
  anchor_target_time = anchor_target_time.reset_index(drop=True)
139
137
 
140
- curr_target_mask = target_mask & (
141
- target_time
142
- <= anchor_target_time + aggr.aggr_time_range.end_date_offset)
138
+ time_filter_mask = (target_time <= anchor_target_time +
139
+ aggr.aggr_time_range.end_date_offset)
143
140
  if aggr.aggr_time_range.start is not None:
144
141
  start_offset = aggr.aggr_time_range.start_date_offset
145
- curr_target_mask &= (target_time
142
+ time_filter_mask &= (target_time
146
143
  > anchor_target_time + start_offset)
147
144
  else:
148
145
  assert num_forecasts == 1
146
+ curr_target_mask = target_mask & time_filter_mask
149
147
 
150
148
  out, mask = self.execute_aggregation_type(
151
149
  aggr.aggr,
152
- feat=target_feat[curr_target_mask],
150
+ feat=target_feat[time_filter_mask[target_mask].reset_index(
151
+ drop=True)],
153
152
  batch=target_batch[curr_target_mask],
154
153
  batch_size=len(anchor_time),
155
154
  filter_na=False if num_forecasts > 1 else filter_na,
@@ -225,13 +224,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
225
224
  def execute_condition(
226
225
  self,
227
226
  condition: Condition,
228
- feat_dict: Dict[str, pd.DataFrame],
229
- time_dict: Dict[str, pd.Series],
230
- batch_dict: Dict[str, np.ndarray],
227
+ feat_dict: dict[str, pd.DataFrame],
228
+ time_dict: dict[str, pd.Series],
229
+ batch_dict: dict[str, np.ndarray],
231
230
  anchor_time: pd.Series,
232
231
  filter_na: bool = True,
233
232
  num_forecasts: int = 1,
234
- ) -> Tuple[pd.Series, np.ndarray]:
233
+ ) -> tuple[pd.Series, np.ndarray]:
235
234
  if num_forecasts > 1:
236
235
  raise NotImplementedError("Forecasting not yet implemented for "
237
236
  "non-regression tasks")
@@ -305,13 +304,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
305
304
  def execute_logical_operation(
306
305
  self,
307
306
  logical_operation: LogicalOperation,
308
- feat_dict: Dict[str, pd.DataFrame],
309
- time_dict: Dict[str, pd.Series],
310
- batch_dict: Dict[str, np.ndarray],
307
+ feat_dict: dict[str, pd.DataFrame],
308
+ time_dict: dict[str, pd.Series],
309
+ batch_dict: dict[str, np.ndarray],
311
310
  anchor_time: pd.Series,
312
311
  filter_na: bool = True,
313
312
  num_forecasts: int = 1,
314
- ) -> Tuple[pd.Series, np.ndarray]:
313
+ ) -> tuple[pd.Series, np.ndarray]:
315
314
  if num_forecasts > 1:
316
315
  raise NotImplementedError("Forecasting not yet implemented for "
317
316
  "non-regression tasks")
@@ -369,13 +368,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
369
368
  def execute_join(
370
369
  self,
371
370
  join: Join,
372
- feat_dict: Dict[str, pd.DataFrame],
373
- time_dict: Dict[str, pd.Series],
374
- batch_dict: Dict[str, np.ndarray],
371
+ feat_dict: dict[str, pd.DataFrame],
372
+ time_dict: dict[str, pd.Series],
373
+ batch_dict: dict[str, np.ndarray],
375
374
  anchor_time: pd.Series,
376
375
  filter_na: bool = True,
377
376
  num_forecasts: int = 1,
378
- ) -> Tuple[pd.Series, np.ndarray]:
377
+ ) -> tuple[pd.Series, np.ndarray]:
379
378
  if isinstance(join.rhs_target, Aggregation):
380
379
  return self.execute_aggregation(
381
380
  aggr=join.rhs_target,
@@ -392,12 +391,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
392
391
  def execute_filter(
393
392
  self,
394
393
  filter: Filter,
395
- feat_dict: Dict[str, pd.DataFrame],
396
- time_dict: Dict[str, pd.Series],
397
- batch_dict: Dict[str, np.ndarray],
394
+ feat_dict: dict[str, pd.DataFrame],
395
+ time_dict: dict[str, pd.Series],
396
+ batch_dict: dict[str, np.ndarray],
398
397
  anchor_time: pd.Series,
399
398
  filter_na: bool = True,
400
- ) -> Tuple[pd.Series, np.ndarray]:
399
+ ) -> tuple[pd.Series, np.ndarray]:
401
400
  out, mask = self.execute_column(
402
401
  column=filter.target,
403
402
  feat_dict=feat_dict,
@@ -430,12 +429,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
430
429
  def execute(
431
430
  self,
432
431
  query: ValidatedPredictiveQuery,
433
- feat_dict: Dict[str, pd.DataFrame],
434
- time_dict: Dict[str, pd.Series],
435
- batch_dict: Dict[str, np.ndarray],
432
+ feat_dict: dict[str, pd.DataFrame],
433
+ time_dict: dict[str, pd.Series],
434
+ batch_dict: dict[str, np.ndarray],
436
435
  anchor_time: pd.Series,
437
436
  num_forecasts: int = 1,
438
- ) -> Tuple[pd.Series, np.ndarray]:
437
+ ) -> tuple[pd.Series, np.ndarray]:
439
438
  if isinstance(query.entity_ast, Column):
440
439
  out, mask = self.execute_column(
441
440
  column=query.entity_ast,
@@ -499,7 +498,32 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
499
498
  )
500
499
  else:
501
500
  raise NotImplementedError(
502
- f'{type(query.target)} compilation missing.')
501
+ f'{type(query.target_ast)} compilation missing.')
502
+ if query.whatif_ast is not None:
503
+ if isinstance(query.whatif_ast, Condition):
504
+ mask &= self.execute_condition(
505
+ condition=query.whatif_ast,
506
+ feat_dict=feat_dict,
507
+ time_dict=time_dict,
508
+ batch_dict=batch_dict,
509
+ anchor_time=anchor_time,
510
+ filter_na=True,
511
+ num_forecasts=num_forecasts,
512
+ )[0]
513
+ elif isinstance(query.whatif_ast, LogicalOperation):
514
+ mask &= self.execute_logical_operation(
515
+ logical_operation=query.whatif_ast,
516
+ feat_dict=feat_dict,
517
+ time_dict=time_dict,
518
+ batch_dict=batch_dict,
519
+ anchor_time=anchor_time,
520
+ filter_na=True,
521
+ num_forecasts=num_forecasts,
522
+ )[0]
523
+ else:
524
+ raise ValueError(
525
+ f'Unsupported ASSUMING condition {type(query.whatif_ast)}')
526
+
503
527
  out = out[mask[_mask]]
504
528
  mask &= _mask
505
529
  out = out.reset_index(drop=True)