kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +52 -52
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +753 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +81 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +313 -245
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +49 -29
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import pyarrow as pa
|
|
9
|
+
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
10
|
+
|
|
11
|
+
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
12
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
13
|
+
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
14
|
+
from kumoai.utils import ProgressLogger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from kumoai.experimental.rfm import Graph
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@contextmanager
|
|
21
|
+
def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
|
|
22
|
+
_style = connection._paramstyle
|
|
23
|
+
connection._paramstyle = style
|
|
24
|
+
yield
|
|
25
|
+
connection._paramstyle = _style
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SnowSampler(SQLSampler):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
graph: 'Graph',
|
|
32
|
+
verbose: bool | ProgressLogger = True,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
35
|
+
|
|
36
|
+
for table in graph.tables.values():
|
|
37
|
+
assert isinstance(table, SnowTable)
|
|
38
|
+
self._connection = table._connection
|
|
39
|
+
|
|
40
|
+
def _get_min_max_time_dict(
|
|
41
|
+
self,
|
|
42
|
+
table_names: list[str],
|
|
43
|
+
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
44
|
+
selects: list[str] = []
|
|
45
|
+
for table_name in table_names:
|
|
46
|
+
column = self.time_column_dict[table_name]
|
|
47
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
48
|
+
select = (f"SELECT\n"
|
|
49
|
+
f" ? as table_name,\n"
|
|
50
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
51
|
+
f" MAX({column_ref}) as max_date\n"
|
|
52
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
53
|
+
selects.append(select)
|
|
54
|
+
sql = "\nUNION ALL\n".join(selects)
|
|
55
|
+
|
|
56
|
+
out_dict: dict[str, tuple[pd.Timestamp, pd.Timestamp]] = {}
|
|
57
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
58
|
+
cursor.execute(sql, table_names)
|
|
59
|
+
rows = cursor.fetchall()
|
|
60
|
+
for table_name, _min, _max in rows:
|
|
61
|
+
out_dict[table_name] = (
|
|
62
|
+
pd.Timestamp.max if _min is None else pd.Timestamp(_min),
|
|
63
|
+
pd.Timestamp.min if _max is None else pd.Timestamp(_max),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
return out_dict
|
|
67
|
+
|
|
68
|
+
def _sample_entity_table(
|
|
69
|
+
self,
|
|
70
|
+
table_name: str,
|
|
71
|
+
columns: set[str],
|
|
72
|
+
num_rows: int,
|
|
73
|
+
random_seed: int | None = None,
|
|
74
|
+
) -> pd.DataFrame:
|
|
75
|
+
# NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
|
|
76
|
+
num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
|
|
77
|
+
|
|
78
|
+
source_table = self.source_table_dict[table_name]
|
|
79
|
+
filters: list[str] = []
|
|
80
|
+
|
|
81
|
+
key = self.primary_key_dict[table_name]
|
|
82
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
83
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
84
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
85
|
+
|
|
86
|
+
column = self.time_column_dict.get(table_name)
|
|
87
|
+
if column is None:
|
|
88
|
+
pass
|
|
89
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
90
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
91
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
92
|
+
|
|
93
|
+
projections = [
|
|
94
|
+
self.table_column_proj_dict[table_name][column]
|
|
95
|
+
for column in columns
|
|
96
|
+
]
|
|
97
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
98
|
+
f"FROM {self.source_name_dict[table_name]}\n"
|
|
99
|
+
f"SAMPLE ROW ({num_rows} ROWS)")
|
|
100
|
+
if len(filters) > 0:
|
|
101
|
+
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
102
|
+
|
|
103
|
+
with self._connection.cursor() as cursor:
|
|
104
|
+
# NOTE This may return duplicate primary keys. This is okay.
|
|
105
|
+
cursor.execute(sql)
|
|
106
|
+
table = cursor.fetch_arrow_all()
|
|
107
|
+
|
|
108
|
+
return Table._sanitize(
|
|
109
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
110
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
111
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def _sample_target(
|
|
115
|
+
self,
|
|
116
|
+
query: ValidatedPredictiveQuery,
|
|
117
|
+
entity_df: pd.DataFrame,
|
|
118
|
+
train_index: np.ndarray,
|
|
119
|
+
train_time: pd.Series,
|
|
120
|
+
num_train_examples: int,
|
|
121
|
+
test_index: np.ndarray,
|
|
122
|
+
test_time: pd.Series,
|
|
123
|
+
num_test_examples: int,
|
|
124
|
+
columns_dict: dict[str, set[str]],
|
|
125
|
+
time_offset_dict: dict[
|
|
126
|
+
tuple[str, str, str],
|
|
127
|
+
tuple[pd.DateOffset | None, pd.DateOffset],
|
|
128
|
+
],
|
|
129
|
+
) -> tuple[pd.Series, np.ndarray, pd.Series, np.ndarray]:
|
|
130
|
+
|
|
131
|
+
# NOTE For Snowflake, we execute everything at once to pay minimal
|
|
132
|
+
# query initialization costs.
|
|
133
|
+
index = np.concatenate([train_index, test_index])
|
|
134
|
+
time = pd.concat([train_time, test_time], axis=0, ignore_index=True)
|
|
135
|
+
|
|
136
|
+
entity_df = entity_df.iloc[index].reset_index(drop=True)
|
|
137
|
+
|
|
138
|
+
feat_dict: dict[str, pd.DataFrame] = {query.entity_table: entity_df}
|
|
139
|
+
time_dict: dict[str, pd.Series] = {}
|
|
140
|
+
time_column = self.time_column_dict.get(query.entity_table)
|
|
141
|
+
if time_column in columns_dict[query.entity_table]:
|
|
142
|
+
time_dict[query.entity_table] = entity_df[time_column]
|
|
143
|
+
batch_dict: dict[str, np.ndarray] = {
|
|
144
|
+
query.entity_table: np.arange(len(entity_df)),
|
|
145
|
+
}
|
|
146
|
+
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
147
|
+
table_name, fkey, _ = edge_type
|
|
148
|
+
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
149
|
+
table_name=table_name,
|
|
150
|
+
fkey=fkey,
|
|
151
|
+
pkey=entity_df[self.primary_key_dict[query.entity_table]],
|
|
152
|
+
anchor_time=time,
|
|
153
|
+
min_offset=min_offset,
|
|
154
|
+
max_offset=max_offset,
|
|
155
|
+
columns=columns_dict[table_name],
|
|
156
|
+
)
|
|
157
|
+
time_column = self.time_column_dict.get(table_name)
|
|
158
|
+
if time_column in columns_dict[table_name]:
|
|
159
|
+
time_dict[table_name] = feat_dict[table_name][time_column]
|
|
160
|
+
|
|
161
|
+
y, mask = PQueryPandasExecutor().execute(
|
|
162
|
+
query=query,
|
|
163
|
+
feat_dict=feat_dict,
|
|
164
|
+
time_dict=time_dict,
|
|
165
|
+
batch_dict=batch_dict,
|
|
166
|
+
anchor_time=time,
|
|
167
|
+
num_forecasts=query.num_forecasts,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
train_mask = mask[:len(train_index)]
|
|
171
|
+
test_mask = mask[len(train_index):]
|
|
172
|
+
|
|
173
|
+
boundary = int(train_mask.sum())
|
|
174
|
+
train_y = y.iloc[:boundary]
|
|
175
|
+
test_y = y.iloc[boundary:].reset_index(drop=True)
|
|
176
|
+
|
|
177
|
+
return train_y, train_mask, test_y, test_mask
|
|
178
|
+
|
|
179
|
+
def _by_pkey(
|
|
180
|
+
self,
|
|
181
|
+
table_name: str,
|
|
182
|
+
pkey: pd.Series,
|
|
183
|
+
columns: set[str],
|
|
184
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
185
|
+
key = self.primary_key_dict[table_name]
|
|
186
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
187
|
+
projections = [
|
|
188
|
+
self.table_column_proj_dict[table_name][column]
|
|
189
|
+
for column in columns
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
payload = json.dumps(list(pkey))
|
|
193
|
+
|
|
194
|
+
sql = ("WITH TMP as (\n"
|
|
195
|
+
" SELECT\n"
|
|
196
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
197
|
+
if self.table_dtype_dict[table_name][key].is_int():
|
|
198
|
+
sql += " f.value::NUMBER as __KUMO_ID__\n"
|
|
199
|
+
elif self.table_dtype_dict[table_name][key].is_float():
|
|
200
|
+
sql += " f.value::FLOAT as __KUMO_ID__\n"
|
|
201
|
+
else:
|
|
202
|
+
sql += " f.value::VARCHAR as __KUMO_ID__\n"
|
|
203
|
+
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
204
|
+
f")\n"
|
|
205
|
+
f"SELECT "
|
|
206
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
207
|
+
f"{', '.join(projections)}\n"
|
|
208
|
+
f"FROM TMP\n"
|
|
209
|
+
f"JOIN {self.source_name_dict[table_name]} ENT\n"
|
|
210
|
+
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
211
|
+
|
|
212
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
213
|
+
cursor.execute(sql, (payload, ))
|
|
214
|
+
table = cursor.fetch_arrow_all()
|
|
215
|
+
|
|
216
|
+
# Remove any duplicated primary keys in post-processing:
|
|
217
|
+
tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
|
|
218
|
+
gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
|
|
219
|
+
table = table.take(gb['__KUMO_ID___min'])
|
|
220
|
+
|
|
221
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
222
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
223
|
+
table = table.remove_column(batch_index)
|
|
224
|
+
|
|
225
|
+
return Table._sanitize(
|
|
226
|
+
df=table.to_pandas(),
|
|
227
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
228
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
229
|
+
), batch
|
|
230
|
+
|
|
231
|
+
# Helper Methods ##########################################################
|
|
232
|
+
|
|
233
|
+
def _by_time(
|
|
234
|
+
self,
|
|
235
|
+
table_name: str,
|
|
236
|
+
fkey: str,
|
|
237
|
+
pkey: pd.Series,
|
|
238
|
+
anchor_time: pd.Series,
|
|
239
|
+
min_offset: pd.DateOffset | None,
|
|
240
|
+
max_offset: pd.DateOffset,
|
|
241
|
+
columns: set[str],
|
|
242
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
243
|
+
time_column = self.time_column_dict[table_name]
|
|
244
|
+
|
|
245
|
+
end_time = anchor_time + max_offset
|
|
246
|
+
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
247
|
+
if min_offset is not None:
|
|
248
|
+
start_time = anchor_time + min_offset
|
|
249
|
+
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
250
|
+
payload = json.dumps(list(zip(pkey, end_time, start_time)))
|
|
251
|
+
else:
|
|
252
|
+
payload = json.dumps(list(zip(pkey, end_time)))
|
|
253
|
+
|
|
254
|
+
key_ref = self.table_column_ref_dict[table_name][fkey]
|
|
255
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
256
|
+
projections = [
|
|
257
|
+
self.table_column_proj_dict[table_name][column]
|
|
258
|
+
for column in columns
|
|
259
|
+
]
|
|
260
|
+
sql = ("WITH TMP as (\n"
|
|
261
|
+
" SELECT\n"
|
|
262
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
263
|
+
if self.table_dtype_dict[table_name][fkey].is_int():
|
|
264
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
265
|
+
elif self.table_dtype_dict[table_name][fkey].is_float():
|
|
266
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
267
|
+
else:
|
|
268
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
269
|
+
sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
|
|
270
|
+
if min_offset is not None:
|
|
271
|
+
sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
|
|
272
|
+
sql += (f"\n"
|
|
273
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
274
|
+
f")\n"
|
|
275
|
+
f"SELECT "
|
|
276
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
277
|
+
f"{', '.join(projections)}\n"
|
|
278
|
+
f"FROM TMP\n"
|
|
279
|
+
f"JOIN {self.source_name_dict[table_name]} FACT\n"
|
|
280
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
281
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
|
|
282
|
+
if min_offset is not None:
|
|
283
|
+
sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
|
|
284
|
+
|
|
285
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
286
|
+
cursor.execute(sql, (payload, ))
|
|
287
|
+
table = cursor.fetch_arrow_all()
|
|
288
|
+
|
|
289
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
290
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
291
|
+
table = table.remove_column(batch_index)
|
|
292
|
+
|
|
293
|
+
return Table._sanitize(
|
|
294
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
295
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
296
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
297
|
+
), batch
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from kumoapi.model_plan import MissingType
|
|
8
|
+
from kumoapi.typing import Dtype
|
|
9
|
+
|
|
10
|
+
from kumoai.experimental.rfm.backend.snow import Connection
|
|
11
|
+
from kumoai.experimental.rfm.base import (
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
14
|
+
DataBackend,
|
|
15
|
+
SourceColumn,
|
|
16
|
+
SourceForeignKey,
|
|
17
|
+
Table,
|
|
18
|
+
)
|
|
19
|
+
from kumoai.utils import quote_ident
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SnowTable(Table):
|
|
23
|
+
r"""A table backed by a :class:`sqlite` database.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
connection: The connection to a :class:`snowflake` database.
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
30
|
+
database: The database.
|
|
31
|
+
schema: The schema.
|
|
32
|
+
columns: The selected columns of this table.
|
|
33
|
+
primary_key: The name of the primary key of this table, if it exists.
|
|
34
|
+
time_column: The name of the time column of this table, if it exists.
|
|
35
|
+
end_time_column: The name of the end time column of this table, if it
|
|
36
|
+
exists.
|
|
37
|
+
"""
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
connection: Connection,
|
|
41
|
+
name: str,
|
|
42
|
+
source_name: str | None = None,
|
|
43
|
+
database: str | None = None,
|
|
44
|
+
schema: str | None = None,
|
|
45
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
46
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
47
|
+
time_column: str | None = None,
|
|
48
|
+
end_time_column: str | None = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
|
|
51
|
+
if database is None or schema is None:
|
|
52
|
+
with connection.cursor() as cursor:
|
|
53
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
54
|
+
result = cursor.fetchone()
|
|
55
|
+
database = database or result[0]
|
|
56
|
+
assert database is not None
|
|
57
|
+
schema = schema or result[1]
|
|
58
|
+
|
|
59
|
+
if schema is None:
|
|
60
|
+
raise ValueError(f"Unspecified 'schema' for table "
|
|
61
|
+
f"'{source_name or name}' in database "
|
|
62
|
+
f"'{database}'")
|
|
63
|
+
|
|
64
|
+
self._connection = connection
|
|
65
|
+
self._database = database
|
|
66
|
+
self._schema = schema
|
|
67
|
+
|
|
68
|
+
super().__init__(
|
|
69
|
+
name=name,
|
|
70
|
+
source_name=source_name,
|
|
71
|
+
columns=columns,
|
|
72
|
+
primary_key=primary_key,
|
|
73
|
+
time_column=time_column,
|
|
74
|
+
end_time_column=end_time_column,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def source_name(self) -> str:
|
|
79
|
+
names: list[str] = []
|
|
80
|
+
if self._database is not None:
|
|
81
|
+
names.append(self._database)
|
|
82
|
+
if self._schema is not None:
|
|
83
|
+
names.append(self._schema)
|
|
84
|
+
return '.'.join(names + [self._source_name])
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def _quoted_source_name(self) -> str:
|
|
88
|
+
names: list[str] = []
|
|
89
|
+
if self._database is not None:
|
|
90
|
+
names.append(quote_ident(self._database))
|
|
91
|
+
if self._schema is not None:
|
|
92
|
+
names.append(quote_ident(self._schema))
|
|
93
|
+
return '.'.join(names + [quote_ident(self._source_name)])
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def backend(self) -> DataBackend:
|
|
97
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
98
|
+
|
|
99
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
100
|
+
source_columns: list[SourceColumn] = []
|
|
101
|
+
with self._connection.cursor() as cursor:
|
|
102
|
+
try:
|
|
103
|
+
sql = f"DESCRIBE TABLE {self._quoted_source_name}"
|
|
104
|
+
cursor.execute(sql)
|
|
105
|
+
except Exception as e:
|
|
106
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
107
|
+
f"in the remote data backend") from e
|
|
108
|
+
|
|
109
|
+
for row in cursor.fetchall():
|
|
110
|
+
column, dtype, _, null, _, is_pkey, is_unique, *_ = row
|
|
111
|
+
|
|
112
|
+
source_column = SourceColumn(
|
|
113
|
+
name=column,
|
|
114
|
+
dtype=self._to_dtype(dtype),
|
|
115
|
+
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
116
|
+
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
117
|
+
is_nullable=null.strip().upper() == 'Y',
|
|
118
|
+
)
|
|
119
|
+
source_columns.append(source_column)
|
|
120
|
+
|
|
121
|
+
return source_columns
|
|
122
|
+
|
|
123
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
124
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
125
|
+
with self._connection.cursor() as cursor:
|
|
126
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
|
|
127
|
+
cursor.execute(sql)
|
|
128
|
+
rows = cursor.fetchall()
|
|
129
|
+
counts = Counter(row[13] for row in rows)
|
|
130
|
+
for row in rows:
|
|
131
|
+
if counts[row[13]] == 1:
|
|
132
|
+
source_foreign_key = SourceForeignKey(
|
|
133
|
+
name=row[8],
|
|
134
|
+
dst_table=f'{row[1]}.{row[2]}.{row[3]}',
|
|
135
|
+
primary_key=row[4],
|
|
136
|
+
)
|
|
137
|
+
source_foreign_keys.append(source_foreign_key)
|
|
138
|
+
return source_foreign_keys
|
|
139
|
+
|
|
140
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
141
|
+
with self._connection.cursor() as cursor:
|
|
142
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
143
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
144
|
+
f"FROM {self._quoted_source_name} "
|
|
145
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
146
|
+
cursor.execute(sql)
|
|
147
|
+
table = cursor.fetch_arrow_all()
|
|
148
|
+
|
|
149
|
+
if table is None:
|
|
150
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
151
|
+
|
|
152
|
+
return self._sanitize(
|
|
153
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
154
|
+
dtype_dict={
|
|
155
|
+
column.name: column.dtype
|
|
156
|
+
for column in self._source_column_dict.values()
|
|
157
|
+
},
|
|
158
|
+
stype_dict=None,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _get_num_rows(self) -> int | None:
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
def _get_expr_sample_df(
|
|
165
|
+
self,
|
|
166
|
+
columns: Sequence[ColumnSpec],
|
|
167
|
+
) -> pd.DataFrame:
|
|
168
|
+
with self._connection.cursor() as cursor:
|
|
169
|
+
projections = [
|
|
170
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
171
|
+
for column in columns
|
|
172
|
+
]
|
|
173
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
174
|
+
f"FROM {self._quoted_source_name} "
|
|
175
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
176
|
+
cursor.execute(sql)
|
|
177
|
+
table = cursor.fetch_arrow_all()
|
|
178
|
+
|
|
179
|
+
if table is None:
|
|
180
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
181
|
+
|
|
182
|
+
return self._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict={column.name: column.dtype
|
|
185
|
+
for column in columns},
|
|
186
|
+
stype_dict=None,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
191
|
+
if dtype is None:
|
|
192
|
+
return None
|
|
193
|
+
dtype = dtype.strip().upper()
|
|
194
|
+
if dtype.startswith('NUMBER'):
|
|
195
|
+
try: # Parse `scale` from 'NUMBER(precision, scale)':
|
|
196
|
+
scale = int(dtype.split(',')[-1].split(')')[0])
|
|
197
|
+
return Dtype.int if scale == 0 else Dtype.float
|
|
198
|
+
except Exception:
|
|
199
|
+
return Dtype.float
|
|
200
|
+
if dtype == 'FLOAT':
|
|
201
|
+
return Dtype.float
|
|
202
|
+
if dtype.startswith('VARCHAR'):
|
|
203
|
+
return Dtype.string
|
|
204
|
+
if dtype.startswith('BINARY'):
|
|
205
|
+
return Dtype.binary
|
|
206
|
+
if dtype == 'BOOLEAN':
|
|
207
|
+
return Dtype.bool
|
|
208
|
+
if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
|
|
209
|
+
return Dtype.date
|
|
210
|
+
if dtype.startswith('TIME'):
|
|
211
|
+
return Dtype.time
|
|
212
|
+
if dtype.startswith('VECTOR'):
|
|
213
|
+
try: # Parse element data type from 'VECTOR(dtype, dimension)':
|
|
214
|
+
dtype = dtype.split(',')[0].split('(')[1].strip()
|
|
215
|
+
if dtype == 'INT':
|
|
216
|
+
return Dtype.intlist
|
|
217
|
+
elif dtype == 'FLOAT':
|
|
218
|
+
return Dtype.floatlist
|
|
219
|
+
except Exception:
|
|
220
|
+
pass
|
|
221
|
+
return Dtype.unsupported
|
|
222
|
+
if dtype.startswith('ARRAY'):
|
|
223
|
+
try: # Parse element data type from 'ARRAY(dtype)':
|
|
224
|
+
dtype = dtype.split('(', maxsplit=1)[1]
|
|
225
|
+
dtype = dtype.rsplit(')', maxsplit=1)[0]
|
|
226
|
+
_dtype = SnowTable._to_dtype(dtype)
|
|
227
|
+
if _dtype is not None and _dtype.is_int():
|
|
228
|
+
return Dtype.intlist
|
|
229
|
+
elif _dtype is not None and _dtype.is_float():
|
|
230
|
+
return Dtype.floatlist
|
|
231
|
+
elif _dtype is not None and _dtype.is_string():
|
|
232
|
+
return Dtype.stringlist
|
|
233
|
+
except Exception:
|
|
234
|
+
pass
|
|
235
|
+
return Dtype.unsupported
|
|
236
|
+
# Unsupported data types:
|
|
237
|
+
if re.search(
|
|
238
|
+
'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
|
|
239
|
+
dtype,
|
|
240
|
+
):
|
|
241
|
+
return Dtype.unsupported
|
|
242
|
+
return None
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
import adbc_driver_sqlite.dbapi as adbc
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError("No module named 'adbc_driver_sqlite'. Please install "
|
|
8
|
+
"Kumo SDK with the 'sqlite' extension via "
|
|
9
|
+
"`pip install kumoai[sqlite]`.")
|
|
10
|
+
|
|
11
|
+
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
|
|
15
|
+
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
|
+
|
|
17
|
+
uri: The path to the database file to be opened.
|
|
18
|
+
kwargs: Additional connection arguments, following the
|
|
19
|
+
:class:`adbc_driver_sqlite` protocol.
|
|
20
|
+
"""
|
|
21
|
+
return adbc.connect(uri, **kwargs)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
'connect',
|
|
29
|
+
'Connection',
|
|
30
|
+
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
32
|
+
]
|