kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +51 -24
  7. kumoai/experimental/rfm/authenticate.py +3 -4
  8. kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
  9. kumoai/experimental/rfm/backend/local/sampler.py +4 -5
  10. kumoai/experimental/rfm/backend/local/table.py +24 -30
  11. kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
  12. kumoai/experimental/rfm/backend/snow/table.py +166 -56
  13. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  14. kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
  15. kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
  16. kumoai/experimental/rfm/base/__init__.py +8 -1
  17. kumoai/experimental/rfm/base/column.py +96 -10
  18. kumoai/experimental/rfm/base/expression.py +44 -0
  19. kumoai/experimental/rfm/base/mapper.py +69 -0
  20. kumoai/experimental/rfm/base/sampler.py +28 -18
  21. kumoai/experimental/rfm/base/source.py +1 -1
  22. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  23. kumoai/experimental/rfm/base/table.py +374 -208
  24. kumoai/experimental/rfm/base/utils.py +36 -0
  25. kumoai/experimental/rfm/graph.py +335 -180
  26. kumoai/experimental/rfm/infer/__init__.py +6 -4
  27. kumoai/experimental/rfm/infer/dtype.py +10 -5
  28. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  29. kumoai/experimental/rfm/infer/pkey.py +4 -2
  30. kumoai/experimental/rfm/infer/stype.py +35 -0
  31. kumoai/experimental/rfm/infer/time_col.py +5 -4
  32. kumoai/experimental/rfm/pquery/executor.py +27 -27
  33. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  34. kumoai/experimental/rfm/relbench.py +76 -0
  35. kumoai/experimental/rfm/rfm.py +606 -361
  36. kumoai/experimental/rfm/sagemaker.py +4 -4
  37. kumoai/experimental/rfm/task_table.py +292 -0
  38. kumoai/pquery/training_table.py +16 -2
  39. kumoai/testing/snow.py +3 -3
  40. kumoai/trainer/distilled_trainer.py +175 -0
  41. kumoai/utils/__init__.py +1 -2
  42. kumoai/utils/display.py +87 -0
  43. kumoai/utils/progress_logger.py +192 -13
  44. kumoai/utils/sql.py +2 -2
  45. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
  46. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
  47. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
  48. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
  49. {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
4
+ import copy
2
5
  import io
3
6
  import warnings
4
7
  from collections import defaultdict
8
+ from collections.abc import Sequence
5
9
  from dataclasses import dataclass, field
10
+ from itertools import chain
6
11
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
12
+ from typing import TYPE_CHECKING, Any, Union
8
13
 
9
14
  import pandas as pd
10
15
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,9 +18,10 @@ from kumoapi.typing import Stype
13
18
  from typing_extensions import Self
14
19
 
15
20
  from kumoai import in_notebook, in_snowflake_notebook
16
- from kumoai.experimental.rfm.base import DataBackend, Table
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
17
22
  from kumoai.graph import Edge
18
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  import graphviz
@@ -25,8 +31,8 @@ if TYPE_CHECKING:
25
31
 
26
32
  @dataclass
27
33
  class SqliteConnectionConfig(CastMixin):
28
- uri: Union[str, Path]
29
- kwargs: Dict[str, Any] = field(default_factory=dict)
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
30
36
 
31
37
 
32
38
  class Graph:
@@ -86,27 +92,35 @@ class Graph:
86
92
  def __init__(
87
93
  self,
88
94
  tables: Sequence[Table],
89
- edges: Optional[Sequence[Edge]] = None,
95
+ edges: Sequence[Edge] | None = None,
90
96
  ) -> None:
91
97
 
92
- self._tables: Dict[str, Table] = {}
93
- self._edges: List[Edge] = []
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
94
100
 
95
101
  for table in tables:
96
102
  self.add_table(table)
97
103
 
98
- for table in tables:
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
99
107
  for fkey in table._source_foreign_key_dict.values():
100
- if fkey.name not in table or fkey.dst_table not in self:
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
101
117
  continue
102
- if self[fkey.dst_table].primary_key is None:
103
- self[fkey.dst_table].primary_key = fkey.primary_key
104
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
- raise ValueError(f"Found duplicate primary key definition "
106
- f"'{self[fkey.dst_table]._primary_key}' "
107
- f"and '{fkey.primary_key}' in table "
108
- f"'{fkey.dst_table}'.")
109
- self.link(table.name, fkey.name, fkey.dst_table)
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
120
+ continue
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
123
+ self.link(table.name, fkey.name, dst_table.name)
110
124
 
111
125
  for edge in (edges or []):
112
126
  _edge = Edge._cast(edge)
@@ -117,8 +131,8 @@ class Graph:
117
131
  @classmethod
118
132
  def from_data(
119
133
  cls,
120
- df_dict: Dict[str, pd.DataFrame],
121
- edges: Optional[Sequence[Edge]] = None,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
122
136
  infer_metadata: bool = True,
123
137
  verbose: bool = True,
124
138
  ) -> Self:
@@ -156,15 +170,17 @@ class Graph:
156
170
  verbose: Whether to print verbose output.
157
171
  """
158
172
  from kumoai.experimental.rfm.backend.local import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
173
 
161
- graph = cls(tables, edges=edges or [])
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
162
178
 
163
179
  if infer_metadata:
164
- graph.infer_metadata(False)
180
+ graph.infer_metadata(verbose=False)
165
181
 
166
182
  if edges is None:
167
- graph.infer_links(False)
183
+ graph.infer_links(verbose=False)
168
184
 
169
185
  if verbose:
170
186
  graph.print_metadata()
@@ -180,10 +196,10 @@ class Graph:
180
196
  SqliteConnectionConfig,
181
197
  str,
182
198
  Path,
183
- Dict[str, Any],
199
+ dict[str, Any],
184
200
  ],
185
- table_names: Optional[Sequence[str]] = None,
186
- edges: Optional[Sequence[Edge]] = None,
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
187
203
  infer_metadata: bool = True,
188
204
  verbose: bool = True,
189
205
  ) -> Self:
@@ -199,17 +215,25 @@ class Graph:
199
215
  >>> # Create a graph from a SQLite database:
200
216
  >>> graph = rfm.Graph.from_sqlite('data.db')
201
217
 
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
202
225
  Args:
203
226
  connection: An open connection from
204
227
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
228
  path to the database file.
206
- table_names: Set of table names to include. If ``None``, will add
207
- all tables present in the database.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
208
232
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
209
233
  add to the graph. If not provided, edges will be automatically
210
234
  inferred from the data in case ``infer_metadata=True``.
211
- infer_metadata: Whether to infer metadata for all tables in the
212
- graph.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
213
237
  verbose: Whether to print verbose output.
214
238
  """
215
239
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -226,24 +250,33 @@ class Graph:
226
250
  internal_connection = True
227
251
  assert isinstance(connection, Connection)
228
252
 
229
- if table_names is None:
253
+ if tables is None:
230
254
  with connection.cursor() as cursor:
231
255
  cursor.execute("SELECT name FROM sqlite_master "
232
256
  "WHERE type='table'")
233
- table_names = [row[0] for row in cursor.fetchall()]
257
+ tables = [row[0] for row in cursor.fetchall()]
234
258
 
235
- tables = [SQLiteTable(connection, name) for name in table_names]
236
-
237
- graph = cls(tables, edges=edges or [])
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
263
+
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
238
271
 
239
272
  if internal_connection:
240
273
  graph._connection = connection # type: ignore
241
274
 
242
275
  if infer_metadata:
243
- graph.infer_metadata(False)
276
+ graph.infer_metadata(verbose=False)
244
277
 
245
278
  if edges is None:
246
- graph.infer_links(False)
279
+ graph.infer_links(verbose=False)
247
280
 
248
281
  if verbose:
249
282
  graph.print_metadata()
@@ -254,11 +287,11 @@ class Graph:
254
287
  @classmethod
255
288
  def from_snowflake(
256
289
  cls,
257
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
258
- database: Optional[str] = None,
259
- schema: Optional[str] = None,
260
- table_names: Optional[Sequence[str]] = None,
261
- edges: Optional[Sequence[Edge]] = None,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
262
295
  infer_metadata: bool = True,
263
296
  verbose: bool = True,
264
297
  ) -> Self:
@@ -275,6 +308,13 @@ class Graph:
275
308
  >>> # Create a graph directly in a Snowflake notebook:
276
309
  >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
277
310
 
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
317
+
278
318
  Args:
279
319
  connection: An open connection from
280
320
  :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
@@ -282,10 +322,11 @@ class Graph:
282
322
  connection. If ``None``, will re-use an active session in case
283
323
  it exists, or create a new connection from credentials stored
284
324
  in environment variables.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
285
328
  database: The database.
286
329
  schema: The schema.
287
- table_names: Set of table names to include. If ``None``, will add
288
- all tables present in the database.
289
330
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
290
331
  add to the graph. If not provided, edges will be automatically
291
332
  inferred from the data in case ``infer_metadata=True``.
@@ -303,37 +344,50 @@ class Graph:
303
344
  connection = connect(**(connection or {}))
304
345
  assert isinstance(connection, Connection)
305
346
 
306
- if table_names is None:
347
+ if database is None or schema is None:
348
+ with connection.cursor() as cursor:
349
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
307
360
  with connection.cursor() as cursor:
308
- if database is None and schema is None:
309
- cursor.execute("SELECT CURRENT_DATABASE(), "
310
- "CURRENT_SCHEMA()")
311
- result = cursor.fetchone()
312
- database = database or result[0]
313
- schema = schema or result[1]
314
361
  cursor.execute(f"""
315
362
  SELECT TABLE_NAME
316
363
  FROM {database}.INFORMATION_SCHEMA.TABLES
317
364
  WHERE TABLE_SCHEMA = '{schema}'
318
365
  """)
319
- table_names = [row[0] for row in cursor.fetchall()]
320
-
321
- tables = [
322
- SnowTable(
323
- connection,
324
- name=table_name,
325
- database=database,
326
- schema=schema,
327
- ) for table_name in table_names
328
- ]
366
+ tables = [row[0] for row in cursor.fetchall()]
329
367
 
330
- graph = cls(tables, edges=edges or [])
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
331
385
 
332
386
  if infer_metadata:
333
- graph.infer_metadata(False)
387
+ graph.infer_metadata(verbose=False)
334
388
 
335
389
  if edges is None:
336
- graph.infer_links(False)
390
+ graph.infer_links(verbose=False)
337
391
 
338
392
  if verbose:
339
393
  graph.print_metadata()
@@ -345,7 +399,7 @@ class Graph:
345
399
  def from_snowflake_semantic_view(
346
400
  cls,
347
401
  semantic_view_name: str,
348
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
349
403
  verbose: bool = True,
350
404
  ) -> Self:
351
405
  import yaml
@@ -363,35 +417,150 @@ class Graph:
363
417
  with connection.cursor() as cursor:
364
418
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
365
419
  f"'{semantic_view_name}')")
