kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +31 -14
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +75 -23
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +71 -28
- kumoai/experimental/rfm/base/__init__.py +24 -3
- kumoai/experimental/rfm/base/column.py +6 -12
- kumoai/experimental/rfm/base/column_expression.py +16 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +113 -0
- kumoai/experimental/rfm/base/table.py +136 -105
- kumoai/experimental/rfm/graph.py +296 -89
- kumoai/experimental/rfm/infer/dtype.py +46 -59
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +4 -2
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +41 -34
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
6
|
+
from collections.abc import Sequence
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
8
|
+
from itertools import chain
|
|
7
9
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
11
|
|
|
10
12
|
import pandas as pd
|
|
11
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,8 +15,8 @@ from kumoapi.table import TableDefinition
|
|
|
13
15
|
from kumoapi.typing import Stype
|
|
14
16
|
from typing_extensions import Self
|
|
15
17
|
|
|
16
|
-
from kumoai import in_notebook
|
|
17
|
-
from kumoai.experimental.rfm import Table
|
|
18
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
+
from kumoai.experimental.rfm.base import DataBackend, SQLTable, Table
|
|
18
20
|
from kumoai.graph import Edge
|
|
19
21
|
from kumoai.mixin import CastMixin
|
|
20
22
|
|
|
@@ -26,8 +28,8 @@ if TYPE_CHECKING:
|
|
|
26
28
|
|
|
27
29
|
@dataclass
|
|
28
30
|
class SqliteConnectionConfig(CastMixin):
|
|
29
|
-
uri:
|
|
30
|
-
kwargs:
|
|
31
|
+
uri: str | Path
|
|
32
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
class Graph:
|
|
@@ -87,27 +89,34 @@ class Graph:
|
|
|
87
89
|
def __init__(
|
|
88
90
|
self,
|
|
89
91
|
tables: Sequence[Table],
|
|
90
|
-
edges:
|
|
92
|
+
edges: Sequence[Edge] | None = None,
|
|
91
93
|
) -> None:
|
|
92
94
|
|
|
93
|
-
self._tables:
|
|
94
|
-
self._edges:
|
|
95
|
+
self._tables: dict[str, Table] = {}
|
|
96
|
+
self._edges: list[Edge] = []
|
|
95
97
|
|
|
96
98
|
for table in tables:
|
|
97
99
|
self.add_table(table)
|
|
98
100
|
|
|
99
101
|
for table in tables:
|
|
102
|
+
if not isinstance(table, SQLTable):
|
|
103
|
+
continue
|
|
100
104
|
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
-
if fkey.name not in table
|
|
105
|
+
if fkey.name not in table:
|
|
106
|
+
continue
|
|
107
|
+
# TODO Skip for non-physical table[fkey.name].
|
|
108
|
+
dst_table_names = [
|
|
109
|
+
table.name for table in self.tables.values()
|
|
110
|
+
if isinstance(table, SQLTable)
|
|
111
|
+
and table._source_name == fkey.dst_table
|
|
112
|
+
]
|
|
113
|
+
if len(dst_table_names) != 1:
|
|
102
114
|
continue
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
f"and '{fkey.primary_key}' in table "
|
|
109
|
-
f"'{fkey.dst_table}'.")
|
|
110
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
115
|
+
dst_table = self[dst_table_names[0]]
|
|
116
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
117
|
+
continue
|
|
118
|
+
# TODO Skip for non-physical dst_table.primary_key.
|
|
119
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
111
120
|
|
|
112
121
|
for edge in (edges or []):
|
|
113
122
|
_edge = Edge._cast(edge)
|
|
@@ -118,8 +127,8 @@ class Graph:
|
|
|
118
127
|
@classmethod
|
|
119
128
|
def from_data(
|
|
120
129
|
cls,
|
|
121
|
-
df_dict:
|
|
122
|
-
edges:
|
|
130
|
+
df_dict: dict[str, pd.DataFrame],
|
|
131
|
+
edges: Sequence[Edge] | None = None,
|
|
123
132
|
infer_metadata: bool = True,
|
|
124
133
|
verbose: bool = True,
|
|
125
134
|
) -> Self:
|
|
@@ -157,15 +166,17 @@ class Graph:
|
|
|
157
166
|
verbose: Whether to print verbose output.
|
|
158
167
|
"""
|
|
159
168
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
169
|
|
|
162
|
-
graph = cls(
|
|
170
|
+
graph = cls(
|
|
171
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
172
|
+
edges=edges or [],
|
|
173
|
+
)
|
|
163
174
|
|
|
164
175
|
if infer_metadata:
|
|
165
|
-
graph.infer_metadata(False)
|
|
176
|
+
graph.infer_metadata(verbose=False)
|
|
166
177
|
|
|
167
178
|
if edges is None:
|
|
168
|
-
graph.infer_links(False)
|
|
179
|
+
graph.infer_links(verbose=False)
|
|
169
180
|
|
|
170
181
|
if verbose:
|
|
171
182
|
graph.print_metadata()
|
|
@@ -181,10 +192,10 @@ class Graph:
|
|
|
181
192
|
SqliteConnectionConfig,
|
|
182
193
|
str,
|
|
183
194
|
Path,
|
|
184
|
-
|
|
195
|
+
dict[str, Any],
|
|
185
196
|
],
|
|
186
|
-
|
|
187
|
-
edges:
|
|
197
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
198
|
+
edges: Sequence[Edge] | None = None,
|
|
188
199
|
infer_metadata: bool = True,
|
|
189
200
|
verbose: bool = True,
|
|
190
201
|
) -> Self:
|
|
@@ -200,17 +211,25 @@ class Graph:
|
|
|
200
211
|
>>> # Create a graph from a SQLite database:
|
|
201
212
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
213
|
|
|
214
|
+
>>> # Fine-grained control over table specification:
|
|
215
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
216
|
+
... 'USERS',
|
|
217
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
218
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
219
|
+
... ])
|
|
220
|
+
|
|
203
221
|
Args:
|
|
204
222
|
connection: An open connection from
|
|
205
223
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
224
|
path to the database file.
|
|
207
|
-
|
|
208
|
-
all tables present
|
|
225
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
226
|
+
arguments to include. If ``None``, will add all tables present
|
|
227
|
+
in the database.
|
|
209
228
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
210
229
|
add to the graph. If not provided, edges will be automatically
|
|
211
230
|
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
213
|
-
graph.
|
|
231
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
232
|
+
the graph.
|
|
214
233
|
verbose: Whether to print verbose output.
|
|
215
234
|
"""
|
|
216
235
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -219,27 +238,41 @@ class Graph:
|
|
|
219
238
|
connect,
|
|
220
239
|
)
|
|
221
240
|
|
|
241
|
+
internal_connection = False
|
|
222
242
|
if not isinstance(connection, Connection):
|
|
223
243
|
connection = SqliteConnectionConfig._cast(connection)
|
|
224
244
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
245
|
connection = connect(connection.uri, **connection.kwargs)
|
|
246
|
+
internal_connection = True
|
|
226
247
|
assert isinstance(connection, Connection)
|
|
227
248
|
|
|
228
|
-
if
|
|
249
|
+
if tables is None:
|
|
229
250
|
with connection.cursor() as cursor:
|
|
230
251
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
252
|
"WHERE type='table'")
|
|
232
|
-
|
|
253
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
233
254
|
|
|
234
|
-
|
|
255
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
256
|
+
for table in tables:
|
|
257
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
258
|
+
table_kwargs.append(kwargs)
|
|
259
|
+
|
|
260
|
+
graph = cls(
|
|
261
|
+
tables=[
|
|
262
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
263
|
+
for kwargs in table_kwargs
|
|
264
|
+
],
|
|
265
|
+
edges=edges or [],
|
|
266
|
+
)
|
|
235
267
|
|
|
236
|
-
|
|
268
|
+
if internal_connection:
|
|
269
|
+
graph._connection = connection # type: ignore
|
|
237
270
|
|
|
238
271
|
if infer_metadata:
|
|
239
|
-
graph.infer_metadata(False)
|
|
272
|
+
graph.infer_metadata(verbose=False)
|
|
240
273
|
|
|
241
274
|
if edges is None:
|
|
242
|
-
graph.infer_links(False)
|
|
275
|
+
graph.infer_links(verbose=False)
|
|
243
276
|
|
|
244
277
|
if verbose:
|
|
245
278
|
graph.print_metadata()
|
|
@@ -250,9 +283,11 @@ class Graph:
|
|
|
250
283
|
@classmethod
|
|
251
284
|
def from_snowflake(
|
|
252
285
|
cls,
|
|
253
|
-
connection: Union['SnowflakeConnection',
|
|
254
|
-
|
|
255
|
-
|
|
286
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
287
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
288
|
+
database: str | None = None,
|
|
289
|
+
schema: str | None = None,
|
|
290
|
+
edges: Sequence[Edge] | None = None,
|
|
256
291
|
infer_metadata: bool = True,
|
|
257
292
|
verbose: bool = True,
|
|
258
293
|
) -> Self:
|
|
@@ -267,7 +302,14 @@ class Graph:
|
|
|
267
302
|
>>> import kumoai.experimental.rfm as rfm
|
|
268
303
|
|
|
269
304
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
270
|
-
>>> graph = rfm.Graph.from_snowflake()
|
|
305
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
306
|
+
|
|
307
|
+
>>> # Fine-grained control over table specification:
|
|
308
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
309
|
+
... 'USERS',
|
|
310
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
311
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
312
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
271
313
|
|
|
272
314
|
Args:
|
|
273
315
|
connection: An open connection from
|
|
@@ -276,8 +318,11 @@ class Graph:
|
|
|
276
318
|
connection. If ``None``, will re-use an active session in case
|
|
277
319
|
it exists, or create a new connection from credentials stored
|
|
278
320
|
in environment variables.
|
|
279
|
-
|
|
280
|
-
all tables present in the
|
|
321
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
322
|
+
to include. If ``None``, will add all tables present in the
|
|
323
|
+
current database and schema.
|
|
324
|
+
database: The database.
|
|
325
|
+
schema: The schema.
|
|
281
326
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
282
327
|
add to the graph. If not provided, edges will be automatically
|
|
283
328
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -295,27 +340,50 @@ class Graph:
|
|
|
295
340
|
connection = connect(**(connection or {}))
|
|
296
341
|
assert isinstance(connection, Connection)
|
|
297
342
|
|
|
298
|
-
if
|
|
343
|
+
if database is None or schema is None:
|
|
299
344
|
with connection.cursor() as cursor:
|
|
300
345
|
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
301
|
-
|
|
302
|
-
|
|
346
|
+
result = cursor.fetchone()
|
|
347
|
+
database = database or result[0]
|
|
348
|
+
assert database is not None
|
|
349
|
+
schema = schema or result[1]
|
|
350
|
+
|
|
351
|
+
if tables is None:
|
|
352
|
+
if schema is None:
|
|
353
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
354
|
+
"Snowflake schema manually")
|
|
355
|
+
|
|
356
|
+
with connection.cursor() as cursor:
|
|
357
|
+
cursor.execute(f"""
|
|
303
358
|
SELECT TABLE_NAME
|
|
304
359
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
305
360
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
306
|
-
"""
|
|
307
|
-
cursor.
|
|
308
|
-
table_names = [row[0] for row in cursor.fetchall()]
|
|
361
|
+
""")
|
|
362
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
309
363
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
364
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
365
|
+
for table in tables:
|
|
366
|
+
if isinstance(table, str):
|
|
367
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
368
|
+
else:
|
|
369
|
+
kwargs = copy.copy(table)
|
|
370
|
+
kwargs.setdefault('database', database)
|
|
371
|
+
kwargs.setdefault('schema', schema)
|
|
372
|
+
table_kwargs.append(kwargs)
|
|
373
|
+
|
|
374
|
+
graph = cls(
|
|
375
|
+
tables=[
|
|
376
|
+
SnowTable(connection=connection, **kwargs)
|
|
377
|
+
for kwargs in table_kwargs
|
|
378
|
+
],
|
|
379
|
+
edges=edges or [],
|
|
380
|
+
)
|
|
313
381
|
|
|
314
382
|
if infer_metadata:
|
|
315
|
-
graph.infer_metadata(False)
|
|
383
|
+
graph.infer_metadata(verbose=False)
|
|
316
384
|
|
|
317
385
|
if edges is None:
|
|
318
|
-
graph.infer_links(False)
|
|
386
|
+
graph.infer_links(verbose=False)
|
|
319
387
|
|
|
320
388
|
if verbose:
|
|
321
389
|
graph.print_metadata()
|
|
@@ -323,7 +391,124 @@ class Graph:
|
|
|
323
391
|
|
|
324
392
|
return graph
|
|
325
393
|
|
|
326
|
-
|
|
394
|
+
@classmethod
|
|
395
|
+
def from_snowflake_semantic_view(
|
|
396
|
+
cls,
|
|
397
|
+
semantic_view_name: str,
|
|
398
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
399
|
+
verbose: bool = True,
|
|
400
|
+
) -> Self:
|
|
401
|
+
import yaml
|
|
402
|
+
|
|
403
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
404
|
+
Connection,
|
|
405
|
+
SnowTable,
|
|
406
|
+
connect,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
if not isinstance(connection, Connection):
|
|
410
|
+
connection = connect(**(connection or {}))
|
|
411
|
+
assert isinstance(connection, Connection)
|
|
412
|
+
|
|
413
|
+
with connection.cursor() as cursor:
|
|
414
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
415
|
+
f"'{semantic_view_name}')")
|
|
416
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
417
|
+
|
|
418
|
+
graph = cls(tables=[])
|
|
419
|
+
|
|
420
|
+
msgs = []
|
|
421
|
+
for table_cfg in cfg['tables']:
|
|
422
|
+
table_name = table_cfg['name']
|
|
423
|
+
source_table_name = table_cfg['base_table']['table']
|
|
424
|
+
database = table_cfg['base_table']['database']
|
|
425
|
+
schema = table_cfg['base_table']['schema']
|
|
426
|
+
|
|
427
|
+
primary_key: str | None = None
|
|
428
|
+
if 'primary_key' in table_cfg:
|
|
429
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
430
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
431
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
432
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
433
|
+
msgs.append(f"Failed to add primary key for table "
|
|
434
|
+
f"'{table_name}' since composite primary keys "
|
|
435
|
+
f"are not yet supported")
|
|
436
|
+
|
|
437
|
+
columns: list[str] = []
|
|
438
|
+
for column_cfg in chain(
|
|
439
|
+
table_cfg.get('dimensions', []),
|
|
440
|
+
table_cfg.get('time_dimensions', []),
|
|
441
|
+
table_cfg.get('facts', []),
|
|
442
|
+
):
|
|
443
|
+
# TODO Add support for derived columns.
|
|
444
|
+
columns.append(column_cfg['name'])
|
|
445
|
+
|
|
446
|
+
table = SnowTable(
|
|
447
|
+
connection,
|
|
448
|
+
name=table_name,
|
|
449
|
+
source_name=source_table_name,
|
|
450
|
+
database=database,
|
|
451
|
+
schema=schema,
|
|
452
|
+
columns=columns,
|
|
453
|
+
primary_key=primary_key,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
457
|
+
table.infer_time_column(verbose=False)
|
|
458
|
+
|
|
459
|
+
graph.add_table(table)
|
|
460
|
+
|
|
461
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
462
|
+
name = relation_cfg['name']
|
|
463
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
464
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
465
|
+
f"composite key references are not yet supported")
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
left_table = relation_cfg['left_table']
|
|
469
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
470
|
+
right_table = relation_cfg['right_table']
|
|
471
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
472
|
+
|
|
473
|
+
if graph[right_table]._primary_key != right_key:
|
|
474
|
+
# Semantic view error - this should never be triggered:
|
|
475
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
476
|
+
f"referenced key '{right_key}' of table "
|
|
477
|
+
f"'{right_table}' is not a primary key")
|
|
478
|
+
continue
|
|
479
|
+
|
|
480
|
+
if graph[left_table]._primary_key == left_key:
|
|
481
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
482
|
+
f"referencing key '{left_key}' of table "
|
|
483
|
+
f"'{left_table}' is a primary key")
|
|
484
|
+
continue
|
|
485
|
+
|
|
486
|
+
if left_key not in graph[left_table]:
|
|
487
|
+
graph[left_table].add_column(left_key)
|
|
488
|
+
|
|
489
|
+
graph.link(left_table, left_key, right_table)
|
|
490
|
+
|
|
491
|
+
graph.validate()
|
|
492
|
+
|
|
493
|
+
if verbose:
|
|
494
|
+
graph.print_metadata()
|
|
495
|
+
graph.print_links()
|
|
496
|
+
|
|
497
|
+
if len(msgs) > 0:
|
|
498
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
499
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
500
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
501
|
+
|
|
502
|
+
return graph
|
|
503
|
+
|
|
504
|
+
# Backend #################################################################
|
|
505
|
+
|
|
506
|
+
@property
|
|
507
|
+
def backend(self) -> DataBackend | None:
|
|
508
|
+
backends = [table.backend for table in self._tables.values()]
|
|
509
|
+
return backends[0] if len(backends) > 0 else None
|
|
510
|
+
|
|
511
|
+
# Tables ##################################################################
|
|
327
512
|
|
|
328
513
|
def has_table(self, name: str) -> bool:
|
|
329
514
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -342,7 +527,7 @@ class Graph:
|
|
|
342
527
|
return self.tables[name]
|
|
343
528
|
|
|
344
529
|
@property
|
|
345
|
-
def tables(self) ->
|
|
530
|
+
def tables(self) -> dict[str, Table]:
|
|
346
531
|
r"""Returns the dictionary of table objects."""
|
|
347
532
|
return self._tables
|
|
348
533
|
|
|
@@ -362,13 +547,10 @@ class Graph:
|
|
|
362
547
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
363
548
|
f"this graph; table names must be globally unique.")
|
|
364
549
|
|
|
365
|
-
if
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
f"'{table.__class__.__name__}' to this "
|
|
370
|
-
f"graph since other tables are of type "
|
|
371
|
-
f"'{cls.__name__}'.")
|
|
550
|
+
if self.backend is not None and table.backend != self.backend:
|
|
551
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
552
|
+
f"'{table.backend}' to this graph since other "
|
|
553
|
+
f"tables have backend '{self.backend}'.")
|
|
372
554
|
|
|
373
555
|
self._tables[table.name] = table
|
|
374
556
|
|
|
@@ -430,9 +612,13 @@ class Graph:
|
|
|
430
612
|
|
|
431
613
|
def print_metadata(self) -> None:
|
|
432
614
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
433
|
-
if
|
|
615
|
+
if in_snowflake_notebook():
|
|
616
|
+
import streamlit as st
|
|
617
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
618
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
619
|
+
elif in_notebook():
|
|
434
620
|
from IPython.display import Markdown, display
|
|
435
|
-
display(Markdown(
|
|
621
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
436
622
|
df = self.metadata
|
|
437
623
|
try:
|
|
438
624
|
if hasattr(df.style, 'hide'):
|
|
@@ -466,7 +652,7 @@ class Graph:
|
|
|
466
652
|
# Edges ###################################################################
|
|
467
653
|
|
|
468
654
|
@property
|
|
469
|
-
def edges(self) ->
|
|
655
|
+
def edges(self) -> list[Edge]:
|
|
470
656
|
r"""Returns the edges of the graph."""
|
|
471
657
|
return self._edges
|
|
472
658
|
|
|
@@ -476,32 +662,42 @@ class Graph:
|
|
|
476
662
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
477
663
|
edges = sorted(edges)
|
|
478
664
|
|
|
479
|
-
if
|
|
665
|
+
if in_snowflake_notebook():
|
|
666
|
+
import streamlit as st
|
|
667
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
668
|
+
if len(edges) > 0:
|
|
669
|
+
st.markdown('\n'.join([
|
|
670
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
671
|
+
for edge in edges
|
|
672
|
+
]))
|
|
673
|
+
else:
|
|
674
|
+
st.markdown("*No links registered*")
|
|
675
|
+
elif in_notebook():
|
|
480
676
|
from IPython.display import Markdown, display
|
|
481
|
-
display(Markdown(
|
|
677
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
482
678
|
if len(edges) > 0:
|
|
483
679
|
display(
|
|
484
680
|
Markdown('\n'.join([
|
|
485
|
-
f
|
|
681
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
486
682
|
for edge in edges
|
|
487
683
|
])))
|
|
488
684
|
else:
|
|
489
|
-
display(Markdown(
|
|
685
|
+
display(Markdown("*No links registered*"))
|
|
490
686
|
else:
|
|
491
687
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
492
688
|
if len(edges) > 0:
|
|
493
689
|
print('\n'.join([
|
|
494
|
-
f
|
|
690
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
495
691
|
for edge in edges
|
|
496
692
|
]))
|
|
497
693
|
else:
|
|
498
|
-
print(
|
|
694
|
+
print("No links registered")
|
|
499
695
|
|
|
500
696
|
def link(
|
|
501
697
|
self,
|
|
502
|
-
src_table:
|
|
698
|
+
src_table: str | Table,
|
|
503
699
|
fkey: str,
|
|
504
|
-
dst_table:
|
|
700
|
+
dst_table: str | Table,
|
|
505
701
|
) -> Self:
|
|
506
702
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
507
703
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -562,9 +758,9 @@ class Graph:
|
|
|
562
758
|
|
|
563
759
|
def unlink(
|
|
564
760
|
self,
|
|
565
|
-
src_table:
|
|
761
|
+
src_table: str | Table,
|
|
566
762
|
fkey: str,
|
|
567
|
-
dst_table:
|
|
763
|
+
dst_table: str | Table,
|
|
568
764
|
) -> Self:
|
|
569
765
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
570
766
|
|
|
@@ -741,6 +937,10 @@ class Graph:
|
|
|
741
937
|
raise ValueError("At least one table needs to be added to the "
|
|
742
938
|
"graph")
|
|
743
939
|
|
|
940
|
+
backends = {table.backend for table in self._tables.values()}
|
|
941
|
+
if len(backends) != 1:
|
|
942
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
943
|
+
|
|
744
944
|
for edge in self.edges:
|
|
745
945
|
src_table, fkey, dst_table = edge
|
|
746
946
|
|
|
@@ -787,7 +987,7 @@ class Graph:
|
|
|
787
987
|
|
|
788
988
|
def visualize(
|
|
789
989
|
self,
|
|
790
|
-
path:
|
|
990
|
+
path: str | io.BytesIO | None = None,
|
|
791
991
|
show_columns: bool = True,
|
|
792
992
|
) -> 'graphviz.Graph':
|
|
793
993
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -812,33 +1012,33 @@ class Graph:
|
|
|
812
1012
|
|
|
813
1013
|
return True
|
|
814
1014
|
|
|
815
|
-
# Check basic dependency:
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
1015
|
+
try: # Check basic dependency:
|
|
1016
|
+
import graphviz
|
|
1017
|
+
except ImportError as e:
|
|
1018
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1019
|
+
"visualization") from e
|
|
1020
|
+
|
|
1021
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
820
1022
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
821
1023
|
"executables are not installed. These "
|
|
822
1024
|
"dependencies are required in addition to the "
|
|
823
1025
|
"'graphviz' Python package. Please install "
|
|
824
1026
|
"them as described at "
|
|
825
1027
|
"https://graphviz.org/download/.")
|
|
826
|
-
else:
|
|
827
|
-
import graphviz
|
|
828
1028
|
|
|
829
|
-
format:
|
|
1029
|
+
format: str | None = None
|
|
830
1030
|
if isinstance(path, str):
|
|
831
1031
|
format = path.split('.')[-1]
|
|
832
1032
|
elif isinstance(path, io.BytesIO):
|
|
833
1033
|
format = 'svg'
|
|
834
1034
|
graph = graphviz.Graph(format=format)
|
|
835
1035
|
|
|
836
|
-
def left_align(keys:
|
|
1036
|
+
def left_align(keys: list[str]) -> str:
|
|
837
1037
|
if len(keys) == 0:
|
|
838
1038
|
return ""
|
|
839
1039
|
return '\\l'.join(keys) + '\\l'
|
|
840
1040
|
|
|
841
|
-
fkeys_dict:
|
|
1041
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
842
1042
|
for src_table_name, fkey_name, _ in self.edges:
|
|
843
1043
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
844
1044
|
|
|
@@ -908,6 +1108,9 @@ class Graph:
|
|
|
908
1108
|
graph.render(path, cleanup=True)
|
|
909
1109
|
elif isinstance(path, io.BytesIO):
|
|
910
1110
|
path.write(graph.pipe())
|
|
1111
|
+
elif in_snowflake_notebook():
|
|
1112
|
+
import streamlit as st
|
|
1113
|
+
st.graphviz_chart(graph)
|
|
911
1114
|
elif in_notebook():
|
|
912
1115
|
from IPython.display import display
|
|
913
1116
|
display(graph)
|
|
@@ -931,8 +1134,8 @@ class Graph:
|
|
|
931
1134
|
# Helpers #################################################################
|
|
932
1135
|
|
|
933
1136
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
934
|
-
tables:
|
|
935
|
-
col_groups:
|
|
1137
|
+
tables: dict[str, TableDefinition] = {}
|
|
1138
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
936
1139
|
for table_name, table in self.tables.items():
|
|
937
1140
|
tables[table_name] = table._to_api_table_definition()
|
|
938
1141
|
if table.primary_key is None:
|
|
@@ -975,3 +1178,7 @@ class Graph:
|
|
|
975
1178
|
f' tables={tables},\n'
|
|
976
1179
|
f' edges={edges},\n'
|
|
977
1180
|
f')')
|
|
1181
|
+
|
|
1182
|
+
def __del__(self) -> None:
|
|
1183
|
+
if hasattr(self, '_connection'):
|
|
1184
|
+
self._connection.close()
|