kumoai 2.12.0.dev202511061731__cp311-cp311-win_amd64.whl → 2.14.0.dev202512311733__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. kumoai/__init__.py +41 -35
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +15 -13
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/client/pquery.py +6 -2
  6. kumoai/client/rfm.py +15 -7
  7. kumoai/connector/utils.py +23 -2
  8. kumoai/experimental/rfm/__init__.py +191 -48
  9. kumoai/experimental/rfm/authenticate.py +3 -4
  10. kumoai/experimental/rfm/backend/__init__.py +0 -0
  11. kumoai/experimental/rfm/backend/local/__init__.py +42 -0
  12. kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +65 -127
  13. kumoai/experimental/rfm/backend/local/sampler.py +312 -0
  14. kumoai/experimental/rfm/backend/local/table.py +113 -0
  15. kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
  16. kumoai/experimental/rfm/backend/snow/sampler.py +297 -0
  17. kumoai/experimental/rfm/backend/snow/table.py +242 -0
  18. kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
  19. kumoai/experimental/rfm/backend/sqlite/sampler.py +398 -0
  20. kumoai/experimental/rfm/backend/sqlite/table.py +184 -0
  21. kumoai/experimental/rfm/base/__init__.py +30 -0
  22. kumoai/experimental/rfm/base/column.py +152 -0
  23. kumoai/experimental/rfm/base/expression.py +44 -0
  24. kumoai/experimental/rfm/base/sampler.py +761 -0
  25. kumoai/experimental/rfm/base/source.py +19 -0
  26. kumoai/experimental/rfm/base/sql_sampler.py +143 -0
  27. kumoai/experimental/rfm/base/table.py +735 -0
  28. kumoai/experimental/rfm/graph.py +1237 -0
  29. kumoai/experimental/rfm/infer/__init__.py +8 -0
  30. kumoai/experimental/rfm/infer/dtype.py +82 -0
  31. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  32. kumoai/experimental/rfm/infer/pkey.py +128 -0
  33. kumoai/experimental/rfm/infer/stype.py +35 -0
  34. kumoai/experimental/rfm/infer/time_col.py +61 -0
  35. kumoai/experimental/rfm/pquery/executor.py +27 -27
  36. kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
  37. kumoai/experimental/rfm/relbench.py +76 -0
  38. kumoai/experimental/rfm/rfm.py +346 -248
  39. kumoai/experimental/rfm/sagemaker.py +138 -0
  40. kumoai/kumolib.cp311-win_amd64.pyd +0 -0
  41. kumoai/pquery/predictive_query.py +10 -6
  42. kumoai/spcs.py +1 -3
  43. kumoai/testing/decorators.py +1 -1
  44. kumoai/testing/snow.py +50 -0
  45. kumoai/trainer/distilled_trainer.py +175 -0
  46. kumoai/utils/__init__.py +3 -2
  47. kumoai/utils/display.py +51 -0
  48. kumoai/utils/progress_logger.py +188 -16
  49. kumoai/utils/sql.py +3 -0
  50. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/METADATA +13 -2
  51. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/RECORD +54 -31
  52. kumoai/experimental/rfm/local_graph.py +0 -810
  53. kumoai/experimental/rfm/local_graph_sampler.py +0 -184
  54. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  55. kumoai/experimental/rfm/local_table.py +0 -545
  56. kumoai/experimental/rfm/utils.py +0 -344
  57. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/WHEEL +0 -0
  58. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/licenses/LICENSE +0 -0
  59. {kumoai-2.12.0.dev202511061731.dist-info → kumoai-2.14.0.dev202512311733.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1237 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import copy
5
+ import io
6
+ import warnings
7
+ from collections import defaultdict
8
+ from collections.abc import Sequence
9
+ from dataclasses import dataclass, field
10
+ from itertools import chain
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Union
13
+
14
+ import pandas as pd
15
+ from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
16
+ from kumoapi.table import TableDefinition
17
+ from kumoapi.typing import Stype
18
+ from typing_extensions import Self
19
+
20
+ from kumoai import in_notebook, in_snowflake_notebook
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
22
+ from kumoai.graph import Edge
23
+ from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
25
+
26
+ if TYPE_CHECKING:
27
+ import graphviz
28
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
29
+ from snowflake.connector import SnowflakeConnection
30
+
31
+
32
+ @dataclass
33
+ class SqliteConnectionConfig(CastMixin):
34
+ uri: str | Path
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
36
+
37
+
38
+ class Graph:
39
+ r"""A graph of :class:`Table` objects, akin to relationships between
40
+ tables in a relational database.
41
+
42
+ Creating a graph is the final step of data definition; after a
43
+ :class:`Graph` is created, you can use it to initialize the
44
+ Kumo Relational Foundation Model (:class:`KumoRFM`).
45
+
46
+ .. code-block:: python
47
+
48
+ >>> # doctest: +SKIP
49
+ >>> import pandas as pd
50
+ >>> import kumoai.experimental.rfm as rfm
51
+
52
+ >>> # Load data frames into memory:
53
+ >>> df1 = pd.DataFrame(...)
54
+ >>> df2 = pd.DataFrame(...)
55
+ >>> df3 = pd.DataFrame(...)
56
+
57
+ >>> # Define tables from data frames:
58
+ >>> table1 = rfm.LocalTable(name="table1", data=df1)
59
+ >>> table2 = rfm.LocalTable(name="table2", data=df2)
60
+ >>> table3 = rfm.LocalTable(name="table3", data=df3)
61
+
62
+ >>> # Create a graph from a dictionary of tables:
63
+ >>> graph = rfm.Graph({
64
+ ... "table1": table1,
65
+ ... "table2": table2,
66
+ ... "table3": table3,
67
+ ... })
68
+
69
+ >>> # Infer table metadata:
70
+ >>> graph.infer_metadata()
71
+
72
+ >>> # Infer links/edges:
73
+ >>> graph.infer_links()
74
+
75
+ >>> # Inspect table metadata:
76
+ >>> for table in graph.tables.values():
77
+ ... table.print_metadata()
78
+
79
+ >>> # Visualize graph (if graphviz is installed):
80
+ >>> graph.visualize()
81
+
82
+ >>> # Add/Remove edges between tables:
83
+ >>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
84
+ >>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
85
+
86
+ >>> # Validate graph:
87
+ >>> graph.validate()
88
+ """
89
+
90
+ # Constructors ############################################################
91
+
92
+ def __init__(
93
+ self,
94
+ tables: Sequence[Table],
95
+ edges: Sequence[Edge] | None = None,
96
+ ) -> None:
97
+
98
+ self._tables: dict[str, Table] = {}
99
+ self._edges: list[Edge] = []
100
+
101
+ for table in tables:
102
+ self.add_table(table)
103
+
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
106
+ continue
107
+ for fkey in table._source_foreign_key_dict.values():
108
+ if fkey.name not in table:
109
+ continue
110
+ if not table[fkey.name].is_source:
111
+ continue
112
+ dst_table_names = [
113
+ table.name for table in self.tables.values()
114
+ if table.source_name == fkey.dst_table
115
+ ]
116
+ if len(dst_table_names) != 1:
117
+ continue
118
+ dst_table = self[dst_table_names[0]]
119
+ if dst_table._primary_key != fkey.primary_key:
120
+ continue
121
+ if not dst_table[fkey.primary_key].is_source:
122
+ continue
123
+ self.link(table.name, fkey.name, dst_table.name)
124
+
125
+ for edge in (edges or []):
126
+ _edge = Edge._cast(edge)
127
+ assert _edge is not None
128
+ if _edge not in self._edges:
129
+ self.link(*_edge)
130
+
131
+ @classmethod
132
+ def from_data(
133
+ cls,
134
+ df_dict: dict[str, pd.DataFrame],
135
+ edges: Sequence[Edge] | None = None,
136
+ infer_metadata: bool = True,
137
+ verbose: bool = True,
138
+ ) -> Self:
139
+ r"""Creates a :class:`Graph` from a dictionary of
140
+ :class:`pandas.DataFrame` objects.
141
+
142
+ Automatically infers table metadata and links by default.
143
+
144
+ .. code-block:: python
145
+
146
+ >>> # doctest: +SKIP
147
+ >>> import pandas as pd
148
+ >>> import kumoai.experimental.rfm as rfm
149
+
150
+ >>> # Load data frames into memory:
151
+ >>> df1 = pd.DataFrame(...)
152
+ >>> df2 = pd.DataFrame(...)
153
+ >>> df3 = pd.DataFrame(...)
154
+
155
+ >>> # Create a graph from a dictionary of data frames:
156
+ >>> graph = rfm.Graph.from_data({
157
+ ... "table1": df1,
158
+ ... "table2": df2,
159
+ ... "table3": df3,
160
+ ... })
161
+
162
+ Args:
163
+ df_dict: A dictionary of data frames, where the keys are the names
164
+ of the tables and the values hold table data.
165
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
166
+ add to the graph. If not provided, edges will be automatically
167
+ inferred from the data in case ``infer_metadata=True``.
168
+ infer_metadata: Whether to infer metadata for all tables in the
169
+ graph.
170
+ verbose: Whether to print verbose output.
171
+ """
172
+ from kumoai.experimental.rfm.backend.local import LocalTable
173
+
174
+ graph = cls(
175
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
176
+ edges=edges or [],
177
+ )
178
+
179
+ if infer_metadata:
180
+ graph.infer_metadata(verbose=False)
181
+
182
+ if edges is None:
183
+ graph.infer_links(verbose=False)
184
+
185
+ if verbose:
186
+ graph.print_metadata()
187
+ graph.print_links()
188
+
189
+ return graph
190
+
191
+ @classmethod
192
+ def from_sqlite(
193
+ cls,
194
+ connection: Union[
195
+ 'AdbcSqliteConnection',
196
+ SqliteConnectionConfig,
197
+ str,
198
+ Path,
199
+ dict[str, Any],
200
+ ],
201
+ tables: Sequence[str | dict[str, Any]] | None = None,
202
+ edges: Sequence[Edge] | None = None,
203
+ infer_metadata: bool = True,
204
+ verbose: bool = True,
205
+ ) -> Self:
206
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
207
+
208
+ Automatically infers table metadata and links by default.
209
+
210
+ .. code-block:: python
211
+
212
+ >>> # doctest: +SKIP
213
+ >>> import kumoai.experimental.rfm as rfm
214
+
215
+ >>> # Create a graph from a SQLite database:
216
+ >>> graph = rfm.Graph.from_sqlite('data.db')
217
+
218
+ >>> # Fine-grained control over table specification:
219
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
220
+ ... 'USERS',
221
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
222
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
223
+ ... ])
224
+
225
+ Args:
226
+ connection: An open connection from
227
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
228
+ path to the database file.
229
+ tables: Set of table names or :class:`SQLiteTable` keyword
230
+ arguments to include. If ``None``, will add all tables present
231
+ in the database.
232
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
233
+ add to the graph. If not provided, edges will be automatically
234
+ inferred from the data in case ``infer_metadata=True``.
235
+ infer_metadata: Whether to infer missing metadata for all tables in
236
+ the graph.
237
+ verbose: Whether to print verbose output.
238
+ """
239
+ from kumoai.experimental.rfm.backend.sqlite import (
240
+ Connection,
241
+ SQLiteTable,
242
+ connect,
243
+ )
244
+
245
+ internal_connection = False
246
+ if not isinstance(connection, Connection):
247
+ connection = SqliteConnectionConfig._cast(connection)
248
+ assert isinstance(connection, SqliteConnectionConfig)
249
+ connection = connect(connection.uri, **connection.kwargs)
250
+ internal_connection = True
251
+ assert isinstance(connection, Connection)
252
+
253
+ if tables is None:
254
+ with connection.cursor() as cursor:
255
+ cursor.execute("SELECT name FROM sqlite_master "
256
+ "WHERE type='table'")
257
+ tables = [row[0] for row in cursor.fetchall()]
258
+
259
+ table_kwargs: list[dict[str, Any]] = []
260
+ for table in tables:
261
+ kwargs = dict(name=table) if isinstance(table, str) else table
262
+ table_kwargs.append(kwargs)
263
+
264
+ graph = cls(
265
+ tables=[
266
+ SQLiteTable(connection=connection, **kwargs)
267
+ for kwargs in table_kwargs
268
+ ],
269
+ edges=edges or [],
270
+ )
271
+
272
+ if internal_connection:
273
+ graph._connection = connection # type: ignore
274
+
275
+ if infer_metadata:
276
+ graph.infer_metadata(verbose=False)
277
+
278
+ if edges is None:
279
+ graph.infer_links(verbose=False)
280
+
281
+ if verbose:
282
+ graph.print_metadata()
283
+ graph.print_links()
284
+
285
+ return graph
286
+
287
+ @classmethod
288
+ def from_snowflake(
289
+ cls,
290
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
291
+ tables: Sequence[str | dict[str, Any]] | None = None,
292
+ database: str | None = None,
293
+ schema: str | None = None,
294
+ edges: Sequence[Edge] | None = None,
295
+ infer_metadata: bool = True,
296
+ verbose: bool = True,
297
+ ) -> Self:
298
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
299
+ schema.
300
+
301
+ Automatically infers table metadata and links by default.
302
+
303
+ .. code-block:: python
304
+
305
+ >>> # doctest: +SKIP
306
+ >>> import kumoai.experimental.rfm as rfm
307
+
308
+ >>> # Create a graph directly in a Snowflake notebook:
309
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
310
+
311
+ >>> # Fine-grained control over table specification:
312
+ >>> graph = rfm.Graph.from_snowflake(tables=[
313
+ ... 'USERS',
314
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
315
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
316
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
317
+
318
+ Args:
319
+ connection: An open connection from
320
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
321
+ :class:`snowflake` connector keyword arguments to open a new
322
+ connection. If ``None``, will re-use an active session in case
323
+ it exists, or create a new connection from credentials stored
324
+ in environment variables.
325
+ tables: Set of table names or :class:`SnowTable` keyword arguments
326
+ to include. If ``None``, will add all tables present in the
327
+ current database and schema.
328
+ database: The database.
329
+ schema: The schema.
330
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
331
+ add to the graph. If not provided, edges will be automatically
332
+ inferred from the data in case ``infer_metadata=True``.
333
+ infer_metadata: Whether to infer metadata for all tables in the
334
+ graph.
335
+ verbose: Whether to print verbose output.
336
+ """
337
+ from kumoai.experimental.rfm.backend.snow import (
338
+ Connection,
339
+ SnowTable,
340
+ connect,
341
+ )
342
+
343
+ if not isinstance(connection, Connection):
344
+ connection = connect(**(connection or {}))
345
+ assert isinstance(connection, Connection)
346
+
347
+ if database is None or schema is None:
348
+ with connection.cursor() as cursor:
349
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
350
+ result = cursor.fetchone()
351
+ database = database or result[0]
352
+ assert database is not None
353
+ schema = schema or result[1]
354
+
355
+ if tables is None:
356
+ if schema is None:
357
+ raise ValueError("No current 'schema' set. Please specify the "
358
+ "Snowflake schema manually")
359
+
360
+ with connection.cursor() as cursor:
361
+ cursor.execute(f"""
362
+ SELECT TABLE_NAME
363
+ FROM {database}.INFORMATION_SCHEMA.TABLES
364
+ WHERE TABLE_SCHEMA = '{schema}'
365
+ """)
366
+ tables = [row[0] for row in cursor.fetchall()]
367
+
368
+ table_kwargs: list[dict[str, Any]] = []
369
+ for table in tables:
370
+ if isinstance(table, str):
371
+ kwargs = dict(name=table, database=database, schema=schema)
372
+ else:
373
+ kwargs = copy.copy(table)
374
+ kwargs.setdefault('database', database)
375
+ kwargs.setdefault('schema', schema)
376
+ table_kwargs.append(kwargs)
377
+
378
+ graph = cls(
379
+ tables=[
380
+ SnowTable(connection=connection, **kwargs)
381
+ for kwargs in table_kwargs
382
+ ],
383
+ edges=edges or [],
384
+ )
385
+
386
+ if infer_metadata:
387
+ graph.infer_metadata(verbose=False)
388
+
389
+ if edges is None:
390
+ graph.infer_links(verbose=False)
391
+
392
+ if verbose:
393
+ graph.print_metadata()
394
+ graph.print_links()
395
+
396
+ return graph
397
+
398
+ @classmethod
399
+ def from_snowflake_semantic_view(
400
+ cls,
401
+ semantic_view_name: str,
402
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
403
+ verbose: bool = True,
404
+ ) -> Self:
405
+ import yaml
406
+
407
+ from kumoai.experimental.rfm.backend.snow import (
408
+ Connection,
409
+ SnowTable,
410
+ connect,
411
+ )
412
+
413
+ if not isinstance(connection, Connection):
414
+ connection = connect(**(connection or {}))
415
+ assert isinstance(connection, Connection)
416
+
417
+ with connection.cursor() as cursor:
418
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
419
+ f"'{semantic_view_name}')")
420
+ cfg = yaml.safe_load(cursor.fetchone()[0])
421
+
422
+ graph = cls(tables=[])
423
+
424
+ msgs = []
425
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
426
+ for table_cfg in cfg['tables']:
427
+ table_name = table_cfg['name']
428
+ source_table_name = table_cfg['base_table']['table']
429
+ database = table_cfg['base_table']['database']
430
+ schema = table_cfg['base_table']['schema']
431
+
432
+ primary_key: str | None = None
433
+ if 'primary_key' in table_cfg:
434
+ primary_key_cfg = table_cfg['primary_key']
435
+ if len(primary_key_cfg['columns']) == 1:
436
+ primary_key = primary_key_cfg['columns'][0]
437
+ elif len(primary_key_cfg['columns']) > 1:
438
+ msgs.append(f"Failed to add primary key for table "
439
+ f"'{table_name}' since composite primary keys "
440
+ f"are not yet supported")
441
+
442
+ columns: list[ColumnSpec] = []
443
+ unsupported_columns: list[str] = []
444
+ for column_cfg in chain(
445
+ table_cfg.get('dimensions', []),
446
+ table_cfg.get('time_dimensions', []),
447
+ table_cfg.get('facts', []),
448
+ ):
449
+ column_name = column_cfg['name']
450
+ column_expr = column_cfg.get('expr', None)
451
+ column_data_type = column_cfg.get('data_type', None)
452
+
453
+ if column_expr is None:
454
+ columns.append(ColumnSpec(name=column_name))
455
+ continue
456
+
457
+ column_expr = column_expr.replace(f'{table_name}.', '')
458
+
459
+ if column_expr == column_name:
460
+ columns.append(ColumnSpec(name=column_name))
461
+ continue
462
+
463
+ # Drop expressions that reference other tables (for now):
464
+ if any(f'{name}.' in column_expr for name in table_names):
465
+ unsupported_columns.append(column_name)
466
+ continue
467
+
468
+ column = ColumnSpec(
469
+ name=column_name,
470
+ expr=column_expr,
471
+ dtype=SnowTable._to_dtype(column_data_type),
472
+ )
473
+ columns.append(column)
474
+
475
+ if len(unsupported_columns) == 1:
476
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
477
+ f"of table '{table_name}' since its expression "
478
+ f"references other tables")
479
+ elif len(unsupported_columns) > 1:
480
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
481
+ f"of table '{table_name}' since their expressions "
482
+ f"reference other tables")
483
+
484
+ table = SnowTable(
485
+ connection,
486
+ name=table_name,
487
+ source_name=source_table_name,
488
+ database=database,
489
+ schema=schema,
490
+ columns=columns,
491
+ primary_key=primary_key,
492
+ )
493
+
494
+ # TODO Add a way to register time columns without heuristic usage.
495
+ table.infer_time_column(verbose=False)
496
+
497
+ graph.add_table(table)
498
+
499
+ for relation_cfg in cfg.get('relationships', []):
500
+ name = relation_cfg['name']
501
+ if len(relation_cfg['relationship_columns']) != 1:
502
+ msgs.append(f"Failed to add relationship '{name}' since "
503
+ f"composite key references are not yet supported")
504
+ continue
505
+
506
+ left_table = relation_cfg['left_table']
507
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
508
+ right_table = relation_cfg['right_table']
509
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
510
+
511
+ if graph[right_table]._primary_key != right_key:
512
+ # Semantic view error - this should never be triggered:
513
+ msgs.append(f"Failed to add relationship '{name}' since the "
514
+ f"referenced key '{right_key}' of table "
515
+ f"'{right_table}' is not a primary key")
516
+ continue
517
+
518
+ if graph[left_table]._primary_key == left_key:
519
+ msgs.append(f"Failed to add relationship '{name}' since the "
520
+ f"referencing key '{left_key}' of table "
521
+ f"'{left_table}' is a primary key")
522
+ continue
523
+
524
+ if left_key not in graph[left_table]:
525
+ graph[left_table].add_column(left_key)
526
+
527
+ graph.link(left_table, left_key, right_table)
528
+
529
+ graph.validate()
530
+
531
+ if verbose:
532
+ graph.print_metadata()
533
+ graph.print_links()
534
+
535
+ if len(msgs) > 0:
536
+ title = (f"Could not fully convert the semantic view definition "
537
+ f"'{semantic_view_name}' into a graph:\n")
538
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
539
+
540
+ return graph
541
+
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
571
+ # Backend #################################################################
572
+
573
+ @property
574
+ def backend(self) -> DataBackend | None:
575
+ backends = [table.backend for table in self._tables.values()]
576
+ return backends[0] if len(backends) > 0 else None
577
+
578
+ # Tables ##################################################################
579
+
580
+ def has_table(self, name: str) -> bool:
581
+ r"""Returns ``True`` if the graph has a table with name ``name``;
582
+ ``False`` otherwise.
583
+ """
584
+ return name in self.tables
585
+
586
+ def table(self, name: str) -> Table:
587
+ r"""Returns the table with name ``name`` in the graph.
588
+
589
+ Raises:
590
+ KeyError: If ``name`` is not present in the graph.
591
+ """
592
+ if not self.has_table(name):
593
+ raise KeyError(f"Table '{name}' not found in graph")
594
+ return self.tables[name]
595
+
596
+ @property
597
+ def tables(self) -> dict[str, Table]:
598
+ r"""Returns the dictionary of table objects."""
599
+ return self._tables
600
+
601
+ def add_table(self, table: Table) -> Self:
602
+ r"""Adds a table to the graph.
603
+
604
+ Args:
605
+ table: The table to add.
606
+
607
+ Raises:
608
+ KeyError: If a table with the same name already exists in the
609
+ graph.
610
+ ValueError: If the table belongs to a different backend than the
611
+ rest of the tables in the graph.
612
+ """
613
+ if table.name in self._tables:
614
+ raise KeyError(f"Cannot add table with name '{table.name}' to "
615
+ f"this graph; table names must be globally unique.")
616
+
617
+ if self.backend is not None and table.backend != self.backend:
618
+ raise ValueError(f"Cannot register a table with backend "
619
+ f"'{table.backend}' to this graph since other "
620
+ f"tables have backend '{self.backend}'.")
621
+
622
+ self._tables[table.name] = table
623
+
624
+ return self
625
+
626
+ def remove_table(self, name: str) -> Self:
627
+ r"""Removes a table with ``name`` from the graph.
628
+
629
+ Args:
630
+ name: The table to remove.
631
+
632
+ Raises:
633
+ KeyError: If no such table is present in the graph.
634
+ """
635
+ if not self.has_table(name):
636
+ raise KeyError(f"Table '{name}' not found in the graph")
637
+
638
+ del self._tables[name]
639
+
640
+ self._edges = [
641
+ edge for edge in self._edges
642
+ if edge.src_table != name and edge.dst_table != name
643
+ ]
644
+
645
+ return self
646
+
647
+ @property
648
+ def metadata(self) -> pd.DataFrame:
649
+ r"""Returns a :class:`pandas.DataFrame` object containing metadata
650
+ information about the tables in this graph.
651
+
652
+ The returned dataframe has columns ``name``, ``primary_key``,
653
+ ``time_column``, and ``end_time_column``, which provide an aggregate
654
+ view of the properties of the tables of this graph.
655
+
656
+ Example:
657
+ >>> # doctest: +SKIP
658
+ >>> import kumoai.experimental.rfm as rfm
659
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
660
+ >>> graph.metadata # doctest: +SKIP
661
+ name primary_key time_column end_time_column
662
+ 0 users user_id - -
663
+ """
664
+ tables = list(self.tables.values())
665
+
666
+ return pd.DataFrame({
667
+ 'name':
668
+ pd.Series(dtype=str, data=[t.name for t in tables]),
669
+ 'primary_key':
670
+ pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
671
+ 'time_column':
672
+ pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
673
+ 'end_time_column':
674
+ pd.Series(
675
+ dtype=str,
676
+ data=[t._end_time_column or '-' for t in tables],
677
+ ),
678
+ })
679
+
680
+ def print_metadata(self) -> None:
681
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
684
+
685
+ def infer_metadata(self, verbose: bool = True) -> Self:
686
+ r"""Infers metadata for all tables in the graph.
687
+
688
+ Args:
689
+ verbose: Whether to print verbose output.
690
+
691
+ Note:
692
+ For more information, please see
693
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
694
+ """
695
+ for table in self.tables.values():
696
+ table.infer_metadata(verbose=False)
697
+
698
+ if verbose:
699
+ self.print_metadata()
700
+
701
+ return self
702
+
703
+ # Edges ###################################################################
704
+
705
+ @property
706
+ def edges(self) -> list[Edge]:
707
+ r"""Returns the edges of the graph."""
708
+ return self._edges
709
+
710
+ def print_links(self) -> None:
711
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
725
+ else:
726
+ display.italic("No links registered")
727
+
728
+ def link(
729
+ self,
730
+ src_table: str | Table,
731
+ fkey: str,
732
+ dst_table: str | Table,
733
+ ) -> Self:
734
+ r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
735
+ key ``fkey`` in the source table to the primary key in the destination
736
+ table.
737
+
738
+ The link is treated as bidirectional.
739
+
740
+ Args:
741
+ src_table: The name of the source table of the edge. This table
742
+ must have a foreign key with name :obj:`fkey` that links to the
743
+ primary key in the destination table.
744
+ fkey: The name of the foreign key in the source table.
745
+ dst_table: The name of the destination table of the edge. This
746
+ table must have a primary key that links to the source table's
747
+ foreign key.
748
+
749
+ Raises:
750
+ ValueError: if the edge is already present in the graph, if the
751
+ source table does not exist in the graph, if the destination
752
+ table does not exist in the graph, if the source key does not
753
+ exist in the source table.
754
+ """
755
+ if isinstance(src_table, Table):
756
+ src_table = src_table.name
757
+ assert isinstance(src_table, str)
758
+
759
+ if isinstance(dst_table, Table):
760
+ dst_table = dst_table.name
761
+ assert isinstance(dst_table, str)
762
+
763
+ edge = Edge(src_table, fkey, dst_table)
764
+
765
+ if edge in self.edges:
766
+ raise ValueError(f"{edge} already exists in the graph")
767
+
768
+ if not self.has_table(src_table):
769
+ raise ValueError(f"Source table '{src_table}' does not exist in "
770
+ f"the graph")
771
+
772
+ if not self.has_table(dst_table):
773
+ raise ValueError(f"Destination table '{dst_table}' does not exist "
774
+ f"in the graph")
775
+
776
+ if not self[src_table].has_column(fkey):
777
+ raise ValueError(f"Source key '{fkey}' does not exist as a column "
778
+ f"in source table '{src_table}'")
779
+
780
+ if not Stype.ID.supports_dtype(self[src_table][fkey].dtype):
781
+ raise ValueError(f"Cannot use '{fkey}' in source table "
782
+ f"'{src_table}' as a foreign key due to its "
783
+ f"incompatible data type. Foreign keys must have "
784
+ f"data type 'int', 'float' or 'string' "
785
+ f"(got '{self[src_table][fkey].dtype}')")
786
+
787
+ self._edges.append(edge)
788
+
789
+ return self
790
+
791
+ def unlink(
792
+ self,
793
+ src_table: str | Table,
794
+ fkey: str,
795
+ dst_table: str | Table,
796
+ ) -> Self:
797
+ r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
798
+
799
+ Args:
800
+ src_table: The name of the source table of the edge.
801
+ fkey: The name of the foreign key in the source table.
802
+ dst_table: The name of the destination table of the edge.
803
+
804
+ Raises:
805
+ ValueError: if the edge is not present in the graph.
806
+ """
807
+ if isinstance(src_table, Table):
808
+ src_table = src_table.name
809
+ assert isinstance(src_table, str)
810
+
811
+ if isinstance(dst_table, Table):
812
+ dst_table = dst_table.name
813
+ assert isinstance(dst_table, str)
814
+
815
+ edge = Edge(src_table, fkey, dst_table)
816
+
817
+ if edge not in self.edges:
818
+ raise ValueError(f"{edge} is not present in the graph")
819
+
820
+ self._edges.remove(edge)
821
+
822
+ return self
823
+
824
+ def infer_links(self, verbose: bool = True) -> Self:
825
+ r"""Infers missing links for the tables and adds them as edges to the
826
+ graph.
827
+
828
+ Args:
829
+ verbose: Whether to print verbose output.
830
+ """
831
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
832
+
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
857
+ # A list of primary key candidates (+score) for every column:
858
+ candidate_dict: dict[
859
+ tuple[str, str],
860
+ list[tuple[str, float]],
861
+ ] = defaultdict(list)
862
+
863
+ for dst_table in self.tables.values():
864
+ dst_key = dst_table.primary_key
865
+
866
+ if dst_key is None:
867
+ continue
868
+
869
+ assert dst_key.dtype is not None
870
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
871
+ dst_string = dst_key.dtype.is_string()
872
+
873
+ dst_table_name = dst_table.name.lower()
874
+ dst_key_name = dst_key.name.lower()
875
+
876
+ for src_table in self.tables.values():
877
+ src_table_name = src_table.name.lower()
878
+
879
+ for src_key in src_table.columns:
880
+ if (src_table.name, src_key.name) in known_edges:
881
+ continue
882
+
883
+ if src_key == src_table.primary_key:
884
+ continue # Cannot link to primary key.
885
+
886
+ src_number = (src_key.dtype.is_int()
887
+ or src_key.dtype.is_float())
888
+ src_string = src_key.dtype.is_string()
889
+
890
+ if src_number != dst_number or src_string != dst_string:
891
+ continue # Non-compatible data types.
892
+
893
+ src_key_name = src_key.name.lower()
894
+
895
+ score = 0.0
896
+
897
+ # Name similarity:
898
+ if src_key_name == dst_key_name:
899
+ score += 7.0
900
+ elif (dst_key_name != 'id'
901
+ and src_key_name.endswith(dst_key_name)):
902
+ score += 4.0
903
+ elif src_key_name.endswith( # e.g., user.id -> user_id
904
+ f'{dst_table_name}_{dst_key_name}'):
905
+ score += 4.0
906
+ elif src_key_name.endswith( # e.g., user.id -> userid
907
+ f'{dst_table_name}{dst_key_name}'):
908
+ score += 4.0
909
+ elif (dst_table_name.endswith('s') and
910
+ src_key_name.endswith( # e.g., users.id -> user_id
911
+ f'{dst_table_name[:-1]}_{dst_key_name}')):
912
+ score += 4.0
913
+ elif (dst_table_name.endswith('s') and
914
+ src_key_name.endswith( # e.g., users.id -> userid
915
+ f'{dst_table_name[:-1]}{dst_key_name}')):
916
+ score += 4.0
917
+ elif src_key_name.endswith(dst_table_name):
918
+ score += 4.0 # e.g., users -> users
919
+ elif (dst_table_name.endswith('s') # e.g., users -> user
920
+ and src_key_name.endswith(dst_table_name[:-1])):
921
+ score += 4.0
922
+ elif ((src_key_name == 'parentid'
923
+ or src_key_name == 'parent_id')
924
+ and src_table_name == dst_table_name):
925
+ score += 2.0
926
+
927
+ # `rel-bench` hard-coding :(
928
+ elif (src_table.name == 'posts'
929
+ and src_key.name == 'AcceptedAnswerId'
930
+ and dst_table.name == 'posts'):
931
+ score += 2.0
932
+ elif (src_table.name == 'user_friends'
933
+ and src_key.name == 'friend'
934
+ and dst_table.name == 'users'):
935
+ score += 3.0
936
+
937
+ # For non-exact matching, at least one additional
938
+ # requirement needs to be met.
939
+
940
+ # Exact data type compatibility:
941
+ if src_key.stype == Stype.ID:
942
+ score += 2.0
943
+
944
+ if src_key.dtype == dst_key.dtype:
945
+ score += 1.0
946
+
947
+ # Cardinality ratio:
948
+ if (src_table._num_rows is not None
949
+ and dst_table._num_rows is not None
950
+ and src_table._num_rows > dst_table._num_rows):
951
+ score += 1.0
952
+
953
+ if score < 5.0:
954
+ continue
955
+
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
958
+
959
+ for (src_table_name, src_key_name), scores in candidate_dict.items():
960
+ scores.sort(key=lambda x: x[-1], reverse=True)
961
+
962
+ if len(scores) > 1 and scores[0][1] == scores[1][1]:
963
+ continue # Cannot uniquely infer link.
964
+
965
+ dst_table_name = scores[0][0]
966
+ self.link(src_table_name, src_key_name, dst_table_name)
967
+
968
+ if verbose:
969
+ self.print_links()
970
+
971
+ return self
972
+
973
+ # Metadata ################################################################
974
+
975
+ def validate(self) -> Self:
976
+ r"""Validates the graph to ensure that all relevant metadata is
977
+ specified for its tables and edges.
978
+
979
+ Concretely, validation ensures that edges properly link foreign keys to
980
+ primary keys between valid tables.
981
+ It additionally ensures that primary and foreign keys between tables
982
+ in an :class:`~kumoai.graph.Edge` are of the same data type.
983
+
984
+ Raises:
985
+ ValueError: if validation fails.
986
+ """
987
+ if len(self.tables) == 0:
988
+ raise ValueError("At least one table needs to be added to the "
989
+ "graph")
990
+
991
+ backends = {table.backend for table in self._tables.values()}
992
+ if len(backends) != 1:
993
+ raise ValueError("Found multiple table backends in the graph")
994
+
995
+ for edge in self.edges:
996
+ src_table, fkey, dst_table = edge
997
+
998
+ src_key = self[src_table][fkey]
999
+ dst_key = self[dst_table].primary_key
1000
+
1001
+ # Check that the destination table defines a primary key:
1002
+ if dst_key is None:
1003
+ raise ValueError(f"Edge {edge} is invalid since table "
1004
+ f"'{dst_table}' does not have a primary key. "
1005
+ f"Add either a primary key or remove the "
1006
+ f"link before proceeding.")
1007
+
1008
+ # Ensure that foreign key is not a primary key:
1009
+ src_pkey = self[src_table].primary_key
1010
+ if src_pkey is not None and src_pkey.name == fkey:
1011
+ raise ValueError(f"Cannot treat the primary key of table "
1012
+ f"'{src_table}' as a foreign key. Remove "
1013
+ f"either the primary key or the link before "
1014
+ f"before proceeding.")
1015
+
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
1036
+
1037
+ return self
1038
+
1039
+ # Visualization ###########################################################
1040
+
1041
+ def visualize(
1042
+ self,
1043
+ path: str | io.BytesIO | None = None,
1044
+ show_columns: bool = True,
1045
+ ) -> 'graphviz.Graph':
1046
+ r"""Visualizes the tables and edges in this graph using the
1047
+ :class:`graphviz` library.
1048
+
1049
+ Args:
1050
+ path: A path to write the produced image to. If ``None``, the image
1051
+ will not be written to disk.
1052
+ show_columns: Whether to show all columns of every table in the
1053
+ graph. If ``False``, will only show the primary key, foreign
1054
+ key(s), and time column of each table.
1055
+
1056
+ Returns:
1057
+ A ``graphviz.Graph`` instance representing the visualized graph.
1058
+ """
1059
+ def has_graphviz_executables() -> bool:
1060
+ import graphviz
1061
+ try:
1062
+ graphviz.Digraph().pipe()
1063
+ except graphviz.backend.ExecutableNotFound:
1064
+ return False
1065
+
1066
+ return True
1067
+
1068
+ try: # Check basic dependency:
1069
+ import graphviz
1070
+ except ImportError as e:
1071
+ raise ImportError("The 'graphviz' package is required for "
1072
+ "visualization") from e
1073
+
1074
+ if not in_snowflake_notebook() and not has_graphviz_executables():
1075
+ raise RuntimeError("Could not visualize graph as 'graphviz' "
1076
+ "executables are not installed. These "
1077
+ "dependencies are required in addition to the "
1078
+ "'graphviz' Python package. Please install "
1079
+ "them as described at "
1080
+ "https://graphviz.org/download/.")
1081
+
1082
+ format: str | None = None
1083
+ if isinstance(path, str):
1084
+ format = path.split('.')[-1]
1085
+ elif isinstance(path, io.BytesIO):
1086
+ format = 'svg'
1087
+ graph = graphviz.Graph(format=format)
1088
+
1089
+ def left_align(keys: list[str]) -> str:
1090
+ if len(keys) == 0:
1091
+ return ""
1092
+ return '\\l'.join(keys) + '\\l'
1093
+
1094
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
1095
+ for src_table_name, fkey_name, _ in self.edges:
1096
+ fkeys_dict[src_table_name].append(fkey_name)
1097
+
1098
+ for table_name, table in self.tables.items():
1099
+ keys = []
1100
+ if primary_key := table.primary_key:
1101
+ keys += [f'{primary_key.name}: PK ({primary_key.dtype})']
1102
+ keys += [
1103
+ f'{fkey_name}: FK ({self[table_name][fkey_name].dtype})'
1104
+ for fkey_name in fkeys_dict[table_name]
1105
+ ]
1106
+ if time_column := table.time_column:
1107
+ keys += [f'{time_column.name}: Time ({time_column.dtype})']
1108
+ if end_time_column := table.end_time_column:
1109
+ keys += [
1110
+ f'{end_time_column.name}: '
1111
+ f'End Time ({end_time_column.dtype})'
1112
+ ]
1113
+ key_repr = left_align(keys)
1114
+
1115
+ columns = []
1116
+ if show_columns:
1117
+ columns += [
1118
+ f'{column.name}: {column.stype} ({column.dtype})'
1119
+ for column in table.columns
1120
+ if column.name not in fkeys_dict[table_name] and
1121
+ column.name != table._primary_key and column.name != table.
1122
+ _time_column and column.name != table._end_time_column
1123
+ ]
1124
+ column_repr = left_align(columns)
1125
+
1126
+ if len(keys) > 0 and len(columns) > 0:
1127
+ label = f'{{{table_name}|{key_repr}|{column_repr}}}'
1128
+ elif len(keys) > 0:
1129
+ label = f'{{{table_name}|{key_repr}}}'
1130
+ elif len(columns) > 0:
1131
+ label = f'{{{table_name}|{column_repr}}}'
1132
+ else:
1133
+ label = f'{{{table_name}}}'
1134
+
1135
+ graph.node(table_name, shape='record', label=label)
1136
+
1137
+ for src_table_name, fkey_name, dst_table_name in self.edges:
1138
+ if self[dst_table_name]._primary_key is None:
1139
+ continue # Invalid edge.
1140
+
1141
+ pkey_name = self[dst_table_name]._primary_key
1142
+
1143
+ if fkey_name != pkey_name:
1144
+ label = f' {fkey_name}\n< >\n{pkey_name} '
1145
+ else:
1146
+ label = f' {fkey_name} '
1147
+
1148
+ graph.edge(
1149
+ src_table_name,
1150
+ dst_table_name,
1151
+ label=label,
1152
+ headlabel='1',
1153
+ taillabel='*',
1154
+ minlen='2',
1155
+ fontsize='11pt',
1156
+ labeldistance='1.5',
1157
+ )
1158
+
1159
+ if isinstance(path, str):
1160
+ path = '.'.join(path.split('.')[:-1])
1161
+ graph.render(path, cleanup=True)
1162
+ elif isinstance(path, io.BytesIO):
1163
+ path.write(graph.pipe())
1164
+ elif in_snowflake_notebook():
1165
+ import streamlit as st
1166
+ st.graphviz_chart(graph)
1167
+ elif in_notebook():
1168
+ from IPython.display import display
1169
+ display(graph)
1170
+ else:
1171
+ try:
1172
+ stderr_buffer = io.StringIO()
1173
+ with contextlib.redirect_stderr(stderr_buffer):
1174
+ graph.view(cleanup=True)
1175
+ if stderr_buffer.getvalue():
1176
+ warnings.warn("Could not visualize graph since your "
1177
+ "system does not know how to open or "
1178
+ "display PDF files from the command line. "
1179
+ "Please specify 'visualize(path=...)' and "
1180
+ "open the generated file yourself.")
1181
+ except Exception as e:
1182
+ warnings.warn(f"Could not visualize graph due to an "
1183
+ f"unexpected error in 'graphviz'. Error: {e}")
1184
+
1185
+ return graph
1186
+
1187
+ # Helpers #################################################################
1188
+
1189
+ def _to_api_graph_definition(self) -> GraphDefinition:
1190
+ tables: dict[str, TableDefinition] = {}
1191
+ col_groups: list[ColumnKeyGroup] = []
1192
+ for table_name, table in self.tables.items():
1193
+ tables[table_name] = table._to_api_table_definition()
1194
+ if table.primary_key is None:
1195
+ continue
1196
+ keys = [ColumnKey(table_name, table.primary_key.name)]
1197
+ for edge in self.edges:
1198
+ if edge.dst_table == table_name:
1199
+ keys.append(ColumnKey(edge.src_table, edge.fkey))
1200
+ keys = sorted(
1201
+ list(set(keys)),
1202
+ key=lambda x: f'{x.table_name}.{x.col_name}',
1203
+ )
1204
+ if len(keys) > 1:
1205
+ col_groups.append(ColumnKeyGroup(keys))
1206
+ return GraphDefinition(tables, col_groups)
1207
+
1208
+ # Class properties ########################################################
1209
+
1210
+ def __hash__(self) -> int:
1211
+ return hash((tuple(self.edges), tuple(sorted(self.tables.keys()))))
1212
+
1213
+ def __contains__(self, name: str) -> bool:
1214
+ return self.has_table(name)
1215
+
1216
+ def __getitem__(self, name: str) -> Table:
1217
+ return self.table(name)
1218
+
1219
+ def __delitem__(self, name: str) -> None:
1220
+ self.remove_table(name)
1221
+
1222
+ def __repr__(self) -> str:
1223
+ tables = '\n'.join(f' {table},' for table in self.tables)
1224
+ tables = f'[\n{tables}\n ]' if len(tables) > 0 else '[]'
1225
+ edges = '\n'.join(
1226
+ f' {edge.src_table}.{edge.fkey}'
1227
+ f' ⇔ {edge.dst_table}.{self[edge.dst_table]._primary_key},'
1228
+ for edge in self.edges)
1229
+ edges = f'[\n{edges}\n ]' if len(edges) > 0 else '[]'
1230
+ return (f'{self.__class__.__name__}(\n'
1231
+ f' tables={tables},\n'
1232
+ f' edges={edges},\n'
1233
+ f')')
1234
+
1235
+ def __del__(self) -> None:
1236
+ if hasattr(self, '_connection'):
1237
+ self._connection.close()