kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/experimental/rfm/__init__.py +33 -8
  5. kumoai/experimental/rfm/authenticate.py +3 -4
  6. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +53 -107
  8. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  9. kumoai/experimental/rfm/backend/local/table.py +41 -80
  10. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  11. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  12. kumoai/experimental/rfm/backend/snow/table.py +147 -0
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +11 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  15. kumoai/experimental/rfm/backend/sqlite/table.py +108 -88
  16. kumoai/experimental/rfm/base/__init__.py +26 -2
  17. kumoai/experimental/rfm/base/column.py +6 -12
  18. kumoai/experimental/rfm/base/column_expression.py +16 -0
  19. kumoai/experimental/rfm/base/sampler.py +773 -0
  20. kumoai/experimental/rfm/base/source.py +19 -0
  21. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  22. kumoai/experimental/rfm/base/sql_table.py +113 -0
  23. kumoai/experimental/rfm/base/table.py +174 -76
  24. kumoai/experimental/rfm/graph.py +444 -84
  25. kumoai/experimental/rfm/infer/__init__.py +6 -0
  26. kumoai/experimental/rfm/infer/dtype.py +77 -0
  27. kumoai/experimental/rfm/infer/pkey.py +128 -0
  28. kumoai/experimental/rfm/infer/time_col.py +61 -0
  29. kumoai/experimental/rfm/pquery/executor.py +27 -27
  30. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  31. kumoai/experimental/rfm/rfm.py +299 -240
  32. kumoai/experimental/rfm/sagemaker.py +4 -4
  33. kumoai/pquery/predictive_query.py +10 -6
  34. kumoai/testing/snow.py +50 -0
  35. kumoai/utils/__init__.py +3 -2
  36. kumoai/utils/progress_logger.py +178 -12
  37. kumoai/utils/sql.py +3 -0
  38. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/METADATA +6 -2
  39. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/RECORD +42 -30
  40. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  41. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  42. kumoai/experimental/rfm/utils.py +0 -344
  43. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/WHEEL +0 -0
  44. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/licenses/LICENSE +0 -0
  45. {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.14.0.dev202512181731.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
6
+ from collections.abc import Sequence
7
+ from dataclasses import dataclass, field
8
+ from itertools import chain
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, Union
7
11
 
8
12
  import pandas as pd
9
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,12 +15,21 @@ from kumoapi.table import TableDefinition
11
15
  from kumoapi.typing import Stype
12
16
  from typing_extensions import Self
13
17
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import Table
18
+ from kumoai import in_notebook, in_snowflake_notebook
19
+ from kumoai.experimental.rfm.base import DataBackend, SQLTable, Table
16
20
  from kumoai.graph import Edge
21
+ from kumoai.mixin import CastMixin
17
22
 
18
23
  if TYPE_CHECKING:
19
24
  import graphviz
25
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
26
+ from snowflake.connector import SnowflakeConnection
27
+
28
+
29
+ @dataclass
30
+ class SqliteConnectionConfig(CastMixin):
31
+ uri: str | Path
32
+ kwargs: dict[str, Any] = field(default_factory=dict)
20
33
 
21
34
 
22
35
  class Graph:
@@ -76,32 +89,53 @@ class Graph:
76
89
  def __init__(
77
90
  self,
78
91
  tables: Sequence[Table],
79
- edges: Optional[Sequence[Edge]] = None,
92
+ edges: Sequence[Edge] | None = None,
80
93
  ) -> None:
81
94
 
82
- self._tables: Dict[str, Table] = {}
83
- self._edges: List[Edge] = []
95
+ self._tables: dict[str, Table] = {}
96
+ self._edges: list[Edge] = []
84
97
 
85
98
  for table in tables:
86
99
  self.add_table(table)
87
100
 
101
+ for table in tables:
102
+ if not isinstance(table, SQLTable):
103
+ continue
104
+ for fkey in table._source_foreign_key_dict.values():
105
+ if fkey.name not in table:
106
+ continue
107
+ # TODO Skip for non-physical table[fkey.name].
108
+ dst_table_names = [
109
+ table.name for table in self.tables.values()
110
+ if isinstance(table, SQLTable)
111
+ and table._source_name == fkey.dst_table
112
+ ]
113
+ if len(dst_table_names) != 1:
114
+ continue
115
+ dst_table = self[dst_table_names[0]]
116
+ if dst_table._primary_key != fkey.primary_key:
117
+ continue
118
+ # TODO Skip for non-physical dst_table.primary_key.
119
+ self.link(table.name, fkey.name, dst_table.name)
120
+
88
121
  for edge in (edges or []):
89
122
  _edge = Edge._cast(edge)
90
123
  assert _edge is not None
91
- self.link(*_edge)
124
+ if _edge not in self._edges:
125
+ self.link(*_edge)
92
126
 
93
127
  @classmethod
94
128
  def from_data(
95
129
  cls,
96
- df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[Sequence[Edge]] = None,
130
+ df_dict: dict[str, pd.DataFrame],
131
+ edges: Sequence[Edge] | None = None,
98
132
  infer_metadata: bool = True,
99
133
  verbose: bool = True,
100
134
  ) -> Self:
101
135
  r"""Creates a :class:`Graph` from a dictionary of
102
136
  :class:`pandas.DataFrame` objects.
103
137
 
104
- Automatically infers table metadata and links.
138
+ Automatically infers table metadata and links by default.
105
139
 
106
140
  .. code-block:: python
107
141
 
@@ -121,54 +155,360 @@ class Graph:
121
155
  ... "table3": df3,
122
156
  ... })
123
157
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
158
  Args:
132
159
  df_dict: A dictionary of data frames, where the keys are the names
133
160
  of the tables and the values hold table data.
161
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
162
+ add to the graph. If not provided, edges will be automatically
163
+ inferred from the data in case ``infer_metadata=True``.
134
164
  infer_metadata: Whether to infer metadata for all tables in the
135
165
  graph.
166
+ verbose: Whether to print verbose output.
167
+ """
168
+ from kumoai.experimental.rfm.backend.local import LocalTable
169
+
170
+ graph = cls(
171
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
172
+ edges=edges or [],
173
+ )
174
+
175
+ if infer_metadata:
176
+ graph.infer_metadata(verbose=False)
177
+
178
+ if edges is None:
179
+ graph.infer_links(verbose=False)
180
+
181
+ if verbose:
182
+ graph.print_metadata()
183
+ graph.print_links()
184
+
185
+ return graph
186
+
187
+ @classmethod
188
+ def from_sqlite(
189
+ cls,
190
+ connection: Union[
191
+ 'AdbcSqliteConnection',
192
+ SqliteConnectionConfig,
193
+ str,
194
+ Path,
195
+ dict[str, Any],
196
+ ],
197
+ tables: Sequence[str | dict[str, Any]] | None = None,
198
+ edges: Sequence[Edge] | None = None,
199
+ infer_metadata: bool = True,
200
+ verbose: bool = True,
201
+ ) -> Self:
202
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
203
+
204
+ Automatically infers table metadata and links by default.
205
+
206
+ .. code-block:: python
207
+
208
+ >>> # doctest: +SKIP
209
+ >>> import kumoai.experimental.rfm as rfm
210
+
211
+ >>> # Create a graph from a SQLite database:
212
+ >>> graph = rfm.Graph.from_sqlite('data.db')
213
+
214
+ >>> # Fine-grained control over table specification:
215
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
216
+ ... 'USERS',
217
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
218
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
219
+ ... ])
220
+
221
+ Args:
222
+ connection: An open connection from
223
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
224
+ path to the database file.
225
+ tables: Set of table names or :class:`SQLiteTable` keyword
226
+ arguments to include. If ``None``, will add all tables present
227
+ in the database.
136
228
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
229
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
230
+ inferred from the data in case ``infer_metadata=True``.
231
+ infer_metadata: Whether to infer missing metadata for all tables in
232
+ the graph.
139
233
  verbose: Whether to print verbose output.
