kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.15.0.dev202601141731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +15 -2
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +40 -35
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
|
+
from collections import Counter
|
|
2
3
|
from collections.abc import Sequence
|
|
3
4
|
from typing import cast
|
|
4
5
|
|
|
@@ -8,28 +9,27 @@ from kumoapi.typing import Dtype
|
|
|
8
9
|
|
|
9
10
|
from kumoai.experimental.rfm.backend.snow import Connection
|
|
10
11
|
from kumoai.experimental.rfm.base import (
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
13
14
|
DataBackend,
|
|
14
15
|
SourceColumn,
|
|
15
16
|
SourceForeignKey,
|
|
16
|
-
|
|
17
|
+
Table,
|
|
17
18
|
)
|
|
18
19
|
from kumoai.utils import quote_ident
|
|
19
20
|
|
|
20
21
|
|
|
21
|
-
class SnowTable(
|
|
22
|
+
class SnowTable(Table):
|
|
22
23
|
r"""A table backed by a :class:`sqlite` database.
|
|
23
24
|
|
|
24
25
|
Args:
|
|
25
26
|
connection: The connection to a :class:`snowflake` database.
|
|
26
|
-
name: The
|
|
27
|
-
source_name: The
|
|
28
|
-
``
|
|
27
|
+
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
29
30
|
database: The database.
|
|
30
31
|
schema: The schema.
|
|
31
|
-
columns: The selected
|
|
32
|
-
column_expressions: The logical columns of this table.
|
|
32
|
+
columns: The selected columns of this table.
|
|
33
33
|
primary_key: The name of the primary key of this table, if it exists.
|
|
34
34
|
time_column: The name of the time column of this table, if it exists.
|
|
35
35
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -42,14 +42,21 @@ class SnowTable(SQLTable):
|
|
|
42
42
|
source_name: str | None = None,
|
|
43
43
|
database: str | None = None,
|
|
44
44
|
schema: str | None = None,
|
|
45
|
-
columns: Sequence[
|
|
46
|
-
column_expressions: Sequence[ColumnExpressionType] | None = None,
|
|
45
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
47
46
|
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
48
47
|
time_column: str | None = None,
|
|
49
48
|
end_time_column: str | None = None,
|
|
50
49
|
) -> None:
|
|
51
50
|
|
|
52
|
-
if database is
|
|
51
|
+
if database is None or schema is None:
|
|
52
|
+
with connection.cursor() as cursor:
|
|
53
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
54
|
+
result = cursor.fetchone()
|
|
55
|
+
database = database or result[0]
|
|
56
|
+
assert database is not None
|
|
57
|
+
schema = schema or result[1]
|
|
58
|
+
|
|
59
|
+
if schema is None:
|
|
53
60
|
raise ValueError(f"Unspecified 'schema' for table "
|
|
54
61
|
f"'{source_name or name}' in database "
|
|
55
62
|
f"'{database}'")
|
|
@@ -62,70 +69,41 @@ class SnowTable(SQLTable):
|
|
|
62
69
|
name=name,
|
|
63
70
|
source_name=source_name,
|
|
64
71
|
columns=columns,
|
|
65
|
-
column_expressions=column_expressions,
|
|
66
72
|
primary_key=primary_key,
|
|
67
73
|
time_column=time_column,
|
|
68
74
|
end_time_column=end_time_column,
|
|
69
75
|
)
|
|
70
76
|
|
|
71
|
-
@
|
|
72
|
-
def
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
snowflake_dtype = snowflake_dtype.strip().upper()
|
|
76
|
-
# TODO 'NUMBER(...)' is not always an integer!
|
|
77
|
-
if snowflake_dtype.startswith('NUMBER'):
|
|
78
|
-
return Dtype.int
|
|
79
|
-
elif snowflake_dtype.startswith('VARCHAR'):
|
|
80
|
-
return Dtype.string
|
|
81
|
-
elif snowflake_dtype == 'FLOAT':
|
|
82
|
-
return Dtype.float
|
|
83
|
-
elif snowflake_dtype == 'BOOLEAN':
|
|
84
|
-
return Dtype.bool
|
|
85
|
-
elif re.search('DATE|TIMESTAMP', snowflake_dtype):
|
|
86
|
-
return Dtype.date
|
|
87
|
-
return None
|
|
77
|
+
@property
|
|
78
|
+
def source_name(self) -> str:
|
|
79
|
+
names = [self._database, self._schema, self._source_name]
|
|
80
|
+
return '.'.join(names)
|
|
88
81
|
|
|
89
82
|
@property
|
|
90
|
-
def
|
|
91
|
-
|
|
83
|
+
def _quoted_source_name(self) -> str:
|
|
84
|
+
names = [self._database, self._schema, self._source_name]
|
|
85
|
+
return '.'.join([quote_ident(name) for name in names])
|
|
92
86
|
|
|
93
87
|
@property
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
names: list[str] = []
|
|
97
|
-
if self._database is not None:
|
|
98
|
-
names.append(quote_ident(self._database))
|
|
99
|
-
if self._schema is not None:
|
|
100
|
-
names.append(quote_ident(self._schema))
|
|
101
|
-
return '.'.join(names + [quote_ident(self._source_name)])
|
|
88
|
+
def backend(self) -> DataBackend:
|
|
89
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
102
90
|
|
|
103
91
|
def _get_source_columns(self) -> list[SourceColumn]:
|
|
104
92
|
source_columns: list[SourceColumn] = []
|
|
105
93
|
with self._connection.cursor() as cursor:
|
|
106
94
|
try:
|
|
107
|
-
sql = f"DESCRIBE TABLE {self.
|
|
95
|
+
sql = f"DESCRIBE TABLE {self._quoted_source_name}"
|
|
108
96
|
cursor.execute(sql)
|
|
109
97
|
except Exception as e:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
names.append(self._database)
|
|
113
|
-
if self._schema is not None:
|
|
114
|
-
names.append(self._schema)
|
|
115
|
-
source_name = '.'.join(names + [self._source_name])
|
|
116
|
-
raise ValueError(f"Table '{source_name}' does not exist in "
|
|
117
|
-
f"the remote data backend") from e
|
|
98
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
99
|
+
f"in the remote data backend") from e
|
|
118
100
|
|
|
119
101
|
for row in cursor.fetchall():
|
|
120
|
-
column,
|
|
121
|
-
|
|
122
|
-
dtype = self.to_dtype(type)
|
|
123
|
-
if dtype is None:
|
|
124
|
-
continue
|
|
102
|
+
column, dtype, _, null, _, is_pkey, is_unique, *_ = row
|
|
125
103
|
|
|
126
104
|
source_column = SourceColumn(
|
|
127
105
|
name=column,
|
|
128
|
-
dtype=dtype,
|
|
106
|
+
dtype=self._to_dtype(dtype),
|
|
129
107
|
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
130
108
|
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
131
109
|
is_nullable=null.strip().upper() == 'Y',
|
|
@@ -135,35 +113,133 @@ class SnowTable(SQLTable):
|
|
|
135
113
|
return source_columns
|
|
136
114
|
|
|
137
115
|
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
138
|
-
|
|
116
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
139
117
|
with self._connection.cursor() as cursor:
|
|
140
|
-
sql = f"SHOW IMPORTED KEYS IN TABLE {self.
|
|
118
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
|
|
141
119
|
cursor.execute(sql)
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
120
|
+
rows = cursor.fetchall()
|
|
121
|
+
counts = Counter(row[13] for row in rows)
|
|
122
|
+
for row in rows:
|
|
123
|
+
if counts[row[13]] == 1:
|
|
124
|
+
source_foreign_key = SourceForeignKey(
|
|
125
|
+
name=row[8],
|
|
126
|
+
dst_table=f'{row[1]}.{row[2]}.{row[3]}',
|
|
127
|
+
primary_key=row[4],
|
|
128
|
+
)
|
|
129
|
+
source_foreign_keys.append(source_foreign_key)
|
|
130
|
+
return source_foreign_keys
|
|
146
131
|
|
|
147
132
|
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
148
133
|
with self._connection.cursor() as cursor:
|
|
149
134
|
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
150
|
-
sql = f"SELECT {', '.join(columns)}
|
|
135
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
136
|
+
f"FROM {self._quoted_source_name} "
|
|
137
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
151
138
|
cursor.execute(sql)
|
|
152
139
|
table = cursor.fetch_arrow_all()
|
|
153
|
-
|
|
140
|
+
|
|
141
|
+
if table is None:
|
|
142
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
143
|
+
|
|
144
|
+
return self._sanitize(
|
|
145
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
146
|
+
dtype_dict={
|
|
147
|
+
column.name: column.dtype
|
|
148
|
+
for column in self._source_column_dict.values()
|
|
149
|
+
},
|
|
150
|
+
stype_dict=None,
|
|
151
|
+
)
|
|
154
152
|
|
|
155
153
|
def _get_num_rows(self) -> int | None:
|
|
156
|
-
|
|
154
|
+
with self._connection.cursor() as cursor:
|
|
155
|
+
quoted_source_name = quote_ident(self._source_name, char="'")
|
|
156
|
+
sql = (f"SHOW TABLES LIKE {quoted_source_name} "
|
|
157
|
+
f"IN SCHEMA {quote_ident(self._database)}."
|
|
158
|
+
f"{quote_ident(self._schema)}")
|
|
159
|
+
cursor.execute(sql)
|
|
160
|
+
num_rows = cursor.fetchone()[7]
|
|
161
|
+
|
|
162
|
+
if num_rows == 0:
|
|
163
|
+
raise RuntimeError("Table '{self.source_name}' is empty")
|
|
157
164
|
|
|
158
|
-
|
|
165
|
+
return num_rows
|
|
166
|
+
|
|
167
|
+
def _get_expr_sample_df(
|
|
159
168
|
self,
|
|
160
|
-
|
|
169
|
+
columns: Sequence[ColumnSpec],
|
|
161
170
|
) -> pd.DataFrame:
|
|
162
171
|
with self._connection.cursor() as cursor:
|
|
163
|
-
|
|
164
|
-
f"{
|
|
172
|
+
projections = [
|
|
173
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
174
|
+
for column in columns
|
|
165
175
|
]
|
|
166
|
-
sql = f"SELECT {', '.join(
|
|
176
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
177
|
+
f"FROM {self._quoted_source_name} "
|
|
178
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
167
179
|
cursor.execute(sql)
|
|
168
180
|
table = cursor.fetch_arrow_all()
|
|
169
|
-
|
|
181
|
+
|
|
182
|
+
if table is None:
|
|
183
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
184
|
+
|
|
185
|
+
return self._sanitize(
|
|
186
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
187
|
+
dtype_dict={column.name: column.dtype
|
|
188
|
+
for column in columns},
|
|
189
|
+
stype_dict=None,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
194
|
+
if dtype is None:
|
|
195
|
+
return None
|
|
196
|
+
dtype = dtype.strip().upper()
|
|
197
|
+
if dtype.startswith('NUMBER'):
|
|
198
|
+
try: # Parse `scale` from 'NUMBER(precision, scale)':
|
|
199
|
+
scale = int(dtype.split(',')[-1].split(')')[0])
|
|
200
|
+
return Dtype.int if scale == 0 else Dtype.float
|
|
201
|
+
except Exception:
|
|
202
|
+
return Dtype.float
|
|
203
|
+
if dtype == 'FLOAT':
|
|
204
|
+
return Dtype.float
|
|
205
|
+
if dtype.startswith('VARCHAR'):
|
|
206
|
+
return Dtype.string
|
|
207
|
+
if dtype.startswith('BINARY'):
|
|
208
|
+
return Dtype.binary
|
|
209
|
+
if dtype == 'BOOLEAN':
|
|
210
|
+
return Dtype.bool
|
|
211
|
+
if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
|
|
212
|
+
return Dtype.date
|
|
213
|
+
if dtype.startswith('TIME'):
|
|
214
|
+
return Dtype.time
|
|
215
|
+
if dtype.startswith('VECTOR'):
|
|
216
|
+
try: # Parse element data type from 'VECTOR(dtype, dimension)':
|
|
217
|
+
dtype = dtype.split(',')[0].split('(')[1].strip()
|
|
218
|
+
if dtype == 'INT':
|
|
219
|
+
return Dtype.intlist
|
|
220
|
+
elif dtype == 'FLOAT':
|
|
221
|
+
return Dtype.floatlist
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
return Dtype.unsupported
|
|
225
|
+
if dtype.startswith('ARRAY'):
|
|
226
|
+
try: # Parse element data type from 'ARRAY(dtype)':
|
|
227
|
+
dtype = dtype.split('(', maxsplit=1)[1]
|
|
228
|
+
dtype = dtype.rsplit(')', maxsplit=1)[0]
|
|
229
|
+
_dtype = SnowTable._to_dtype(dtype)
|
|
230
|
+
if _dtype is not None and _dtype.is_int():
|
|
231
|
+
return Dtype.intlist
|
|
232
|
+
elif _dtype is not None and _dtype.is_float():
|
|
233
|
+
return Dtype.floatlist
|
|
234
|
+
elif _dtype is not None and _dtype.is_string():
|
|
235
|
+
return Dtype.stringlist
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
return Dtype.unsupported
|
|
239
|
+
# Unsupported data types:
|
|
240
|
+
if re.search(
|
|
241
|
+
'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
|
|
242
|
+
dtype,
|
|
243
|
+
):
|
|
244
|
+
return Dtype.unsupported
|
|
245
|
+
return None
|