kumoai 2.13.0.dev202511181731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. kumoai/__init__.py +12 -0
  2. kumoai/_version.py +1 -1
  3. kumoai/client/pquery.py +6 -2
  4. kumoai/connector/utils.py +23 -2
  5. kumoai/experimental/rfm/__init__.py +52 -52
  6. kumoai/experimental/rfm/authenticate.py +3 -4
  7. kumoai/experimental/rfm/backend/__init__.py +0 -0
  8. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  9. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +57 -110
  10. kumoai/experimental/rfm/backend/local/sampler.py +315 -0
  11. kumoai/experimental/rfm/backend/local/table.py +114 -0
  12. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  13. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  14. kumoai/experimental/rfm/backend/snow/table.py +169 -0
  15. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  16. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  17. kumoai/experimental/rfm/backend/sqlite/table.py +154 -0
  18. kumoai/experimental/rfm/base/__init__.py +33 -0
  19. kumoai/experimental/rfm/base/column.py +68 -0
  20. kumoai/experimental/rfm/base/column_expression.py +50 -0
  21. kumoai/experimental/rfm/base/sampler.py +773 -0
  22. kumoai/experimental/rfm/base/source.py +19 -0
  23. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  24. kumoai/experimental/rfm/base/sql_table.py +229 -0
  25. kumoai/experimental/rfm/{local_table.py → base/table.py} +219 -189
  26. kumoai/experimental/rfm/{local_graph.py → graph.py} +510 -91
  27. kumoai/experimental/rfm/infer/__init__.py +8 -0
  28. kumoai/experimental/rfm/infer/dtype.py +79 -0
  29. kumoai/experimental/rfm/infer/pkey.py +128 -0
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +61 -0
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  34. kumoai/experimental/rfm/rfm.py +313 -246
  35. kumoai/experimental/rfm/sagemaker.py +15 -7
  36. kumoai/pquery/predictive_query.py +10 -6
  37. kumoai/testing/decorators.py +1 -1
  38. kumoai/testing/snow.py +50 -0
  39. kumoai/utils/__init__.py +3 -2
  40. kumoai/utils/progress_logger.py +178 -12
  41. kumoai/utils/sql.py +3 -0
  42. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/METADATA +10 -8
  43. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/RECORD +46 -26
  44. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  45. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  46. kumoai/experimental/rfm/utils.py +0 -344
  47. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/WHEEL +0 -0
  48. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/licenses/LICENSE +0 -0
  49. {kumoai-2.13.0.dev202511181731.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
6
+ from collections.abc import Sequence
7
+ from dataclasses import dataclass, field
8
+ from itertools import chain
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, Union
7
11
 
8
12
  import pandas as pd
9
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,20 +15,34 @@ from kumoapi.table import TableDefinition
11
15
  from kumoapi.typing import Stype
12
16
  from typing_extensions import Self
13
17
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
18
+ from kumoai import in_notebook, in_snowflake_notebook
19
+ from kumoai.experimental.rfm.base import (
20
+ ColumnExpressionSpec,
21
+ DataBackend,
22
+ SQLTable,
23
+ Table,
24
+ )
16
25
  from kumoai.graph import Edge
26
+ from kumoai.mixin import CastMixin
17
27
 
18
28
  if TYPE_CHECKING:
19
29
  import graphviz
30
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
31
+ from snowflake.connector import SnowflakeConnection
20
32
 
21
33
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
34
+ @dataclass
35
+ class SqliteConnectionConfig(CastMixin):
36
+ uri: str | Path
37
+ kwargs: dict[str, Any] = field(default_factory=dict)
38
+
39
+
40
+ class Graph:
41
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
42
  tables in a relational database.
25
43
 
26
44
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
45
+ :class:`Graph` is created, you can use it to initialize the
28
46
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
47
 
30
48
  .. code-block:: python
@@ -44,7 +62,7 @@ class LocalGraph:
44
62
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
63
 
46
64
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
65
+ >>> graph = rfm.Graph({
48
66
  ... "table1": table1,
49
67
  ... "table2": table2,
50
68
  ... "table3": table3,
@@ -75,33 +93,58 @@ class LocalGraph:
75
93
 
76
94
  def __init__(
77
95
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
96
+ tables: Sequence[Table],
97
+ edges: Sequence[Edge] | None = None,
80
98
  ) -> None:
81
99
 
82
- self._tables: Dict[str, LocalTable] = {}
83
- self._edges: List[Edge] = []
100
+ self._tables: dict[str, Table] = {}
101
+ self._edges: list[Edge] = []
84
102
 
85
103
  for table in tables:
86
104
  self.add_table(table)
87
105
 
106
+ for table in tables:
107
+ if not isinstance(table, SQLTable):
108
+ continue
109
+ if '_source_column_dict' not in table.__dict__:
110
+ continue
111
+ for fkey in table._source_foreign_key_dict.values():
112
+ if fkey.name not in table:
113
+ continue
114
+ if not table[fkey.name].is_physical:
115
+ continue
116
+ dst_table_names = [
117
+ table.name for table in self.tables.values()
118
+ if isinstance(table, SQLTable)
119
+ and table._source_name == fkey.dst_table
120
+ ]
121
+ if len(dst_table_names) != 1:
122
+ continue
123
+ dst_table = self[dst_table_names[0]]
124
+ if dst_table._primary_key != fkey.primary_key:
125
+ continue
126
+ if not dst_table[fkey.primary_key].is_physical:
127
+ continue
128
+ self.link(table.name, fkey.name, dst_table.name)
129
+
88
130
  for edge in (edges or []):
89
131
  _edge = Edge._cast(edge)
90
132
  assert _edge is not None
91
- self.link(*_edge)
133
+ if _edge not in self._edges:
134
+ self.link(*_edge)
92
135
 
93
136
  @classmethod
94
137
  def from_data(
95
138
  cls,
96
- df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
139
+ df_dict: dict[str, pd.DataFrame],
140
+ edges: Sequence[Edge] | None = None,
98
141
  infer_metadata: bool = True,
99
142
  verbose: bool = True,
100
143
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
144
+ r"""Creates a :class:`Graph` from a dictionary of
102
145
  :class:`pandas.DataFrame` objects.
103
146
 
104
- Automatically infers table metadata and links.
147
+ Automatically infers table metadata and links by default.
105
148
 
106
149
  .. code-block:: python
107
150
 
@@ -115,59 +158,402 @@ class LocalGraph:
115
158
  >>> df3 = pd.DataFrame(...)
116
159
 
117
160
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
161
+ >>> graph = rfm.Graph.from_data({
119
162
  ... "table1": df1,
120
163
  ... "table2": df2,
121
164
  ... "table3": df3,
122
165
  ... })
123
166
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
167
  Args:
132
168
  df_dict: A dictionary of data frames, where the keys are the names
133
169
  of the tables and the values hold table data.
170
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
171
+ add to the graph. If not provided, edges will be automatically
172
+ inferred from the data in case ``infer_metadata=True``.
134
173
  infer_metadata: Whether to infer metadata for all tables in the
135
174
  graph.
175
+ verbose: Whether to print verbose output.
176
+ """
177
+ from kumoai.experimental.rfm.backend.local import LocalTable
178
+
179
+ graph = cls(
180
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
181
+ edges=edges or [],
182
+ )
183
+
184
+ if infer_metadata:
185
+ graph.infer_metadata(verbose=False)
186
+
187
+ if edges is None:
188
+ graph.infer_links(verbose=False)
189
+
190
+ if verbose:
191
+ graph.print_metadata()
192
+ graph.print_links()
193
+
194
+ return graph
195
+
196
+ @classmethod
197
+ def from_sqlite(
198
+ cls,
199
+ connection: Union[
200
+ 'AdbcSqliteConnection',
201
+ SqliteConnectionConfig,
202
+ str,
203
+ Path,
204
+ dict[str, Any],
205
+ ],
206
+ tables: Sequence[str | dict[str, Any]] | None = None,
207
+ edges: Sequence[Edge] | None = None,
208
+ infer_metadata: bool = True,
209
+ verbose: bool = True,
210
+ ) -> Self:
211
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
212
+
213
+ Automatically infers table metadata and links by default.
214
+
215
+ .. code-block:: python
216
+
217
+ >>> # doctest: +SKIP
218
+ >>> import kumoai.experimental.rfm as rfm
219
+
220
+ >>> # Create a graph from a SQLite database:
221
+ >>> graph = rfm.Graph.from_sqlite('data.db')
222
+
223
+ >>> # Fine-grained control over table specification:
224
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
225
+ ... 'USERS',
226
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
227
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
228
+ ... ])
229
+
230
+ Args:
231
+ connection: An open connection from
232
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
233
+ path to the database file.
234
+ tables: Set of table names or :class:`SQLiteTable` keyword
235
+ arguments to include. If ``None``, will add all tables present
236
+ in the database.
136
237
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
238
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
239
+ inferred from the data in case ``infer_metadata=True``.
240
+ infer_metadata: Whether to infer missing metadata for all tables in
241
+ the graph.
139
242
  verbose: Whether to print verbose output.
243
+ """
244
+ from kumoai.experimental.rfm.backend.sqlite import (
245
+ Connection,
246
+ SQLiteTable,
247
+ connect,
248
+ )
249
+
250
+ internal_connection = False
251
+ if not isinstance(connection, Connection):
252
+ connection = SqliteConnectionConfig._cast(connection)
253
+ assert isinstance(connection, SqliteConnectionConfig)
254
+ connection = connect(connection.uri, **connection.kwargs)
255
+ internal_connection = True
256
+ assert isinstance(connection, Connection)
257
+
258
+ if tables is None:
259
+ with connection.cursor() as cursor:
260
+ cursor.execute("SELECT name FROM sqlite_master "
261
+ "WHERE type='table'")
262
+ tables = [row[0] for row in cursor.fetchall()]
263
+
264
+ table_kwargs: list[dict[str, Any]] = []
265
+ for table in tables:
266
+ kwargs = dict(name=table) if isinstance(table, str) else table
267
+ table_kwargs.append(kwargs)
140
268
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
269
+ graph = cls(
270
+ tables=[
271
+ SQLiteTable(connection=connection, **kwargs)
272
+ for kwargs in table_kwargs
273
+ ],
274
+ edges=edges or [],
275
+ )
276
+
277
+ if internal_connection:
278
+ graph._connection = connection # type: ignore
279
+
280
+ if infer_metadata:
281
+ graph.infer_metadata(verbose=False)
282
+
283
+ if edges is None:
284
+ graph.infer_links(verbose=False)
285
+
286
+ if verbose:
287
+ graph.print_metadata()
288
+ graph.print_links()
289
+
290
+ return graph
291
+
292
+ @classmethod
293
+ def from_snowflake(
294
+ cls,
295
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
296
+ tables: Sequence[str | dict[str, Any]] | None = None,
297
+ database: str | None = None,
298
+ schema: str | None = None,
299
+ edges: Sequence[Edge] | None = None,
300
+ infer_metadata: bool = True,
301
+ verbose: bool = True,
302
+ ) -> Self:
303
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
304
+ schema.
305
+
306
+ Automatically infers table metadata and links by default.
307
+
308
+ .. code-block:: python
144
309
 
145
- Example:
146
310
  >>> # doctest: +SKIP
147
311
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
157
- """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
159
312
 
160
- graph = cls(tables, edges=edges or [])
313
+ >>> # Create a graph directly in a Snowflake notebook:
314
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
315
+
316
+ >>> # Fine-grained control over table specification:
317
+ >>> graph = rfm.Graph.from_snowflake(tables=[
318
+ ... 'USERS',
319
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
320
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
321
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
322
+
323
+ Args:
324
+ connection: An open connection from
325
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
326
+ :class:`snowflake` connector keyword arguments to open a new
327
+ connection. If ``None``, will re-use an active session in case
328
+ it exists, or create a new connection from credentials stored
329
+ in environment variables.
330
+ tables: Set of table names or :class:`SnowTable` keyword arguments
331
+ to include. If ``None``, will add all tables present in the
332
+ current database and schema.
333
+ database: The database.
334
+ schema: The schema.
335
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
336
+ add to the graph. If not provided, edges will be automatically
337
+ inferred from the data in case ``infer_metadata=True``.
338
+ infer_metadata: Whether to infer metadata for all tables in the
339
+ graph.
340
+ verbose: Whether to print verbose output.
341
+ """
342
+ from kumoai.experimental.rfm.backend.snow import (
343
+ Connection,
344
+ SnowTable,
345
+ connect,
346
+ )
347
+
348
+ if not isinstance(connection, Connection):
349
+ connection = connect(**(connection or {}))
350
+ assert isinstance(connection, Connection)
351
+
352
+ if database is None or schema is None:
353
+ with connection.cursor() as cursor:
354
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
355
+ result = cursor.fetchone()
356
+ database = database or result[0]
357
+ assert database is not None
358
+ schema = schema or result[1]
359
+
360
+ if tables is None:
361
+ if schema is None:
362
+ raise ValueError("No current 'schema' set. Please specify the "
363
+ "Snowflake schema manually")
364
+
365
+ with connection.cursor() as cursor:
366
+ cursor.execute(f"""
367
+ SELECT TABLE_NAME
368
+ FROM {database}.INFORMATION_SCHEMA.TABLES
369
+ WHERE TABLE_SCHEMA = '{schema}'
370
+ """)
371
+ tables = [row[0] for row in cursor.fetchall()]
372
+
373
+ table_kwargs: list[dict[str, Any]] = []
374
+ for table in tables:
375
+ if isinstance(table, str):
376
+ kwargs = dict(name=table, database=database, schema=schema)
377
+ else:
378
+ kwargs = copy.copy(table)
379
+ kwargs.setdefault('database', database)
380
+ kwargs.setdefault('schema', schema)
381
+ table_kwargs.append(kwargs)
382
+
383
+ graph = cls(
384
+ tables=[
385
+ SnowTable(connection=connection, **kwargs)
386
+ for kwargs in table_kwargs
387
+ ],
388
+ edges=edges or [],
389
+ )
161
390
 
162
391
  if infer_metadata:
163
- graph.infer_metadata(verbose)
392
+ graph.infer_metadata(verbose=False)
164
393
 
165
394
  if edges is None:
166
- graph.infer_links(verbose)
395
+ graph.infer_links(verbose=False)
396
+
397
+ if verbose:
398
+ graph.print_metadata()
399
+ graph.print_links()
400
+
401
+ return graph
402
+
403
+ @classmethod
404
+ def from_snowflake_semantic_view(
405
+ cls,
406
+ semantic_view_name: str,
407
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
408
+ verbose: bool = True,
409
+ ) -> Self:
410
+ import yaml
411
+
412
+ from kumoai.experimental.rfm.backend.snow import (
413
+ Connection,
414
+ SnowTable,
415
+ connect,
416
+ )
417
+
418
+ if not isinstance(connection, Connection):
419
+ connection = connect(**(connection or {}))
420
+ assert isinstance(connection, Connection)
421
+
422
+ with connection.cursor() as cursor:
423
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
424
+ f"'{semantic_view_name}')")
425
+ cfg = yaml.safe_load(cursor.fetchone()[0])
426
+
427
+ graph = cls(tables=[])
428
+
429
+ msgs = []
430
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
431
+ for table_cfg in cfg['tables']:
432
+ table_name = table_cfg['name']
433
+ source_table_name = table_cfg['base_table']['table']
434
+ database = table_cfg['base_table']['database']
435
+ schema = table_cfg['base_table']['schema']
436
+
437
+ primary_key: str | None = None
438
+ if 'primary_key' in table_cfg:
439
+ primary_key_cfg = table_cfg['primary_key']
440
+ if len(primary_key_cfg['columns']) == 1:
441
+ primary_key = primary_key_cfg['columns'][0]
442
+ elif len(primary_key_cfg['columns']) > 1:
443
+ msgs.append(f"Failed to add primary key for table "
444
+ f"'{table_name}' since composite primary keys "
445
+ f"are not yet supported")
446
+
447
+ columns: list[str] = []
448
+ unsupported_columns: list[str] = []
449
+ column_expression_specs: list[ColumnExpressionSpec] = []
450
+ for column_cfg in chain(
451
+ table_cfg.get('dimensions', []),
452
+ table_cfg.get('time_dimensions', []),
453
+ table_cfg.get('facts', []),
454
+ ):
455
+ column_name = column_cfg['name']
456
+ column_expr = column_cfg.get('expr', None)
457
+ column_data_type = column_cfg.get('data_type', None)
458
+
459
+ if column_expr is None:
460
+ columns.append(column_name)
461
+ continue
462
+
463
+ column_expr = column_expr.replace(f'{table_name}.', '')
464
+
465
+ if column_expr == column_name:
466
+ columns.append(column_name)
467
+ continue
468
+
469
+ # Drop expressions that reference other tables (for now):
470
+ if any(f'{name}.' in column_expr for name in table_names):
471
+ unsupported_columns.append(column_name)
472
+ continue
473
+
474
+ spec = ColumnExpressionSpec(
475
+ name=column_name,
476
+ expr=column_expr,
477
+ dtype=SnowTable.to_dtype(column_data_type),
478
+ )
479
+ column_expression_specs.append(spec)
480
+
481
+ if len(unsupported_columns) == 1:
482
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
483
+ f"of table '{table_name}' since its expression "
484
+ f"references other tables")
485
+ elif len(unsupported_columns) > 1:
486
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
487
+ f"of table '{table_name}' since their expressions "
488
+ f"reference other tables")
489
+
490
+ table = SnowTable(
491
+ connection,
492
+ name=table_name,
493
+ source_name=source_table_name,
494
+ database=database,
495
+ schema=schema,
496
+ columns=columns,
497
+ column_expressions=column_expression_specs,
498
+ primary_key=primary_key,
499
+ )
500
+
501
+ # TODO Add a way to register time columns without heuristic usage.
502
+ table.infer_time_column(verbose=False)
503
+
504
+ graph.add_table(table)
505
+
506
+ for relation_cfg in cfg.get('relationships', []):
507
+ name = relation_cfg['name']
508
+ if len(relation_cfg['relationship_columns']) != 1:
509
+ msgs.append(f"Failed to add relationship '{name}' since "
510
+ f"composite key references are not yet supported")
511
+ continue
512
+
513
+ left_table = relation_cfg['left_table']
514
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
515
+ right_table = relation_cfg['right_table']
516
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
517
+
518
+ if graph[right_table]._primary_key != right_key:
519
+ # Semantic view error - this should never be triggered:
520
+ msgs.append(f"Failed to add relationship '{name}' since the "
521
+ f"referenced key '{right_key}' of table "
522
+ f"'{right_table}' is not a primary key")
523
+ continue
524
+
525
+ if graph[left_table]._primary_key == left_key:
526
+ msgs.append(f"Failed to add relationship '{name}' since the "
527
+ f"referencing key '{left_key}' of table "
528
+ f"'{left_table}' is a primary key")
529
+ continue
530
+
531
+ if left_key not in graph[left_table]:
532
+ graph[left_table].add_column(left_key)
533
+
534
+ graph.link(left_table, left_key, right_table)
535
+
536
+ graph.validate()
537
+
538
+ if verbose:
539
+ graph.print_metadata()
540
+ graph.print_links()
541
+
542
+ if len(msgs) > 0:
543
+ title = (f"Could not fully convert the semantic view definition "
544
+ f"'{semantic_view_name}' into a graph:\n")
545
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
167
546
 
168
547
  return graph
169
548
 
170
- # Tables ##############################################################
549
+ # Backend #################################################################
550
+
551
+ @property
552
+ def backend(self) -> DataBackend | None:
553
+ backends = [table.backend for table in self._tables.values()]
554
+ return backends[0] if len(backends) > 0 else None
555
+
556
+ # Tables ##################################################################
171
557
 
172
558
  def has_table(self, name: str) -> bool:
173
559
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -175,7 +561,7 @@ class LocalGraph:
175
561
  """
176
562
  return name in self.tables
177
563
 
178
- def table(self, name: str) -> LocalTable:
564
+ def table(self, name: str) -> Table:
179
565
  r"""Returns the table with name ``name`` in the graph.
180
566
 
181
567
  Raises:
@@ -186,11 +572,11 @@ class LocalGraph:
186
572
  return self.tables[name]
187
573
 
188
574
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
575
+ def tables(self) -> dict[str, Table]:
190
576
  r"""Returns the dictionary of table objects."""
191
577
  return self._tables
192
578
 
193
- def add_table(self, table: LocalTable) -> Self:
579
+ def add_table(self, table: Table) -> Self:
194
580
  r"""Adds a table to the graph.
195
581
 
196
582
  Args:
@@ -199,11 +585,18 @@ class LocalGraph:
199
585
  Raises:
200
586
  KeyError: If a table with the same name already exists in the
201
587
  graph.
588
+ ValueError: If the table belongs to a different backend than the
589
+ rest of the tables in the graph.
202
590
  """
203
591
  if table.name in self._tables:
204
592
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
593
  f"this graph; table names must be globally unique.")
206
594
 
595
+ if self.backend is not None and table.backend != self.backend:
596
+ raise ValueError(f"Cannot register a table with backend "
597
+ f"'{table.backend}' to this graph since other "
598
+ f"tables have backend '{self.backend}'.")
599
+
207
600
  self._tables[table.name] = table
208
601
 
209
602
  return self
@@ -241,7 +634,7 @@ class LocalGraph:
241
634
  Example:
242
635
  >>> # doctest: +SKIP
243
636
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
637
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
638
  >>> graph.metadata # doctest: +SKIP
246
639
  name primary_key time_column end_time_column
247
640
  0 users user_id - -
@@ -263,10 +656,14 @@ class LocalGraph:
263
656
  })
264
657
 
265
658
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
659
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
660
+ if in_snowflake_notebook():
661
+ import streamlit as st
662
+ st.markdown("### 🗂️ Graph Metadata")
663
+ st.dataframe(self.metadata, hide_index=True)
664
+ elif in_notebook():
268
665
  from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
666
+ display(Markdown("### 🗂️ Graph Metadata"))
270
667
  df = self.metadata
271
668
  try:
272
669
  if hasattr(df.style, 'hide'):
@@ -287,7 +684,7 @@ class LocalGraph:
287
684
 
288
685
  Note:
289
686
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
687
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
688
  """
292
689
  for table in self.tables.values():
293
690
  table.infer_metadata(verbose=False)
@@ -300,42 +697,52 @@ class LocalGraph:
300
697
  # Edges ###################################################################
301
698
 
302
699
  @property
303
- def edges(self) -> List[Edge]:
700
+ def edges(self) -> list[Edge]:
304
701
  r"""Returns the edges of the graph."""
305
702
  return self._edges
306
703
 
307
704
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
705
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
706
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
707
  edge.src_table, edge.fkey) for edge in self.edges]
311
708
  edges = sorted(edges)
312
709
 
313
- if in_notebook():
710
+ if in_snowflake_notebook():
711
+ import streamlit as st
712
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
713
+ if len(edges) > 0:
714
+ st.markdown('\n'.join([
715
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
716
+ for edge in edges
717
+ ]))
718
+ else:
719
+ st.markdown("*No links registered*")
720
+ elif in_notebook():
314
721
  from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
722
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
316
723
  if len(edges) > 0:
317
724
  display(
318
725
  Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
726
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
320
727
  for edge in edges
321
728
  ])))
