kumoai 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/experimental/rfm/__init__.py +49 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  11. kumoai/experimental/rfm/backend/local/table.py +32 -14
  12. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +186 -39
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +131 -41
  18. kumoai/experimental/rfm/base/__init__.py +23 -3
  19. kumoai/experimental/rfm/base/column.py +96 -10
  20. kumoai/experimental/rfm/base/expression.py +44 -0
  21. kumoai/experimental/rfm/base/sampler.py +761 -0
  22. kumoai/experimental/rfm/base/source.py +2 -1
  23. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  24. kumoai/experimental/rfm/base/table.py +380 -185
  25. kumoai/experimental/rfm/graph.py +404 -144
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +52 -60
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +1 -2
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +283 -230
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/pquery/predictive_query.py +10 -6
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/trainer/distilled_trainer.py +175 -0
  40. kumoai/utils/__init__.py +3 -2
  41. kumoai/utils/display.py +51 -0
  42. kumoai/utils/progress_logger.py +178 -12
  43. kumoai/utils/sql.py +3 -0
  44. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +4 -2
  45. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +48 -38
  46. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  47. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  48. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
  49. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
  50. {kumoai-2.13.0.dev202512031731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
4
+ import copy
2
5
  import io
3
6
  import warnings
4
7
  from collections import defaultdict
8
+ from collections.abc import Sequence
5
9
  from dataclasses import dataclass, field
6
- from importlib.util import find_spec
10
+ from itertools import chain
7
11
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
12
+ from typing import TYPE_CHECKING, Any, Union
9
13
 
10
14
  import pandas as pd
11
15
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,10 +17,11 @@ from kumoapi.table import TableDefinition
13
17
  from kumoapi.typing import Stype
14
18
  from typing_extensions import Self
15
19
 
16
- from kumoai import in_notebook
17
- from kumoai.experimental.rfm import Table
20
+ from kumoai import in_notebook, in_snowflake_notebook
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
18
22
  from kumoai.graph import Edge
19
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
20
25
 
21
26
  if TYPE_CHECKING:
22
27
  import graphviz
@@ -26,8 +31,8 @@ if TYPE_CHECKING:
26
31
 
27
32
  @dataclass
28
33
  class SqliteConnectionConfig(CastMixin):
29
- uri: Union[str, Path]
30
- kwargs: Dict[str, Any] = field(default_factory=dict)
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
31
36
 
32
37
 
33
38
  class Graph:
@@ -87,27 +92,35 @@ class Graph:
87
92
  def __init__(
88
93
  self,
89
94
  tables: Sequence[Table],
90
- edges: Optional[Sequence[Edge]] = None,
95
+ edges: Sequence[Edge] | None = None,
91
96
  ) -> None:
92
97
 
93
- self._tables: Dict[str, Table] = {}
94
- self._edges: List[Edge] = []
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
95
100
 
96
101
  for table in tables:
97
102
  self.add_table(table)
98
103
 
99
- for table in tables:
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
100
107
  for fkey in table._source_foreign_key_dict.values():
101
- if fkey.name not in table or fkey.dst_table not in self:
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
117
+ continue
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
120
+ continue
121
+ if not dst_table[fkey.primary_key].is_source:
102
122
  continue
103
- if self[fkey.dst_table].primary_key is None:
104
- self[fkey.dst_table].primary_key = fkey.primary_key
105
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
- raise ValueError(f"Found duplicate primary key definition "
107
- f"'{self[fkey.dst_table]._primary_key}' "
108
- f"and '{fkey.primary_key}' in table "
109
- f"'{fkey.dst_table}'.")
110
- self.link(table.name, fkey.name, fkey.dst_table)
123
+ self.link(table.name, fkey.name, dst_table.name)
111
124
 
112
125
  for edge in (edges or []):
113
126
  _edge = Edge._cast(edge)
@@ -118,8 +131,8 @@ class Graph:
118
131
  @classmethod
119
132
  def from_data(
120
133
  cls,
121
- df_dict: Dict[str, pd.DataFrame],
122
- edges: Optional[Sequence[Edge]] = None,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
123
136
  infer_metadata: bool = True,
124
137
  verbose: bool = True,
125
138
  ) -> Self:
@@ -157,15 +170,17 @@ class Graph:
157
170
  verbose: Whether to print verbose output.
158
171
  """
159
172
  from kumoai.experimental.rfm.backend.local import LocalTable
160
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
173
 
162
- graph = cls(tables, edges=edges or [])
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
163
178
 
164
179
  if infer_metadata:
165
- graph.infer_metadata(False)
180
+ graph.infer_metadata(verbose=False)
166
181
 
167
182
  if edges is None:
168
- graph.infer_links(False)
183
+ graph.infer_links(verbose=False)
169
184
 
170
185
  if verbose:
171
186
  graph.print_metadata()
@@ -181,10 +196,10 @@ class Graph:
181
196
  SqliteConnectionConfig,
182
197
  str,
183
198
  Path,
184
- Dict[str, Any],
199
+ dict[str, Any],
185
200
  ],
186
- table_names: Optional[Sequence[str]] = None,
187
- edges: Optional[Sequence[Edge]] = None,
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
188
203
  infer_metadata: bool = True,
189
204
  verbose: bool = True,
190
205
  ) -> Self:
@@ -200,17 +215,25 @@ class Graph:
200
215
  >>> # Create a graph from a SQLite database:
201
216
  >>> graph = rfm.Graph.from_sqlite('data.db')
202
217
 
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
203
225
  Args:
204
226
  connection: An open connection from
205
227
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
228
  path to the database file.
207
- table_names: Set of table names to include. If ``None``, will add
208
- all tables present in the database.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
209
232
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
210
233
  add to the graph. If not provided, edges will be automatically
211
234
  inferred from the data in case ``infer_metadata=True``.
212
- infer_metadata: Whether to infer metadata for all tables in the
213
- graph.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
214
237
  verbose: Whether to print verbose output.
215
238
  """
216
239
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -219,27 +242,41 @@ class Graph:
219
242
  connect,
220
243
  )
