kumoai 2.12.1__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +9 -13
  4. kumoai/client/pquery.py +6 -2
  5. kumoai/connector/utils.py +23 -2
  6. kumoai/experimental/rfm/__init__.py +162 -46
  7. kumoai/experimental/rfm/backend/__init__.py +0 -0
  8. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
  10. kumoai/experimental/rfm/backend/local/sampler.py +313 -0
  11. kumoai/experimental/rfm/backend/local/table.py +119 -0
  12. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +135 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
  18. kumoai/experimental/rfm/base/__init__.py +23 -0
  19. kumoai/experimental/rfm/base/column.py +66 -0
  20. kumoai/experimental/rfm/base/sampler.py +773 -0
  21. kumoai/experimental/rfm/base/source.py +19 -0
  22. kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
  23. kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
  24. kumoai/experimental/rfm/infer/__init__.py +6 -0
  25. kumoai/experimental/rfm/infer/dtype.py +79 -0
  26. kumoai/experimental/rfm/infer/pkey.py +126 -0
  27. kumoai/experimental/rfm/infer/time_col.py +62 -0
  28. kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
  29. kumoai/experimental/rfm/rfm.py +233 -174
  30. kumoai/experimental/rfm/sagemaker.py +138 -0
  31. kumoai/spcs.py +1 -3
  32. kumoai/testing/decorators.py +1 -1
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +2 -0
  35. kumoai/utils/sql.py +3 -0
  36. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +12 -2
  37. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +40 -23
  38. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  39. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  40. kumoai/experimental/rfm/utils.py +0 -344
  41. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
  42. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.12.1.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
8
 
8
9
  import pandas as pd
9
10
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,20 +12,29 @@ from kumoapi.table import TableDefinition
11
12
  from kumoapi.typing import Stype
12
13
  from typing_extensions import Self
13
14
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai import in_notebook, in_snowflake_notebook
16
+ from kumoai.experimental.rfm.base import DataBackend, Table
16
17
  from kumoai.graph import Edge
18
+ from kumoai.mixin import CastMixin
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  import graphviz
22
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
23
+ from snowflake.connector import SnowflakeConnection
20
24
 
21
25
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
26
+ @dataclass
27
+ class SqliteConnectionConfig(CastMixin):
28
+ uri: Union[str, Path]
29
+ kwargs: Dict[str, Any] = field(default_factory=dict)
30
+
31
+
32
+ class Graph:
33
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
34
  tables in a relational database.
25
35
 
26
36
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
37
+ :class:`Graph` is created, you can use it to initialize the
28
38
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
39
 
30
40
  .. code-block:: python
@@ -44,7 +54,7 @@ class LocalGraph:
44
54
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
55
 