322
729
  else:
323
- display(Markdown('*No links registered*'))
730
+ display(Markdown("*No links registered*"))
324
731
  else:
325
732
  print("🕸️ Graph Links (FK ↔️ PK):")
326
733
  if len(edges) > 0:
327
734
  print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
735
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
329
736
  for edge in edges
330
737
  ]))
331
738
  else:
332
- print('No links registered')
739
+ print("No links registered")
333
740
 
334
741
  def link(
335
742
  self,
336
- src_table: Union[str, LocalTable],
743
+ src_table: str | Table,
337
744
  fkey: str,
338
- dst_table: Union[str, LocalTable],
745
+ dst_table: str | Table,
339
746
  ) -> Self:
340
747
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
748
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +765,11 @@ class LocalGraph:
358
765
  table does not exist in the graph, if the source key does not
359
766
  exist in the source table.
360
767
  """
361
- if isinstance(src_table, LocalTable):
768
+ if isinstance(src_table, Table):
362
769
  src_table = src_table.name
363
770
  assert isinstance(src_table, str)
364
771
 
365
- if isinstance(dst_table, LocalTable):
772
+ if isinstance(dst_table, Table):
366
773
  dst_table = dst_table.name
367
774
  assert isinstance(dst_table, str)
368
775
 
@@ -396,9 +803,9 @@ class LocalGraph:
396
803
 
397
804
  def unlink(
398
805
  self,
399
- src_table: Union[str, LocalTable],
806
+ src_table: str | Table,
400
807
  fkey: str,
401
- dst_table: Union[str, LocalTable],
808
+ dst_table: str | Table,
402
809
  ) -> Self:
403
810
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
811
 
@@ -410,11 +817,11 @@ class LocalGraph:
410
817
  Raises:
411
818
  ValueError: if the edge is not present in the graph.
412
819
  """