221
244
 
245
+ internal_connection = False
222
246
  if not isinstance(connection, Connection):
223
247
  connection = SqliteConnectionConfig._cast(connection)
224
248
  assert isinstance(connection, SqliteConnectionConfig)
225
249
  connection = connect(connection.uri, **connection.kwargs)
250
+ internal_connection = True
226
251
  assert isinstance(connection, Connection)
227
252
 
228
- if table_names is None:
253
+ if tables is None:
229
254
  with connection.cursor() as cursor:
230
255
  cursor.execute("SELECT name FROM sqlite_master "
231
256
  "WHERE type='table'")
232
- table_names = [row[0] for row in cursor.fetchall()]
257
+ tables = [row[0] for row in cursor.fetchall()]
233
258
 
234
- tables = [SQLiteTable(connection, name) for name in table_names]
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
263
+
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
235
271
 
236
- graph = cls(tables, edges=edges or [])
272
+ if internal_connection:
273
+ graph._connection = connection # type: ignore
237
274
 
238
275
  if infer_metadata:
239
- graph.infer_metadata(False)
276
+ graph.infer_metadata(verbose=False)
240
277
 
241
278
  if edges is None:
242
- graph.infer_links(False)
279
+ graph.infer_links(verbose=False)
243
280
 
244
281
  if verbose:
245
282
  graph.print_metadata()
@@ -250,9 +287,11 @@ class Graph:
250
287
  @classmethod
251
288
  def from_snowflake(
252
289
  cls,
253
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
- table_names: Optional[Sequence[str]] = None,
255
- edges: Optional[Sequence[Edge]] = None,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
256
295
  infer_metadata: bool = True,
257
296
  verbose: bool = True,
258
297
  ) -> Self:
@@ -267,7 +306,14 @@ class Graph:
267
306
  >>> import kumoai.experimental.rfm as rfm
268
307
 
269
308
  >>> # Create a graph directly in a Snowflake notebook:
270
- >>> graph = rfm.Graph.from_snowflake()
309
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
310
+
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
271
317
 
272
318
  Args:
273
319
  connection: An open connection from
@@ -276,8 +322,11 @@ class Graph:
276
322
  connection. If ``None``, will re-use an active session in case
