kumoai 2.13.0.dev202511191731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. kumoai/__init__.py +35 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +44 -9
  7. kumoai/experimental/rfm/__init__.py +70 -68
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +366 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +454 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/mapper.py +67 -0
  24. kumoai/experimental/rfm/base/sampler.py +782 -0
  25. kumoai/experimental/rfm/base/source.py +19 -0
  26. kumoai/experimental/rfm/base/sql_sampler.py +366 -0
  27. kumoai/experimental/rfm/base/table.py +741 -0
  28. kumoai/experimental/rfm/{local_graph.py → graph.py} +581 -154
  29. kumoai/experimental/rfm/infer/__init__.py +8 -0
  30. kumoai/experimental/rfm/infer/dtype.py +82 -0
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +128 -0
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +61 -0
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +775 -481
  39. kumoai/experimental/rfm/sagemaker.py +15 -7
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/pquery/predictive_query.py +10 -6
  42. kumoai/pquery/training_table.py +16 -2
  43. kumoai/testing/decorators.py +1 -1
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +190 -12
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +10 -8
  51. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +54 -30
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -182
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. kumoai/experimental/rfm/local_table.py +0 -545
  55. kumoai/experimental/rfm/utils.py +0 -344
  56. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
  57. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
  58. {kumoai-2.13.0.dev202511191731.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
4
+ import copy
2
5
  import io
3
6
  import warnings
4
7
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
8
+ from collections.abc import Sequence
9
+ from dataclasses import dataclass, field
10
+ from itertools import chain
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Union
7
13
 
8
14
  import pandas as pd
9
15
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,20 +17,30 @@ from kumoapi.table import TableDefinition
11
17
  from kumoapi.typing import Stype
12
18
  from typing_extensions import Self
13
19
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
20
+ from kumoai import in_notebook, in_snowflake_notebook
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
16
22
  from kumoai.graph import Edge
23
+ from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
17
25
 
18
26
  if TYPE_CHECKING:
19
27
  import graphviz
28
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
29
+ from snowflake.connector import SnowflakeConnection
30
+
31
+
32
+ @dataclass
33
+ class SqliteConnectionConfig(CastMixin):
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
20
36
 
21
37
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
38
+ class Graph:
39
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
40
  tables in a relational database.
25
41
 
26
42
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
43
+ :class:`Graph` is created, you can use it to initialize the
28
44
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
45
 
30
46
  .. code-block:: python
@@ -44,7 +60,7 @@ class LocalGraph:
44
60
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
61
 
46
62
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
63
+ >>> graph = rfm.Graph({
48
64
  ... "table1": table1,
49
65
  ... "table2": table2,
50
66
  ... "table3": table3,
@@ -75,33 +91,55 @@ class LocalGraph:
75
91
 
76
92
  def __init__(
77
93
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
94
+ tables: Sequence[Table],
95
+ edges: Sequence[Edge] | None = None,
80
96
  ) -> None:
81
97
 
82
- self._tables: Dict[str, LocalTable] = {}
83
- self._edges: List[Edge] = []
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
84
100
 
85
101
  for table in tables:
86
102
  self.add_table(table)
87
103
 
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
107
+ for fkey in table._source_foreign_key_dict.values():
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
117
+ continue
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
120
+ continue
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
123
+ self.link(table.name, fkey.name, dst_table.name)
124
+
88
125
  for edge in (edges or []):
89
126
  _edge = Edge._cast(edge)
90
127
  assert _edge is not None
91
- self.link(*_edge)
128
+ if _edge not in self._edges:
129
+ self.link(*_edge)
92
130
 
93
131
  @classmethod
94
132
  def from_data(
95
133
  cls,
96
- df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
98
136
  infer_metadata: bool = True,
99
137
  verbose: bool = True,
100
138
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
139
+ r"""Creates a :class:`Graph` from a dictionary of
102
140
  :class:`pandas.DataFrame` objects.
103
141
 
104
- Automatically infers table metadata and links.
142
+ Automatically infers table metadata and links by default.
105
143
 
106
144
  .. code-block:: python
107
145
 
@@ -115,59 +153,429 @@ class LocalGraph:
115
153
  >>> df3 = pd.DataFrame(...)
116
154
 
117
155
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
156
+ >>> graph = rfm.Graph.from_data({
119
157
  ... "table1": df1,
120
158
  ... "table2": df2,
121
159
  ... "table3": df3,
122
160
  ... })
123
161
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
162
  Args:
132
163
  df_dict: A dictionary of data frames, where the keys are the names
133
164
  of the tables and the values hold table data.
165
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
166
+ add to the graph. If not provided, edges will be automatically
167
+ inferred from the data in case ``infer_metadata=True``.
134
168
  infer_metadata: Whether to infer metadata for all tables in the
135
169
  graph.
170
+ verbose: Whether to print verbose output.
171
+ """
172
+ from kumoai.experimental.rfm.backend.local import LocalTable
173
+
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
178
+
179
+ if infer_metadata:
180
+ graph.infer_metadata(verbose=False)
181
+
182
+ if edges is None:
183
+ graph.infer_links(verbose=False)
184
+
185
+ if verbose:
186
+ graph.print_metadata()
187
+ graph.print_links()
188
+
189
+ return graph
190
+
191
+ @classmethod
192
+ def from_sqlite(
193
+ cls,
194
+ connection: Union[
195
+ 'AdbcSqliteConnection',
196
+ SqliteConnectionConfig,
197
+ str,
198
+ Path,
199
+ dict[str, Any],
200
+ ],
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
203
+ infer_metadata: bool = True,
204
+ verbose: bool = True,
205
+ ) -> Self:
206
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
207
+
208
+ Automatically infers table metadata and links by default.
209
+
210
+ .. code-block:: python
211
+
212
+ >>> # doctest: +SKIP
213
+ >>> import kumoai.experimental.rfm as rfm
214
+
215
+ >>> # Create a graph from a SQLite database:
216
+ >>> graph = rfm.Graph.from_sqlite('data.db')
217
+
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
225
+ Args:
226
+ connection: An open connection from
227
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
228
+ path to the database file.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
136
232
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
233
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
234
+ inferred from the data in case ``infer_metadata=True``.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
139
237
  verbose: Whether to print verbose output.
238
+ """
239
+ from kumoai.experimental.rfm.backend.sqlite import (
240
+ Connection,
241
+ SQLiteTable,
242
+ connect,
243
+ )
244
+
245
+ internal_connection = False
246
+ if not isinstance(connection, Connection):
247
+ connection = SqliteConnectionConfig._cast(connection)
248
+ assert isinstance(connection, SqliteConnectionConfig)
249
+ connection = connect(connection.uri, **connection.kwargs)
250
+ internal_connection = True
251
+ assert isinstance(connection, Connection)
252
+
253
+ if tables is None:
254
+ with connection.cursor() as cursor:
255
+ cursor.execute("SELECT name FROM sqlite_master "
256
+ "WHERE type='table'")
257
+ tables = [row[0] for row in cursor.fetchall()]
258
+
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
140
263
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
271
+
272
+ if internal_connection:
273
+ graph._connection = connection # type: ignore
274
+
275
+ if infer_metadata:
276
+ graph.infer_metadata(verbose=False)
277
+
278
+ if edges is None:
279
+ graph.infer_links(verbose=False)
280
+
281
+ if verbose:
282
+ graph.print_metadata()
283
+ graph.print_links()
284
+
285
+ return graph
286
+
287
+ @classmethod
288
+ def from_snowflake(
289
+ cls,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
295
+ infer_metadata: bool = True,
296
+ verbose: bool = True,
297
+ ) -> Self:
298
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
299
+ schema.
300
+
301
+ Automatically infers table metadata and links by default.
302
+
303
+ .. code-block:: python
144
304
 
145
- Example:
146
305
  >>> # doctest: +SKIP
147
306
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
157
- """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
159
307
 
160
- graph = cls(tables, edges=edges or [])
308
+ >>> # Create a graph directly in a Snowflake notebook:
309
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
310
+
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
317
+
318
+ Args:
319
+ connection: An open connection from
320
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
321
+ :class:`snowflake` connector keyword arguments to open a new
322
+ connection. If ``None``, will re-use an active session in case
323
+ it exists, or create a new connection from credentials stored
324
+ in environment variables.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
328
+ database: The database.
329
+ schema: The schema.
330
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
331
+ add to the graph. If not provided, edges will be automatically
332
+ inferred from the data in case ``infer_metadata=True``.
333
+ infer_metadata: Whether to infer metadata for all tables in the
334
+ graph.
335
+ verbose: Whether to print verbose output.
336
+ """
337
+ from kumoai.experimental.rfm.backend.snow import (
338
+ Connection,
339
+ SnowTable,
340
+ connect,
341
+ )
342
+
343
+ if not isinstance(connection, Connection):
344
+ connection = connect(**(connection or {}))
345
+ assert isinstance(connection, Connection)
346
+
347
+ if database is None or schema is None:
348
+ with connection.cursor() as cursor:
349
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
360
+ with connection.cursor() as cursor:
361
+ cursor.execute(f"""
362
+ SELECT TABLE_NAME
363
+ FROM {database}.INFORMATION_SCHEMA.TABLES
364
+ WHERE TABLE_SCHEMA = '{schema}'
365
+ """)
366
+ tables = [row[0] for row in cursor.fetchall()]
367
+
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
161
385
 
162
386
  if infer_metadata:
163
- graph.infer_metadata(verbose)
387
+ graph.infer_metadata(verbose=False)
164
388
 
165
389
  if edges is None:
166
- graph.infer_links(verbose)
390
+ graph.infer_links(verbose=False)
391
+
392
+ if verbose:
393
+ graph.print_metadata()
394
+ graph.print_links()
167
395
 
168
396
  return graph
169
397
 
170
- # Tables ##############################################################
398
+ @classmethod
399
+ def from_snowflake_semantic_view(
400
+ cls,
401
+ semantic_view_name: str,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
403
+ verbose: bool = True,
404
+ ) -> Self:
405
+ import yaml
406
+
407
+ from kumoai.experimental.rfm.backend.snow import (
408
+ Connection,
409
+ SnowTable,
410
+ connect,
411
+ )
412
+
413
+ if not isinstance(connection, Connection):
414
+ connection = connect(**(connection or {}))
415
+ assert isinstance(connection, Connection)
416
+
417
+ with connection.cursor() as cursor:
418
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
419
+ f"'{semantic_view_name}')")
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
421
+
422
+ graph = cls(tables=[])
423
+
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
483
+
484
+ table = SnowTable(
485
+ connection,
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
491
+ primary_key=primary_key,
492
+ )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
497
+ graph.add_table(table)
498
+
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
505
+
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
530
+
531
+ if verbose:
532
+ graph.print_metadata()
533
+ graph.print_links()
534
+
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
540
+ return graph
541
+
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
571
+ # Backend #################################################################
572
+
573
+ @property
574
+ def backend(self) -> DataBackend | None:
575
+ backends = [table.backend for table in self._tables.values()]
576
+ return backends[0] if len(backends) > 0 else None
577
+
578
+ # Tables ##################################################################
171
579
 
172
580
  def has_table(self, name: str) -> bool:
173
581
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -175,7 +583,7 @@ class LocalGraph:
175
583
  """
176
584
  return name in self.tables
177
585
 
178
- def table(self, name: str) -> LocalTable:
586
+ def table(self, name: str) -> Table:
179
587
  r"""Returns the table with name ``name`` in the graph.
180
588
 
181
589
  Raises:
@@ -186,11 +594,11 @@ class LocalGraph:
186
594
  return self.tables[name]
187
595
 
188
596
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
597
+ def tables(self) -> dict[str, Table]:
190
598
  r"""Returns the dictionary of table objects."""
191
599
  return self._tables
192
600
 
193
- def add_table(self, table: LocalTable) -> Self:
601
+ def add_table(self, table: Table) -> Self:
194
602
  r"""Adds a table to the graph.
195
603
 
196
604
  Args:
@@ -199,11 +607,18 @@ class LocalGraph:
199
607
  Raises:
200
608
  KeyError: If a table with the same name already exists in the
201
609
  graph.
610
+ ValueError: If the table belongs to a different backend than the
611
+ rest of the tables in the graph.
202
612
  """
203
613
  if table.name in self._tables:
204
614
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
615
  f"this graph; table names must be globally unique.")
206
616
 
617
+ if self.backend is not None and table.backend != self.backend:
618
+ raise ValueError(f"Cannot register a table with backend "
619
+ f"'{table.backend}' to this graph since other "
620
+ f"tables have backend '{self.backend}'.")
621
+
207
622
  self._tables[table.name] = table
208
623
 
209
624
  return self
@@ -234,28 +649,28 @@ class LocalGraph:
234
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
235
650
  information about the tables in this graph.
236
651
 
237
- The returned dataframe has columns ``name``, ``primary_key``,
238
- ``time_column``, and ``end_time_column``, which provide an aggregate
239
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
240
655
 
241
656
  Example:
242
657
  >>> # doctest: +SKIP
243
658
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
659
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
660
  >>> graph.metadata # doctest: +SKIP
246
- name primary_key time_column end_time_column
247
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
248
663
  """
249
664
  tables = list(self.tables.values())
250
665
 
251
666
  return pd.DataFrame({
252
- 'name':
667
+ 'Name':
253
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
254
- 'primary_key':
669
+ 'Primary Key':
255
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
256
- 'time_column':
671
+ 'Time Column':
257
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
258
- 'end_time_column':
673
+ 'End Time Column':
259
674
  pd.Series(
260
675
  dtype=str,
261
676
  data=[t._end_time_column or '-' for t in tables],
@@ -263,21 +678,9 @@ class LocalGraph:
263
678
  })
264
679
 
265
680
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
268
- from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
270
- df = self.metadata
271
- try:
272
- if hasattr(df.style, 'hide'):
273
- display(df.style.hide(axis='index')) # pandas=2
274
- else:
275
- display(df.style.hide_index()) # pandas<1.3
276
- except ImportError:
277
- print(df.to_string(index=False)) # missing jinja2
278
- else:
279
- print("🗂️ Graph Metadata:")
280
- print(self.metadata.to_string(index=False))
681
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
281
684
 
282
685
  def infer_metadata(self, verbose: bool = True) -> Self:
283
686
  r"""Infers metadata for all tables in the graph.
@@ -287,7 +690,7 @@ class LocalGraph:
287
690
 
288
691
  Note:
289
692
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
693
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
694
  """
292
695
  for table in self.tables.values():
293
696
  table.infer_metadata(verbose=False)
@@ -300,42 +703,33 @@ class LocalGraph:
300
703
  # Edges ###################################################################
301
704
 
302
705
  @property
303
- def edges(self) -> List[Edge]:
706
+ def edges(self) -> list[Edge]:
304
707
  r"""Returns the edges of the graph."""
305
708
  return self._edges
306
709
 
307
710
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
309
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
- edge.src_table, edge.fkey) for edge in self.edges]
311
- edges = sorted(edges)
312
-
313
- if in_notebook():
314
- from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
316
- if len(edges) > 0:
317
- display(
318
- Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
320
- for edge in edges
321
- ])))
322
- else:
323
- display(Markdown('*No links registered*'))
711
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
324
725
  else:
