kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +184 -70
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +191 -86
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +26 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +182 -19
- kumoai/experimental/rfm/base/table.py +275 -109
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +530 -304
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +36 -33
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from collections.abc import Iterator
|
|
3
3
|
from contextlib import contextmanager
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pandas as pd
|
|
7
8
|
import pyarrow as pa
|
|
8
9
|
from kumoapi.pquery import ValidatedPredictiveQuery
|
|
9
10
|
|
|
10
|
-
from kumoai.experimental.rfm.backend.snow import Connection
|
|
11
|
-
from kumoai.experimental.rfm.base import SQLSampler
|
|
11
|
+
from kumoai.experimental.rfm.backend.snow import Connection, SnowTable
|
|
12
|
+
from kumoai.experimental.rfm.base import SQLSampler, Table
|
|
12
13
|
from kumoai.experimental.rfm.pquery import PQueryPandasExecutor
|
|
13
|
-
from kumoai.utils import
|
|
14
|
+
from kumoai.utils import ProgressLogger
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from kumoai.experimental.rfm import Graph
|
|
14
18
|
|
|
15
19
|
|
|
16
20
|
@contextmanager
|
|
@@ -22,18 +26,30 @@ def paramstyle(connection: Connection, style: str = 'qmark') -> Iterator[None]:
|
|
|
22
26
|
|
|
23
27
|
|
|
24
28
|
class SnowSampler(SQLSampler):
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
graph: 'Graph',
|
|
32
|
+
verbose: bool | ProgressLogger = True,
|
|
33
|
+
) -> None:
|
|
34
|
+
super().__init__(graph=graph, verbose=verbose)
|
|
35
|
+
|
|
36
|
+
for table in graph.tables.values():
|
|
37
|
+
assert isinstance(table, SnowTable)
|
|
38
|
+
self._connection = table._connection
|
|
39
|
+
|
|
25
40
|
def _get_min_max_time_dict(
|
|
26
41
|
self,
|
|
27
42
|
table_names: list[str],
|
|
28
43
|
) -> dict[str, tuple[pd.Timestamp, pd.Timestamp]]:
|
|
29
44
|
selects: list[str] = []
|
|
30
45
|
for table_name in table_names:
|
|
31
|
-
|
|
46
|
+
column = self.time_column_dict[table_name]
|
|
47
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
32
48
|
select = (f"SELECT\n"
|
|
33
49
|
f" ? as table_name,\n"
|
|
34
|
-
f" MIN({
|
|
35
|
-
f" MAX({
|
|
36
|
-
f"FROM {self.
|
|
50
|
+
f" MIN({column_ref}) as min_date,\n"
|
|
51
|
+
f" MAX({column_ref}) as max_date\n"
|
|
52
|
+
f"FROM {self.source_name_dict[table_name]}")
|
|
37
53
|
selects.append(select)
|
|
38
54
|
sql = "\nUNION ALL\n".join(selects)
|
|
39
55
|
|
|
@@ -59,17 +75,27 @@ class SnowSampler(SQLSampler):
|
|
|
59
75
|
# NOTE Snowflake does support `SEED` only as part of `SYSTEM` sampling.
|
|
60
76
|
num_rows = min(num_rows, 1_000_000) # Snowflake's upper limit.
|
|
61
77
|
|
|
78
|
+
source_table = self.source_table_dict[table_name]
|
|
62
79
|
filters: list[str] = []
|
|
63
|
-
primary_key = self.primary_key_dict[table_name]
|
|
64
|
-
if self.source_table_dict[table_name][primary_key].is_nullable:
|
|
65
|
-
filters.append(f" {quote_ident(primary_key)} IS NOT NULL")
|
|
66
|
-
time_column = self.time_column_dict.get(table_name)
|
|
67
|
-
if (time_column is not None and
|
|
68
|
-
self.source_table_dict[table_name][time_column].is_nullable):
|
|
69
|
-
filters.append(f" {quote_ident(time_column)} IS NOT NULL")
|
|
70
80
|
|
|
71
|
-
|
|
72
|
-
|
|
81
|
+
key = self.primary_key_dict[table_name]
|
|
82
|
+
if key not in source_table or source_table[key].is_nullable:
|
|
83
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
84
|
+
filters.append(f" {key_ref} IS NOT NULL")
|
|
85
|
+
|
|
86
|
+
column = self.time_column_dict.get(table_name)
|
|
87
|
+
if column is None:
|
|
88
|
+
pass
|
|
89
|
+
elif column not in source_table or source_table[column].is_nullable:
|
|
90
|
+
column_ref = self.table_column_ref_dict[table_name][column]
|
|
91
|
+
filters.append(f" {column_ref} IS NOT NULL")
|
|
92
|
+
|
|
93
|
+
projections = [
|
|
94
|
+
self.table_column_proj_dict[table_name][column]
|
|
95
|
+
for column in columns
|
|
96
|
+
]
|
|
97
|
+
sql = (f"SELECT {', '.join(projections)}\n"
|
|
98
|
+
f"FROM {self.source_name_dict[table_name]}\n"
|
|
73
99
|
f"SAMPLE ROW ({num_rows} ROWS)")
|
|
74
100
|
if len(filters) > 0:
|
|
75
101
|
sql += f"\nWHERE{' AND'.join(filters)}"
|
|
@@ -79,7 +105,11 @@ class SnowSampler(SQLSampler):
|
|
|
79
105
|
cursor.execute(sql)
|
|
80
106
|
table = cursor.fetch_arrow_all()
|
|
81
107
|
|
|
82
|
-
return
|
|
108
|
+
return Table._sanitize(
|
|
109
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
110
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
111
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
112
|
+
)
|
|
83
113
|
|
|
84
114
|
def _sample_target(
|
|
85
115
|
self,
|
|
@@ -114,11 +144,11 @@ class SnowSampler(SQLSampler):
|
|
|
114
144
|
query.entity_table: np.arange(len(entity_df)),
|
|
115
145
|
}
|
|
116
146
|
for edge_type, (min_offset, max_offset) in time_offset_dict.items():
|
|
117
|
-
table_name,
|
|
147
|
+
table_name, foreign_key, _ = edge_type
|
|
118
148
|
feat_dict[table_name], batch_dict[table_name] = self._by_time(
|
|
119
149
|
table_name=table_name,
|
|
120
|
-
|
|
121
|
-
|
|
150
|
+
foreign_key=foreign_key,
|
|
151
|
+
index=entity_df[self.primary_key_dict[query.entity_table]],
|
|
122
152
|
anchor_time=time,
|
|
123
153
|
min_offset=min_offset,
|
|
124
154
|
max_offset=max_offset,
|
|
@@ -149,104 +179,188 @@ class SnowSampler(SQLSampler):
|
|
|
149
179
|
def _by_pkey(
|
|
150
180
|
self,
|
|
151
181
|
table_name: str,
|
|
152
|
-
|
|
182
|
+
index: pd.Series,
|
|
153
183
|
columns: set[str],
|
|
154
184
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
185
|
+
key = self.primary_key_dict[table_name]
|
|
186
|
+
key_ref = self.table_column_ref_dict[table_name][key]
|
|
187
|
+
projections = [
|
|
188
|
+
self.table_column_proj_dict[table_name][column]
|
|
189
|
+
for column in columns
|
|
190
|
+
]
|
|
155
191
|
|
|
156
|
-
|
|
157
|
-
source_table = self.source_table_dict[table_name]
|
|
158
|
-
|
|
159
|
-
payload = json.dumps(list(pkey))
|
|
192
|
+
payload = json.dumps(list(index))
|
|
160
193
|
|
|
161
194
|
sql = ("WITH TMP as (\n"
|
|
162
195
|
" SELECT\n"
|
|
163
|
-
" f.index as
|
|
164
|
-
if
|
|
165
|
-
sql += " f.value::NUMBER as
|
|
166
|
-
elif
|
|
167
|
-
sql += " f.value::FLOAT as
|
|
196
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
197
|
+
if self.table_dtype_dict[table_name][key].is_int():
|
|
198
|
+
sql += " f.value::NUMBER as __KUMO_ID__\n"
|
|
199
|
+
elif self.table_dtype_dict[table_name][key].is_float():
|
|
200
|
+
sql += " f.value::FLOAT as __KUMO_ID__\n"
|
|
168
201
|
else:
|
|
169
|
-
sql += " f.value::VARCHAR as
|
|
202
|
+
sql += " f.value::VARCHAR as __KUMO_ID__\n"
|
|
170
203
|
sql += (f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
171
204
|
f")\n"
|
|
172
|
-
f"SELECT
|
|
173
|
-
f"
|
|
205
|
+
f"SELECT "
|
|
206
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
207
|
+
f"{', '.join(projections)}\n"
|
|
174
208
|
f"FROM TMP\n"
|
|
175
|
-
f"JOIN {self.
|
|
176
|
-
f" ON
|
|
209
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
210
|
+
f" ON {key_ref} = TMP.__KUMO_ID__")
|
|
177
211
|
|
|
178
212
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
179
213
|
cursor.execute(sql, (payload, ))
|
|
180
214
|
table = cursor.fetch_arrow_all()
|
|
181
215
|
|
|
182
216
|
# Remove any duplicated primary keys in post-processing:
|
|
183
|
-
tmp = table.append_column('
|
|
184
|
-
gb = tmp.group_by('
|
|
185
|
-
table = table.take(gb['
|
|
217
|
+
tmp = table.append_column('__KUMO_ID__', pa.array(range(len(table))))
|
|
218
|
+
gb = tmp.group_by('__KUMO_BATCH__').aggregate([('__KUMO_ID__', 'min')])
|
|
219
|
+
table = table.take(gb['__KUMO_ID___min'])
|
|
220
|
+
|
|
221
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
222
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
223
|
+
table = table.remove_column(batch_index)
|
|
224
|
+
|
|
225
|
+
return Table._sanitize(
|
|
226
|
+
df=table.to_pandas(),
|
|
227
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
228
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
229
|
+
), batch
|
|
230
|
+
|
|
231
|
+
def _by_fkey(
|
|
232
|
+
self,
|
|
233
|
+
table_name: str,
|
|
234
|
+
foreign_key: str,
|
|
235
|
+
index: pd.Series,
|
|
236
|
+
num_neighbors: int,
|
|
237
|
+
anchor_time: pd.Series | None,
|
|
238
|
+
columns: set[str],
|
|
239
|
+
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
240
|
+
time_column = self.time_column_dict.get(table_name)
|
|
241
|
+
|
|
242
|
+
if time_column is not None and anchor_time is not None:
|
|
243
|
+
anchor_time = anchor_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
244
|
+
payload = json.dumps(list(zip(index, anchor_time)))
|
|
245
|
+
else:
|
|
246
|
+
payload = json.dumps(list(zip(index)))
|
|
186
247
|
|
|
187
|
-
|
|
188
|
-
|
|
248
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
249
|
+
projections = [
|
|
250
|
+
self.table_column_proj_dict[table_name][column]
|
|
251
|
+
for column in columns
|
|
252
|
+
]
|
|
189
253
|
|
|
190
|
-
|
|
254
|
+
sql = ("WITH TMP as (\n"
|
|
255
|
+
" SELECT\n"
|
|
256
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
257
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
258
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__"
|
|
259
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
260
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__"
|
|
261
|
+
else:
|
|
262
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__"
|
|
263
|
+
if time_column is not None and anchor_time is not None:
|
|
264
|
+
sql += (",\n"
|
|
265
|
+
" f.value[1]::TIMESTAMP_NTZ as __KUMO_TIME__")
|
|
266
|
+
sql += (f"\n"
|
|
267
|
+
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
268
|
+
f")\n"
|
|
269
|
+
f"SELECT "
|
|
270
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
271
|
+
f"{', '.join(projections)}\n"
|
|
272
|
+
f"FROM TMP\n"
|
|
273
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
274
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n")
|
|
275
|
+
if time_column is not None and anchor_time is not None:
|
|
276
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
277
|
+
sql += f" AND {time_ref} <= TMP.__KUMO_TIME__\n"
|
|
278
|
+
sql += ("QUALIFY ROW_NUMBER() OVER (\n"
|
|
279
|
+
" PARTITION BY TMP.__KUMO_BATCH__\n")
|
|
280
|
+
if time_column is not None:
|
|
281
|
+
sql += f" ORDER BY {time_ref} DESC\n"
|
|
282
|
+
else:
|
|
283
|
+
sql += f" ORDER BY {key_ref}\n"
|
|
284
|
+
sql += f") <= {num_neighbors}"
|
|
285
|
+
|
|
286
|
+
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
287
|
+
cursor.execute(sql, (payload, ))
|
|
288
|
+
table = cursor.fetch_arrow_all()
|
|
289
|
+
|
|
290
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
291
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
292
|
+
table = table.remove_column(batch_index)
|
|
293
|
+
|
|
294
|
+
return Table._sanitize(
|
|
295
|
+
df=table.to_pandas(),
|
|
296
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
297
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
298
|
+
), batch
|
|
191
299
|
|
|
192
300
|
# Helper Methods ##########################################################
|
|
193
301
|
|
|
194
302
|
def _by_time(
|
|
195
303
|
self,
|
|
196
304
|
table_name: str,
|
|
197
|
-
|
|
198
|
-
|
|
305
|
+
foreign_key: str,
|
|
306
|
+
index: pd.Series,
|
|
199
307
|
anchor_time: pd.Series,
|
|
200
308
|
min_offset: pd.DateOffset | None,
|
|
201
309
|
max_offset: pd.DateOffset,
|
|
202
310
|
columns: set[str],
|
|
203
311
|
) -> tuple[pd.DataFrame, np.ndarray]:
|
|
312
|
+
time_column = self.time_column_dict[table_name]
|
|
204
313
|
|
|
205
314
|
end_time = anchor_time + max_offset
|
|
206
315
|
end_time = end_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
207
316
|
if min_offset is not None:
|
|
208
317
|
start_time = anchor_time + min_offset
|
|
209
318
|
start_time = start_time.dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
210
|
-
payload = json.dumps(list(zip(
|
|
319
|
+
payload = json.dumps(list(zip(index, end_time, start_time)))
|
|
211
320
|
else:
|
|
212
|
-
payload = json.dumps(list(zip(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
321
|
+
payload = json.dumps(list(zip(index, end_time)))
|
|
322
|
+
|
|
323
|
+
key_ref = self.table_column_ref_dict[table_name][foreign_key]
|
|
324
|
+
time_ref = self.table_column_ref_dict[table_name][time_column]
|
|
325
|
+
projections = [
|
|
326
|
+
self.table_column_proj_dict[table_name][column]
|
|
327
|
+
for column in columns
|
|
328
|
+
]
|
|
218
329
|
sql = ("WITH TMP as (\n"
|
|
219
330
|
" SELECT\n"
|
|
220
|
-
" f.index as
|
|
221
|
-
if
|
|
222
|
-
sql += " f.value[0]::NUMBER as
|
|
223
|
-
elif
|
|
224
|
-
sql += " f.value[0]::FLOAT as
|
|
331
|
+
" f.index as __KUMO_BATCH__,\n")
|
|
332
|
+
if self.table_dtype_dict[table_name][foreign_key].is_int():
|
|
333
|
+
sql += " f.value[0]::NUMBER as __KUMO_ID__,\n"
|
|
334
|
+
elif self.table_dtype_dict[table_name][foreign_key].is_float():
|
|
335
|
+
sql += " f.value[0]::FLOAT as __KUMO_ID__,\n"
|
|
225
336
|
else:
|
|
226
|
-
sql += " f.value[0]::VARCHAR as
|
|
227
|
-
sql += " f.value[1]::TIMESTAMP_NTZ as
|
|
337
|
+
sql += " f.value[0]::VARCHAR as __KUMO_ID__,\n"
|
|
338
|
+
sql += " f.value[1]::TIMESTAMP_NTZ as __KUMO_END_TIME__"
|
|
228
339
|
if min_offset is not None:
|
|
229
|
-
sql += ",\n f.value[2]::TIMESTAMP_NTZ as
|
|
340
|
+
sql += ",\n f.value[2]::TIMESTAMP_NTZ as __KUMO_START_TIME__"
|
|
230
341
|
sql += (f"\n"
|
|
231
342
|
f" FROM TABLE(FLATTEN(INPUT => PARSE_JSON(?))) f\n"
|
|
232
343
|
f")\n"
|
|
233
|
-
f"SELECT
|
|
234
|
-
f"
|
|
344
|
+
f"SELECT "
|
|
345
|
+
f"TMP.__KUMO_BATCH__ as __KUMO_BATCH__, "
|
|
346
|
+
f"{', '.join(projections)}\n"
|
|
235
347
|
f"FROM TMP\n"
|
|
236
|
-
f"JOIN {self.
|
|
237
|
-
f" ON
|
|
238
|
-
f" AND
|
|
348
|
+
f"JOIN {self.source_name_dict[table_name]}\n"
|
|
349
|
+
f" ON {key_ref} = TMP.__KUMO_ID__\n"
|
|
350
|
+
f" AND {time_ref} <= TMP.__KUMO_END_TIME__")
|
|
239
351
|
if min_offset is not None:
|
|
240
|
-
sql += f"\n AND
|
|
352
|
+
sql += f"\n AND {time_ref} > TMP.__KUMO_START_TIME__"
|
|
241
353
|
|
|
242
354
|
with paramstyle(self._connection), self._connection.cursor() as cursor:
|
|
243
355
|
cursor.execute(sql, (payload, ))
|
|
244
356
|
table = cursor.fetch_arrow_all()
|
|
245
357
|
|
|
246
|
-
batch = table['
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
return self._sanitize(table_name, table), batch
|
|
358
|
+
batch = table['__KUMO_BATCH__'].cast(pa.int64()).to_numpy()
|
|
359
|
+
batch_index = table.schema.get_field_index('__KUMO_BATCH__')
|
|
360
|
+
table = table.remove_column(batch_index)
|
|
250
361
|
|
|
251
|
-
|
|
252
|
-
|
|
362
|
+
return Table._sanitize(
|
|
363
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
364
|
+
dtype_dict=self.table_dtype_dict[table_name],
|
|
365
|
+
stype_dict=self.table_stype_dict[table_name],
|
|
366
|
+
), batch
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
+
from collections import Counter
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from typing import cast
|
|
4
5
|
|
|
@@ -8,28 +9,27 @@ from kumoapi.typing import Dtype
|
|
|
8
9
|
|
|
9
10
|
from kumoai.experimental.rfm.backend.snow import Connection
|
|
10
11
|
from kumoai.experimental.rfm.base import (
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
13
14
|
DataBackend,
|
|
14
15
|
SourceColumn,
|
|
15
16
|
SourceForeignKey,
|
|
16
|
-
|
|
17
|
+
Table,
|
|
17
18
|
)
|
|
18
19
|
from kumoai.utils import quote_ident
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
class SnowTable(
|
|
22
|
+
class SnowTable(Table):
|
|
22
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
23
24
|
|
|
24
25
|
Args:
|
|
25
26
|
connection: The connection to a :class:`snowflake` database.
|
|
26
|
-
name: The
|
|
27
|
-
source_name: The
|
|
28
|
-
``
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
29
30
|
database: The database.
|
|
30
31
|
schema: The schema.
|
|
31
|
-
columns: The selected
|
|
32
|
-
column_expressions: The logical columns of this table.
|
|
32
|
+
columns: The selected columns of this table.
|
|
33
33
|
primary_key: The name of the primary key of this table, if it exists.
|
|
34
34
|
time_column: The name of the time column of this table, if it exists.
|
|
35
35
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -42,14 +42,21 @@ class SnowTable(SQLTable):
|
|
|
42
42
|
source_name: str | None = None,
|
|
43
43
|
database: str | None = None,
|
|
44
44
|
schema: str | None = None,
|
|
45
|
-
columns: Sequence[
|
|
46
|
-
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
45
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
47
46
|
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
48
47
|
time_column: str | None = None,
|
|
49
48
|
end_time_column: str | None = None,
|
|
50
49
|
) -> None:
|
|
51
50
|
|
|
52
|
-
if database is
|
|
51
|
+
if database is None or schema is None:
|
|
52
|
+
with connection.cursor() as cursor:
|
|
53
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
54
|
+
result = cursor.fetchone()
|
|
55
|
+
database = database or result[0]
|
|
56
|
+
assert database is not None
|
|
57
|
+
schema = schema or result[1]
|
|
58
|
+
|
|
59
|
+
if schema is None:
|
|
53
60
|
raise ValueError(f"Unspecified 'schema' for table "
|
|
54
61
|
f"'{source_name or name}' in database "
|
|
55
62
|
f"'{database}'")
|
|
@@ -62,37 +69,22 @@ class SnowTable(SQLTable):
|
|
|
62
69
|
name=name,
|
|
63
70
|
source_name=source_name,
|
|
64
71
|
columns=columns,
|
|
65
|
-
column_expressions=column_expressions,
|
|
66
72
|
primary_key=primary_key,
|
|
67
73
|
time_column=time_column,
|
|
68
74
|
end_time_column=end_time_column,
|
|
69
75
|
)
|
|
70
76
|
|
|
71
|
-
@staticmethod
|
|
72
|
-
def to_dtype(snowflake_dtype: str | None) -> Dtype | None:
|
|
73
|
-
if snowflake_dtype is None:
|
|
74
|
-
return None
|
|
75
|
-
snowflake_dtype = snowflake_dtype.strip().upper()
|
|
76
|
-
# TODO 'NUMBER(...)' is not always an integer!
|
|
77
|
-
if snowflake_dtype.startswith('NUMBER'):
|
|
78
|
-
return Dtype.int
|
|
79
|
-
elif snowflake_dtype.startswith('VARCHAR'):
|
|
80
|
-
return Dtype.string
|
|
81
|
-
elif snowflake_dtype == 'FLOAT':
|
|
82
|
-
return Dtype.float
|
|
83
|
-
elif snowflake_dtype == 'BOOLEAN':
|
|
84
|
-
return Dtype.bool
|
|
85
|
-
elif re.search('DATE|TIMESTAMP', snowflake_dtype):
|
|
86
|
-
return Dtype.date
|
|
87
|
-
return None
|
|
88
|
-
|
|
89
77
|
@property
|
|
90
|
-
def
|
|
91
|
-
|
|
78
|
+
def source_name(self) -> str:
|
|
79
|
+
names: list[str] = []
|
|
80
|
+
if self._database is not None:
|
|
81
|
+
names.append(self._database)
|
|
82
|
+
if self._schema is not None:
|
|
83
|
+
names.append(self._schema)
|
|
84
|
+
return '.'.join(names + [self._source_name])
|
|
92
85
|
|
|
93
86
|
@property
|
|
94
|
-
def
|
|
95
|
-
r"""The fully-qualified quoted table name."""
|
|
87
|
+
def _quoted_source_name(self) -> str:
|
|
96
88
|
names: list[str] = []
|
|
97
89
|
if self._database is not None:
|
|
98
90
|
names.append(quote_ident(self._database))
|
|
@@ -100,32 +92,26 @@ class SnowTable(SQLTable):
|
|
|
100
92
|
names.append(quote_ident(self._schema))
|
|
101
93
|
return '.'.join(names + [quote_ident(self._source_name)])
|
|
102
94
|
|
|
95
|
+
@property
|
|
96
|
+
def backend(self) -> DataBackend:
|
|
97
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
98
|
+
|
|
103
99
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
104
100
|
source_columns: list[SourceColumn] = []
|
|
105
101
|
with self._connection.cursor() as cursor:
|
|
106
102
|
try:
|
|
107
|
-
sql = f"DESCRIBE TABLE {self.
|
|
103
|
+
sql = f"DESCRIBE TABLE {self._quoted_source_name}"
|
|
108
104
|
cursor.execute(sql)
|
|
109
105
|
except Exception as e:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
names.append(self._database)
|
|
113
|
-
if self._schema is not None:
|
|
114
|
-
names.append(self._schema)
|
|
115
|
-
source_name = '.'.join(names + [self._source_name])
|
|
116
|
-
raise ValueError(f"Table '{source_name}' does not exist in "
|
|
117
|
-
f"the remote data backend") from e
|
|
106
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
107
|
+
f"in the remote data backend") from e
|
|
118
108
|
|
|
119
109
|
for row in cursor.fetchall():
|
|
120
|
-
column,
|
|
121
|
-
|
|
122
|
-
dtype = self.to_dtype(type)
|
|
123
|
-
if dtype is None:
|
|
124
|
-
continue
|
|
110
|
+
column, dtype, _, null, _, is_pkey, is_unique, *_ = row
|
|
125
111
|
|
|
126
112
|
source_column = SourceColumn(
|
|
127
113
|
name=column,
|
|
128
|
-
dtype=dtype,
|
|
114
|
+
dtype=self._to_dtype(dtype),
|
|
129
115
|
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
130
116
|
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
131
117
|
is_nullable=null.strip().upper() == 'Y',
|
|
@@ -135,35 +121,122 @@ class SnowTable(SQLTable):
|
|
|
135
121
|
return source_columns
|
|
136
122
|
|
|
137
123
|
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
138
|
-
|
|
124
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
139
125
|
with self._connection.cursor() as cursor:
|
|
140
|
-
sql = f"SHOW IMPORTED KEYS IN TABLE {self.
|
|
126
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
|
|
141
127
|
cursor.execute(sql)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
128
|
+
rows = cursor.fetchall()
|
|
129
|
+
counts = Counter(row[13] for row in rows)
|
|
130
|
+
for row in rows:
|
|
131
|
+
if counts[row[13]] == 1:
|
|
132
|
+
source_foreign_key = SourceForeignKey(
|
|
133
|
+
name=row[8],
|
|
134
|
+
dst_table=f'{row[1]}.{row[2]}.{row[3]}',
|
|
135
|
+
primary_key=row[4],
|
|
136
|
+
)
|
|
137
|
+
source_foreign_keys.append(source_foreign_key)
|
|
138
|
+
return source_foreign_keys
|
|
146
139
|
|
|
147
140
|
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
148
141
|
with self._connection.cursor() as cursor:
|
|
149
142
|
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
150
|
-
sql = f"SELECT {', '.join(columns)}
|
|
143
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
144
|
+
f"FROM {self._quoted_source_name} "
|
|
145
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
151
146
|
cursor.execute(sql)
|
|
152
147
|
table = cursor.fetch_arrow_all()
|
|
153
|
-
|
|
148
|
+
|
|
149
|
+
if table is None:
|
|
150
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
151
|
+
|
|
152
|
+
return self._sanitize(
|
|
153
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
154
|
+
dtype_dict={
|
|
155
|
+
column.name: column.dtype
|
|
156
|
+
for column in self._source_column_dict.values()
|
|
157
|
+
},
|
|
158
|
+
stype_dict=None,
|
|
159
|
+
)
|
|
154
160
|
|
|
155
161
|
def _get_num_rows(self) -> int | None:
|
|
156
162
|
return None
|
|
157
163
|
|
|
158
|
-
def
|
|
164
|
+
def _get_expr_sample_df(
|
|
159
165
|
self,
|
|
160
|
-
|
|
166
|
+
columns: Sequence[ColumnSpec],
|
|
161
167
|
) -> pd.DataFrame:
|
|
162
168
|
with self._connection.cursor() as cursor:
|
|
163
|
-
|
|
164
|
-
f"{
|
|
169
|
+
projections = [
|
|
170
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
171
|
+
for column in columns
|
|
165
172
|
]
|
|
166
|
-
sql = f"SELECT {', '.join(
|
|
173
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
174
|
+
f"FROM {self._quoted_source_name} "
|
|
175
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
167
176
|
cursor.execute(sql)
|
|
168
177
|
table = cursor.fetch_arrow_all()
|
|
169
|
-
|
|
178
|
+
|
|
179
|
+
if table is None:
|
|
180
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
181
|
+
|
|
182
|
+
return self._sanitize(
|
|
183
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
184
|
+
dtype_dict={column.name: column.dtype
|
|
185
|
+
for column in columns},
|
|
186
|
+
stype_dict=None,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
191
|
+
if dtype is None:
|
|
192
|
+
return None
|
|
193
|
+
dtype = dtype.strip().upper()
|
|
194
|
+
if dtype.startswith('NUMBER'):
|
|
195
|
+
try: # Parse `scale` from 'NUMBER(precision, scale)':
|
|
196
|
+
scale = int(dtype.split(',')[-1].split(')')[0])
|
|
197
|
+
return Dtype.int if scale == 0 else Dtype.float
|
|
198
|
+
except Exception:
|
|
199
|
+
return Dtype.float
|
|
200
|
+
if dtype == 'FLOAT':
|
|
201
|
+
return Dtype.float
|
|
202
|
+
if dtype.startswith('VARCHAR'):
|
|
203
|
+
return Dtype.string
|
|
204
|
+
if dtype.startswith('BINARY'):
|
|
205
|
+
return Dtype.binary
|
|
206
|
+
if dtype == 'BOOLEAN':
|
|
207
|
+
return Dtype.bool
|
|
208
|
+
if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
|
|
209
|
+
return Dtype.date
|
|
210
|
+
if dtype.startswith('TIME'):
|
|
211
|
+
return Dtype.time
|
|
212
|
+
if dtype.startswith('VECTOR'):
|
|
213
|
+
try: # Parse element data type from 'VECTOR(dtype, dimension)':
|
|
214
|
+
dtype = dtype.split(',')[0].split('(')[1].strip()
|
|
215
|
+
if dtype == 'INT':
|
|
216
|
+
return Dtype.intlist
|
|
217
|
+
elif dtype == 'FLOAT':
|
|
218
|
+
return Dtype.floatlist
|
|
219
|
+
except Exception:
|
|
220
|
+
pass
|
|
221
|
+
return Dtype.unsupported
|
|
222
|
+
if dtype.startswith('ARRAY'):
|
|
223
|
+
try: # Parse element data type from 'ARRAY(dtype)':
|
|
224
|
+
dtype = dtype.split('(', maxsplit=1)[1]
|
|
225
|
+
dtype = dtype.rsplit(')', maxsplit=1)[0]
|
|
226
|
+
_dtype = SnowTable._to_dtype(dtype)
|
|
227
|
+
if _dtype is not None and _dtype.is_int():
|
|
228
|
+
return Dtype.intlist
|
|
229
|
+
elif _dtype is not None and _dtype.is_float():
|
|
230
|
+
return Dtype.floatlist
|
|
231
|
+
elif _dtype is not None and _dtype.is_string():
|
|
232
|
+
return Dtype.stringlist
|
|
233
|
+
except Exception:
|
|
234
|
+
pass
|
|
235
|
+
return Dtype.unsupported
|
|
236
|
+
# Unsupported data types:
|
|
237
|
+
if re.search(
|
|
238
|
+
'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
|
|
239
|
+
dtype,
|
|
240
|
+
):
|
|
241
|
+
return Dtype.unsupported
|
|
242
|
+
return None
|