kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +49 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +32 -14
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +186 -39
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +380 -185
- kumoai/experimental/rfm/graph.py +404 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +52 -60
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +283 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
from .dtype import infer_dtype
|
|
2
|
-
from .pkey import infer_primary_key
|
|
3
|
-
from .time_col import infer_time_column
|
|
4
2
|
from .id import contains_id
|
|
5
3
|
from .timestamp import contains_timestamp
|
|
6
4
|
from .categorical import contains_categorical
|
|
7
5
|
from .multicategorical import contains_multicategorical
|
|
6
|
+
from .stype import infer_stype
|
|
7
|
+
from .pkey import infer_primary_key
|
|
8
|
+
from .time_col import infer_time_column
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
'infer_dtype',
|
|
11
|
-
'infer_primary_key',
|
|
12
|
-
'infer_time_column',
|
|
13
12
|
'contains_id',
|
|
14
13
|
'contains_timestamp',
|
|
15
14
|
'contains_categorical',
|
|
16
15
|
'contains_multicategorical',
|
|
16
|
+
'infer_stype',
|
|
17
|
+
'infer_primary_key',
|
|
18
|
+
'infer_time_column',
|
|
17
19
|
]
|
|
@@ -1,36 +1,27 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import pandas as pd
|
|
5
3
|
import pyarrow as pa
|
|
6
4
|
from kumoapi.typing import Dtype
|
|
7
5
|
|
|
8
|
-
PANDAS_TO_DTYPE:
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
pa.string(): Dtype.string,
|
|
28
|
-
pa.binary(): Dtype.binary,
|
|
29
|
-
np.dtype('datetime64[ns]'): Dtype.date,
|
|
30
|
-
np.dtype('timedelta64[ns]'): Dtype.timedelta,
|
|
31
|
-
pa.list_(pa.float32()): Dtype.floatlist,
|
|
32
|
-
pa.list_(pa.int64()): Dtype.intlist,
|
|
33
|
-
pa.list_(pa.string()): Dtype.stringlist,
|
|
6
|
+
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
7
|
+
'bool': Dtype.bool,
|
|
8
|
+
'boolean': Dtype.bool,
|
|
9
|
+
'int8': Dtype.int,
|
|
10
|
+
'int16': Dtype.int,
|
|
11
|
+
'int32': Dtype.int,
|
|
12
|
+
'int64': Dtype.int,
|
|
13
|
+
'float': Dtype.float,
|
|
14
|
+
'double': Dtype.float,
|
|
15
|
+
'float16': Dtype.float,
|
|
16
|
+
'float32': Dtype.float,
|
|
17
|
+
'float64': Dtype.float,
|
|
18
|
+
'object': Dtype.string,
|
|
19
|
+
'string': Dtype.string,
|
|
20
|
+
'string[python]': Dtype.string,
|
|
21
|
+
'string[pyarrow]': Dtype.string,
|
|
22
|
+
'binary': Dtype.binary,
|
|
23
|
+
'binary[python]': Dtype.binary,
|
|
24
|
+
'binary[pyarrow]': Dtype.binary,
|
|
34
25
|
}
|
|
35
26
|
|
|
36
27
|
|
|
@@ -45,46 +36,47 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
45
36
|
"""
|
|
46
37
|
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
47
38
|
return Dtype.date
|
|
48
|
-
|
|
39
|
+
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
40
|
+
return Dtype.timedelta
|
|
49
41
|
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
50
42
|
return Dtype.string
|
|
51
43
|
|
|
52
|
-
if pd.api.types.is_object_dtype(ser.dtype)
|
|
44
|
+
if (pd.api.types.is_object_dtype(ser.dtype)
|
|
45
|
+
and not isinstance(ser.dtype, pd.ArrowDtype)):
|
|
53
46
|
index = ser.iloc[:1000].first_valid_index()
|
|
54
47
|
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
55
48
|
pos = ser.index.get_loc(index)
|
|
56
49
|
assert isinstance(pos, int)
|
|
57
50
|
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
# Infer unique data types in this series:
|
|
67
|
-
dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
|
|
68
|
-
np.array(x).dtype, Dtype.string)).unique().tolist()
|
|
69
|
-
|
|
70
|
-
invalid_dtypes = set(dtypes) - {
|
|
71
|
-
Dtype.string,
|
|
72
|
-
Dtype.int,
|
|
73
|
-
Dtype.float,
|
|
74
|
-
}
|
|
75
|
-
if len(invalid_dtypes) > 0:
|
|
76
|
-
raise ValueError(f"Data contains unsupported list data types: "
|
|
77
|
-
f"{list(invalid_dtypes)}")
|
|
78
|
-
|
|
79
|
-
if Dtype.string in dtypes:
|
|
80
|
-
return Dtype.stringlist
|
|
81
|
-
|
|
82
|
-
if dtypes == [Dtype.int]:
|
|
51
|
+
arr = pa.array(ser.tolist())
|
|
52
|
+
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
53
|
+
|
|
54
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
55
|
+
if (pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
56
|
+
or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
|
|
57
|
+
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
58
|
+
if pa.types.is_integer(elem_dtype):
|
|
83
59
|
return Dtype.intlist
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
60
|
+
if pa.types.is_floating(elem_dtype):
|
|
61
|
+
return Dtype.floatlist
|
|
62
|
+
if pa.types.is_decimal(elem_dtype):
|
|
63
|
+
return Dtype.floatlist
|
|
64
|
+
if pa.types.is_string(elem_dtype):
|
|
65
|
+
return Dtype.stringlist
|
|
66
|
+
if pa.types.is_null(elem_dtype):
|
|
67
|
+
return Dtype.floatlist
|
|
68
|
+
|
|
69
|
+
if isinstance(ser.dtype, np.dtype):
|
|
70
|
+
dtype_str = str(ser.dtype).lower()
|
|
71
|
+
elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
|
|
72
|
+
dtype_str = ser.dtype.name.lower()
|
|
73
|
+
dtype_str = dtype_str.split('[')[0] # Remove backend metadata
|
|
74
|
+
elif isinstance(ser.dtype, pa.DataType):
|
|
75
|
+
dtype_str = str(ser.dtype).lower()
|
|
76
|
+
else:
|
|
77
|
+
dtype_str = 'object'
|
|
78
|
+
|
|
79
|
+
if dtype_str not in PANDAS_TO_DTYPE:
|
|
88
80
|
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
89
81
|
|
|
90
|
-
return PANDAS_TO_DTYPE[
|
|
82
|
+
return PANDAS_TO_DTYPE[dtype_str]
|
|
@@ -40,7 +40,7 @@ def contains_multicategorical(
|
|
|
40
40
|
sep = max(candidates, key=candidates.get) # type: ignore
|
|
41
41
|
ser = ser.str.split(sep)
|
|
42
42
|
|
|
43
|
-
num_unique_multi = ser.explode().nunique()
|
|
43
|
+
num_unique_multi = ser.astype('object').explode().nunique()
|
|
44
44
|
|
|
45
45
|
if dtype.is_list():
|
|
46
46
|
return num_unique_multi <= MAX_CAT
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import pandas as pd
|
|
6
5
|
|
|
@@ -9,7 +8,7 @@ def infer_primary_key(
|
|
|
9
8
|
table_name: str,
|
|
10
9
|
df: pd.DataFrame,
|
|
11
10
|
candidates: list[str],
|
|
12
|
-
) ->
|
|
11
|
+
) -> str | None:
|
|
13
12
|
r"""Auto-detect potential primary key column.
|
|
14
13
|
|
|
15
14
|
Args:
|
|
@@ -20,6 +19,9 @@ def infer_primary_key(
|
|
|
20
19
|
Returns:
|
|
21
20
|
The name of the detected primary key, or ``None`` if not found.
|
|
22
21
|
"""
|
|
22
|
+
if len(candidates) == 0:
|
|
23
|
+
return None
|
|
24
|
+
|
|
23
25
|
# A list of (potentially modified) table names that are eligible to match
|
|
24
26
|
# with a primary key, i.e.:
|
|
25
27
|
# - UserInfo -> User
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from kumoapi.typing import Dtype, Stype
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.infer import (
|
|
5
|
+
contains_categorical,
|
|
6
|
+
contains_id,
|
|
7
|
+
contains_multicategorical,
|
|
8
|
+
contains_timestamp,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
|
|
13
|
+
"""Infers the :class:`Stype` from a :class:`pandas.Series`.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
17
|
+
column_name: The column name.
|
|
18
|
+
dtype: The data type.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The semantic type.
|
|
22
|
+
"""
|
|
23
|
+
if contains_id(ser, column_name, dtype):
|
|
24
|
+
return Stype.ID
|
|
25
|
+
|
|
26
|
+
if contains_timestamp(ser, column_name, dtype):
|
|
27
|
+
return Stype.timestamp
|
|
28
|
+
|
|
29
|
+
if contains_multicategorical(ser, column_name, dtype):
|
|
30
|
+
return Stype.multicategorical
|
|
31
|
+
|
|
32
|
+
if contains_categorical(ser, column_name, dtype):
|
|
33
|
+
return Stype.categorical
|
|
34
|
+
|
|
35
|
+
return dtype.default_stype
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import pandas as pd
|
|
6
5
|
|
|
@@ -8,7 +7,7 @@ import pandas as pd
|
|
|
8
7
|
def infer_time_column(
|
|
9
8
|
df: pd.DataFrame,
|
|
10
9
|
candidates: list[str],
|
|
11
|
-
) ->
|
|
10
|
+
) -> str | None:
|
|
12
11
|
r"""Auto-detect potential time column.
|
|
13
12
|
|
|
14
13
|
Args:
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
3
|
|
|
4
4
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
5
5
|
from kumoapi.pquery.AST import (
|
|
@@ -21,82 +21,82 @@ class PQueryExecutor(Generic[TableData, ColumnData, IndexData], ABC):
|
|
|
21
21
|
def execute_column(
|
|
22
22
|
self,
|
|
23
23
|
column: Column,
|
|
24
|
-
feat_dict:
|
|
24
|
+
feat_dict: dict[str, TableData],
|
|
25
25
|
filter_na: bool = True,
|
|
26
|
-
) ->
|
|
26
|
+
) -> tuple[ColumnData, IndexData]:
|
|
27
27
|
pass
|
|
28
28
|
|
|
29
29
|
@abstractmethod
|
|
30
30
|
def execute_aggregation(
|
|
31
31
|
self,
|
|
32
32
|
aggr: Aggregation,
|
|
33
|
-
feat_dict:
|
|
34
|
-
time_dict:
|
|
35
|
-
batch_dict:
|
|
33
|
+
feat_dict: dict[str, TableData],
|
|
34
|
+
time_dict: dict[str, ColumnData],
|
|
35
|
+
batch_dict: dict[str, IndexData],
|
|
36
36
|
anchor_time: ColumnData,
|
|
37
37
|
filter_na: bool = True,
|
|
38
38
|
num_forecasts: int = 1,
|
|
39
|
-
) ->
|
|
39
|
+
) -> tuple[ColumnData, IndexData]:
|
|
40
40
|
pass
|
|
41
41
|
|
|
42
42
|
@abstractmethod
|
|
43
43
|
def execute_condition(
|
|
44
44
|
self,
|
|
45
45
|
condition: Condition,
|
|
46
|
-
feat_dict:
|
|
47
|
-
time_dict:
|
|
48
|
-
batch_dict:
|
|
46
|
+
feat_dict: dict[str, TableData],
|
|
47
|
+
time_dict: dict[str, ColumnData],
|
|
48
|
+
batch_dict: dict[str, IndexData],
|
|
49
49
|
anchor_time: ColumnData,
|
|
50
50
|
filter_na: bool = True,
|
|
51
51
|
num_forecasts: int = 1,
|
|
52
|
-
) ->
|
|
52
|
+
) -> tuple[ColumnData, IndexData]:
|
|
53
53
|
pass
|
|
54
54
|
|
|
55
55
|
@abstractmethod
|
|
56
56
|
def execute_logical_operation(
|
|
57
57
|
self,
|
|
58
58
|
logical_operation: LogicalOperation,
|
|
59
|
-
feat_dict:
|
|
60
|
-
time_dict:
|
|
61
|
-
batch_dict:
|
|
59
|
+
feat_dict: dict[str, TableData],
|
|
60
|
+
time_dict: dict[str, ColumnData],
|
|
61
|
+
batch_dict: dict[str, IndexData],
|
|
62
62
|
anchor_time: ColumnData,
|
|
63
63
|
filter_na: bool = True,
|
|
64
64
|
num_forecasts: int = 1,
|
|
65
|
-
) ->
|
|
65
|
+
) -> tuple[ColumnData, IndexData]:
|
|
66
66
|
pass
|
|
67
67
|
|
|
68
68
|
@abstractmethod
|
|
69
69
|
def execute_join(
|
|
70
70
|
self,
|
|
71
71
|
join: Join,
|
|
72
|
-
feat_dict:
|
|
73
|
-
time_dict:
|
|
74
|
-
batch_dict:
|
|
72
|
+
feat_dict: dict[str, TableData],
|
|
73
|
+
time_dict: dict[str, ColumnData],
|
|
74
|
+
batch_dict: dict[str, IndexData],
|
|
75
75
|
anchor_time: ColumnData,
|
|
76
76
|
filter_na: bool = True,
|
|
77
77
|
num_forecasts: int = 1,
|
|
78
|
-
) ->
|
|
78
|
+
) -> tuple[ColumnData, IndexData]:
|
|
79
79
|
pass
|
|
80
80
|
|
|
81
81
|
@abstractmethod
|
|
82
82
|
def execute_filter(
|
|
83
83
|
self,
|
|
84
84
|
filter: Filter,
|
|
85
|
-
feat_dict:
|
|
86
|
-
time_dict:
|
|
87
|
-
batch_dict:
|
|
85
|
+
feat_dict: dict[str, TableData],
|
|
86
|
+
time_dict: dict[str, ColumnData],
|
|
87
|
+
batch_dict: dict[str, IndexData],
|
|
88
88
|
anchor_time: ColumnData,
|
|
89
|
-
) ->
|
|
89
|
+
) -> tuple[ColumnData, IndexData]:
|
|
90
90
|
pass
|
|
91
91
|
|
|
92
92
|
@abstractmethod
|
|
93
93
|
def execute(
|
|
94
94
|
self,
|
|
95
95
|
query: ValidatedPredictiveQuery,
|
|
96
|
-
feat_dict:
|
|
97
|
-
time_dict:
|
|
98
|
-
batch_dict:
|
|
96
|
+
feat_dict: dict[str, TableData],
|
|
97
|
+
time_dict: dict[str, ColumnData],
|
|
98
|
+
batch_dict: dict[str, IndexData],
|
|
99
99
|
anchor_time: ColumnData,
|
|
100
100
|
num_forecasts: int = 1,
|
|
101
|
-
) ->
|
|
101
|
+
) -> tuple[ColumnData, IndexData]:
|
|
102
102
|
pass
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
from typing import Dict, List, Tuple
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import pandas as pd
|
|
5
3
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
@@ -22,9 +20,9 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
22
20
|
def execute_column(
|
|
23
21
|
self,
|
|
24
22
|
column: Column,
|
|
25
|
-
feat_dict:
|
|
23
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
26
24
|
filter_na: bool = True,
|
|
27
|
-
) ->
|
|
25
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
28
26
|
table_name, column_name = column.fqn.split(".")
|
|
29
27
|
if column_name == '*':
|
|
30
28
|
out = pd.Series(np.ones(len(feat_dict[table_name]), dtype='int64'))
|
|
@@ -60,7 +58,7 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
60
58
|
batch: np.ndarray,
|
|
61
59
|
batch_size: int,
|
|
62
60
|
filter_na: bool = True,
|
|
63
|
-
) ->
|
|
61
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
64
62
|
|
|
65
63
|
mask = feat.notna()
|
|
66
64
|
feat, batch = feat[mask], batch[mask]
|
|
@@ -104,13 +102,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
104
102
|
def execute_aggregation(
|
|
105
103
|
self,
|
|
106
104
|
aggr: Aggregation,
|
|
107
|
-
feat_dict:
|
|
108
|
-
time_dict:
|
|
109
|
-
batch_dict:
|
|
105
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
106
|
+
time_dict: dict[str, pd.Series],
|
|
107
|
+
batch_dict: dict[str, np.ndarray],
|
|
110
108
|
anchor_time: pd.Series,
|
|
111
109
|
filter_na: bool = True,
|
|
112
110
|
num_forecasts: int = 1,
|
|
113
|
-
) ->
|
|
111
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
114
112
|
target_table = aggr._get_target_column_name().split('.')[0]
|
|
115
113
|
target_batch = batch_dict[target_table]
|
|
116
114
|
target_time = time_dict[target_table]
|
|
@@ -131,10 +129,10 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
131
129
|
filter_na=True,
|
|
132
130
|
)
|
|
133
131
|
|
|
134
|
-
outs:
|
|
135
|
-
masks:
|
|
132
|
+
outs: list[pd.Series] = []
|
|
133
|
+
masks: list[np.ndarray] = []
|
|
136
134
|
for _ in range(num_forecasts):
|
|
137
|
-
anchor_target_time = anchor_time[target_batch]
|
|
135
|
+
anchor_target_time = anchor_time.iloc[target_batch]
|
|
138
136
|
anchor_target_time = anchor_target_time.reset_index(drop=True)
|
|
139
137
|
|
|
140
138
|
time_filter_mask = (target_time <= anchor_target_time +
|
|
@@ -226,13 +224,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
226
224
|
def execute_condition(
|
|
227
225
|
self,
|
|
228
226
|
condition: Condition,
|
|
229
|
-
feat_dict:
|
|
230
|
-
time_dict:
|
|
231
|
-
batch_dict:
|
|
227
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
228
|
+
time_dict: dict[str, pd.Series],
|
|
229
|
+
batch_dict: dict[str, np.ndarray],
|
|
232
230
|
anchor_time: pd.Series,
|
|
233
231
|
filter_na: bool = True,
|
|
234
232
|
num_forecasts: int = 1,
|
|
235
|
-
) ->
|
|
233
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
236
234
|
if num_forecasts > 1:
|
|
237
235
|
raise NotImplementedError("Forecasting not yet implemented for "
|
|
238
236
|
"non-regression tasks")
|
|
@@ -306,13 +304,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
306
304
|
def execute_logical_operation(
|
|
307
305
|
self,
|
|
308
306
|
logical_operation: LogicalOperation,
|
|
309
|
-
feat_dict:
|
|
310
|
-
time_dict:
|
|
311
|
-
batch_dict:
|
|
307
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
308
|
+
time_dict: dict[str, pd.Series],
|
|
309
|
+
batch_dict: dict[str, np.ndarray],
|
|
312
310
|
anchor_time: pd.Series,
|
|
313
311
|
filter_na: bool = True,
|
|
314
312
|
num_forecasts: int = 1,
|
|
315
|
-
) ->
|
|
313
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
316
314
|
if num_forecasts > 1:
|
|
317
315
|
raise NotImplementedError("Forecasting not yet implemented for "
|
|
318
316
|
"non-regression tasks")
|
|
@@ -370,13 +368,13 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
370
368
|
def execute_join(
|
|
371
369
|
self,
|
|
372
370
|
join: Join,
|
|
373
|
-
feat_dict:
|
|
374
|
-
time_dict:
|
|
375
|
-
batch_dict:
|
|
371
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
372
|
+
time_dict: dict[str, pd.Series],
|
|
373
|
+
batch_dict: dict[str, np.ndarray],
|
|
376
374
|
anchor_time: pd.Series,
|
|
377
375
|
filter_na: bool = True,
|
|
378
376
|
num_forecasts: int = 1,
|
|
379
|
-
) ->
|
|
377
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
380
378
|
if isinstance(join.rhs_target, Aggregation):
|
|
381
379
|
return self.execute_aggregation(
|
|
382
380
|
aggr=join.rhs_target,
|
|
@@ -393,12 +391,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
393
391
|
def execute_filter(
|
|
394
392
|
self,
|
|
395
393
|
filter: Filter,
|
|
396
|
-
feat_dict:
|
|
397
|
-
time_dict:
|
|
398
|
-
batch_dict:
|
|
394
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
395
|
+
time_dict: dict[str, pd.Series],
|
|
396
|
+
batch_dict: dict[str, np.ndarray],
|
|
399
397
|
anchor_time: pd.Series,
|
|
400
398
|
filter_na: bool = True,
|
|
401
|
-
) ->
|
|
399
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
402
400
|
out, mask = self.execute_column(
|
|
403
401
|
column=filter.target,
|
|
404
402
|
feat_dict=feat_dict,
|
|
@@ -431,12 +429,12 @@ class PQueryPandasExecutor(PQueryExecutor[pd.DataFrame, pd.Series,
|
|
|
431
429
|
def execute(
|
|
432
430
|
self,
|
|
433
431
|
query: ValidatedPredictiveQuery,
|
|
434
|
-
feat_dict:
|
|
435
|
-
time_dict:
|
|
436
|
-
batch_dict:
|
|
432
|
+
feat_dict: dict[str, pd.DataFrame],
|
|
433
|
+
time_dict: dict[str, pd.Series],
|
|
434
|
+
batch_dict: dict[str, np.ndarray],
|
|
437
435
|
anchor_time: pd.Series,
|
|
438
436
|
num_forecasts: int = 1,
|
|
439
|
-
) ->
|
|
437
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
440
438
|
if isinstance(query.entity_ast, Column):
|
|
441
439
|
out, mask = self.execute_column(
|
|
442
440
|
column=query.entity_ast,
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|