366
- view = yaml.safe_load(cursor.fetchone()[0])
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
367
421
 
368
422
  graph = cls(tables=[])
369
423
 
370
- for table_desc in view['tables']:
371
- primary_key: Optional[str] = None
372
- if ('primary_key' in table_desc # NOTE No composite keys yet.
373
- and len(table_desc['primary_key']['columns']) == 1):
374
- primary_key = table_desc['primary_key']['columns'][0]
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
375
483
 
376
484
  table = SnowTable(
377
485
  connection,
378
- name=table_desc['base_table']['table'],
379
- database=table_desc['base_table']['database'],
380
- schema=table_desc['base_table']['schema'],
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
381
491
  primary_key=primary_key,
382
492
  )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
383
497
  graph.add_table(table)
384
498
 
385
- # TODO Find a solution to register time columns!
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
386
505
 
387
- for relations in view['relationships']:
388
- if len(relations['relationship_columns']) != 1:
389
- continue # NOTE No composite keys yet.
390
- graph.link(
391
- src_table=relations['left_table'],
392
- fkey=relations['relationship_columns'][0]['left_column'],
393
- dst_table=relations['right_table'],
394
- )
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
530
+
531
+ if verbose:
532
+ graph.print_metadata()
533
+ graph.print_links()
534
+
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
540
+ return graph
541
+
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
395
564
 
