kumoai 2.13.0.dev202511131731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +18 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +15 -13
- kumoai/client/jobs.py +24 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +191 -50
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +113 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
- kumoai/experimental/rfm/backend/snow/table.py +242 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
- kumoai/experimental/rfm/base/__init__.py +30 -0
- kumoai/experimental/rfm/base/column.py +152 -0
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +761 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +143 -0
- kumoai/experimental/rfm/base/table.py +753 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
- kumoai/experimental/rfm/infer/__init__.py +8 -0
- kumoai/experimental/rfm/infer/dtype.py +81 -0
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +322 -252
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +13 -2
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +50 -29
- kumoai/experimental/rfm/local_graph_sampler.py +0 -184
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/local_table.py +0 -545
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from itertools import chain
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
7
11
|
|
|
8
12
|
import pandas as pd
|
|
9
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,20 +15,29 @@ from kumoapi.table import TableDefinition
|
|
|
11
15
|
from kumoapi.typing import Stype
|
|
12
16
|
from typing_extensions import Self
|
|
13
17
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
18
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
16
20
|
from kumoai.graph import Edge
|
|
21
|
+
from kumoai.mixin import CastMixin
|
|
17
22
|
|
|
18
23
|
if TYPE_CHECKING:
|
|
19
24
|
import graphviz
|
|
25
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
26
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
27
|
|
|
21
28
|
|
|
22
|
-
|
|
23
|
-
|
|
29
|
+
@dataclass
|
|
30
|
+
class SqliteConnectionConfig(CastMixin):
|
|
31
|
+
uri: str | Path
|
|
32
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Graph:
|
|
36
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
37
|
tables in a relational database.
|
|
25
38
|
|
|
26
39
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
40
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
41
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
42
|
|
|
30
43
|
.. code-block:: python
|
|
@@ -44,7 +57,7 @@ class LocalGraph:
|
|
|
44
57
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
58
|
|
|
46
59
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
60
|
+
>>> graph = rfm.Graph({
|
|
48
61
|
... "table1": table1,
|
|
49
62
|
... "table2": table2,
|
|
50
63
|
... "table3": table3,
|
|
@@ -75,33 +88,55 @@ class LocalGraph:
|
|
|
75
88
|
|
|
76
89
|
def __init__(
|
|
77
90
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges:
|
|
91
|
+
tables: Sequence[Table],
|
|
92
|
+
edges: Sequence[Edge] | None = None,
|
|
80
93
|
) -> None:
|
|
81
94
|
|
|
82
|
-
self._tables:
|
|
83
|
-
self._edges:
|
|
95
|
+
self._tables: dict[str, Table] = {}
|
|
96
|
+
self._edges: list[Edge] = []
|
|
84
97
|
|
|
85
98
|
for table in tables:
|
|
86
99
|
self.add_table(table)
|
|
87
100
|
|
|
101
|
+
for table in tables: # Use links from source metadata:
|
|
102
|
+
if not any(column.is_source for column in table.columns):
|
|
103
|
+
continue
|
|
104
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
105
|
+
if fkey.name not in table:
|
|
106
|
+
continue
|
|
107
|
+
if not table[fkey.name].is_source:
|
|
108
|
+
continue
|
|
109
|
+
dst_table_names = [
|
|
110
|
+
table.name for table in self.tables.values()
|
|
111
|
+
if table.source_name == fkey.dst_table
|
|
112
|
+
]
|
|
113
|
+
if len(dst_table_names) != 1:
|
|
114
|
+
continue
|
|
115
|
+
dst_table = self[dst_table_names[0]]
|
|
116
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
117
|
+
continue
|
|
118
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
119
|
+
continue
|
|
120
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
121
|
+
|
|
88
122
|
for edge in (edges or []):
|
|
89
123
|
_edge = Edge._cast(edge)
|
|
90
124
|
assert _edge is not None
|
|
91
|
-
self.
|
|
125
|
+
if _edge not in self._edges:
|
|
126
|
+
self.link(*_edge)
|
|
92
127
|
|
|
93
128
|
@classmethod
|
|
94
129
|
def from_data(
|
|
95
130
|
cls,
|
|
96
|
-
df_dict:
|
|
97
|
-
edges:
|
|
131
|
+
df_dict: dict[str, pd.DataFrame],
|
|
132
|
+
edges: Sequence[Edge] | None = None,
|
|
98
133
|
infer_metadata: bool = True,
|
|
99
134
|
verbose: bool = True,
|
|
100
135
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
136
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
137
|
:class:`pandas.DataFrame` objects.
|
|
103
138
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
139
|
+
Automatically infers table metadata and links by default.
|
|
105
140
|
|
|
106
141
|
.. code-block:: python
|
|
107
142
|
|
|
@@ -115,59 +150,400 @@ class LocalGraph:
|
|
|
115
150
|
>>> df3 = pd.DataFrame(...)
|
|
116
151
|
|
|
117
152
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
153
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
154
|
... "table1": df1,
|
|
120
155
|
... "table2": df2,
|
|
121
156
|
... "table3": df3,
|
|
122
157
|
... })
|
|
123
158
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
159
|
Args:
|
|
132
160
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
161
|
of the tables and the values hold table data.
|
|
162
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
163
|
+
add to the graph. If not provided, edges will be automatically
|
|
164
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
165
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
166
|
graph.
|
|
167
|
+
verbose: Whether to print verbose output.
|
|
168
|
+
"""
|
|
169
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
170
|
+
|
|
171
|
+
graph = cls(
|
|
172
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
173
|
+
edges=edges or [],
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if infer_metadata:
|
|
177
|
+
graph.infer_metadata(verbose=False)
|
|
178
|
+
|
|
179
|
+
if edges is None:
|
|
180
|
+
graph.infer_links(verbose=False)
|
|
181
|
+
|
|
182
|
+
if verbose:
|
|
183
|
+
graph.print_metadata()
|
|
184
|
+
graph.print_links()
|
|
185
|
+
|
|
186
|
+
return graph
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def from_sqlite(
|
|
190
|
+
cls,
|
|
191
|
+
connection: Union[
|
|
192
|
+
'AdbcSqliteConnection',
|
|
193
|
+
SqliteConnectionConfig,
|
|
194
|
+
str,
|
|
195
|
+
Path,
|
|
196
|
+
dict[str, Any],
|
|
197
|
+
],
|
|
198
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
199
|
+
edges: Sequence[Edge] | None = None,
|
|
200
|
+
infer_metadata: bool = True,
|
|
201
|
+
verbose: bool = True,
|
|
202
|
+
) -> Self:
|
|
203
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
204
|
+
|
|
205
|
+
Automatically infers table metadata and links by default.
|
|
206
|
+
|
|
207
|
+
.. code-block:: python
|
|
208
|
+
|
|
209
|
+
>>> # doctest: +SKIP
|
|
210
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
211
|
+
|
|
212
|
+
>>> # Create a graph from a SQLite database:
|
|
213
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
214
|
+
|
|
215
|
+
>>> # Fine-grained control over table specification:
|
|
216
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
217
|
+
... 'USERS',
|
|
218
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
219
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
220
|
+
... ])
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
connection: An open connection from
|
|
224
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
225
|
+
path to the database file.
|
|
226
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
227
|
+
arguments to include. If ``None``, will add all tables present
|
|
228
|
+
in the database.
|
|
136
229
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
230
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
231
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
232
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
233
|
+
the graph.
|
|
139
234
|
verbose: Whether to print verbose output.
|
|
235
|
+
"""
|
|
236
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
237
|
+
Connection,
|
|
238
|
+
SQLiteTable,
|
|
239
|
+
connect,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
internal_connection = False
|
|
243
|
+
if not isinstance(connection, Connection):
|
|
244
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
245
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
246
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
247
|
+
internal_connection = True
|
|
248
|
+
assert isinstance(connection, Connection)
|
|
249
|
+
|
|
250
|
+
if tables is None:
|
|
251
|
+
with connection.cursor() as cursor:
|
|
252
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
253
|
+
"WHERE type='table'")
|
|
254
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
255
|
+
|
|
256
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
257
|
+
for table in tables:
|
|
258
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
259
|
+
table_kwargs.append(kwargs)
|
|
140
260
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
261
|
+
graph = cls(
|
|
262
|
+
tables=[
|
|
263
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
264
|
+
for kwargs in table_kwargs
|
|
265
|
+
],
|
|
266
|
+
edges=edges or [],
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if internal_connection:
|
|
270
|
+
graph._connection = connection # type: ignore
|
|
271
|
+
|
|
272
|
+
if infer_metadata:
|
|
273
|
+
graph.infer_metadata(verbose=False)
|
|
274
|
+
|
|
275
|
+
if edges is None:
|
|
276
|
+
graph.infer_links(verbose=False)
|
|
277
|
+
|
|
278
|
+
if verbose:
|
|
279
|
+
graph.print_metadata()
|
|
280
|
+
graph.print_links()
|
|
281
|
+
|
|
282
|
+
return graph
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def from_snowflake(
|
|
286
|
+
cls,
|
|
287
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
288
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
289
|
+
database: str | None = None,
|
|
290
|
+
schema: str | None = None,
|
|
291
|
+
edges: Sequence[Edge] | None = None,
|
|
292
|
+
infer_metadata: bool = True,
|
|
293
|
+
verbose: bool = True,
|
|
294
|
+
) -> Self:
|
|
295
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
296
|
+
schema.
|
|
297
|
+
|
|
298
|
+
Automatically infers table metadata and links by default.
|
|
299
|
+
|
|
300
|
+
.. code-block:: python
|
|
144
301
|
|
|
145
|
-
Example:
|
|
146
302
|
>>> # doctest: +SKIP
|
|
147
303
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
>>> df1 = pd.DataFrame(...)
|
|
149
|
-
>>> df2 = pd.DataFrame(...)
|
|
150
|
-
>>> df3 = pd.DataFrame(...)
|
|
151
|
-
>>> graph = rfm.LocalGraph.from_data(data={
|
|
152
|
-
... "table1": df1,
|
|
153
|
-
... "table2": df2,
|
|
154
|
-
... "table3": df3,
|
|
155
|
-
... })
|
|
156
|
-
>>> graph.validate()
|
|
157
|
-
"""
|
|
158
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
159
304
|
|
|
160
|
-
|
|
305
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
306
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
307
|
+
|
|
308
|
+
>>> # Fine-grained control over table specification:
|
|
309
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
310
|
+
... 'USERS',
|
|
311
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
312
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
313
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
connection: An open connection from
|
|
317
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
318
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
319
|
+
connection. If ``None``, will re-use an active session in case
|
|
320
|
+
it exists, or create a new connection from credentials stored
|
|
321
|
+
in environment variables.
|
|
322
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
323
|
+
to include. If ``None``, will add all tables present in the
|
|
324
|
+
current database and schema.
|
|
325
|
+
database: The database.
|
|
326
|
+
schema: The schema.
|
|
327
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
328
|
+
add to the graph. If not provided, edges will be automatically
|
|
329
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
330
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
331
|
+
graph.
|
|
332
|
+
verbose: Whether to print verbose output.
|
|
333
|
+
"""
|
|
334
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
335
|
+
Connection,
|
|
336
|
+
SnowTable,
|
|
337
|
+
connect,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
if not isinstance(connection, Connection):
|
|
341
|
+
connection = connect(**(connection or {}))
|
|
342
|
+
assert isinstance(connection, Connection)
|
|
343
|
+
|
|
344
|
+
if database is None or schema is None:
|
|
345
|
+
with connection.cursor() as cursor:
|
|
346
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
347
|
+
result = cursor.fetchone()
|
|
348
|
+
database = database or result[0]
|
|
349
|
+
assert database is not None
|
|
350
|
+
schema = schema or result[1]
|
|
351
|
+
|
|
352
|
+
if tables is None:
|
|
353
|
+
if schema is None:
|
|
354
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
355
|
+
"Snowflake schema manually")
|
|
356
|
+
|
|
357
|
+
with connection.cursor() as cursor:
|
|
358
|
+
cursor.execute(f"""
|
|
359
|
+
SELECT TABLE_NAME
|
|
360
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
361
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
362
|
+
""")
|
|
363
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
364
|
+
|
|
365
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
366
|
+
for table in tables:
|
|
367
|
+
if isinstance(table, str):
|
|
368
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
369
|
+
else:
|
|
370
|
+
kwargs = copy.copy(table)
|
|
371
|
+
kwargs.setdefault('database', database)
|
|
372
|
+
kwargs.setdefault('schema', schema)
|
|
373
|
+
table_kwargs.append(kwargs)
|
|
374
|
+
|
|
375
|
+
graph = cls(
|
|
376
|
+
tables=[
|
|
377
|
+
SnowTable(connection=connection, **kwargs)
|
|
378
|
+
for kwargs in table_kwargs
|
|
379
|
+
],
|
|
380
|
+
edges=edges or [],
|
|
381
|
+
)
|
|
161
382
|
|
|
162
383
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(verbose)
|
|
384
|
+
graph.infer_metadata(verbose=False)
|
|
164
385
|
|
|
165
386
|
if edges is None:
|
|
166
|
-
graph.infer_links(verbose)
|
|
387
|
+
graph.infer_links(verbose=False)
|
|
388
|
+
|
|
389
|
+
if verbose:
|
|
390
|
+
graph.print_metadata()
|
|
391
|
+
graph.print_links()
|
|
167
392
|
|
|
168
393
|
return graph
|
|
169
394
|
|
|
170
|
-
|
|
395
|
+
@classmethod
|
|
396
|
+
def from_snowflake_semantic_view(
|
|
397
|
+
cls,
|
|
398
|
+
semantic_view_name: str,
|
|
399
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
400
|
+
verbose: bool = True,
|
|
401
|
+
) -> Self:
|
|
402
|
+
import yaml
|
|
403
|
+
|
|
404
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
405
|
+
Connection,
|
|
406
|
+
SnowTable,
|
|
407
|
+
connect,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
if not isinstance(connection, Connection):
|
|
411
|
+
connection = connect(**(connection or {}))
|
|
412
|
+
assert isinstance(connection, Connection)
|
|
413
|
+
|
|
414
|
+
with connection.cursor() as cursor:
|
|
415
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
416
|
+
f"'{semantic_view_name}')")
|
|
417
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
418
|
+
|
|
419
|
+
graph = cls(tables=[])
|
|
420
|
+
|
|
421
|
+
msgs = []
|
|
422
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
423
|
+
for table_cfg in cfg['tables']:
|
|
424
|
+
table_name = table_cfg['name']
|
|
425
|
+
source_table_name = table_cfg['base_table']['table']
|
|
426
|
+
database = table_cfg['base_table']['database']
|
|
427
|
+
schema = table_cfg['base_table']['schema']
|
|
428
|
+
|
|
429
|
+
primary_key: str | None = None
|
|
430
|
+
if 'primary_key' in table_cfg:
|
|
431
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
432
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
433
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
434
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
435
|
+
msgs.append(f"Failed to add primary key for table "
|
|
436
|
+
f"'{table_name}' since composite primary keys "
|
|
437
|
+
f"are not yet supported")
|
|
438
|
+
|
|
439
|
+
columns: list[ColumnSpec] = []
|
|
440
|
+
unsupported_columns: list[str] = []
|
|
441
|
+
for column_cfg in chain(
|
|
442
|
+
table_cfg.get('dimensions', []),
|
|
443
|
+
table_cfg.get('time_dimensions', []),
|
|
444
|
+
table_cfg.get('facts', []),
|
|
445
|
+
):
|
|
446
|
+
column_name = column_cfg['name']
|
|
447
|
+
column_expr = column_cfg.get('expr', None)
|
|
448
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
449
|
+
|
|
450
|
+
if column_expr is None:
|
|
451
|
+
columns.append(ColumnSpec(name=column_name))
|
|
452
|
+
continue
|
|
453
|
+
|
|
454
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
455
|
+
|
|
456
|
+
if column_expr == column_name:
|
|
457
|
+
columns.append(ColumnSpec(name=column_name))
|
|
458
|
+
continue
|
|
459
|
+
|
|
460
|
+
# Drop expressions that reference other tables (for now):
|
|
461
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
462
|
+
unsupported_columns.append(column_name)
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
column = ColumnSpec(
|
|
466
|
+
name=column_name,
|
|
467
|
+
expr=column_expr,
|
|
468
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
469
|
+
)
|
|
470
|
+
columns.append(column)
|
|
471
|
+
|
|
472
|
+
if len(unsupported_columns) == 1:
|
|
473
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
474
|
+
f"of table '{table_name}' since its expression "
|
|
475
|
+
f"references other tables")
|
|
476
|
+
elif len(unsupported_columns) > 1:
|
|
477
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
478
|
+
f"of table '{table_name}' since their expressions "
|
|
479
|
+
f"reference other tables")
|
|
480
|
+
|
|
481
|
+
table = SnowTable(
|
|
482
|
+
connection,
|
|
483
|
+
name=table_name,
|
|
484
|
+
source_name=source_table_name,
|
|
485
|
+
database=database,
|
|
486
|
+
schema=schema,
|
|
487
|
+
columns=columns,
|
|
488
|
+
primary_key=primary_key,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
492
|
+
table.infer_time_column(verbose=False)
|
|
493
|
+
|
|
494
|
+
graph.add_table(table)
|
|
495
|
+
|
|
496
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
497
|
+
name = relation_cfg['name']
|
|
498
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
499
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
500
|
+
f"composite key references are not yet supported")
|
|
501
|
+
continue
|
|
502
|
+
|
|
503
|
+
left_table = relation_cfg['left_table']
|
|
504
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
505
|
+
right_table = relation_cfg['right_table']
|
|
506
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
507
|
+
|
|
508
|
+
if graph[right_table]._primary_key != right_key:
|
|
509
|
+
# Semantic view error - this should never be triggered:
|
|
510
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
511
|
+
f"referenced key '{right_key}' of table "
|
|
512
|
+
f"'{right_table}' is not a primary key")
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
if graph[left_table]._primary_key == left_key:
|
|
516
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
517
|
+
f"referencing key '{left_key}' of table "
|
|
518
|
+
f"'{left_table}' is a primary key")
|
|
519
|
+
continue
|
|
520
|
+
|
|
521
|
+
if left_key not in graph[left_table]:
|
|
522
|
+
graph[left_table].add_column(left_key)
|
|
523
|
+
|
|
524
|
+
graph.link(left_table, left_key, right_table)
|
|
525
|
+
|
|
526
|
+
graph.validate()
|
|
527
|
+
|
|
528
|
+
if verbose:
|
|
529
|
+
graph.print_metadata()
|
|
530
|
+
graph.print_links()
|
|
531
|
+
|
|
532
|
+
if len(msgs) > 0:
|
|
533
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
534
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
535
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
536
|
+
|
|
537
|
+
return graph
|
|
538
|
+
|
|
539
|
+
# Backend #################################################################
|
|
540
|
+
|
|
541
|
+
@property
|
|
542
|
+
def backend(self) -> DataBackend | None:
|
|
543
|
+
backends = [table.backend for table in self._tables.values()]
|
|
544
|
+
return backends[0] if len(backends) > 0 else None
|
|
545
|
+
|
|
546
|
+
# Tables ##################################################################
|
|
171
547
|
|
|
172
548
|
def has_table(self, name: str) -> bool:
|
|
173
549
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -175,7 +551,7 @@ class LocalGraph:
|
|
|
175
551
|
"""
|
|
176
552
|
return name in self.tables
|
|
177
553
|
|
|
178
|
-
def table(self, name: str) ->
|
|
554
|
+
def table(self, name: str) -> Table:
|
|
179
555
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
556
|
|
|
181
557
|
Raises:
|
|
@@ -186,11 +562,11 @@ class LocalGraph:
|
|
|
186
562
|
return self.tables[name]
|
|
187
563
|
|
|
188
564
|
@property
|
|
189
|
-
def tables(self) ->
|
|
565
|
+
def tables(self) -> dict[str, Table]:
|
|
190
566
|
r"""Returns the dictionary of table objects."""
|
|
191
567
|
return self._tables
|
|
192
568
|
|
|
193
|
-
def add_table(self, table:
|
|
569
|
+
def add_table(self, table: Table) -> Self:
|
|
194
570
|
r"""Adds a table to the graph.
|
|
195
571
|
|
|
196
572
|
Args:
|
|
@@ -199,11 +575,18 @@ class LocalGraph:
|
|
|
199
575
|
Raises:
|
|
200
576
|
KeyError: If a table with the same name already exists in the
|
|
201
577
|
graph.
|
|
578
|
+
ValueError: If the table belongs to a different backend than the
|
|
579
|
+
rest of the tables in the graph.
|
|
202
580
|
"""
|
|
203
581
|
if table.name in self._tables:
|
|
204
582
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
583
|
f"this graph; table names must be globally unique.")
|
|
206
584
|
|
|
585
|
+
if self.backend is not None and table.backend != self.backend:
|
|
586
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
587
|
+
f"'{table.backend}' to this graph since other "
|
|
588
|
+
f"tables have backend '{self.backend}'.")
|
|
589
|
+
|
|
207
590
|
self._tables[table.name] = table
|
|
208
591
|
|
|
209
592
|
return self
|
|
@@ -241,7 +624,7 @@ class LocalGraph:
|
|
|
241
624
|
Example:
|
|
242
625
|
>>> # doctest: +SKIP
|
|
243
626
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
627
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
628
|
>>> graph.metadata # doctest: +SKIP
|
|
246
629
|
name primary_key time_column end_time_column
|
|
247
630
|
0 users user_id - -
|
|
@@ -263,10 +646,14 @@ class LocalGraph:
|
|
|
263
646
|
})
|
|
264
647
|
|
|
265
648
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
267
|
-
if
|
|
649
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
650
|
+
if in_snowflake_notebook():
|
|
651
|
+
import streamlit as st
|
|
652
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
653
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
654
|
+
elif in_notebook():
|
|
268
655
|
from IPython.display import Markdown, display
|
|
269
|
-
display(Markdown(
|
|
656
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
270
657
|
df = self.metadata
|
|
271
658
|
try:
|
|
272
659
|
if hasattr(df.style, 'hide'):
|
|
@@ -287,7 +674,7 @@ class LocalGraph:
|
|
|
287
674
|
|
|
288
675
|
Note:
|
|
289
676
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
677
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
678
|
"""
|
|
292
679
|
for table in self.tables.values():
|
|
293
680
|
table.infer_metadata(verbose=False)
|
|
@@ -300,42 +687,52 @@ class LocalGraph:
|
|
|
300
687
|
# Edges ###################################################################
|
|
301
688
|
|
|
302
689
|
@property
|
|
303
|
-
def edges(self) ->
|
|
690
|
+
def edges(self) -> list[Edge]:
|
|
304
691
|
r"""Returns the edges of the graph."""
|
|
305
692
|
return self._edges
|
|
306
693
|
|
|
307
694
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
695
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
696
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
697
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
698
|
edges = sorted(edges)
|
|
312
699
|
|
|
313
|
-
if
|
|
700
|
+
if in_snowflake_notebook():
|
|
701
|
+
import streamlit as st
|
|
702
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
703
|
+
if len(edges) > 0:
|
|
704
|
+
st.markdown('\n'.join([
|
|
705
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
706
|
+
for edge in edges
|
|
707
|
+
]))
|
|
708
|
+
else:
|
|
709
|
+
st.markdown("*No links registered*")
|
|
710
|
+
elif in_notebook():
|
|
314
711
|
from IPython.display import Markdown, display
|
|
315
|
-
display(Markdown(
|
|
712
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
316
713
|
if len(edges) > 0:
|
|
317
714
|
display(
|
|
318
715
|
Markdown('\n'.join([
|
|
319
|
-
f
|
|
716
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
320
717
|
for edge in edges
|
|
321
718
|
])))
|
|
322
719
|
else:
|
|
323
|
-
display(Markdown(
|
|
720
|
+
display(Markdown("*No links registered*"))
|
|
324
721
|
else:
|
|
325
722
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
326
723
|
if len(edges) > 0:
|
|
327
724
|
print('\n'.join([
|
|
328
|
-
f
|
|
725
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
329
726
|
for edge in edges
|
|
330
727
|
]))
|
|
331
728
|
else:
|
|
332
|
-
print(
|
|
729
|
+
print("No links registered")
|
|
333
730
|
|
|
334
731
|
def link(
|
|
335
732
|
self,
|
|
336
|
-
src_table:
|
|
733
|
+
src_table: str | Table,
|
|
337
734
|
fkey: str,
|
|
338
|
-
dst_table:
|
|
735
|
+
dst_table: str | Table,
|
|
339
736
|
) -> Self:
|
|
340
737
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
738
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +755,11 @@ class LocalGraph:
|
|
|
358
755
|
table does not exist in the graph, if the source key does not
|
|
359
756
|
exist in the source table.
|
|
360
757
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
758
|
+
if isinstance(src_table, Table):
|
|
362
759
|
src_table = src_table.name
|
|
363
760
|
assert isinstance(src_table, str)
|
|
364
761
|
|
|
365
|
-
if isinstance(dst_table,
|
|
762
|
+
if isinstance(dst_table, Table):
|
|
366
763
|
dst_table = dst_table.name
|
|
367
764
|
assert isinstance(dst_table, str)
|
|
368
765
|
|
|
@@ -396,9 +793,9 @@ class LocalGraph:
|
|
|
396
793
|
|
|
397
794
|
def unlink(
|
|
398
795
|
self,
|
|
399
|
-
src_table:
|
|
796
|
+
src_table: str | Table,
|
|
400
797
|
fkey: str,
|
|
401
|
-
dst_table:
|
|
798
|
+
dst_table: str | Table,
|
|
402
799
|
) -> Self:
|
|
403
800
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
801
|
|
|
@@ -410,11 +807,11 @@ class LocalGraph:
|
|
|
410
807
|
Raises:
|
|
411
808
|
ValueError: if the edge is not present in the graph.
|
|
412
809
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
810
|
+
if isinstance(src_table, Table):
|
|
414
811
|
src_table = src_table.name
|
|
415
812
|
assert isinstance(src_table, str)
|
|
416
813
|
|
|
417
|
-
if isinstance(dst_table,
|
|
814
|
+
if isinstance(dst_table, Table):
|
|
418
815
|
dst_table = dst_table.name
|
|
419
816
|
assert isinstance(dst_table, str)
|
|
420
817
|
|
|
@@ -428,17 +825,37 @@ class LocalGraph:
|
|
|
428
825
|
return self
|
|
429
826
|
|
|
430
827
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
828
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
829
|
+
graph.
|
|
432
830
|
|
|
433
831
|
Args:
|
|
434
832
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
833
|
"""
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
834
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
835
|
+
|
|
836
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
837
|
+
if not any(column.is_source for column in table.columns):
|
|
838
|
+
continue
|
|
839
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
840
|
+
if fkey.name not in table:
|
|
841
|
+
continue
|
|
842
|
+
if not table[fkey.name].is_source:
|
|
843
|
+
continue
|
|
844
|
+
if (table.name, fkey.name) in known_edges:
|
|
845
|
+
continue
|
|
846
|
+
dst_table_names = [
|
|
847
|
+
table.name for table in self.tables.values()
|
|
848
|
+
if table.source_name == fkey.dst_table
|
|
849
|
+
]
|
|
850
|
+
if len(dst_table_names) != 1:
|
|
851
|
+
continue
|
|
852
|
+
dst_table = self[dst_table_names[0]]
|
|
853
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
854
|
+
continue
|
|
855
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
856
|
+
continue
|
|
857
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
858
|
+
known_edges.add((table.name, fkey.name))
|
|
442
859
|
|
|
443
860
|
# A list of primary key candidates (+score) for every column:
|
|
444
861
|
candidate_dict: dict[
|
|
@@ -463,6 +880,9 @@ class LocalGraph:
|
|
|
463
880
|
src_table_name = src_table.name.lower()
|
|
464
881
|
|
|
465
882
|
for src_key in src_table.columns:
|
|
883
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
884
|
+
continue
|
|
885
|
+
|
|
466
886
|
if src_key == src_table.primary_key:
|
|
467
887
|
continue # Cannot link to primary key.
|
|
468
888
|
|
|
@@ -528,19 +948,16 @@ class LocalGraph:
|
|
|
528
948
|
score += 1.0
|
|
529
949
|
|
|
530
950
|
# Cardinality ratio:
|
|
531
|
-
if
|
|
951
|
+
if (src_table._num_rows is not None
|
|
952
|
+
and dst_table._num_rows is not None
|
|
953
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
532
954
|
score += 1.0
|
|
533
955
|
|
|
534
956
|
if score < 5.0:
|
|
535
957
|
continue
|
|
536
958
|
|
|
537
|
-
candidate_dict[(
|
|
538
|
-
|
|
539
|
-
src_key.name,
|
|
540
|
-
)].append((
|
|
541
|
-
dst_table.name,
|
|
542
|
-
score,
|
|
543
|
-
))
|
|
959
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
960
|
+
(dst_table.name, score))
|
|
544
961
|
|
|
545
962
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
546
963
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -574,6 +991,10 @@ class LocalGraph:
|
|
|
574
991
|
raise ValueError("At least one table needs to be added to the "
|
|
575
992
|
"graph")
|
|
576
993
|
|
|
994
|
+
backends = {table.backend for table in self._tables.values()}
|
|
995
|
+
if len(backends) != 1:
|
|
996
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
997
|
+
|
|
577
998
|
for edge in self.edges:
|
|
578
999
|
src_table, fkey, dst_table = edge
|
|
579
1000
|
|
|
@@ -595,24 +1016,26 @@ class LocalGraph:
|
|
|
595
1016
|
f"either the primary key or the link before "
|
|
596
1017
|
f"before proceeding.")
|
|
597
1018
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
1019
|
+
if self.backend == DataBackend.LOCAL:
|
|
1020
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1021
|
+
assert src_key.dtype is not None
|
|
1022
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1023
|
+
src_string = src_key.dtype.is_string()
|
|
1024
|
+
assert dst_key.dtype is not None
|
|
1025
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1026
|
+
dst_string = dst_key.dtype.is_string()
|
|
1027
|
+
|
|
1028
|
+
if not src_number and not src_string:
|
|
1029
|
+
raise ValueError(
|
|
1030
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1031
|
+
f"or string (got '{src_key.dtype}'")
|
|
1032
|
+
|
|
1033
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1034
|
+
raise ValueError(
|
|
1035
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1036
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1037
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1038
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
616
1039
|
|
|
617
1040
|
return self
|
|
618
1041
|
|
|
@@ -620,7 +1043,7 @@ class LocalGraph:
|
|
|
620
1043
|
|
|
621
1044
|
def visualize(
|
|
622
1045
|
self,
|
|
623
|
-
path:
|
|
1046
|
+
path: str | io.BytesIO | None = None,
|
|
624
1047
|
show_columns: bool = True,
|
|
625
1048
|
) -> 'graphviz.Graph':
|
|
626
1049
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -645,33 +1068,33 @@ class LocalGraph:
|
|
|
645
1068
|
|
|
646
1069
|
return True
|
|
647
1070
|
|
|
648
|
-
# Check basic dependency:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
1071
|
+
try: # Check basic dependency:
|
|
1072
|
+
import graphviz
|
|
1073
|
+
except ImportError as e:
|
|
1074
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1075
|
+
"visualization") from e
|
|
1076
|
+
|
|
1077
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
653
1078
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
654
1079
|
"executables are not installed. These "
|
|
655
1080
|
"dependencies are required in addition to the "
|
|
656
1081
|
"'graphviz' Python package. Please install "
|
|
657
1082
|
"them as described at "
|
|
658
1083
|
"https://graphviz.org/download/.")
|
|
659
|
-
else:
|
|
660
|
-
import graphviz
|
|
661
1084
|
|
|
662
|
-
format:
|
|
1085
|
+
format: str | None = None
|
|
663
1086
|
if isinstance(path, str):
|
|
664
1087
|
format = path.split('.')[-1]
|
|
665
1088
|
elif isinstance(path, io.BytesIO):
|
|
666
1089
|
format = 'svg'
|
|
667
1090
|
graph = graphviz.Graph(format=format)
|
|
668
1091
|
|
|
669
|
-
def left_align(keys:
|
|
1092
|
+
def left_align(keys: list[str]) -> str:
|
|
670
1093
|
if len(keys) == 0:
|
|
671
1094
|
return ""
|
|
672
1095
|
return '\\l'.join(keys) + '\\l'
|
|
673
1096
|
|
|
674
|
-
fkeys_dict:
|
|
1097
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
675
1098
|
for src_table_name, fkey_name, _ in self.edges:
|
|
676
1099
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
677
1100
|
|
|
@@ -741,6 +1164,9 @@ class LocalGraph:
|
|
|
741
1164
|
graph.render(path, cleanup=True)
|
|
742
1165
|
elif isinstance(path, io.BytesIO):
|
|
743
1166
|
path.write(graph.pipe())
|
|
1167
|
+
elif in_snowflake_notebook():
|
|
1168
|
+
import streamlit as st
|
|
1169
|
+
st.graphviz_chart(graph)
|
|
744
1170
|
elif in_notebook():
|
|
745
1171
|
from IPython.display import display
|
|
746
1172
|
display(graph)
|
|
@@ -764,8 +1190,8 @@ class LocalGraph:
|
|
|
764
1190
|
# Helpers #################################################################
|
|
765
1191
|
|
|
766
1192
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
767
|
-
tables:
|
|
768
|
-
col_groups:
|
|
1193
|
+
tables: dict[str, TableDefinition] = {}
|
|
1194
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
769
1195
|
for table_name, table in self.tables.items():
|
|
770
1196
|
tables[table_name] = table._to_api_table_definition()
|
|
771
1197
|
if table.primary_key is None:
|
|
@@ -790,7 +1216,7 @@ class LocalGraph:
|
|
|
790
1216
|
def __contains__(self, name: str) -> bool:
|
|
791
1217
|
return self.has_table(name)
|
|
792
1218
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
1219
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
1220
|
return self.table(name)
|
|
795
1221
|
|
|
796
1222
|
def __delitem__(self, name: str) -> None:
|
|
@@ -808,3 +1234,7 @@ class LocalGraph:
|
|
|
808
1234
|
f' tables={tables},\n'
|
|
809
1235
|
f' edges={edges},\n'
|
|
810
1236
|
f')')
|
|
1237
|
+
|
|
1238
|
+
def __del__(self) -> None:
|
|
1239
|
+
if hasattr(self, '_connection'):
|
|
1240
|
+
self._connection.close()
|