413
- if isinstance(src_table, LocalTable):
820
+ if isinstance(src_table, Table):
414
821
  src_table = src_table.name
415
822
  assert isinstance(src_table, str)
416
823
 
417
- if isinstance(dst_table, LocalTable):
824
+ if isinstance(dst_table, Table):
418
825
  dst_table = dst_table.name
419
826
  assert isinstance(dst_table, str)
420
827
 
@@ -428,17 +835,13 @@ class LocalGraph:
428
835
  return self
429
836
 
430
837
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
838
+ r"""Infers missing links for the tables and adds them as edges to the
839
+ graph.
432
840
 
433
841
  Args:
434
842
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
843
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
844
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
442
845
 
443
846
  # A list of primary key candidates (+score) for every column:
444
847
  candidate_dict: dict[
@@ -463,6 +866,9 @@ class LocalGraph:
463
866
  src_table_name = src_table.name.lower()
464
867
 
465
868
  for src_key in src_table.columns:
869
+ if (src_table.name, src_key.name) in known_edges:
870
+ continue
871
+
466
872
  if src_key == src_table.primary_key:
467
873
  continue # Cannot link to primary key.
468
874
 
@@ -528,7 +934,9 @@ class LocalGraph:
528
934
  score += 1.0
529
935
 
530
936
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
937
+ if (src_table._num_rows is not None
938
+ and dst_table._num_rows is not None
939
+ and src_table._num_rows > dst_table._num_rows):
532
940
  score += 1.0
533
941
 
534
942
  if score < 5.0:
@@ -574,6 +982,10 @@ class LocalGraph:
574
982
  raise ValueError("At least one table needs to be added to the "
575
983
  "graph")
576
984
 
985
+ backends = {table.backend for table in self._tables.values()}
986
+ if len(backends) != 1:
987
+ raise ValueError("Found multiple table backends in the graph")
988
+
577
989
  for edge in self.edges:
578
990
  src_table, fkey, dst_table = edge
579
991
 
@@ -620,7 +1032,7 @@ class LocalGraph:
620
1032
 
621
1033
  def visualize(
622
1034
  self,
623
- path: Optional[Union[str, io.BytesIO]] = None,
1035
+ path: str | io.BytesIO | None = None,
624
1036
  show_columns: bool = True,
625
1037
  ) -> 'graphviz.Graph':
626
1038
  r"""Visualizes the tables and edges in this graph using the
