kumoai 2.13.0.dev202512040649__cp313-cp313-win_amd64.whl → 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +52 -91
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +21 -16
  10. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +102 -48
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  16. kumoai/experimental/rfm/base/__init__.py +26 -3
  17. kumoai/experimental/rfm/base/column.py +14 -12
  18. kumoai/experimental/rfm/base/column_expression.py +50 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +1 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +229 -0
  23. kumoai/experimental/rfm/base/table.py +173 -138
  24. kumoai/experimental/rfm/graph.py +302 -108
  25. kumoai/experimental/rfm/infer/__init__.py +6 -4
  26. kumoai/experimental/rfm/infer/dtype.py +3 -3
  27. kumoai/experimental/rfm/infer/pkey.py +4 -2
  28. kumoai/experimental/rfm/infer/stype.py +35 -0
  29. kumoai/experimental/rfm/infer/time_col.py +1 -2
  30. kumoai/experimental/rfm/pquery/executor.py +27 -27
  31. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  32. kumoai/experimental/rfm/rfm.py +299 -230
  33. kumoai/experimental/rfm/sagemaker.py +4 -4
  34. kumoai/kumolib.cp313-win_amd64.pyd +0 -0
  35. kumoai/pquery/predictive_query.py +10 -6
  36. kumoai/testing/snow.py +50 -0
  37. kumoai/utils/__init__.py +3 -2
  38. kumoai/utils/progress_logger.py +178 -12
  39. kumoai/utils/sql.py +3 -0
  40. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +3 -2
  41. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +44 -36
  42. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  43. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  44. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  45. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  46. {kumoai-2.13.0.dev202512040649.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
6
+ from collections.abc import Sequence
5
7
  from dataclasses import dataclass, field
6
- from importlib.util import find_spec
8
+ from itertools import chain
7
9
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
10
+ from typing import TYPE_CHECKING, Any, Union
9
11
 
10
12
  import pandas as pd
11
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,8 +15,13 @@ from kumoapi.table import TableDefinition
13
15
  from kumoapi.typing import Stype
14
16
  from typing_extensions import Self
15
17
 
16
- from kumoai import in_notebook
17
- from kumoai.experimental.rfm import Table
18
+ from kumoai import in_notebook, in_snowflake_notebook
19
+ from kumoai.experimental.rfm.base import (
20
+ ColumnExpressionSpec,
21
+ DataBackend,
22
+ SQLTable,
23
+ Table,
24
+ )
18
25
  from kumoai.graph import Edge
19
26
  from kumoai.mixin import CastMixin
20
27
 
@@ -26,8 +33,8 @@ if TYPE_CHECKING:
26
33
 
27
34
  @dataclass
28
35
  class SqliteConnectionConfig(CastMixin):
29
- uri: Union[str, Path]
30
- kwargs: Dict[str, Any] = field(default_factory=dict)
36
+ uri: str | Path
37
+ kwargs: dict[str, Any] = field(default_factory=dict)
31
38
 
32
39
 
33
40
  class Graph:
@@ -87,27 +94,38 @@ class Graph:
87
94
  def __init__(
88
95
  self,
89
96
  tables: Sequence[Table],
90
- edges: Optional[Sequence[Edge]] = None,
97
+ edges: Sequence[Edge] | None = None,
91
98
  ) -> None:
92
99
 
93
- self._tables: Dict[str, Table] = {}
94
- self._edges: List[Edge] = []
100
+ self._tables: dict[str, Table] = {}
101
+ self._edges: list[Edge] = []
95
102
 
96
103
  for table in tables:
97
104
  self.add_table(table)
98
105
 
99
106
  for table in tables:
107
+ if not isinstance(table, SQLTable):
108
+ continue
109
+ if '_source_column_dict' not in table.__dict__:
110
+ continue
100
111
  for fkey in table._source_foreign_key_dict.values():
101
- if fkey.name not in table or fkey.dst_table not in self:
112
+ if fkey.name not in table:
113
+ continue
114
+ if not table[fkey.name].is_physical:
115
+ continue
116
+ dst_table_names = [
117
+ table.name for table in self.tables.values()
118
+ if isinstance(table, SQLTable)
119
+ and table._source_name == fkey.dst_table
120
+ ]
121
+ if len(dst_table_names) != 1:
122
+ continue
123
+ dst_table = self[dst_table_names[0]]
124
+ if dst_table._primary_key != fkey.primary_key:
125
+ continue
126
+ if not dst_table[fkey.primary_key].is_physical:
102
127
  continue
103
- if self[fkey.dst_table].primary_key is None:
104
- self[fkey.dst_table].primary_key = fkey.primary_key
105
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
- raise ValueError(f"Found duplicate primary key definition "
107
- f"'{self[fkey.dst_table]._primary_key}' "
108
- f"and '{fkey.primary_key}' in table "
109
- f"'{fkey.dst_table}'.")
110
- self.link(table.name, fkey.name, fkey.dst_table)
128
+ self.link(table.name, fkey.name, dst_table.name)
111
129
 
112
130
  for edge in (edges or []):
113
131
  _edge = Edge._cast(edge)
@@ -118,8 +136,8 @@ class Graph:
118
136
  @classmethod
119
137
  def from_data(
120
138
  cls,
121
- df_dict: Dict[str, pd.DataFrame],
122
- edges: Optional[Sequence[Edge]] = None,
139
+ df_dict: dict[str, pd.DataFrame],
140
+ edges: Sequence[Edge] | None = None,
123
141
  infer_metadata: bool = True,
124
142
  verbose: bool = True,
125
143
  ) -> Self:
@@ -157,15 +175,17 @@ class Graph:
157
175
  verbose: Whether to print verbose output.
158
176
  """
159
177
  from kumoai.experimental.rfm.backend.local import LocalTable
160
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
178
 
162
- graph = cls(tables, edges=edges or [])
179
+ graph = cls(
180
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
181
+ edges=edges or [],
182
+ )
163
183
 
164
184
  if infer_metadata:
165
- graph.infer_metadata(False)
185
+ graph.infer_metadata(verbose=False)
166
186
 
167
187
  if edges is None:
168
- graph.infer_links(False)
188
+ graph.infer_links(verbose=False)
169
189
 
170
190
  if verbose:
171
191
  graph.print_metadata()
@@ -181,10 +201,10 @@ class Graph:
181
201
  SqliteConnectionConfig,
182
202
  str,
183
203
  Path,
184
- Dict[str, Any],
204
+ dict[str, Any],
185
205
  ],
186
- table_names: Optional[Sequence[str]] = None,
187
- edges: Optional[Sequence[Edge]] = None,
206
+ tables: Sequence[str | dict[str, Any]] | None = None,
207
+ edges: Sequence[Edge] | None = None,
188
208
  infer_metadata: bool = True,
189
209
  verbose: bool = True,
190
210
  ) -> Self:
@@ -200,17 +220,25 @@ class Graph:
200
220
  >>> # Create a graph from a SQLite database:
201
221
  >>> graph = rfm.Graph.from_sqlite('data.db')
202
222
 
223
+ >>> # Fine-grained control over table specification:
224
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
225
+ ... 'USERS',
226
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
227
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
228
+ ... ])
229
+
203
230
  Args:
204
231
  connection: An open connection from
205
232
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
233
  path to the database file.
207
- table_names: Set of table names to include. If ``None``, will add
208
- all tables present in the database.
234
+ tables: Set of table names or :class:`SQLiteTable` keyword
235
+ arguments to include. If ``None``, will add all tables present
236
+ in the database.
209
237
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
210
238
  add to the graph. If not provided, edges will be automatically
211
239
  inferred from the data in case ``infer_metadata=True``.
212
- infer_metadata: Whether to infer metadata for all tables in the
213
- graph.
240
+ infer_metadata: Whether to infer missing metadata for all tables in
241
+ the graph.
214
242
  verbose: Whether to print verbose output.
215
243
  """
216
244
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -219,27 +247,41 @@ class Graph:
219
247
  connect,
220
248
  )