396
565
  if verbose:
397
566
  graph.print_metadata()
@@ -425,7 +594,7 @@ class Graph:
425
594
  return self.tables[name]
426
595
 
427
596
  @property
428
- def tables(self) -> Dict[str, Table]:
597
+ def tables(self) -> dict[str, Table]:
429
598
  r"""Returns the dictionary of table objects."""
430
599
  return self._tables
431
600
 
@@ -480,28 +649,28 @@ class Graph:
480
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
481
650
  information about the tables in this graph.
482
651
 
483
- The returned dataframe has columns ``name``, ``primary_key``,
484
- ``time_column``, and ``end_time_column``, which provide an aggregate
485
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
486
655
 
487
656
  Example:
488
657
  >>> # doctest: +SKIP
489
658
  >>> import kumoai.experimental.rfm as rfm
490
659
  >>> graph = rfm.Graph(tables=...).infer_metadata()
491
660
  >>> graph.metadata # doctest: +SKIP
492
- name primary_key time_column end_time_column
493
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
494
663
  """
495
664
  tables = list(self.tables.values())
496
665
 
497
666
  return pd.DataFrame({
498
- 'name':
667
+ 'Name':
499
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
500
- 'primary_key':
669
+ 'Primary Key':
501
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
502
- 'time_column':
671
+ 'Time Column':
503
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
504
- 'end_time_column':
673
+ 'End Time Column':
505
674
  pd.Series(
506
675
  dtype=str,
507
676
  data=[t._end_time_column or '-' for t in tables],
@@ -510,24 +679,8 @@ class Graph:
510
679
 
511
680
  def print_metadata(self) -> None:
512
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
513
- if in_snowflake_notebook():
514
- import streamlit as st
515
- st.markdown("### 🗂️ Graph Metadata")
516
- st.dataframe(self.metadata, hide_index=True)
517
- elif in_notebook():
518
- from IPython.display import Markdown, display
519
- display(Markdown("### 🗂️ Graph Metadata"))
520
- df = self.metadata
521
- try:
522
- if hasattr(df.style, 'hide'):
523
- display(df.style.hide(axis='index')) # pandas=2
524
- else:
525
- display(df.style.hide_index()) # pandas<1.3
526
- except ImportError:
527
- print(df.to_string(index=False)) # missing jinja2
528
- else:
529
- print("🗂️ Graph Metadata:")
530
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
531
684
 
532
685
  def infer_metadata(self, verbose: bool = True) -> Self:
533
686
  r"""Infers metadata for all tables in the graph.