46
56
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
57
+ >>> graph = rfm.Graph({
48
58
  ... "table1": table1,
49
59
  ... "table2": table2,
50
60
  ... "table3": table3,
@@ -75,33 +85,47 @@ class LocalGraph:
75
85
 
76
86
  def __init__(
77
87
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
88
+ tables: Sequence[Table],
89
+ edges: Optional[Sequence[Edge]] = None,
80
90
  ) -> None:
81
91
 
82
- self._tables: Dict[str, LocalTable] = {}
92
+ self._tables: Dict[str, Table] = {}
83
93
  self._edges: List[Edge] = []
84
94
 
85
95
  for table in tables:
86
96
  self.add_table(table)
87
97
 
98
+ for table in tables:
99
+ for fkey in table._source_foreign_key_dict.values():
100
+ if fkey.name not in table or fkey.dst_table not in self:
101
+ continue
102
+ if self[fkey.dst_table].primary_key is None:
103
+ self[fkey.dst_table].primary_key = fkey.primary_key
104
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
+ raise ValueError(f"Found duplicate primary key definition "
106
+ f"'{self[fkey.dst_table]._primary_key}' "
107
+ f"and '{fkey.primary_key}' in table "
108
+ f"'{fkey.dst_table}'.")
109
+ self.link(table.name, fkey.name, fkey.dst_table)
110
+
88
111
  for edge in (edges or []):
89
112
  _edge = Edge._cast(edge)
90
113
  assert _edge is not None
91
- self.link(*_edge)
114
+ if _edge not in self._edges:
115
+ self.link(*_edge)
92
116
 
93
117
  @classmethod
94
118
  def from_data(
95
119
  cls,
96
120
  df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
121
+ edges: Optional[Sequence[Edge]] = None,
98
122
  infer_metadata: bool = True,
99
123
  verbose: bool = True,
100
124
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
125
+ r"""Creates a :class:`Graph` from a dictionary of
102
126
  :class:`pandas.DataFrame` objects.
103
127
 
104
- Automatically infers table metadata and links.
128
+ Automatically infers table metadata and links by default.
105
129
 
106
130
  .. code-block:: python
107
131
 
@@ -115,59 +139,274 @@ class LocalGraph:
115
139
  >>> df3 = pd.DataFrame(...)
116
140
 
117
141
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
142
+ >>> graph = rfm.Graph.from_data({
119
143
  ... "table1": df1,
120
144
  ... "table2": df2,
121
145
  ... "table3": df3,
122
146
  ... })
123
147
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
148
  Args:
132
149
  df_dict: A dictionary of data frames, where the keys are the names
133
150
  of the tables and the values hold table data.
151
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
152
+ add to the graph. If not provided, edges will be automatically
153
+ inferred from the data in case ``infer_metadata=True``.
134
154
  infer_metadata: Whether to infer metadata for all tables in the
135
155
  graph.
156
+ verbose: Whether to print verbose output.
157
+ """
158
+ from kumoai.experimental.rfm.backend.local import LocalTable
159
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
+
161
+ graph = cls(tables, edges=edges or [])
162
+
163
+ if infer_metadata:
164
+ graph.infer_metadata(False)
165
+
166
+ if edges is None:
167
+ graph.infer_links(False)
168
+
169
+ if verbose:
170
+ graph.print_metadata()
171
+ graph.print_links()
172
+
173
+ return graph
174
+
175
+ @classmethod
176
+ def from_sqlite(
177
+ cls,
178
+ connection: Union[
179
+ 'AdbcSqliteConnection',
180
+ SqliteConnectionConfig,
181
+ str,
182
+ Path,
183
+ Dict[str, Any],
184
+ ],
185
+ table_names: Optional[Sequence[str]] = None,
186
+ edges: Optional[Sequence[Edge]] = None,
187
+ infer_metadata: bool = True,
188
+ verbose: bool = True,
189
+ ) -> Self:
190
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
191
+
192
+ Automatically infers table metadata and links by default.
193
+
194
+ .. code-block:: python
195
+
196
+ >>> # doctest: +SKIP
197
+ >>> import kumoai.experimental.rfm as rfm
198
+
199
+ >>> # Create a graph from a SQLite database:
200
+ >>> graph = rfm.Graph.from_sqlite('data.db')
201
+
202
+ Args:
203
+ connection: An open connection from
204
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
+ path to the database file.
206
+ table_names: Set of table names to include. If ``None``, will add
207
+ all tables present in the database.
136
208
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
209
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
210
+ inferred from the data in case ``infer_metadata=True``.
211
+ infer_metadata: Whether to infer metadata for all tables in the
212
+ graph.
139
213
  verbose: Whether to print verbose output.
214
+ """
215
+ from kumoai.experimental.rfm.backend.sqlite import (
216
+ Connection,
217
+ SQLiteTable,
218
+ connect,
219
+ )
220
+
221
+ internal_connection = False
222
+ if not isinstance(connection, Connection):
223
+ connection = SqliteConnectionConfig._cast(connection)
224
+ assert isinstance(connection, SqliteConnectionConfig)
225
+ connection = connect(connection.uri, **connection.kwargs)
226
+ internal_connection = True
227
+ assert isinstance(connection, Connection)
228
+
229
+ if table_names is None:
230
+ with connection.cursor() as cursor:
231
+ cursor.execute("SELECT name FROM sqlite_master "
232
+ "WHERE type='table'")
233
+ table_names = [row[0] for row in cursor.fetchall()]
234
+
235
+ tables = [SQLiteTable(connection, name) for name in table_names]
140
236
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
237
+ graph = cls(tables, edges=edges or [])
238
+
239
+ if internal_connection:
240
+ graph._connection = connection # type: ignore
241
+
242
+ if infer_metadata:
243
+ graph.infer_metadata(False)
244
+
245
+ if edges is None:
246
+ graph.infer_links(False)
247
+
248
+ if verbose:
249
+ graph.print_metadata()
250
+ graph.print_links()
251
+
252
+ return graph
253
+
254
+ @classmethod
255
+ def from_snowflake(
256
+ cls,
257
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
258
+ database: Optional[str] = None,
259
+ schema: Optional[str] = None,
260
+ table_names: Optional[Sequence[str]] = None,
261
+ edges: Optional[Sequence[Edge]] = None,
262
+ infer_metadata: bool = True,
263
+ verbose: bool = True,
264
+ ) -> Self:
265
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
266
+ schema.
267
+
268
+ Automatically infers table metadata and links by default.
269
+
270
+ .. code-block:: python
144
271
 
145
- Example:
146
272
  >>> # doctest: +SKIP
147
273
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
274
+
275
+ >>> # Create a graph directly in a Snowflake notebook:
276
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
277
+
278
+ Args:
279
+ connection: An open connection from
280
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
281
+ :class:`snowflake` connector keyword arguments to open a new
282
+ connection. If ``None``, will re-use an active session in case
283
+ it exists, or create a new connection from credentials stored
284
+ in environment variables.
285
+ database: The database.
286
+ schema: The schema.
287
+ table_names: Set of table names to include. If ``None``, will add
288
+ all tables present in the database.
289
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
290
+ add to the graph. If not provided, edges will be automatically
291
+ inferred from the data in case ``infer_metadata=True``.
292
+ infer_metadata: Whether to infer metadata for all tables in the
293
+ graph.
294
+ verbose: Whether to print verbose output.
157
295
  """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
296
+ from kumoai.experimental.rfm.backend.snow import (
297
+ Connection,
298
+ SnowTable,
299
+ connect,
300
+ )
301
+
302
+ if not isinstance(connection, Connection):
303
+ connection = connect(**(connection or {}))
304
+ assert isinstance(connection, Connection)
305
+
306
+ if table_names is None:
307
+ with connection.cursor() as cursor:
308
+ if database is None and schema is None:
309
+ cursor.execute("SELECT CURRENT_DATABASE(), "
310
+ "CURRENT_SCHEMA()")
311
+ result = cursor.fetchone()
312
+ database = database or result[0]
313
+ schema = schema or result[1]
314
+ cursor.execute(f"""
315
+ SELECT TABLE_NAME
316
+ FROM {database}.INFORMATION_SCHEMA.TABLES
317
+ WHERE TABLE_SCHEMA = '{schema}'
318
+ """)
319
+ table_names = [row[0] for row in cursor.fetchall()]
320
+
321
+ tables = [
322
+ SnowTable(
323
+ connection,
324
+ name=table_name,
325
+ database=database,
326
+ schema=schema,
327
+ ) for table_name in table_names
328
+ ]
159
329
 
160
330
  graph = cls(tables, edges=edges or [])
161
331
 
162
332
  if infer_metadata:
163
- graph.infer_metadata(verbose)
333
+ graph.infer_metadata(False)
164
334
 
165
335
  if edges is None:
166
- graph.infer_links(verbose)
336
+ graph.infer_links(False)
337
+
338
+ if verbose:
339
+ graph.print_metadata()
340
+ graph.print_links()
341
+
342
+ return graph
343
+
344
+ @classmethod
345
+ def from_snowflake_semantic_view(
346
+ cls,
347
+ semantic_view_name: str,
348
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
349
+ verbose: bool = True,
350
+ ) -> Self:
351
+ import yaml
352
+
353
+ from kumoai.experimental.rfm.backend.snow import (
354
+ Connection,
355
+ SnowTable,
356
+ connect,
357
+ )
358
+
359
+ if not isinstance(connection, Connection):
360
+ connection = connect(**(connection or {}))
361
+ assert isinstance(connection, Connection)
362
+
363
+ with connection.cursor() as cursor:
364
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
365
+ f"'{semantic_view_name}')")
366
+ view = yaml.safe_load(cursor.fetchone()[0])
367
+
368
+ graph = cls(tables=[])
369
+
370
+ for table_desc in view['tables']:
371
+ primary_key: Optional[str] = None
372
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
373
+ and len(table_desc['primary_key']['columns']) == 1):
374
+ primary_key = table_desc['primary_key']['columns'][0]
375
+
376
+ table = SnowTable(
377
+ connection,
378
+ name=table_desc['base_table']['table'],
379
+ database=table_desc['base_table']['database'],
380
+ schema=table_desc['base_table']['schema'],
381
+ primary_key=primary_key,
382
+ )
383
+ graph.add_table(table)
384
+
385
+ # TODO Find a solution to register time columns!
386
+
387
+ for relations in view['relationships']:
388
+ if len(relations['relationship_columns']) != 1:
389
+ continue # NOTE No composite keys yet.
390
+ graph.link(
391
+ src_table=relations['left_table'],
392
+ fkey=relations['relationship_columns'][0]['left_column'],
393
+ dst_table=relations['right_table'],
394
+ )
395
+
396
+ if verbose:
397
+ graph.print_metadata()
398
+ graph.print_links()
167
399
 
168
400
  return graph
169
401
 
170
- # Tables ##############################################################
402
+ # Backend #################################################################
403
+
404
+ @property
405
+ def backend(self) -> DataBackend | None:
406
+ backends = [table.backend for table in self._tables.values()]
407
+ return backends[0] if len(backends) > 0 else None
408
+
409
+ # Tables ##################################################################
171
410
 
172
411
  def has_table(self, name: str) -> bool:
173
412
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -175,7 +414,7 @@ class LocalGraph:
175
414
  """
176
415
  return name in self.tables
177
416
 
178
- def table(self, name: str) -> LocalTable:
417
+ def table(self, name: str) -> Table:
179
418
  r"""Returns the table with name ``name`` in the graph.
180
419
 
181
420
  Raises:
@@ -186,11 +425,11 @@ class LocalGraph:
186
425
  return self.tables[name]
187
426
 
188
427
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
428
+ def tables(self) -> Dict[str, Table]:
190
429
  r"""Returns the dictionary of table objects."""
191
430
  return self._tables
192
431
 
193
- def add_table(self, table: LocalTable) -> Self:
432
+ def add_table(self, table: Table) -> Self:
194
433
  r"""Adds a table to the graph.
195
434
 
196
435
  Args:
@@ -199,11 +438,18 @@ class LocalGraph:
199
438
  Raises:
200
439
  KeyError: If a table with the same name already exists in the
201
440
  graph.
441
+ ValueError: If the table belongs to a different backend than the
442
+ rest of the tables in the graph.
202
443
  """
203
444
  if table.name in self._tables:
204
445
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
446
  f"this graph; table names must be globally unique.")