221
249
 
250
+ internal_connection = False
222
251
  if not isinstance(connection, Connection):
223
252
  connection = SqliteConnectionConfig._cast(connection)
224
253
  assert isinstance(connection, SqliteConnectionConfig)
225
254
  connection = connect(connection.uri, **connection.kwargs)
255
+ internal_connection = True
226
256
  assert isinstance(connection, Connection)
227
257
 
228
- if table_names is None:
258
+ if tables is None:
229
259
  with connection.cursor() as cursor:
230
260
  cursor.execute("SELECT name FROM sqlite_master "
231
261
  "WHERE type='table'")
232
- table_names = [row[0] for row in cursor.fetchall()]
262
+ tables = [row[0] for row in cursor.fetchall()]
233
263
 
234
- tables = [SQLiteTable(connection, name) for name in table_names]
264
+ table_kwargs: list[dict[str, Any]] = []
265
+ for table in tables:
266
+ kwargs = dict(name=table) if isinstance(table, str) else table
267
+ table_kwargs.append(kwargs)
268
+
269
+ graph = cls(
270
+ tables=[
271
+ SQLiteTable(connection=connection, **kwargs)
272
+ for kwargs in table_kwargs
273
+ ],
274
+ edges=edges or [],
275
+ )
235
276
 
236
- graph = cls(tables, edges=edges or [])
277
+ if internal_connection:
278
+ graph._connection = connection # type: ignore
237
279
 
