kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
- kumoai/experimental/rfm/backend/local/table.py +16 -21
- kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
- kumoai/experimental/rfm/backend/snow/table.py +67 -33
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
- kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
- kumoai/experimental/rfm/base/__init__.py +8 -0
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sql_sampler.py +31 -3
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +162 -143
- kumoai/experimental/rfm/graph.py +242 -95
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/rfm.py +86 -80
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/progress_logger.py +178 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/_version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = '2.14.0.
|
|
1
|
+
__version__ = '2.14.0.dev202512211732'
|
|
@@ -6,11 +6,11 @@ import socket
|
|
|
6
6
|
import threading
|
|
7
7
|
from dataclasses import dataclass
|
|
8
8
|
from enum import Enum
|
|
9
|
-
from typing import Dict, Optional, Tuple
|
|
10
9
|
from urllib.parse import urlparse
|
|
11
10
|
|
|
12
11
|
import kumoai
|
|
13
12
|
from kumoai.client.client import KumoClient
|
|
13
|
+
from kumoai.spcs import _get_active_session
|
|
14
14
|
|
|
15
15
|
from .authenticate import authenticate
|
|
16
16
|
from .sagemaker import (
|
|
@@ -49,7 +49,8 @@ class InferenceBackend(str, Enum):
|
|
|
49
49
|
|
|
50
50
|
|
|
51
51
|
def _detect_backend(
|
|
52
|
-
url: str
|
|
52
|
+
url: str, #
|
|
53
|
+
) -> tuple[InferenceBackend, str | None, str | None]:
|
|
53
54
|
parsed = urlparse(url)
|
|
54
55
|
|
|
55
56
|
# Remote SageMaker
|
|
@@ -73,12 +74,27 @@ def _detect_backend(
|
|
|
73
74
|
return InferenceBackend.REST, None, None
|
|
74
75
|
|
|
75
76
|
|
|
77
|
+
def _get_snowflake_url(snowflake_application: str) -> str:
|
|
78
|
+
snowpark_session = _get_active_session()
|
|
79
|
+
if not snowpark_session:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"Client creation failed: snowflake_application is specified "
|
|
82
|
+
"without an active snowpark session. If running outside "
|
|
83
|
+
"a snowflake notebook, specify a URL and credentials.")
|
|
84
|
+
with snowpark_session.connection.cursor() as cur:
|
|
85
|
+
cur.execute(
|
|
86
|
+
f"DESCRIBE SERVICE {snowflake_application}.user_schema.rfm_service"
|
|
87
|
+
f" ->> SELECT \"dns_name\" from $1")
|
|
88
|
+
dns_name: str = cur.fetchone()[0]
|
|
89
|
+
return f"http://{dns_name}:8000/api"
|
|
90
|
+
|
|
91
|
+
|
|
76
92
|
@dataclass
|
|
77
93
|
class RfmGlobalState:
|
|
78
94
|
_url: str = '__url_not_provided__'
|
|
79
95
|
_backend: InferenceBackend = InferenceBackend.UNKNOWN
|
|
80
|
-
_region:
|
|
81
|
-
_endpoint_name:
|
|
96
|
+
_region: str | None = None
|
|
97
|
+
_endpoint_name: str | None = None
|
|
82
98
|
_thread_local = threading.local()
|
|
83
99
|
|
|
84
100
|
# Thread-safe init-once.
|
|
@@ -121,10 +137,10 @@ global_state = RfmGlobalState()
|
|
|
121
137
|
|
|
122
138
|
|
|
123
139
|
def init(
|
|
124
|
-
url:
|
|
125
|
-
api_key:
|
|
126
|
-
snowflake_credentials:
|
|
127
|
-
snowflake_application:
|
|
140
|
+
url: str | None = None,
|
|
141
|
+
api_key: str | None = None,
|
|
142
|
+
snowflake_credentials: dict[str, str] | None = None,
|
|
143
|
+
snowflake_application: str | None = None,
|
|
128
144
|
log_level: str = "INFO",
|
|
129
145
|
) -> None:
|
|
130
146
|
with global_state._lock:
|
|
@@ -136,6 +152,15 @@ def init(
|
|
|
136
152
|
"supported.")
|
|
137
153
|
return
|
|
138
154
|
|
|
155
|
+
if snowflake_application:
|
|
156
|
+
if url is not None:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
"Client creation failed: both snowflake_application and "
|
|
159
|
+
"url are specified. If running from a snowflake notebook, "
|
|
160
|
+
"specify only snowflake_application.")
|
|
161
|
+
url = _get_snowflake_url(snowflake_application)
|
|
162
|
+
api_key = "test:DISABLED"
|
|
163
|
+
|
|
139
164
|
if url is None:
|
|
140
165
|
url = os.getenv("RFM_API_URL", "https://kumorfm.ai/api")
|
|
141
166
|
|
|
@@ -2,12 +2,11 @@ import logging
|
|
|
2
2
|
import os
|
|
3
3
|
import platform
|
|
4
4
|
from datetime import datetime
|
|
5
|
-
from typing import Optional
|
|
6
5
|
|
|
7
6
|
from kumoai import in_notebook
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
def authenticate(api_url:
|
|
9
|
+
def authenticate(api_url: str | None = None) -> None:
|
|
11
10
|
"""Authenticates the user and sets the Kumo API key for the SDK.
|
|
12
11
|
|
|
13
12
|
This function detects the current environment and launches the appropriate
|
|
@@ -65,11 +64,11 @@ def _authenticate_local(api_url: str, redirect_port: int = 8765) -> None:
|
|
|
65
64
|
import webbrowser
|
|
66
65
|
from getpass import getpass
|
|
67
66
|
from socketserver import TCPServer
|
|
68
|
-
from typing import Any
|
|
67
|
+
from typing import Any
|
|
69
68
|
|
|
70
69
|
logger = logging.getLogger('kumoai')
|
|
71
70
|
|
|
72
|
-
token_status:
|
|
71
|
+
token_status: dict[str, Any] = {
|
|
73
72
|
'token': None,
|
|
74
73
|
'token_name': None,
|
|
75
74
|
'failed': False
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import pandas as pd
|
|
@@ -7,7 +7,7 @@ from kumoapi.rfm.context import Subgraph
|
|
|
7
7
|
from kumoapi.typing import Stype
|
|
8
8
|
|
|
9
9
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
10
|
-
from kumoai.utils import
|
|
10
|
+
from kumoai.utils import ProgressLogger
|
|
11
11
|
|
|
12
12
|
try:
|
|
13
13
|
import torch
|
|
@@ -23,12 +23,12 @@ class LocalGraphStore:
|
|
|
23
23
|
def __init__(
|
|
24
24
|
self,
|
|
25
25
|
graph: 'Graph',
|
|
26
|
-
verbose:
|
|
26
|
+
verbose: bool | ProgressLogger = True,
|
|
27
27
|
) -> None:
|
|
28
28
|
|
|
29
29
|
if not isinstance(verbose, ProgressLogger):
|
|
30
|
-
verbose =
|
|
31
|
-
"Materializing graph",
|
|
30
|
+
verbose = ProgressLogger.default(
|
|
31
|
+
msg="Materializing graph",
|
|
32
32
|
verbose=verbose,
|
|
33
33
|
)
|
|
34
34
|
|
|
@@ -94,7 +94,7 @@ class LocalGraphStore:
|
|
|
94
94
|
def sanitize(
|
|
95
95
|
self,
|
|
96
96
|
graph: 'Graph',
|
|
97
|
-
) ->
|
|
97
|
+
) -> tuple[dict[str, pd.DataFrame], dict[str, np.ndarray]]:
|
|
98
98
|
r"""Sanitizes raw data according to table schema definition:
|
|
99
99
|
|
|
100
100
|
In particular, it:
|
|
@@ -103,13 +103,13 @@ class LocalGraphStore:
|
|
|
103
103
|
* drops duplicate primary keys
|
|
104
104
|
* removes rows with missing primary keys or time values
|
|
105
105
|
"""
|
|
106
|
-
df_dict:
|
|
106
|
+
df_dict: dict[str, pd.DataFrame] = {}
|
|
107
107
|
for table_name, table in graph.tables.items():
|
|
108
108
|
assert isinstance(table, LocalTable)
|
|
109
109
|
df = table._data
|
|
110
110
|
df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
|
|
111
111
|
|
|
112
|
-
mask_dict:
|
|
112
|
+
mask_dict: dict[str, np.ndarray] = {}
|
|
113
113
|
for table in graph.tables.values():
|
|
114
114
|
for col in table.columns:
|
|
115
115
|
if col.stype == Stype.timestamp:
|
|
@@ -126,7 +126,7 @@ class LocalGraphStore:
|
|
|
126
126
|
ser = ser.dt.tz_localize(None)
|
|
127
127
|
df_dict[table.name][col.name] = ser
|
|
128
128
|
|
|
129
|
-
mask:
|
|
129
|
+
mask: np.ndarray | None = None
|
|
130
130
|
if table._time_column is not None:
|
|
131
131
|
ser = df_dict[table.name][table._time_column]
|
|
132
132
|
mask = ser.notna().to_numpy()
|
|
@@ -144,8 +144,8 @@ class LocalGraphStore:
|
|
|
144
144
|
def get_pkey_map_dict(
|
|
145
145
|
self,
|
|
146
146
|
graph: 'Graph',
|
|
147
|
-
) ->
|
|
148
|
-
pkey_map_dict:
|
|
147
|
+
) -> dict[str, pd.DataFrame]:
|
|
148
|
+
pkey_map_dict: dict[str, pd.DataFrame] = {}
|
|
149
149
|
|
|
150
150
|
for table in graph.tables.values():
|
|
151
151
|
if table._primary_key is None:
|
|
@@ -177,12 +177,12 @@ class LocalGraphStore:
|
|
|
177
177
|
def get_time_data(
|
|
178
178
|
self,
|
|
179
179
|
graph: 'Graph',
|
|
180
|
-
) ->
|
|
181
|
-
|
|
182
|
-
|
|
180
|
+
) -> tuple[
|
|
181
|
+
dict[str, np.ndarray],
|
|
182
|
+
dict[str, tuple[pd.Timestamp, pd.Timestamp]],
|
|
183
183
|
]:
|
|
184
|
-
time_dict:
|
|
185
|
-
min_max_time_dict:
|
|
184
|
+
time_dict: dict[str, np.ndarray] = {}
|
|
185
|
+
min_max_time_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
186
186
|
for table in graph.tables.values():
|
|
187
187
|
if table._time_column is None:
|
|
188
188
|
continue
|
|
@@ -207,15 +207,15 @@ class LocalGraphStore:
|
|
|
207
207
|
def get_csc(
|
|
208
208
|
self,
|
|
209
209
|
graph: 'Graph',
|
|
210
|
-
) ->
|
|
211
|
-
|
|
212
|
-
|
|
210
|
+
) -> tuple[
|
|
211
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
212
|
+
dict[tuple[str, str, str], np.ndarray],
|
|
213
213
|
]:
|
|
214
214
|
# A mapping from raw primary keys to node indices (0 to N-1):
|
|
215
|
-
map_dict:
|
|
215
|
+
map_dict: dict[str, pd.CategoricalDtype] = {}
|
|
216
216
|
# A dictionary to manage offsets of node indices for invalid rows:
|
|
217
|
-
offset_dict:
|
|
218
|
-
for table_name in
|
|
217
|
+
offset_dict: dict[str, np.ndarray] = {}
|
|
218
|
+
for table_name in {edge.dst_table for edge in graph.edges}:
|
|
219
219
|
ser = self.df_dict[table_name][graph[table_name]._primary_key]
|
|
220
220
|
if table_name in self.mask_dict.keys():
|
|
221
221
|
mask = self.mask_dict[table_name]
|
|
@@ -224,8 +224,8 @@ class LocalGraphStore:
|
|
|
224
224
|
map_dict[table_name] = pd.CategoricalDtype(ser, ordered=True)
|
|
225
225
|
|
|
226
226
|
# Build CSC graph representation:
|
|
227
|
-
row_dict:
|
|
228
|
-
colptr_dict:
|
|
227
|
+
row_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
228
|
+
colptr_dict: dict[tuple[str, str, str], np.ndarray] = {}
|
|
229
229
|
for src_table, fkey, dst_table in graph.edges:
|
|
230
230
|
src_df = self.df_dict[src_table]
|
|
231
231
|
dst_df = self.df_dict[dst_table]
|
|
@@ -287,7 +287,7 @@ def _argsort(input: np.ndarray) -> np.ndarray:
|
|
|
287
287
|
return torch.from_numpy(input).argsort().numpy()
|
|
288
288
|
|
|
289
289
|
|
|
290
|
-
def _lexsort(inputs:
|
|
290
|
+
def _lexsort(inputs: list[np.ndarray]) -> np.ndarray:
|
|
291
291
|
assert len(inputs) >= 1
|
|
292
292
|
|
|
293
293
|
if not WITH_TORCH:
|
|
@@ -1,14 +1,10 @@
|
|
|
1
1
|
import warnings
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import cast
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
|
+
from kumoapi.model_plan import MissingType
|
|
5
6
|
|
|
6
|
-
from kumoai.experimental.rfm.base import
|
|
7
|
-
DataBackend,
|
|
8
|
-
SourceColumn,
|
|
9
|
-
SourceForeignKey,
|
|
10
|
-
Table,
|
|
11
|
-
)
|
|
7
|
+
from kumoai.experimental.rfm.base import DataBackend, SourceColumn, Table
|
|
12
8
|
from kumoai.experimental.rfm.infer import infer_dtype
|
|
13
9
|
|
|
14
10
|
|
|
@@ -57,9 +53,9 @@ class LocalTable(Table):
|
|
|
57
53
|
self,
|
|
58
54
|
df: pd.DataFrame,
|
|
59
55
|
name: str,
|
|
60
|
-
primary_key:
|
|
61
|
-
time_column:
|
|
62
|
-
end_time_column:
|
|
56
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
57
|
+
time_column: str | None = None,
|
|
58
|
+
end_time_column: str | None = None,
|
|
63
59
|
) -> None:
|
|
64
60
|
|
|
65
61
|
if df.empty:
|
|
@@ -85,17 +81,19 @@ class LocalTable(Table):
|
|
|
85
81
|
def backend(self) -> DataBackend:
|
|
86
82
|
return cast(DataBackend, DataBackend.LOCAL)
|
|
87
83
|
|
|
88
|
-
def _get_source_columns(self) ->
|
|
89
|
-
source_columns:
|
|
84
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
85
|
+
source_columns: list[SourceColumn] = []
|
|
90
86
|
for column in self._data.columns:
|
|
91
87
|
ser = self._data[column]
|
|
92
88
|
try:
|
|
93
89
|
dtype = infer_dtype(ser)
|
|
94
90
|
except Exception:
|
|
95
|
-
warnings.warn(f"
|
|
96
|
-
f"
|
|
97
|
-
f"the data type of
|
|
98
|
-
f"
|
|
91
|
+
warnings.warn(f"Encountered unsupported data type "
|
|
92
|
+
f"'{ser.dtype}' for column '{column}' in table "
|
|
93
|
+
f"'{self.name}'. Please change the data type of "
|
|
94
|
+
f"the column in the `pandas.DataFrame` to use "
|
|
95
|
+
f"it within this table, or remove it to "
|
|
96
|
+
f"suppress this warning.")
|
|
99
97
|
continue
|
|
100
98
|
|
|
101
99
|
source_column = SourceColumn(
|
|
@@ -109,11 +107,8 @@ class LocalTable(Table):
|
|
|
109
107
|
|
|
110
108
|
return source_columns
|
|
111
109
|
|
|
112
|
-
def
|
|
113
|
-
return []
|
|
114
|
-
|
|
115
|
-
def _get_sample_df(self) -> pd.DataFrame:
|
|
110
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
116
111
|
return self._data
|
|
117
112
|
|
|
118
|
-
def _get_num_rows(self) ->
|
|
113
|
+
def _get_num_rows(self) -> int | None:
|
|
119
114
|
return len(self._data)
|
|
@@ -1,39 +1,27 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from contextlib import contextmanager
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
5
6
|
import pandas as pd
|
|
6
7
|
import pyarrow as pa
|
|
7
8
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
8
9
|
|
|
9
|
-
from kumoai.experimental.rfm.backend.snow import
|
|
10
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
10
11
|
from kumoai.experimental.rfm.base import SQLSampler
|
|
11
12
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
12
|
-
from kumoai.utils import
|
|
13
|
+
from kumoai.utils import quote_ident
|
|
13
14
|
|
|
14
|
-
if TYPE_CHECKING:
|
|
15
|
-
from kumoai.experimental.rfm import Graph
|
|
16
15
|
|
|
16
|
+
@contextmanager
|
|
17
|
+
def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
|
|
18
|
+
_style = connection._paramstyle
|
|
19
|
+
connection._paramstyle = style
|
|
20
|
+
yield
|
|
21
|
+
connection._paramstyle = _style
|
|
17
22
|
|
|
18
|
-
class SnowSampler(SQLSampler):
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
graph: 'Graph',
|
|
22
|
-
verbose: bool | ProgressLogger = True,
|
|
23
|
-
) -> None:
|
|
24
|
-
super().__init__(graph=graph, verbose=verbose)
|
|
25
|
-
|
|
26
|
-
self._fqn_dict: dict[str, str] = {}
|
|
27
|
-
for table in graph.tables.values():
|
|
28
|
-
assert isinstance(table, SnowTable)
|
|
29
|
-
self._connection = table._connection
|
|
30
|
-
self._fqn_dict[table.name] = table.fqn
|
|
31
|
-
|
|
32
|
-
@property
|
|
33
|
-
def fqn_dict(self) -> dict[str, str]:
|
|
34
|
-
r"""The fully-qualified quoted names for all tables in the graph."""
|
|
35
|
-
return self._fqn_dict
|
|
36
23
|
|
|
24
|
+
class SnowSampler(SQLSampler):
|
|
37
25
|
def _get_min_max_time_dict(
|
|
38
26
|
self,
|
|
39
27
|
table_names: list[str],
|
|
@@ -42,7 +30,7 @@ class SnowSampler(SQLSampler):
|
|
|
42
30
|
for table_name in table_names:
|
|
43
31
|
time_column = self.time_column_dict[table_name]
|
|
44
32
|
select = (f"SELECT\n"
|
|
45
|
-
f"
|
|
33
|
+
f" ? as table_name,\n"
|
|
46
34
|
f" MIN({quote_ident(time_column)}) as min_date,\n"
|
|
47
35
|
f" MAX({quote_ident(time_column)}) as max_date\n"
|
|
48
36
|
f"FROM {self.fqn_dict[table_name]}")
|
|
@@ -50,14 +38,14 @@ class SnowSampler(SQLSampler):
|
|
|
50
38
|
sql = "\nUNION ALL\n".join(selects)
|
|
51
39
|
|
|
52
40
|
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
53
|
-
with self._connection.cursor() as cursor:
|
|
41
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
54
42
|
cursor.execute(sql, table_names)
|
|
55
43
|
rows = cursor.fetchall()
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
44
|
+
for table_name, _min, _max in rows:
|
|
45
|
+
out_dict[table_name] = (
|
|
46
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
47
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
48
|
+
)
|
|
61
49
|
|
|
62
50
|
return out_dict
|
|
63
51
|
|
|
@@ -179,7 +167,7 @@ class SnowSampler(SQLSampler):
|
|
|
179
167
|
sql += " f.value::FLOAT as ID\n"
|
|
180
168
|
else:
|
|
181
169
|
sql += " f.value::VARCHAR as ID\n"
|
|
182
|
-
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(
|
|
170
|
+
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
183
171
|
f")\n"
|
|
184
172
|
f"SELECT TMP.BATCH as __BATCH__, "
|
|
185
173
|
f"{', '.join('ENT.' + quote_ident(col) for col in columns)}\n"
|
|
@@ -187,7 +175,7 @@ class SnowSampler(SQLSampler):
|
|
|
187
175
|
f"JOIN {self.fqn_dict[table_name]} ENT\n"
|
|
188
176
|
f" ON ENT.{quote_ident(pkey_name)} = TMP.ID")
|
|
189
177
|
|
|
190
|
-
with self._connection.cursor() as cursor:
|
|
178
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
191
179
|
cursor.execute(sql, (payload, ))
|
|
192
180
|
table = cursor.fetch_arrow_all()
|
|
193
181
|
|
|
@@ -240,7 +228,7 @@ class SnowSampler(SQLSampler):
|
|
|
240
228
|
if min_offset is not None:
|
|
241
229
|
sql += ",\n f.value[2]::TIMESTAMP_NTZ as START_TIME"
|
|
242
230
|
sql += (f"\n"
|
|
243
|
-
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(
|
|
231
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
244
232
|
f")\n"
|
|
245
233
|
f"SELECT TMP.BATCH as __BATCH__, "
|
|
246
234
|
f"{', '.join('FACT.' + quote_ident(col) for col in columns)}\n"
|
|
@@ -251,7 +239,7 @@ class SnowSampler(SQLSampler):
|
|
|
251
239
|
if min_offset is not None:
|
|
252
240
|
sql += f"\n AND FACT.{quote_ident(time_column)} > TMP.START_TIME"
|
|
253
241
|
|
|
254
|
-
with self._connection.cursor() as cursor:
|
|
242
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
255
243
|
cursor.execute(sql, (payload, ))
|
|
256
244
|
table = cursor.fetch_arrow_all()
|
|
257
245
|
|
|
@@ -1,28 +1,35 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
from typing import cast
|
|
3
4
|
|
|
4
5
|
import pandas as pd
|
|
6
|
+
from kumoapi.model_plan import MissingType
|
|
5
7
|
from kumoapi.typing import Dtype
|
|
6
8
|
|
|
7
9
|
from kumoai.experimental.rfm.backend.snow import Connection
|
|
8
10
|
from kumoai.experimental.rfm.base import (
|
|
11
|
+
ColumnExpressionSpec,
|
|
12
|
+
ColumnExpressionType,
|
|
9
13
|
DataBackend,
|
|
10
14
|
SourceColumn,
|
|
11
15
|
SourceForeignKey,
|
|
12
|
-
|
|
16
|
+
SQLTable,
|
|
13
17
|
)
|
|
14
18
|
from kumoai.utils import quote_ident
|
|
15
19
|
|
|
16
20
|
|
|
17
|
-
class SnowTable(
|
|
21
|
+
class SnowTable(SQLTable):
|
|
18
22
|
r"""A table backed by a :class:`sqlite` database.
|
|
19
23
|
|
|
20
24
|
Args:
|
|
21
25
|
connection: The connection to a :class:`snowflake` database.
|
|
22
|
-
name: The name of this table.
|
|
26
|
+
name: The logical name of this table.
|
|
27
|
+
source_name: The physical name of this table in the database. If set to
|
|
28
|
+
``None``, ``name`` is being used.
|
|
23
29
|
database: The database.
|
|
24
30
|
schema: The schema.
|
|
25
|
-
columns: The selected columns of this table.
|
|
31
|
+
columns: The selected physical columns of this table.
|
|
32
|
+
column_expressions: The logical columns of this table.
|
|
26
33
|
primary_key: The name of the primary key of this table, if it exists.
|
|
27
34
|
time_column: The name of the time column of this table, if it exists.
|
|
28
35
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -32,17 +39,20 @@ class SnowTable(Table):
|
|
|
32
39
|
self,
|
|
33
40
|
connection: Connection,
|
|
34
41
|
name: str,
|
|
42
|
+
source_name: str | None = None,
|
|
35
43
|
database: str | None = None,
|
|
36
44
|
schema: str | None = None,
|
|
37
|
-
columns:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
45
|
+
columns: Sequence[str] | None = None,
|
|
46
|
+
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
47
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
48
|
+
time_column: str | None = None,
|
|
49
|
+
end_time_column: str | None = None,
|
|
41
50
|
) -> None:
|
|
42
51
|
|
|
43
52
|
if database is not None and schema is None:
|
|
44
|
-
raise ValueError(f"
|
|
45
|
-
f"
|
|
53
|
+
raise ValueError(f"Unspecified 'schema' for table "
|
|
54
|
+
f"'{source_name or name}' in database "
|
|
55
|
+
f"'{database}'")
|
|
46
56
|
|
|
47
57
|
self._connection = connection
|
|
48
58
|
self._database = database
|
|
@@ -50,12 +60,32 @@ class SnowTable(Table):
|
|
|
50
60
|
|
|
51
61
|
super().__init__(
|
|
52
62
|
name=name,
|
|
63
|
+
source_name=source_name,
|
|
53
64
|
columns=columns,
|
|
65
|
+
column_expressions=column_expressions,
|
|
54
66
|
primary_key=primary_key,
|
|
55
67
|
time_column=time_column,
|
|
56
68
|
end_time_column=end_time_column,
|
|
57
69
|
)
|
|
58
70
|
|
|
71
|
+
@staticmethod
|
|
72
|
+
def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
|
|
73
|
+
if snowflake_dtype is None:
|
|
74
|
+
return None
|
|
75
|
+
snowflake_dtype = snowflake_dtype.strip().upper()
|
|
76
|
+
# TODO 'NUMBER(...)' is not always an integer!
|
|
77
|
+
if snowflake_dtype.startswith('NUMBER'):
|
|
78
|
+
return Dtype.int
|
|
79
|
+
elif snowflake_dtype.startswith('VARCHAR'):
|
|
80
|
+
return Dtype.string
|
|
81
|
+
elif snowflake_dtype == 'FLOAT':
|
|
82
|
+
return Dtype.float
|
|
83
|
+
elif snowflake_dtype == 'BOOLEAN':
|
|
84
|
+
return Dtype.bool
|
|
85
|
+
elif re.search('DATE|TIMESTAMP', snowflake_dtype):
|
|
86
|
+
return Dtype.date
|
|
87
|
+
return None
|
|
88
|
+
|
|
59
89
|
@property
|
|
60
90
|
def backend(self) -> DataBackend:
|
|
61
91
|
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
@@ -63,15 +93,15 @@ class SnowTable(Table):
|
|
|
63
93
|
@property
|
|
64
94
|
def fqn(self) -> str:
|
|
65
95
|
r"""The fully-qualified quoted table name."""
|
|
66
|
-
names:
|
|
96
|
+
names: list[str] = []
|
|
67
97
|
if self._database is not None:
|
|
68
98
|
names.append(quote_ident(self._database))
|
|
69
99
|
if self._schema is not None:
|
|
70
100
|
names.append(quote_ident(self._schema))
|
|
71
|
-
return '.'.join(names + [quote_ident(self.
|
|
101
|
+
return '.'.join(names + [quote_ident(self._source_name)])
|
|
72
102
|
|
|
73
|
-
def _get_source_columns(self) ->
|
|
74
|
-
source_columns:
|
|
103
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
104
|
+
source_columns: list[SourceColumn] = []
|
|
75
105
|
with self._connection.cursor() as cursor:
|
|
76
106
|
try:
|
|
77
107
|
sql = f"DESCRIBE TABLE {self.fqn}"
|
|
@@ -82,24 +112,15 @@ class SnowTable(Table):
|
|
|
82
112
|
names.append(self._database)
|
|
83
113
|
if self._schema is not None:
|
|
84
114
|
names.append(self._schema)
|
|
85
|
-
|
|
86
|
-
raise ValueError(f"Table '{
|
|
115
|
+
source_name = '.'.join(names + [self._source_name])
|
|
116
|
+
raise ValueError(f"Table '{source_name}' does not exist in "
|
|
117
|
+
f"the remote data backend") from e
|
|
87
118
|
|
|
88
119
|
for row in cursor.fetchall():
|
|
89
120
|
column, type, _, null, _, is_pkey, is_unique, *_ = row
|
|
90
121
|
|
|
91
|
-
|
|
92
|
-
if
|
|
93
|
-
dtype = Dtype.int
|
|
94
|
-
elif type.startswith('VARCHAR'):
|
|
95
|
-
dtype = Dtype.string
|
|
96
|
-
elif type == 'FLOAT':
|
|
97
|
-
dtype = Dtype.float
|
|
98
|
-
elif type == 'BOOLEAN':
|
|
99
|
-
dtype = Dtype.bool
|
|
100
|
-
elif re.search('DATE|TIMESTAMP', type):
|
|
101
|
-
dtype = Dtype.date
|
|
102
|
-
else:
|
|
122
|
+
dtype = self.to_dtype(type)
|
|
123
|
+
if dtype is None:
|
|
103
124
|
continue
|
|
104
125
|
|
|
105
126
|
source_column = SourceColumn(
|
|
@@ -113,8 +134,8 @@ class SnowTable(Table):
|
|
|
113
134
|
|
|
114
135
|
return source_columns
|
|
115
136
|
|
|
116
|
-
def _get_source_foreign_keys(self) ->
|
|
117
|
-
source_fkeys:
|
|
137
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
138
|
+
source_fkeys: list[SourceForeignKey] = []
|
|
118
139
|
with self._connection.cursor() as cursor:
|
|
119
140
|
sql = f"SHOW IMPORTED KEYS IN TABLE {self.fqn}"
|
|
120
141
|
cursor.execute(sql)
|
|
@@ -123,7 +144,7 @@ class SnowTable(Table):
|
|
|
123
144
|
source_fkeys.append(SourceForeignKey(fkey, dst_table, pkey))
|
|
124
145
|
return source_fkeys
|
|
125
146
|
|
|
126
|
-
def
|
|
147
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
127
148
|
with self._connection.cursor() as cursor:
|
|
128
149
|
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
129
150
|
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
@@ -131,5 +152,18 @@ class SnowTable(Table):
|
|
|
131
152
|
table = cursor.fetch_arrow_all()
|
|
132
153
|
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
133
154
|
|
|
134
|
-
def _get_num_rows(self) ->
|
|
155
|
+
def _get_num_rows(self) -> int | None:
|
|
135
156
|
return None
|
|
157
|
+
|
|
158
|
+
def _get_expression_sample_df(
|
|
159
|
+
self,
|
|
160
|
+
specs: Sequence[ColumnExpressionSpec],
|
|
161
|
+
) -> pd.DataFrame:
|
|
162
|
+
with self._connection.cursor() as cursor:
|
|
163
|
+
columns = [
|
|
164
|
+
f"{spec.expr} AS {quote_ident(spec.name)}" for spec in specs
|
|
165
|
+
]
|
|
166
|
+
sql = f"SELECT {', '.join(columns)} FROM {self.fqn} LIMIT 1000"
|
|
167
|
+
cursor.execute(sql)
|
|
168
|
+
table = cursor.fetch_arrow_all()
|
|
169
|
+
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, TypeAlias
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
3
|
|
|
4
4
|
try:
|
|
5
5
|
import adbc_driver_sqlite.dbapi as adbc
|
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
|
11
11
|
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def connect(uri:
|
|
14
|
+
def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
|
|
15
15
|
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
16
|
|
|
17
17
|
uri: The path to the database file to be opened.
|