kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +41 -80
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +147 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
- kumoai/experimental/rfm/base/__init__.py +26 -2
- kumoai/experimental/rfm/base/column.py +6 -12
- kumoai/experimental/rfm/base/column_expression.py +16 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +113 -0
- kumoai/experimental/rfm/base/table.py +174 -76
- kumoai/experimental/rfm/graph.py +444 -84
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +77 -0
- kumoai/experimental/rfm/infer/pkey.py +128 -0
- kumoai/experimental/rfm/infer/time_col.py +61 -0
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -240
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from itertools import chain
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
7
11
|
|
|
8
12
|
import pandas as pd
|
|
9
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,12 +15,21 @@ from kumoapi.table import TableDefinition
|
|
|
11
15
|
from kumoapi.typing import Stype
|
|
12
16
|
from typing_extensions import Self
|
|
13
17
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import Table
|
|
18
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
+
from kumoai.experimental.rfm.base import DataBackend, SQLTable, Table
|
|
16
20
|
from kumoai.graph import Edge
|
|
21
|
+
from kumoai.mixin import CastMixin
|
|
17
22
|
|
|
18
23
|
if TYPE_CHECKING:
|
|
19
24
|
import graphviz
|
|
25
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
26
|
+
from snowflake.connector import SnowflakeConnection
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class SqliteConnectionConfig(CastMixin):
|
|
31
|
+
uri: str | Path
|
|
32
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
20
33
|
|
|
21
34
|
|
|
22
35
|
class Graph:
|
|
@@ -76,32 +89,53 @@ class Graph:
|
|
|
76
89
|
def __init__(
|
|
77
90
|
self,
|
|
78
91
|
tables: Sequence[Table],
|
|
79
|
-
edges:
|
|
92
|
+
edges: Sequence[Edge] | None = None,
|
|
80
93
|
) -> None:
|
|
81
94
|
|
|
82
|
-
self._tables:
|
|
83
|
-
self._edges:
|
|
95
|
+
self._tables: dict[str, Table] = {}
|
|
96
|
+
self._edges: list[Edge] = []
|
|
84
97
|
|
|
85
98
|
for table in tables:
|
|
86
99
|
self.add_table(table)
|
|
87
100
|
|
|
101
|
+
for table in tables:
|
|
102
|
+
if not isinstance(table, SQLTable):
|
|
103
|
+
continue
|
|
104
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
105
|
+
if fkey.name not in table:
|
|
106
|
+
continue
|
|
107
|
+
# TODO Skip for non-physical table[fkey.name].
|
|
108
|
+
dst_table_names = [
|
|
109
|
+
table.name for table in self.tables.values()
|
|
110
|
+
if isinstance(table, SQLTable)
|
|
111
|
+
and table._source_name == fkey.dst_table
|
|
112
|
+
]
|
|
113
|
+
if len(dst_table_names) != 1:
|
|
114
|
+
continue
|
|
115
|
+
dst_table = self[dst_table_names[0]]
|
|
116
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
117
|
+
continue
|
|
118
|
+
# TODO Skip for non-physical dst_table.primary_key.
|
|
119
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
120
|
+
|
|
88
121
|
for edge in (edges or []):
|
|
89
122
|
_edge = Edge._cast(edge)
|
|
90
123
|
assert _edge is not None
|
|
91
|
-
self.
|
|
124
|
+
if _edge not in self._edges:
|
|
125
|
+
self.link(*_edge)
|
|
92
126
|
|
|
93
127
|
@classmethod
|
|
94
128
|
def from_data(
|
|
95
129
|
cls,
|
|
96
|
-
df_dict:
|
|
97
|
-
edges:
|
|
130
|
+
df_dict: dict[str, pd.DataFrame],
|
|
131
|
+
edges: Sequence[Edge] | None = None,
|
|
98
132
|
infer_metadata: bool = True,
|
|
99
133
|
verbose: bool = True,
|
|
100
134
|
) -> Self:
|
|
101
135
|
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
136
|
:class:`pandas.DataFrame` objects.
|
|
103
137
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
138
|
+
Automatically infers table metadata and links by default.
|
|
105
139
|
|
|
106
140
|
.. code-block:: python
|
|
107
141
|
|
|
@@ -121,54 +155,360 @@ class Graph:
|
|
|
121
155
|
... "table3": df3,
|
|
122
156
|
... })
|
|
123
157
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
158
|
Args:
|
|
132
159
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
160
|
of the tables and the values hold table data.
|
|
161
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
162
|
+
add to the graph. If not provided, edges will be automatically
|
|
163
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
164
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
165
|
graph.
|
|
166
|
+
verbose: Whether to print verbose output.
|
|
167
|
+
"""
|
|
168
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
169
|
+
|
|
170
|
+
graph = cls(
|
|
171
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
172
|
+
edges=edges or [],
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
if infer_metadata:
|
|
176
|
+
graph.infer_metadata(verbose=False)
|
|
177
|
+
|
|
178
|
+
if edges is None:
|
|
179
|
+
graph.infer_links(verbose=False)
|
|
180
|
+
|
|
181
|
+
if verbose:
|
|
182
|
+
graph.print_metadata()
|
|
183
|
+
graph.print_links()
|
|
184
|
+
|
|
185
|
+
return graph
|
|
186
|
+
|
|
187
|
+
@classmethod
|
|
188
|
+
def from_sqlite(
|
|
189
|
+
cls,
|
|
190
|
+
connection: Union[
|
|
191
|
+
'AdbcSqliteConnection',
|
|
192
|
+
SqliteConnectionConfig,
|
|
193
|
+
str,
|
|
194
|
+
Path,
|
|
195
|
+
dict[str, Any],
|
|
196
|
+
],
|
|
197
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
198
|
+
edges: Sequence[Edge] | None = None,
|
|
199
|
+
infer_metadata: bool = True,
|
|
200
|
+
verbose: bool = True,
|
|
201
|
+
) -> Self:
|
|
202
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
203
|
+
|
|
204
|
+
Automatically infers table metadata and links by default.
|
|
205
|
+
|
|
206
|
+
.. code-block:: python
|
|
207
|
+
|
|
208
|
+
>>> # doctest: +SKIP
|
|
209
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
210
|
+
|
|
211
|
+
>>> # Create a graph from a SQLite database:
|
|
212
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
213
|
+
|
|
214
|
+
>>> # Fine-grained control over table specification:
|
|
215
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
216
|
+
... 'USERS',
|
|
217
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
218
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
219
|
+
... ])
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
connection: An open connection from
|
|
223
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
224
|
+
path to the database file.
|
|
225
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
226
|
+
arguments to include. If ``None``, will add all tables present
|
|
227
|
+
in the database.
|
|
136
228
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
229
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
230
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
231
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
232
|
+
the graph.
|
|
139
233
|
verbose: Whether to print verbose output.
|
|
234
|
+
"""
|
|
235
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
236
|
+
Connection,
|
|
237
|
+
SQLiteTable,
|
|
238
|
+
connect,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
internal_connection = False
|
|
242
|
+
if not isinstance(connection, Connection):
|
|
243
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
244
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
245
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
246
|
+
internal_connection = True
|
|
247
|
+
assert isinstance(connection, Connection)
|
|
248
|
+
|
|
249
|
+
if tables is None:
|
|
250
|
+
with connection.cursor() as cursor:
|
|
251
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
252
|
+
"WHERE type='table'")
|
|
253
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
254
|
+
|
|
255
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
256
|
+
for table in tables:
|
|
257
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
258
|
+
table_kwargs.append(kwargs)
|
|
140
259
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
260
|
+
graph = cls(
|
|
261
|
+
tables=[
|
|
262
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
263
|
+
for kwargs in table_kwargs
|
|
264
|
+
],
|
|
265
|
+
edges=edges or [],
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if internal_connection:
|
|
269
|
+
graph._connection = connection # type: ignore
|
|
270
|
+
|
|
271
|
+
if infer_metadata:
|
|
272
|
+
graph.infer_metadata(verbose=False)
|
|
273
|
+
|
|
274
|
+
if edges is None:
|
|
275
|
+
graph.infer_links(verbose=False)
|
|
276
|
+
|
|
277
|
+
if verbose:
|
|
278
|
+
graph.print_metadata()
|
|
279
|
+
graph.print_links()
|
|
280
|
+
|
|
281
|
+
return graph
|
|
282
|
+
|
|
283
|
+
@classmethod
|
|
284
|
+
def from_snowflake(
|
|
285
|
+
cls,
|
|
286
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
287
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
288
|
+
database: str | None = None,
|
|
289
|
+
schema: str | None = None,
|
|
290
|
+
edges: Sequence[Edge] | None = None,
|
|
291
|
+
infer_metadata: bool = True,
|
|
292
|
+
verbose: bool = True,
|
|
293
|
+
) -> Self:
|
|
294
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
295
|
+
schema.
|
|
296
|
+
|
|
297
|
+
Automatically infers table metadata and links by default.
|
|
298
|
+
|
|
299
|
+
.. code-block:: python
|
|
144
300
|
|
|
145
|
-
Example:
|
|
146
301
|
>>> # doctest: +SKIP
|
|
147
302
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
>>> df1 = pd.DataFrame(...)
|
|
149
|
-
>>> df2 = pd.DataFrame(...)
|
|
150
|
-
>>> df3 = pd.DataFrame(...)
|
|
151
|
-
>>> graph = rfm.Graph.from_data(data={
|
|
152
|
-
... "table1": df1,
|
|
153
|
-
... "table2": df2,
|
|
154
|
-
... "table3": df3,
|
|
155
|
-
... })
|
|
156
|
-
>>> graph.validate()
|
|
157
|
-
"""
|
|
158
|
-
from kumoai.experimental.rfm import LocalTable
|
|
159
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
303
|
|
|
161
|
-
|
|
304
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
305
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
306
|
+
|
|
307
|
+
>>> # Fine-grained control over table specification:
|
|
308
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
309
|
+
... 'USERS',
|
|
310
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
311
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
312
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
connection: An open connection from
|
|
316
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
317
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
318
|
+
connection. If ``None``, will re-use an active session in case
|
|
319
|
+
it exists, or create a new connection from credentials stored
|
|
320
|
+
in environment variables.
|
|
321
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
322
|
+
to include. If ``None``, will add all tables present in the
|
|
323
|
+
current database and schema.
|
|
324
|
+
database: The database.
|
|
325
|
+
schema: The schema.
|
|
326
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
327
|
+
add to the graph. If not provided, edges will be automatically
|
|
328
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
329
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
330
|
+
graph.
|
|
331
|
+
verbose: Whether to print verbose output.
|
|
332
|
+
"""
|
|
333
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
334
|
+
Connection,
|
|
335
|
+
SnowTable,
|
|
336
|
+
connect,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if not isinstance(connection, Connection):
|
|
340
|
+
connection = connect(**(connection or {}))
|
|
341
|
+
assert isinstance(connection, Connection)
|
|
342
|
+
|
|
343
|
+
if database is None or schema is None:
|
|
344
|
+
with connection.cursor() as cursor:
|
|
345
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
346
|
+
result = cursor.fetchone()
|
|
347
|
+
database = database or result[0]
|
|
348
|
+
assert database is not None
|
|
349
|
+
schema = schema or result[1]
|
|
350
|
+
|
|
351
|
+
if tables is None:
|
|
352
|
+
if schema is None:
|
|
353
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
354
|
+
"Snowflake schema manually")
|
|
355
|
+
|
|
356
|
+
with connection.cursor() as cursor:
|
|
357
|
+
cursor.execute(f"""
|
|
358
|
+
SELECT TABLE_NAME
|
|
359
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
360
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
361
|
+
""")
|
|
362
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
363
|
+
|
|
364
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
365
|
+
for table in tables:
|
|
366
|
+
if isinstance(table, str):
|
|
367
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
368
|
+
else:
|
|
369
|
+
kwargs = copy.copy(table)
|
|
370
|
+
kwargs.setdefault('database', database)
|
|
371
|
+
kwargs.setdefault('schema', schema)
|
|
372
|
+
table_kwargs.append(kwargs)
|
|
373
|
+
|
|
374
|
+
graph = cls(
|
|
375
|
+
tables=[
|
|
376
|
+
SnowTable(connection=connection, **kwargs)
|
|
377
|
+
for kwargs in table_kwargs
|
|
378
|
+
],
|
|
379
|
+
edges=edges or [],
|
|
380
|
+
)
|
|
162
381
|
|
|
163
382
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(verbose)
|
|
383
|
+
graph.infer_metadata(verbose=False)
|
|
165
384
|
|
|
166
385
|
if edges is None:
|
|
167
|
-
graph.infer_links(verbose)
|
|
386
|
+
graph.infer_links(verbose=False)
|
|
387
|
+
|
|
388
|
+
if verbose:
|
|
389
|
+
graph.print_metadata()
|
|
390
|
+
graph.print_links()
|
|
168
391
|
|
|
169
392
|
return graph
|
|
170
393
|
|
|
171
|
-
|
|
394
|
+
@classmethod
|
|
395
|
+
def from_snowflake_semantic_view(
|
|
396
|
+
cls,
|
|
397
|
+
semantic_view_name: str,
|
|
398
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
399
|
+
verbose: bool = True,
|
|
400
|
+
) -> Self:
|
|
401
|
+
import yaml
|
|
402
|
+
|
|
403
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
404
|
+
Connection,
|
|
405
|
+
SnowTable,
|
|
406
|
+
connect,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if not isinstance(connection, Connection):
|
|
410
|
+
connection = connect(**(connection or {}))
|
|
411
|
+
assert isinstance(connection, Connection)
|
|
412
|
+
|
|
413
|
+
with connection.cursor() as cursor:
|
|
414
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
415
|
+
f"'{semantic_view_name}')")
|
|
416
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
417
|
+
|
|
418
|
+
graph = cls(tables=[])
|
|
419
|
+
|
|
420
|
+
msgs = []
|
|
421
|
+
for table_cfg in cfg['tables']:
|
|
422
|
+
table_name = table_cfg['name']
|
|
423
|
+
source_table_name = table_cfg['base_table']['table']
|
|
424
|
+
database = table_cfg['base_table']['database']
|
|
425
|
+
schema = table_cfg['base_table']['schema']
|
|
426
|
+
|
|
427
|
+
primary_key: str | None = None
|
|
428
|
+
if 'primary_key' in table_cfg:
|
|
429
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
430
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
431
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
432
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
433
|
+
msgs.append(f"Failed to add primary key for table "
|
|
434
|
+
f"'{table_name}' since composite primary keys "
|
|
435
|
+
f"are not yet supported")
|
|
436
|
+
|
|
437
|
+
columns: list[str] = []
|
|
438
|
+
for column_cfg in chain(
|
|
439
|
+
table_cfg.get('dimensions', []),
|
|
440
|
+
table_cfg.get('time_dimensions', []),
|
|
441
|
+
table_cfg.get('facts', []),
|
|
442
|
+
):
|
|
443
|
+
# TODO Add support for derived columns.
|
|
444
|
+
columns.append(column_cfg['name'])
|
|
445
|
+
|
|
446
|
+
table = SnowTable(
|
|
447
|
+
connection,
|
|
448
|
+
name=table_name,
|
|
449
|
+
source_name=source_table_name,
|
|
450
|
+
database=database,
|
|
451
|
+
schema=schema,
|
|
452
|
+
columns=columns,
|
|
453
|
+
primary_key=primary_key,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
457
|
+
table.infer_time_column(verbose=False)
|
|
458
|
+
|
|
459
|
+
graph.add_table(table)
|
|
460
|
+
|
|
461
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
462
|
+
name = relation_cfg['name']
|
|
463
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
464
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
465
|
+
f"composite key references are not yet supported")
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
left_table = relation_cfg['left_table']
|
|
469
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
470
|
+
right_table = relation_cfg['right_table']
|
|
471
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
472
|
+
|
|
473
|
+
if graph[right_table]._primary_key != right_key:
|
|
474
|
+
# Semantic view error - this should never be triggered:
|
|
475
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
476
|
+
f"referenced key '{right_key}' of table "
|
|
477
|
+
f"'{right_table}' is not a primary key")
|
|
478
|
+
continue
|
|
479
|
+
|
|
480
|
+
if graph[left_table]._primary_key == left_key:
|
|
481
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
482
|
+
f"referencing key '{left_key}' of table "
|
|
483
|
+
f"'{left_table}' is a primary key")
|
|
484
|
+
continue
|
|
485
|
+
|
|
486
|
+
if left_key not in graph[left_table]:
|
|
487
|
+
graph[left_table].add_column(left_key)
|
|
488
|
+
|
|
489
|
+
graph.link(left_table, left_key, right_table)
|
|
490
|
+
|
|
491
|
+
graph.validate()
|
|
492
|
+
|
|
493
|
+
if verbose:
|
|
494
|
+
graph.print_metadata()
|
|
495
|
+
graph.print_links()
|
|
496
|
+
|
|
497
|
+
if len(msgs) > 0:
|
|
498
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
499
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
500
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
501
|
+
|
|
502
|
+
return graph
|
|
503
|
+
|
|
504
|
+
# Backend #################################################################
|
|
505
|
+
|
|
506
|
+
@property
|
|
507
|
+
def backend(self) -> DataBackend | None:
|
|
508
|
+
backends = [table.backend for table in self._tables.values()]
|
|
509
|
+
return backends[0] if len(backends) > 0 else None
|
|
510
|
+
|
|
511
|
+
# Tables ##################################################################
|
|
172
512
|
|
|
173
513
|
def has_table(self, name: str) -> bool:
|
|
174
514
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -187,7 +527,7 @@ class Graph:
|
|
|
187
527
|
return self.tables[name]
|
|
188
528
|
|
|
189
529
|
@property
|
|
190
|
-
def tables(self) ->
|
|
530
|
+
def tables(self) -> dict[str, Table]:
|
|
191
531
|
r"""Returns the dictionary of table objects."""
|
|
192
532
|
return self._tables
|
|
193
533
|
|
|
@@ -207,13 +547,10 @@ class Graph:
|
|
|
207
547
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
208
548
|
f"this graph; table names must be globally unique.")
|
|
209
549
|
|
|
210
|
-
if
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
f"'{table.__class__.__name__}' to this "
|
|
215
|
-
f"graph since other tables are of type "
|
|
216
|
-
f"'{cls.__name__}'.")
|
|
550
|
+
if self.backend is not None and table.backend != self.backend:
|
|
551
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
552
|
+
f"'{table.backend}' to this graph since other "
|
|
553
|
+
f"tables have backend '{self.backend}'.")
|
|
217
554
|
|
|
218
555
|
self._tables[table.name] = table
|
|
219
556
|
|
|
@@ -275,9 +612,13 @@ class Graph:
|
|
|
275
612
|
|
|
276
613
|
def print_metadata(self) -> None:
|
|
277
614
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
278
|
-
if
|
|
615
|
+
if in_snowflake_notebook():
|
|
616
|
+
import streamlit as st
|
|
617
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
618
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
619
|
+
elif in_notebook():
|
|
279
620
|
from IPython.display import Markdown, display
|
|
280
|
-
display(Markdown(
|
|
621
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
281
622
|
df = self.metadata
|
|
282
623
|
try:
|
|
283
624
|
if hasattr(df.style, 'hide'):
|
|
@@ -311,7 +652,7 @@ class Graph:
|
|
|
311
652
|
# Edges ###################################################################
|
|
312
653
|
|
|
313
654
|
@property
|
|
314
|
-
def edges(self) ->
|
|
655
|
+
def edges(self) -> list[Edge]:
|
|
315
656
|
r"""Returns the edges of the graph."""
|
|
316
657
|
return self._edges
|
|
317
658
|
|
|
@@ -321,32 +662,42 @@ class Graph:
|
|
|
321
662
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
322
663
|
edges = sorted(edges)
|
|
323
664
|
|
|
324
|
-
if
|
|
665
|
+
if in_snowflake_notebook():
|
|
666
|
+
import streamlit as st
|
|
667
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
668
|
+
if len(edges) > 0:
|
|
669
|
+
st.markdown('\n'.join([
|
|
670
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
671
|
+
for edge in edges
|
|
672
|
+
]))
|
|
673
|
+
else:
|
|
674
|
+
st.markdown("*No links registered*")
|
|
675
|
+
elif in_notebook():
|
|
325
676
|
from IPython.display import Markdown, display
|
|
326
|
-
display(Markdown(
|
|
677
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
327
678
|
if len(edges) > 0:
|
|
328
679
|
display(
|
|
329
680
|
Markdown('\n'.join([
|
|
330
|
-
f
|
|
681
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
331
682
|
for edge in edges
|
|
332
683
|
])))
|
|
333
684
|
else:
|
|
334
|
-
display(Markdown(
|
|
685
|
+
display(Markdown("*No links registered*"))
|
|
335
686
|
else:
|
|
336
687
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
337
688
|
if len(edges) > 0:
|
|
338
689
|
print('\n'.join([
|
|
339
|
-
f
|
|
690
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
340
691
|
for edge in edges
|
|
341
692
|
]))
|
|
342
693
|
else:
|
|
343
|
-
print(
|
|
694
|
+
print("No links registered")
|
|
344
695
|
|
|
345
696
|
def link(
|
|
346
697
|
self,
|
|
347
|
-
src_table:
|
|
698
|
+
src_table: str | Table,
|
|
348
699
|
fkey: str,
|
|
349
|
-
dst_table:
|
|
700
|
+
dst_table: str | Table,
|
|
350
701
|
) -> Self:
|
|
351
702
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
352
703
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -407,9 +758,9 @@ class Graph:
|
|
|
407
758
|
|
|
408
759
|
def unlink(
|
|
409
760
|
self,
|
|
410
|
-
src_table:
|
|
761
|
+
src_table: str | Table,
|
|
411
762
|
fkey: str,
|
|
412
|
-
dst_table:
|
|
763
|
+
dst_table: str | Table,
|
|
413
764
|
) -> Self:
|
|
414
765
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
415
766
|
|
|
@@ -439,17 +790,13 @@ class Graph:
|
|
|
439
790
|
return self
|
|
440
791
|
|
|
441
792
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
442
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
793
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
794
|
+
graph.
|
|
443
795
|
|
|
444
796
|
Args:
|
|
445
797
|
verbose: Whether to print verbose output.
|
|
446
|
-
|
|
447
|
-
Note:
|
|
448
|
-
This function expects graph edges to be undefined upfront.
|
|
449
798
|
"""
|
|
450
|
-
|
|
451
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
452
|
-
return self
|
|
799
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
453
800
|
|
|
454
801
|
# A list of primary key candidates (+score) for every column:
|
|
455
802
|
candidate_dict: dict[
|
|
@@ -474,6 +821,9 @@ class Graph:
|
|
|
474
821
|
src_table_name = src_table.name.lower()
|
|
475
822
|
|
|
476
823
|
for src_key in src_table.columns:
|
|
824
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
825
|
+
continue
|
|
826
|
+
|
|
477
827
|
if src_key == src_table.primary_key:
|
|
478
828
|
continue # Cannot link to primary key.
|
|
479
829
|
|
|
@@ -539,10 +889,9 @@ class Graph:
|
|
|
539
889
|
score += 1.0
|
|
540
890
|
|
|
541
891
|
# Cardinality ratio:
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
and src_num_rows > dst_num_rows):
|
|
892
|
+
if (src_table._num_rows is not None
|
|
893
|
+
and dst_table._num_rows is not None
|
|
894
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
546
895
|
score += 1.0
|
|
547
896
|
|
|
548
897
|
if score < 5.0:
|
|
@@ -588,6 +937,10 @@ class Graph:
|
|
|
588
937
|
raise ValueError("At least one table needs to be added to the "
|
|
589
938
|
"graph")
|
|
590
939
|
|
|
940
|
+
backends = {table.backend for table in self._tables.values()}
|
|
941
|
+
if len(backends) != 1:
|
|
942
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
943
|
+
|
|
591
944
|
for edge in self.edges:
|
|
592
945
|
src_table, fkey, dst_table = edge
|
|
593
946
|
|
|
@@ -634,7 +987,7 @@ class Graph:
|
|
|
634
987
|
|
|
635
988
|
def visualize(
|
|
636
989
|
self,
|
|
637
|
-
path:
|
|
990
|
+
path: str | io.BytesIO | None = None,
|
|
638
991
|
show_columns: bool = True,
|
|
639
992
|
) -> 'graphviz.Graph':
|
|
640
993
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -659,33 +1012,33 @@ class Graph:
|
|
|
659
1012
|
|
|
660
1013
|
return True
|
|
661
1014
|
|
|
662
|
-
# Check basic dependency:
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
1015
|
+
try: # Check basic dependency:
|
|
1016
|
+
import graphviz
|
|
1017
|
+
except ImportError as e:
|
|
1018
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1019
|
+
"visualization") from e
|
|
1020
|
+
|
|
1021
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
667
1022
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
668
1023
|
"executables are not installed. These "
|
|
669
1024
|
"dependencies are required in addition to the "
|
|
670
1025
|
"'graphviz' Python package. Please install "
|
|
671
1026
|
"them as described at "
|
|
672
1027
|
"https://graphviz.org/download/.")
|
|
673
|
-
else:
|
|
674
|
-
import graphviz
|
|
675
1028
|
|
|
676
|
-
format:
|
|
1029
|
+
format: str | None = None
|
|
677
1030
|
if isinstance(path, str):
|
|
678
1031
|
format = path.split('.')[-1]
|
|
679
1032
|
elif isinstance(path, io.BytesIO):
|
|
680
1033
|
format = 'svg'
|
|
681
1034
|
graph = graphviz.Graph(format=format)
|
|
682
1035
|
|
|
683
|
-
def left_align(keys:
|
|
1036
|
+
def left_align(keys: list[str]) -> str:
|
|
684
1037
|
if len(keys) == 0:
|
|
685
1038
|
return ""
|
|
686
1039
|
return '\\l'.join(keys) + '\\l'
|
|
687
1040
|
|
|
688
|
-
fkeys_dict:
|
|
1041
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
689
1042
|
for src_table_name, fkey_name, _ in self.edges:
|
|
690
1043
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
691
1044
|
|
|
@@ -755,6 +1108,9 @@ class Graph:
|
|
|
755
1108
|
graph.render(path, cleanup=True)
|
|
756
1109
|
elif isinstance(path, io.BytesIO):
|
|
757
1110
|
path.write(graph.pipe())
|
|
1111
|
+
elif in_snowflake_notebook():
|
|
1112
|
+
import streamlit as st
|
|
1113
|
+
st.graphviz_chart(graph)
|
|
758
1114
|
elif in_notebook():
|
|
759
1115
|
from IPython.display import display
|
|
760
1116
|
display(graph)
|
|
@@ -778,8 +1134,8 @@ class Graph:
|
|
|
778
1134
|
# Helpers #################################################################
|
|
779
1135
|
|
|
780
1136
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
781
|
-
tables:
|
|
782
|
-
col_groups:
|
|
1137
|
+
tables: dict[str, TableDefinition] = {}
|
|
1138
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
783
1139
|
for table_name, table in self.tables.items():
|
|
784
1140
|
tables[table_name] = table._to_api_table_definition()
|
|
785
1141
|
if table.primary_key is None:
|
|
@@ -822,3 +1178,7 @@ class Graph:
|
|
|
822
1178
|
f' tables={tables},\n'
|
|
823
1179
|
f' edges={edges},\n'
|
|
824
1180
|
f')')
|
|
1181
|
+
|
|
1182
|
+
def __del__(self) -> None:
|
|
1183
|
+
if hasattr(self, '_connection'):
|
|
1184
|
+
self._connection.close()
|