206
447
 
448
+ if self.backend is not None and table.backend != self.backend:
449
+ raise ValueError(f"Cannot register a table with backend "
450
+ f"'{table.backend}' to this graph since other "
451
+ f"tables have backend '{self.backend}'.")
452
+
207
453
  self._tables[table.name] = table
208
454
 
209
455
  return self
@@ -241,7 +487,7 @@ class LocalGraph:
241
487
  Example:
242
488
  >>> # doctest: +SKIP
243
489
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
490
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
491
  >>> graph.metadata # doctest: +SKIP
246
492
  name primary_key time_column end_time_column
247
493
  0 users user_id - -
@@ -263,10 +509,14 @@ class LocalGraph:
263
509
  })
264
510
 
265
511
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
512
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
513
+ if in_snowflake_notebook():
514
+ import streamlit as st
515
+ st.markdown("### 🗂️ Graph Metadata")
516
+ st.dataframe(self.metadata, hide_index=True)
517
+ elif in_notebook():
268
518
  from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
519
+ display(Markdown("### 🗂️ Graph Metadata"))
270
520
  df = self.metadata
271
521
  try:
272
522
  if hasattr(df.style, 'hide'):
@@ -287,7 +537,7 @@ class LocalGraph:
287
537
 
288
538
  Note:
289
539
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
540
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
541
  """
292
542
  for table in self.tables.values():
293
543
  table.infer_metadata(verbose=False)
@@ -305,37 +555,47 @@ class LocalGraph:
305
555
  return self._edges
306
556
 
307
557
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
558
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
559
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
560
  edge.src_table, edge.fkey) for edge in self.edges]
311
561
  edges = sorted(edges)
312
562
 
313
- if in_notebook():
563
+ if in_snowflake_notebook():
564
+ import streamlit as st
565
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
566
+ if len(edges) > 0:
567
+ st.markdown('\n'.join([
568
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
569
+ for edge in edges
570
+ ]))
571
+ else:
572
+ st.markdown("*No links registered*")
573
+ elif in_notebook():
314
574
  from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
575
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
316
576
  if len(edges) > 0:
317
577
  display(
318
578
  Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
579
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
320
580
  for edge in edges
321
581
  ])))
322
582
  else:
323
- display(Markdown('*No links registered*'))
583
+ display(Markdown("*No links registered*"))
324
584
  else:
325
585
  print("🕸️ Graph Links (FK ↔️ PK):")
326
586
  if len(edges) > 0:
327
587
  print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
588
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
329
589
  for edge in edges
330
590
  ]))
331
591
  else:
332
- print('No links registered')
592
+ print("No links registered")
333
593
 
334
594
  def link(
335
595
  self,
336
- src_table: Union[str, LocalTable],
596
+ src_table: Union[str, Table],
337
597
  fkey: str,
338
- dst_table: Union[str, LocalTable],
598
+ dst_table: Union[str, Table],
339
599
  ) -> Self:
340
600
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
601
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +618,11 @@ class LocalGraph:
358
618
  table does not exist in the graph, if the source key does not
359
619
  exist in the source table.
360
620
  """
