kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +31 -14
  10. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +75 -23
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +71 -28
  16. kumoai/experimental/rfm/base/__init__.py +24 -3
  17. kumoai/experimental/rfm/base/column.py +6 -12
  18. kumoai/experimental/rfm/base/column_expression.py +16 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +1 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +113 -0
  23. kumoai/experimental/rfm/base/table.py +136 -105
  24. kumoai/experimental/rfm/graph.py +296 -89
  25. kumoai/experimental/rfm/infer/dtype.py +46 -59
  26. kumoai/experimental/rfm/infer/pkey.py +4 -2
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +4 -2
  38. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +41 -34
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
6
+ from collections.abc import Sequence
5
7
  from dataclasses import dataclass, field
6
- from importlib.util import find_spec
8
+ from itertools import chain
7
9
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
10
+ from typing import TYPE_CHECKING, Any, Union
9
11
 
10
12
  import pandas as pd
11
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,8 +15,8 @@ from kumoapi.table import TableDefinition
13
15
  from kumoapi.typing import Stype
14
16
  from typing_extensions import Self
15
17
 
16
- from kumoai import in_notebook
17
- from kumoai.experimental.rfm import Table
18
+ from kumoai import in_notebook, in_snowflake_notebook
19
+ from kumoai.experimental.rfm.base import DataBackend, SQLTable, Table
18
20
  from kumoai.graph import Edge
19
21
  from kumoai.mixin import CastMixin
20
22
 
@@ -26,8 +28,8 @@ if TYPE_CHECKING:
26
28
 
27
29
  @dataclass
28
30
  class SqliteConnectionConfig(CastMixin):
29
- uri: Union[str, Path]
30
- kwargs: Dict[str, Any] = field(default_factory=dict)
31
+ uri: str | Path
32
+ kwargs: dict[str, Any] = field(default_factory=dict)
31
33
 
32
34
 
33
35
  class Graph:
@@ -87,27 +89,34 @@ class Graph:
87
89
  def __init__(
88
90
  self,
89
91
  tables: Sequence[Table],
90
- edges: Optional[Sequence[Edge]] = None,
92
+ edges: Sequence[Edge] | None = None,
91
93
  ) -> None:
92
94
 
93
- self._tables: Dict[str, Table] = {}
94
- self._edges: List[Edge] = []
95
+ self._tables: dict[str, Table] = {}
96
+ self._edges: list[Edge] = []
95
97
 
96
98
  for table in tables:
97
99
  self.add_table(table)
98
100
 
99
101
  for table in tables:
102
+ if not isinstance(table, SQLTable):
103
+ continue
100
104
  for fkey in table._source_foreign_key_dict.values():
101
- if fkey.name not in table or fkey.dst_table not in self:
105
+ if fkey.name not in table:
106
+ continue
107
+ # TODO Skip for non-physical table[fkey.name].
108
+ dst_table_names = [
109
+ table.name for table in self.tables.values()
110
+ if isinstance(table, SQLTable)
111
+ and table._source_name == fkey.dst_table
112
+ ]
113
+ if len(dst_table_names) != 1:
102
114
  continue
103
- if self[fkey.dst_table].primary_key is None:
104
- self[fkey.dst_table].primary_key = fkey.primary_key
105
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
- raise ValueError(f"Found duplicate primary key definition "
107
- f"'{self[fkey.dst_table]._primary_key}' "
108
- f"and '{fkey.primary_key}' in table "
109
- f"'{fkey.dst_table}'.")
110
- self.link(table.name, fkey.name, fkey.dst_table)
115
+ dst_table = self[dst_table_names[0]]
116
+ if dst_table._primary_key != fkey.primary_key:
117
+ continue
118
+ # TODO Skip for non-physical dst_table.primary_key.
119
+ self.link(table.name, fkey.name, dst_table.name)
111
120
 
112
121
  for edge in (edges or []):
113
122
  _edge = Edge._cast(edge)
@@ -118,8 +127,8 @@ class Graph:
118
127
  @classmethod
119
128
  def from_data(
120
129
  cls,
121
- df_dict: Dict[str, pd.DataFrame],
122
- edges: Optional[Sequence[Edge]] = None,
130
+ df_dict: dict[str, pd.DataFrame],
131
+ edges: Sequence[Edge] | None = None,
123
132
  infer_metadata: bool = True,
124
133
  verbose: bool = True,
125
134
  ) -> Self:
@@ -157,15 +166,17 @@ class Graph:
157
166
  verbose: Whether to print verbose output.
158
167
  """
159
168
  from kumoai.experimental.rfm.backend.local import LocalTable
160
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
169
 
162
- graph = cls(tables, edges=edges or [])
170
+ graph = cls(
171
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
172
+ edges=edges or [],
173
+ )
163
174
 
164
175
  if infer_metadata:
165
- graph.infer_metadata(False)
176
+ graph.infer_metadata(verbose=False)
166
177
 
167
178
  if edges is None:
168
- graph.infer_links(False)
179
+ graph.infer_links(verbose=False)
169
180
 
170
181
  if verbose:
171
182
  graph.print_metadata()
@@ -181,10 +192,10 @@ class Graph:
181
192
  SqliteConnectionConfig,
182
193
  str,
183
194
  Path,
184
- Dict[str, Any],
195
+ dict[str, Any],
185
196
  ],
186
- table_names: Optional[Sequence[str]] = None,
187
- edges: Optional[Sequence[Edge]] = None,
197
+ tables: Sequence[str | dict[str, Any]] | None = None,
198
+ edges: Sequence[Edge] | None = None,
188
199
  infer_metadata: bool = True,
189
200
  verbose: bool = True,
190
201
  ) -> Self:
@@ -200,17 +211,25 @@ class Graph:
200
211
  >>> # Create a graph from a SQLite database:
201
212
  >>> graph = rfm.Graph.from_sqlite('data.db')
202
213
 
214
+ >>> # Fine-grained control over table specification:
215
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
216
+ ... 'USERS',
217
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
218
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
219
+ ... ])
220
+
203
221
  Args:
204
222
  connection: An open connection from
205
223
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
224
  path to the database file.
207
- table_names: Set of table names to include. If ``None``, will add
208
- all tables present in the database.
225
+ tables: Set of table names or :class:`SQLiteTable` keyword
226
+ arguments to include. If ``None``, will add all tables present
227
+ in the database.
209
228
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
210
229
  add to the graph. If not provided, edges will be automatically
211
230
  inferred from the data in case ``infer_metadata=True``.
212
- infer_metadata: Whether to infer metadata for all tables in the
213
- graph.
231
+ infer_metadata: Whether to infer missing metadata for all tables in
232
+ the graph.
214
233
  verbose: Whether to print verbose output.
215
234
  """
216
235
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -219,27 +238,41 @@ class Graph:
219
238
  connect,
220
239
  )
221
240
 
241
+ internal_connection = False
222
242
  if not isinstance(connection, Connection):
223
243
  connection = SqliteConnectionConfig._cast(connection)
224
244
  assert isinstance(connection, SqliteConnectionConfig)
225
245
  connection = connect(connection.uri, **connection.kwargs)
246
+ internal_connection = True
226
247
  assert isinstance(connection, Connection)
227
248
 
228
- if table_names is None:
249
+ if tables is None:
229
250
  with connection.cursor() as cursor:
230
251
  cursor.execute("SELECT name FROM sqlite_master "
231
252
  "WHERE type='table'")
232
- table_names = [row[0] for row in cursor.fetchall()]
253
+ tables = [row[0] for row in cursor.fetchall()]
233
254
 
234
- tables = [SQLiteTable(connection, name) for name in table_names]
255
+ table_kwargs: list[dict[str, Any]] = []
256
+ for table in tables:
257
+ kwargs = dict(name=table) if isinstance(table, str) else table
258
+ table_kwargs.append(kwargs)
259
+
260
+ graph = cls(
261
+ tables=[
262
+ SQLiteTable(connection=connection, **kwargs)
263
+ for kwargs in table_kwargs
264
+ ],
265
+ edges=edges or [],
266
+ )
235
267
 
236
- graph = cls(tables, edges=edges or [])
268
+ if internal_connection:
269
+ graph._connection = connection # type: ignore
237
270
 
238
271
  if infer_metadata:
239
- graph.infer_metadata(False)
272
+ graph.infer_metadata(verbose=False)
240
273
 
