kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +21 -16
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +102 -48
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
- kumoai/experimental/rfm/base/__init__.py +26 -3
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +173 -138
- kumoai/experimental/rfm/graph.py +302 -108
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +44 -36
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
|
@@ -1,20 +1,35 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import cast
|
|
3
4
|
|
|
4
5
|
import pandas as pd
|
|
6
|
+
from kumoapi.model_plan import MissingType
|
|
5
7
|
from kumoapi.typing import Dtype
|
|
6
8
|
|
|
7
|
-
from kumoai.experimental.rfm.backend.
|
|
8
|
-
from kumoai.experimental.rfm.base import
|
|
9
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
10
|
+
from kumoai.experimental.rfm.base import (
|
|
11
|
+
ColumnExpressionSpec,
|
|
12
|
+
ColumnExpressionType,
|
|
13
|
+
DataBackend,
|
|
14
|
+
SourceColumn,
|
|
15
|
+
SourceForeignKey,
|
|
16
|
+
SQLTable,
|
|
17
|
+
)
|
|
18
|
+
from kumoai.utils import quote_ident
|
|
9
19
|
|
|
10
20
|
|
|
11
|
-
class SnowTable(
|
|
21
|
+
class SnowTable(SQLTable):
|
|
12
22
|
r"""A table backed by a :class:`sqlite` database.
|
|
13
23
|
|
|
14
24
|
Args:
|
|
15
25
|
connection: The connection to a :class:`snowflake` database.
|
|
16
|
-
name: The name of this table.
|
|
17
|
-
|
|
26
|
+
name: The logical name of this table.
|
|
27
|
+
source_name: The physical name of this table in the database. If set to
|
|
28
|
+
``None``, ``name`` is being used.
|
|
29
|
+
database: The database.
|
|
30
|
+
schema: The schema.
|
|
31
|
+
columns: The selected physical columns of this table.
|
|
32
|
+
column_expressions: The logical columns of this table.
|
|
18
33
|
primary_key: The name of the primary key of this table, if it exists.
|
|
19
34
|
time_column: The name of the time column of this table, if it exists.
|
|
20
35
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -24,17 +39,20 @@ class SnowTable(Table):
|
|
|
24
39
|
self,
|
|
25
40
|
connection: Connection,
|
|
26
41
|
name: str,
|
|
42
|
+
source_name: str | None = None,
|
|
27
43
|
database: str | None = None,
|
|
28
44
|
schema: str | None = None,
|
|
29
|
-
columns:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
45
|
+
columns: Sequence[str] | None = None,
|
|
46
|
+
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
47
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
48
|
+
time_column: str | None = None,
|
|
49
|
+
end_time_column: str | None = None,
|
|
33
50
|
) -> None:
|
|
34
51
|
|
|
35
52
|
if database is not None and schema is None:
|
|
36
|
-
raise ValueError(f"
|
|
37
|
-
f"
|
|
53
|
+
raise ValueError(f"Unspecified 'schema' for table "
|
|
54
|
+
f"'{source_name or name}' in database "
|
|
55
|
+
f"'{database}'")
|
|
38
56
|
|
|
39
57
|
self._connection = connection
|
|
40
58
|
self._database = database
|
|
@@ -42,47 +60,67 @@ class SnowTable(Table):
|
|
|
42
60
|
|
|
43
61
|
super().__init__(
|
|
44
62
|
name=name,
|
|
63
|
+
source_name=source_name,
|
|
45
64
|
columns=columns,
|
|
65
|
+
column_expressions=column_expressions,
|
|
46
66
|
primary_key=primary_key,
|
|
47
67
|
time_column=time_column,
|
|
48
68
|
end_time_column=end_time_column,
|
|
49
69
|
)
|
|
50
70
|
|
|
71
|
+
@staticmethod
|
|
72
|
+
def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
|
|
73
|
+
if snowflake_dtype is None:
|
|
74
|
+
return None
|
|
75
|
+
snowflake_dtype = snowflake_dtype.strip().upper()
|
|
76
|
+
# TODO 'NUMBER(...)' is not always an integer!
|
|
77
|
+
if snowflake_dtype.startswith('NUMBER'):
|
|
78
|
+
return Dtype.int
|
|
79
|
+
elif snowflake_dtype.startswith('VARCHAR'):
|
|
80
|
+
return Dtype.string
|
|
81
|
+
elif snowflake_dtype == 'FLOAT':
|
|
82
|
+
return Dtype.float
|
|
83
|
+
elif snowflake_dtype == 'BOOLEAN':
|
|
84
|
+
return Dtype.bool
|
|
85
|
+
elif re.search('DATE|TIMESTAMP', snowflake_dtype):
|
|
86
|
+
return Dtype.date
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def backend(self) -> DataBackend:
|
|
91
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
92
|
+
|
|
51
93
|
@property
|
|
52
|
-
def
|
|
53
|
-
|
|
94
|
+
def fqn(self) -> str:
|
|
95
|
+
r"""The fully-qualified quoted table name."""
|
|
96
|
+
names: list[str] = []
|
|
54
97
|
if self._database is not None:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def _get_source_columns(self) -> List[SourceColumn]:
|
|
63
|
-
source_columns: List[SourceColumn] = []
|
|
98
|
+
names.append(quote_ident(self._database))
|
|
99
|
+
if self._schema is not None:
|
|
100
|
+
names.append(quote_ident(self._schema))
|
|
101
|
+
return '.'.join(names + [quote_ident(self._source_name)])
|
|
102
|
+
|
|
103
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
104
|
+
source_columns: list[SourceColumn] = []
|
|
64
105
|
with self._connection.cursor() as cursor:
|
|
65
106
|
try:
|
|
66
|
-
|
|
107
|
+
sql = f"DESCRIBE TABLE {self.fqn}"
|
|
108
|
+
cursor.execute(sql)
|
|
67
109
|
except Exception as e:
|
|
68
|
-
|
|
69
|
-
|
|
110
|
+
names: list[str] = []
|
|
111
|
+
if self._database is not None:
|
|
112
|
+
names.append(self._database)
|
|
113
|
+
if self._schema is not None:
|
|
114
|
+
names.append(self._schema)
|
|
115
|
+
source_name = '.'.join(names + [self._source_name])
|
|
116
|
+
raise ValueError(f"Table '{source_name}' does not exist in "
|
|
117
|
+
f"the remote data backend") from e
|
|
70
118
|
|
|
71
119
|
for row in cursor.fetchall():
|
|
72
|
-
column, type, _,
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
if
|
|
76
|
-
dtype = Dtype.int
|
|
77
|
-
elif type.startswith('VARCHAR'):
|
|
78
|
-
dtype = Dtype.string
|
|
79
|
-
elif type == 'FLOAT':
|
|
80
|
-
dtype = Dtype.float
|
|
81
|
-
elif type == 'BOOLEAN':
|
|
82
|
-
dtype = Dtype.bool
|
|
83
|
-
elif re.search('DATE|TIMESTAMP', type):
|
|
84
|
-
dtype = Dtype.date
|
|
85
|
-
else:
|
|
120
|
+
column, type, _, null, _, is_pkey, is_unique, *_ = row
|
|
121
|
+
|
|
122
|
+
dtype = self.to_dtype(type)
|
|
123
|
+
if dtype is None:
|
|
86
124
|
continue
|
|
87
125
|
|
|
88
126
|
source_column = SourceColumn(
|
|
@@ -90,26 +128,42 @@ class SnowTable(Table):
|
|
|
90
128
|
dtype=dtype,
|
|
91
129
|
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
92
130
|
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
131
|
+
is_nullable=null.strip().upper() == 'Y',
|
|
93
132
|
)
|
|
94
133
|
source_columns.append(source_column)
|
|
95
134
|
|
|
96
135
|
return source_columns
|
|
97
136
|
|
|
98
|
-
def _get_source_foreign_keys(self) ->
|
|
99
|
-
source_fkeys:
|
|
137
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
138
|
+
source_fkeys: list[SourceForeignKey] = []
|
|
100
139
|
with self._connection.cursor() as cursor:
|
|
101
|
-
|
|
140
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
|
|
141
|
+
cursor.execute(sql)
|
|
102
142
|
for row in cursor.fetchall():
|
|
103
|
-
_, _, _, dst_table, pkey, _, _, _, fkey = row
|
|
143
|
+
_, _, _, dst_table, pkey, _, _, _, fkey, *_ = row
|
|
104
144
|
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
105
145
|
return source_fkeys
|
|
106
146
|
|
|
107
|
-
def
|
|
147
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
108
148
|
with self._connection.cursor() as cursor:
|
|
109
|
-
columns =
|
|
110
|
-
|
|
149
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
150
|
+
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
151
|
+
cursor.execute(sql)
|
|
111
152
|
table = cursor.fetch_arrow_all()
|
|
112
153
|
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
113
154
|
|
|
114
|
-
def _get_num_rows(self) ->
|
|
155
|
+
def _get_num_rows(self) -> int | None:
|
|
115
156
|
return None
|
|
157
|
+
|
|
158
|
+
def _get_expression_sample_df(
|
|
159
|
+
self,
|
|
160
|
+
specs: Sequence[ColumnExpressionSpec],
|
|
161
|
+
) -> pd.DataFrame:
|
|
162
|
+
with self._connection.cursor() as cursor:
|
|
163
|
+
columns = [
|
|
164
|
+
f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
|
|
165
|
+
]
|
|
166
|
+
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
167
|
+
cursor.execute(sql)
|
|
168
|
+
table = cursor.fetch_arrow_all()
|
|
169
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, TypeAlias
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
3
|
|
|
4
4
|
try:
|
|
5
5
|
import adbc_driver_sqlite.dbapi as adbc
|
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
|
11
11
|
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def connect(uri:
|
|
14
|
+
def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
|
|
15
15
|
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
16
|
|
|
17
17
|
uri: The path to the database file to be opened.
|
|
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
25
26
|
|
|
26
27
|
__all__ = [
|
|
27
28
|
'connect',
|
|
28
29
|
'Connection',
|
|
29
30
|
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
30
32
|
]
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
|
+
from kumoapi.typing import Stype
|
|
10
|
+
|
|
11
|
+
from kumoai.experimental.rfm.base import SQLSampler
|
|
12
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
|
+
from kumoai.utils import ProgressLogger, quote_ident
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from kumoai.experimental.rfm import Graph
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SQLiteSampler(SQLSampler):
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
graph: 'Graph',
|
|
23
|
+
verbose: bool | ProgressLogger = True,
|
|
24
|
+
optimize: bool = False,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
27
|
+
|
|
28
|
+
if optimize:
|
|
29
|
+
with self._connection.cursor() as cursor:
|
|
30
|
+
cursor.execute("PRAGMA temp_store = MEMORY")
|
|
31
|
+
cursor.execute("PRAGMA cache_size = -2000000") # 2 GB
|
|
32
|
+
|
|
33
|
+
# Collect database indices to speed-up sampling:
|
|
34
|
+
index_dict: dict[str, set[tuple[str, ...]]] = defaultdict(set)
|
|
35
|
+
for table_name, primary_key in self.primary_key_dict.items():
|
|
36
|
+
source_table = self.source_table_dict[table_name]
|
|
37
|
+
if not source_table[primary_key].is_unique_key:
|
|
38
|
+
index_dict[table_name].add((primary_key, ))
|
|
39
|
+
for src_table_name, foreign_key, _ in graph.edges:
|
|
40
|
+
source_table = self.source_table_dict[src_table_name]
|
|
41
|
+
if source_table[foreign_key].is_unique_key:
|
|
42
|
+
pass
|
|
43
|
+
elif time_column := self.time_column_dict.get(src_table_name):
|
|
44
|
+
index_dict[src_table_name].add((foreign_key, time_column))
|
|
45
|
+
else:
|
|
46
|
+
index_dict[src_table_name].add((foreign_key, ))
|
|
47
|
+
|
|
48
|
+
# Only maintain missing indices:
|
|
49
|
+
with self._connection.cursor() as cursor:
|
|
50
|
+
for table_name in list(index_dict.keys()):
|
|
51
|
+
indices = index_dict[table_name]
|
|
52
|
+
sql = f"PRAGMA index_list({self.fqn_dict[table_name]})"
|
|
53
|
+
cursor.execute(sql)
|
|
54
|
+
for _, index_name, *_ in cursor.fetchall():
|
|
55
|
+
sql = f"PRAGMA index_info({quote_ident(index_name)})"
|
|
56
|
+
cursor.execute(sql)
|
|
57
|
+
index = tuple(info[2] for info in sorted(
|
|
58
|
+
cursor.fetchall(), key=lambda x: x[0]))
|
|
59
|
+
indices.discard(index)
|
|
60
|
+
if len(indices) == 0:
|
|
61
|
+
del index_dict[table_name]
|
|
62
|
+
|
|
63
|
+
num = sum(len(indices) for indices in index_dict.values())
|
|
64
|
+
index_repr = '1 index' if num == 1 else f'{num} indices'
|
|
65
|
+
num = len(index_dict)
|
|
66
|
+
table_repr = '1 table' if num == 1 else f'{num} tables'
|
|
67
|
+
|
|
68
|
+
if optimize and len(index_dict) > 0:
|
|
69
|
+
if not isinstance(verbose, ProgressLogger):
|
|
70
|
+
verbose = ProgressLogger.default(
|
|
71
|
+
msg="Optimizing SQLite database",
|
|
72
|
+
verbose=verbose,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
with verbose as logger, self._connection.cursor() as cursor:
|
|
76
|
+
for table_name, indices in index_dict.items():
|
|
77
|
+
for index in indices:
|
|
78
|
+
name = f"kumo_index_{table_name}_{'_'.join(index)}"
|
|
79
|
+
name = quote_ident(name)
|
|
80
|
+
columns = ', '.join(quote_ident(v) for v in index)
|
|
81
|
+
columns += ' DESC' if len(index) > 1 else ''
|
|
82
|
+
sql = (f"CREATE INDEX IF NOT EXISTS {name}\n"
|
|
83
|
+
f"ON {self.fqn_dict[table_name]}({columns})")
|
|
84
|
+
cursor.execute(sql)
|
|
85
|
+
self._connection.commit()
|
|
86
|
+
logger.log(f"Created {index_repr} in {table_repr}")
|
|
87
|
+
|
|
88
|
+
elif len(index_dict) > 0:
|
|
89
|
+
warnings.warn(f"Missing {index_repr} in {table_repr} for optimal "
|
|
90
|
+
f"database querying. For improving runtime, we "
|
|
91
|
+
f"strongly suggest to create these indices by "
|
|
92
|
+
f"instantiating KumoRFM via "
|
|
93
|
+
f"`KumoRFM(graph, optimize=True)`.")
|
|
94
|
+
|
|
95
|
+
def _get_min_max_time_dict(
|
|
96
|
+
self,
|
|
97
|
+
table_names: list[str],
|
|
98
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
99
|
+
selects: list[str] = []
|
|
100
|
+
for table_name in table_names:
|
|
101
|
+
time_column = self.time_column_dict[table_name]
|
|
102
|
+
select = (f"SELECT\n"
|
|
103
|
+
f" ? as table_name,\n"
|
|
104
|
+
f" MIN({quote_ident(time_column)}) as min_date,\n"
|
|
105
|
+
f" MAX({quote_ident(time_column)}) as max_date\n"
|
|
106
|
+
f"FROM {self.fqn_dict[table_name]}")
|
|
107
|
+
selects.append(select)
|
|
108
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
109
|
+
|
|
110
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
111
|
+
with self._connection.cursor() as cursor:
|
|
112
|
+
cursor.execute(sql, table_names)
|
|
113
|
+
for table_name, _min, _max in cursor.fetchall():
|
|
114
|
+
out_dict[table_name] = (
|
|
115
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
116
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
117
|
+
)
|
|
118
|
+
return out_dict
|
|
119
|
+
|
|
120
|
+
def _sample_entity_table(
|
|
121
|
+
self,
|
|
122
|
+
table_name: str,
|
|
123
|
+
columns: set[str],
|
|
124
|
+
num_rows: int,
|
|
125
|
+
random_seed: int | None = None,
|
|
126
|
+
) -> pd.DataFrame:
|
|
127
|
+
# NOTE SQLite does not natively support passing a `random_seed`.
|
|
128
|
+
|
|
129
|
+
filters: list[str] = []
|
|
130
|
+
primary_key = self.primary_key_dict[table_name]
|
|
131
|
+
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
132
|
+
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
133
|
+
time_column = self.time_column_dict.get(table_name)
|
|
134
|
+
if (time_column is not None and
|
|
135
|
+
self.source_table_dict[table_name][time_column].is_nullable):
|
|
136
|
+
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
137
|
+
|
|
138
|
+
# TODO Make this query more efficient - it does full table scan.
|
|
139
|
+
sql = (f"SELECT {', '.join(quote_ident(col) for col in columns)}\n"
|
|
140
|
+
f"FROM {self.fqn_dict[table_name]}")
|
|
141
|
+
if len(filters) > 0:
|
|
142
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
143
|
+
sql += f"\nORDER BY RANDOM() LIMIT {num_rows}"
|
|
144
|
+
|
|
145
|
+
with self._connection.cursor() as cursor:
|
|
146
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
147
|
+
cursor.execute(sql)
|
|
148
|
+
table = cursor.fetch_arrow_table()
|
|
149
|
+
|
|
150
|
+
return self._sanitize(table_name, table)
|
|
151
|
+
|
|
152
|
+
def _sample_target(
|
|
153
|
+
self,
|
|
154
|
+
query: ValidatedPredictiveQuery,
|
|
155
|
+
entity_df: pd.DataFrame,
|
|
156
|
+
train_index: np.ndarray,
|
|
157
|
+
train_time: pd.Series,
|
|
158
|
+
num_train_examples: int,
|
|
159
|
+
test_index: np.ndarray,
|
|
160
|
+
test_time: pd.Series,
|
|
161
|
+
num_test_examples: int,
|
|
162
|
+
columns_dict: dict[str, set[str]],
|
|
163
|
+
time_offset_dict: dict[
|
|
164
|
+
tuple[str, str, str],
|
|
165
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
166
|
+
],
|
|
167
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
168
|
+
train_y, train_mask = self._sample_target_set(
|
|
169
|
+
query=query,
|
|
170
|
+
entity_df=entity_df,
|
|
171
|
+
index=train_index,
|
|
172
|
+
anchor_time=train_time,
|
|
173
|
+
num_examples=num_train_examples,
|
|
174
|
+
columns_dict=columns_dict,
|
|
175
|
+
time_offset_dict=time_offset_dict,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
test_y, test_mask = self._sample_target_set(
|
|
179
|
+
query=query,
|
|
180
|
+
entity_df=entity_df,
|
|
181
|
+
index=test_index,
|
|
182
|
+
anchor_time=test_time,
|
|
183
|
+
num_examples=num_test_examples,
|
|
184
|
+
columns_dict=columns_dict,
|
|
185
|
+
time_offset_dict=time_offset_dict,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return train_y, train_mask, test_y, test_mask
|
|
189
|
+
|
|
190
|
+
def _by_pkey(
|
|
191
|
+
self,
|
|
192
|
+
table_name: str,
|
|
193
|
+
pkey: pd.Series,
|
|
194
|
+
columns: set[str],
|
|
195
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
196
|
+
pkey_name = self.primary_key_dict[table_name]
|
|
197
|
+
|
|
198
|
+
tmp = pa.table([pa.array(pkey)], names=['id'])
|
|
199
|
+
tmp_name = f'tmp_{table_name}_{pkey_name}_{id(tmp)}'
|
|
200
|
+
|
|
201
|
+
if self.source_table_dict[table_name][pkey_name].is_unique_key:
|
|
202
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
203
|
+
f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
|
|
204
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
205
|
+
f"JOIN {self.fqn_dict[table_name]} ent\n"
|
|
206
|
+
f" ON ent.{quote_ident(pkey_name)} = tmp.id")
|
|
207
|
+
else:
|
|
208
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
209
|
+
f"{', '.join('ent.' + quote_ident(c) for c in columns)}\n"
|
|
210
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
211
|
+
f"JOIN {self.fqn_dict[table_name]} ent\n"
|
|
212
|
+
f" ON ent.rowid = (\n"
|
|
213
|
+
f" SELECT rowid FROM {self.fqn_dict[table_name]}\n"
|
|
214
|
+
f" WHERE {quote_ident(pkey_name)} == tmp.id\n"
|
|
215
|
+
f" LIMIT 1\n"
|
|
216
|
+
f")")
|
|
217
|
+
|
|
218
|
+
with self._connection.cursor() as cursor:
|
|
219
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
220
|
+
cursor.execute(sql)
|
|
221
|
+
table = cursor.fetch_arrow_table()
|
|
222
|
+
|
|
223
|
+
batch = table['__batch__'].to_numpy()
|
|
224
|
+
table = table.remove_column(table.schema.get_field_index('__batch__'))
|
|
225
|
+
|
|
226
|
+
return table.to_pandas(), batch # TODO Use `self._sanitize`.
|
|
227
|
+
|
|
228
|
+
# Helper Methods ##########################################################
|
|
229
|
+
|
|
230
|
+
def _by_time(
|
|
231
|
+
self,
|
|
232
|
+
table_name: str,
|
|
233
|
+
fkey: str,
|
|
234
|
+
pkey: pd.Series,
|
|
235
|
+
anchor_time: pd.Series,
|
|
236
|
+
min_offset: pd.DateOffset | None,
|
|
237
|
+
max_offset: pd.DateOffset,
|
|
238
|
+
columns: set[str],
|
|
239
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
|
+
# NOTE SQLite does not have a native datetime format. Currently, we
|
|
241
|
+
# assume timestamps are given as `TEXT` in `ISO-8601 UTC`:
|
|
242
|
+
tmp = pa.table([pa.array(pkey)], names=['id'])
|
|
243
|
+
end_time = anchor_time + max_offset
|
|
244
|
+
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
245
|
+
tmp = tmp.append_column('end', pa.array(end_time))
|
|
246
|
+
if min_offset is not None:
|
|
247
|
+
start_time = anchor_time + min_offset
|
|
248
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
249
|
+
tmp = tmp.append_column('start', pa.array(start_time))
|
|
250
|
+
tmp_name = f'tmp_{table_name}_{fkey}_{id(tmp)}'
|
|
251
|
+
|
|
252
|
+
time_column = self.time_column_dict[table_name]
|
|
253
|
+
sql = (f"SELECT tmp.rowid - 1 as __batch__, "
|
|
254
|
+
f"{', '.join('fact.' + quote_ident(col) for col in columns)}\n"
|
|
255
|
+
f"FROM {quote_ident(tmp_name)} tmp\n"
|
|
256
|
+
f"JOIN {self.fqn_dict[table_name]} fact\n"
|
|
257
|
+
f" ON fact.{quote_ident(fkey)} = tmp.id\n"
|
|
258
|
+
f" AND fact.{quote_ident(time_column)} <= tmp.end")
|
|
259
|
+
if min_offset is not None:
|
|
260
|
+
sql += f"\n AND fact.{quote_ident(time_column)} > tmp.start"
|
|
261
|
+
|
|
262
|
+
with self._connection.cursor() as cursor:
|
|
263
|
+
cursor.adbc_ingest(tmp_name, tmp, mode='replace')
|
|
264
|
+
cursor.execute(sql)
|
|
265
|
+
table = cursor.fetch_arrow_table()
|
|
266
|
+
|
|
267
|
+
batch = table['__batch__'].to_numpy()
|
|
268
|
+
table = table.remove_column(table.schema.get_field_index('__batch__'))
|
|
269
|
+
|
|
270
|
+
return self._sanitize(table_name, table), batch
|
|
271
|
+
|
|
272
|
+
def _sample_target_set(
|
|
273
|
+
self,
|
|
274
|
+
query: ValidatedPredictiveQuery,
|
|
275
|
+
entity_df: pd.DataFrame,
|
|
276
|
+
index: np.ndarray,
|
|
277
|
+
anchor_time: pd.Series,
|
|
278
|
+
num_examples: int,
|
|
279
|
+
columns_dict: dict[str, set[str]],
|
|
280
|
+
time_offset_dict: dict[
|
|
281
|
+
tuple[str, str, str],
|
|
282
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
283
|
+
],
|
|
284
|
+
batch_size: int = 10_000,
|
|
285
|
+
) -> tuple[pd.Series, np.ndarray]:
|
|
286
|
+
|
|
287
|
+
count = 0
|
|
288
|
+
ys: list[pd.Series] = []
|
|
289
|
+
mask = np.full(len(index), False, dtype=bool)
|
|
290
|
+
for start in range(0, len(index), batch_size):
|
|
291
|
+
df = entity_df.iloc[index[start:start + batch_size]]
|
|
292
|
+
time = anchor_time.iloc[start:start + batch_size]
|
|
293
|
+
|
|
294
|
+
feat_dict: dict[str, pd.DataFrame] = {query.entity_table: df}
|
|
295
|
+
time_dict: dict[str, pd.Series] = {}
|
|
296
|
+
time_column = self.time_column_dict.get(query.entity_table)
|
|
297
|
+
if time_column in columns_dict[query.entity_table]:
|
|
298
|
+
time_dict[query.entity_table] = df[time_column]
|
|
299
|
+
batch_dict: dict[str, np.ndarray] = {
|
|
300
|
+
query.entity_table: np.arange(len(df)),
|
|
301
|
+
}
|
|
302
|
+
for edge_type, (_min, _max) in time_offset_dict.items():
|
|
303
|
+
table_name, fkey, _ = edge_type
|
|
304
|
+
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
305
|
+
table_name=table_name,
|
|
306
|
+
fkey=fkey,
|
|
307
|
+
pkey=df[self.primary_key_dict[query.entity_table]],
|
|
308
|
+
anchor_time=time,
|
|
309
|
+
min_offset=_min,
|
|
310
|
+
max_offset=_max,
|
|
311
|
+
columns=columns_dict[table_name],
|
|
312
|
+
)
|
|
313
|
+
time_column = self.time_column_dict.get(table_name)
|
|
314
|
+
if time_column in columns_dict[table_name]:
|
|
315
|
+
time_dict[table_name] = feat_dict[table_name][time_column]
|
|
316
|
+
|
|
317
|
+
y, _mask = PQueryPandasExecutor().execute(
|
|
318
|
+
query=query,
|
|
319
|
+
feat_dict=feat_dict,
|
|
320
|
+
time_dict=time_dict,
|
|
321
|
+
batch_dict=batch_dict,
|
|
322
|
+
anchor_time=anchor_time,
|
|
323
|
+
num_forecasts=query.num_forecasts,
|
|
324
|
+
)
|
|
325
|
+
ys.append(y)
|
|
326
|
+
mask[start:start + batch_size] = _mask
|
|
327
|
+
|
|
328
|
+
count += len(y)
|
|
329
|
+
if count >= num_examples:
|
|
330
|
+
break
|
|
331
|
+
|
|
332
|
+
if len(ys) == 0:
|
|
333
|
+
y = pd.Series([], dtype=float)
|
|
334
|
+
elif len(ys) == 1:
|
|
335
|
+
y = ys[0]
|
|
336
|
+
else:
|
|
337
|
+
y = pd.concat(ys, axis=0, ignore_index=True)
|
|
338
|
+
|
|
339
|
+
return y, mask
|
|
340
|
+
|
|
341
|
+
def _sanitize(self, table_name: str, table: pa.table) -> pd.DataFrame:
|
|
342
|
+
df = table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
343
|
+
|
|
344
|
+
stype_dict = self.table_stype_dict[table_name]
|
|
345
|
+
for column_name in df.columns:
|
|
346
|
+
if stype_dict.get(column_name) == Stype.timestamp:
|
|
347
|
+
df[column_name] = pd.to_datetime(df[column_name])
|
|
348
|
+
|
|
349
|
+
return df
|