361
- if isinstance(src_table, LocalTable):
621
+ if isinstance(src_table, Table):
362
622
  src_table = src_table.name
363
623
  assert isinstance(src_table, str)
364
624
 
365
- if isinstance(dst_table, LocalTable):
625
+ if isinstance(dst_table, Table):
366
626
  dst_table = dst_table.name
367
627
  assert isinstance(dst_table, str)
368
628
 
@@ -396,9 +656,9 @@ class LocalGraph:
396
656
 
397
657
  def unlink(
398
658
  self,
399
- src_table: Union[str, LocalTable],
659
+ src_table: Union[str, Table],
400
660
  fkey: str,
401
- dst_table: Union[str, LocalTable],
661
+ dst_table: Union[str, Table],
402
662
  ) -> Self:
403
663
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
664
 
@@ -410,11 +670,11 @@ class LocalGraph:
410
670
  Raises:
411
671
  ValueError: if the edge is not present in the graph.
412
672
  """
413
- if isinstance(src_table, LocalTable):
673
+ if isinstance(src_table, Table):
414
674
  src_table = src_table.name
415
675
  assert isinstance(src_table, str)
416
676
 
417
- if isinstance(dst_table, LocalTable):
677
+ if isinstance(dst_table, Table):
418
678
  dst_table = dst_table.name
419
679
  assert isinstance(dst_table, str)
420
680
 
@@ -428,17 +688,13 @@ class LocalGraph:
428
688
  return self
429
689
 
430
690
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
691
+ r"""Infers missing links for the tables and adds them as edges to the
692
+ graph.
432
693
 
