kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +26 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/connector/utils.py +21 -7
  7. kumoai/experimental/rfm/__init__.py +51 -24
  8. kumoai/experimental/rfm/authenticate.py +3 -4
  9. kumoai/experimental/rfm/backend/local/__init__.py +4 -0
  10. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
  11. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  12. kumoai/experimental/rfm/backend/local/table.py +35 -31
  13. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  14. kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
  15. kumoai/experimental/rfm/backend/snow/table.py +178 -50
  16. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  17. kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
  18. kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
  19. kumoai/experimental/rfm/base/__init__.py +22 -4
  20. kumoai/experimental/rfm/base/column.py +96 -10
  21. kumoai/experimental/rfm/base/expression.py +44 -0
  22. kumoai/experimental/rfm/base/mapper.py +69 -0
  23. kumoai/experimental/rfm/base/sampler.py +696 -47
  24. kumoai/experimental/rfm/base/source.py +2 -1
  25. kumoai/experimental/rfm/base/sql_sampler.py +385 -0
  26. kumoai/experimental/rfm/base/table.py +384 -207
  27. kumoai/experimental/rfm/base/utils.py +36 -0
  28. kumoai/experimental/rfm/graph.py +359 -187
  29. kumoai/experimental/rfm/infer/__init__.py +6 -4
  30. kumoai/experimental/rfm/infer/dtype.py +10 -5
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +4 -2
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +5 -4
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +770 -467
  39. kumoai/experimental/rfm/sagemaker.py +4 -4
  40. kumoai/experimental/rfm/task_table.py +292 -0
  41. kumoai/kumolib.cp310-win_amd64.pyd +0 -0
  42. kumoai/pquery/predictive_query.py +10 -6
  43. kumoai/pquery/training_table.py +16 -2
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +87 -0
  48. kumoai/utils/progress_logger.py +192 -13
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
  51. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
  52. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  53. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  54. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
  55. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
  56. {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,15 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
4
+ import copy
2
5
  import io
3
6
  import warnings
4
7
  from collections import defaultdict
8
+ from collections.abc import Sequence
5
9
  from dataclasses import dataclass, field
10
+ from itertools import chain
6
11
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
12
+ from typing import TYPE_CHECKING, Any, Union
8
13
 
9
14
  import pandas as pd
10
15
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,9 +18,10 @@ from kumoapi.typing import Stype
13
18
  from typing_extensions import Self
14
19
 
15
20
  from kumoai import in_notebook, in_snowflake_notebook
16
- from kumoai.experimental.rfm import Table
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
17
22
  from kumoai.graph import Edge
18
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  import graphviz
@@ -25,8 +31,8 @@ if TYPE_CHECKING:
25
31
 
26
32
  @dataclass
27
33
  class SqliteConnectionConfig(CastMixin):
28
- uri: Union[str, Path]
29
- kwargs: Dict[str, Any] = field(default_factory=dict)
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
30
36
 
31
37
 
32
38
  class Graph:
@@ -86,27 +92,35 @@ class Graph:
86
92
  def __init__(
87
93
  self,
88
94
  tables: Sequence[Table],
89
- edges: Optional[Sequence[Edge]] = None,
95
+ edges: Sequence[Edge] | None = None,
90
96
  ) -> None:
91
97
 
92
- self._tables: Dict[str, Table] = {}
93
- self._edges: List[Edge] = []
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
94
100
 
95
101
  for table in tables:
96
102
  self.add_table(table)
97
103
 
98
- for table in tables:
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
99
107
  for fkey in table._source_foreign_key_dict.values():
100
- if fkey.name not in table or fkey.dst_table not in self:
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
117
+ continue
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
101
120
  continue
102
- if self[fkey.dst_table].primary_key is None:
103
- self[fkey.dst_table].primary_key = fkey.primary_key
104
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
- raise ValueError(f"Found duplicate primary key definition "
106
- f"'{self[fkey.dst_table]._primary_key}' "
107
- f"and '{fkey.primary_key}' in table "
108
- f"'{fkey.dst_table}'.")
109
- self.link(table.name, fkey.name, fkey.dst_table)
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
123
+ self.link(table.name, fkey.name, dst_table.name)
110
124
 
111
125
  for edge in (edges or []):
112
126
  _edge = Edge._cast(edge)
@@ -117,8 +131,8 @@ class Graph:
117
131
  @classmethod
118
132
  def from_data(
119
133
  cls,
120
- df_dict: Dict[str, pd.DataFrame],
121
- edges: Optional[Sequence[Edge]] = None,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
122
136
  infer_metadata: bool = True,
123
137
  verbose: bool = True,
124
138
  ) -> Self:
@@ -156,15 +170,17 @@ class Graph:
156
170
  verbose: Whether to print verbose output.
157
171
  """
158
172
  from kumoai.experimental.rfm.backend.local import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
173
 
161
- graph = cls(tables, edges=edges or [])
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
162
178
 
163
179
  if infer_metadata:
164
- graph.infer_metadata(False)
180
+ graph.infer_metadata(verbose=False)
165
181
 
166
182
  if edges is None:
167
- graph.infer_links(False)
183
+ graph.infer_links(verbose=False)
168
184
 
169
185
  if verbose:
170
186
  graph.print_metadata()
@@ -180,10 +196,10 @@ class Graph:
180
196
  SqliteConnectionConfig,
181
197
  str,
182
198
  Path,
183
- Dict[str, Any],
199
+ dict[str, Any],
184
200
  ],
185
- table_names: Optional[Sequence[str]] = None,
186
- edges: Optional[Sequence[Edge]] = None,
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
187
203
  infer_metadata: bool = True,
188
204
  verbose: bool = True,
189
205
  ) -> Self:
@@ -199,17 +215,25 @@ class Graph:
199
215
  >>> # Create a graph from a SQLite database:
200
216
  >>> graph = rfm.Graph.from_sqlite('data.db')
201
217
 
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
202
225
  Args:
203
226
  connection: An open connection from
204
227
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
228
  path to the database file.
206
- table_names: Set of table names to include. If ``None``, will add
207
- all tables present in the database.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
208
232
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
209
233
  add to the graph. If not provided, edges will be automatically
210
234
  inferred from the data in case ``infer_metadata=True``.
211
- infer_metadata: Whether to infer metadata for all tables in the
212
- graph.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
213
237
  verbose: Whether to print verbose output.
214
238
  """
215
239
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -218,27 +242,41 @@ class Graph:
218
242
  connect,
219
243
  )
