kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
- kumoai/experimental/rfm/backend/snow/table.py +178 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +22 -4
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +696 -47
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +384 -207
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +359 -187
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +10 -5
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +770 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,22 @@
|
|
|
1
1
|
import re
|
|
2
|
-
from
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import cast
|
|
3
5
|
|
|
4
6
|
import pandas as pd
|
|
7
|
+
from kumoapi.model_plan import MissingType
|
|
5
8
|
from kumoapi.typing import Dtype
|
|
6
9
|
|
|
7
10
|
from kumoai.experimental.rfm.backend.snow import Connection
|
|
8
|
-
from kumoai.experimental.rfm.base import
|
|
11
|
+
from kumoai.experimental.rfm.base import (
|
|
12
|
+
ColumnSpec,
|
|
13
|
+
ColumnSpecType,
|
|
14
|
+
DataBackend,
|
|
15
|
+
SourceColumn,
|
|
16
|
+
SourceForeignKey,
|
|
17
|
+
Table,
|
|
18
|
+
)
|
|
19
|
+
from kumoai.utils import quote_ident
|
|
9
20
|
|
|
10
21
|
|
|
11
22
|
class SnowTable(Table):
|
|
@@ -14,6 +25,8 @@ class SnowTable(Table):
|
|
|
14
25
|
Args:
|
|
15
26
|
connection: The connection to a :class:`snowflake` database.
|
|
16
27
|
name: The name of this table.
|
|
28
|
+
source_name: The source name of this table. If set to ``None``,
|
|
29
|
+
``name`` is being used.
|
|
17
30
|
database: The database.
|
|
18
31
|
schema: The schema.
|
|
19
32
|
columns: The selected columns of this table.
|
|
@@ -26,17 +39,27 @@ class SnowTable(Table):
|
|
|
26
39
|
self,
|
|
27
40
|
connection: Connection,
|
|
28
41
|
name: str,
|
|
42
|
+
source_name: str | None = None,
|
|
29
43
|
database: str | None = None,
|
|
30
44
|
schema: str | None = None,
|
|
31
|
-
columns:
|
|
32
|
-
primary_key:
|
|
33
|
-
time_column:
|
|
34
|
-
end_time_column:
|
|
45
|
+
columns: Sequence[ColumnSpecType] | None = None,
|
|
46
|
+
primary_key: MissingType | str | None = MissingType.VALUE,
|
|
47
|
+
time_column: str | None = None,
|
|
48
|
+
end_time_column: str | None = None,
|
|
35
49
|
) -> None:
|
|
36
50
|
|
|
37
|
-
if database is
|
|
38
|
-
|
|
39
|
-
|
|
51
|
+
if database is None or schema is None:
|
|
52
|
+
with connection.cursor() as cursor:
|
|
53
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
54
|
+
result = cursor.fetchone()
|
|
55
|
+
database = database or result[0]
|
|
56
|
+
assert database is not None
|
|
57
|
+
schema = schema or result[1]
|
|
58
|
+
|
|
59
|
+
if schema is None:
|
|
60
|
+
raise ValueError(f"Unspecified 'schema' for table "
|
|
61
|
+
f"'{source_name or name}' in database "
|
|
62
|
+
f"'{database}'")
|
|
40
63
|
|
|
41
64
|
self._connection = connection
|
|
42
65
|
self._database = database
|
|
@@ -44,6 +67,7 @@ class SnowTable(Table):
|
|
|
44
67
|
|
|
45
68
|
super().__init__(
|
|
46
69
|
name=name,
|
|
70
|
+
source_name=source_name,
|
|
47
71
|
columns=columns,
|
|
48
72
|
primary_key=primary_key,
|
|
49
73
|
time_column=time_column,
|
|
@@ -51,67 +75,171 @@ class SnowTable(Table):
|
|
|
51
75
|
)
|
|
52
76
|
|
|
53
77
|
@property
|
|
54
|
-
def
|
|
55
|
-
names
|
|
56
|
-
if self._database is not None:
|
|
57
|
-
assert self._schema is not None
|
|
58
|
-
names.extend([self._database, self._schema])
|
|
59
|
-
elif self._schema is not None:
|
|
60
|
-
names.append(self._schema)
|
|
61
|
-
names.append(self._name)
|
|
78
|
+
def source_name(self) -> str:
|
|
79
|
+
names = [self._database, self._schema, self._source_name]
|
|
62
80
|
return '.'.join(names)
|
|
63
81
|
|
|
64
|
-
|
|
65
|
-
|
|
82
|
+
@property
|
|
83
|
+
def _quoted_source_name(self) -> str:
|
|
84
|
+
names = [self._database, self._schema, self._source_name]
|
|
85
|
+
return '.'.join([quote_ident(name) for name in names])
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def backend(self) -> DataBackend:
|
|
89
|
+
return cast(DataBackend, DataBackend.SNOWFLAKE)
|
|
90
|
+
|
|
91
|
+
def _get_source_columns(self) -> list[SourceColumn]:
|
|
92
|
+
source_columns: list[SourceColumn] = []
|
|
66
93
|
with self._connection.cursor() as cursor:
|
|
67
94
|
try:
|
|
68
|
-
|
|
95
|
+
sql = f"DESCRIBE TABLE {self._quoted_source_name}"
|
|
96
|
+
cursor.execute(sql)
|
|
69
97
|
except Exception as e:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
|
|
98
|
+
raise ValueError(f"Table '{self.source_name}' does not exist "
|
|
99
|
+
f"in the remote data backend") from e
|
|
72
100
|
|
|
73
101
|
for row in cursor.fetchall():
|
|
74
|
-
column,
|
|
75
|
-
|
|
76
|
-
type = type.strip().upper()
|
|
77
|
-
if type.startswith('NUMBER'):
|
|
78
|
-
dtype = Dtype.int
|
|
79
|
-
elif type.startswith('VARCHAR'):
|
|
80
|
-
dtype = Dtype.string
|
|
81
|
-
elif type == 'FLOAT':
|
|
82
|
-
dtype = Dtype.float
|
|
83
|
-
elif type == 'BOOLEAN':
|
|
84
|
-
dtype = Dtype.bool
|
|
85
|
-
elif re.search('DATE|TIMESTAMP', type):
|
|
86
|
-
dtype = Dtype.date
|
|
87
|
-
else:
|
|
88
|
-
continue
|
|
102
|
+
column, dtype, _, null, _, is_pkey, is_unique, *_ = row
|
|
89
103
|
|
|
90
104
|
source_column = SourceColumn(
|
|
91
105
|
name=column,
|
|
92
|
-
dtype=dtype,
|
|
106
|
+
dtype=self._to_dtype(dtype),
|
|
93
107
|
is_primary_key=is_pkey.strip().upper() == 'Y',
|
|
94
108
|
is_unique_key=is_unique.strip().upper() == 'Y',
|
|
109
|
+
is_nullable=null.strip().upper() == 'Y',
|
|
95
110
|
)
|
|
96
111
|
source_columns.append(source_column)
|
|
97
112
|
|
|
98
113
|
return source_columns
|
|
99
114
|
|
|
100
|
-
def _get_source_foreign_keys(self) ->
|
|
101
|
-
|
|
115
|
+
def _get_source_foreign_keys(self) -> list[SourceForeignKey]:
|
|
116
|
+
source_foreign_keys: list[SourceForeignKey] = []
|
|
102
117
|
with self._connection.cursor() as cursor:
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
118
|
+
sql = f"SHOW IMPORTED KEYS IN TABLE {self._quoted_source_name}"
|
|
119
|
+
cursor.execute(sql)
|
|
120
|
+
rows = cursor.fetchall()
|
|
121
|
+
counts = Counter(row[13] for row in rows)
|
|
122
|
+
for row in rows:
|
|
123
|
+
if counts[row[13]] == 1:
|
|
124
|
+
source_foreign_key = SourceForeignKey(
|
|
125
|
+
name=row[8],
|
|
126
|
+
dst_table=f'{row[1]}.{row[2]}.{row[3]}',
|
|
127
|
+
primary_key=row[4],
|
|
128
|
+
)
|
|
129
|
+
source_foreign_keys.append(source_foreign_key)
|
|
130
|
+
return source_foreign_keys
|
|
131
|
+
|
|
132
|
+
def _get_source_sample_df(self) -> pd.DataFrame:
|
|
133
|
+
with self._connection.cursor() as cursor:
|
|
134
|
+
columns = [quote_ident(col) for col in self._source_column_dict]
|
|
135
|
+
sql = (f"SELECT {', '.join(columns)} "
|
|
136
|
+
f"FROM {self._quoted_source_name} "
|
|
137
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
138
|
+
cursor.execute(sql)
|
|
139
|
+
table = cursor.fetch_arrow_all()
|
|
108
140
|
|
|
109
|
-
|
|
141
|
+
if table is None:
|
|
142
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
143
|
+
|
|
144
|
+
return self._sanitize(
|
|
145
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
146
|
+
dtype_dict={
|
|
147
|
+
column.name: column.dtype
|
|
148
|
+
for column in self._source_column_dict.values()
|
|
149
|
+
},
|
|
150
|
+
stype_dict=None,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def _get_num_rows(self) -> int | None:
|
|
110
154
|
with self._connection.cursor() as cursor:
|
|
111
|
-
|
|
112
|
-
|
|
155
|
+
quoted_source_name = quote_ident(self._source_name, char="'")
|
|
156
|
+
sql = (f"SHOW TABLES LIKE {quoted_source_name} "
|
|
157
|
+
f"IN SCHEMA {quote_ident(self._database)}."
|
|
158
|
+
f"{quote_ident(self._schema)}")
|
|
159
|
+
cursor.execute(sql)
|
|
160
|
+
num_rows = cursor.fetchone()[7]
|
|
161
|
+
|
|
162
|
+
if num_rows == 0:
|
|
163
|
+
raise RuntimeError("Table '{self.source_name}' is empty")
|
|
164
|
+
|
|
165
|
+
return num_rows
|
|
166
|
+
|
|
167
|
+
def _get_expr_sample_df(
|
|
168
|
+
self,
|
|
169
|
+
columns: Sequence[ColumnSpec],
|
|
170
|
+
) -> pd.DataFrame:
|
|
171
|
+
with self._connection.cursor() as cursor:
|
|
172
|
+
projections = [
|
|
173
|
+
f"{column.expr} AS {quote_ident(column.name)}"
|
|
174
|
+
for column in columns
|
|
175
|
+
]
|
|
176
|
+
sql = (f"SELECT {', '.join(projections)} "
|
|
177
|
+
f"FROM {self._quoted_source_name} "
|
|
178
|
+
f"LIMIT {self._NUM_SAMPLE_ROWS}")
|
|
179
|
+
cursor.execute(sql)
|
|
113
180
|
table = cursor.fetch_arrow_all()
|
|
114
|
-
return table.to_pandas(types_mapper=pd.ArrowDtype)
|
|
115
181
|
|
|
116
|
-
|
|
182
|
+
if table is None:
|
|
183
|
+
raise RuntimeError(f"Table '{self.source_name}' is empty")
|
|
184
|
+
|
|
185
|
+
return self._sanitize(
|
|
186
|
+
df=table.to_pandas(types_mapper=pd.ArrowDtype),
|
|
187
|
+
dtype_dict={column.name: column.dtype
|
|
188
|
+
for column in columns},
|
|
189
|
+
stype_dict=None,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _to_dtype(dtype: str | None) -> Dtype | None:
|
|
194
|
+
if dtype is None:
|
|
195
|
+
return None
|
|
196
|
+
dtype = dtype.strip().upper()
|
|
197
|
+
if dtype.startswith('NUMBER'):
|
|
198
|
+
try: # Parse `scale` from 'NUMBER(precision, scale)':
|
|
199
|
+
scale = int(dtype.split(',')[-1].split(')')[0])
|
|
200
|
+
return Dtype.int if scale == 0 else Dtype.float
|
|
201
|
+
except Exception:
|
|
202
|
+
return Dtype.float
|
|
203
|
+
if dtype == 'FLOAT':
|
|
204
|
+
return Dtype.float
|
|
205
|
+
if dtype.startswith('VARCHAR'):
|
|
206
|
+
return Dtype.string
|
|
207
|
+
if dtype.startswith('BINARY'):
|
|
208
|
+
return Dtype.binary
|
|
209
|
+
if dtype == 'BOOLEAN':
|
|
210
|
+
return Dtype.bool
|
|
211
|
+
if dtype.startswith('DATE') or dtype.startswith('TIMESTAMP'):
|
|
212
|
+
return Dtype.date
|
|
213
|
+
if dtype.startswith('TIME'):
|
|
214
|
+
return Dtype.time
|
|
215
|
+
if dtype.startswith('VECTOR'):
|
|
216
|
+
try: # Parse element data type from 'VECTOR(dtype, dimension)':
|
|
217
|
+
dtype = dtype.split(',')[0].split('(')[1].strip()
|
|
218
|
+
if dtype == 'INT':
|
|
219
|
+
return Dtype.intlist
|
|
220
|
+
elif dtype == 'FLOAT':
|
|
221
|
+
return Dtype.floatlist
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
return Dtype.unsupported
|
|
225
|
+
if dtype.startswith('ARRAY'):
|
|
226
|
+
try: # Parse element data type from 'ARRAY(dtype)':
|
|
227
|
+
dtype = dtype.split('(', maxsplit=1)[1]
|
|
228
|
+
dtype = dtype.rsplit(')', maxsplit=1)[0]
|
|
229
|
+
_dtype = SnowTable._to_dtype(dtype)
|
|
230
|
+
if _dtype is not None and _dtype.is_int():
|
|
231
|
+
return Dtype.intlist
|
|
232
|
+
elif _dtype is not None and _dtype.is_float():
|
|
233
|
+
return Dtype.floatlist
|
|
234
|
+
elif _dtype is not None and _dtype.is_string():
|
|
235
|
+
return Dtype.stringlist
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
return Dtype.unsupported
|
|
239
|
+
# Unsupported data types:
|
|
240
|
+
if re.search(
|
|
241
|
+
'DECFLOAT|VARIANT|OBJECT|MAP|FILE|GEOGRAPHY|GEOMETRY',
|
|
242
|
+
dtype,
|
|
243
|
+
):
|
|
244
|
+
return Dtype.unsupported
|
|
117
245
|
return None
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Any, TypeAlias
|
|
2
|
+
from typing import Any, TypeAlias
|
|
3
3
|
|
|
4
4
|
try:
|
|
5
5
|
import adbc_driver_sqlite.dbapi as adbc
|
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
|
11
11
|
Connection: TypeAlias = adbc.AdbcSqliteConnection
|
|
12
12
|
|
|
13
13
|
|
|
14
|
-
def connect(uri:
|
|
14
|
+
def connect(uri: str | Path | None = None, **kwargs: Any) -> Connection:
|
|
15
15
|
r"""Opens a connection to a :class:`sqlite` database.
|
|
16
16
|
|
|
17
17
|
uri: The path to the database file to be opened.
|
|
@@ -22,9 +22,11 @@ def connect(uri: Union[str, Path, None] = None, **kwargs: Any) -> Connection:
|
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
from .table import SQLiteTable # noqa: E402
|
|
25
|
+
from .sampler import SQLiteSampler # noqa: E402
|
|
25
26
|
|
|
26
27
|
__all__ = [
|
|
27
28
|
'connect',
|
|
28
29
|
'Connection',
|
|
29
30
|
'SQLiteTable',
|
|
31
|
+
'SQLiteSampler',
|
|
30
32
|
]
|