238
280
  if infer_metadata:
239
- graph.infer_metadata(False)
281
+ graph.infer_metadata(verbose=False)
240
282
 
241
283
  if edges is None:
242
- graph.infer_links(False)
284
+ graph.infer_links(verbose=False)
243
285
 
244
286
  if verbose:
245
287
  graph.print_metadata()
@@ -250,9 +292,11 @@ class Graph:
250
292
  @classmethod
251
293
  def from_snowflake(
252
294
  cls,
253
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
- table_names: Optional[Sequence[str]] = None,
255
- edges: Optional[Sequence[Edge]] = None,
295
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
296
+ tables: Sequence[str | dict[str, Any]] | None = None,
297
+ database: str | None = None,
298
+ schema: str | None = None,
299
+ edges: Sequence[Edge] | None = None,
256
300
  infer_metadata: bool = True,
257
301
  verbose: bool = True,
258
302
  ) -> Self:
@@ -267,7 +311,14 @@ class Graph:
267
311
  >>> import kumoai.experimental.rfm as rfm
268
312
 
269
313
  >>> # Create a graph directly in a Snowflake notebook:
270
- >>> graph = rfm.Graph.from_snowflake()
314
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
315
+
316
+ >>> # Fine-grained control over table specification:
317
+ >>> graph = rfm.Graph.from_snowflake(tables=[
318
+ ... 'USERS',
319
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
320
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
321
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
271
322
 
272
323
  Args:
273
324
  connection: An open connection from
@@ -276,8 +327,11 @@ class Graph:
276
327
  connection. If ``None``, will re-use an active session in case
277
328
  it exists, or create a new connection from credentials stored
278
329
  in environment variables.
279
- table_names: Set of table names to include. If ``None``, will add
280
- all tables present in the database.
330
+ tables: Set of table names or :class:`SnowTable` keyword arguments
331
+ to include. If ``None``, will add all tables present in the
332
+ current database and schema.
333
+ database: The database.
334
+ schema: The schema.
281
335
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
336
  add to the graph. If not provided, edges will be automatically
283
337
  inferred from the data in case ``infer_metadata=True``.
@@ -295,27 +349,50 @@ class Graph:
295
349
  connection = connect(**(connection or {}))
296
350
  assert isinstance(connection, Connection)
297
351
 
298
- if table_names is None:
352
+ if database is None or schema is None:
299
353
  with connection.cursor() as cursor:
300
354
  cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
- database, schema = cursor.fetchone()
302
- query = f"""
355
+ result = cursor.fetchone()
356
+ database = database or result[0]
357
+ assert database is not None
358
+ schema = schema or result[1]
359
+
360
+ if tables is None:
361
+ if schema is None:
362
+ raise ValueError("No current 'schema' set. Please specify the "
363
+ "Snowflake schema manually")
364
+
365
+ with connection.cursor() as cursor:
366
+ cursor.execute(f"""
303
367
  SELECT TABLE_NAME
304
368
  FROM {database}.INFORMATION_SCHEMA.TABLES
305
369
  WHERE TABLE_SCHEMA = '{schema}'
306
- """
307
- cursor.execute(query)
308
- table_names = [row[0] for row in cursor.fetchall()]
309
-
310
- tables = [SnowTable(connection, name) for name in table_names]
370
+ """)
371
+ tables = [row[0] for row in cursor.fetchall()]
311
372
 
312
- graph = cls(tables, edges=edges or [])
373
+ table_kwargs: list[dict[str, Any]] = []
374
+ for table in tables:
375
+ if isinstance(table, str):
376
+ kwargs = dict(name=table, database=database, schema=schema)
377
+ else:
378
+ kwargs = copy.copy(table)
379
+ kwargs.setdefault('database', database)
380
+ kwargs.setdefault('schema', schema)
381
+ table_kwargs.append(kwargs)
382
+
383
+ graph = cls(
384
+ tables=[
385
+ SnowTable(connection=connection, **kwargs)
386
+ for kwargs in table_kwargs
387
+ ],
388
+ edges=edges or [],
389
+ )
313
390
 
314
391
  if infer_metadata:
315
- graph.infer_metadata(False)
392
+ graph.infer_metadata(verbose=False)
316
393
 
317
394
  if edges is None:
318
- graph.infer_links(False)
395
+ graph.infer_links(verbose=False)
319
396
 
320
397
  if verbose:
321
398
  graph.print_metadata()
@@ -327,7 +404,7 @@ class Graph:
327
404
  def from_snowflake_semantic_view(
328
405
  cls,
329
406
  semantic_view_name: str,
330
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
407
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
331
408
  verbose: bool = True,
332
409
  ) -> Self:
333
410
  import yaml
@@ -345,43 +422,138 @@ class Graph:
345
422
  with connection.cursor() as cursor:
346
423
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
347
424
  f"'{semantic_view_name}')")
