kumoai 2.13.0.dev202511131731__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. kumoai/__init__.py +18 -9
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +23 -2
  7. kumoai/experimental/rfm/__init__.py +191 -50
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/__init__.py +0 -0
  10. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  11. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  12. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  13. kumoai/experimental/rfm/backend/local/table.py +113 -0
  14. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  15. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  16. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  17. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  18. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  19. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  20. kumoai/experimental/rfm/base/__init__.py +30 -0
  21. kumoai/experimental/rfm/base/column.py +152 -0
  22. kumoai/experimental/rfm/base/expression.py +44 -0
  23. kumoai/experimental/rfm/base/sampler.py +761 -0
  24. kumoai/experimental/rfm/base/source.py +19 -0
  25. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  26. kumoai/experimental/rfm/base/table.py +753 -0
  27. kumoai/experimental/rfm/{local_graph.py → graph.py} +546 -116
  28. kumoai/experimental/rfm/infer/__init__.py +8 -0
  29. kumoai/experimental/rfm/infer/dtype.py +81 -0
  30. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  31. kumoai/experimental/rfm/infer/pkey.py +128 -0
  32. kumoai/experimental/rfm/infer/stype.py +35 -0
  33. kumoai/experimental/rfm/infer/time_col.py +61 -0
  34. kumoai/experimental/rfm/pquery/executor.py +27 -27
  35. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  36. kumoai/experimental/rfm/rfm.py +322 -252
  37. kumoai/experimental/rfm/sagemaker.py +138 -0
  38. kumoai/pquery/predictive_query.py +10 -6
  39. kumoai/spcs.py +1 -3
  40. kumoai/testing/decorators.py +1 -1
  41. kumoai/testing/snow.py +50 -0
  42. kumoai/trainer/distilled_trainer.py +175 -0
  43. kumoai/utils/__init__.py +3 -2
  44. kumoai/utils/progress_logger.py +178 -12
  45. kumoai/utils/sql.py +3 -0
  46. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/METADATA +13 -2
  47. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/RECORD +50 -29
  48. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  49. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  50. kumoai/experimental/rfm/local_table.py +0 -545
  51. kumoai/experimental/rfm/utils.py +0 -344
  52. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/WHEEL +0 -0
  53. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/licenses/LICENSE +0 -0
  54. {kumoai-2.13.0.dev202511131731.dist-info → kumoai-2.14.0.dev202512271732.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
5
- from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
6
+ from collections.abc import Sequence
7
+ from dataclasses import dataclass, field
8
+ from itertools import chain
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, Union
7
11
 
8
12
  import pandas as pd
9
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -11,20 +15,29 @@ from kumoapi.table import TableDefinition
11
15
  from kumoapi.typing import Stype
12
16
  from typing_extensions import Self
13
17
 
14
- from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
18
+ from kumoai import in_notebook, in_snowflake_notebook
19
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
16
20
  from kumoai.graph import Edge
21
+ from kumoai.mixin import CastMixin
17
22
 
18
23
  if TYPE_CHECKING:
19
24
  import graphviz
25
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
26
+ from snowflake.connector import SnowflakeConnection
20
27
 
21
28
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
29
+ @dataclass
30
+ class SqliteConnectionConfig(CastMixin):
31
+ uri: str | Path
32
+ kwargs: dict[str, Any] = field(default_factory=dict)
33
+
34
+
35
+ class Graph:
36
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
37
  tables in a relational database.
25
38
 
26
39
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
40
+ :class:`Graph` is created, you can use it to initialize the
28
41
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
42
 
30
43
  .. code-block:: python
@@ -44,7 +57,7 @@ class LocalGraph:
44
57
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
58
 
46
59
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
60
+ >>> graph = rfm.Graph({
48
61
  ... "table1": table1,
49
62
  ... "table2": table2,
50
63
  ... "table3": table3,
@@ -75,33 +88,55 @@ class LocalGraph:
75
88
 
76
89
  def __init__(
77
90
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
91
+ tables: Sequence[Table],
92
+ edges: Sequence[Edge] | None = None,
80
93
  ) -> None:
81
94
 
82
- self._tables: Dict[str, LocalTable] = {}
83
- self._edges: List[Edge] = []
95
+ self._tables: dict[str, Table] = {}
96
+ self._edges: list[Edge] = []
84
97
 
85
98
  for table in tables:
86
99
  self.add_table(table)
87
100
 
101
+ for table in tables: # Use links from source metadata:
102
+ if not any(column.is_source for column in table.columns):
103
+ continue
104
+ for fkey in table._source_foreign_key_dict.values():
105
+ if fkey.name not in table:
106
+ continue
107
+ if not table[fkey.name].is_source:
108
+ continue
109
+ dst_table_names = [
110
+ table.name for table in self.tables.values()
111
+ if table.source_name == fkey.dst_table
112
+ ]
113
+ if len(dst_table_names) != 1:
114
+ continue
115
+ dst_table = self[dst_table_names[0]]
116
+ if dst_table._primary_key != fkey.primary_key:
117
+ continue
118
+ if not dst_table[fkey.primary_key].is_source:
119
+ continue
120
+ self.link(table.name, fkey.name, dst_table.name)
121
+
88
122
  for edge in (edges or []):
89
123
  _edge = Edge._cast(edge)
90
124
  assert _edge is not None
91
- self.link(*_edge)
125
+ if _edge not in self._edges:
126
+ self.link(*_edge)
92
127
 
93
128
  @classmethod
94
129
  def from_data(
95
130
  cls,
96
- df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
131
+ df_dict: dict[str, pd.DataFrame],
132
+ edges: Sequence[Edge] | None = None,
98
133
  infer_metadata: bool = True,
99
134
  verbose: bool = True,
100
135
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
136
+ r"""Creates a :class:`Graph` from a dictionary of
102
137
  :class:`pandas.DataFrame` objects.
103
138
 
104
- Automatically infers table metadata and links.
139
+ Automatically infers table metadata and links by default.
105
140
 
106
141
  .. code-block:: python
107
142
 
@@ -115,59 +150,400 @@ class LocalGraph:
115
150
  >>> df3 = pd.DataFrame(...)
116
151
 
117
152
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
153
+ >>> graph = rfm.Graph.from_data({
119
154
  ... "table1": df1,
120
155
  ... "table2": df2,
121
156
  ... "table3": df3,
122
157
  ... })
123
158
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
159
  Args:
132
160
  df_dict: A dictionary of data frames, where the keys are the names
133
161
  of the tables and the values hold table data.
162
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
163
+ add to the graph. If not provided, edges will be automatically
164
+ inferred from the data in case ``infer_metadata=True``.
134
165
  infer_metadata: Whether to infer metadata for all tables in the
135
166
  graph.
167
+ verbose: Whether to print verbose output.
168
+ """
169
+ from kumoai.experimental.rfm.backend.local import LocalTable
170
+
171
+ graph = cls(
172
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
173
+ edges=edges or [],
174
+ )
175
+
176
+ if infer_metadata:
177
+ graph.infer_metadata(verbose=False)
178
+
179
+ if edges is None:
180
+ graph.infer_links(verbose=False)
181
+
182
+ if verbose:
183
+ graph.print_metadata()
184
+ graph.print_links()
185
+
186
+ return graph
187
+
188
+ @classmethod
189
+ def from_sqlite(
190
+ cls,
191
+ connection: Union[
192
+ 'AdbcSqliteConnection',
193
+ SqliteConnectionConfig,
194
+ str,
195
+ Path,
196
+ dict[str, Any],
197
+ ],
198
+ tables: Sequence[str | dict[str, Any]] | None = None,
199
+ edges: Sequence[Edge] | None = None,
200
+ infer_metadata: bool = True,
201
+ verbose: bool = True,
202
+ ) -> Self:
203
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
204
+
205
+ Automatically infers table metadata and links by default.
206
+
207
+ .. code-block:: python
208
+
209
+ >>> # doctest: +SKIP
210
+ >>> import kumoai.experimental.rfm as rfm
211
+
212
+ >>> # Create a graph from a SQLite database:
213
+ >>> graph = rfm.Graph.from_sqlite('data.db')
214
+
215
+ >>> # Fine-grained control over table specification:
216
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
217
+ ... 'USERS',
218
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
219
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
220
+ ... ])
221
+
222
+ Args:
223
+ connection: An open connection from
224
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
225
+ path to the database file.
226
+ tables: Set of table names or :class:`SQLiteTable` keyword
227
+ arguments to include. If ``None``, will add all tables present
228
+ in the database.
136
229
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
230
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
231
+ inferred from the data in case ``infer_metadata=True``.
232
+ infer_metadata: Whether to infer missing metadata for all tables in
233
+ the graph.
139
234
  verbose: Whether to print verbose output.
235
+ """
236
+ from kumoai.experimental.rfm.backend.sqlite import (
237
+ Connection,
238
+ SQLiteTable,
239
+ connect,
240
+ )
241
+
242
+ internal_connection = False
243
+ if not isinstance(connection, Connection):
244
+ connection = SqliteConnectionConfig._cast(connection)
245
+ assert isinstance(connection, SqliteConnectionConfig)
246
+ connection = connect(connection.uri, **connection.kwargs)
247
+ internal_connection = True
248
+ assert isinstance(connection, Connection)
249
+
250
+ if tables is None:
251
+ with connection.cursor() as cursor:
252
+ cursor.execute("SELECT name FROM sqlite_master "
253
+ "WHERE type='table'")
254
+ tables = [row[0] for row in cursor.fetchall()]
255
+
256
+ table_kwargs: list[dict[str, Any]] = []
257
+ for table in tables:
258
+ kwargs = dict(name=table) if isinstance(table, str) else table
259
+ table_kwargs.append(kwargs)
140
260
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
261
+ graph = cls(
262
+ tables=[
263
+ SQLiteTable(connection=connection, **kwargs)
264
+ for kwargs in table_kwargs
265
+ ],
266
+ edges=edges or [],
267
+ )
268
+
269
+ if internal_connection:
270
+ graph._connection = connection # type: ignore
271
+
272
+ if infer_metadata:
273
+ graph.infer_metadata(verbose=False)
274
+
275
+ if edges is None:
276
+ graph.infer_links(verbose=False)
277
+
278
+ if verbose:
279
+ graph.print_metadata()
280
+ graph.print_links()
281
+
282
+ return graph
283
+
284
+ @classmethod
285
+ def from_snowflake(
286
+ cls,
287
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
288
+ tables: Sequence[str | dict[str, Any]] | None = None,
289
+ database: str | None = None,
290
+ schema: str | None = None,
291
+ edges: Sequence[Edge] | None = None,
292
+ infer_metadata: bool = True,
293
+ verbose: bool = True,
294
+ ) -> Self:
295
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
296
+ schema.
297
+
298
+ Automatically infers table metadata and links by default.
299
+
300
+ .. code-block:: python
144
301
 
145
- Example:
146
302
  >>> # doctest: +SKIP
147
303
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
157
- """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
159
304
 
160
- graph = cls(tables, edges=edges or [])
305
+ >>> # Create a graph directly in a Snowflake notebook:
306
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
307
+
308
+ >>> # Fine-grained control over table specification:
309
+ >>> graph = rfm.Graph.from_snowflake(tables=[
310
+ ... 'USERS',
311
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
312
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
313
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
314
+
315
+ Args:
316
+ connection: An open connection from
317
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
318
+ :class:`snowflake` connector keyword arguments to open a new
319
+ connection. If ``None``, will re-use an active session in case
320
+ it exists, or create a new connection from credentials stored
321
+ in environment variables.
322
+ tables: Set of table names or :class:`SnowTable` keyword arguments
323
+ to include. If ``None``, will add all tables present in the
324
+ current database and schema.
325
+ database: The database.
326
+ schema: The schema.
327
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
328
+ add to the graph. If not provided, edges will be automatically
329
+ inferred from the data in case ``infer_metadata=True``.
330
+ infer_metadata: Whether to infer metadata for all tables in the
331
+ graph.
332
+ verbose: Whether to print verbose output.
333
+ """
334
+ from kumoai.experimental.rfm.backend.snow import (
335
+ Connection,
336
+ SnowTable,
337
+ connect,
338
+ )
339
+
340
+ if not isinstance(connection, Connection):
341
+ connection = connect(**(connection or {}))
342
+ assert isinstance(connection, Connection)
343
+
344
+ if database is None or schema is None:
345
+ with connection.cursor() as cursor:
346
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
347
+ result = cursor.fetchone()
348
+ database = database or result[0]
349
+ assert database is not None
350
+ schema = schema or result[1]
351
+
352
+ if tables is None:
353
+ if schema is None:
354
+ raise ValueError("No current 'schema' set. Please specify the "
355
+ "Snowflake schema manually")
356
+
357
+ with connection.cursor() as cursor:
358
+ cursor.execute(f"""
359
+ SELECT TABLE_NAME
360
+ FROM {database}.INFORMATION_SCHEMA.TABLES
361
+ WHERE TABLE_SCHEMA = '{schema}'
362
+ """)
363
+ tables = [row[0] for row in cursor.fetchall()]
364
+
365
+ table_kwargs: list[dict[str, Any]] = []
366
+ for table in tables:
367
+ if isinstance(table, str):
368
+ kwargs = dict(name=table, database=database, schema=schema)
369
+ else:
370
+ kwargs = copy.copy(table)
371
+ kwargs.setdefault('database', database)
372
+ kwargs.setdefault('schema', schema)
373
+ table_kwargs.append(kwargs)
374
+
375
+ graph = cls(
376
+ tables=[
377
+ SnowTable(connection=connection, **kwargs)
378
+ for kwargs in table_kwargs
379
+ ],
380
+ edges=edges or [],
381
+ )
161
382
 
162
383
  if infer_metadata:
163
- graph.infer_metadata(verbose)
384
+ graph.infer_metadata(verbose=False)
164
385
 
165
386
  if edges is None:
166
- graph.infer_links(verbose)
387
+ graph.infer_links(verbose=False)
388
+
389
+ if verbose:
390
+ graph.print_metadata()
391
+ graph.print_links()
167
392
 
168
393
  return graph
169
394
 
170
- # Tables ##############################################################
395
+ @classmethod
396
+ def from_snowflake_semantic_view(
397
+ cls,
398
+ semantic_view_name: str,
399
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
400
+ verbose: bool = True,
401
+ ) -> Self:
402
+ import yaml
403
+
404
+ from kumoai.experimental.rfm.backend.snow import (
405
+ Connection,
406
+ SnowTable,
407
+ connect,
408
+ )
409
+
410
+ if not isinstance(connection, Connection):
411
+ connection = connect(**(connection or {}))
412
+ assert isinstance(connection, Connection)
413
+
414
+ with connection.cursor() as cursor:
415
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
416
+ f"'{semantic_view_name}')")
417
+ cfg = yaml.safe_load(cursor.fetchone()[0])
418
+
419
+ graph = cls(tables=[])
420
+
421
+ msgs = []
422
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
423
+ for table_cfg in cfg['tables']:
424
+ table_name = table_cfg['name']
425
+ source_table_name = table_cfg['base_table']['table']
426
+ database = table_cfg['base_table']['database']
427
+ schema = table_cfg['base_table']['schema']
428
+
429
+ primary_key: str | None = None
430
+ if 'primary_key' in table_cfg:
431
+ primary_key_cfg = table_cfg['primary_key']
432
+ if len(primary_key_cfg['columns']) == 1:
433
+ primary_key = primary_key_cfg['columns'][0]
434
+ elif len(primary_key_cfg['columns']) > 1:
435
+ msgs.append(f"Failed to add primary key for table "
436
+ f"'{table_name}' since composite primary keys "
437
+ f"are not yet supported")
438
+
439
+ columns: list[ColumnSpec] = []
440
+ unsupported_columns: list[str] = []
441
+ for column_cfg in chain(
442
+ table_cfg.get('dimensions', []),
443
+ table_cfg.get('time_dimensions', []),
444
+ table_cfg.get('facts', []),
445
+ ):
446
+ column_name = column_cfg['name']
447
+ column_expr = column_cfg.get('expr', None)
448
+ column_data_type = column_cfg.get('data_type', None)
449
+
450
+ if column_expr is None:
451
+ columns.append(ColumnSpec(name=column_name))
452
+ continue
453
+
454
+ column_expr = column_expr.replace(f'{table_name}.', '')
455
+
456
+ if column_expr == column_name:
457
+ columns.append(ColumnSpec(name=column_name))
458
+ continue
459
+
460
+ # Drop expressions that reference other tables (for now):
461
+ if any(f'{name}.' in column_expr for name in table_names):
462
+ unsupported_columns.append(column_name)
463
+ continue
464
+
465
+ column = ColumnSpec(
466
+ name=column_name,
467
+ expr=column_expr,
468
+ dtype=SnowTable._to_dtype(column_data_type),
469
+ )
470
+ columns.append(column)
471
+
472
+ if len(unsupported_columns) == 1:
473
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
474
+ f"of table '{table_name}' since its expression "
475
+ f"references other tables")
476
+ elif len(unsupported_columns) > 1:
477
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
478
+ f"of table '{table_name}' since their expressions "
479
+ f"reference other tables")
480
+
481
+ table = SnowTable(
482
+ connection,
483
+ name=table_name,
484
+ source_name=source_table_name,
485
+ database=database,
486
+ schema=schema,
487
+ columns=columns,
488
+ primary_key=primary_key,
489
+ )
490
+
491
+ # TODO Add a way to register time columns without heuristic usage.
492
+ table.infer_time_column(verbose=False)
493
+
494
+ graph.add_table(table)
495
+
496
+ for relation_cfg in cfg.get('relationships', []):
497
+ name = relation_cfg['name']
498
+ if len(relation_cfg['relationship_columns']) != 1:
499
+ msgs.append(f"Failed to add relationship '{name}' since "
500
+ f"composite key references are not yet supported")
501
+ continue
502
+
503
+ left_table = relation_cfg['left_table']
504
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
505
+ right_table = relation_cfg['right_table']
506
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
507
+
508
+ if graph[right_table]._primary_key != right_key:
509
+ # Semantic view error - this should never be triggered:
510
+ msgs.append(f"Failed to add relationship '{name}' since the "
511
+ f"referenced key '{right_key}' of table "
512
+ f"'{right_table}' is not a primary key")
513
+ continue
514
+
515
+ if graph[left_table]._primary_key == left_key:
516
+ msgs.append(f"Failed to add relationship '{name}' since the "
517
+ f"referencing key '{left_key}' of table "
518
+ f"'{left_table}' is a primary key")
519
+ continue
520
+
521
+ if left_key not in graph[left_table]:
522
+ graph[left_table].add_column(left_key)
523
+
524
+ graph.link(left_table, left_key, right_table)
525
+
526
+ graph.validate()
527
+
528
+ if verbose:
529
+ graph.print_metadata()
530
+ graph.print_links()
531
+
532
+ if len(msgs) > 0:
533
+ title = (f"Could not fully convert the semantic view definition "
534
+ f"'{semantic_view_name}' into a graph:\n")
535
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
536
+
537
+ return graph
538
+
539
+ # Backend #################################################################
540
+
541
+ @property
542
+ def backend(self) -> DataBackend | None:
543
+ backends = [table.backend for table in self._tables.values()]
544
+ return backends[0] if len(backends) > 0 else None
545
+
546
+ # Tables ##################################################################
171
547
 
172
548
  def has_table(self, name: str) -> bool:
173
549
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -175,7 +551,7 @@ class LocalGraph:
175
551
  """
176
552
  return name in self.tables
177
553
 
178
- def table(self, name: str) -> LocalTable:
554
+ def table(self, name: str) -> Table:
179
555
  r"""Returns the table with name ``name`` in the graph.
180
556
 
181
557
  Raises:
@@ -186,11 +562,11 @@ class LocalGraph:
186
562
  return self.tables[name]
187
563
 
188
564
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
565
+ def tables(self) -> dict[str, Table]:
190
566
  r"""Returns the dictionary of table objects."""
191
567
  return self._tables
192
568
 
193
- def add_table(self, table: LocalTable) -> Self:
569
+ def add_table(self, table: Table) -> Self:
194
570
  r"""Adds a table to the graph.
195
571
 
196
572
  Args:
@@ -199,11 +575,18 @@ class LocalGraph:
199
575
  Raises:
200
576
  KeyError: If a table with the same name already exists in the
201
577
  graph.
578
+ ValueError: If the table belongs to a different backend than the
579
+ rest of the tables in the graph.
202
580
  """
203
581
  if table.name in self._tables:
204
582
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
583
  f"this graph; table names must be globally unique.")
206
584
 
585
+ if self.backend is not None and table.backend != self.backend:
586
+ raise ValueError(f"Cannot register a table with backend "
587
+ f"'{table.backend}' to this graph since other "
588
+ f"tables have backend '{self.backend}'.")
589
+
207
590
  self._tables[table.name] = table
208
591
 
209
592
  return self
@@ -241,7 +624,7 @@ class LocalGraph:
241
624
  Example:
242
625
  >>> # doctest: +SKIP
243
626
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
627
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
628
  >>> graph.metadata # doctest: +SKIP
246
629
  name primary_key time_column end_time_column
247
630
  0 users user_id - -
@@ -263,10 +646,14 @@ class LocalGraph:
263
646
  })
264
647
 
265
648
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
267
- if in_notebook():
649
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
650
+ if in_snowflake_notebook():
651
+ import streamlit as st
652
+ st.markdown("### 🗂️ Graph Metadata")
653
+ st.dataframe(self.metadata, hide_index=True)
654
+ elif in_notebook():
268
655
  from IPython.display import Markdown, display
269
- display(Markdown('### 🗂️ Graph Metadata'))
656
+ display(Markdown("### 🗂️ Graph Metadata"))
270
657
  df = self.metadata
271
658
  try:
272
659
  if hasattr(df.style, 'hide'):
@@ -287,7 +674,7 @@ class LocalGraph:
287
674
 
288
675
  Note:
289
676
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
677
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
678
  """
292
679
  for table in self.tables.values():
293
680
  table.infer_metadata(verbose=False)
@@ -300,42 +687,52 @@ class LocalGraph:
300
687
  # Edges ###################################################################
301
688
 
302
689
  @property
303
- def edges(self) -> List[Edge]:
690
+ def edges(self) -> list[Edge]:
304
691
  r"""Returns the edges of the graph."""
305
692
  return self._edges
306
693
 
307
694
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
695
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
696
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
697
  edge.src_table, edge.fkey) for edge in self.edges]
311
698
  edges = sorted(edges)
312
699
 
313
- if in_notebook():
700
+ if in_snowflake_notebook():
701
+ import streamlit as st
702
+ st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
703
+ if len(edges) > 0:
704
+ st.markdown('\n'.join([
705
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
706
+ for edge in edges
707
+ ]))
708
+ else:
709
+ st.markdown("*No links registered*")
710
+ elif in_notebook():
314
711
  from IPython.display import Markdown, display
315
- display(Markdown('### 🕸️ Graph Links (FK ↔️ PK)'))
712
+ display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
316
713
  if len(edges) > 0:
317
714
  display(
318
715
  Markdown('\n'.join([
319
- f'- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`'
716
+ f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
320
717
  for edge in edges
321
718
  ])))
322
719
  else:
323
- display(Markdown('*No links registered*'))
720
+ display(Markdown("*No links registered*"))
324
721
  else:
325
722
  print("🕸️ Graph Links (FK ↔️ PK):")
326
723
  if len(edges) > 0:
327
724
  print('\n'.join([
328
- f'• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}'
725
+ f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
329
726
  for edge in edges
330
727
  ]))
331
728
  else:
332
- print('No links registered')
729
+ print("No links registered")
333
730
 
334
731
  def link(
335
732
  self,
336
- src_table: Union[str, LocalTable],
733
+ src_table: str | Table,
337
734
  fkey: str,
338
- dst_table: Union[str, LocalTable],
735
+ dst_table: str | Table,
339
736
  ) -> Self:
340
737
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
738
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +755,11 @@ class LocalGraph:
358
755
  table does not exist in the graph, if the source key does not
359
756
  exist in the source table.
360
757
  """
361
- if isinstance(src_table, LocalTable):
758
+ if isinstance(src_table, Table):
362
759
  src_table = src_table.name
363
760
  assert isinstance(src_table, str)
364
761
 
365
- if isinstance(dst_table, LocalTable):
762
+ if isinstance(dst_table, Table):
366
763
  dst_table = dst_table.name
367
764
  assert isinstance(dst_table, str)
368
765
 
@@ -396,9 +793,9 @@ class LocalGraph:
396
793
 
397
794
  def unlink(
398
795
  self,
399
- src_table: Union[str, LocalTable],
796
+ src_table: str | Table,
400
797
  fkey: str,
401
- dst_table: Union[str, LocalTable],
798
+ dst_table: str | Table,
402
799
  ) -> Self:
403
800
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
801
 
@@ -410,11 +807,11 @@ class LocalGraph:
410
807
  Raises:
411
808
  ValueError: if the edge is not present in the graph.
412
809
  """
413
- if isinstance(src_table, LocalTable):
810
+ if isinstance(src_table, Table):
414
811
  src_table = src_table.name
415
812
  assert isinstance(src_table, str)
416
813
 
417
- if isinstance(dst_table, LocalTable):
814
+ if isinstance(dst_table, Table):
418
815
  dst_table = dst_table.name
419
816
  assert isinstance(dst_table, str)
420
817
 
@@ -428,17 +825,37 @@ class LocalGraph:
428
825
  return self
429
826
 
430
827
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
828
+ r"""Infers missing links for the tables and adds them as edges to the
829
+ graph.
432
830
 
433
831
  Args:
434
832
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
833
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
834
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
835
+
836
+ for table in self.tables.values(): # Use links from source metadata:
837
+ if not any(column.is_source for column in table.columns):
838
+ continue
839
+ for fkey in table._source_foreign_key_dict.values():
840
+ if fkey.name not in table:
841
+ continue
842
+ if not table[fkey.name].is_source:
843
+ continue
844
+ if (table.name, fkey.name) in known_edges:
845
+ continue
846
+ dst_table_names = [
847
+ table.name for table in self.tables.values()
848
+ if table.source_name == fkey.dst_table
849
+ ]
850
+ if len(dst_table_names) != 1:
851
+ continue
852
+ dst_table = self[dst_table_names[0]]
853
+ if dst_table._primary_key != fkey.primary_key:
854
+ continue
855
+ if not dst_table[fkey.primary_key].is_source:
856
+ continue
857
+ self.link(table.name, fkey.name, dst_table.name)
858
+ known_edges.add((table.name, fkey.name))
442
859
 
443
860
  # A list of primary key candidates (+score) for every column:
444
861
  candidate_dict: dict[
@@ -463,6 +880,9 @@ class LocalGraph:
463
880
  src_table_name = src_table.name.lower()
464
881
 
465
882
  for src_key in src_table.columns:
883
+ if (src_table.name, src_key.name) in known_edges:
884
+ continue
885
+
466
886
  if src_key == src_table.primary_key:
467
887
  continue # Cannot link to primary key.
468
888
 
@@ -528,19 +948,16 @@ class LocalGraph:
528
948
  score += 1.0
529
949
 
530
950
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
951
+ if (src_table._num_rows is not None
952
+ and dst_table._num_rows is not None
953
+ and src_table._num_rows > dst_table._num_rows):
532
954
  score += 1.0
533
955
 
534
956
  if score < 5.0:
535
957
  continue
536
958
 
537
- candidate_dict[(
538
- src_table.name,
539
- src_key.name,
540
- )].append((
541
- dst_table.name,
542
- score,
543
- ))
959
+ candidate_dict[(src_table.name, src_key.name)].append(
960
+ (dst_table.name, score))
544
961
 
545
962
  for (src_table_name, src_key_name), scores in candidate_dict.items():
546
963
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -574,6 +991,10 @@ class LocalGraph:
574
991
  raise ValueError("At least one table needs to be added to the "
575
992
  "graph")
576
993
 
994
+ backends = {table.backend for table in self._tables.values()}
995
+ if len(backends) != 1:
996
+ raise ValueError("Found multiple table backends in the graph")
997
+
577
998
  for edge in self.edges:
578
999
  src_table, fkey, dst_table = edge
579
1000
 
@@ -595,24 +1016,26 @@ class LocalGraph:
595
1016
  f"either the primary key or the link before "
596
1017
  f"before proceeding.")
597
1018
 
598
- # Check that fkey/pkey have valid and consistent data types:
599
- assert src_key.dtype is not None
600
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
601
- src_string = src_key.dtype.is_string()
602
- assert dst_key.dtype is not None
603
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
604
- dst_string = dst_key.dtype.is_string()
605
-
606
- if not src_number and not src_string:
607
- raise ValueError(f"{edge} is invalid as foreign key must be a "
608
- f"number or string (got '{src_key.dtype}'")
609
-
610
- if src_number != dst_number or src_string != dst_string:
611
- raise ValueError(f"{edge} is invalid as foreign key "
612
- f"'{fkey}' and primary key '{dst_key.name}' "
613
- f"have incompatible data types (got "
614
- f"fkey.dtype '{src_key.dtype}' and "
615
- f"pkey.dtype '{dst_key.dtype}')")
1019
+ if self.backend == DataBackend.LOCAL:
1020
+ # Check that fkey/pkey have valid and consistent data types:
1021
+ assert src_key.dtype is not None
1022
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1023
+ src_string = src_key.dtype.is_string()
1024
+ assert dst_key.dtype is not None
1025
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1026
+ dst_string = dst_key.dtype.is_string()
1027
+
1028
+ if not src_number and not src_string:
1029
+ raise ValueError(
1030
+ f"{edge} is invalid as foreign key must be a number "
1031
+ f"or string (got '{src_key.dtype}'")
1032
+
1033
+ if src_number != dst_number or src_string != dst_string:
1034
+ raise ValueError(
1035
+ f"{edge} is invalid as foreign key '{fkey}' and "
1036
+ f"primary key '{dst_key.name}' have incompatible data "
1037
+ f"types (got foreign key data type '{src_key.dtype}' "
1038
+ f"and primary key data type '{dst_key.dtype}')")
616
1039
 
617
1040
  return self
618
1041
 
@@ -620,7 +1043,7 @@ class LocalGraph:
620
1043
 
621
1044
  def visualize(
622
1045
  self,
623
- path: Optional[Union[str, io.BytesIO]] = None,
1046
+ path: str | io.BytesIO | None = None,
624
1047
  show_columns: bool = True,
625
1048
  ) -> 'graphviz.Graph':
626
1049
  r"""Visualizes the tables and edges in this graph using the
@@ -645,33 +1068,33 @@ class LocalGraph:
645
1068
 
646
1069
  return True
647
1070
 
648
- # Check basic dependency:
649
- if not find_spec('graphviz'):
650
- raise ModuleNotFoundError("The 'graphviz' package is required for "
651
- "visualization")
652
- elif not has_graphviz_executables():
1071
+ try: # Check basic dependency:
1072
+ import graphviz
1073
+ except ImportError as e:
1074
+ raise ImportError("The 'graphviz' package is required for "
1075
+ "visualization") from e
1076
+
1077
+ if not in_snowflake_notebook() and not has_graphviz_executables():
653
1078
  raise RuntimeError("Could not visualize graph as 'graphviz' "
654
1079
  "executables are not installed. These "
655
1080
  "dependencies are required in addition to the "
656
1081
  "'graphviz' Python package. Please install "
657
1082
  "them as described at "
658
1083
  "https://graphviz.org/download/.")
659
- else:
660
- import graphviz
661
1084
 
662
- format: Optional[str] = None
1085
+ format: str | None = None
663
1086
  if isinstance(path, str):
664
1087
  format = path.split('.')[-1]
665
1088
  elif isinstance(path, io.BytesIO):
666
1089
  format = 'svg'
667
1090
  graph = graphviz.Graph(format=format)
668
1091
 
669
- def left_align(keys: List[str]) -> str:
1092
+ def left_align(keys: list[str]) -> str:
670
1093
  if len(keys) == 0:
671
1094
  return ""
672
1095
  return '\\l'.join(keys) + '\\l'
673
1096
 
674
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1097
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
675
1098
  for src_table_name, fkey_name, _ in self.edges:
676
1099
  fkeys_dict[src_table_name].append(fkey_name)
677
1100
 
@@ -741,6 +1164,9 @@ class LocalGraph:
741
1164
  graph.render(path, cleanup=True)
742
1165
  elif isinstance(path, io.BytesIO):
743
1166
  path.write(graph.pipe())
1167
+ elif in_snowflake_notebook():
1168
+ import streamlit as st
1169
+ st.graphviz_chart(graph)
744
1170
  elif in_notebook():
745
1171
  from IPython.display import display
746
1172
  display(graph)
@@ -764,8 +1190,8 @@ class LocalGraph:
764
1190
  # Helpers #################################################################
765
1191
 
766
1192
  def _to_api_graph_definition(self) -> GraphDefinition:
767
- tables: Dict[str, TableDefinition] = {}
768
- col_groups: List[ColumnKeyGroup] = []
1193
+ tables: dict[str, TableDefinition] = {}
1194
+ col_groups: list[ColumnKeyGroup] = []
769
1195
  for table_name, table in self.tables.items():
770
1196
  tables[table_name] = table._to_api_table_definition()
771
1197
  if table.primary_key is None:
@@ -790,7 +1216,7 @@ class LocalGraph:
790
1216
  def __contains__(self, name: str) -> bool:
791
1217
  return self.has_table(name)
792
1218
 
793
- def __getitem__(self, name: str) -> LocalTable:
1219
+ def __getitem__(self, name: str) -> Table:
794
1220
  return self.table(name)
795
1221
 
796
1222
  def __delitem__(self, name: str) -> None:
@@ -808,3 +1234,7 @@ class LocalGraph:
808
1234
  f' tables={tables},\n'
809
1235
  f' edges={edges},\n'
810
1236
  f')')
1237
+
1238
+ def __del__(self) -> None:
1239
+ if hasattr(self, '_connection'):
1240
+ self._connection.close()