277
323
  it exists, or create a new connection from credentials stored
278
324
  in environment variables.
279
- table_names: Set of table names to include. If ``None``, will add
280
- all tables present in the database.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
328
+ database: The database.
329
+ schema: The schema.
281
330
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
331
  add to the graph. If not provided, edges will be automatically
283
332
  inferred from the data in case ``infer_metadata=True``.
@@ -295,27 +344,50 @@ class Graph:
295
344
  connection = connect(**(connection or {}))
296
345
  assert isinstance(connection, Connection)
297
346
 
298
- if table_names is None:
347
+ if database is None or schema is None:
299
348
  with connection.cursor() as cursor:
300
349
  cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
- database, schema = cursor.fetchone()
302
- query = f"""
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
360
+ with connection.cursor() as cursor:
361
+ cursor.execute(f"""
303
362
  SELECT TABLE_NAME
304
363
  FROM {database}.INFORMATION_SCHEMA.TABLES
305
364
  WHERE TABLE_SCHEMA = '{schema}'
306
- """
307
- cursor.execute(query)
308
- table_names = [row[0] for row in cursor.fetchall()]
309
-
310
- tables = [SnowTable(connection, name) for name in table_names]
365
+ """)
366
+ tables = [row[0] for row in cursor.fetchall()]
311
367
 
312
- graph = cls(tables, edges=edges or [])
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
313
385
 
314
386
  if infer_metadata:
315
- graph.infer_metadata(False)
387
+ graph.infer_metadata(verbose=False)
316
388
 
317
389
  if edges is None:
318
- graph.infer_links(False)
390
+ graph.infer_links(verbose=False)
319
391
 
320
392
  if verbose:
321
393
  graph.print_metadata()
@@ -323,7 +395,187 @@ class Graph:
323
395
 
324
396
  return graph
325
397
 
326
- # Tables ##############################################################
398
+ @classmethod
399
+ def from_snowflake_semantic_view(
400
+ cls,
401
+ semantic_view_name: str,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
403
+ verbose: bool = True,
404
+ ) -> Self:
405
+ import yaml
406
+
407
+ from kumoai.experimental.rfm.backend.snow import (
408
+ Connection,
409
+ SnowTable,
410
+ connect,
411
+ )
412
+
413
+ if not isinstance(connection, Connection):
414
+ connection = connect(**(connection or {}))
415
+ assert isinstance(connection, Connection)
416
+
417
+ with connection.cursor() as cursor:
418
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
419
+ f"'{semantic_view_name}')")
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
421
+
422
+ graph = cls(tables=[])
423
+
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
483
+
484
+ table = SnowTable(
485
+ connection,
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
491
+ primary_key=primary_key,
492
+ )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
497
+ graph.add_table(table)
498
+
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
505
+
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
530
+
531
+ if verbose:
532
+ graph.print_metadata()
533
+ graph.print_links()
534
+
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
540
+ return graph
541
+
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
571
+ # Backend #################################################################
572
+
573
+ @property
574
+ def backend(self) -> DataBackend | None:
575
+ backends = [table.backend for table in self._tables.values()]
576
+ return backends[0] if len(backends) > 0 else None
577
+
578
+ # Tables ##################################################################
327
579
 
328
580
  def has_table(self, name: str) -> bool:
329
581
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -342,7 +594,7 @@ class Graph:
342
594
  return self.tables[name]
343
595
 
344
596
  @property
345
- def tables(self) -> Dict[str, Table]:
597
+ def tables(self) -> dict[str, Table]:
346
598
  r"""Returns the dictionary of table objects."""
347
599
  return self._tables
348
600
 
@@ -362,13 +614,10 @@ class Graph:
362
614
  raise KeyError(f"Cannot add table with name '{table.name}' to "
363
615
  f"this graph; table names must be globally unique.")
364
616
 
365
- if len(self._tables) > 0:
366
- cls = next(iter(self._tables.values())).__class__
367
- if table.__class__ != cls:
368
- raise ValueError(f"Cannot register a "
369
- f"'{table.__class__.__name__}' to this "
370
- f"graph since other tables are of type "
371
- f"'{cls.__name__}'.")
617
+ if self.backend is not None and table.backend != self.backend:
618
+ raise ValueError(f"Cannot register a table with backend "
619
+ f"'{table.backend}' to this graph since other "
620
+ f"tables have backend '{self.backend}'.")
372
621
 
373
622
  self._tables[table.name] = table
374
623
 
@@ -430,20 +679,8 @@ class Graph:
430
679
 
431
680
  def print_metadata(self) -> None:
432
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
433
- if in_notebook():
434
- from IPython.display import Markdown, display
435
- display(Markdown('### 🗂️ Graph Metadata'))
436
- df = self.metadata
437
- try:
438
- if hasattr(df.style, 'hide'):
439
- display(df.style.hide(axis='index')) # pandas=2
440
- else:
441
- display(df.style.hide_index()) # pandas<1.3
442
- except ImportError:
443
- print(df.to_string(index=False)) # missing jinja2
444
- else:
445
- print("🗂️ Graph Metadata:")
446
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
447
684
 
448
685
  def infer_metadata(self, verbose: bool = True) -> Self:
449
686
  r"""Infers metadata for all tables in the graph.
@@ -466,42 +703,33 @@ class Graph:
466
703
  # Edges ###################################################################
467
704
 
468
705
  @property
469
- def edges(self) -> List[Edge]:
706
+ def edges(self) -> list[Edge]:
470
707
  r"""Returns the edges of the graph."""
471
708
  return self._edges
472
709
 
473
710
  def print_links(self) -> None:
474
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
475
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
476
- edge.src_table, edge.fkey) for edge in self.edges]
477
- edges = sorted(edges)
478
-
479
- if in_notebook():
480
- from IPython.display import Markdown, display
481
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
482
- if len(edges) > 0:
483
- display(
484
- Markdown('\n'.join([
485
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
486
- for edge in edges
487
- ])))
488
- else:
489
- display(Markdown('*No links registered*'))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
490
725
  else:
491
- print("🕸️ Graph Links (FK ↔️ PK):")
492
- if len(edges) > 0:
493
- print('\n'.join([
494
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
495
- for edge in edges
496
- ]))
497
- else:
498
- print('No links registered')
726
+ display.italic("No links registered")
499
727
 
500
728
  def link(
501
729
  self,
502
- src_table: Union[str, Table],
730
+ src_table: str | Table,
503
731
  fkey: str,
504
- dst_table: Union[str, Table],
732
+ dst_table: str | Table,
505
733
  ) -> Self:
506
734
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
507
735
  key ``fkey`` in the source table to the primary key in the destination
@@ -562,9 +790,9 @@ class Graph:
562
790
 
563
791
  def unlink(
564
792
  self,
565
- src_table: Union[str, Table],
793
+ src_table: str | Table,
566
794
  fkey: str,
567
- dst_table: Union[str, Table],
795
+ dst_table: str | Table,
568
796
  ) -> Self:
569
797
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
570
798
 
@@ -602,6 +830,30 @@ class Graph:
602
830
  """
603
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
604
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
605
857
  # A list of primary key candidates (+score) for every column:
606
858
  candidate_dict: dict[
607
859
  tuple[str, str],
@@ -701,13 +953,8 @@ class Graph:
701
953
  if score < 5.0:
702
954
  continue
703
955
 
704
- candidate_dict[(
705
- src_table.name,
706
- src_key.name,
707
- )].append((
708
- dst_table.name,
709
- score,
710
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
711
958
 
712
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
713
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -741,6 +988,10 @@ class Graph:
741
988
  raise ValueError("At least one table needs to be added to the "
742
989
  "graph")
743
990
 
991
+ backends = {table.backend for table in self._tables.values()}
992
+ if len(backends) != 1:
993
+ raise ValueError("Found multiple table backends in the graph")
994
+
744
995
  for edge in self.edges:
745
996
  src_table, fkey, dst_table = edge
746
997
 
@@ -762,24 +1013,26 @@ class Graph:
762
1013
  f"either the primary key or the link before "
763
1014
  f"before proceeding.")
764
1015
 
765
- # Check that fkey/pkey have valid and consistent data types:
766
- assert src_key.dtype is not None
767
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
768
- src_string = src_key.dtype.is_string()
769
- assert dst_key.dtype is not None
770
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
771
- dst_string = dst_key.dtype.is_string()
772
-
773
- if not src_number and not src_string:
774
- raise ValueError(f"{edge} is invalid as foreign key must be a "
775
- f"number or string (got '{src_key.dtype}'")
776
-
777
- if src_number != dst_number or src_string != dst_string:
778
- raise ValueError(f"{edge} is invalid as foreign key "
779
- f"'{fkey}' and primary key '{dst_key.name}' "
780
- f"have incompatible data types (got "
781
- f"fkey.dtype '{src_key.dtype}' and "
782
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
783
1036
 
784
1037
  return self
785
1038
 
@@ -787,7 +1040,7 @@ class Graph:
787
1040
 
788
1041
  def visualize(
789
1042
  self,
790
- path: Optional[Union[str, io.BytesIO]] = None,
1043
+ path: str | io.BytesIO | None = None,
791
1044
  show_columns: bool = True,
792
1045
  ) -> 'graphviz.Graph':
793
1046
  r"""Visualizes the tables and edges in this graph using the
@@ -812,33 +1065,33 @@ class Graph:
812
1065
 
813
1066
  return True
814
1067
 
815
- # Check basic dependency:
816
- if not find_spec('graphviz'):
817
- raise ModuleNotFoundError("The 'graphviz' package is required for "
818
- "visualization")
819
- elif not has_graphviz_executables():
1068
+ try: # Check basic dependency:
1069
+ import graphviz
1070
+ except ImportError as e:
1071
+ raise ImportError("The 'graphviz' package is required for "
1072
+ "visualization") from e
1073
+
1074
+ if not in_snowflake_notebook() and not has_graphviz_executables():
820
1075
  raise RuntimeError("Could not visualize graph as 'graphviz' "
821
1076
  "executables are not installed. These "
822
1077
  "dependencies are required in addition to the "
823
1078
  "'graphviz' Python package. Please install "
824
1079
  "them as described at "
825
1080
  "https://graphviz.org/download/.")
826
- else:
827
- import graphviz
828
1081
 
829
- format: Optional[str] = None
1082
+ format: str | None = None
830
1083
  if isinstance(path, str):
831
1084
  format = path.split('.')[-1]
832
1085
  elif isinstance(path, io.BytesIO):
833
1086
  format = 'svg'
834
1087
  graph = graphviz.Graph(format=format)
835
1088
 
836
- def left_align(keys: List[str]) -> str:
1089
+ def left_align(keys: list[str]) -> str:
837
1090
  if len(keys) == 0:
838
1091
  return ""
839
1092
  return '\\l'.join(keys) + '\\l'
840
1093
 
841
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
842
1095
  for src_table_name, fkey_name, _ in self.edges:
843
1096
  fkeys_dict[src_table_name].append(fkey_name)
844
1097
 
@@ -908,6 +1161,9 @@ class Graph:
908
1161
  graph.render(path, cleanup=True)
909
1162
  elif isinstance(path, io.BytesIO):
910
1163
  path.write(graph.pipe())
1164
+ elif in_snowflake_notebook():
1165
+ import streamlit as st
1166
+ st.graphviz_chart(graph)
911
1167
  elif in_notebook():
912
1168
  from IPython.display import display
913
1169
  display(graph)
@@ -931,8 +1187,8 @@ class Graph:
931
1187
  # Helpers #################################################################
932
1188
 
933
1189
  def _to_api_graph_definition(self) -> GraphDefinition:
934
- tables: Dict[str, TableDefinition] = {}
935
- col_groups: List[ColumnKeyGroup] = []
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
936
1192
  for table_name, table in self.tables.items():
937
1193
  tables[table_name] = table._to_api_table_definition()
938
1194
  if table.primary_key is None:
@@ -975,3 +1231,7 @@ class Graph:
975
1231
  f' tables={tables},\n'
976
1232
  f' edges={edges},\n'
977
1233
  f')')
1234
+
1235
+ def __del__(self) -> None:
1236
+ if hasattr(self, '_connection'):
1237
+ self._connection.close()