348
- view = yaml.safe_load(cursor.fetchone()[0])
425
+ cfg = yaml.safe_load(cursor.fetchone()[0])
349
426
 
350
427
  graph = cls(tables=[])
351
428
 
352
- for table_desc in view['tables']:
353
- primary_key: Optional[str] = None
354
- if ('primary_key' in table_desc # NOTE No composite keys yet.
355
- and len(table_desc['primary_key']['columns']) == 1):
356
- primary_key = table_desc['primary_key']['columns'][0]
429
+ msgs = []
430
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
431
+ for table_cfg in cfg['tables']:
432
+ table_name = table_cfg['name']
433
+ source_table_name = table_cfg['base_table']['table']
434
+ database = table_cfg['base_table']['database']
435
+ schema = table_cfg['base_table']['schema']
436
+
437
+ primary_key: str | None = None
438
+ if 'primary_key' in table_cfg:
439
+ primary_key_cfg = table_cfg['primary_key']
440
+ if len(primary_key_cfg['columns']) == 1:
441
+ primary_key = primary_key_cfg['columns'][0]
442
+ elif len(primary_key_cfg['columns']) > 1:
443
+ msgs.append(f"Failed to add primary key for table "
444
+ f"'{table_name}' since composite primary keys "
445
+ f"are not yet supported")
446
+
447
+ columns: list[str] = []
448
+ unsupported_columns: list[str] = []
449
+ column_expression_specs: list[ColumnExpressionSpec] = []
450
+ for column_cfg in chain(
451
+ table_cfg.get('dimensions', []),
452
+ table_cfg.get('time_dimensions', []),
453
+ table_cfg.get('facts', []),
454
+ ):
455
+ column_name = column_cfg['name']
456
+ column_expr = column_cfg.get('expr', None)
457
+ column_data_type = column_cfg.get('data_type', None)
458
+
459
+ if column_expr is None:
460
+ columns.append(column_name)
461
+ continue
462
+
463
+ column_expr = column_expr.replace(f'{table_name}.', '')
464
+
465
+ if column_expr == column_name:
466
+ columns.append(column_name)
467
+ continue
468
+
469
+ # Drop expressions that reference other tables (for now):
470
+ if any(f'{name}.' in column_expr for name in table_names):
471
+ unsupported_columns.append(column_name)
472
+ continue
473
+
474
+ spec = ColumnExpressionSpec(
475
+ name=column_name,
476
+ expr=column_expr,
477
+ dtype=SnowTable.to_dtype(column_data_type),
478
+ )
479
+ column_expression_specs.append(spec)
480
+
481
+ if len(unsupported_columns) == 1:
482
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
483
+ f"of table '{table_name}' since its expression "
484
+ f"references other tables")
485
+ elif len(unsupported_columns) > 1:
486
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
487
+ f"of table '{table_name}' since their expressions "
488
+ f"reference other tables")
357
489
 
