kumoai 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512061731__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/connector/utils.py +23 -2
  4. kumoai/experimental/rfm/__init__.py +20 -45
  5. kumoai/experimental/rfm/backend/__init__.py +0 -0
  6. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  7. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
  8. kumoai/experimental/rfm/backend/local/sampler.py +131 -0
  9. kumoai/experimental/rfm/backend/local/table.py +109 -0
  10. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  11. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  12. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  14. kumoai/experimental/rfm/base/__init__.py +14 -0
  15. kumoai/experimental/rfm/base/column.py +66 -0
  16. kumoai/experimental/rfm/base/sampler.py +287 -0
  17. kumoai/experimental/rfm/base/source.py +18 -0
  18. kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
  19. kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
  20. kumoai/experimental/rfm/infer/__init__.py +6 -0
  21. kumoai/experimental/rfm/infer/dtype.py +79 -0
  22. kumoai/experimental/rfm/infer/pkey.py +126 -0
  23. kumoai/experimental/rfm/infer/time_col.py +62 -0
  24. kumoai/experimental/rfm/local_graph_sampler.py +43 -2
  25. kumoai/experimental/rfm/local_pquery_driver.py +1 -1
  26. kumoai/experimental/rfm/rfm.py +7 -17
  27. kumoai/experimental/rfm/sagemaker.py +11 -3
  28. kumoai/testing/decorators.py +1 -1
  29. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/METADATA +9 -8
  30. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/RECORD +33 -19
  31. kumoai/experimental/rfm/utils.py +0 -344
  32. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/WHEEL +0 -0
  33. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/licenses/LICENSE +0 -0
  34. {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512061731.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,9 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
8
 
8
9
  import pandas as pd
9
10
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,20 +12,29 @@ from kumoapi.table import TableDefinition
11
12
  from kumoapi.typing import Stype
12
13
  from typing_extensions import Self
13
14
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
15
+ from kumoai import in_notebook, in_snowflake_notebook
16
+ from kumoai.experimental.rfm import Table
16
17
  from kumoai.graph import Edge
18
+ from kumoai.mixin import CastMixin
17
19
 
18
20
  if TYPE_CHECKING:
19
21
  import graphviz
22
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
23
+ from snowflake.connector import SnowflakeConnection
20
24
 
21
25
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
26
+ @dataclass
27
+ class SqliteConnectionConfig(CastMixin):
28
+ uri: Union[str, Path]
29
+ kwargs: Dict[str, Any] = field(default_factory=dict)
30
+
31
+
32
+ class Graph:
33
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
34
  tables in a relational database.
25
35
 
26
36
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
37
+ :class:`Graph` is created, you can use it to initialize the
28
38
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
39
 
30
40
  .. code-block:: python
@@ -44,7 +54,7 @@ class LocalGraph:
44
54
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
55
 
46
56
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
57
+ >>> graph = rfm.Graph({
48
58
  ... "table1": table1,
49
59
  ... "table2": table2,
50
60
  ... "table3": table3,
@@ -75,33 +85,47 @@ class LocalGraph:
75
85
 
76
86
  def __init__(
77
87
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
88
+ tables: Sequence[Table],
89
+ edges: Optional[Sequence[Edge]] = None,
80
90
  ) -> None:
81
91
 
82
- self._tables: Dict[str, LocalTable] = {}
92
+ self._tables: Dict[str, Table] = {}
83
93
  self._edges: List[Edge] = []
84
94
 
85
95
  for table in tables:
86
96
  self.add_table(table)
87
97
 
98
+ for table in tables:
99
+ for fkey in table._source_foreign_key_dict.values():
100
+ if fkey.name not in table or fkey.dst_table not in self:
101
+ continue
102
+ if self[fkey.dst_table].primary_key is None:
103
+ self[fkey.dst_table].primary_key = fkey.primary_key
104
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
+ raise ValueError(f"Found duplicate primary key definition "
106
+ f"'{self[fkey.dst_table]._primary_key}' "
107
+ f"and '{fkey.primary_key}' in table "
108
+ f"'{fkey.dst_table}'.")
109
+ self.link(table.name, fkey.name, fkey.dst_table)
110
+
88
111
  for edge in (edges or []):
89
112
  _edge = Edge._cast(edge)
90
113
  assert _edge is not None
91
- self.link(*_edge)
114
+ if _edge not in self._edges:
115
+ self.link(*_edge)
92
116
 
93
117
  @classmethod
94
118
  def from_data(
95
119
  cls,
96
120
  df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
121
+ edges: Optional[Sequence[Edge]] = None,
98
122
  infer_metadata: bool = True,
99
123
  verbose: bool = True,
100
124
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
125
+ r"""Creates a :class:`Graph` from a dictionary of
102
126
  :class:`pandas.DataFrame` objects.
103
127
 
104
- Automatically infers table metadata and links.
128
+ Automatically infers table metadata and links by default.
105
129
 
106
130
  .. code-block:: python
107
131
 
@@ -115,55 +139,258 @@ class LocalGraph:
115
139
  >>> df3 = pd.DataFrame(...)
116
140
 
117
141
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
142
+ >>> graph = rfm.Graph.from_data({
119
143
  ... "table1": df1,
120
144
  ... "table2": df2,
121
145
  ... "table3": df3,
122
146
  ... })
123
147
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
148
  Args:
132
149
  df_dict: A dictionary of data frames, where the keys are the names
133
150
  of the tables and the values hold table data.
151
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
152
+ add to the graph. If not provided, edges will be automatically
153
+ inferred from the data in case ``infer_metadata=True``.
134
154
  infer_metadata: Whether to infer metadata for all tables in the
135
155
  graph.
156
+ verbose: Whether to print verbose output.
157
+ """
158
+ from kumoai.experimental.rfm.backend.local import LocalTable
159
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
+
161
+ graph = cls(tables, edges=edges or [])
162
+
163
+ if infer_metadata:
164
+ graph.infer_metadata(False)
165
+
166
+ if edges is None:
167
+ graph.infer_links(False)
168
+
169
+ if verbose:
170
+ graph.print_metadata()
171
+ graph.print_links()
172
+
173
+ return graph
174
+
175
+ @classmethod
176
+ def from_sqlite(
177
+ cls,
178
+ connection: Union[
179
+ 'AdbcSqliteConnection',
180
+ SqliteConnectionConfig,
181
+ str,
182
+ Path,
183
+ Dict[str, Any],
184
+ ],
185
+ table_names: Optional[Sequence[str]] = None,
186
+ edges: Optional[Sequence[Edge]] = None,
187
+ infer_metadata: bool = True,
188
+ verbose: bool = True,
189
+ ) -> Self:
190
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
191
+
192
+ Automatically infers table metadata and links by default.
193
+
194
+ .. code-block:: python
195
+
196
+ >>> # doctest: +SKIP
197
+ >>> import kumoai.experimental.rfm as rfm
198
+
199
+ >>> # Create a graph from a SQLite database:
200
+ >>> graph = rfm.Graph.from_sqlite('data.db')
201
+
202
+ Args:
203
+ connection: An open connection from
204
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
+ path to the database file.
206
+ table_names: Set of table names to include. If ``None``, will add
207
+ all tables present in the database.
136
208
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
209
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
210
+ inferred from the data in case ``infer_metadata=True``.
211
+ infer_metadata: Whether to infer metadata for all tables in the
212
+ graph.
139
213
  verbose: Whether to print verbose output.
214
+ """
215
+ from kumoai.experimental.rfm.backend.sqlite import (
216
+ Connection,
217
+ SQLiteTable,
218
+ connect,
219
+ )
220
+
221
+ if not isinstance(connection, Connection):
222
+ connection = SqliteConnectionConfig._cast(connection)
223
+ assert isinstance(connection, SqliteConnectionConfig)
224
+ connection = connect(connection.uri, **connection.kwargs)
225
+ assert isinstance(connection, Connection)
226
+
227
+ if table_names is None:
228
+ with connection.cursor() as cursor:
229
+ cursor.execute("SELECT name FROM sqlite_master "
230
+ "WHERE type='table'")
231
+ table_names = [row[0] for row in cursor.fetchall()]
232
+
233
+ tables = [SQLiteTable(connection, name) for name in table_names]
140
234
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
235
+ graph = cls(tables, edges=edges or [])
236
+
237
+ if infer_metadata:
238
+ graph.infer_metadata(False)
239
+
240
+ if edges is None:
241
+ graph.infer_links(False)
242
+
243
+ if verbose:
244
+ graph.print_metadata()
245
+ graph.print_links()
246
+
247
+ return graph
248
+
249
+ @classmethod
250
+ def from_snowflake(
251
+ cls,
252
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
253
+ database: Optional[str] = None,
254
+ schema: Optional[str] = None,
255
+ table_names: Optional[Sequence[str]] = None,
256
+ edges: Optional[Sequence[Edge]] = None,
257
+ infer_metadata: bool = True,
258
+ verbose: bool = True,
259
+ ) -> Self:
260
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
261
+ schema.
262
+
263
+ Automatically infers table metadata and links by default.
264
+
265
+ .. code-block:: python
144
266
 
145
- Example:
146
267
  >>> # doctest: +SKIP
147
268
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
269
+
270
+ >>> # Create a graph directly in a Snowflake notebook:
271
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
272
+
273
+ Args:
274
+ connection: An open connection from
275
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
276
+ :class:`snowflake` connector keyword arguments to open a new
277
+ connection. If ``None``, will re-use an active session in case
278
+ it exists, or create a new connection from credentials stored
279
+ in environment variables.
280
+ database: The database.
281
+ schema: The schema.
282
+ table_names: Set of table names to include. If ``None``, will add
283
+ all tables present in the database.
284
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
285
+ add to the graph. If not provided, edges will be automatically
286
+ inferred from the data in case ``infer_metadata=True``.
287
+ infer_metadata: Whether to infer metadata for all tables in the
288
+ graph.
289
+ verbose: Whether to print verbose output.
157
290
  """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
291
+ from kumoai.experimental.rfm.backend.snow import (
292
+ Connection,
293
+ SnowTable,
294
+ connect,
295
+ )
296
+
297
+ if not isinstance(connection, Connection):
298
+ connection = connect(**(connection or {}))
299
+ assert isinstance(connection, Connection)
300
+
301
+ if table_names is None:
302
+ with connection.cursor() as cursor:
303
+ if database is None and schema is None:
304
+ cursor.execute("SELECT CURRENT_DATABASE(), "
305
+ "CURRENT_SCHEMA()")
306
+ result = cursor.fetchone()
307
+ database = database or result[0]
308
+ schema = schema or result[1]
309
+ cursor.execute(f"""
310
+ SELECT TABLE_NAME
311
+ FROM {database}.INFORMATION_SCHEMA.TABLES
312
+ WHERE TABLE_SCHEMA = '{schema}'
313
+ """)
314
+ table_names = [row[0] for row in cursor.fetchall()]
315
+
316
+ tables = [
317
+ SnowTable(
318
+ connection,
319
+ name=table_name,
320
+ database=database,
321
+ schema=schema,
322
+ ) for table_name in table_names
323
+ ]
159
324
 
160
325
  graph = cls(tables, edges=edges or [])
161
326
 
162
327
  if infer_metadata:
163
- graph.infer_metadata(verbose)
328
+ graph.infer_metadata(False)
164
329
 
165
330
  if edges is None:
166
- graph.infer_links(verbose)
331
+ graph.infer_links(False)
332
+
333
+ if verbose:
334
+ graph.print_metadata()
335
+ graph.print_links()
336
+
337
+ return graph
338
+
339
+ @classmethod
340
+ def from_snowflake_semantic_view(
341
+ cls,
342
+ semantic_view_name: str,
343
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
344
+ verbose: bool = True,
345
+ ) -> Self:
346
+ import yaml
347
+
348
+ from kumoai.experimental.rfm.backend.snow import (
349
+ Connection,
350
+ SnowTable,
351
+ connect,
352
+ )
353
+
354
+ if not isinstance(connection, Connection):
355
+ connection = connect(**(connection or {}))
356
+ assert isinstance(connection, Connection)
357
+
358
+ with connection.cursor() as cursor:
359
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
360
+ f"'{semantic_view_name}')")
361
+ view = yaml.safe_load(cursor.fetchone()[0])
362
+
363
+ graph = cls(tables=[])
364
+
365
+ for table_desc in view['tables']:
366
+ primary_key: Optional[str] = None
367
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
368
+ and len(table_desc['primary_key']['columns']) == 1):
369
+ primary_key = table_desc['primary_key']['columns'][0]
370
+
371
+ table = SnowTable(
372
+ connection,
373
+ name=table_desc['base_table']['table'],
374
+ database=table_desc['base_table']['database'],
375
+ schema=table_desc['base_table']['schema'],
376
+ primary_key=primary_key,
377
+ )
378
+ graph.add_table(table)
379
+
380
+ # TODO Find a solution to register time columns!
381
+
382
+ for relations in view['relationships']:
383
+ if len(relations['relationship_columns']) != 1:
384
+ continue # NOTE No composite keys yet.
385
+ graph.link(
386
+ src_table=relations['left_table'],
387
+ fkey=relations['relationship_columns'][0]['left_column'],
388
+ dst_table=relations['right_table'],
389
+ )
390
+
391
+ if verbose:
392
+ graph.print_metadata()
393
+ graph.print_links()
167
394
 
168
395
  return graph
169
396
 
@@ -175,7 +402,7 @@ class LocalGraph:
175
402
  """
176
403
  return name in self.tables
177
404
 
178
- def table(self, name: str) -> LocalTable:
405
+ def table(self, name: str) -> Table:
179
406
  r"""Returns the table with name ``name`` in the graph.
180
407
 
181
408
  Raises:
@@ -186,11 +413,11 @@ class LocalGraph:
186
413
  return self.tables[name]
187
414
 
188
415
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
416
+ def tables(self) -> Dict[str, Table]:
190
417
  r"""Returns the dictionary of table objects."""
191
418
  return self._tables
192
419
 
193
- def add_table(self, table: LocalTable) -> Self:
420
+ def add_table(self, table: Table) -> Self:
194
421
  r"""Adds a table to the graph.
195
422
 
196
423
  Args:
@@ -199,11 +426,21 @@ class LocalGraph:
199
426
  Raises:
200
427
  KeyError: If a table with the same name already exists in the
201
428
  graph.
429
+ ValueError: If the table belongs to a different backend than the
430
+ rest of the tables in the graph.
202
431
  """
203
432
  if table.name in self._tables:
204
433
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
434
  f"this graph; table names must be globally unique.")
206
435
 
436
+ if len(self._tables) > 0:
437
+ cls = next(iter(self._tables.values())).__class__
438
+ if table.__class__ != cls:
439
+ raise ValueError(f"Cannot register a "
440
+ f"'{table.__class__.__name__}' to this "
441
+ f"graph since other tables are of type "
442
+ f"'{cls.__name__}'.")
443
+
207
444
  self._tables[table.name] = table
208
445
 
209
446
  return self
@@ -241,7 +478,7 @@ class LocalGraph:
241
478
  Example:
242
479
  >>> # doctest: +SKIP
243
480
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
481
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
482
  >>> graph.metadata # doctest: +SKIP
246
483
  name primary_key time_column end_time_column
247
484
  0 users user_id - -
@@ -263,10 +500,14 @@ class LocalGraph:
263
500
  })