234
+ """
235
+ from kumoai.experimental.rfm.backend.sqlite import (
236
+ Connection,
237
+ SQLiteTable,
238
+ connect,
239
+ )
240
+
241
+ internal_connection = False
242
+ if not isinstance(connection, Connection):
243
+ connection = SqliteConnectionConfig._cast(connection)
244
+ assert isinstance(connection, SqliteConnectionConfig)
245
+ connection = connect(connection.uri, **connection.kwargs)
246
+ internal_connection = True
247
+ assert isinstance(connection, Connection)
248
+
249
+ if tables is None:
250
+ with connection.cursor() as cursor:
251
+ cursor.execute("SELECT name FROM sqlite_master "
252
+ "WHERE type='table'")
253
+ tables = [row[0] for row in cursor.fetchall()]
254
+
255
+ table_kwargs: list[dict[str, Any]] = []
256
+ for table in tables:
257
+ kwargs = dict(name=table) if isinstance(table, str) else table
258
+ table_kwargs.append(kwargs)
140
259
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
260
+ graph = cls(
261
+ tables=[
262
+ SQLiteTable(connection=connection, **kwargs)
263
+ for kwargs in table_kwargs
264
+ ],
265
+ edges=edges or [],
266
+ )
267
+
268
+ if internal_connection:
269
+ graph._connection = connection # type: ignore
270
+
271
+ if infer_metadata:
272
+ graph.infer_metadata(verbose=False)
273
+
274
+ if edges is None:
275
+ graph.infer_links(verbose=False)
276
+
277
+ if verbose:
278
+ graph.print_metadata()
279
+ graph.print_links()
280
+
281
+ return graph
282
+
283
+ @classmethod
284
+ def from_snowflake(
285
+ cls,
286
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
287
+ tables: Sequence[str | dict[str, Any]] | None = None,
288
+ database: str | None = None,
289
+ schema: str | None = None,
290
+ edges: Sequence[Edge] | None = None,
291
+ infer_metadata: bool = True,
292
+ verbose: bool = True,
293
+ ) -> Self:
294
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
295
+ schema.
296
+
297
+ Automatically infers table metadata and links by default.
298
+
299
+ .. code-block:: python
144
300
 
145
- Example:
146
301
  >>> # doctest: +SKIP
147
302
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.Graph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
157
- """
158
- from kumoai.experimental.rfm import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
303
 