433
694
  Args:
434
695
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
696
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
697
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
442
698
 
443
699
  # A list of primary key candidates (+score) for every column:
444
700
  candidate_dict: dict[
@@ -463,6 +719,9 @@ class LocalGraph:
463
719
  src_table_name = src_table.name.lower()
464
720
 
465
721
  for src_key in src_table.columns:
722
+ if (src_table.name, src_key.name) in known_edges:
723
+ continue
724
+
466
725
  if src_key == src_table.primary_key:
467
726
  continue # Cannot link to primary key.
468
727
 
@@ -528,7 +787,9 @@ class LocalGraph:
528
787
  score += 1.0
529
788
 
530
789
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
790
+ if (src_table._num_rows is not None
791
+ and dst_table._num_rows is not None
792
+ and src_table._num_rows > dst_table._num_rows):
532
793
  score += 1.0
533
794
 
534
795
  if score < 5.0:
@@ -574,6 +835,10 @@ class LocalGraph:
574
835
  raise ValueError("At least one table needs to be added to the "
575
836
  "graph")
576
837
 
838
+ backends = {table.backend for table in self._tables.values()}
839
+ if len(backends) != 1:
840
+ raise ValueError("Found multiple table backends in the graph")
841
+
577
842
  for edge in self.edges:
578
843
  src_table, fkey, dst_table = edge
579
844
 
@@ -645,19 +910,19 @@ class LocalGraph:
645
910
 
646
911
  return True
647
912
 
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
913
+ try: # Check basic dependency:
914
+ import graphviz
915
+ except ImportError as e:
916
+ raise ImportError("The 'graphviz' package is required for "
917
+ "visualization") from e
918
+
919
+ if not in_snowflake_notebook() and not has_graphviz_executables():
653
920
  raise RuntimeError("Could not visualize graph as 'graphviz' "
654
921
  "executables are not installed. These "
655
922
  "dependencies are required in addition to the "
656
923
  "'graphviz' Python package. Please install "
657
924
  "them as described at "
658
925
  "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
926
 
662
927
  format: Optional[str] = None
663
928
  if isinstance(path, str):
@@ -741,6 +1006,9 @@ class LocalGraph:
741
1006
  graph.render(path, cleanup=True)
742
1007
  elif isinstance(path, io.BytesIO):
743
1008
  path.write(graph.pipe())
1009
+ elif in_snowflake_notebook():
1010
+ import streamlit as st
1011
+ st.graphviz_chart(graph)
744
1012
  elif in_notebook():
745
1013
  from IPython.display import display
746
1014
  display(graph)
@@ -790,7 +1058,7 @@ class LocalGraph:
790
1058
  def __contains__(self, name: str) -> bool:
791
1059
  return self.has_table(name)
792
1060
 
793
- def __getitem__(self, name: str) -> LocalTable:
1061
+ def __getitem__(self, name: str) -> Table:
794
1062
  return self.table(name)
795
1063
 
796
1064
  def __delitem__(self, name: str) -> None:
@@ -808,3 +1076,7 @@ class LocalGraph:
808
1076
  f' tables={tables},\n'
809
1077
  f' edges={edges},\n'
810
1078
  f')')
1079
+
1080
+ def __del__(self) -> None:
1081
+ if hasattr(self, '_connection'):
1082
+ self._connection.close()