@@ -645,33 +1057,33 @@ class LocalGraph:
645
1057
 
646
1058
  return True
647
1059
 
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
1060
+ try: # Check basic dependency:
1061
+ import graphviz
1062
+ except ImportError as e:
1063
+ raise ImportError("The 'graphviz' package is required for "
1064
+ "visualization") from e
1065
+
1066
+ if not in_snowflake_notebook() and not has_graphviz_executables():
653
1067
  raise RuntimeError("Could not visualize graph as 'graphviz' "
654
1068
  "executables are not installed. These "
655
1069
  "dependencies are required in addition to the "
656
1070
  "'graphviz' Python package. Please install "
657
1071
  "them as described at "
658
1072
  "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
1073
 
662
- format: Optional[str] = None
1074
+ format: str | None = None
663
1075
  if isinstance(path, str):
664
1076
  format = path.split('.')[-1]
665
1077
  elif isinstance(path, io.BytesIO):
666
1078
  format = 'svg'
667
1079
  graph = graphviz.Graph(format=format)
668
1080
 
669
- def left_align(keys: List[str]) -> str:
1081
+ def left_align(keys: list[str]) -> str:
670
1082
  if len(keys) == 0:
671
1083
  return ""
672
1084
  return '\\l'.join(keys) + '\\l'
673
1085
 
674
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1086
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
675
1087
  for src_table_name, fkey_name, _ in self.edges:
676
1088
  fkeys_dict[src_table_name].append(fkey_name)
677
1089
 
@@ -741,6 +1153,9 @@ class LocalGraph:
741
1153
  graph.render(path, cleanup=True)
742
1154
  elif isinstance(path, io.BytesIO):
743
1155
  path.write(graph.pipe())
1156
+ elif in_snowflake_notebook():
1157
+ import streamlit as st
1158
+ st.graphviz_chart(graph)
744
1159
  elif in_notebook():
745
1160
  from IPython.display import display
746
1161
  display(graph)
@@ -764,8 +1179,8 @@ class LocalGraph:
764
1179
  # Helpers #################################################################
765
1180
 
766
1181
  def _to_api_graph_definition(self) -> GraphDefinition:
767
- tables: Dict[str, TableDefinition] = {}
768
- col_groups: List[ColumnKeyGroup] = []
1182
+ tables: dict[str, TableDefinition] = {}
1183
+ col_groups: list[ColumnKeyGroup] = []
769
1184
  for table_name, table in self.tables.items():
770
1185
  tables[table_name] = table._to_api_table_definition()
771
1186
  if table.primary_key is None:
@@ -790,7 +1205,7 @@ class LocalGraph:
790
1205
  def __contains__(self, name: str) -> bool:
791
1206
  return self.has_table(name)
792
1207
 
793
- def __getitem__(self, name: str) -> LocalTable:
1208
+ def __getitem__(self, name: str) -> Table:
794
1209
  return self.table(name)
795
1210
 
796
1211
  def __delitem__(self, name: str) -> None:
@@ -808,3 +1223,7 @@ class LocalGraph:
808
1223
  f' tables={tables},\n'
809
1224
  f' edges={edges},\n'
810
1225
  f')')
1226
+
1227
+ def __del__(self) -> None:
1228
+ if hasattr(self, '_connection'):
1229
+ self._connection.close()