241
274
  if edges is None:
242
- graph.infer_links(False)
275
+ graph.infer_links(verbose=False)
243
276
 
244
277
  if verbose:
245
278
  graph.print_metadata()
@@ -250,9 +283,11 @@ class Graph:
250
283
  @classmethod
251
284
  def from_snowflake(
252
285
  cls,
253
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
- table_names: Optional[Sequence[str]] = None,
255
- edges: Optional[Sequence[Edge]] = None,
286
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
287
+ tables: Sequence[str | dict[str, Any]] | None = None,
288
+ database: str | None = None,
289
+ schema: str | None = None,
290
+ edges: Sequence[Edge] | None = None,
256
291
  infer_metadata: bool = True,
257
292
  verbose: bool = True,
258
293
  ) -> Self:
@@ -267,7 +302,14 @@ class Graph:
267
302
  >>> import kumoai.experimental.rfm as rfm
268
303
 
269
304
  >>> # Create a graph directly in a Snowflake notebook:
270
- >>> graph = rfm.Graph.from_snowflake()
305
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
306
+
307
+ >>> # Fine-grained control over table specification:
308
+ >>> graph = rfm.Graph.from_snowflake(tables=[
309
+ ... 'USERS',
310
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
311
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
312
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
271
313
 
272
314
  Args:
273
315
  connection: An open connection from
@@ -276,8 +318,11 @@ class Graph:
276
318
  connection. If ``None``, will re-use an active session in case
277
319
  it exists, or create a new connection from credentials stored
278
320
  in environment variables.
279
- table_names: Set of table names to include. If ``None``, will add
280
- all tables present in the database.
321
+ tables: Set of table names or :class:`SnowTable` keyword arguments
322
+ to include. If ``None``, will add all tables present in the
323
+ current database and schema.
324
+ database: The database.
325
+ schema: The schema.
281
326
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
327
  add to the graph. If not provided, edges will be automatically
283
328
  inferred from the data in case ``infer_metadata=True``.
@@ -295,27 +340,50 @@ class Graph:
295
340
  connection = connect(**(connection or {}))
296
341
  assert isinstance(connection, Connection)
297
342
 
298
- if table_names is None:
343
+ if database is None or schema is None:
299
344
  with connection.cursor() as cursor:
300
345
  cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
- database, schema = cursor.fetchone()
302
- query = f"""
346
+ result = cursor.fetchone()
347
+ database = database or result[0]
348
+ assert database is not None
349
+ schema = schema or result[1]
350
+
351
+ if tables is None:
352
+ if schema is None:
353
+ raise ValueError("No current 'schema' set. Please specify the "
354
+ "Snowflake schema manually")
355
+
356
+ with connection.cursor() as cursor:
357
+ cursor.execute(f"""
303
358
  SELECT TABLE_NAME
304
359
  FROM {database}.INFORMATION_SCHEMA.TABLES
305
360
  WHERE TABLE_SCHEMA = '{schema}'
306
- """
307
- cursor.execute(query)
308
- table_names = [row[0] for row in cursor.fetchall()]
361
+ """)
362
+ tables = [row[0] for row in cursor.fetchall()]
309
363
 
310
- tables = [SnowTable(connection, name) for name in table_names]
311
-
312
- graph = cls(tables, edges=edges or [])
364
+ table_kwargs: list[dict[str, Any]] = []
365
+ for table in tables:
366
+ if isinstance(table, str):
367
+ kwargs = dict(name=table, database=database, schema=schema)
368
+ else:
369
+ kwargs = copy.copy(table)
370
+ kwargs.setdefault('database', database)
371
+ kwargs.setdefault('schema', schema)
372
+ table_kwargs.append(kwargs)
373
+
374
+ graph = cls(
375
+ tables=[
376
+ SnowTable(connection=connection, **kwargs)
377
+ for kwargs in table_kwargs
378
+ ],
379
+ edges=edges or [],
380
+ )
313
381
 
314
382
  if infer_metadata:
315
- graph.infer_metadata(False)
383
+ graph.infer_metadata(verbose=False)
316
384
 
317
385
  if edges is None:
318
- graph.infer_links(False)
386
+ graph.infer_links(verbose=False)
319
387
 
320
388
  if verbose:
321
389
  graph.print_metadata()
@@ -323,7 +391,124 @@ class Graph:
323
391
 
324
392
  return graph
325
393
 
326
- # Tables ##############################################################
394
+ @classmethod
395
+ def from_snowflake_semantic_view(
396
+ cls,
397
+ semantic_view_name: str,
398
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
399
+ verbose: bool = True,
400
+ ) -> Self:
401
+ import yaml
402
+
403
+ from kumoai.experimental.rfm.backend.snow import (
404
+ Connection,
405
+ SnowTable,
406
+ connect,
407
+ )
408
+
409
+ if not isinstance(connection, Connection):
410
+ connection = connect(**(connection or {}))
411
+ assert isinstance(connection, Connection)
412
+
413
+ with connection.cursor() as cursor:
414
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
415
+ f"'{semantic_view_name}')")
416
+ cfg = yaml.safe_load(cursor.fetchone()[0])
417
+
418
+ graph = cls(tables=[])
419
+
420
+ msgs = []
421
+ for table_cfg in cfg['tables']:
422
+ table_name = table_cfg['name']
423
+ source_table_name = table_cfg['base_table']['table']
424
+ database = table_cfg['base_table']['database']
425
+ schema = table_cfg['base_table']['schema']
426
+
427
+ primary_key: str | None = None
428
+ if 'primary_key' in table_cfg:
429
+ primary_key_cfg = table_cfg['primary_key']
430
+ if len(primary_key_cfg['columns']) == 1:
431
+ primary_key = primary_key_cfg['columns'][0]
432
+ elif len(primary_key_cfg['columns']) > 1:
433
+ msgs.append(f"Failed to add primary key for table "
434
+ f"'{table_name}' since composite primary keys "
435
+ f"are not yet supported")
436
+
437
+ columns: list[str] = []
438
+ for column_cfg in chain(
439
+ table_cfg.get('dimensions', []),
440
+ table_cfg.get('time_dimensions', []),
441
+ table_cfg.get('facts', []),
442
+ ):
443
+ # TODO Add support for derived columns.
444
+ columns.append(column_cfg['name'])
445
+
446
+ table = SnowTable(
447
+ connection,
448
+ name=table_name,
449
+ source_name=source_table_name,
450
+ database=database,
451
+ schema=schema,
452
+ columns=columns,
453
+ primary_key=primary_key,
454
+ )
455
+
456
+ # TODO Add a way to register time columns without heuristic usage.
457
+ table.infer_time_column(verbose=False)
458
+
459
+ graph.add_table(table)
460
+
461
+ for relation_cfg in cfg.get('relationships', []):
462
+ name = relation_cfg['name']
463
+ if len(relation_cfg['relationship_columns']) != 1:
464
+ msgs.append(f"Failed to add relationship '{name}' since "
465
+ f"composite key references are not yet supported")
466
+ continue
467
+
468
+ left_table = relation_cfg['left_table']
469
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
470
+ right_table = relation_cfg['right_table']
471
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
472
+
473
+ if graph[right_table]._primary_key != right_key:
474
+ # Semantic view error - this should never be triggered:
475
+ msgs.append(f"Failed to add relationship '{name}' since the "
476
+ f"referenced key '{right_key}' of table "
477
+ f"'{right_table}' is not a primary key")
478
+ continue
479
+
480
+ if graph[left_table]._primary_key == left_key:
481
+ msgs.append(f"Failed to add relationship '{name}' since the "
482
+ f"referencing key '{left_key}' of table "
483
+ f"'{left_table}' is a primary key")
484
+ continue
485
+
486
+ if left_key not in graph[left_table]:
487
+ graph[left_table].add_column(left_key)
488
+
489
+ graph.link(left_table, left_key, right_table)
490
+
491
+ graph.validate()
492
+
493
+ if verbose:
494
+ graph.print_metadata()
495
+ graph.print_links()
496
+
497
+ if len(msgs) > 0:
498
+ title = (f"Could not fully convert the semantic view definition "
499
+ f"'{semantic_view_name}' into a graph:\n")
500
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
501
+
502
+ return graph
503
+
504
+ # Backend #################################################################
505
+
506
+ @property
507
+ def backend(self) -> DataBackend | None:
508
+ backends = [table.backend for table in self._tables.values()]
509
+ return backends[0] if len(backends) > 0 else None
510
+
511
+ # Tables ##################################################################
327
512
 