325
- print("🕸️ Graph Links (FK ↔️ PK):")
326
- if len(edges) > 0:
327
- print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
329
- for edge in edges
330
- ]))
331
- else:
332
- print('No links registered')
726
+ display.italic("No links registered")
333
727
 
334
728
  def link(
335
729
  self,
336
- src_table: Union[str, LocalTable],
730
+ src_table: str | Table,
337
731
  fkey: str,
338
- dst_table: Union[str, LocalTable],
732
+ dst_table: str | Table,
339
733
  ) -> Self:
340
734
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
735
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +752,11 @@ class LocalGraph:
358
752
  table does not exist in the graph, if the source key does not
359
753
  exist in the source table.
360
754
  """
361
- if isinstance(src_table, LocalTable):
755
+ if isinstance(src_table, Table):
362
756
  src_table = src_table.name
363
757
  assert isinstance(src_table, str)
364
758
 
365
- if isinstance(dst_table, LocalTable):
759
+ if isinstance(dst_table, Table):
366
760
  dst_table = dst_table.name
367
761
  assert isinstance(dst_table, str)
368
762
 
@@ -396,9 +790,9 @@ class LocalGraph:
396
790
 
397
791
  def unlink(
398
792
  self,
399
- src_table: Union[str, LocalTable],
793
+ src_table: str | Table,
400
794
  fkey: str,
401
- dst_table: Union[str, LocalTable],
795
+ dst_table: str | Table,
402
796
  ) -> Self:
403
797
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
798
 
@@ -410,11 +804,11 @@ class LocalGraph:
410
804
  Raises:
411
805
  ValueError: if the edge is not present in the graph.
412
806
  """
