kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +21 -16
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +102 -48
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
- kumoai/experimental/rfm/base/__init__.py +26 -3
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +173 -138
- kumoai/experimental/rfm/graph.py +302 -108
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/rfm.py +299 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +44 -36
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
6
|
+
from collections.abc import Sequence
|
|
5
7
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
8
|
+
from itertools import chain
|
|
7
9
|
from pathlib import Path
|
|
8
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
9
11
|
|
|
10
12
|
import pandas as pd
|
|
11
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,8 +15,13 @@ from kumoapi.table import TableDefinition
|
|
|
13
15
|
from kumoapi.typing import Stype
|
|
14
16
|
from typing_extensions import Self
|
|
15
17
|
|
|
16
|
-
from kumoai import in_notebook
|
|
17
|
-
from kumoai.experimental.rfm import
|
|
18
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
+
from kumoai.experimental.rfm.base import (
|
|
20
|
+
ColumnExpressionSpec,
|
|
21
|
+
DataBackend,
|
|
22
|
+
SQLTable,
|
|
23
|
+
Table,
|
|
24
|
+
)
|
|
18
25
|
from kumoai.graph import Edge
|
|
19
26
|
from kumoai.mixin import CastMixin
|
|
20
27
|
|
|
@@ -26,8 +33,8 @@ if TYPE_CHECKING:
|
|
|
26
33
|
|
|
27
34
|
@dataclass
|
|
28
35
|
class SqliteConnectionConfig(CastMixin):
|
|
29
|
-
uri:
|
|
30
|
-
kwargs:
|
|
36
|
+
uri: str | Path
|
|
37
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
31
38
|
|
|
32
39
|
|
|
33
40
|
class Graph:
|
|
@@ -87,27 +94,38 @@ class Graph:
|
|
|
87
94
|
def __init__(
|
|
88
95
|
self,
|
|
89
96
|
tables: Sequence[Table],
|
|
90
|
-
edges:
|
|
97
|
+
edges: Sequence[Edge] | None = None,
|
|
91
98
|
) -> None:
|
|
92
99
|
|
|
93
|
-
self._tables:
|
|
94
|
-
self._edges:
|
|
100
|
+
self._tables: dict[str, Table] = {}
|
|
101
|
+
self._edges: list[Edge] = []
|
|
95
102
|
|
|
96
103
|
for table in tables:
|
|
97
104
|
self.add_table(table)
|
|
98
105
|
|
|
99
106
|
for table in tables:
|
|
107
|
+
if not isinstance(table, SQLTable):
|
|
108
|
+
continue
|
|
109
|
+
if '_source_column_dict' not in table.__dict__:
|
|
110
|
+
continue
|
|
100
111
|
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
-
if fkey.name not in table
|
|
112
|
+
if fkey.name not in table:
|
|
113
|
+
continue
|
|
114
|
+
if not table[fkey.name].is_physical:
|
|
115
|
+
continue
|
|
116
|
+
dst_table_names = [
|
|
117
|
+
table.name for table in self.tables.values()
|
|
118
|
+
if isinstance(table, SQLTable)
|
|
119
|
+
and table._source_name == fkey.dst_table
|
|
120
|
+
]
|
|
121
|
+
if len(dst_table_names) != 1:
|
|
122
|
+
continue
|
|
123
|
+
dst_table = self[dst_table_names[0]]
|
|
124
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
125
|
+
continue
|
|
126
|
+
if not dst_table[fkey.primary_key].is_physical:
|
|
102
127
|
continue
|
|
103
|
-
|
|
104
|
-
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
105
|
-
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
106
|
-
raise ValueError(f"Found duplicate primary key definition "
|
|
107
|
-
f"'{self[fkey.dst_table]._primary_key}' "
|
|
108
|
-
f"and '{fkey.primary_key}' in table "
|
|
109
|
-
f"'{fkey.dst_table}'.")
|
|
110
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
128
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
111
129
|
|
|
112
130
|
for edge in (edges or []):
|
|
113
131
|
_edge = Edge._cast(edge)
|
|
@@ -118,8 +136,8 @@ class Graph:
|
|
|
118
136
|
@classmethod
|
|
119
137
|
def from_data(
|
|
120
138
|
cls,
|
|
121
|
-
df_dict:
|
|
122
|
-
edges:
|
|
139
|
+
df_dict: dict[str, pd.DataFrame],
|
|
140
|
+
edges: Sequence[Edge] | None = None,
|
|
123
141
|
infer_metadata: bool = True,
|
|
124
142
|
verbose: bool = True,
|
|
125
143
|
) -> Self:
|
|
@@ -157,15 +175,17 @@ class Graph:
|
|
|
157
175
|
verbose: Whether to print verbose output.
|
|
158
176
|
"""
|
|
159
177
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
178
|
|
|
162
|
-
graph = cls(
|
|
179
|
+
graph = cls(
|
|
180
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
181
|
+
edges=edges or [],
|
|
182
|
+
)
|
|
163
183
|
|
|
164
184
|
if infer_metadata:
|
|
165
|
-
graph.infer_metadata(False)
|
|
185
|
+
graph.infer_metadata(verbose=False)
|
|
166
186
|
|
|
167
187
|
if edges is None:
|
|
168
|
-
graph.infer_links(False)
|
|
188
|
+
graph.infer_links(verbose=False)
|
|
169
189
|
|
|
170
190
|
if verbose:
|
|
171
191
|
graph.print_metadata()
|
|
@@ -181,10 +201,10 @@ class Graph:
|
|
|
181
201
|
SqliteConnectionConfig,
|
|
182
202
|
str,
|
|
183
203
|
Path,
|
|
184
|
-
|
|
204
|
+
dict[str, Any],
|
|
185
205
|
],
|
|
186
|
-
|
|
187
|
-
edges:
|
|
206
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
207
|
+
edges: Sequence[Edge] | None = None,
|
|
188
208
|
infer_metadata: bool = True,
|
|
189
209
|
verbose: bool = True,
|
|
190
210
|
) -> Self:
|
|
@@ -200,17 +220,25 @@ class Graph:
|
|
|
200
220
|
>>> # Create a graph from a SQLite database:
|
|
201
221
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
222
|
|
|
223
|
+
>>> # Fine-grained control over table specification:
|
|
224
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
225
|
+
... 'USERS',
|
|
226
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
227
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
228
|
+
... ])
|
|
229
|
+
|
|
203
230
|
Args:
|
|
204
231
|
connection: An open connection from
|
|
205
232
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
233
|
path to the database file.
|
|
207
|
-
|
|
208
|
-
all tables present
|
|
234
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
235
|
+
arguments to include. If ``None``, will add all tables present
|
|
236
|
+
in the database.
|
|
209
237
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
210
238
|
add to the graph. If not provided, edges will be automatically
|
|
211
239
|
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
213
|
-
graph.
|
|
240
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
241
|
+
the graph.
|
|
214
242
|
verbose: Whether to print verbose output.
|
|
215
243
|
"""
|
|
216
244
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -219,27 +247,41 @@ class Graph:
|
|
|
219
247
|
connect,
|
|
220
248
|
)
|
|
221
249
|
|
|
250
|
+
internal_connection = False
|
|
222
251
|
if not isinstance(connection, Connection):
|
|
223
252
|
connection = SqliteConnectionConfig._cast(connection)
|
|
224
253
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
254
|
connection = connect(connection.uri, **connection.kwargs)
|
|
255
|
+
internal_connection = True
|
|
226
256
|
assert isinstance(connection, Connection)
|
|
227
257
|
|
|
228
|
-
if
|
|
258
|
+
if tables is None:
|
|
229
259
|
with connection.cursor() as cursor:
|
|
230
260
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
261
|
"WHERE type='table'")
|
|
232
|
-
|
|
262
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
233
263
|
|
|
234
|
-
|
|
264
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
265
|
+
for table in tables:
|
|
266
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
267
|
+
table_kwargs.append(kwargs)
|
|
268
|
+
|
|
269
|
+
graph = cls(
|
|
270
|
+
tables=[
|
|
271
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
272
|
+
for kwargs in table_kwargs
|
|
273
|
+
],
|
|
274
|
+
edges=edges or [],
|
|
275
|
+
)
|
|
235
276
|
|
|
236
|
-
|
|
277
|
+
if internal_connection:
|
|
278
|
+
graph._connection = connection # type: ignore
|
|
237
279
|
|
|
238
280
|
if infer_metadata:
|
|
239
|
-
graph.infer_metadata(False)
|
|
281
|
+
graph.infer_metadata(verbose=False)
|
|
240
282
|
|
|
241
283
|
if edges is None:
|
|
242
|
-
graph.infer_links(False)
|
|
284
|
+
graph.infer_links(verbose=False)
|
|
243
285
|
|
|
244
286
|
if verbose:
|
|
245
287
|
graph.print_metadata()
|
|
@@ -250,9 +292,11 @@ class Graph:
|
|
|
250
292
|
@classmethod
|
|
251
293
|
def from_snowflake(
|
|
252
294
|
cls,
|
|
253
|
-
connection: Union['SnowflakeConnection',
|
|
254
|
-
|
|
255
|
-
|
|
295
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
296
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
297
|
+
database: str | None = None,
|
|
298
|
+
schema: str | None = None,
|
|
299
|
+
edges: Sequence[Edge] | None = None,
|
|
256
300
|
infer_metadata: bool = True,
|
|
257
301
|
verbose: bool = True,
|
|
258
302
|
) -> Self:
|
|
@@ -267,7 +311,14 @@ class Graph:
|
|
|
267
311
|
>>> import kumoai.experimental.rfm as rfm
|
|
268
312
|
|
|
269
313
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
270
|
-
>>> graph = rfm.Graph.from_snowflake()
|
|
314
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
315
|
+
|
|
316
|
+
>>> # Fine-grained control over table specification:
|
|
317
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
318
|
+
... 'USERS',
|
|
319
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
320
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
321
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
271
322
|
|
|
272
323
|
Args:
|
|
273
324
|
connection: An open connection from
|
|
@@ -276,8 +327,11 @@ class Graph:
|
|
|
276
327
|
connection. If ``None``, will re-use an active session in case
|
|
277
328
|
it exists, or create a new connection from credentials stored
|
|
278
329
|
in environment variables.
|
|
279
|
-
|
|
280
|
-
all tables present in the
|
|
330
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
331
|
+
to include. If ``None``, will add all tables present in the
|
|
332
|
+
current database and schema.
|
|
333
|
+
database: The database.
|
|
334
|
+
schema: The schema.
|
|
281
335
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
282
336
|
add to the graph. If not provided, edges will be automatically
|
|
283
337
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -295,27 +349,50 @@ class Graph:
|
|
|
295
349
|
connection = connect(**(connection or {}))
|
|
296
350
|
assert isinstance(connection, Connection)
|
|
297
351
|
|
|
298
|
-
if
|
|
352
|
+
if database is None or schema is None:
|
|
299
353
|
with connection.cursor() as cursor:
|
|
300
354
|
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
301
|
-
|
|
302
|
-
|
|
355
|
+
result = cursor.fetchone()
|
|
356
|
+
database = database or result[0]
|
|
357
|
+
assert database is not None
|
|
358
|
+
schema = schema or result[1]
|
|
359
|
+
|
|
360
|
+
if tables is None:
|
|
361
|
+
if schema is None:
|
|
362
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
363
|
+
"Snowflake schema manually")
|
|
364
|
+
|
|
365
|
+
with connection.cursor() as cursor:
|
|
366
|
+
cursor.execute(f"""
|
|
303
367
|
SELECT TABLE_NAME
|
|
304
368
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
305
369
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
306
|
-
"""
|
|
307
|
-
cursor.
|
|
308
|
-
table_names = [row[0] for row in cursor.fetchall()]
|
|
309
|
-
|
|
310
|
-
tables = [SnowTable(connection, name) for name in table_names]
|
|
370
|
+
""")
|
|
371
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
311
372
|
|
|
312
|
-
|
|
373
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
374
|
+
for table in tables:
|
|
375
|
+
if isinstance(table, str):
|
|
376
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
377
|
+
else:
|
|
378
|
+
kwargs = copy.copy(table)
|
|
379
|
+
kwargs.setdefault('database', database)
|
|
380
|
+
kwargs.setdefault('schema', schema)
|
|
381
|
+
table_kwargs.append(kwargs)
|
|
382
|
+
|
|
383
|
+
graph = cls(
|
|
384
|
+
tables=[
|
|
385
|
+
SnowTable(connection=connection, **kwargs)
|
|
386
|
+
for kwargs in table_kwargs
|
|
387
|
+
],
|
|
388
|
+
edges=edges or [],
|
|
389
|
+
)
|
|
313
390
|
|
|
314
391
|
if infer_metadata:
|
|
315
|
-
graph.infer_metadata(False)
|
|
392
|
+
graph.infer_metadata(verbose=False)
|
|
316
393
|
|
|
317
394
|
if edges is None:
|
|
318
|
-
graph.infer_links(False)
|
|
395
|
+
graph.infer_links(verbose=False)
|
|
319
396
|
|
|
320
397
|
if verbose:
|
|
321
398
|
graph.print_metadata()
|
|
@@ -327,7 +404,7 @@ class Graph:
|
|
|
327
404
|
def from_snowflake_semantic_view(
|
|
328
405
|
cls,
|
|
329
406
|
semantic_view_name: str,
|
|
330
|
-
connection: Union['SnowflakeConnection',
|
|
407
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
331
408
|
verbose: bool = True,
|
|
332
409
|
) -> Self:
|
|
333
410
|
import yaml
|
|
@@ -345,43 +422,138 @@ class Graph:
|
|
|
345
422
|
with connection.cursor() as cursor:
|
|
346
423
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
347
424
|
f"'{semantic_view_name}')")
|
|
348
|
-
|
|
425
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
349
426
|
|
|
350
427
|
graph = cls(tables=[])
|
|
351
428
|
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
429
|
+
msgs = []
|
|
430
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
431
|
+
for table_cfg in cfg['tables']:
|
|
432
|
+
table_name = table_cfg['name']
|
|
433
|
+
source_table_name = table_cfg['base_table']['table']
|
|
434
|
+
database = table_cfg['base_table']['database']
|
|
435
|
+
schema = table_cfg['base_table']['schema']
|
|
436
|
+
|
|
437
|
+
primary_key: str | None = None
|
|
438
|
+
if 'primary_key' in table_cfg:
|
|
439
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
440
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
441
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
442
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
443
|
+
msgs.append(f"Failed to add primary key for table "
|
|
444
|
+
f"'{table_name}' since composite primary keys "
|
|
445
|
+
f"are not yet supported")
|
|
446
|
+
|
|
447
|
+
columns: list[str] = []
|
|
448
|
+
unsupported_columns: list[str] = []
|
|
449
|
+
column_expression_specs: list[ColumnExpressionSpec] = []
|
|
450
|
+
for column_cfg in chain(
|
|
451
|
+
table_cfg.get('dimensions', []),
|
|
452
|
+
table_cfg.get('time_dimensions', []),
|
|
453
|
+
table_cfg.get('facts', []),
|
|
454
|
+
):
|
|
455
|
+
column_name = column_cfg['name']
|
|
456
|
+
column_expr = column_cfg.get('expr', None)
|
|
457
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
458
|
+
|
|
459
|
+
if column_expr is None:
|
|
460
|
+
columns.append(column_name)
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
464
|
+
|
|
465
|
+
if column_expr == column_name:
|
|
466
|
+
columns.append(column_name)
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
# Drop expressions that reference other tables (for now):
|
|
470
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
471
|
+
unsupported_columns.append(column_name)
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
spec = ColumnExpressionSpec(
|
|
475
|
+
name=column_name,
|
|
476
|
+
expr=column_expr,
|
|
477
|
+
dtype=SnowTable.to_dtype(column_data_type),
|
|
478
|
+
)
|
|
479
|
+
column_expression_specs.append(spec)
|
|
480
|
+
|
|
481
|
+
if len(unsupported_columns) == 1:
|
|
482
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
483
|
+
f"of table '{table_name}' since its expression "
|
|
484
|
+
f"references other tables")
|
|
485
|
+
elif len(unsupported_columns) > 1:
|
|
486
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
487
|
+
f"of table '{table_name}' since their expressions "
|
|
488
|
+
f"reference other tables")
|
|
357
489
|
|
|
358
490
|
table = SnowTable(
|
|
359
491
|
connection,
|
|
360
|
-
name=
|
|
361
|
-
|
|
362
|
-
|
|
492
|
+
name=table_name,
|
|
493
|
+
source_name=source_table_name,
|
|
494
|
+
database=database,
|
|
495
|
+
schema=schema,
|
|
496
|
+
columns=columns,
|
|
497
|
+
column_expressions=column_expression_specs,
|
|
363
498
|
primary_key=primary_key,
|
|
364
499
|
)
|
|
500
|
+
|
|
501
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
502
|
+
table.infer_time_column(verbose=False)
|
|
503
|
+
|
|
365
504
|
graph.add_table(table)
|
|
366
505
|
|
|
367
|
-
|
|
506
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
507
|
+
name = relation_cfg['name']
|
|
508
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
509
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
510
|
+
f"composite key references are not yet supported")
|
|
511
|
+
continue
|
|
368
512
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
513
|
+
left_table = relation_cfg['left_table']
|
|
514
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
515
|
+
right_table = relation_cfg['right_table']
|
|
516
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
517
|
+
|
|
518
|
+
if graph[right_table]._primary_key != right_key:
|
|
519
|
+
# Semantic view error - this should never be triggered:
|
|
520
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
521
|
+
f"referenced key '{right_key}' of table "
|
|
522
|
+
f"'{right_table}' is not a primary key")
|
|
523
|
+
continue
|
|
524
|
+
|
|
525
|
+
if graph[left_table]._primary_key == left_key:
|
|
526
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
527
|
+
f"referencing key '{left_key}' of table "
|
|
528
|
+
f"'{left_table}' is a primary key")
|
|
529
|
+
continue
|
|
530
|
+
|
|
531
|
+
if left_key not in graph[left_table]:
|
|
532
|
+
graph[left_table].add_column(left_key)
|
|
533
|
+
|
|
534
|
+
graph.link(left_table, left_key, right_table)
|
|
535
|
+
|
|
536
|
+
graph.validate()
|
|
377
537
|
|
|
378
538
|
if verbose:
|
|
379
539
|
graph.print_metadata()
|
|
380
540
|
graph.print_links()
|
|
381
541
|
|
|
542
|
+
if len(msgs) > 0:
|
|
543
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
544
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
545
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
546
|
+
|
|
382
547
|
return graph
|
|
383
548
|
|
|
384
|
-
#
|
|
549
|
+
# Backend #################################################################
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def backend(self) -> DataBackend | None:
|
|
553
|
+
backends = [table.backend for table in self._tables.values()]
|
|
554
|
+
return backends[0] if len(backends) > 0 else None
|
|
555
|
+
|
|
556
|
+
# Tables ##################################################################
|
|
385
557
|
|
|
386
558
|
def has_table(self, name: str) -> bool:
|
|
387
559
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -400,7 +572,7 @@ class Graph:
|
|
|
400
572
|
return self.tables[name]
|
|
401
573
|
|
|
402
574
|
@property
|
|
403
|
-
def tables(self) ->
|
|
575
|
+
def tables(self) -> dict[str, Table]:
|
|
404
576
|
r"""Returns the dictionary of table objects."""
|
|
405
577
|
return self._tables
|
|
406
578
|
|
|
@@ -420,13 +592,10 @@ class Graph:
|
|
|
420
592
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
421
593
|
f"this graph; table names must be globally unique.")
|
|
422
594
|
|
|
423
|
-
if
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
f"'{table.__class__.__name__}' to this "
|
|
428
|
-
f"graph since other tables are of type "
|
|
429
|
-
f"'{cls.__name__}'.")
|
|
595
|
+
if self.backend is not None and table.backend != self.backend:
|
|
596
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
597
|
+
f"'{table.backend}' to this graph since other "
|
|
598
|
+
f"tables have backend '{self.backend}'.")
|
|
430
599
|
|
|
431
600
|
self._tables[table.name] = table
|
|
432
601
|
|
|
@@ -488,9 +657,13 @@ class Graph:
|
|
|
488
657
|
|
|
489
658
|
def print_metadata(self) -> None:
|
|
490
659
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
491
|
-
if
|
|
660
|
+
if in_snowflake_notebook():
|
|
661
|
+
import streamlit as st
|
|
662
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
663
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
664
|
+
elif in_notebook():
|
|
492
665
|
from IPython.display import Markdown, display
|
|
493
|
-
display(Markdown(
|
|
666
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
494
667
|
df = self.metadata
|
|
495
668
|
try:
|
|
496
669
|
if hasattr(df.style, 'hide'):
|
|
@@ -524,7 +697,7 @@ class Graph:
|
|
|
524
697
|
# Edges ###################################################################
|
|
525
698
|
|
|
526
699
|
@property
|
|
527
|
-
def edges(self) ->
|
|
700
|
+
def edges(self) -> list[Edge]:
|
|
528
701
|
r"""Returns the edges of the graph."""
|
|
529
702
|
return self._edges
|
|
530
703
|
|
|
@@ -534,32 +707,42 @@ class Graph:
|
|
|
534
707
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
535
708
|
edges = sorted(edges)
|
|
536
709
|
|
|
537
|
-
if
|
|
710
|
+
if in_snowflake_notebook():
|
|
711
|
+
import streamlit as st
|
|
712
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
713
|
+
if len(edges) > 0:
|
|
714
|
+
st.markdown('\n'.join([
|
|
715
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
716
|
+
for edge in edges
|
|
717
|
+
]))
|
|
718
|
+
else:
|
|
719
|
+
st.markdown("*No links registered*")
|
|
720
|
+
elif in_notebook():
|
|
538
721
|
from IPython.display import Markdown, display
|
|
539
|
-
display(Markdown(
|
|
722
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
540
723
|
if len(edges) > 0:
|
|
541
724
|
display(
|
|
542
725
|
Markdown('\n'.join([
|
|
543
|
-
f
|
|
726
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
544
727
|
for edge in edges
|
|
545
728
|
])))
|
|
546
729
|
else:
|
|
547
|
-
display(Markdown(
|
|
730
|
+
display(Markdown("*No links registered*"))
|
|
548
731
|
else:
|
|
549
732
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
550
733
|
if len(edges) > 0:
|
|
551
734
|
print('\n'.join([
|
|
552
|
-
f
|
|
735
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
553
736
|
for edge in edges
|
|
554
737
|
]))
|
|
555
738
|
else:
|
|
556
|
-
print(
|
|
739
|
+
print("No links registered")
|
|
557
740
|
|
|
558
741
|
def link(
|
|
559
742
|
self,
|
|
560
|
-
src_table:
|
|
743
|
+
src_table: str | Table,
|
|
561
744
|
fkey: str,
|
|
562
|
-
dst_table:
|
|
745
|
+
dst_table: str | Table,
|
|
563
746
|
) -> Self:
|
|
564
747
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
565
748
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -620,9 +803,9 @@ class Graph:
|
|
|
620
803
|
|
|
621
804
|
def unlink(
|
|
622
805
|
self,
|
|
623
|
-
src_table:
|
|
806
|
+
src_table: str | Table,
|
|
624
807
|
fkey: str,
|
|
625
|
-
dst_table:
|
|
808
|
+
dst_table: str | Table,
|
|
626
809
|
) -> Self:
|
|
627
810
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
628
811
|
|
|
@@ -799,6 +982,10 @@ class Graph:
|
|
|
799
982
|
raise ValueError("At least one table needs to be added to the "
|
|
800
983
|
"graph")
|
|
801
984
|
|
|
985
|
+
backends = {table.backend for table in self._tables.values()}
|
|
986
|
+
if len(backends) != 1:
|
|
987
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
988
|
+
|
|
802
989
|
for edge in self.edges:
|
|
803
990
|
src_table, fkey, dst_table = edge
|
|
804
991
|
|
|
@@ -845,7 +1032,7 @@ class Graph:
|
|
|
845
1032
|
|
|
846
1033
|
def visualize(
|
|
847
1034
|
self,
|
|
848
|
-
path:
|
|
1035
|
+
path: str | io.BytesIO | None = None,
|
|
849
1036
|
show_columns: bool = True,
|
|
850
1037
|
) -> 'graphviz.Graph':
|
|
851
1038
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -870,33 +1057,33 @@ class Graph:
|
|
|
870
1057
|
|
|
871
1058
|
return True
|
|
872
1059
|
|
|
873
|
-
# Check basic dependency:
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
1060
|
+
try: # Check basic dependency:
|
|
1061
|
+
import graphviz
|
|
1062
|
+
except ImportError as e:
|
|
1063
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
1064
|
+
"visualization") from e
|
|
1065
|
+
|
|
1066
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
878
1067
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
879
1068
|
"executables are not installed. These "
|
|
880
1069
|
"dependencies are required in addition to the "
|
|
881
1070
|
"'graphviz' Python package. Please install "
|
|
882
1071
|
"them as described at "
|
|
883
1072
|
"https://graphviz.org/download/.")
|
|
884
|
-
else:
|
|
885
|
-
import graphviz
|
|
886
1073
|
|
|
887
|
-
format:
|
|
1074
|
+
format: str | None = None
|
|
888
1075
|
if isinstance(path, str):
|
|
889
1076
|
format = path.split('.')[-1]
|
|
890
1077
|
elif isinstance(path, io.BytesIO):
|
|
891
1078
|
format = 'svg'
|
|
892
1079
|
graph = graphviz.Graph(format=format)
|
|
893
1080
|
|
|
894
|
-
def left_align(keys:
|
|
1081
|
+
def left_align(keys: list[str]) -> str:
|
|
895
1082
|
if len(keys) == 0:
|
|
896
1083
|
return ""
|
|
897
1084
|
return '\\l'.join(keys) + '\\l'
|
|
898
1085
|
|
|
899
|
-
fkeys_dict:
|
|
1086
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
900
1087
|
for src_table_name, fkey_name, _ in self.edges:
|
|
901
1088
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
902
1089
|
|
|
@@ -966,6 +1153,9 @@ class Graph:
|
|
|
966
1153
|
graph.render(path, cleanup=True)
|
|
967
1154
|
elif isinstance(path, io.BytesIO):
|
|
968
1155
|
path.write(graph.pipe())
|
|
1156
|
+
elif in_snowflake_notebook():
|
|
1157
|
+
import streamlit as st
|
|
1158
|
+
st.graphviz_chart(graph)
|
|
969
1159
|
elif in_notebook():
|
|
970
1160
|
from IPython.display import display
|
|
971
1161
|
display(graph)
|
|
@@ -989,8 +1179,8 @@ class Graph:
|
|
|
989
1179
|
# Helpers #################################################################
|
|
990
1180
|
|
|
991
1181
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
992
|
-
tables:
|
|
993
|
-
col_groups:
|
|
1182
|
+
tables: dict[str, TableDefinition] = {}
|
|
1183
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
994
1184
|
for table_name, table in self.tables.items():
|
|
995
1185
|
tables[table_name] = table._to_api_table_definition()
|
|
996
1186
|
if table.primary_key is None:
|
|
@@ -1033,3 +1223,7 @@ class Graph:
|
|
|
1033
1223
|
f' tables={tables},\n'
|
|
1034
1224
|
f' edges={edges},\n'
|
|
1035
1225
|
f')')
|
|
1226
|
+
|
|
1227
|
+
def __del__(self) -> None:
|
|
1228
|
+
if hasattr(self, '_connection'):
|
|
1229
|
+
self._connection.close()
|