328
513
  def has_table(self, name: str) -> bool:
329
514
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -342,7 +527,7 @@ class Graph:
342
527
  return self.tables[name]
343
528
 
344
529
  @property
345
- def tables(self) -> Dict[str, Table]:
530
+ def tables(self) -> dict[str, Table]:
346
531
  r"""Returns the dictionary of table objects."""
347
532
  return self._tables
348
533
 
@@ -362,13 +547,10 @@ class Graph:
362
547
  raise KeyError(f"Cannot add table with name '{table.name}' to "
363
548
  f"this graph; table names must be globally unique.")
364
549
 
365
- if len(self._tables) > 0:
366
- cls = next(iter(self._tables.values())).__class__
367
- if table.__class__ != cls:
368
- raise ValueError(f"Cannot register a "
369
- f"'{table.__class__.__name__}' to this "
370
- f"graph since other tables are of type "
371
- f"'{cls.__name__}'.")
550
+ if self.backend is not None and table.backend != self.backend:
551
+ raise ValueError(f"Cannot register a table with backend "
552
+ f"'{table.backend}' to this graph since other "
553
+ f"tables have backend '{self.backend}'.")
372
554
 
373
555
  self._tables[table.name] = table
374
556
 
@@ -430,9 +612,13 @@ class Graph:
430
612
 
431
613
  def print_metadata(self) -> None:
432
614
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
433
- if in_notebook():
615
+ if in_snowflake_notebook():
616
+ import streamlit as st
617
+ st.markdown("### 🗂️ Graph Metadata")
618
+ st.dataframe(self.metadata, hide_index=True)
619
+ elif in_notebook():
434
620
  from IPython.display import Markdown, display
435
- display(Markdown('### 🗂️ Graph Metadata'))
621
+ display(Markdown("### 🗂️ Graph Metadata"))
436
622
  df = self.metadata
437
623
  try:
438
624
  if hasattr(df.style, 'hide'):
@@ -466,7 +652,7 @@ class Graph:
466
652
  # Edges ###################################################################
467
653
 
468
654
  @property
469
- def edges(self) -> List[Edge]:
655
+ def edges(self) -> list[Edge]:
470
656
  r"""Returns the edges of the graph."""
471
657
  return self._edges
472
658
 
@@ -476,32 +662,42 @@ class Graph:
476
662
  edge.src_table, edge.fkey) for edge in self.edges]
477
663
  edges = sorted(edges)
478
664
 
479
- if in_notebook():
665
+ if in_snowflake_notebook():
666
+ import streamlit as st
667
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
668
+ if len(edges) > 0:
669
+ st.markdown('\n'.join([
670
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
671
+ for edge in edges
672
+ ]))
673
+ else:
674
+ st.markdown("*No links registered*")
675
+ elif in_notebook():
480
676
  from IPython.display import Markdown, display
481
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
677
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
482
678
  if len(edges) > 0:
483
679
  display(
484
680
  Markdown('\n'.join([
485
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
681
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
486
682
  for edge in edges
487
683
  ])))
488
684
  else:
489
- display(Markdown('*No links registered*'))
685
+ display(Markdown("*No links registered*"))
490
686
  else:
491
687
  print("🕸️ Graph Links (FK ↔️ PK):")
492
688
  if len(edges) > 0:
493
689
  print('\n'.join([
494
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
690
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
495
691
  for edge in edges
496
692
  ]))
497
693
  else:
498
- print('No links registered')
694
+ print("No links registered")
499
695
 
500
696
  def link(
501
697
  self,
502
- src_table: Union[str, Table],
698
+ src_table: str | Table,
503
699
  fkey: str,
504
- dst_table: Union[str, Table],
700
+ dst_table: str | Table,
505
701
  ) -> Self:
506
702
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
507
703
  key ``fkey`` in the source table to the primary key in the destination
@@ -562,9 +758,9 @@ class Graph:
562
758
 
563
759
  def unlink(
564
760
  self,
565
- src_table: Union[str, Table],
761
+ src_table: str | Table,
566
762
  fkey: str,
567
- dst_table: Union[str, Table],
763
+ dst_table: str | Table,
568
764
  ) -> Self:
569
765
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
570
766
 
@@ -741,6 +937,10 @@ class Graph:
741
937
  raise ValueError("At least one table needs to be added to the "
742
938
  "graph")
743
939
 
940
+ backends = {table.backend for table in self._tables.values()}
941
+ if len(backends) != 1:
942
+ raise ValueError("Found multiple table backends in the graph")
943
+
744
944
  for edge in self.edges:
745
945
  src_table, fkey, dst_table = edge
746
946
 
@@ -787,7 +987,7 @@ class Graph:
787
987
 
788
988
  def visualize(
789
989
  self,
790
- path: Optional[Union[str, io.BytesIO]] = None,
990
+ path: str | io.BytesIO | None = None,
791
991
  show_columns: bool = True,
792
992
  ) -> 'graphviz.Graph':
793
993
  r"""Visualizes the tables and edges in this graph using the
@@ -812,33 +1012,33 @@ class Graph:
812
1012
 
813
1013
  return True
814
1014
 
815
- # Check basic dependency:
816
- if not find_spec('graphviz'):
817
- raise ModuleNotFoundError("The 'graphviz' package is required for "
818
- "visualization")
819
- elif not has_graphviz_executables():
1015
+ try: # Check basic dependency:
1016
+ import graphviz
1017
+ except ImportError as e:
1018
+ raise ImportError("The 'graphviz' package is required for "
1019
+ "visualization") from e
1020
+
1021
+ if not in_snowflake_notebook() and not has_graphviz_executables():
820
1022
  raise RuntimeError("Could not visualize graph as 'graphviz' "
821
1023
  "executables are not installed. These "
822
1024
  "dependencies are required in addition to the "
823
1025
  "'graphviz' Python package. Please install "
824
1026
  "them as described at "
825
1027
  "https://graphviz.org/download/.")
826
- else:
827
- import graphviz
828
1028
 
829
- format: Optional[str] = None
1029
+ format: str | None = None
830
1030
  if isinstance(path, str):
831
1031
  format = path.split('.')[-1]
832
1032
  elif isinstance(path, io.BytesIO):
833
1033
  format = 'svg'
834
1034
  graph = graphviz.Graph(format=format)
835
1035
 
836
- def left_align(keys: List[str]) -> str:
1036
+ def left_align(keys: list[str]) -> str:
837
1037
  if len(keys) == 0:
838
1038
  return ""
839
1039
  return '\\l'.join(keys) + '\\l'
840
1040
 
841
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1041
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
842
1042
  for src_table_name, fkey_name, _ in self.edges:
843
1043
  fkeys_dict[src_table_name].append(fkey_name)
844
1044
 
@@ -908,6 +1108,9 @@ class Graph:
908
1108
  graph.render(path, cleanup=True)
909
1109
  elif isinstance(path, io.BytesIO):
910
1110
  path.write(graph.pipe())
1111
+ elif in_snowflake_notebook():
1112
+ import streamlit as st
1113
+ st.graphviz_chart(graph)
911
1114
  elif in_notebook():
912
1115
  from IPython.display import display
913
1116
  display(graph)
@@ -931,8 +1134,8 @@ class Graph:
931
1134
  # Helpers #################################################################
932
1135
 
933
1136
  def _to_api_graph_definition(self) -> GraphDefinition:
934
- tables: Dict[str, TableDefinition] = {}
935
- col_groups: List[ColumnKeyGroup] = []
1137
+ tables: dict[str, TableDefinition] = {}
1138
+ col_groups: list[ColumnKeyGroup] = []
936
1139
  for table_name, table in self.tables.items():
937
1140
  tables[table_name] = table._to_api_table_definition()
938
1141
  if table.primary_key is None:
@@ -975,3 +1178,7 @@ class Graph:
975
1178
  f' tables={tables},\n'
976
1179
  f' edges={edges},\n'
977
1180
  f')')
1181
+
1182
+ def __del__(self) -> None:
1183
+ if hasattr(self, '_connection'):
1184
+ self._connection.close()