413
- if isinstance(src_table, LocalTable):
807
+ if isinstance(src_table, Table):
414
808
  src_table = src_table.name
415
809
  assert isinstance(src_table, str)
416
810
 
417
- if isinstance(dst_table, LocalTable):
811
+ if isinstance(dst_table, Table):
418
812
  dst_table = dst_table.name
419
813
  assert isinstance(dst_table, str)
420
814
 
@@ -428,17 +822,37 @@ class LocalGraph:
428
822
  return self
429
823
 
430
824
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
825
+ r"""Infers missing links for the tables and adds them as edges to the
826
+ graph.
432
827
 
433
828
  Args:
434
829
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
830
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
831
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
832
+
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
442
856
 
443
857
  # A list of primary key candidates (+score) for every column:
444
858
  candidate_dict: dict[
@@ -463,6 +877,9 @@ class LocalGraph:
463
877
  src_table_name = src_table.name.lower()
464
878
 
465
879
  for src_key in src_table.columns:
880
+ if (src_table.name, src_key.name) in known_edges:
881
+ continue
882
+
466
883
  if src_key == src_table.primary_key:
467
884
  continue # Cannot link to primary key.
468
885
 
@@ -528,19 +945,16 @@ class LocalGraph:
528
945
  score += 1.0
529
946
 
530
947
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
948
+ if (src_table._num_rows is not None
949
+ and dst_table._num_rows is not None
950
+ and src_table._num_rows > dst_table._num_rows):
532
951
  score += 1.0
533
952
 
534
953
  if score < 5.0:
535
954
  continue
536
955
 
537
- candidate_dict[(
538
- src_table.name,
539
- src_key.name,
540
- )].append((
541
- dst_table.name,
542
- score,
543
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
544
958
 
545
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
546
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -574,6 +988,10 @@ class LocalGraph:
574
988
  raise ValueError("At least one table needs to be added to the "
575
989
  "graph")
576
990
 
991
+ backends = {table.backend for table in self._tables.values()}
992
+ if len(backends) != 1:
993
+ raise ValueError("Found multiple table backends in the graph")
994
+
577
995
  for edge in self.edges:
578
996
  src_table, fkey, dst_table = edge
579
997
 
@@ -595,24 +1013,26 @@ class LocalGraph:
595
1013
  f"either the primary key or the link before "
596
1014
  f"before proceeding.")
597
1015
 
598
- # Check that fkey/pkey have valid and consistent data types:
599
- assert src_key.dtype is not None
600
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
601
- src_string = src_key.dtype.is_string()
602
- assert dst_key.dtype is not None
603
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
604
- dst_string = dst_key.dtype.is_string()
605
-
606
- if not src_number and not src_string:
607
- raise ValueError(f"{edge} is invalid as foreign key must be a "
608
- f"number or string (got '{src_key.dtype}'")
609
-
610
- if src_number != dst_number or src_string != dst_string:
611
- raise ValueError(f"{edge} is invalid as foreign key "
612
- f"'{fkey}' and primary key '{dst_key.name}' "
613
- f"have incompatible data types (got "
614
- f"fkey.dtype '{src_key.dtype}' and "
615
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
616
1036
 
617
1037
  return self
618
1038
 
@@ -620,7 +1040,7 @@ class LocalGraph:
620
1040
 
621
1041
  def visualize(
622
1042
  self,
623
- path: Optional[Union[str, io.BytesIO]] = None,
1043
+ path: str | io.BytesIO | None = None,
624
1044
  show_columns: bool = True,
625
1045
  ) -> 'graphviz.Graph':
626
1046
  r"""Visualizes the tables and edges in this graph using the