161
- graph = cls(tables, edges=edges or [])
304
+ >>> # Create a graph directly in a Snowflake notebook:
305
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
306
+
307
+ >>> # Fine-grained control over table specification:
308
+ >>> graph = rfm.Graph.from_snowflake(tables=[
309
+ ... 'USERS',
310
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
311
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
312
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
313
+
314
+ Args:
315
+ connection: An open connection from
316
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
317
+ :class:`snowflake` connector keyword arguments to open a new
318
+ connection. If ``None``, will re-use an active session in case
319
+ it exists, or create a new connection from credentials stored
320
+ in environment variables.
321
+ tables: Set of table names or :class:`SnowTable` keyword arguments
322
+ to include. If ``None``, will add all tables present in the
323
+ current database and schema.
324
+ database: The database.
325
+ schema: The schema.
326
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
327
+ add to the graph. If not provided, edges will be automatically
328
+ inferred from the data in case ``infer_metadata=True``.
329
+ infer_metadata: Whether to infer metadata for all tables in the
330
+ graph.
331
+ verbose: Whether to print verbose output.
332
+ """
333
+ from kumoai.experimental.rfm.backend.snow import (
334
+ Connection,
335
+ SnowTable,
336
+ connect,
337
+ )
338
+
339
+ if not isinstance(connection, Connection):
340
+ connection = connect(**(connection or {}))
341
+ assert isinstance(connection, Connection)
342
+
343
+ if database is None or schema is None:
344
+ with connection.cursor() as cursor:
345
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
346
+ result = cursor.fetchone()
347
+ database = database or result[0]
348
+ assert database is not None
349
+ schema = schema or result[1]
350
+
351
+ if tables is None:
352
+ if schema is None:
353
+ raise ValueError("No current 'schema' set. Please specify the "
354
+ "Snowflake schema manually")
355
+
356
+ with connection.cursor() as cursor:
357
+ cursor.execute(f"""
358
+ SELECT TABLE_NAME
359
+ FROM {database}.INFORMATION_SCHEMA.TABLES
360
+ WHERE TABLE_SCHEMA = '{schema}'
361
+ """)
362
+ tables = [row[0] for row in cursor.fetchall()]
363
+
364
+ table_kwargs: list[dict[str, Any]] = []
365
+ for table in tables:
366
+ if isinstance(table, str):
367
+ kwargs = dict(name=table, database=database, schema=schema)
368
+ else:
369
+ kwargs = copy.copy(table)
370
+ kwargs.setdefault('database', database)
371
+ kwargs.setdefault('schema', schema)
372
+ table_kwargs.append(kwargs)
373
+
374
+ graph = cls(
375
+ tables=[
376
+ SnowTable(connection=connection, **kwargs)
377
+ for kwargs in table_kwargs
378
+ ],
379
+ edges=edges or [],
380
+ )
162
381
 
163
382
  if infer_metadata:
164
- graph.infer_metadata(verbose)
383
+ graph.infer_metadata(verbose=False)
165
384
 
166
385
  if edges is None:
167
- graph.infer_links(verbose)
386
+ graph.infer_links(verbose=False)
387
+
388
+ if verbose:
389
+ graph.print_metadata()
390
+ graph.print_links()
168
391
 
169
392
  return graph
170
393
 
171
- # Tables ##############################################################
394
+ @classmethod
395
+ def from_snowflake_semantic_view(
396
+ cls,
397
+ semantic_view_name: str,
398
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
399
+ verbose: bool = True,
400
+ ) -> Self:
401
+ import yaml
402
+
403
+ from kumoai.experimental.rfm.backend.snow import (
404
+ Connection,
405
+ SnowTable,
406
+ connect,
407
+ )
408
+
409
+ if not isinstance(connection, Connection):
410
+ connection = connect(**(connection or {}))
411
+ assert isinstance(connection, Connection)
412
+
413
+ with connection.cursor() as cursor:
414
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
415
+ f"'{semantic_view_name}')")
416
+ cfg = yaml.safe_load(cursor.fetchone()[0])
417
+
418
+ graph = cls(tables=[])
419
+
420
+ msgs = []
421
+ for table_cfg in cfg['tables']:
422
+ table_name = table_cfg['name']
423
+ source_table_name = table_cfg['base_table']['table']
424
+ database = table_cfg['base_table']['database']
425
+ schema = table_cfg['base_table']['schema']
426
+
427
+ primary_key: str | None = None
428
+ if 'primary_key' in table_cfg:
429
+ primary_key_cfg = table_cfg['primary_key']
430
+ if len(primary_key_cfg['columns']) == 1:
431
+ primary_key = primary_key_cfg['columns'][0]
432
+ elif len(primary_key_cfg['columns']) > 1:
433
+ msgs.append(f"Failed to add primary key for table "
434
+ f"'{table_name}' since composite primary keys "
435
+ f"are not yet supported")
436
+
437
+ columns: list[str] = []
438
+ for column_cfg in chain(
439
+ table_cfg.get('dimensions', []),
440
+ table_cfg.get('time_dimensions', []),
441
+ table_cfg.get('facts', []),
442
+ ):
443
+ # TODO Add support for derived columns.
444
+ columns.append(column_cfg['name'])
445
+
446
+ table = SnowTable(
447
+ connection,
448
+ name=table_name,
449
+ source_name=source_table_name,
450
+ database=database,
451
+ schema=schema,
452
+ columns=columns,
453
+ primary_key=primary_key,
454
+ )
455
+
456
+ # TODO Add a way to register time columns without heuristic usage.
457
+ table.infer_time_column(verbose=False)
458
+
459
+ graph.add_table(table)
460
+
461
+ for relation_cfg in cfg.get('relationships', []):
462
+ name = relation_cfg['name']
463
+ if len(relation_cfg['relationship_columns']) != 1:
464
+ msgs.append(f"Failed to add relationship '{name}' since "
465
+ f"composite key references are not yet supported")
466
+ continue
467
+
468
+ left_table = relation_cfg['left_table']
469
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
470
+ right_table = relation_cfg['right_table']
471
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
472
+
473
+ if graph[right_table]._primary_key != right_key:
474
+ # Semantic view error - this should never be triggered:
475
+ msgs.append(f"Failed to add relationship '{name}' since the "
476
+ f"referenced key '{right_key}' of table "
477
+ f"'{right_table}' is not a primary key")
478
+ continue
479
+
480
+ if graph[left_table]._primary_key == left_key:
481
+ msgs.append(f"Failed to add relationship '{name}' since the "
482
+ f"referencing key '{left_key}' of table "
483
+ f"'{left_table}' is a primary key")
484
+ continue
485
+
486
+ if left_key not in graph[left_table]:
487
+ graph[left_table].add_column(left_key)
488
+
489
+ graph.link(left_table, left_key, right_table)
490
+
491
+ graph.validate()
492
+
493
+ if verbose:
494
+ graph.print_metadata()
495
+ graph.print_links()
496
+
497
+ if len(msgs) > 0:
498
+ title = (f"Could not fully convert the semantic view definition "
499
+ f"'{semantic_view_name}' into a graph:\n")
500
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
501
+
502
+ return graph
503
+
504
+ # Backend #################################################################
505
+
506
+ @property
507
+ def backend(self) -> DataBackend | None:
508
+ backends = [table.backend for table in self._tables.values()]
509
+ return backends[0] if len(backends) > 0 else None
510
+
511
+ # Tables ##################################################################
172
512
 
173
513
  def has_table(self, name: str) -> bool:
174
514
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -187,7 +527,7 @@ class Graph:
187
527
  return self.tables[name]
188
528
 
189
529
  @property
190
- def tables(self) -> Dict[str, Table]:
530
+ def tables(self) -> dict[str, Table]:
191
531
  r"""Returns the dictionary of table objects."""
192
532
  return self._tables
193
533
 
@@ -207,13 +547,10 @@ class Graph:
207
547
  raise KeyError(f"Cannot add table with name '{table.name}' to "
208
548
  f"this graph; table names must be globally unique.")
209
549
 
210
- if len(self._tables) > 0:
211
- cls = next(iter(self._tables.values())).__class__
212
- if table.__class__ != cls:
213
- raise ValueError(f"Cannot register a "
214
- f"'{table.__class__.__name__}' to this "
215
- f"graph since other tables are of type "
216
- f"'{cls.__name__}'.")
550
+ if self.backend is not None and table.backend != self.backend:
551
+ raise ValueError(f"Cannot register a table with backend "
552
+ f"'{table.backend}' to this graph since other "
553
+ f"tables have backend '{self.backend}'.")
217
554
 
218
555
  self._tables[table.name] = table
219
556
 
@@ -275,9 +612,13 @@ class Graph:
275
612
 
276
613
  def print_metadata(self) -> None:
277
614
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
278
- if in_notebook():
615
+ if in_snowflake_notebook():
616
+ import streamlit as st
617
+ st.markdown("### 🗂️ Graph Metadata")
618
+ st.dataframe(self.metadata, hide_index=True)
619
+ elif in_notebook():
279
620
  from IPython.display import Markdown, display
280
- display(Markdown('### 🗂️ Graph Metadata'))
621
+ display(Markdown("### 🗂️ Graph Metadata"))
281
622
  df = self.metadata
282
623
  try:
283
624
  if hasattr(df.style, 'hide'):
@@ -311,7 +652,7 @@ class Graph:
311
652
  # Edges ###################################################################
312
653
 
313
654
  @property
314
- def edges(self) -> List[Edge]:
655
+ def edges(self) -> list[Edge]:
315
656
  r"""Returns the edges of the graph."""
316
657
  return self._edges
317
658
 
@@ -321,32 +662,42 @@ class Graph:
321
662
  edge.src_table, edge.fkey) for edge in self.edges]
322
663
  edges = sorted(edges)
323
664
 
324
- if in_notebook():
665
+ if in_snowflake_notebook():
666
+ import streamlit as st
667
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
668
+ if len(edges) > 0:
669
+ st.markdown('\n'.join([
670
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
671
+ for edge in edges
672
+ ]))
673
+ else:
674
+ st.markdown("*No links registered*")
675
+ elif in_notebook():
325
676
  from IPython.display import Markdown, display
326
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
677
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
327
678
  if len(edges) > 0:
328
679
  display(
329
680
  Markdown('\n'.join([
330
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
681
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
331
682
  for edge in edges
332
683
  ])))
333
684
  else:
334
- display(Markdown('*No links registered*'))
685
+ display(Markdown("*No links registered*"))
335
686
  else:
336
687
  print("🕸️ Graph Links (FK ↔️ PK):")
337
688
  if len(edges) > 0:
338
689
  print('\n'.join([
339
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
690
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
340
691
  for edge in edges
341
692
  ]))
342
693
  else:
343
- print('No links registered')
694
+ print("No links registered")
344
695
 
345
696
  def link(
346
697
  self,
347
- src_table: Union[str, Table],
698
+ src_table: str | Table,
348
699
  fkey: str,
349
- dst_table: Union[str, Table],
700
+ dst_table: str | Table,
350
701
  ) -> Self:
351
702
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
352
703
  key ``fkey`` in the source table to the primary key in the destination
@@ -407,9 +758,9 @@ class Graph:
407
758
 
408
759
  def unlink(
409
760
  self,
410
- src_table: Union[str, Table],
761
+ src_table: str | Table,
411
762
  fkey: str,
412
- dst_table: Union[str, Table],
763
+ dst_table: str | Table,
413
764
  ) -> Self:
414
765
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
415
766
 
@@ -439,17 +790,13 @@ class Graph:
439
790
  return self
440
791
 
441
792
  def infer_links(self, verbose: bool = True) -> Self:
442
- r"""Infers links for the tables and adds them as edges to the graph.
793
+ r"""Infers missing links for the tables and adds them as edges to the
794
+ graph.
443
795
 
444
796
  Args:
445
797
  verbose: Whether to print verbose output.
446
-
447
- Note:
448
- This function expects graph edges to be undefined upfront.
449
798
  """
450
- if len(self.edges) > 0:
451
- warnings.warn("Cannot infer links if graph edges already exist")
452
- return self
799
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
453
800
 
454
801
  # A list of primary key candidates (+score) for every column:
455
802
  candidate_dict: dict[
@@ -474,6 +821,9 @@ class Graph:
474
821
  src_table_name = src_table.name.lower()
475
822
 
476
823
  for src_key in src_table.columns:
824
+ if (src_table.name, src_key.name) in known_edges:
825
+ continue
826
+
477
827
  if src_key == src_table.primary_key:
478
828
  continue # Cannot link to primary key.
479
829
 
@@ -539,10 +889,9 @@ class Graph:
539
889
  score += 1.0
540
890
 
541
891
  # Cardinality ratio:
542
- src_num_rows = src_table._num_rows()
543
- dst_num_rows = dst_table._num_rows()
544
- if (src_num_rows is not None and dst_num_rows is not None
545
- and src_num_rows > dst_num_rows):
892
+ if (src_table._num_rows is not None
893
+ and dst_table._num_rows is not None
894
+ and src_table._num_rows > dst_table._num_rows):
546
895
  score += 1.0
547
896
 
548
897
  if score < 5.0:
@@ -588,6 +937,10 @@ class Graph:
588
937
  raise ValueError("At least one table needs to be added to the "
589
938
  "graph")
590
939
 
940
+ backends = {table.backend for table in self._tables.values()}
941
+ if len(backends) != 1:
942
+ raise ValueError("Found multiple table backends in the graph")
943
+
591
944
  for edge in self.edges:
592
945
  src_table, fkey, dst_table = edge
593
946
 
@@ -634,7 +987,7 @@ class Graph:
634
987
 
635
988
  def visualize(
636
989
  self,
637
- path: Optional[Union[str, io.BytesIO]] = None,
990
+ path: str | io.BytesIO | None = None,
638
991
  show_columns: bool = True,
639
992
  ) -> 'graphviz.Graph':
640
993
  r"""Visualizes the tables and edges in this graph using the
@@ -659,33 +1012,33 @@ class Graph:
659
1012
 
660
1013
  return True
661
1014
 
662
- # Check basic dependency:
663
- if not find_spec('graphviz'):
664
- raise ModuleNotFoundError("The 'graphviz' package is required for "
665
- "visualization")
666
- elif not has_graphviz_executables():
1015
+ try: # Check basic dependency:
1016
+ import graphviz
1017
+ except ImportError as e:
1018
+ raise ImportError("The 'graphviz' package is required for "
1019
+ "visualization") from e
1020
+
1021
+ if not in_snowflake_notebook() and not has_graphviz_executables():
667
1022
  raise RuntimeError("Could not visualize graph as 'graphviz' "
668
1023
  "executables are not installed. These "
669
1024
  "dependencies are required in addition to the "
670
1025
  "'graphviz' Python package. Please install "
671
1026
  "them as described at "
672
1027
  "https://graphviz.org/download/.")
673
- else:
674
- import graphviz
675
1028
 
676
- format: Optional[str] = None
1029
+ format: str | None = None
677
1030
  if isinstance(path, str):
678
1031
  format = path.split('.')[-1]
679
1032
  elif isinstance(path, io.BytesIO):
680
1033
  format = 'svg'
681
1034
  graph = graphviz.Graph(format=format)
682
1035
 
683
- def left_align(keys: List[str]) -> str:
1036
+ def left_align(keys: list[str]) -> str:
684
1037
  if len(keys) == 0:
685
1038
  return ""
686
1039
  return '\\l'.join(keys) + '\\l'
687
1040
 
688
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1041
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
689
1042
  for src_table_name, fkey_name, _ in self.edges:
690
1043
  fkeys_dict[src_table_name].append(fkey_name)
691
1044
 
@@ -755,6 +1108,9 @@ class Graph:
755
1108
  graph.render(path, cleanup=True)
756
1109
  elif isinstance(path, io.BytesIO):
757
1110
  path.write(graph.pipe())
1111
+ elif in_snowflake_notebook():
1112
+ import streamlit as st
1113
+ st.graphviz_chart(graph)
758
1114
  elif in_notebook():
759
1115
  from IPython.display import display
760
1116
  display(graph)
@@ -778,8 +1134,8 @@ class Graph:
778
1134
  # Helpers #################################################################
779
1135
 
780
1136
  def _to_api_graph_definition(self) -> GraphDefinition:
781
- tables: Dict[str, TableDefinition] = {}
782
- col_groups: List[ColumnKeyGroup] = []
1137
+ tables: dict[str, TableDefinition] = {}
1138
+ col_groups: list[ColumnKeyGroup] = []
783
1139
  for table_name, table in self.tables.items():
784
1140
  tables[table_name] = table._to_api_table_definition()
785
1141
  if table.primary_key is None:
@@ -822,3 +1178,7 @@ class Graph:
822
1178
  f' tables={tables},\n'
823
1179
  f' edges={edges},\n'
824
1180
  f')')
1181
+
1182
+ def __del__(self) -> None:
1183
+ if hasattr(self, '_connection'):
1184
+ self._connection.close()