264
501
 
265
502
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
503
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
504
+ if in_snowflake_notebook():
505
+ import streamlit as st
506
+ st.markdown("### 🗂️ Graph Metadata")
507
+ st.dataframe(self.metadata, hide_index=True)
508
+ elif in_notebook():
268
509
  from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
510
+ display(Markdown("### 🗂️ Graph Metadata"))
270
511
  df = self.metadata
271
512
  try:
272
513
  if hasattr(df.style, 'hide'):
@@ -287,7 +528,7 @@ class LocalGraph:
287
528
 
288
529
  Note:
289
530
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
531
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
532
  """
292
533
  for table in self.tables.values():
293
534
  table.infer_metadata(verbose=False)
@@ -305,37 +546,47 @@ class LocalGraph:
305
546
  return self._edges
306
547
 
307
548
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
549
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
550
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
551
  edge.src_table, edge.fkey) for edge in self.edges]
311
552
  edges = sorted(edges)
312
553
 
313
- if in_notebook():
554
+ if in_snowflake_notebook():
555
+ import streamlit as st
556
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
557
+ if len(edges) > 0:
558
+ st.markdown('\n'.join([
559
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
560
+ for edge in edges
561
+ ]))
562
+ else:
563
+ st.markdown("*No links registered*")
564
+ elif in_notebook():
314
565
  from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
566
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
316
567
  if len(edges) > 0:
317
568
  display(
318
569
  Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
570
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
320
571
  for edge in edges
321
572
  ])))
322
573
  else:
323
- display(Markdown('*No links registered*'))
574
+ display(Markdown("*No links registered*"))
324
575
  else:
325
576
  print("🕸️ Graph Links (FK ↔️ PK):")
326
577
  if len(edges) > 0:
327
578
  print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
579
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
329
580
  for edge in edges
330
581
  ]))
331
582
  else:
332
- print('No links registered')
583
+ print("No links registered")
333
584
 
334
585
  def link(
335
586
  self,
336
- src_table: Union[str, LocalTable],
587
+ src_table: Union[str, Table],
337
588
  fkey: str,
338
- dst_table: Union[str, LocalTable],
589
+ dst_table: Union[str, Table],
339
590
  ) -> Self:
340
591
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
592
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +609,11 @@ class LocalGraph:
358
609
  table does not exist in the graph, if the source key does not
359
610
  exist in the source table.
360
611
  """
