kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +49 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +32 -14
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +186 -39
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
- kumoai/experimental/rfm/base/__init__.py +23 -3
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +380 -185
- kumoai/experimental/rfm/graph.py +404 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +52 -60
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +283 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +51 -0
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,11 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import io
|
|
3
6
|
import warnings
|
|
4
7
|
from collections import defaultdict
|
|
8
|
+
from collections.abc import Sequence
|
|
5
9
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
10
|
+
from itertools import chain
|
|
7
11
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
13
|
|
|
10
14
|
import pandas as pd
|
|
11
15
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,10 +17,11 @@ from kumoapi.table import TableDefinition
|
|
|
13
17
|
from kumoapi.typing import Stype
|
|
14
18
|
from typing_extensions import Self
|
|
15
19
|
|
|
16
|
-
from kumoai import in_notebook
|
|
17
|
-
from kumoai.experimental.rfm import Table
|
|
20
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
18
22
|
from kumoai.graph import Edge
|
|
19
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
20
25
|
|
|
21
26
|
if TYPE_CHECKING:
|
|
22
27
|
import graphviz
|
|
@@ -26,8 +31,8 @@ if TYPE_CHECKING:
|
|
|
26
31
|
|
|
27
32
|
@dataclass
|
|
28
33
|
class SqliteConnectionConfig(CastMixin):
|
|
29
|
-
uri:
|
|
30
|
-
kwargs:
|
|
34
|
+
uri: str | Path
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
class Graph:
|
|
@@ -87,27 +92,35 @@ class Graph:
|
|
|
87
92
|
def __init__(
|
|
88
93
|
self,
|
|
89
94
|
tables: Sequence[Table],
|
|
90
|
-
edges:
|
|
95
|
+
edges: Sequence[Edge] | None = None,
|
|
91
96
|
) -> None:
|
|
92
97
|
|
|
93
|
-
self._tables:
|
|
94
|
-
self._edges:
|
|
98
|
+
self._tables: dict[str, Table] = {}
|
|
99
|
+
self._edges: list[Edge] = []
|
|
95
100
|
|
|
96
101
|
for table in tables:
|
|
97
102
|
self.add_table(table)
|
|
98
103
|
|
|
99
|
-
for table in tables:
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
106
|
+
continue
|
|
100
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
-
if fkey.name not in table
|
|
108
|
+
if fkey.name not in table:
|
|
109
|
+
continue
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
112
|
+
dst_table_names = [
|
|
113
|
+
table.name for table in self.tables.values()
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
115
|
+
]
|
|
116
|
+
if len(dst_table_names) != 1:
|
|
117
|
+
continue
|
|
118
|
+
dst_table = self[dst_table_names[0]]
|
|
119
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
120
|
+
continue
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
102
122
|
continue
|
|
103
|
-
|
|
104
|
-
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
105
|
-
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
106
|
-
raise ValueError(f"Found duplicate primary key definition "
|
|
107
|
-
f"'{self[fkey.dst_table]._primary_key}' "
|
|
108
|
-
f"and '{fkey.primary_key}' in table "
|
|
109
|
-
f"'{fkey.dst_table}'.")
|
|
110
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
123
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
111
124
|
|
|
112
125
|
for edge in (edges or []):
|
|
113
126
|
_edge = Edge._cast(edge)
|
|
@@ -118,8 +131,8 @@ class Graph:
|
|
|
118
131
|
@classmethod
|
|
119
132
|
def from_data(
|
|
120
133
|
cls,
|
|
121
|
-
df_dict:
|
|
122
|
-
edges:
|
|
134
|
+
df_dict: dict[str, pd.DataFrame],
|
|
135
|
+
edges: Sequence[Edge] | None = None,
|
|
123
136
|
infer_metadata: bool = True,
|
|
124
137
|
verbose: bool = True,
|
|
125
138
|
) -> Self:
|
|
@@ -157,15 +170,17 @@ class Graph:
|
|
|
157
170
|
verbose: Whether to print verbose output.
|
|
158
171
|
"""
|
|
159
172
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
173
|
|
|
162
|
-
graph = cls(
|
|
174
|
+
graph = cls(
|
|
175
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
176
|
+
edges=edges or [],
|
|
177
|
+
)
|
|
163
178
|
|
|
164
179
|
if infer_metadata:
|
|
165
|
-
graph.infer_metadata(False)
|
|
180
|
+
graph.infer_metadata(verbose=False)
|
|
166
181
|
|
|
167
182
|
if edges is None:
|
|
168
|
-
graph.infer_links(False)
|
|
183
|
+
graph.infer_links(verbose=False)
|
|
169
184
|
|
|
170
185
|
if verbose:
|
|
171
186
|
graph.print_metadata()
|
|
@@ -181,10 +196,10 @@ class Graph:
|
|
|
181
196
|
SqliteConnectionConfig,
|
|
182
197
|
str,
|
|
183
198
|
Path,
|
|
184
|
-
|
|
199
|
+
dict[str, Any],
|
|
185
200
|
],
|
|
186
|
-
|
|
187
|
-
edges:
|
|
201
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
202
|
+
edges: Sequence[Edge] | None = None,
|
|
188
203
|
infer_metadata: bool = True,
|
|
189
204
|
verbose: bool = True,
|
|
190
205
|
) -> Self:
|
|
@@ -200,17 +215,25 @@ class Graph:
|
|
|
200
215
|
>>> # Create a graph from a SQLite database:
|
|
201
216
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
217
|
|
|
218
|
+
>>> # Fine-grained control over table specification:
|
|
219
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
220
|
+
... 'USERS',
|
|
221
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
222
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
223
|
+
... ])
|
|
224
|
+
|
|
203
225
|
Args:
|
|
204
226
|
connection: An open connection from
|
|
205
227
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
228
|
path to the database file.
|
|
207
|
-
|
|
208
|
-
all tables present
|
|
229
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
230
|
+
arguments to include. If ``None``, will add all tables present
|
|
231
|
+
in the database.
|
|
209
232
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
210
233
|
add to the graph. If not provided, edges will be automatically
|
|
211
234
|
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
213
|
-
graph.
|
|
235
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
236
|
+
the graph.
|
|
214
237
|
verbose: Whether to print verbose output.
|
|
215
238
|
"""
|
|
216
239
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -219,27 +242,41 @@ class Graph:
|
|
|
219
242
|
connect,
|
|
220
243
|
)
|
|
221
244
|
|
|
245
|
+
internal_connection = False
|
|
222
246
|
if not isinstance(connection, Connection):
|
|
223
247
|
connection = SqliteConnectionConfig._cast(connection)
|
|
224
248
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
249
|
connection = connect(connection.uri, **connection.kwargs)
|
|
250
|
+
internal_connection = True
|
|
226
251
|
assert isinstance(connection, Connection)
|
|
227
252
|
|
|
228
|
-
if
|
|
253
|
+
if tables is None:
|
|
229
254
|
with connection.cursor() as cursor:
|
|
230
255
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
256
|
"WHERE type='table'")
|
|
232
|
-
|
|
257
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
233
258
|
|
|
234
|
-
|
|
259
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
260
|
+
for table in tables:
|
|
261
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
262
|
+
table_kwargs.append(kwargs)
|
|
263
|
+
|
|
264
|
+
graph = cls(
|
|
265
|
+
tables=[
|
|
266
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
267
|
+
for kwargs in table_kwargs
|
|
268
|
+
],
|
|
269
|
+
edges=edges or [],
|
|
270
|
+
)
|
|
235
271
|
|
|
236
|
-
|
|
272
|
+
if internal_connection:
|
|
273
|
+
graph._connection = connection # type: ignore
|
|
237
274
|
|
|
238
275
|
if infer_metadata:
|
|
239
|
-
graph.infer_metadata(False)
|
|
276
|
+
graph.infer_metadata(verbose=False)
|
|
240
277
|
|
|
241
278
|
if edges is None:
|
|
242
|
-
graph.infer_links(False)
|
|
279
|
+
graph.infer_links(verbose=False)
|
|
243
280
|
|
|
244
281
|
if verbose:
|
|
245
282
|
graph.print_metadata()
|
|
@@ -250,9 +287,11 @@ class Graph:
|
|
|
250
287
|
@classmethod
|
|
251
288
|
def from_snowflake(
|
|
252
289
|
cls,
|
|
253
|
-
connection: Union['SnowflakeConnection',
|
|
254
|
-
|
|
255
|
-
|
|
290
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
291
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
292
|
+
database: str | None = None,
|
|
293
|
+
schema: str | None = None,
|
|
294
|
+
edges: Sequence[Edge] | None = None,
|
|
256
295
|
infer_metadata: bool = True,
|
|
257
296
|
verbose: bool = True,
|
|
258
297
|
) -> Self:
|
|
@@ -267,7 +306,14 @@ class Graph:
|
|
|
267
306
|
>>> import kumoai.experimental.rfm as rfm
|
|
268
307
|
|
|
269
308
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
270
|
-
>>> graph = rfm.Graph.from_snowflake()
|
|
309
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
310
|
+
|
|
311
|
+
>>> # Fine-grained control over table specification:
|
|
312
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
313
|
+
... 'USERS',
|
|
314
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
315
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
316
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
271
317
|
|
|
272
318
|
Args:
|
|
273
319
|
connection: An open connection from
|
|
@@ -276,8 +322,11 @@ class Graph:
|
|
|
276
322
|
connection. If ``None``, will re-use an active session in case
|
|
277
323
|
it exists, or create a new connection from credentials stored
|
|
278
324
|
in environment variables.
|
|
279
|
-
|
|
280
|
-
all tables present in the
|
|
325
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
326
|
+
to include. If ``None``, will add all tables present in the
|
|
327
|
+
current database and schema.
|
|
328
|
+
database: The database.
|
|
329
|
+
schema: The schema.
|
|
281
330
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
282
331
|
add to the graph. If not provided, edges will be automatically
|
|
283
332
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -295,27 +344,50 @@ class Graph:
|
|
|
295
344
|
connection = connect(**(connection or {}))
|
|
296
345
|
assert isinstance(connection, Connection)
|
|
297
346
|
|
|
298
|
-
if
|
|
347
|
+
if database is None or schema is None:
|
|
299
348
|
with connection.cursor() as cursor:
|
|
300
349
|
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
301
|
-
|
|
302
|
-
|
|
350
|
+
result = cursor.fetchone()
|
|
351
|
+
database = database or result[0]
|
|
352
|
+
assert database is not None
|
|
353
|
+
schema = schema or result[1]
|
|
354
|
+
|
|
355
|
+
if tables is None:
|
|
356
|
+
if schema is None:
|
|
357
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
358
|
+
"Snowflake schema manually")
|
|
359
|
+
|
|
360
|
+
with connection.cursor() as cursor:
|
|
361
|
+
cursor.execute(f"""
|
|
303
362
|
SELECT TABLE_NAME
|
|
304
363
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
305
364
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
306
|
-
"""
|
|
307
|
-
cursor.
|
|
308
|
-
table_names = [row[0] for row in cursor.fetchall()]
|
|
309
|
-
|
|
310
|
-
tables = [SnowTable(connection, name) for name in table_names]
|
|
365
|
+
""")
|
|
366
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
311
367
|
|
|
312
|
-
|
|
368
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
369
|
+
for table in tables:
|
|
370
|
+
if isinstance(table, str):
|
|
371
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
372
|
+
else:
|
|
373
|
+
kwargs = copy.copy(table)
|
|
374
|
+
kwargs.setdefault('database', database)
|
|
375
|
+
kwargs.setdefault('schema', schema)
|
|
376
|
+
table_kwargs.append(kwargs)
|
|
377
|
+
|
|
378
|
+
graph = cls(
|
|
379
|
+
tables=[
|
|
380
|
+
SnowTable(connection=connection, **kwargs)
|
|
381
|
+
for kwargs in table_kwargs
|
|
382
|
+
],
|
|
383
|
+
edges=edges or [],
|
|
384
|
+
)
|
|
313
385
|
|
|
314
386
|
if infer_metadata:
|
|
315
|
-
graph.infer_metadata(False)
|
|
387
|
+
graph.infer_metadata(verbose=False)
|
|
316
388
|
|
|
317
389
|
if edges is None:
|
|
318
|
-
graph.infer_links(False)
|
|
390
|
+
graph.infer_links(verbose=False)
|
|
319
391
|
|
|
320
392
|
if verbose:
|
|
321
393
|
graph.print_metadata()
|
|
@@ -323,7 +395,187 @@ class Graph:
|
|
|
323
395
|
|
|
324
396
|
return graph
|
|
325
397
|
|
|
326
|
-
|
|
398
|
+
@classmethod
|
|
399
|
+
def from_snowflake_semantic_view(
|
|
400
|
+
cls,
|
|
401
|
+
semantic_view_name: str,
|
|
402
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
403
|
+
verbose: bool = True,
|
|
404
|
+
) -> Self:
|
|
405
|
+
import yaml
|
|
406
|
+
|
|
407
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
408
|
+
Connection,
|
|
409
|
+
SnowTable,
|
|
410
|
+
connect,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if not isinstance(connection, Connection):
|
|
414
|
+
connection = connect(**(connection or {}))
|
|
415
|
+
assert isinstance(connection, Connection)
|
|
416
|
+
|
|
417
|
+
with connection.cursor() as cursor:
|
|
418
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
419
|
+
f"'{semantic_view_name}')")
|
|
420
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
421
|
+
|
|
422
|
+
graph = cls(tables=[])
|
|
423
|
+
|
|
424
|
+
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
426
|
+
for table_cfg in cfg['tables']:
|
|
427
|
+
table_name = table_cfg['name']
|
|
428
|
+
source_table_name = table_cfg['base_table']['table']
|
|
429
|
+
database = table_cfg['base_table']['database']
|
|
430
|
+
schema = table_cfg['base_table']['schema']
|
|
431
|
+
|
|
432
|
+
primary_key: str | None = None
|
|
433
|
+
if 'primary_key' in table_cfg:
|
|
434
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
435
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
436
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
437
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
438
|
+
msgs.append(f"Failed to add primary key for table "
|
|
439
|
+
f"'{table_name}' since composite primary keys "
|
|
440
|
+
f"are not yet supported")
|
|
441
|
+
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
444
|
+
for column_cfg in chain(
|
|
445
|
+
table_cfg.get('dimensions', []),
|
|
446
|
+
table_cfg.get('time_dimensions', []),
|
|
447
|
+
table_cfg.get('facts', []),
|
|
448
|
+
):
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
483
|
+
|
|
484
|
+
table = SnowTable(
|
|
485
|
+
connection,
|
|
486
|
+
name=table_name,
|
|
487
|
+
source_name=source_table_name,
|
|
488
|
+
database=database,
|
|
489
|
+
schema=schema,
|
|
490
|
+
columns=columns,
|
|
491
|
+
primary_key=primary_key,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
+
table.infer_time_column(verbose=False)
|
|
496
|
+
|
|
497
|
+
graph.add_table(table)
|
|
498
|
+
|
|
499
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
500
|
+
name = relation_cfg['name']
|
|
501
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
502
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
503
|
+
f"composite key references are not yet supported")
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
left_table = relation_cfg['left_table']
|
|
507
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
508
|
+
right_table = relation_cfg['right_table']
|
|
509
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
510
|
+
|
|
511
|
+
if graph[right_table]._primary_key != right_key:
|
|
512
|
+
# Semantic view error - this should never be triggered:
|
|
513
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
514
|
+
f"referenced key '{right_key}' of table "
|
|
515
|
+
f"'{right_table}' is not a primary key")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if graph[left_table]._primary_key == left_key:
|
|
519
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
520
|
+
f"referencing key '{left_key}' of table "
|
|
521
|
+
f"'{left_table}' is a primary key")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
if left_key not in graph[left_table]:
|
|
525
|
+
graph[left_table].add_column(left_key)
|
|
526
|
+
|
|
527
|
+
graph.link(left_table, left_key, right_table)
|
|
528
|
+
|
|
529
|
+
graph.validate()
|
|
530
|
+
|
|
531
|
+
if verbose:
|
|
532
|
+
graph.print_metadata()
|
|
533
|
+
graph.print_links()
|
|
534
|
+
|
|
535
|
+
if len(msgs) > 0:
|
|
536
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
537
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
538
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
539
|
+
|
|
540
|
+
return graph
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
571
|
+
# Backend #################################################################
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def backend(self) -> DataBackend | None:
|
|
575
|
+
backends = [table.backend for table in self._tables.values()]
|
|
576
|
+
return backends[0] if len(backends) > 0 else None
|
|
577
|
+
|
|
578
|
+
# Tables ##################################################################
|
|
327
579
|
|
|
328
580
|
def has_table(self, name: str) -> bool:
|
|
329
581
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -342,7 +594,7 @@ class Graph:
|
|
|
342
594
|
return self.tables[name]
|
|
343
595
|
|
|
344
596
|
@property
|
|
345
|
-
def tables(self) ->
|
|
597
|
+
def tables(self) -> dict[str, Table]:
|
|
346
598
|
r"""Returns the dictionary of table objects."""
|
|
347
599
|
return self._tables
|
|
348
600
|
|
|
@@ -362,13 +614,10 @@ class Graph:
|
|
|
362
614
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
363
615
|
f"this graph; table names must be globally unique.")
|
|
364
616
|
|
|
365
|
-
if
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
f"'{table.__class__.__name__}' to this "
|
|
370
|
-
f"graph since other tables are of type "
|
|
371
|
-
f"'{cls.__name__}'.")
|
|
617
|
+
if self.backend is not None and table.backend != self.backend:
|
|
618
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
619
|
+
f"'{table.backend}' to this graph since other "
|
|
620
|
+
f"tables have backend '{self.backend}'.")
|
|
372
621
|
|
|
373
622
|
self._tables[table.name] = table
|
|
374
623
|
|
|
@@ -430,20 +679,8 @@ class Graph:
|
|
|
430
679
|
|
|
431
680
|
def print_metadata(self) -> None:
|
|
432
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
display(Markdown('### 🗂️ Graph Metadata'))
|
|
436
|
-
df = self.metadata
|
|
437
|
-
try:
|
|
438
|
-
if hasattr(df.style, 'hide'):
|
|
439
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
440
|
-
else:
|
|
441
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
442
|
-
except ImportError:
|
|
443
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
444
|
-
else:
|
|
445
|
-
print("🗂️ Graph Metadata:")
|
|
446
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
447
684
|
|
|
448
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
449
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -466,42 +703,33 @@ class Graph:
|
|
|
466
703
|
# Edges ###################################################################
|
|
467
704
|
|
|
468
705
|
@property
|
|
469
|
-
def edges(self) ->
|
|
706
|
+
def edges(self) -> list[Edge]:
|
|
470
707
|
r"""Returns the edges of the graph."""
|
|
471
708
|
return self._edges
|
|
472
709
|
|
|
473
710
|
def print_links(self) -> None:
|
|
474
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
475
|
-
edges = [(
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
else:
|
|
489
|
-
display(Markdown('*No links registered*'))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
490
725
|
else:
|
|
491
|
-
|
|
492
|
-
if len(edges) > 0:
|
|
493
|
-
print('\n'.join([
|
|
494
|
-
f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
|
|
495
|
-
for edge in edges
|
|
496
|
-
]))
|
|
497
|
-
else:
|
|
498
|
-
print('No links registered')
|
|
726
|
+
display.italic("No links registered")
|
|
499
727
|
|
|
500
728
|
def link(
|
|
501
729
|
self,
|
|
502
|
-
src_table:
|
|
730
|
+
src_table: str | Table,
|
|
503
731
|
fkey: str,
|
|
504
|
-
dst_table:
|
|
732
|
+
dst_table: str | Table,
|
|
505
733
|
) -> Self:
|
|
506
734
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
507
735
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -562,9 +790,9 @@ class Graph:
|
|
|
562
790
|
|
|
563
791
|
def unlink(
|
|
564
792
|
self,
|
|
565
|
-
src_table:
|
|
793
|
+
src_table: str | Table,
|
|
566
794
|
fkey: str,
|
|
567
|
-
dst_table:
|
|
795
|
+
dst_table: str | Table,
|
|
568
796
|
) -> Self:
|
|
569
797
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
570
798
|
|
|
@@ -602,6 +830,30 @@ class Graph:
|
|
|
602
830
|
"""
|
|
603
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
604
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
605
857
|
# A list of primary key candidates (+score) for every column:
|
|
606
858
|
candidate_dict: dict[
|
|
607
859
|
tuple[str, str],
|
|
@@ -701,13 +953,8 @@ class Graph:
|
|
|
701
953
|
if score < 5.0:
|
|
702
954
|
continue
|
|
703
955
|
|
|
704
|
-
candidate_dict[(
|
|
705
|
-
|
|
706
|
-
src_key.name,
|
|
707
|
-
)].append((
|
|
708
|
-
dst_table.name,
|
|
709
|
-
score,
|
|
710
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
711
958
|
|
|
712
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
713
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -741,6 +988,10 @@ class Graph:
|
|
|
741
988
|
raise ValueError("At least one table needs to be added to the "
|
|
742
989
|
"graph")
|
|
743
990
|
|
|
991
|
+
backends = {table.backend for table in self._tables.values()}
|
|
992
|
+
if len(backends) != 1:
|
|
993
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
994
|
+
|
|
744
995
|
for edge in self.edges:
|
|
745
996
|
src_table, fkey, dst_table = edge
|
|
746
997
|
|
|
@@ -762,24 +1013,26 @@ class Graph:
|
|
|
762
1013
|
f"either the primary key or the link before "
|
|
763
1014
|
f"before proceeding.")
|
|
764
1015
|
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
783
1036
|
|
|
784
1037
|
return self
|
|
785
1038
|
|
|
@@ -787,7 +1040,7 @@ class Graph:
|
|
|
787
1040
|
|
|
788
1041
|
def visualize(
|
|
789
1042
|
self,
|
|
790
|
-
path:
|
|
1043
|
+
path: str | io.BytesIO | None = None,
|
|
791
1044
|
show_columns: bool = True,
|
|
792
1045
|
) -> 'graphviz.Graph':
|
|
793
1046
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -812,33 +1065,33 @@ class Graph:
|
|
|
812
1065
|
|
|
813
1066
|
return True
|
|
814
1067
|
|
|
815
|
-
# Check basic dependency:
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
1068
|
+
try: # Check basic dependency:
|
|
1069
|
+
import graphviz
|
|
1070
|
+
except ImportError as e:
|
|
1071
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1072
|
+
"visualization") from e
|
|
1073
|
+
|
|
1074
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
820
1075
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
821
1076
|
"executables are not installed. These "
|
|
822
1077
|
"dependencies are required in addition to the "
|
|
823
1078
|
"'graphviz' Python package. Please install "
|
|
824
1079
|
"them as described at "
|
|
825
1080
|
"https://graphviz.org/download/.")
|
|
826
|
-
else:
|
|
827
|
-
import graphviz
|
|
828
1081
|
|
|
829
|
-
format:
|
|
1082
|
+
format: str | None = None
|
|
830
1083
|
if isinstance(path, str):
|
|
831
1084
|
format = path.split('.')[-1]
|
|
832
1085
|
elif isinstance(path, io.BytesIO):
|
|
833
1086
|
format = 'svg'
|
|
834
1087
|
graph = graphviz.Graph(format=format)
|
|
835
1088
|
|
|
836
|
-
def left_align(keys:
|
|
1089
|
+
def left_align(keys: list[str]) -> str:
|
|
837
1090
|
if len(keys) == 0:
|
|
838
1091
|
return ""
|
|
839
1092
|
return '\\l'.join(keys) + '\\l'
|
|
840
1093
|
|
|
841
|
-
fkeys_dict:
|
|
1094
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
842
1095
|
for src_table_name, fkey_name, _ in self.edges:
|
|
843
1096
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
844
1097
|
|
|
@@ -908,6 +1161,9 @@ class Graph:
|
|
|
908
1161
|
graph.render(path, cleanup=True)
|
|
909
1162
|
elif isinstance(path, io.BytesIO):
|
|
910
1163
|
path.write(graph.pipe())
|
|
1164
|
+
elif in_snowflake_notebook():
|
|
1165
|
+
import streamlit as st
|
|
1166
|
+
st.graphviz_chart(graph)
|
|
911
1167
|
elif in_notebook():
|
|
912
1168
|
from IPython.display import display
|
|
913
1169
|
display(graph)
|
|
@@ -931,8 +1187,8 @@ class Graph:
|
|
|
931
1187
|
# Helpers #################################################################
|
|
932
1188
|
|
|
933
1189
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
934
|
-
tables:
|
|
935
|
-
col_groups:
|
|
1190
|
+
tables: dict[str, TableDefinition] = {}
|
|
1191
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
936
1192
|
for table_name, table in self.tables.items():
|
|
937
1193
|
tables[table_name] = table._to_api_table_definition()
|
|
938
1194
|
if table.primary_key is None:
|
|
@@ -975,3 +1231,7 @@ class Graph:
|
|
|
975
1231
|
f' tables={tables},\n'
|
|
976
1232
|
f' edges={edges},\n'
|
|
977
1233
|
f')')
|
|
1234
|
+
|
|
1235
|
+
def __del__(self) -> None:
|
|
1236
|
+
if hasattr(self, '_connection'):
|
|
1237
|
+
self._connection.close()
|