@@ -550,52 +703,33 @@ class Graph:
550
703
  # Edges ###################################################################
551
704
 
552
705
  @property
553
- def edges(self) -> List[Edge]:
706
+ def edges(self) -> list[Edge]:
554
707
  r"""Returns the edges of the graph."""
555
708
  return self._edges
556
709
 
557
710
  def print_links(self) -> None:
558
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
559
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
560
- edge.src_table, edge.fkey) for edge in self.edges]
561
- edges = sorted(edges)
562
-
563
- if in_snowflake_notebook():
564
- import streamlit as st
565
- st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
566
- if len(edges) > 0:
567
- st.markdown('\n'.join([
568
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
569
- for edge in edges
570
- ]))
571
- else:
572
- st.markdown("*No links registered*")
573
- elif in_notebook():
574
- from IPython.display import Markdown, display
575
- display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
576
- if len(edges) > 0:
577
- display(
578
- Markdown('\n'.join([
579
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
580
- for edge in edges
581
- ])))
582
- else:
583
- display(Markdown("*No links registered*"))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
584
725
  else:
585
- print("🕸️ Graph Links (FK ↔️ PK):")
586
- if len(edges) > 0:
587
- print('\n'.join([
588
- f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
589
- for edge in edges
590
- ]))
591
- else:
592
- print("No links registered")
726
+ display.italic("No links registered")
593
727
 