@@ -645,33 +1065,33 @@ class LocalGraph:
645
1065
 
646
1066
  return True
647
1067
 
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
1068
+ try: # Check basic dependency:
1069
+ import graphviz
1070
+ except ImportError as e:
1071
+ raise ImportError("The 'graphviz' package is required for "
1072
+ "visualization") from e
1073
+
1074
+ if not in_snowflake_notebook() and not has_graphviz_executables():
653
1075
  raise RuntimeError("Could not visualize graph as 'graphviz' "
654
1076
  "executables are not installed. These "
655
1077
  "dependencies are required in addition to the "
656
1078
  "'graphviz' Python package. Please install "
657
1079
  "them as described at "
658
1080
  "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
1081
 
662
- format: Optional[str] = None
1082
+ format: str | None = None
663
1083
  if isinstance(path, str):
664
1084
  format = path.split('.')[-1]
665
1085
  elif isinstance(path, io.BytesIO):
666
1086
  format = 'svg'
667
1087
  graph = graphviz.Graph(format=format)
668
1088
 
669
- def left_align(keys: List[str]) -> str:
1089
+ def left_align(keys: list[str]) -> str:
670
1090
  if len(keys) == 0:
671
1091
  return ""
672
1092
  return '\\l'.join(keys) + '\\l'
673
1093
 
674
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
675
1095
  for src_table_name, fkey_name, _ in self.edges:
676
1096
  fkeys_dict[src_table_name].append(fkey_name)
677
1097
 
@@ -741,6 +1161,9 @@ class LocalGraph:
741
1161
  graph.render(path, cleanup=True)
742
1162
  elif isinstance(path, io.BytesIO):
743
1163
  path.write(graph.pipe())
1164
+ elif in_snowflake_notebook():
1165
+ import streamlit as st
1166
+ st.graphviz_chart(graph)
744
1167
  elif in_notebook():
745
1168
  from IPython.display import display
746
1169
  display(graph)
@@ -764,8 +1187,8 @@ class LocalGraph:
764
1187
  # Helpers #################################################################
765
1188
 
766
1189
  def _to_api_graph_definition(self) -> GraphDefinition:
767
- tables: Dict[str, TableDefinition] = {}
768
- col_groups: List[ColumnKeyGroup] = []
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
769
1192
  for table_name, table in self.tables.items():
770
1193
  tables[table_name] = table._to_api_table_definition()
771
1194
  if table.primary_key is None:
@@ -790,7 +1213,7 @@ class LocalGraph:
790
1213
  def __contains__(self, name: str) -> bool:
791
1214
  return self.has_table(name)
792
1215
 
793
- def __getitem__(self, name: str) -> LocalTable:
1216
+ def __getitem__(self, name: str) -> Table:
794
1217
  return self.table(name)
795
1218
 
796
1219
  def __delitem__(self, name: str) -> None:
@@ -808,3 +1231,7 @@ class LocalGraph:
808
1231
  f' tables={tables},\n'
809
1232
  f' edges={edges},\n'
810
1233
  f')')
1234
+
1235
+ def __del__(self) -> None:
1236
+ if hasattr(self, '_connection'):
1237
+ self._connection.close()