358
490
  table = SnowTable(
359
491
  connection,
360
- name=table_desc['base_table']['table'],
361
- database=table_desc['base_table']['database'],
362
- schema=table_desc['base_table']['schema'],
492
+ name=table_name,
493
+ source_name=source_table_name,
494
+ database=database,
495
+ schema=schema,
496
+ columns=columns,
497
+ column_expressions=column_expression_specs,
363
498
  primary_key=primary_key,
364
499
  )
500
+
501
+ # TODO Add a way to register time columns without heuristic usage.
502
+ table.infer_time_column(verbose=False)
503
+
365
504
  graph.add_table(table)
366
505
 
367
- # TODO Find a solution to register time columns!
506
+ for relation_cfg in cfg.get('relationships', []):
507
+ name = relation_cfg['name']
508
+ if len(relation_cfg['relationship_columns']) != 1:
509
+ msgs.append(f"Failed to add relationship '{name}' since "
510
+ f"composite key references are not yet supported")
511
+ continue
368
512
 
369
- for relations in view['relationships']:
370
- if len(relations['relationship_columns']) != 1:
371
- continue # NOTE No composite keys yet.
372
- graph.link(
373
- src_table=relations['left_table'],
374
- fkey=relations['relationship_columns'][0]['left_column'],
375
- dst_table=relations['right_table'],
376
- )
513
+ left_table = relation_cfg['left_table']
514
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
515
+ right_table = relation_cfg['right_table']
516
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
517
+
518
+ if graph[right_table]._primary_key != right_key:
519
+ # Semantic view error - this should never be triggered:
520
+ msgs.append(f"Failed to add relationship '{name}' since the "
521
+ f"referenced key '{right_key}' of table "
522
+ f"'{right_table}' is not a primary key")
523
+ continue
524
+
525
+ if graph[left_table]._primary_key == left_key:
526
+ msgs.append(f"Failed to add relationship '{name}' since the "
527
+ f"referencing key '{left_key}' of table "
528
+ f"'{left_table}' is a primary key")
529
+ continue
530
+
531
+ if left_key not in graph[left_table]:
532
+ graph[left_table].add_column(left_key)
533
+
534
+ graph.link(left_table, left_key, right_table)
535
+
536
+ graph.validate()
377
537
 
378
538
  if verbose:
379
539
  graph.print_metadata()
380
540
  graph.print_links()
381
541
 
542
+ if len(msgs) > 0:
543
+ title = (f"Could not fully convert the semantic view definition "
544
+ f"'{semantic_view_name}' into a graph:\n")
545
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
546
+
382
547
  return graph
383
548
 
384
- # Tables ##############################################################
549
+ # Backend #################################################################
550
+
551
+ @property
552
+ def backend(self) -> DataBackend | None:
553
+ backends = [table.backend for table in self._tables.values()]
554
+ return backends[0] if len(backends) > 0 else None
555
+
556
+ # Tables ##################################################################
385
557
 
386
558
  def has_table(self, name: str) -> bool:
387
559
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -400,7 +572,7 @@ class Graph:
400
572
  return self.tables[name]
401
573
 
402
574
  @property
403
- def tables(self) -> Dict[str, Table]:
575
+ def tables(self) -> dict[str, Table]:
404
576
  r"""Returns the dictionary of table objects."""
405
577
  return self._tables
406
578
 
@@ -420,13 +592,10 @@ class Graph:
420
592
  raise KeyError(f"Cannot add table with name '{table.name}' to "
421
593
  f"this graph; table names must be globally unique.")
422
594
 
423
- if len(self._tables) > 0:
424
- cls = next(iter(self._tables.values())).__class__
425
- if table.__class__ != cls:
426
- raise ValueError(f"Cannot register a "
427
- f"'{table.__class__.__name__}' to this "
428
- f"graph since other tables are of type "
429
- f"'{cls.__name__}'.")
595
+ if self.backend is not None and table.backend != self.backend:
596
+ raise ValueError(f"Cannot register a table with backend "
597
+ f"'{table.backend}' to this graph since other "
598
+ f"tables have backend '{self.backend}'.")
430
599
 
431
600
  self._tables[table.name] = table
432
601
 
@@ -488,9 +657,13 @@ class Graph:
488
657
 
489
658
  def print_metadata(self) -> None:
490
659
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
491
- if in_notebook():
660
+ if in_snowflake_notebook():
661
+ import streamlit as st
662
+ st.markdown("### 🗂️ Graph Metadata")
663
+ st.dataframe(self.metadata, hide_index=True)
664
+ elif in_notebook():
492
665
  from IPython.display import Markdown, display
493
- display(Markdown('### 🗂️ Graph Metadata'))
666
+ display(Markdown("### 🗂️ Graph Metadata"))
494
667
  df = self.metadata
495
668
  try:
496
669
  if hasattr(df.style, 'hide'):
@@ -524,7 +697,7 @@ class Graph:
524
697
  # Edges ###################################################################
525
698
 
526
699
  @property
527
- def edges(self) -> List[Edge]:
700
+ def edges(self) -> list[Edge]:
528
701
  r"""Returns the edges of the graph."""
529
702
  return self._edges
530
703
 
@@ -534,32 +707,42 @@ class Graph:
534
707
  edge.src_table, edge.fkey) for edge in self.edges]
535
708
  edges = sorted(edges)
536
709
 
537
- if in_notebook():
710
+ if in_snowflake_notebook():
711
+ import streamlit as st
712
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
713
+ if len(edges) > 0:
714
+ st.markdown('\n'.join([
715
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
716
+ for edge in edges
717
+ ]))
718
+ else:
719
+ st.markdown("*No links registered*")
720
+ elif in_notebook():
538
721
  from IPython.display import Markdown, display
539
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
722
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
540
723
  if len(edges) > 0:
541
724
  display(
542
725
  Markdown('\n'.join([
543
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
726
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
544
727
  for edge in edges
545
728
  ])))
546
729
  else:
547
- display(Markdown('*No links registered*'))
730
+ display(Markdown("*No links registered*"))
548
731
  else:
549
732
  print("🕸️ Graph Links (FK ↔️ PK):")
550
733
  if len(edges) > 0:
551
734
  print('\n'.join([
552
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
735
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
553
736
  for edge in edges
554
737
  ]))
555
738
  else:
556
- print('No links registered')
739
+ print("No links registered")
557
740
 
558
741
  def link(
559
742
  self,
560
- src_table: Union[str, Table],
743
+ src_table: str | Table,
561
744
  fkey: str,
562
- dst_table: Union[str, Table],
745
+ dst_table: str | Table,
563
746
  ) -> Self:
564
747
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
565
748
  key ``fkey`` in the source table to the primary key in the destination
@@ -620,9 +803,9 @@ class Graph:
620
803
 
621
804
  def unlink(
622
805
  self,
623
- src_table: Union[str, Table],
806
+ src_table: str | Table,
624
807
  fkey: str,
625
- dst_table: Union[str, Table],
808
+ dst_table: str | Table,
626
809
  ) -> Self:
627
810
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
628
811
 
@@ -799,6 +982,10 @@ class Graph:
799
982
  raise ValueError("At least one table needs to be added to the "
800
983
  "graph")
801
984
 
985
+ backends = {table.backend for table in self._tables.values()}
986
+ if len(backends) != 1:
987
+ raise ValueError("Found multiple table backends in the graph")
988
+
802
989
  for edge in self.edges:
803
990
  src_table, fkey, dst_table = edge
804
991
 
@@ -845,7 +1032,7 @@ class Graph:
845
1032
 
846
1033
  def visualize(
847
1034
  self,
848
- path: Optional[Union[str, io.BytesIO]] = None,
1035
+ path: str | io.BytesIO | None = None,
849
1036
  show_columns: bool = True,
850
1037
  ) -> 'graphviz.Graph':
851
1038
  r"""Visualizes the tables and edges in this graph using the
@@ -870,33 +1057,33 @@ class Graph:
870
1057
 
871
1058
  return True
872
1059
 
873
- # Check basic dependency:
874
- if not find_spec('graphviz'):
875
- raise ModuleNotFoundError("The 'graphviz' package is required for "
876
- "visualization")
877
- elif not has_graphviz_executables():
1060
+ try: # Check basic dependency:
1061
+ import graphviz
1062
+ except ImportError as e:
1063
+ raise ImportError("The 'graphviz' package is required for "
1064
+ "visualization") from e
1065
+
1066
+ if not in_snowflake_notebook() and not has_graphviz_executables():
878
1067
  raise RuntimeError("Could not visualize graph as 'graphviz' "
879
1068
  "executables are not installed. These "
880
1069
  "dependencies are required in addition to the "
881
1070
  "'graphviz' Python package. Please install "
882
1071
  "them as described at "
883
1072
  "https://graphviz.org/download/.")
884
- else:
885
- import graphviz
886
1073
 
887
- format: Optional[str] = None
1074
+ format: str | None = None
888
1075
  if isinstance(path, str):
889
1076
  format = path.split('.')[-1]
890
1077
  elif isinstance(path, io.BytesIO):
891
1078
  format = 'svg'
892
1079
  graph = graphviz.Graph(format=format)
893
1080
 
894
- def left_align(keys: List[str]) -> str:
1081
+ def left_align(keys: list[str]) -> str:
895
1082
  if len(keys) == 0:
896
1083
  return ""
897
1084
  return '\\l'.join(keys) + '\\l'
898
1085
 
899
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1086
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
900
1087
  for src_table_name, fkey_name, _ in self.edges:
901
1088
  fkeys_dict[src_table_name].append(fkey_name)
902
1089
 
@@ -966,6 +1153,9 @@ class Graph:
966
1153
  graph.render(path, cleanup=True)
967
1154
  elif isinstance(path, io.BytesIO):
968
1155
  path.write(graph.pipe())
1156
+ elif in_snowflake_notebook():
1157
+ import streamlit as st
1158
+ st.graphviz_chart(graph)
969
1159
  elif in_notebook():
970
1160
  from IPython.display import display
971
1161
  display(graph)
@@ -989,8 +1179,8 @@ class Graph:
989
1179
  # Helpers #################################################################
990
1180
 
991
1181
  def _to_api_graph_definition(self) -> GraphDefinition:
992
- tables: Dict[str, TableDefinition] = {}
993
- col_groups: List[ColumnKeyGroup] = []
1182
+ tables: dict[str, TableDefinition] = {}
1183
+ col_groups: list[ColumnKeyGroup] = []
994
1184
  for table_name, table in self.tables.items():
995
1185
  tables[table_name] = table._to_api_table_definition()
996
1186
  if table.primary_key is None:
@@ -1033,3 +1223,7 @@ class Graph:
1033
1223
  f' tables={tables},\n'
1034
1224
  f' edges={edges},\n'
1035
1225
  f')')
1226
+
1227
+ def __del__(self) -> None:
1228
+ if hasattr(self, '_connection'):
1229
+ self._connection.close()