220
244
 
245
+ internal_connection = False
221
246
  if not isinstance(connection, Connection):
222
247
  connection = SqliteConnectionConfig._cast(connection)
223
248
  assert isinstance(connection, SqliteConnectionConfig)
224
249
  connection = connect(connection.uri, **connection.kwargs)
250
+ internal_connection = True
225
251
  assert isinstance(connection, Connection)
226
252
 
227
- if table_names is None:
253
+ if tables is None:
228
254
  with connection.cursor() as cursor:
229
255
  cursor.execute("SELECT name FROM sqlite_master "
230
256
  "WHERE type='table'")
231
- table_names = [row[0] for row in cursor.fetchall()]
257
+ tables = [row[0] for row in cursor.fetchall()]
232
258
 
233
- tables = [SQLiteTable(connection, name) for name in table_names]
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
263
+
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
234
271
 
235
- graph = cls(tables, edges=edges or [])
272
+ if internal_connection:
273
+ graph._connection = connection # type: ignore
236
274
 
237
275
  if infer_metadata:
238
- graph.infer_metadata(False)
276
+ graph.infer_metadata(verbose=False)
239
277
 
240
278
  if edges is None:
241
- graph.infer_links(False)
279
+ graph.infer_links(verbose=False)
242
280
 
243
281
  if verbose:
244
282
  graph.print_metadata()
@@ -249,11 +287,11 @@ class Graph:
249
287
  @classmethod
250
288
  def from_snowflake(
251
289
  cls,
252
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
253
- database: Optional[str] = None,
254
- schema: Optional[str] = None,
255
- table_names: Optional[Sequence[str]] = None,
256
- edges: Optional[Sequence[Edge]] = None,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
257
295
  infer_metadata: bool = True,
258
296
  verbose: bool = True,
259
297
  ) -> Self:
@@ -270,6 +308,13 @@ class Graph:
270
308
  >>> # Create a graph directly in a Snowflake notebook:
271
309
  >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
272
310
 
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
317
+
273
318
  Args:
274
319
  connection: An open connection from
275
320
  :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
@@ -277,10 +322,11 @@ class Graph:
277
322
  connection. If ``None``, will re-use an active session in case
278
323
  it exists, or create a new connection from credentials stored
279
324
  in environment variables.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
280
328
  database: The database.
281
329
  schema: The schema.
282
- table_names: Set of table names to include. If ``None``, will add
283
- all tables present in the database.
284
330
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
285
331
  add to the graph. If not provided, edges will be automatically
286
332
  inferred from the data in case ``infer_metadata=True``.
@@ -298,37 +344,50 @@ class Graph:
298
344
  connection = connect(**(connection or {}))
299
345
  assert isinstance(connection, Connection)
300
346
 
301
- if table_names is None:
347
+ if database is None or schema is None:
348
+ with connection.cursor() as cursor:
349
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
302
360
  with connection.cursor() as cursor:
303
- if database is None and schema is None:
304
- cursor.execute("SELECT CURRENT_DATABASE(), "
305
- "CURRENT_SCHEMA()")
306
- result = cursor.fetchone()
307
- database = database or result[0]
308
- schema = schema or result[1]
309
361
  cursor.execute(f"""
310
362
  SELECT TABLE_NAME
311
363
  FROM {database}.INFORMATION_SCHEMA.TABLES
312
364
  WHERE TABLE_SCHEMA = '{schema}'
313
365
  """)
314
- table_names = [row[0] for row in cursor.fetchall()]
315
-
316
- tables = [
317
- SnowTable(
318
- connection,
319
- name=table_name,
320
- database=database,
321
- schema=schema,
322
- ) for table_name in table_names
323
- ]
366
+ tables = [row[0] for row in cursor.fetchall()]
324
367
 
325
- graph = cls(tables, edges=edges or [])
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
326
385
 
327
386
  if infer_metadata:
328
- graph.infer_metadata(False)
387
+ graph.infer_metadata(verbose=False)
329
388
 
330
389
  if edges is None:
331
- graph.infer_links(False)
390
+ graph.infer_links(verbose=False)
332
391
 
333
392
  if verbose:
334
393
  graph.print_metadata()
@@ -340,7 +399,7 @@ class Graph:
340
399
  def from_snowflake_semantic_view(
341
400
  cls,
342
401
  semantic_view_name: str,
343
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
344
403
  verbose: bool = True,
345
404
  ) -> Self:
346
405
  import yaml
@@ -358,43 +417,165 @@ class Graph:
358
417
  with connection.cursor() as cursor:
359
418
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
360
419
  f"'{semantic_view_name}')")
361
- view = yaml.safe_load(cursor.fetchone()[0])
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
362
421
 
363
422
  graph = cls(tables=[])
364
423
 
365
- for table_desc in view['tables']:
366
- primary_key: Optional[str] = None
367
- if ('primary_key' in table_desc # NOTE No composite keys yet.
368
- and len(table_desc['primary_key']['columns']) == 1):
369
- primary_key = table_desc['primary_key']['columns'][0]
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
370
483
 
371
484
  table = SnowTable(
372
485
  connection,
373
- name=table_desc['base_table']['table'],
374
- database=table_desc['base_table']['database'],
375
- schema=table_desc['base_table']['schema'],
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
376
491
  primary_key=primary_key,
377
492
  )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
378
497
  graph.add_table(table)
379
498
 
380
- # TODO Find a solution to register time columns!
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
381
505
 
382
- for relations in view['relationships']:
383
- if len(relations['relationship_columns']) != 1:
384
- continue # NOTE No composite keys yet.
385
- graph.link(
386
- src_table=relations['left_table'],
387
- fkey=relations['relationship_columns'][0]['left_column'],
388
- dst_table=relations['right_table'],
389
- )
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
390
530
 
391
531
  if verbose:
392
532
  graph.print_metadata()
393
533
  graph.print_links()
394
534
 
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
395
540
  return graph
396
541
 
397
- # Tables ##############################################################
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
571
+ # Backend #################################################################
572
+
573
+ @property
574
+ def backend(self) -> DataBackend | None:
575
+ backends = [table.backend for table in self._tables.values()]
576
+ return backends[0] if len(backends) > 0 else None
577
+
578
+ # Tables ##################################################################
398
579
 
399
580
  def has_table(self, name: str) -> bool:
400
581
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -413,7 +594,7 @@ class Graph:
413
594
  return self.tables[name]
414
595
 
415
596
  @property
416
- def tables(self) -> Dict[str, Table]:
597
+ def tables(self) -> dict[str, Table]:
417
598
  r"""Returns the dictionary of table objects."""
418
599
  return self._tables
419
600
 
@@ -433,13 +614,10 @@ class Graph:
433
614
  raise KeyError(f"Cannot add table with name '{table.name}' to "
434
615
  f"this graph; table names must be globally unique.")
435
616
 
436
- if len(self._tables) > 0:
437
- cls = next(iter(self._tables.values())).__class__
438
- if table.__class__ != cls:
439
- raise ValueError(f"Cannot register a "
440
- f"'{table.__class__.__name__}' to this "
441
- f"graph since other tables are of type "
442
- f"'{cls.__name__}'.")
617
+ if self.backend is not None and table.backend != self.backend:
618
+ raise ValueError(f"Cannot register a table with backend "
619
+ f"'{table.backend}' to this graph since other "
620
+ f"tables have backend '{self.backend}'.")
443
621
 
444
622
  self._tables[table.name] = table
445
623
 
@@ -471,28 +649,28 @@ class Graph:
471
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
472
650
  information about the tables in this graph.
473
651
 
474
- The returned dataframe has columns ``name``, ``primary_key``,
475
- ``time_column``, and ``end_time_column``, which provide an aggregate
476
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
477
655
 
478
656
  Example:
479
657
  >>> # doctest: +SKIP
480
658
  >>> import kumoai.experimental.rfm as rfm
481
659
  >>> graph = rfm.Graph(tables=...).infer_metadata()
482
660
  >>> graph.metadata # doctest: +SKIP
483
- name primary_key time_column end_time_column
484
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
485
663
  """
486
664
  tables = list(self.tables.values())
487
665
 
488
666
  return pd.DataFrame({
489
- 'name':
667
+ 'Name':
490
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
491
- 'primary_key':
669
+ 'Primary Key':
492
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
493
- 'time_column':
671
+ 'Time Column':
494
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
495
- 'end_time_column':
673
+ 'End Time Column':
496
674
  pd.Series(
497
675
  dtype=str,
498
676
  data=[t._end_time_column or '-' for t in tables],
@@ -501,24 +679,8 @@ class Graph:
501
679
 
502
680
  def print_metadata(self) -> None:
503
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
504
- if in_snowflake_notebook():
505
- import streamlit as st
506
- st.markdown("### 🗂️ Graph Metadata")
507
- st.dataframe(self.metadata, hide_index=True)
508
- elif in_notebook():
509
- from IPython.display import Markdown, display
510
- display(Markdown("### 🗂️ Graph Metadata"))
511
- df = self.metadata
512
- try:
513
- if hasattr(df.style, 'hide'):
514
- display(df.style.hide(axis='index')) # pandas=2
515
- else:
516
- display(df.style.hide_index()) # pandas<1.3
517
- except ImportError:
518
- print(df.to_string(index=False)) # missing jinja2
519
- else:
520
- print("🗂️ Graph Metadata:")
521
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
522
684
 
523
685
  def infer_metadata(self, verbose: bool = True) -> Self:
524
686
  r"""Infers metadata for all tables in the graph.
@@ -541,52 +703,33 @@ class Graph:
541
703
  # Edges ###################################################################
542
704
 
543
705
  @property
544
- def edges(self) -> List[Edge]:
706
+ def edges(self) -> list[Edge]:
545
707
  r"""Returns the edges of the graph."""
546
708
  return self._edges
547
709
 
548
710
  def print_links(self) -> None:
549
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
550
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
551
- edge.src_table, edge.fkey) for edge in self.edges]
552
- edges = sorted(edges)
553
-
554
- if in_snowflake_notebook():
555
- import streamlit as st
556
- st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
557
- if len(edges) > 0:
558
- st.markdown('\n'.join([
559
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
560
- for edge in edges
561
- ]))
562
- else:
563
- st.markdown("*No links registered*")
564
- elif in_notebook():
565
- from IPython.display import Markdown, display
566
- display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
567
- if len(edges) > 0:
568
- display(
569
- Markdown('\n'.join([
570
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
571
- for edge in edges
572
- ])))
573
- else:
574
- display(Markdown("*No links registered*"))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
575
725
  else:
576
- print("🕸️ Graph Links (FK ↔️ PK):")
577
- if len(edges) > 0:
578
- print('\n'.join([
579
- f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
580
- for edge in edges
581
- ]))
582
- else:
583
- print("No links registered")
726
+ display.italic("No links registered")
584
727
 
585
728
  def link(
586
729
  self,
587
- src_table: Union[str, Table],
730
+ src_table: str | Table,
588
731
  fkey: str,
589
- dst_table: Union[str, Table],
732
+ dst_table: str | Table,
590
733
  ) -> Self:
591
734
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
592
735
  key ``fkey`` in the source table to the primary key in the destination
@@ -647,9 +790,9 @@ class Graph:
647
790
 
648
791
  def unlink(
649
792
  self,
650
- src_table: Union[str, Table],
793
+ src_table: str | Table,
651
794
  fkey: str,
652
- dst_table: Union[str, Table],
795
+ dst_table: str | Table,
653
796
  ) -> Self:
654
797
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
655
798
 
@@ -687,6 +830,30 @@ class Graph:
687
830
  """
688
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
689
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
690
857
  # A list of primary key candidates (+score) for every column:
691
858
  candidate_dict: dict[
692
859
  tuple[str, str],
@@ -786,13 +953,8 @@ class Graph:
786
953
  if score < 5.0:
787
954
  continue
788
955
 
789
- candidate_dict[(
790
- src_table.name,
791
- src_key.name,
792
- )].append((
793
- dst_table.name,
794
- score,
795
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
796
958
 
797
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
798
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -826,6 +988,10 @@ class Graph:
826
988
  raise ValueError("At least one table needs to be added to the "
827
989
  "graph")
828
990
 
991
+ backends = {table.backend for table in self._tables.values()}
992
+ if len(backends) != 1:
993
+ raise ValueError("Found multiple table backends in the graph")
994
+
829
995
  for edge in self.edges:
830
996
  src_table, fkey, dst_table = edge
831
997
 
@@ -847,24 +1013,26 @@ class Graph:
847
1013
  f"either the primary key or the link before "
848
1014
  f"before proceeding.")
849
1015
 
850
- # Check that fkey/pkey have valid and consistent data types:
851
- assert src_key.dtype is not None
852
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
853
- src_string = src_key.dtype.is_string()
854
- assert dst_key.dtype is not None
855
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
856
- dst_string = dst_key.dtype.is_string()
857
-
858
- if not src_number and not src_string:
859
- raise ValueError(f"{edge} is invalid as foreign key must be a "
860
- f"number or string (got '{src_key.dtype}'")
861
-
862
- if src_number != dst_number or src_string != dst_string:
863
- raise ValueError(f"{edge} is invalid as foreign key "
864
- f"'{fkey}' and primary key '{dst_key.name}' "
865
- f"have incompatible data types (got "
866
- f"fkey.dtype '{src_key.dtype}' and "
867
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
868
1036
 
869
1037
  return self
870
1038
 
@@ -872,7 +1040,7 @@ class Graph:
872
1040
 
873
1041
  def visualize(
874
1042
  self,
875
- path: Optional[Union[str, io.BytesIO]] = None,
1043
+ path: str | io.BytesIO | None = None,
876
1044
  show_columns: bool = True,
877
1045
  ) -> 'graphviz.Graph':
878
1046
  r"""Visualizes the tables and edges in this graph using the
@@ -911,19 +1079,19 @@ class Graph:
911
1079
  "them as described at "
912
1080
  "https://graphviz.org/download/.")
913
1081
 
914
- format: Optional[str] = None
1082
+ format: str | None = None
915
1083
  if isinstance(path, str):
916
1084
  format = path.split('.')[-1]
917
1085
  elif isinstance(path, io.BytesIO):
918
1086
  format = 'svg'
919
1087
  graph = graphviz.Graph(format=format)
920
1088
 
921
- def left_align(keys: List[str]) -> str:
1089
+ def left_align(keys: list[str]) -> str:
922
1090
  if len(keys) == 0:
923
1091
  return ""
924
1092
  return '\\l'.join(keys) + '\\l'
925
1093
 
926
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
927
1095
  for src_table_name, fkey_name, _ in self.edges:
928
1096
  fkeys_dict[src_table_name].append(fkey_name)
929
1097
 
@@ -1019,8 +1187,8 @@ class Graph:
1019
1187
  # Helpers #################################################################
1020
1188
 
1021
1189
  def _to_api_graph_definition(self) -> GraphDefinition:
1022
- tables: Dict[str, TableDefinition] = {}
1023
- col_groups: List[ColumnKeyGroup] = []
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
1024
1192
  for table_name, table in self.tables.items():
1025
1193
  tables[table_name] = table._to_api_table_definition()
1026
1194
  if table.primary_key is None:
@@ -1063,3 +1231,7 @@ class Graph:
1063
1231
  f' tables={tables},\n'
1064
1232
  f' edges={edges},\n'
1065
1233
  f')')
1234
+
1235
+ def __del__(self) -> None:
1236
+ if hasattr(self, '_connection'):
1237
+ self._connection.close()