kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +35 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +44 -9
- kumoai/experimental/rfm/__init__.py +70 -68
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +67 -0
- kumoai/experimental/rfm/base/sampler.py +782 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +366 -0
- kumoai/experimental/rfm/base/table.py +741 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +82 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +775 -481
- kumoai/experimental/rfm/sagemaker.py +15 -7
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +190 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +10 -8
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +54 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import io
|
|
3
6
|
import warnings
|
|
4
7
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from itertools import chain
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
7
13
|
|
|
8
14
|
import pandas as pd
|
|
9
15
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,20 +17,30 @@ from kumoapi.table import TableDefinition
|
|
|
11
17
|
from kumoapi.typing import Stype
|
|
12
18
|
from typing_extensions import Self
|
|
13
19
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
20
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
16
22
|
from kumoai.graph import Edge
|
|
23
|
+
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
17
25
|
|
|
18
26
|
if TYPE_CHECKING:
|
|
19
27
|
import graphviz
|
|
28
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
29
|
+
from snowflake.connector import SnowflakeConnection
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class SqliteConnectionConfig(CastMixin):
|
|
34
|
+
uri: str | Path
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
20
36
|
|
|
21
37
|
|
|
22
|
-
class
|
|
23
|
-
r"""A graph of :class:`
|
|
38
|
+
class Graph:
|
|
39
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
40
|
tables in a relational database.
|
|
25
41
|
|
|
26
42
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
43
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
44
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
45
|
|
|
30
46
|
.. code-block:: python
|
|
@@ -44,7 +60,7 @@ class LocalGraph:
|
|
|
44
60
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
61
|
|
|
46
62
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
63
|
+
>>> graph = rfm.Graph({
|
|
48
64
|
... "table1": table1,
|
|
49
65
|
... "table2": table2,
|
|
50
66
|
... "table3": table3,
|
|
@@ -75,33 +91,55 @@ class LocalGraph:
|
|
|
75
91
|
|
|
76
92
|
def __init__(
|
|
77
93
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges:
|
|
94
|
+
tables: Sequence[Table],
|
|
95
|
+
edges: Sequence[Edge] | None = None,
|
|
80
96
|
) -> None:
|
|
81
97
|
|
|
82
|
-
self._tables:
|
|
83
|
-
self._edges:
|
|
98
|
+
self._tables: dict[str, Table] = {}
|
|
99
|
+
self._edges: list[Edge] = []
|
|
84
100
|
|
|
85
101
|
for table in tables:
|
|
86
102
|
self.add_table(table)
|
|
87
103
|
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
106
|
+
continue
|
|
107
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
108
|
+
if fkey.name not in table:
|
|
109
|
+
continue
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
112
|
+
dst_table_names = [
|
|
113
|
+
table.name for table in self.tables.values()
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
115
|
+
]
|
|
116
|
+
if len(dst_table_names) != 1:
|
|
117
|
+
continue
|
|
118
|
+
dst_table = self[dst_table_names[0]]
|
|
119
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
120
|
+
continue
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
122
|
+
continue
|
|
123
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
124
|
+
|
|
88
125
|
for edge in (edges or []):
|
|
89
126
|
_edge = Edge._cast(edge)
|
|
90
127
|
assert _edge is not None
|
|
91
|
-
self.
|
|
128
|
+
if _edge not in self._edges:
|
|
129
|
+
self.link(*_edge)
|
|
92
130
|
|
|
93
131
|
@classmethod
|
|
94
132
|
def from_data(
|
|
95
133
|
cls,
|
|
96
|
-
df_dict:
|
|
97
|
-
edges:
|
|
134
|
+
df_dict: dict[str, pd.DataFrame],
|
|
135
|
+
edges: Sequence[Edge] | None = None,
|
|
98
136
|
infer_metadata: bool = True,
|
|
99
137
|
verbose: bool = True,
|
|
100
138
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
139
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
140
|
:class:`pandas.DataFrame` objects.
|
|
103
141
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
142
|
+
Automatically infers table metadata and links by default.
|
|
105
143
|
|
|
106
144
|
.. code-block:: python
|
|
107
145
|
|
|
@@ -115,59 +153,429 @@ class LocalGraph:
|
|
|
115
153
|
>>> df3 = pd.DataFrame(...)
|
|
116
154
|
|
|
117
155
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
156
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
157
|
... "table1": df1,
|
|
120
158
|
... "table2": df2,
|
|
121
159
|
... "table3": df3,
|
|
122
160
|
... })
|
|
123
161
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
162
|
Args:
|
|
132
163
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
164
|
of the tables and the values hold table data.
|
|
165
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
166
|
+
add to the graph. If not provided, edges will be automatically
|
|
167
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
168
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
169
|
graph.
|
|
170
|
+
verbose: Whether to print verbose output.
|
|
171
|
+
"""
|
|
172
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
173
|
+
|
|
174
|
+
graph = cls(
|
|
175
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
176
|
+
edges=edges or [],
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if infer_metadata:
|
|
180
|
+
graph.infer_metadata(verbose=False)
|
|
181
|
+
|
|
182
|
+
if edges is None:
|
|
183
|
+
graph.infer_links(verbose=False)
|
|
184
|
+
|
|
185
|
+
if verbose:
|
|
186
|
+
graph.print_metadata()
|
|
187
|
+
graph.print_links()
|
|
188
|
+
|
|
189
|
+
return graph
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def from_sqlite(
|
|
193
|
+
cls,
|
|
194
|
+
connection: Union[
|
|
195
|
+
'AdbcSqliteConnection',
|
|
196
|
+
SqliteConnectionConfig,
|
|
197
|
+
str,
|
|
198
|
+
Path,
|
|
199
|
+
dict[str, Any],
|
|
200
|
+
],
|
|
201
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
202
|
+
edges: Sequence[Edge] | None = None,
|
|
203
|
+
infer_metadata: bool = True,
|
|
204
|
+
verbose: bool = True,
|
|
205
|
+
) -> Self:
|
|
206
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
207
|
+
|
|
208
|
+
Automatically infers table metadata and links by default.
|
|
209
|
+
|
|
210
|
+
.. code-block:: python
|
|
211
|
+
|
|
212
|
+
>>> # doctest: +SKIP
|
|
213
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
214
|
+
|
|
215
|
+
>>> # Create a graph from a SQLite database:
|
|
216
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
217
|
+
|
|
218
|
+
>>> # Fine-grained control over table specification:
|
|
219
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
220
|
+
... 'USERS',
|
|
221
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
222
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
223
|
+
... ])
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
connection: An open connection from
|
|
227
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
228
|
+
path to the database file.
|
|
229
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
230
|
+
arguments to include. If ``None``, will add all tables present
|
|
231
|
+
in the database.
|
|
136
232
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
233
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
234
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
235
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
236
|
+
the graph.
|
|
139
237
|
verbose: Whether to print verbose output.
|
|
238
|
+
"""
|
|
239
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
240
|
+
Connection,
|
|
241
|
+
SQLiteTable,
|
|
242
|
+
connect,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
internal_connection = False
|
|
246
|
+
if not isinstance(connection, Connection):
|
|
247
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
248
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
249
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
250
|
+
internal_connection = True
|
|
251
|
+
assert isinstance(connection, Connection)
|
|
252
|
+
|
|
253
|
+
if tables is None:
|
|
254
|
+
with connection.cursor() as cursor:
|
|
255
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
256
|
+
"WHERE type='table'")
|
|
257
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
258
|
+
|
|
259
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
260
|
+
for table in tables:
|
|
261
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
262
|
+
table_kwargs.append(kwargs)
|
|
140
263
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
264
|
+
graph = cls(
|
|
265
|
+
tables=[
|
|
266
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
267
|
+
for kwargs in table_kwargs
|
|
268
|
+
],
|
|
269
|
+
edges=edges or [],
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if internal_connection:
|
|
273
|
+
graph._connection = connection # type: ignore
|
|
274
|
+
|
|
275
|
+
if infer_metadata:
|
|
276
|
+
graph.infer_metadata(verbose=False)
|
|
277
|
+
|
|
278
|
+
if edges is None:
|
|
279
|
+
graph.infer_links(verbose=False)
|
|
280
|
+
|
|
281
|
+
if verbose:
|
|
282
|
+
graph.print_metadata()
|
|
283
|
+
graph.print_links()
|
|
284
|
+
|
|
285
|
+
return graph
|
|
286
|
+
|
|
287
|
+
@classmethod
|
|
288
|
+
def from_snowflake(
|
|
289
|
+
cls,
|
|
290
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
291
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
292
|
+
database: str | None = None,
|
|
293
|
+
schema: str | None = None,
|
|
294
|
+
edges: Sequence[Edge] | None = None,
|
|
295
|
+
infer_metadata: bool = True,
|
|
296
|
+
verbose: bool = True,
|
|
297
|
+
) -> Self:
|
|
298
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
299
|
+
schema.
|
|
300
|
+
|
|
301
|
+
Automatically infers table metadata and links by default.
|
|
302
|
+
|
|
303
|
+
.. code-block:: python
|
|
144
304
|
|
|
145
|
-
Example:
|
|
146
305
|
>>> # doctest: +SKIP
|
|
147
306
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
>>> df1 = pd.DataFrame(...)
|
|
149
|
-
>>> df2 = pd.DataFrame(...)
|
|
150
|
-
>>> df3 = pd.DataFrame(...)
|
|
151
|
-
>>> graph = rfm.LocalGraph.from_data(data={
|
|
152
|
-
... "table1": df1,
|
|
153
|
-
... "table2": df2,
|
|
154
|
-
... "table3": df3,
|
|
155
|
-
... })
|
|
156
|
-
>>> graph.validate()
|
|
157
|
-
"""
|
|
158
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
159
307
|
|
|
160
|
-
|
|
308
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
309
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
310
|
+
|
|
311
|
+
>>> # Fine-grained control over table specification:
|
|
312
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
313
|
+
... 'USERS',
|
|
314
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
315
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
316
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
connection: An open connection from
|
|
320
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
321
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
322
|
+
connection. If ``None``, will re-use an active session in case
|
|
323
|
+
it exists, or create a new connection from credentials stored
|
|
324
|
+
in environment variables.
|
|
325
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
326
|
+
to include. If ``None``, will add all tables present in the
|
|
327
|
+
current database and schema.
|
|
328
|
+
database: The database.
|
|
329
|
+
schema: The schema.
|
|
330
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
331
|
+
add to the graph. If not provided, edges will be automatically
|
|
332
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
333
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
334
|
+
graph.
|
|
335
|
+
verbose: Whether to print verbose output.
|
|
336
|
+
"""
|
|
337
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
338
|
+
Connection,
|
|
339
|
+
SnowTable,
|
|
340
|
+
connect,
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if not isinstance(connection, Connection):
|
|
344
|
+
connection = connect(**(connection or {}))
|
|
345
|
+
assert isinstance(connection, Connection)
|
|
346
|
+
|
|
347
|
+
if database is None or schema is None:
|
|
348
|
+
with connection.cursor() as cursor:
|
|
349
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
350
|
+
result = cursor.fetchone()
|
|
351
|
+
database = database or result[0]
|
|
352
|
+
assert database is not None
|
|
353
|
+
schema = schema or result[1]
|
|
354
|
+
|
|
355
|
+
if tables is None:
|
|
356
|
+
if schema is None:
|
|
357
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
358
|
+
"Snowflake schema manually")
|
|
359
|
+
|
|
360
|
+
with connection.cursor() as cursor:
|
|
361
|
+
cursor.execute(f"""
|
|
362
|
+
SELECT TABLE_NAME
|
|
363
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
364
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
365
|
+
""")
|
|
366
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
367
|
+
|
|
368
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
369
|
+
for table in tables:
|
|
370
|
+
if isinstance(table, str):
|
|
371
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
372
|
+
else:
|
|
373
|
+
kwargs = copy.copy(table)
|
|
374
|
+
kwargs.setdefault('database', database)
|
|
375
|
+
kwargs.setdefault('schema', schema)
|
|
376
|
+
table_kwargs.append(kwargs)
|
|
377
|
+
|
|
378
|
+
graph = cls(
|
|
379
|
+
tables=[
|
|
380
|
+
SnowTable(connection=connection, **kwargs)
|
|
381
|
+
for kwargs in table_kwargs
|
|
382
|
+
],
|
|
383
|
+
edges=edges or [],
|
|
384
|
+
)
|
|
161
385
|
|
|
162
386
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(verbose)
|
|
387
|
+
graph.infer_metadata(verbose=False)
|
|
164
388
|
|
|
165
389
|
if edges is None:
|
|
166
|
-
graph.infer_links(verbose)
|
|
390
|
+
graph.infer_links(verbose=False)
|
|
391
|
+
|
|
392
|
+
if verbose:
|
|
393
|
+
graph.print_metadata()
|
|
394
|
+
graph.print_links()
|
|
167
395
|
|
|
168
396
|
return graph
|
|
169
397
|
|
|
170
|
-
|
|
398
|
+
@classmethod
|
|
399
|
+
def from_snowflake_semantic_view(
|
|
400
|
+
cls,
|
|
401
|
+
semantic_view_name: str,
|
|
402
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
403
|
+
verbose: bool = True,
|
|
404
|
+
) -> Self:
|
|
405
|
+
import yaml
|
|
406
|
+
|
|
407
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
408
|
+
Connection,
|
|
409
|
+
SnowTable,
|
|
410
|
+
connect,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if not isinstance(connection, Connection):
|
|
414
|
+
connection = connect(**(connection or {}))
|
|
415
|
+
assert isinstance(connection, Connection)
|
|
416
|
+
|
|
417
|
+
with connection.cursor() as cursor:
|
|
418
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
419
|
+
f"'{semantic_view_name}')")
|
|
420
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
421
|
+
|
|
422
|
+
graph = cls(tables=[])
|
|
423
|
+
|
|
424
|
+
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
426
|
+
for table_cfg in cfg['tables']:
|
|
427
|
+
table_name = table_cfg['name']
|
|
428
|
+
source_table_name = table_cfg['base_table']['table']
|
|
429
|
+
database = table_cfg['base_table']['database']
|
|
430
|
+
schema = table_cfg['base_table']['schema']
|
|
431
|
+
|
|
432
|
+
primary_key: str | None = None
|
|
433
|
+
if 'primary_key' in table_cfg:
|
|
434
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
435
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
436
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
437
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
438
|
+
msgs.append(f"Failed to add primary key for table "
|
|
439
|
+
f"'{table_name}' since composite primary keys "
|
|
440
|
+
f"are not yet supported")
|
|
441
|
+
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
444
|
+
for column_cfg in chain(
|
|
445
|
+
table_cfg.get('dimensions', []),
|
|
446
|
+
table_cfg.get('time_dimensions', []),
|
|
447
|
+
table_cfg.get('facts', []),
|
|
448
|
+
):
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
483
|
+
|
|
484
|
+
table = SnowTable(
|
|
485
|
+
connection,
|
|
486
|
+
name=table_name,
|
|
487
|
+
source_name=source_table_name,
|
|
488
|
+
database=database,
|
|
489
|
+
schema=schema,
|
|
490
|
+
columns=columns,
|
|
491
|
+
primary_key=primary_key,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
+
table.infer_time_column(verbose=False)
|
|
496
|
+
|
|
497
|
+
graph.add_table(table)
|
|
498
|
+
|
|
499
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
500
|
+
name = relation_cfg['name']
|
|
501
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
502
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
503
|
+
f"composite key references are not yet supported")
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
left_table = relation_cfg['left_table']
|
|
507
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
508
|
+
right_table = relation_cfg['right_table']
|
|
509
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
510
|
+
|
|
511
|
+
if graph[right_table]._primary_key != right_key:
|
|
512
|
+
# Semantic view error - this should never be triggered:
|
|
513
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
514
|
+
f"referenced key '{right_key}' of table "
|
|
515
|
+
f"'{right_table}' is not a primary key")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if graph[left_table]._primary_key == left_key:
|
|
519
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
520
|
+
f"referencing key '{left_key}' of table "
|
|
521
|
+
f"'{left_table}' is a primary key")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
if left_key not in graph[left_table]:
|
|
525
|
+
graph[left_table].add_column(left_key)
|
|
526
|
+
|
|
527
|
+
graph.link(left_table, left_key, right_table)
|
|
528
|
+
|
|
529
|
+
graph.validate()
|
|
530
|
+
|
|
531
|
+
if verbose:
|
|
532
|
+
graph.print_metadata()
|
|
533
|
+
graph.print_links()
|
|
534
|
+
|
|
535
|
+
if len(msgs) > 0:
|
|
536
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
537
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
538
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
539
|
+
|
|
540
|
+
return graph
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
571
|
+
# Backend #################################################################
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def backend(self) -> DataBackend | None:
|
|
575
|
+
backends = [table.backend for table in self._tables.values()]
|
|
576
|
+
return backends[0] if len(backends) > 0 else None
|
|
577
|
+
|
|
578
|
+
# Tables ##################################################################
|
|
171
579
|
|
|
172
580
|
def has_table(self, name: str) -> bool:
|
|
173
581
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -175,7 +583,7 @@ class LocalGraph:
|
|
|
175
583
|
"""
|
|
176
584
|
return name in self.tables
|
|
177
585
|
|
|
178
|
-
def table(self, name: str) ->
|
|
586
|
+
def table(self, name: str) -> Table:
|
|
179
587
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
588
|
|
|
181
589
|
Raises:
|
|
@@ -186,11 +594,11 @@ class LocalGraph:
|
|
|
186
594
|
return self.tables[name]
|
|
187
595
|
|
|
188
596
|
@property
|
|
189
|
-
def tables(self) ->
|
|
597
|
+
def tables(self) -> dict[str, Table]:
|
|
190
598
|
r"""Returns the dictionary of table objects."""
|
|
191
599
|
return self._tables
|
|
192
600
|
|
|
193
|
-
def add_table(self, table:
|
|
601
|
+
def add_table(self, table: Table) -> Self:
|
|
194
602
|
r"""Adds a table to the graph.
|
|
195
603
|
|
|
196
604
|
Args:
|
|
@@ -199,11 +607,18 @@ class LocalGraph:
|
|
|
199
607
|
Raises:
|
|
200
608
|
KeyError: If a table with the same name already exists in the
|
|
201
609
|
graph.
|
|
610
|
+
ValueError: If the table belongs to a different backend than the
|
|
611
|
+
rest of the tables in the graph.
|
|
202
612
|
"""
|
|
203
613
|
if table.name in self._tables:
|
|
204
614
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
615
|
f"this graph; table names must be globally unique.")
|
|
206
616
|
|
|
617
|
+
if self.backend is not None and table.backend != self.backend:
|
|
618
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
619
|
+
f"'{table.backend}' to this graph since other "
|
|
620
|
+
f"tables have backend '{self.backend}'.")
|
|
621
|
+
|
|
207
622
|
self._tables[table.name] = table
|
|
208
623
|
|
|
209
624
|
return self
|
|
@@ -234,28 +649,28 @@ class LocalGraph:
|
|
|
234
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
235
650
|
information about the tables in this graph.
|
|
236
651
|
|
|
237
|
-
The returned dataframe has columns ``
|
|
238
|
-
``
|
|
239
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
240
655
|
|
|
241
656
|
Example:
|
|
242
657
|
>>> # doctest: +SKIP
|
|
243
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
659
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
660
|
>>> graph.metadata # doctest: +SKIP
|
|
246
|
-
|
|
247
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
248
663
|
"""
|
|
249
664
|
tables = list(self.tables.values())
|
|
250
665
|
|
|
251
666
|
return pd.DataFrame({
|
|
252
|
-
'
|
|
667
|
+
'Name':
|
|
253
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
254
|
-
'
|
|
669
|
+
'Primary Key':
|
|
255
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
256
|
-
'
|
|
671
|
+
'Time Column':
|
|
257
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
258
|
-
'
|
|
673
|
+
'End Time Column':
|
|
259
674
|
pd.Series(
|
|
260
675
|
dtype=str,
|
|
261
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -263,21 +678,9 @@ class LocalGraph:
|
|
|
263
678
|
})
|
|
264
679
|
|
|
265
680
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
display(Markdown('### 🗂️ Graph Metadata'))
|
|
270
|
-
df = self.metadata
|
|
271
|
-
try:
|
|
272
|
-
if hasattr(df.style, 'hide'):
|
|
273
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
274
|
-
else:
|
|
275
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
276
|
-
except ImportError:
|
|
277
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
278
|
-
else:
|
|
279
|
-
print("🗂️ Graph Metadata:")
|
|
280
|
-
print(self.metadata.to_string(index=False))
|
|
681
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
281
684
|
|
|
282
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
283
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -287,7 +690,7 @@ class LocalGraph:
|
|
|
287
690
|
|
|
288
691
|
Note:
|
|
289
692
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
693
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
694
|
"""
|
|
292
695
|
for table in self.tables.values():
|
|
293
696
|
table.infer_metadata(verbose=False)
|
|
@@ -300,42 +703,33 @@ class LocalGraph:
|
|
|
300
703
|
# Edges ###################################################################
|
|
301
704
|
|
|
302
705
|
@property
|
|
303
|
-
def edges(self) ->
|
|
706
|
+
def edges(self) -> list[Edge]:
|
|
304
707
|
r"""Returns the edges of the graph."""
|
|
305
708
|
return self._edges
|
|
306
709
|
|
|
307
710
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
309
|
-
edges = [(
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
else:
|
|
323
|
-
display(Markdown('*No links registered*'))
|
|
711
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
324
725
|
else:
|
|
325
|
-
|
|
326
|
-
if len(edges) > 0:
|
|
327
|
-
print('\n'.join([
|
|
328
|
-
f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
|
|
329
|
-
for edge in edges
|
|
330
|
-
]))
|
|
331
|
-
else:
|
|
332
|
-
print('No links registered')
|
|
726
|
+
display.italic("No links registered")
|
|
333
727
|
|
|
334
728
|
def link(
|
|
335
729
|
self,
|
|
336
|
-
src_table:
|
|
730
|
+
src_table: str | Table,
|
|
337
731
|
fkey: str,
|
|
338
|
-
dst_table:
|
|
732
|
+
dst_table: str | Table,
|
|
339
733
|
) -> Self:
|
|
340
734
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
735
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +752,11 @@ class LocalGraph:
|
|
|
358
752
|
table does not exist in the graph, if the source key does not
|
|
359
753
|
exist in the source table.
|
|
360
754
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
755
|
+
if isinstance(src_table, Table):
|
|
362
756
|
src_table = src_table.name
|
|
363
757
|
assert isinstance(src_table, str)
|
|
364
758
|
|
|
365
|
-
if isinstance(dst_table,
|
|
759
|
+
if isinstance(dst_table, Table):
|
|
366
760
|
dst_table = dst_table.name
|
|
367
761
|
assert isinstance(dst_table, str)
|
|
368
762
|
|
|
@@ -396,9 +790,9 @@ class LocalGraph:
|
|
|
396
790
|
|
|
397
791
|
def unlink(
|
|
398
792
|
self,
|
|
399
|
-
src_table:
|
|
793
|
+
src_table: str | Table,
|
|
400
794
|
fkey: str,
|
|
401
|
-
dst_table:
|
|
795
|
+
dst_table: str | Table,
|
|
402
796
|
) -> Self:
|
|
403
797
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
798
|
|
|
@@ -410,11 +804,11 @@ class LocalGraph:
|
|
|
410
804
|
Raises:
|
|
411
805
|
ValueError: if the edge is not present in the graph.
|
|
412
806
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
807
|
+
if isinstance(src_table, Table):
|
|
414
808
|
src_table = src_table.name
|
|
415
809
|
assert isinstance(src_table, str)
|
|
416
810
|
|
|
417
|
-
if isinstance(dst_table,
|
|
811
|
+
if isinstance(dst_table, Table):
|
|
418
812
|
dst_table = dst_table.name
|
|
419
813
|
assert isinstance(dst_table, str)
|
|
420
814
|
|
|
@@ -428,17 +822,37 @@ class LocalGraph:
|
|
|
428
822
|
return self
|
|
429
823
|
|
|
430
824
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
825
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
826
|
+
graph.
|
|
432
827
|
|
|
433
828
|
Args:
|
|
434
829
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
830
|
"""
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
831
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
832
|
+
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
442
856
|
|
|
443
857
|
# A list of primary key candidates (+score) for every column:
|
|
444
858
|
candidate_dict: dict[
|
|
@@ -463,6 +877,9 @@ class LocalGraph:
|
|
|
463
877
|
src_table_name = src_table.name.lower()
|
|
464
878
|
|
|
465
879
|
for src_key in src_table.columns:
|
|
880
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
881
|
+
continue
|
|
882
|
+
|
|
466
883
|
if src_key == src_table.primary_key:
|
|
467
884
|
continue # Cannot link to primary key.
|
|
468
885
|
|
|
@@ -528,19 +945,16 @@ class LocalGraph:
|
|
|
528
945
|
score += 1.0
|
|
529
946
|
|
|
530
947
|
# Cardinality ratio:
|
|
531
|
-
if
|
|
948
|
+
if (src_table._num_rows is not None
|
|
949
|
+
and dst_table._num_rows is not None
|
|
950
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
532
951
|
score += 1.0
|
|
533
952
|
|
|
534
953
|
if score < 5.0:
|
|
535
954
|
continue
|
|
536
955
|
|
|
537
|
-
candidate_dict[(
|
|
538
|
-
|
|
539
|
-
src_key.name,
|
|
540
|
-
)].append((
|
|
541
|
-
dst_table.name,
|
|
542
|
-
score,
|
|
543
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
544
958
|
|
|
545
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
546
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -574,6 +988,10 @@ class LocalGraph:
|
|
|
574
988
|
raise ValueError("At least one table needs to be added to the "
|
|
575
989
|
"graph")
|
|
576
990
|
|
|
991
|
+
backends = {table.backend for table in self._tables.values()}
|
|
992
|
+
if len(backends) != 1:
|
|
993
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
994
|
+
|
|
577
995
|
for edge in self.edges:
|
|
578
996
|
src_table, fkey, dst_table = edge
|
|
579
997
|
|
|
@@ -595,24 +1013,26 @@ class LocalGraph:
|
|
|
595
1013
|
f"either the primary key or the link before "
|
|
596
1014
|
f"before proceeding.")
|
|
597
1015
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
616
1036
|
|
|
617
1037
|
return self
|
|
618
1038
|
|
|
@@ -620,7 +1040,7 @@ class LocalGraph:
|
|
|
620
1040
|
|
|
621
1041
|
def visualize(
|
|
622
1042
|
self,
|
|
623
|
-
path:
|
|
1043
|
+
path: str | io.BytesIO | None = None,
|
|
624
1044
|
show_columns: bool = True,
|
|
625
1045
|
) -> 'graphviz.Graph':
|
|
626
1046
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -645,33 +1065,33 @@ class LocalGraph:
|
|
|
645
1065
|
|
|
646
1066
|
return True
|
|
647
1067
|
|
|
648
|
-
# Check basic dependency:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1068
|
+
try: # Check basic dependency:
|
|
1069
|
+
import graphviz
|
|
1070
|
+
except ImportError as e:
|
|
1071
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1072
|
+
"visualization") from e
|
|
1073
|
+
|
|
1074
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
653
1075
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
654
1076
|
"executables are not installed. These "
|
|
655
1077
|
"dependencies are required in addition to the "
|
|
656
1078
|
"'graphviz' Python package. Please install "
|
|
657
1079
|
"them as described at "
|
|
658
1080
|
"https://graphviz.org/download/.")
|
|
659
|
-
else:
|
|
660
|
-
import graphviz
|
|
661
1081
|
|
|
662
|
-
format:
|
|
1082
|
+
format: str | None = None
|
|
663
1083
|
if isinstance(path, str):
|
|
664
1084
|
format = path.split('.')[-1]
|
|
665
1085
|
elif isinstance(path, io.BytesIO):
|
|
666
1086
|
format = 'svg'
|
|
667
1087
|
graph = graphviz.Graph(format=format)
|
|
668
1088
|
|
|
669
|
-
def left_align(keys:
|
|
1089
|
+
def left_align(keys: list[str]) -> str:
|
|
670
1090
|
if len(keys) == 0:
|
|
671
1091
|
return ""
|
|
672
1092
|
return '\\l'.join(keys) + '\\l'
|
|
673
1093
|
|
|
674
|
-
fkeys_dict:
|
|
1094
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
675
1095
|
for src_table_name, fkey_name, _ in self.edges:
|
|
676
1096
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
677
1097
|
|
|
@@ -741,6 +1161,9 @@ class LocalGraph:
|
|
|
741
1161
|
graph.render(path, cleanup=True)
|
|
742
1162
|
elif isinstance(path, io.BytesIO):
|
|
743
1163
|
path.write(graph.pipe())
|
|
1164
|
+
elif in_snowflake_notebook():
|
|
1165
|
+
import streamlit as st
|
|
1166
|
+
st.graphviz_chart(graph)
|
|
744
1167
|
elif in_notebook():
|
|
745
1168
|
from IPython.display import display
|
|
746
1169
|
display(graph)
|
|
@@ -764,8 +1187,8 @@ class LocalGraph:
|
|
|
764
1187
|
# Helpers #################################################################
|
|
765
1188
|
|
|
766
1189
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
767
|
-
tables:
|
|
768
|
-
col_groups:
|
|
1190
|
+
tables: dict[str, TableDefinition] = {}
|
|
1191
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
769
1192
|
for table_name, table in self.tables.items():
|
|
770
1193
|
tables[table_name] = table._to_api_table_definition()
|
|
771
1194
|
if table.primary_key is None:
|
|
@@ -790,7 +1213,7 @@ class LocalGraph:
|
|
|
790
1213
|
def __contains__(self, name: str) -> bool:
|
|
791
1214
|
return self.has_table(name)
|
|
792
1215
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
1216
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
1217
|
return self.table(name)
|
|
795
1218
|
|
|
796
1219
|
def __delitem__(self, name: str) -> None:
|
|
@@ -808,3 +1231,7 @@ class LocalGraph:
|
|
|
808
1231
|
f' tables={tables},\n'
|
|
809
1232
|
f' edges={edges},\n'
|
|
810
1233
|
f')')
|
|
1234
|
+
|
|
1235
|
+
def __del__(self) -> None:
|
|
1236
|
+
if hasattr(self, '_connection'):
|
|
1237
|
+
self._connection.close()
|