361
- if isinstance(src_table, LocalTable):
612
+ if isinstance(src_table, Table):
362
613
  src_table = src_table.name
363
614
  assert isinstance(src_table, str)
364
615
 
365
- if isinstance(dst_table, LocalTable):
616
+ if isinstance(dst_table, Table):
366
617
  dst_table = dst_table.name
367
618
  assert isinstance(dst_table, str)
368
619
 
@@ -396,9 +647,9 @@ class LocalGraph:
396
647
 
397
648
  def unlink(
398
649
  self,
399
- src_table: Union[str, LocalTable],
650
+ src_table: Union[str, Table],
400
651
  fkey: str,
401
- dst_table: Union[str, LocalTable],
652
+ dst_table: Union[str, Table],
402
653
  ) -> Self:
403
654
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
655
 
@@ -410,11 +661,11 @@ class LocalGraph:
410
661
  Raises:
411
662
  ValueError: if the edge is not present in the graph.
412
663
  """
413
- if isinstance(src_table, LocalTable):
664
+ if isinstance(src_table, Table):
414
665
  src_table = src_table.name
415
666
  assert isinstance(src_table, str)
416
667
 
417
- if isinstance(dst_table, LocalTable):
668
+ if isinstance(dst_table, Table):
418
669
  dst_table = dst_table.name
419
670
  assert isinstance(dst_table, str)
420
671
 
@@ -428,17 +679,13 @@ class LocalGraph:
428
679
  return self
429
680
 
430
681
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
682
+ r"""Infers missing links for the tables and adds them as edges to the
683
+ graph.
432
684
 
