kumoai 2.13.0.dev202512040252__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +181 -51
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +23 -3
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +783 -0
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +385 -203
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +374 -172
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -41
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512040252.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
4
+ import copy
2
5
  import io
3
6
  import warnings
4
7
  from collections import defaultdict
8
+ from collections.abc import Sequence
5
9
  from dataclasses import dataclass, field
6
- from importlib.util import find_spec
10
+ from itertools import chain
7
11
  from pathlib import Path
8
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
12
+ from typing import TYPE_CHECKING, Any, Union
9
13
 
10
14
  import pandas as pd
11
15
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,10 +17,11 @@ from kumoapi.table import TableDefinition
13
17
  from kumoapi.typing import Stype
14
18
  from typing_extensions import Self
15
19
 
16
- from kumoai import in_notebook
17
- from kumoai.experimental.rfm import Table
20
+ from kumoai import in_notebook, in_snowflake_notebook
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
18
22
  from kumoai.graph import Edge
19
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
20
25
 
21
26
  if TYPE_CHECKING:
22
27
  import graphviz
@@ -26,8 +31,8 @@ if TYPE_CHECKING:
26
31
 
27
32
  @dataclass
28
33
  class SqliteConnectionConfig(CastMixin):
29
- uri: Union[str, Path]
30
- kwargs: Dict[str, Any] = field(default_factory=dict)
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
31
36
 
32
37
 
33
38
  class Graph:
@@ -87,27 +92,35 @@ class Graph:
87
92
  def __init__(
88
93
  self,
89
94
  tables: Sequence[Table],
90
- edges: Optional[Sequence[Edge]] = None,
95
+ edges: Sequence[Edge] | None = None,
91
96
  ) -> None:
92
97
 
93
- self._tables: Dict[str, Table] = {}
94
- self._edges: List[Edge] = []
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
95
100
 
96
101
  for table in tables:
97
102
  self.add_table(table)
98
103
 
99
- for table in tables:
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
100
107
  for fkey in table._source_foreign_key_dict.values():
101
- if fkey.name not in table or fkey.dst_table not in self:
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
102
117
  continue
103
- if self[fkey.dst_table].primary_key is None:
104
- self[fkey.dst_table].primary_key = fkey.primary_key
105
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
- raise ValueError(f"Found duplicate primary key definition "
107
- f"'{self[fkey.dst_table]._primary_key}' "
108
- f"and '{fkey.primary_key}' in table "
109
- f"'{fkey.dst_table}'.")
110
- self.link(table.name, fkey.name, fkey.dst_table)
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
120
+ continue
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
123
+ self.link(table.name, fkey.name, dst_table.name)
111
124
 
112
125
  for edge in (edges or []):
113
126
  _edge = Edge._cast(edge)
@@ -118,8 +131,8 @@ class Graph:
118
131
  @classmethod
119
132
  def from_data(
120
133
  cls,
121
- df_dict: Dict[str, pd.DataFrame],
122
- edges: Optional[Sequence[Edge]] = None,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
123
136
  infer_metadata: bool = True,
124
137
  verbose: bool = True,
125
138
  ) -> Self:
@@ -157,15 +170,17 @@ class Graph:
157
170
  verbose: Whether to print verbose output.
158
171
  """
159
172
  from kumoai.experimental.rfm.backend.local import LocalTable
160
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
173
 
162
- graph = cls(tables, edges=edges or [])
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
163
178
 
164
179
  if infer_metadata:
165
- graph.infer_metadata(False)
180
+ graph.infer_metadata(verbose=False)
166
181
 
167
182
  if edges is None:
168
- graph.infer_links(False)
183
+ graph.infer_links(verbose=False)
169
184
 
170
185
  if verbose:
171
186
  graph.print_metadata()
@@ -181,10 +196,10 @@ class Graph:
181
196
  SqliteConnectionConfig,
182
197
  str,
183
198
  Path,
184
- Dict[str, Any],
199
+ dict[str, Any],
185
200
  ],
186
- table_names: Optional[Sequence[str]] = None,
187
- edges: Optional[Sequence[Edge]] = None,
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
188
203
  infer_metadata: bool = True,
189
204
  verbose: bool = True,
190
205
  ) -> Self:
@@ -200,17 +215,25 @@ class Graph:
200
215
  >>> # Create a graph from a SQLite database:
201
216
  >>> graph = rfm.Graph.from_sqlite('data.db')
202
217
 
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
203
225
  Args:
204
226
  connection: An open connection from
205
227
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
228
  path to the database file.
207
- table_names: Set of table names to include. If ``None``, will add
208
- all tables present in the database.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
209
232
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
210
233
  add to the graph. If not provided, edges will be automatically
211
234
  inferred from the data in case ``infer_metadata=True``.
212
- infer_metadata: Whether to infer metadata for all tables in the
213
- graph.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
214
237
  verbose: Whether to print verbose output.
215
238
  """
216
239
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -219,27 +242,41 @@ class Graph:
219
242
  connect,
220
243
  )
221
244
 
245
+ internal_connection = False
222
246
  if not isinstance(connection, Connection):
223
247
  connection = SqliteConnectionConfig._cast(connection)
224
248
  assert isinstance(connection, SqliteConnectionConfig)
225
249
  connection = connect(connection.uri, **connection.kwargs)
250
+ internal_connection = True
226
251
  assert isinstance(connection, Connection)
227
252
 
228
- if table_names is None:
253
+ if tables is None:
229
254
  with connection.cursor() as cursor:
230
255
  cursor.execute("SELECT name FROM sqlite_master "
231
256
  "WHERE type='table'")
232
- table_names = [row[0] for row in cursor.fetchall()]
257
+ tables = [row[0] for row in cursor.fetchall()]
233
258
 
234
- tables = [SQLiteTable(connection, name) for name in table_names]
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
263
+
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
235
271
 
236
- graph = cls(tables, edges=edges or [])
272
+ if internal_connection:
273
+ graph._connection = connection # type: ignore
237
274
 
238
275
  if infer_metadata:
239
- graph.infer_metadata(False)
276
+ graph.infer_metadata(verbose=False)
240
277
 
241
278
  if edges is None:
242
- graph.infer_links(False)
279
+ graph.infer_links(verbose=False)
243
280
 
244
281
  if verbose:
245
282
  graph.print_metadata()
@@ -250,9 +287,11 @@ class Graph:
250
287
  @classmethod
251
288
  def from_snowflake(
252
289
  cls,
253
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
- table_names: Optional[Sequence[str]] = None,
255
- edges: Optional[Sequence[Edge]] = None,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
256
295
  infer_metadata: bool = True,
257
296
  verbose: bool = True,
258
297
  ) -> Self:
@@ -267,7 +306,14 @@ class Graph:
267
306
  >>> import kumoai.experimental.rfm as rfm
268
307
 
269
308
  >>> # Create a graph directly in a Snowflake notebook:
270
- >>> graph = rfm.Graph.from_snowflake()
309
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
310
+
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
271
317
 
272
318
  Args:
273
319
  connection: An open connection from
@@ -276,8 +322,11 @@ class Graph:
276
322
  connection. If ``None``, will re-use an active session in case
277
323
  it exists, or create a new connection from credentials stored
278
324
  in environment variables.
279
- table_names: Set of table names to include. If ``None``, will add
280
- all tables present in the database.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
328
+ database: The database.
329
+ schema: The schema.
281
330
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
282
331
  add to the graph. If not provided, edges will be automatically
283
332
  inferred from the data in case ``infer_metadata=True``.
@@ -295,27 +344,50 @@ class Graph:
295
344
  connection = connect(**(connection or {}))
296
345
  assert isinstance(connection, Connection)
297
346
 
298
- if table_names is None:
347
+ if database is None or schema is None:
299
348
  with connection.cursor() as cursor:
300
349
  cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
301
- database, schema = cursor.fetchone()
302
- query = f"""
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
360
+ with connection.cursor() as cursor:
361
+ cursor.execute(f"""
303
362
  SELECT TABLE_NAME
304
363
  FROM {database}.INFORMATION_SCHEMA.TABLES
305
364
  WHERE TABLE_SCHEMA = '{schema}'
306
- """
307
- cursor.execute(query)
308
- table_names = [row[0] for row in cursor.fetchall()]
365
+ """)
366
+ tables = [row[0] for row in cursor.fetchall()]
309
367
 
310
- tables = [SnowTable(connection, name) for name in table_names]
311
-
312
- graph = cls(tables, edges=edges or [])
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
313
385
 
314
386
  if infer_metadata:
315
- graph.infer_metadata(False)
387
+ graph.infer_metadata(verbose=False)
316
388
 
317
389
  if edges is None:
318
- graph.infer_links(False)
390
+ graph.infer_links(verbose=False)
319
391
 
320
392
  if verbose:
321
393
  graph.print_metadata()
@@ -327,7 +399,7 @@ class Graph:
327
399
  def from_snowflake_semantic_view(
328
400
  cls,
329
401
  semantic_view_name: str,
330
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
331
403
  verbose: bool = True,
332
404
  ) -> Self:
333
405
  import yaml
@@ -345,43 +417,165 @@ class Graph:
345
417
  with connection.cursor() as cursor:
346
418
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
347
419
  f"'{semantic_view_name}')")
348
- view = yaml.safe_load(cursor.fetchone()[0])
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
349
421
 
350
422
  graph = cls(tables=[])
351
423
 
352
- for table_desc in view['tables']:
353
- primary_key: Optional[str] = None
354
- if ('primary_key' in table_desc # NOTE No composite keys yet.
355
- and len(table_desc['primary_key']['columns']) == 1):
356
- primary_key = table_desc['primary_key']['columns'][0]
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
357
483
 
358
484
  table = SnowTable(
359
485
  connection,
360
- name=table_desc['base_table']['table'],
361
- database=table_desc['base_table']['database'],
362
- schema=table_desc['base_table']['schema'],
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
363
491
  primary_key=primary_key,
364
492
  )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
365
497
  graph.add_table(table)
366
498
 
367
- # TODO Find a solution to register time columns!
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
368
505
 
369
- for relations in view['relationships']:
370
- if len(relations['relationship_columns']) != 1:
371
- continue # NOTE No composite keys yet.
372
- graph.link(
373
- src_table=relations['left_table'],
374
- fkey=relations['relationship_columns'][0]['left_column'],
375
- dst_table=relations['right_table'],
376
- )
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
377
530
 
378
531
  if verbose:
379
532
  graph.print_metadata()
380
533
  graph.print_links()
381
534
 
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
382
540
  return graph
383
541
 
384
- # Tables ##############################################################
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
571
+ # Backend #################################################################
572
+
573
+ @property
574
+ def backend(self) -> DataBackend | None:
575
+ backends = [table.backend for table in self._tables.values()]
576
+ return backends[0] if len(backends) > 0 else None
577
+
578
+ # Tables ##################################################################
385
579
 
386
580
  def has_table(self, name: str) -> bool:
387
581
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -400,7 +594,7 @@ class Graph:
400
594
  return self.tables[name]
401
595
 
402
596
  @property
403
- def tables(self) -> Dict[str, Table]:
597
+ def tables(self) -> dict[str, Table]:
404
598
  r"""Returns the dictionary of table objects."""
405
599
  return self._tables
406
600
 
@@ -420,13 +614,10 @@ class Graph:
420
614
  raise KeyError(f"Cannot add table with name '{table.name}' to "
421
615
  f"this graph; table names must be globally unique.")
422
616
 
423
- if len(self._tables) > 0:
424
- cls = next(iter(self._tables.values())).__class__
425
- if table.__class__ != cls:
426
- raise ValueError(f"Cannot register a "
427
- f"'{table.__class__.__name__}' to this "
428
- f"graph since other tables are of type "
429
- f"'{cls.__name__}'.")
617
+ if self.backend is not None and table.backend != self.backend:
618
+ raise ValueError(f"Cannot register a table with backend "
619
+ f"'{table.backend}' to this graph since other "
620
+ f"tables have backend '{self.backend}'.")
430
621
 
431
622
  self._tables[table.name] = table
432
623
 
@@ -458,28 +649,28 @@ class Graph:
458
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
459
650
  information about the tables in this graph.
460
651
 
461
- The returned dataframe has columns ``name``, ``primary_key``,
462
- ``time_column``, and ``end_time_column``, which provide an aggregate
463
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
464
655
 
465
656
  Example:
466
657
  >>> # doctest: +SKIP
467
658
  >>> import kumoai.experimental.rfm as rfm
468
659
  >>> graph = rfm.Graph(tables=...).infer_metadata()
469
660
  >>> graph.metadata # doctest: +SKIP
470
- name primary_key time_column end_time_column
471
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
472
663
  """
473
664
  tables = list(self.tables.values())
474
665
 
475
666
  return pd.DataFrame({
476
- 'name':
667
+ 'Name':
477
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
478
- 'primary_key':
669
+ 'Primary Key':
479
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
480
- 'time_column':
671
+ 'Time Column':
481
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
482
- 'end_time_column':
673
+ 'End Time Column':
483
674
  pd.Series(
484
675
  dtype=str,
485
676
  data=[t._end_time_column or '-' for t in tables],
@@ -488,20 +679,8 @@ class Graph:
488
679
 
489
680
  def print_metadata(self) -> None:
490
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
491
- if in_notebook():
492
- from IPython.display import Markdown, display
493
- display(Markdown('### 🗂️ Graph Metadata'))
494
- df = self.metadata
495
- try:
496
- if hasattr(df.style, 'hide'):
497
- display(df.style.hide(axis='index')) # pandas=2
498
- else:
499
- display(df.style.hide_index()) # pandas<1.3
500
- except ImportError:
501
- print(df.to_string(index=False)) # missing jinja2
502
- else:
503
- print("🗂️ Graph Metadata:")
504
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
505
684
 
506
685
  def infer_metadata(self, verbose: bool = True) -> Self:
507
686
  r"""Infers metadata for all tables in the graph.
@@ -524,42 +703,33 @@ class Graph:
524
703
  # Edges ###################################################################
525
704
 
526
705
  @property
527
- def edges(self) -> List[Edge]:
706
+ def edges(self) -> list[Edge]:
528
707
  r"""Returns the edges of the graph."""
529
708
  return self._edges
530
709
 
531
710
  def print_links(self) -> None:
532
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
533
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
534
- edge.src_table, edge.fkey) for edge in self.edges]
535
- edges = sorted(edges)
536
-
537
- if in_notebook():
538
- from IPython.display import Markdown, display
539
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
540
- if len(edges) > 0:
541
- display(
542
- Markdown('\n'.join([
543
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
544
- for edge in edges
545
- ])))
546
- else:
547
- display(Markdown('*No links registered*'))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
548
725
  else:
549
- print("🕸️ Graph Links (FK ↔️ PK):")
550
- if len(edges) > 0:
551
- print('\n'.join([
552
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
553
- for edge in edges
554
- ]))
555
- else:
556
- print('No links registered')
726
+ display.italic("No links registered")
557
727
 
558
728
  def link(
559
729
  self,
560
- src_table: Union[str, Table],
730
+ src_table: str | Table,
561
731
  fkey: str,
562
- dst_table: Union[str, Table],
732
+ dst_table: str | Table,
563
733
  ) -> Self:
564
734
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
565
735
  key ``fkey`` in the source table to the primary key in the destination
@@ -620,9 +790,9 @@ class Graph:
620
790
 
621
791
  def unlink(
622
792
  self,
623
- src_table: Union[str, Table],
793
+ src_table: str | Table,
624
794
  fkey: str,
625
- dst_table: Union[str, Table],
795
+ dst_table: str | Table,
626
796
  ) -> Self:
627
797
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
628
798
 
@@ -660,6 +830,30 @@ class Graph:
660
830
  """
661
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
662
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
663
857
  # A list of primary key candidates (+score) for every column:
664
858
  candidate_dict: dict[
665
859
  tuple[str, str],
@@ -759,13 +953,8 @@ class Graph:
759
953
  if score < 5.0:
760
954
  continue
761
955
 
762
- candidate_dict[(
763
- src_table.name,
764
- src_key.name,
765
- )].append((
766
- dst_table.name,
767
- score,
768
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
769
958
 
770
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
771
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -799,6 +988,10 @@ class Graph:
799
988
  raise ValueError("At least one table needs to be added to the "
800
989
  "graph")
801
990
 
991
+ backends = {table.backend for table in self._tables.values()}
992
+ if len(backends) != 1:
993
+ raise ValueError("Found multiple table backends in the graph")
994
+
802
995
  for edge in self.edges:
803
996
  src_table, fkey, dst_table = edge
804
997
 
@@ -820,24 +1013,26 @@ class Graph:
820
1013
  f"either the primary key or the link before "
821
1014
  f"before proceeding.")
822
1015
 
823
- # Check that fkey/pkey have valid and consistent data types:
824
- assert src_key.dtype is not None
825
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
826
- src_string = src_key.dtype.is_string()
827
- assert dst_key.dtype is not None
828
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
829
- dst_string = dst_key.dtype.is_string()
830
-
831
- if not src_number and not src_string:
832
- raise ValueError(f"{edge} is invalid as foreign key must be a "
833
- f"number or string (got '{src_key.dtype}'")
834
-
835
- if src_number != dst_number or src_string != dst_string:
836
- raise ValueError(f"{edge} is invalid as foreign key "
837
- f"'{fkey}' and primary key '{dst_key.name}' "
838
- f"have incompatible data types (got "
839
- f"fkey.dtype '{src_key.dtype}' and "
840
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
841
1036
 
842
1037
  return self
843
1038
 
@@ -845,7 +1040,7 @@ class Graph:
845
1040
 
846
1041
  def visualize(
847
1042
  self,
848
- path: Optional[Union[str, io.BytesIO]] = None,
1043
+ path: str | io.BytesIO | None = None,
849
1044
  show_columns: bool = True,
850
1045
  ) -> 'graphviz.Graph':
851
1046
  r"""Visualizes the tables and edges in this graph using the
@@ -870,33 +1065,33 @@ class Graph:
870
1065
 
871
1066
  return True
872
1067
 
873
- # Check basic dependency:
874
- if not find_spec('graphviz'):
875
- raise ModuleNotFoundError("The 'graphviz' package is required for "
876
- "visualization")
877
- elif not has_graphviz_executables():
1068
+ try: # Check basic dependency:
1069
+ import graphviz
1070
+ except ImportError as e:
1071
+ raise ImportError("The 'graphviz' package is required for "
1072
+ "visualization") from e
1073
+
1074
+ if not in_snowflake_notebook() and not has_graphviz_executables():
878
1075
  raise RuntimeError("Could not visualize graph as 'graphviz' "
879
1076
  "executables are not installed. These "
880
1077
  "dependencies are required in addition to the "
881
1078
  "'graphviz' Python package. Please install "
882
1079
  "them as described at "
883
1080
  "https://graphviz.org/download/.")
884
- else:
885
- import graphviz
886
1081
 
887
- format: Optional[str] = None
1082
+ format: str | None = None
888
1083
  if isinstance(path, str):
889
1084
  format = path.split('.')[-1]
890
1085
  elif isinstance(path, io.BytesIO):
891
1086
  format = 'svg'
892
1087
  graph = graphviz.Graph(format=format)
893
1088
 
894
- def left_align(keys: List[str]) -> str:
1089
+ def left_align(keys: list[str]) -> str:
895
1090
  if len(keys) == 0:
896
1091
  return ""
897
1092
  return '\\l'.join(keys) + '\\l'
898
1093
 
899
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
900
1095
  for src_table_name, fkey_name, _ in self.edges:
901
1096
  fkeys_dict[src_table_name].append(fkey_name)
902
1097
 
@@ -966,6 +1161,9 @@ class Graph:
966
1161
  graph.render(path, cleanup=True)
967
1162
  elif isinstance(path, io.BytesIO):
968
1163
  path.write(graph.pipe())
1164
+ elif in_snowflake_notebook():
1165
+ import streamlit as st
1166
+ st.graphviz_chart(graph)
969
1167
  elif in_notebook():
970
1168
  from IPython.display import display
971
1169
  display(graph)
@@ -989,8 +1187,8 @@ class Graph:
989
1187
  # Helpers #################################################################
990
1188
 
991
1189
  def _to_api_graph_definition(self) -> GraphDefinition:
992
- tables: Dict[str, TableDefinition] = {}
993
- col_groups: List[ColumnKeyGroup] = []
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
994
1192
  for table_name, table in self.tables.items():
995
1193
  tables[table_name] = table._to_api_table_definition()
996
1194
  if table.primary_key is None:
@@ -1033,3 +1231,7 @@ class Graph:
1033
1231
  f' tables={tables},\n'
1034
1232
  f' edges={edges},\n'
1035
1233
  f')')
1234
+
1235
+ def __del__(self) -> None:
1236
+ if hasattr(self, '_connection'):
1237
+ self._connection.close()