kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202601081732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +177 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +247 -0
- kumoai/experimental/rfm/base/table.py +404 -203
- kumoai/experimental/rfm/graph.py +374 -172
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +7 -4
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +762 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/RECORD +52 -41
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202601081732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import io
|
|
3
6
|
import warnings
|
|
4
7
|
from collections import defaultdict
|
|
8
|
+
from collections.abc import Sequence
|
|
5
9
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
10
|
+
from itertools import chain
|
|
7
11
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
13
|
|
|
10
14
|
import pandas as pd
|
|
11
15
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,10 +17,11 @@ from kumoapi.table import TableDefinition
|
|
|
13
17
|
from kumoapi.typing import Stype
|
|
14
18
|
from typing_extensions import Self
|
|
15
19
|
|
|
16
|
-
from kumoai import in_notebook
|
|
17
|
-
from kumoai.experimental.rfm import Table
|
|
20
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
18
22
|
from kumoai.graph import Edge
|
|
19
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
20
25
|
|
|
21
26
|
if TYPE_CHECKING:
|
|
22
27
|
import graphviz
|
|
@@ -26,8 +31,8 @@ if TYPE_CHECKING:
|
|
|
26
31
|
|
|
27
32
|
@dataclass
|
|
28
33
|
class SqliteConnectionConfig(CastMixin):
|
|
29
|
-
uri:
|
|
30
|
-
kwargs:
|
|
34
|
+
uri: str | Path
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
class Graph:
|
|
@@ -87,27 +92,35 @@ class Graph:
|
|
|
87
92
|
def __init__(
|
|
88
93
|
self,
|
|
89
94
|
tables: Sequence[Table],
|
|
90
|
-
edges:
|
|
95
|
+
edges: Sequence[Edge] | None = None,
|
|
91
96
|
) -> None:
|
|
92
97
|
|
|
93
|
-
self._tables:
|
|
94
|
-
self._edges:
|
|
98
|
+
self._tables: dict[str, Table] = {}
|
|
99
|
+
self._edges: list[Edge] = []
|
|
95
100
|
|
|
96
101
|
for table in tables:
|
|
97
102
|
self.add_table(table)
|
|
98
103
|
|
|
99
|
-
for table in tables:
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
106
|
+
continue
|
|
100
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
-
if fkey.name not in table
|
|
108
|
+
if fkey.name not in table:
|
|
109
|
+
continue
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
112
|
+
dst_table_names = [
|
|
113
|
+
table.name for table in self.tables.values()
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
115
|
+
]
|
|
116
|
+
if len(dst_table_names) != 1:
|
|
102
117
|
continue
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
f"'{fkey.dst_table}'.")
|
|
110
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
118
|
+
dst_table = self[dst_table_names[0]]
|
|
119
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
120
|
+
continue
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
122
|
+
continue
|
|
123
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
111
124
|
|
|
112
125
|
for edge in (edges or []):
|
|
113
126
|
_edge = Edge._cast(edge)
|
|
@@ -118,8 +131,8 @@ class Graph:
|
|
|
118
131
|
@classmethod
|
|
119
132
|
def from_data(
|
|
120
133
|
cls,
|
|
121
|
-
df_dict:
|
|
122
|
-
edges:
|
|
134
|
+
df_dict: dict[str, pd.DataFrame],
|
|
135
|
+
edges: Sequence[Edge] | None = None,
|
|
123
136
|
infer_metadata: bool = True,
|
|
124
137
|
verbose: bool = True,
|
|
125
138
|
) -> Self:
|
|
@@ -157,15 +170,17 @@ class Graph:
|
|
|
157
170
|
verbose: Whether to print verbose output.
|
|
158
171
|
"""
|
|
159
172
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
173
|
|
|
162
|
-
graph = cls(
|
|
174
|
+
graph = cls(
|
|
175
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
176
|
+
edges=edges or [],
|
|
177
|
+
)
|
|
163
178
|
|
|
164
179
|
if infer_metadata:
|
|
165
|
-
graph.infer_metadata(False)
|
|
180
|
+
graph.infer_metadata(verbose=False)
|
|
166
181
|
|
|
167
182
|
if edges is None:
|
|
168
|
-
graph.infer_links(False)
|
|
183
|
+
graph.infer_links(verbose=False)
|
|
169
184
|
|
|
170
185
|
if verbose:
|
|
171
186
|
graph.print_metadata()
|
|
@@ -181,10 +196,10 @@ class Graph:
|
|
|
181
196
|
SqliteConnectionConfig,
|
|
182
197
|
str,
|
|
183
198
|
Path,
|
|
184
|
-
|
|
199
|
+
dict[str, Any],
|
|
185
200
|
],
|
|
186
|
-
|
|
187
|
-
edges:
|
|
201
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
202
|
+
edges: Sequence[Edge] | None = None,
|
|
188
203
|
infer_metadata: bool = True,
|
|
189
204
|
verbose: bool = True,
|
|
190
205
|
) -> Self:
|
|
@@ -200,17 +215,25 @@ class Graph:
|
|
|
200
215
|
>>> # Create a graph from a SQLite database:
|
|
201
216
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
217
|
|
|
218
|
+
>>> # Fine-grained control over table specification:
|
|
219
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
220
|
+
... 'USERS',
|
|
221
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
222
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
223
|
+
... ])
|
|
224
|
+
|
|
203
225
|
Args:
|
|
204
226
|
connection: An open connection from
|
|
205
227
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
228
|
path to the database file.
|
|
207
|
-
|
|
208
|
-
all tables present
|
|
229
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
230
|
+
arguments to include. If ``None``, will add all tables present
|
|
231
|
+
in the database.
|
|
209
232
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
210
233
|
add to the graph. If not provided, edges will be automatically
|
|
211
234
|
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
213
|
-
graph.
|
|
235
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
236
|
+
the graph.
|
|
214
237
|
verbose: Whether to print verbose output.
|
|
215
238
|
"""
|
|
216
239
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -219,27 +242,41 @@ class Graph:
|
|
|
219
242
|
connect,
|
|
220
243
|
)
|
|
221
244
|
|
|
245
|
+
internal_connection = False
|
|
222
246
|
if not isinstance(connection, Connection):
|
|
223
247
|
connection = SqliteConnectionConfig._cast(connection)
|
|
224
248
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
249
|
connection = connect(connection.uri, **connection.kwargs)
|
|
250
|
+
internal_connection = True
|
|
226
251
|
assert isinstance(connection, Connection)
|
|
227
252
|
|
|
228
|
-
if
|
|
253
|
+
if tables is None:
|
|
229
254
|
with connection.cursor() as cursor:
|
|
230
255
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
256
|
"WHERE type='table'")
|
|
232
|
-
|
|
257
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
233
258
|
|
|
234
|
-
|
|
259
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
260
|
+
for table in tables:
|
|
261
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
262
|
+
table_kwargs.append(kwargs)
|
|
263
|
+
|
|
264
|
+
graph = cls(
|
|
265
|
+
tables=[
|
|
266
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
267
|
+
for kwargs in table_kwargs
|
|
268
|
+
],
|
|
269
|
+
edges=edges or [],
|
|
270
|
+
)
|
|
235
271
|
|
|
236
|
-
|
|
272
|
+
if internal_connection:
|
|
273
|
+
graph._connection = connection # type: ignore
|
|
237
274
|
|
|
238
275
|
if infer_metadata:
|
|
239
|
-
graph.infer_metadata(False)
|
|
276
|
+
graph.infer_metadata(verbose=False)
|
|
240
277
|
|
|
241
278
|
if edges is None:
|
|
242
|
-
graph.infer_links(False)
|
|
279
|
+
graph.infer_links(verbose=False)
|
|
243
280
|
|
|
244
281
|
if verbose:
|
|
245
282
|
graph.print_metadata()
|
|
@@ -250,9 +287,11 @@ class Graph:
|
|
|
250
287
|
@classmethod
|
|
251
288
|
def from_snowflake(
|
|
252
289
|
cls,
|
|
253
|
-
connection: Union['SnowflakeConnection',
|
|
254
|
-
|
|
255
|
-
|
|
290
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
291
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
292
|
+
database: str | None = None,
|
|
293
|
+
schema: str | None = None,
|
|
294
|
+
edges: Sequence[Edge] | None = None,
|
|
256
295
|
infer_metadata: bool = True,
|
|
257
296
|
verbose: bool = True,
|
|
258
297
|
) -> Self:
|
|
@@ -267,7 +306,14 @@ class Graph:
|
|
|
267
306
|
>>> import kumoai.experimental.rfm as rfm
|
|
268
307
|
|
|
269
308
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
270
|
-
>>> graph = rfm.Graph.from_snowflake()
|
|
309
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
310
|
+
|
|
311
|
+
>>> # Fine-grained control over table specification:
|
|
312
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
313
|
+
... 'USERS',
|
|
314
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
315
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
316
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
271
317
|
|
|
272
318
|
Args:
|
|
273
319
|
connection: An open connection from
|
|
@@ -276,8 +322,11 @@ class Graph:
|
|
|
276
322
|
connection. If ``None``, will re-use an active session in case
|
|
277
323
|
it exists, or create a new connection from credentials stored
|
|
278
324
|
in environment variables.
|
|
279
|
-
|
|
280
|
-
all tables present in the
|
|
325
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
326
|
+
to include. If ``None``, will add all tables present in the
|
|
327
|
+
current database and schema.
|
|
328
|
+
database: The database.
|
|
329
|
+
schema: The schema.
|
|
281
330
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
282
331
|
add to the graph. If not provided, edges will be automatically
|
|
283
332
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -295,27 +344,50 @@ class Graph:
|
|
|
295
344
|
connection = connect(**(connection or {}))
|
|
296
345
|
assert isinstance(connection, Connection)
|
|
297
346
|
|
|
298
|
-
if
|
|
347
|
+
if database is None or schema is None:
|
|
299
348
|
with connection.cursor() as cursor:
|
|
300
349
|
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
301
|
-
|
|
302
|
-
|
|
350
|
+
result = cursor.fetchone()
|
|
351
|
+
database = database or result[0]
|
|
352
|
+
assert database is not None
|
|
353
|
+
schema = schema or result[1]
|
|
354
|
+
|
|
355
|
+
if tables is None:
|
|
356
|
+
if schema is None:
|
|
357
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
358
|
+
"Snowflake schema manually")
|
|
359
|
+
|
|
360
|
+
with connection.cursor() as cursor:
|
|
361
|
+
cursor.execute(f"""
|
|
303
362
|
SELECT TABLE_NAME
|
|
304
363
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
305
364
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
306
|
-
"""
|
|
307
|
-
cursor.
|
|
308
|
-
table_names = [row[0] for row in cursor.fetchall()]
|
|
365
|
+
""")
|
|
366
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
309
367
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
368
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
369
|
+
for table in tables:
|
|
370
|
+
if isinstance(table, str):
|
|
371
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
372
|
+
else:
|
|
373
|
+
kwargs = copy.copy(table)
|
|
374
|
+
kwargs.setdefault('database', database)
|
|
375
|
+
kwargs.setdefault('schema', schema)
|
|
376
|
+
table_kwargs.append(kwargs)
|
|
377
|
+
|
|
378
|
+
graph = cls(
|
|
379
|
+
tables=[
|
|
380
|
+
SnowTable(connection=connection, **kwargs)
|
|
381
|
+
for kwargs in table_kwargs
|
|
382
|
+
],
|
|
383
|
+
edges=edges or [],
|
|
384
|
+
)
|
|
313
385
|
|
|
314
386
|
if infer_metadata:
|
|
315
|
-
graph.infer_metadata(False)
|
|
387
|
+
graph.infer_metadata(verbose=False)
|
|
316
388
|
|
|
317
389
|
if edges is None:
|
|
318
|
-
graph.infer_links(False)
|
|
390
|
+
graph.infer_links(verbose=False)
|
|
319
391
|
|
|
320
392
|
if verbose:
|
|
321
393
|
graph.print_metadata()
|
|
@@ -327,7 +399,7 @@ class Graph:
|
|
|
327
399
|
def from_snowflake_semantic_view(
|
|
328
400
|
cls,
|
|
329
401
|
semantic_view_name: str,
|
|
330
|
-
connection: Union['SnowflakeConnection',
|
|
402
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
331
403
|
verbose: bool = True,
|
|
332
404
|
) -> Self:
|
|
333
405
|
import yaml
|
|
@@ -345,43 +417,165 @@ class Graph:
|
|
|
345
417
|
with connection.cursor() as cursor:
|
|
346
418
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
347
419
|
f"'{semantic_view_name}')")
|
|
348
|
-
|
|
420
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
349
421
|
|
|
350
422
|
graph = cls(tables=[])
|
|
351
423
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
424
|
+
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
426
|
+
for table_cfg in cfg['tables']:
|
|
427
|
+
table_name = table_cfg['name']
|
|
428
|
+
source_table_name = table_cfg['base_table']['table']
|
|
429
|
+
database = table_cfg['base_table']['database']
|
|
430
|
+
schema = table_cfg['base_table']['schema']
|
|
431
|
+
|
|
432
|
+
primary_key: str | None = None
|
|
433
|
+
if 'primary_key' in table_cfg:
|
|
434
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
435
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
436
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
437
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
438
|
+
msgs.append(f"Failed to add primary key for table "
|
|
439
|
+
f"'{table_name}' since composite primary keys "
|
|
440
|
+
f"are not yet supported")
|
|
441
|
+
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
444
|
+
for column_cfg in chain(
|
|
445
|
+
table_cfg.get('dimensions', []),
|
|
446
|
+
table_cfg.get('time_dimensions', []),
|
|
447
|
+
table_cfg.get('facts', []),
|
|
448
|
+
):
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
357
483
|
|
|
358
484
|
table = SnowTable(
|
|
359
485
|
connection,
|
|
360
|
-
name=
|
|
361
|
-
|
|
362
|
-
|
|
486
|
+
name=table_name,
|
|
487
|
+
source_name=source_table_name,
|
|
488
|
+
database=database,
|
|
489
|
+
schema=schema,
|
|
490
|
+
columns=columns,
|
|
363
491
|
primary_key=primary_key,
|
|
364
492
|
)
|
|
493
|
+
|
|
494
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
+
table.infer_time_column(verbose=False)
|
|
496
|
+
|
|
365
497
|
graph.add_table(table)
|
|
366
498
|
|
|
367
|
-
|
|
499
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
500
|
+
name = relation_cfg['name']
|
|
501
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
502
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
503
|
+
f"composite key references are not yet supported")
|
|
504
|
+
continue
|
|
368
505
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
506
|
+
left_table = relation_cfg['left_table']
|
|
507
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
508
|
+
right_table = relation_cfg['right_table']
|
|
509
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
510
|
+
|
|
511
|
+
if graph[right_table]._primary_key != right_key:
|
|
512
|
+
# Semantic view error - this should never be triggered:
|
|
513
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
514
|
+
f"referenced key '{right_key}' of table "
|
|
515
|
+
f"'{right_table}' is not a primary key")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if graph[left_table]._primary_key == left_key:
|
|
519
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
520
|
+
f"referencing key '{left_key}' of table "
|
|
521
|
+
f"'{left_table}' is a primary key")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
if left_key not in graph[left_table]:
|
|
525
|
+
graph[left_table].add_column(left_key)
|
|
526
|
+
|
|
527
|
+
graph.link(left_table, left_key, right_table)
|
|
528
|
+
|
|
529
|
+
graph.validate()
|
|
377
530
|
|
|
378
531
|
if verbose:
|
|
379
532
|
graph.print_metadata()
|
|
380
533
|
graph.print_links()
|
|
381
534
|
|
|
535
|
+
if len(msgs) > 0:
|
|
536
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
537
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
538
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
539
|
+
|
|
382
540
|
return graph
|
|
383
541
|
|
|
384
|
-
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
571
|
+
# Backend #################################################################
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def backend(self) -> DataBackend | None:
|
|
575
|
+
backends = [table.backend for table in self._tables.values()]
|
|
576
|
+
return backends[0] if len(backends) > 0 else None
|
|
577
|
+
|
|
578
|
+
# Tables ##################################################################
|
|
385
579
|
|
|
386
580
|
def has_table(self, name: str) -> bool:
|
|
387
581
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -400,7 +594,7 @@ class Graph:
|
|
|
400
594
|
return self.tables[name]
|
|
401
595
|
|
|
402
596
|
@property
|
|
403
|
-
def tables(self) ->
|
|
597
|
+
def tables(self) -> dict[str, Table]:
|
|
404
598
|
r"""Returns the dictionary of table objects."""
|
|
405
599
|
return self._tables
|
|
406
600
|
|
|
@@ -420,13 +614,10 @@ class Graph:
|
|
|
420
614
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
421
615
|
f"this graph; table names must be globally unique.")
|
|
422
616
|
|
|
423
|
-
if
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
f"'{table.__class__.__name__}' to this "
|
|
428
|
-
f"graph since other tables are of type "
|
|
429
|
-
f"'{cls.__name__}'.")
|
|
617
|
+
if self.backend is not None and table.backend != self.backend:
|
|
618
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
619
|
+
f"'{table.backend}' to this graph since other "
|
|
620
|
+
f"tables have backend '{self.backend}'.")
|
|
430
621
|
|
|
431
622
|
self._tables[table.name] = table
|
|
432
623
|
|
|
@@ -458,28 +649,28 @@ class Graph:
|
|
|
458
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
459
650
|
information about the tables in this graph.
|
|
460
651
|
|
|
461
|
-
The returned dataframe has columns ``
|
|
462
|
-
``
|
|
463
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
464
655
|
|
|
465
656
|
Example:
|
|
466
657
|
>>> # doctest: +SKIP
|
|
467
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
468
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
469
660
|
>>> graph.metadata # doctest: +SKIP
|
|
470
|
-
|
|
471
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
472
663
|
"""
|
|
473
664
|
tables = list(self.tables.values())
|
|
474
665
|
|
|
475
666
|
return pd.DataFrame({
|
|
476
|
-
'
|
|
667
|
+
'Name':
|
|
477
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
478
|
-
'
|
|
669
|
+
'Primary Key':
|
|
479
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
480
|
-
'
|
|
671
|
+
'Time Column':
|
|
481
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
482
|
-
'
|
|
673
|
+
'End Time Column':
|
|
483
674
|
pd.Series(
|
|
484
675
|
dtype=str,
|
|
485
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -488,20 +679,8 @@ class Graph:
|
|
|
488
679
|
|
|
489
680
|
def print_metadata(self) -> None:
|
|
490
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
display(Markdown('### 🗂️ Graph Metadata'))
|
|
494
|
-
df = self.metadata
|
|
495
|
-
try:
|
|
496
|
-
if hasattr(df.style, 'hide'):
|
|
497
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
498
|
-
else:
|
|
499
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
500
|
-
except ImportError:
|
|
501
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
502
|
-
else:
|
|
503
|
-
print("🗂️ Graph Metadata:")
|
|
504
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
505
684
|
|
|
506
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
507
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -524,42 +703,33 @@ class Graph:
|
|
|
524
703
|
# Edges ###################################################################
|
|
525
704
|
|
|
526
705
|
@property
|
|
527
|
-
def edges(self) ->
|
|
706
|
+
def edges(self) -> list[Edge]:
|
|
528
707
|
r"""Returns the edges of the graph."""
|
|
529
708
|
return self._edges
|
|
530
709
|
|
|
531
710
|
def print_links(self) -> None:
|
|
532
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
533
|
-
edges = [(
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
else:
|
|
547
|
-
display(Markdown('*No links registered*'))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
548
725
|
else:
|
|
549
|
-
|
|
550
|
-
if len(edges) > 0:
|
|
551
|
-
print('\n'.join([
|
|
552
|
-
f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
|
|
553
|
-
for edge in edges
|
|
554
|
-
]))
|
|
555
|
-
else:
|
|
556
|
-
print('No links registered')
|
|
726
|
+
display.italic("No links registered")
|
|
557
727
|
|
|
558
728
|
def link(
|
|
559
729
|
self,
|
|
560
|
-
src_table:
|
|
730
|
+
src_table: str | Table,
|
|
561
731
|
fkey: str,
|
|
562
|
-
dst_table:
|
|
732
|
+
dst_table: str | Table,
|
|
563
733
|
) -> Self:
|
|
564
734
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
565
735
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -620,9 +790,9 @@ class Graph:
|
|
|
620
790
|
|
|
621
791
|
def unlink(
|
|
622
792
|
self,
|
|
623
|
-
src_table:
|
|
793
|
+
src_table: str | Table,
|
|
624
794
|
fkey: str,
|
|
625
|
-
dst_table:
|
|
795
|
+
dst_table: str | Table,
|
|
626
796
|
) -> Self:
|
|
627
797
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
628
798
|
|
|
@@ -660,6 +830,30 @@ class Graph:
|
|
|
660
830
|
"""
|
|
661
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
662
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
663
857
|
# A list of primary key candidates (+score) for every column:
|
|
664
858
|
candidate_dict: dict[
|
|
665
859
|
tuple[str, str],
|
|
@@ -759,13 +953,8 @@ class Graph:
|
|
|
759
953
|
if score < 5.0:
|
|
760
954
|
continue
|
|
761
955
|
|
|
762
|
-
candidate_dict[(
|
|
763
|
-
|
|
764
|
-
src_key.name,
|
|
765
|
-
)].append((
|
|
766
|
-
dst_table.name,
|
|
767
|
-
score,
|
|
768
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
769
958
|
|
|
770
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
771
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -799,6 +988,10 @@ class Graph:
|
|
|
799
988
|
raise ValueError("At least one table needs to be added to the "
|
|
800
989
|
"graph")
|
|
801
990
|
|
|
991
|
+
backends = {table.backend for table in self._tables.values()}
|
|
992
|
+
if len(backends) != 1:
|
|
993
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
994
|
+
|
|
802
995
|
for edge in self.edges:
|
|
803
996
|
src_table, fkey, dst_table = edge
|
|
804
997
|
|
|
@@ -820,24 +1013,26 @@ class Graph:
|
|
|
820
1013
|
f"either the primary key or the link before "
|
|
821
1014
|
f"before proceeding.")
|
|
822
1015
|
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
841
1036
|
|
|
842
1037
|
return self
|
|
843
1038
|
|
|
@@ -845,7 +1040,7 @@ class Graph:
|
|
|
845
1040
|
|
|
846
1041
|
def visualize(
|
|
847
1042
|
self,
|
|
848
|
-
path:
|
|
1043
|
+
path: str | io.BytesIO | None = None,
|
|
849
1044
|
show_columns: bool = True,
|
|
850
1045
|
) -> 'graphviz.Graph':
|
|
851
1046
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -870,33 +1065,33 @@ class Graph:
|
|
|
870
1065
|
|
|
871
1066
|
return True
|
|
872
1067
|
|
|
873
|
-
# Check basic dependency:
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1068
|
+
try: # Check basic dependency:
|
|
1069
|
+
import graphviz
|
|
1070
|
+
except ImportError as e:
|
|
1071
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1072
|
+
"visualization") from e
|
|
1073
|
+
|
|
1074
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
878
1075
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
879
1076
|
"executables are not installed. These "
|
|
880
1077
|
"dependencies are required in addition to the "
|
|
881
1078
|
"'graphviz' Python package. Please install "
|
|
882
1079
|
"them as described at "
|
|
883
1080
|
"https://graphviz.org/download/.")
|
|
884
|
-
else:
|
|
885
|
-
import graphviz
|
|
886
1081
|
|
|
887
|
-
format:
|
|
1082
|
+
format: str | None = None
|
|
888
1083
|
if isinstance(path, str):
|
|
889
1084
|
format = path.split('.')[-1]
|
|
890
1085
|
elif isinstance(path, io.BytesIO):
|
|
891
1086
|
format = 'svg'
|
|
892
1087
|
graph = graphviz.Graph(format=format)
|
|
893
1088
|
|
|
894
|
-
def left_align(keys:
|
|
1089
|
+
def left_align(keys: list[str]) -> str:
|
|
895
1090
|
if len(keys) == 0:
|
|
896
1091
|
return ""
|
|
897
1092
|
return '\\l'.join(keys) + '\\l'
|
|
898
1093
|
|
|
899
|
-
fkeys_dict:
|
|
1094
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
900
1095
|
for src_table_name, fkey_name, _ in self.edges:
|
|
901
1096
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
902
1097
|
|
|
@@ -966,6 +1161,9 @@ class Graph:
|
|
|
966
1161
|
graph.render(path, cleanup=True)
|
|
967
1162
|
elif isinstance(path, io.BytesIO):
|
|
968
1163
|
path.write(graph.pipe())
|
|
1164
|
+
elif in_snowflake_notebook():
|
|
1165
|
+
import streamlit as st
|
|
1166
|
+
st.graphviz_chart(graph)
|
|
969
1167
|
elif in_notebook():
|
|
970
1168
|
from IPython.display import display
|
|
971
1169
|
display(graph)
|
|
@@ -989,8 +1187,8 @@ class Graph:
|
|
|
989
1187
|
# Helpers #################################################################
|
|
990
1188
|
|
|
991
1189
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
992
|
-
tables:
|
|
993
|
-
col_groups:
|
|
1190
|
+
tables: dict[str, TableDefinition] = {}
|
|
1191
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
994
1192
|
for table_name, table in self.tables.items():
|
|
995
1193
|
tables[table_name] = table._to_api_table_definition()
|
|
996
1194
|
if table.primary_key is None:
|
|
@@ -1033,3 +1231,7 @@ class Graph:
|
|
|
1033
1231
|
f' tables={tables},\n'
|
|
1034
1232
|
f' edges={edges},\n'
|
|
1035
1233
|
f')')
|
|
1234
|
+
|
|
1235
|
+
def __del__(self) -> None:
|
|
1236
|
+
if hasattr(self, '_connection'):
|
|
1237
|
+
self._connection.close()
|