433
685
  Args:
434
686
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
687
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
688
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
442
689
 
443
690
  # A list of primary key candidates (+score) for every column:
444
691
  candidate_dict: dict[
@@ -463,6 +710,9 @@ class LocalGraph:
463
710
  src_table_name = src_table.name.lower()
464
711
 
465
712
  for src_key in src_table.columns:
713
+ if (src_table.name, src_key.name) in known_edges:
714
+ continue
715
+
466
716
  if src_key == src_table.primary_key:
467
717
  continue # Cannot link to primary key.
468
718
 
@@ -528,7 +778,9 @@ class LocalGraph:
528
778
  score += 1.0
529
779
 
530
780
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
781
+ if (src_table._num_rows is not None
782
+ and dst_table._num_rows is not None
783
+ and src_table._num_rows > dst_table._num_rows):
532
784
  score += 1.0
533
785
 
534
786
  if score < 5.0:
@@ -645,19 +897,19 @@ class LocalGraph:
645
897
 
646
898
  return True
647
899
 
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
900
+ try: # Check basic dependency:
901
+ import graphviz
902
+ except ImportError as e:
903
+ raise ImportError("The 'graphviz' package is required for "
904
+ "visualization") from e
905
+
906
+ if not in_snowflake_notebook() and not has_graphviz_executables():
653
907
  raise RuntimeError("Could not visualize graph as 'graphviz' "
654
908
  "executables are not installed. These "
655
909
  "dependencies are required in addition to the "
656
910
  "'graphviz' Python package. Please install "
657
911
  "them as described at "
658
912
  "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
913
 
662
914
  format: Optional[str] = None
663
915
  if isinstance(path, str):
@@ -741,6 +993,9 @@ class LocalGraph:
741
993
  graph.render(path, cleanup=True)
742
994
  elif isinstance(path, io.BytesIO):
743
995
  path.write(graph.pipe())
996
+ elif in_snowflake_notebook():
997
+ import streamlit as st
998
+ st.graphviz_chart(graph)
744
999
  elif in_notebook():
745
1000
  from IPython.display import display
746
1001
  display(graph)
@@ -790,7 +1045,7 @@ class LocalGraph:
790
1045
  def __contains__(self, name: str) -> bool:
791
1046
  return self.has_table(name)
792
1047
 
793
- def __getitem__(self, name: str) -> LocalTable:
1048
+ def __getitem__(self, name: str) -> Table:
794
1049
  return self.table(name)
795
1050
 
796
1051
  def __delitem__(self, name: str) -> None:
@@ -1,9 +1,15 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
8
 
6
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
7
13
  'contains_id',
8
14
  'contains_timestamp',
9
15
  'contains_categorical',