594
728
  def link(
595
729
  self,
596
- src_table: Union[str, Table],
730
+ src_table: str | Table,
597
731
  fkey: str,
598
- dst_table: Union[str, Table],
732
+ dst_table: str | Table,
599
733
  ) -> Self:
600
734
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
601
735
  key ``fkey`` in the source table to the primary key in the destination
@@ -656,9 +790,9 @@ class Graph:
656
790
 
657
791
  def unlink(
658
792
  self,
659
- src_table: Union[str, Table],
793
+ src_table: str | Table,
660
794
  fkey: str,
661
- dst_table: Union[str, Table],
795
+ dst_table: str | Table,
662
796
  ) -> Self:
663
797
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
664
798
 
@@ -696,6 +830,30 @@ class Graph:
696
830
  """
697
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
698
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
699
857
  # A list of primary key candidates (+score) for every column:
700
858
  candidate_dict: dict[
701
859
  tuple[str, str],
@@ -795,13 +953,8 @@ class Graph:
795
953
  if score < 5.0:
796
954
  continue
797
955
 
798
- candidate_dict[(
799
- src_table.name,
800
- src_key.name,
801
- )].append((
802
- dst_table.name,
803
- score,
804
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
805
958
 
806
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
807
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -860,24 +1013,26 @@ class Graph:
860
1013
  f"either the primary key or the link before "
861
1014
  f"before proceeding.")
862
1015
 
863
- # Check that fkey/pkey have valid and consistent data types:
864
- assert src_key.dtype is not None
865
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
866
- src_string = src_key.dtype.is_string()
867
- assert dst_key.dtype is not None
868
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
869
- dst_string = dst_key.dtype.is_string()
870
-
871
- if not src_number and not src_string:
872
- raise ValueError(f"{edge} is invalid as foreign key must be a "
873
- f"number or string (got '{src_key.dtype}'")
874
-
875
- if src_number != dst_number or src_string != dst_string:
876
- raise ValueError(f"{edge} is invalid as foreign key "
877
- f"'{fkey}' and primary key '{dst_key.name}' "
878
- f"have incompatible data types (got "
879
- f"fkey.dtype '{src_key.dtype}' and "
880
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
881
1036
 
882
1037
  return self
883
1038
 
@@ -885,7 +1040,7 @@ class Graph:
885
1040
 
886
1041
  def visualize(
887
1042
  self,
888
- path: Optional[Union[str, io.BytesIO]] = None,
1043
+ path: str | io.BytesIO | None = None,
889
1044
  show_columns: bool = True,
890
1045
  ) -> 'graphviz.Graph':
891
1046
  r"""Visualizes the tables and edges in this graph using the
@@ -924,19 +1079,19 @@ class Graph:
924
1079
  "them as described at "
925
1080
  "https://graphviz.org/download/.")
926
1081
 
927
- format: Optional[str] = None
1082
+ format: str | None = None
928
1083
  if isinstance(path, str):
929
1084
  format = path.split('.')[-1]
930
1085
  elif isinstance(path, io.BytesIO):
931
1086
  format = 'svg'
932
1087
  graph = graphviz.Graph(format=format)
933
1088
 
934
- def left_align(keys: List[str]) -> str:
1089
+ def left_align(keys: list[str]) -> str:
935
1090
  if len(keys) == 0:
936
1091
  return ""
937
1092
  return '\\l'.join(keys) + '\\l'
938
1093
 
939
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
940
1095
  for src_table_name, fkey_name, _ in self.edges:
941
1096
  fkeys_dict[src_table_name].append(fkey_name)
942
1097
 
@@ -1032,8 +1187,8 @@ class Graph:
1032
1187
  # Helpers #################################################################
1033
1188
 
1034
1189
  def _to_api_graph_definition(self) -> GraphDefinition:
1035
- tables: Dict[str, TableDefinition] = {}
1036
- col_groups: List[ColumnKeyGroup] = []
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
1037
1192
  for table_name, table in self.tables.items():
1038
1193
  tables[table_name] = table._to_api_table_definition()
1039
1194
  if table.primary_key is None: