kumoai 2.14.0.dev202512141732__py3-none-any.whl → 2.15.0.dev202601131732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +37 -46
- kumoai/experimental/rfm/backend/local/sampler.py +4 -5
- kumoai/experimental/rfm/backend/local/table.py +24 -30
- kumoai/experimental/rfm/backend/snow/sampler.py +331 -43
- kumoai/experimental/rfm/backend/snow/table.py +166 -56
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +372 -30
- kumoai/experimental/rfm/backend/sqlite/table.py +117 -48
- kumoai/experimental/rfm/base/__init__.py +8 -1
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +374 -208
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +335 -180
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +10 -5
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +606 -361
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/METADATA +3 -2
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/RECORD +49 -40
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512141732.dist-info → kumoai-2.15.0.dev202601131732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import io
|
|
3
6
|
import warnings
|
|
4
7
|
from collections import defaultdict
|
|
8
|
+
from collections.abc import Sequence
|
|
5
9
|
from dataclasses import dataclass, field
|
|
10
|
+
from itertools import chain
|
|
6
11
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
8
13
|
|
|
9
14
|
import pandas as pd
|
|
10
15
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,9 +18,10 @@ from kumoapi.typing import Stype
|
|
|
13
18
|
from typing_extensions import Self
|
|
14
19
|
|
|
15
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
-
from kumoai.experimental.rfm.base import DataBackend, Table
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
17
22
|
from kumoai.graph import Edge
|
|
18
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
19
25
|
|
|
20
26
|
if TYPE_CHECKING:
|
|
21
27
|
import graphviz
|
|
@@ -25,8 +31,8 @@ if TYPE_CHECKING:
|
|
|
25
31
|
|
|
26
32
|
@dataclass
|
|
27
33
|
class SqliteConnectionConfig(CastMixin):
|
|
28
|
-
uri:
|
|
29
|
-
kwargs:
|
|
34
|
+
uri: str | Path
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
class Graph:
|
|
@@ -86,27 +92,35 @@ class Graph:
|
|
|
86
92
|
def __init__(
|
|
87
93
|
self,
|
|
88
94
|
tables: Sequence[Table],
|
|
89
|
-
edges:
|
|
95
|
+
edges: Sequence[Edge] | None = None,
|
|
90
96
|
) -> None:
|
|
91
97
|
|
|
92
|
-
self._tables:
|
|
93
|
-
self._edges:
|
|
98
|
+
self._tables: dict[str, Table] = {}
|
|
99
|
+
self._edges: list[Edge] = []
|
|
94
100
|
|
|
95
101
|
for table in tables:
|
|
96
102
|
self.add_table(table)
|
|
97
103
|
|
|
98
|
-
for table in tables:
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
106
|
+
continue
|
|
99
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
-
if fkey.name not in table
|
|
108
|
+
if fkey.name not in table:
|
|
109
|
+
continue
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
112
|
+
dst_table_names = [
|
|
113
|
+
table.name for table in self.tables.values()
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
115
|
+
]
|
|
116
|
+
if len(dst_table_names) != 1:
|
|
101
117
|
continue
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
f"'{fkey.dst_table}'.")
|
|
109
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
118
|
+
dst_table = self[dst_table_names[0]]
|
|
119
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
120
|
+
continue
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
122
|
+
continue
|
|
123
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
110
124
|
|
|
111
125
|
for edge in (edges or []):
|
|
112
126
|
_edge = Edge._cast(edge)
|
|
@@ -117,8 +131,8 @@ class Graph:
|
|
|
117
131
|
@classmethod
|
|
118
132
|
def from_data(
|
|
119
133
|
cls,
|
|
120
|
-
df_dict:
|
|
121
|
-
edges:
|
|
134
|
+
df_dict: dict[str, pd.DataFrame],
|
|
135
|
+
edges: Sequence[Edge] | None = None,
|
|
122
136
|
infer_metadata: bool = True,
|
|
123
137
|
verbose: bool = True,
|
|
124
138
|
) -> Self:
|
|
@@ -156,15 +170,17 @@ class Graph:
|
|
|
156
170
|
verbose: Whether to print verbose output.
|
|
157
171
|
"""
|
|
158
172
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
173
|
|
|
161
|
-
graph = cls(
|
|
174
|
+
graph = cls(
|
|
175
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
176
|
+
edges=edges or [],
|
|
177
|
+
)
|
|
162
178
|
|
|
163
179
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(False)
|
|
180
|
+
graph.infer_metadata(verbose=False)
|
|
165
181
|
|
|
166
182
|
if edges is None:
|
|
167
|
-
graph.infer_links(False)
|
|
183
|
+
graph.infer_links(verbose=False)
|
|
168
184
|
|
|
169
185
|
if verbose:
|
|
170
186
|
graph.print_metadata()
|
|
@@ -180,10 +196,10 @@ class Graph:
|
|
|
180
196
|
SqliteConnectionConfig,
|
|
181
197
|
str,
|
|
182
198
|
Path,
|
|
183
|
-
|
|
199
|
+
dict[str, Any],
|
|
184
200
|
],
|
|
185
|
-
|
|
186
|
-
edges:
|
|
201
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
202
|
+
edges: Sequence[Edge] | None = None,
|
|
187
203
|
infer_metadata: bool = True,
|
|
188
204
|
verbose: bool = True,
|
|
189
205
|
) -> Self:
|
|
@@ -199,17 +215,25 @@ class Graph:
|
|
|
199
215
|
>>> # Create a graph from a SQLite database:
|
|
200
216
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
217
|
|
|
218
|
+
>>> # Fine-grained control over table specification:
|
|
219
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
220
|
+
... 'USERS',
|
|
221
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
222
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
223
|
+
... ])
|
|
224
|
+
|
|
202
225
|
Args:
|
|
203
226
|
connection: An open connection from
|
|
204
227
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
228
|
path to the database file.
|
|
206
|
-
|
|
207
|
-
all tables present
|
|
229
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
230
|
+
arguments to include. If ``None``, will add all tables present
|
|
231
|
+
in the database.
|
|
208
232
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
209
233
|
add to the graph. If not provided, edges will be automatically
|
|
210
234
|
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
212
|
-
graph.
|
|
235
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
236
|
+
the graph.
|
|
213
237
|
verbose: Whether to print verbose output.
|
|
214
238
|
"""
|
|
215
239
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -226,24 +250,33 @@ class Graph:
|
|
|
226
250
|
internal_connection = True
|
|
227
251
|
assert isinstance(connection, Connection)
|
|
228
252
|
|
|
229
|
-
if
|
|
253
|
+
if tables is None:
|
|
230
254
|
with connection.cursor() as cursor:
|
|
231
255
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
232
256
|
"WHERE type='table'")
|
|
233
|
-
|
|
257
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
234
258
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
259
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
260
|
+
for table in tables:
|
|
261
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
262
|
+
table_kwargs.append(kwargs)
|
|
263
|
+
|
|
264
|
+
graph = cls(
|
|
265
|
+
tables=[
|
|
266
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
267
|
+
for kwargs in table_kwargs
|
|
268
|
+
],
|
|
269
|
+
edges=edges or [],
|
|
270
|
+
)
|
|
238
271
|
|
|
239
272
|
if internal_connection:
|
|
240
273
|
graph._connection = connection # type: ignore
|
|
241
274
|
|
|
242
275
|
if infer_metadata:
|
|
243
|
-
graph.infer_metadata(False)
|
|
276
|
+
graph.infer_metadata(verbose=False)
|
|
244
277
|
|
|
245
278
|
if edges is None:
|
|
246
|
-
graph.infer_links(False)
|
|
279
|
+
graph.infer_links(verbose=False)
|
|
247
280
|
|
|
248
281
|
if verbose:
|
|
249
282
|
graph.print_metadata()
|
|
@@ -254,11 +287,11 @@ class Graph:
|
|
|
254
287
|
@classmethod
|
|
255
288
|
def from_snowflake(
|
|
256
289
|
cls,
|
|
257
|
-
connection: Union['SnowflakeConnection',
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
edges:
|
|
290
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
291
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
292
|
+
database: str | None = None,
|
|
293
|
+
schema: str | None = None,
|
|
294
|
+
edges: Sequence[Edge] | None = None,
|
|
262
295
|
infer_metadata: bool = True,
|
|
263
296
|
verbose: bool = True,
|
|
264
297
|
) -> Self:
|
|
@@ -275,6 +308,13 @@ class Graph:
|
|
|
275
308
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
276
309
|
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
277
310
|
|
|
311
|
+
>>> # Fine-grained control over table specification:
|
|
312
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
313
|
+
... 'USERS',
|
|
314
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
315
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
316
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
317
|
+
|
|
278
318
|
Args:
|
|
279
319
|
connection: An open connection from
|
|
280
320
|
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
@@ -282,10 +322,11 @@ class Graph:
|
|
|
282
322
|
connection. If ``None``, will re-use an active session in case
|
|
283
323
|
it exists, or create a new connection from credentials stored
|
|
284
324
|
in environment variables.
|
|
325
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
326
|
+
to include. If ``None``, will add all tables present in the
|
|
327
|
+
current database and schema.
|
|
285
328
|
database: The database.
|
|
286
329
|
schema: The schema.
|
|
287
|
-
table_names: Set of table names to include. If ``None``, will add
|
|
288
|
-
all tables present in the database.
|
|
289
330
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
290
331
|
add to the graph. If not provided, edges will be automatically
|
|
291
332
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -303,37 +344,50 @@ class Graph:
|
|
|
303
344
|
connection = connect(**(connection or {}))
|
|
304
345
|
assert isinstance(connection, Connection)
|
|
305
346
|
|
|
306
|
-
if
|
|
347
|
+
if database is None or schema is None:
|
|
348
|
+
with connection.cursor() as cursor:
|
|
349
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
350
|
+
result = cursor.fetchone()
|
|
351
|
+
database = database or result[0]
|
|
352
|
+
assert database is not None
|
|
353
|
+
schema = schema or result[1]
|
|
354
|
+
|
|
355
|
+
if tables is None:
|
|
356
|
+
if schema is None:
|
|
357
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
358
|
+
"Snowflake schema manually")
|
|
359
|
+
|
|
307
360
|
with connection.cursor() as cursor:
|
|
308
|
-
if database is None and schema is None:
|
|
309
|
-
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
310
|
-
"CURRENT_SCHEMA()")
|
|
311
|
-
result = cursor.fetchone()
|
|
312
|
-
database = database or result[0]
|
|
313
|
-
schema = schema or result[1]
|
|
314
361
|
cursor.execute(f"""
|
|
315
362
|
SELECT TABLE_NAME
|
|
316
363
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
317
364
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
318
365
|
""")
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
tables = [
|
|
322
|
-
SnowTable(
|
|
323
|
-
connection,
|
|
324
|
-
name=table_name,
|
|
325
|
-
database=database,
|
|
326
|
-
schema=schema,
|
|
327
|
-
) for table_name in table_names
|
|
328
|
-
]
|
|
366
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
329
367
|
|
|
330
|
-
|
|
368
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
369
|
+
for table in tables:
|
|
370
|
+
if isinstance(table, str):
|
|
371
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
372
|
+
else:
|
|
373
|
+
kwargs = copy.copy(table)
|
|
374
|
+
kwargs.setdefault('database', database)
|
|
375
|
+
kwargs.setdefault('schema', schema)
|
|
376
|
+
table_kwargs.append(kwargs)
|
|
377
|
+
|
|
378
|
+
graph = cls(
|
|
379
|
+
tables=[
|
|
380
|
+
SnowTable(connection=connection, **kwargs)
|
|
381
|
+
for kwargs in table_kwargs
|
|
382
|
+
],
|
|
383
|
+
edges=edges or [],
|
|
384
|
+
)
|
|
331
385
|
|
|
332
386
|
if infer_metadata:
|
|
333
|
-
graph.infer_metadata(False)
|
|
387
|
+
graph.infer_metadata(verbose=False)
|
|
334
388
|
|
|
335
389
|
if edges is None:
|
|
336
|
-
graph.infer_links(False)
|
|
390
|
+
graph.infer_links(verbose=False)
|
|
337
391
|
|
|
338
392
|
if verbose:
|
|
339
393
|
graph.print_metadata()
|
|
@@ -345,7 +399,7 @@ class Graph:
|
|
|
345
399
|
def from_snowflake_semantic_view(
|
|
346
400
|
cls,
|
|
347
401
|
semantic_view_name: str,
|
|
348
|
-
connection: Union['SnowflakeConnection',
|
|
402
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
349
403
|
verbose: bool = True,
|
|
350
404
|
) -> Self:
|
|
351
405
|
import yaml
|
|
@@ -363,35 +417,150 @@ class Graph:
|
|
|
363
417
|
with connection.cursor() as cursor:
|
|
364
418
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
365
419
|
f"'{semantic_view_name}')")
|
|
366
|
-
|
|
420
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
367
421
|
|
|
368
422
|
graph = cls(tables=[])
|
|
369
423
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
424
|
+
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
426
|
+
for table_cfg in cfg['tables']:
|
|
427
|
+
table_name = table_cfg['name']
|
|
428
|
+
source_table_name = table_cfg['base_table']['table']
|
|
429
|
+
database = table_cfg['base_table']['database']
|
|
430
|
+
schema = table_cfg['base_table']['schema']
|
|
431
|
+
|
|
432
|
+
primary_key: str | None = None
|
|
433
|
+
if 'primary_key' in table_cfg:
|
|
434
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
435
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
436
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
437
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
438
|
+
msgs.append(f"Failed to add primary key for table "
|
|
439
|
+
f"'{table_name}' since composite primary keys "
|
|
440
|
+
f"are not yet supported")
|
|
441
|
+
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
444
|
+
for column_cfg in chain(
|
|
445
|
+
table_cfg.get('dimensions', []),
|
|
446
|
+
table_cfg.get('time_dimensions', []),
|
|
447
|
+
table_cfg.get('facts', []),
|
|
448
|
+
):
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
375
483
|
|
|
376
484
|
table = SnowTable(
|
|
377
485
|
connection,
|
|
378
|
-
name=
|
|
379
|
-
|
|
380
|
-
|
|
486
|
+
name=table_name,
|
|
487
|
+
source_name=source_table_name,
|
|
488
|
+
database=database,
|
|
489
|
+
schema=schema,
|
|
490
|
+
columns=columns,
|
|
381
491
|
primary_key=primary_key,
|
|
382
492
|
)
|
|
493
|
+
|
|
494
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
+
table.infer_time_column(verbose=False)
|
|
496
|
+
|
|
383
497
|
graph.add_table(table)
|
|
384
498
|
|
|
385
|
-
|
|
499
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
500
|
+
name = relation_cfg['name']
|
|
501
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
502
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
503
|
+
f"composite key references are not yet supported")
|
|
504
|
+
continue
|
|
386
505
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
506
|
+
left_table = relation_cfg['left_table']
|
|
507
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
508
|
+
right_table = relation_cfg['right_table']
|
|
509
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
510
|
+
|
|
511
|
+
if graph[right_table]._primary_key != right_key:
|
|
512
|
+
# Semantic view error - this should never be triggered:
|
|
513
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
514
|
+
f"referenced key '{right_key}' of table "
|
|
515
|
+
f"'{right_table}' is not a primary key")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if graph[left_table]._primary_key == left_key:
|
|
519
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
520
|
+
f"referencing key '{left_key}' of table "
|
|
521
|
+
f"'{left_table}' is a primary key")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
if left_key not in graph[left_table]:
|
|
525
|
+
graph[left_table].add_column(left_key)
|
|
526
|
+
|
|
527
|
+
graph.link(left_table, left_key, right_table)
|
|
528
|
+
|
|
529
|
+
graph.validate()
|
|
530
|
+
|
|
531
|
+
if verbose:
|
|
532
|
+
graph.print_metadata()
|
|
533
|
+
graph.print_links()
|
|
534
|
+
|
|
535
|
+
if len(msgs) > 0:
|
|
536
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
537
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
538
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
539
|
+
|
|
540
|
+
return graph
|
|
541
|
+
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
395
564
|
|
|
396
565
|
if verbose:
|
|
397
566
|
graph.print_metadata()
|
|
@@ -425,7 +594,7 @@ class Graph:
|
|
|
425
594
|
return self.tables[name]
|
|
426
595
|
|
|
427
596
|
@property
|
|
428
|
-
def tables(self) ->
|
|
597
|
+
def tables(self) -> dict[str, Table]:
|
|
429
598
|
r"""Returns the dictionary of table objects."""
|
|
430
599
|
return self._tables
|
|
431
600
|
|
|
@@ -480,28 +649,28 @@ class Graph:
|
|
|
480
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
481
650
|
information about the tables in this graph.
|
|
482
651
|
|
|
483
|
-
The returned dataframe has columns ``
|
|
484
|
-
``
|
|
485
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
486
655
|
|
|
487
656
|
Example:
|
|
488
657
|
>>> # doctest: +SKIP
|
|
489
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
490
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
491
660
|
>>> graph.metadata # doctest: +SKIP
|
|
492
|
-
|
|
493
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
494
663
|
"""
|
|
495
664
|
tables = list(self.tables.values())
|
|
496
665
|
|
|
497
666
|
return pd.DataFrame({
|
|
498
|
-
'
|
|
667
|
+
'Name':
|
|
499
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
500
|
-
'
|
|
669
|
+
'Primary Key':
|
|
501
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
502
|
-
'
|
|
671
|
+
'Time Column':
|
|
503
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
504
|
-
'
|
|
673
|
+
'End Time Column':
|
|
505
674
|
pd.Series(
|
|
506
675
|
dtype=str,
|
|
507
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -510,24 +679,8 @@ class Graph:
|
|
|
510
679
|
|
|
511
680
|
def print_metadata(self) -> None:
|
|
512
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
516
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
517
|
-
elif in_notebook():
|
|
518
|
-
from IPython.display import Markdown, display
|
|
519
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
520
|
-
df = self.metadata
|
|
521
|
-
try:
|
|
522
|
-
if hasattr(df.style, 'hide'):
|
|
523
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
524
|
-
else:
|
|
525
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
526
|
-
except ImportError:
|
|
527
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
528
|
-
else:
|
|
529
|
-
print("🗂️ Graph Metadata:")
|
|
530
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
531
684
|
|
|
532
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
533
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -550,52 +703,33 @@ class Graph:
|
|
|
550
703
|
# Edges ###################################################################
|
|
551
704
|
|
|
552
705
|
@property
|
|
553
|
-
def edges(self) ->
|
|
706
|
+
def edges(self) -> list[Edge]:
|
|
554
707
|
r"""Returns the edges of the graph."""
|
|
555
708
|
return self._edges
|
|
556
709
|
|
|
557
710
|
def print_links(self) -> None:
|
|
558
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
559
|
-
edges = [(
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
st.markdown("*No links registered*")
|
|
573
|
-
elif in_notebook():
|
|
574
|
-
from IPython.display import Markdown, display
|
|
575
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
576
|
-
if len(edges) > 0:
|
|
577
|
-
display(
|
|
578
|
-
Markdown('\n'.join([
|
|
579
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
580
|
-
for edge in edges
|
|
581
|
-
])))
|
|
582
|
-
else:
|
|
583
|
-
display(Markdown("*No links registered*"))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
584
725
|
else:
|
|
585
|
-
|
|
586
|
-
if len(edges) > 0:
|
|
587
|
-
print('\n'.join([
|
|
588
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
589
|
-
for edge in edges
|
|
590
|
-
]))
|
|
591
|
-
else:
|
|
592
|
-
print("No links registered")
|
|
726
|
+
display.italic("No links registered")
|
|
593
727
|
|
|
594
728
|
def link(
|
|
595
729
|
self,
|
|
596
|
-
src_table:
|
|
730
|
+
src_table: str | Table,
|
|
597
731
|
fkey: str,
|
|
598
|
-
dst_table:
|
|
732
|
+
dst_table: str | Table,
|
|
599
733
|
) -> Self:
|
|
600
734
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
601
735
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -656,9 +790,9 @@ class Graph:
|
|
|
656
790
|
|
|
657
791
|
def unlink(
|
|
658
792
|
self,
|
|
659
|
-
src_table:
|
|
793
|
+
src_table: str | Table,
|
|
660
794
|
fkey: str,
|
|
661
|
-
dst_table:
|
|
795
|
+
dst_table: str | Table,
|
|
662
796
|
) -> Self:
|
|
663
797
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
664
798
|
|
|
@@ -696,6 +830,30 @@ class Graph:
|
|
|
696
830
|
"""
|
|
697
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
698
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
699
857
|
# A list of primary key candidates (+score) for every column:
|
|
700
858
|
candidate_dict: dict[
|
|
701
859
|
tuple[str, str],
|
|
@@ -795,13 +953,8 @@ class Graph:
|
|
|
795
953
|
if score < 5.0:
|
|
796
954
|
continue
|
|
797
955
|
|
|
798
|
-
candidate_dict[(
|
|
799
|
-
|
|
800
|
-
src_key.name,
|
|
801
|
-
)].append((
|
|
802
|
-
dst_table.name,
|
|
803
|
-
score,
|
|
804
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
805
958
|
|
|
806
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
807
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -860,24 +1013,26 @@ class Graph:
|
|
|
860
1013
|
f"either the primary key or the link before "
|
|
861
1014
|
f"before proceeding.")
|
|
862
1015
|
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
881
1036
|
|
|
882
1037
|
return self
|
|
883
1038
|
|
|
@@ -885,7 +1040,7 @@ class Graph:
|
|
|
885
1040
|
|
|
886
1041
|
def visualize(
|
|
887
1042
|
self,
|
|
888
|
-
path:
|
|
1043
|
+
path: str | io.BytesIO | None = None,
|
|
889
1044
|
show_columns: bool = True,
|
|
890
1045
|
) -> 'graphviz.Graph':
|
|
891
1046
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -924,19 +1079,19 @@ class Graph:
|
|
|
924
1079
|
"them as described at "
|
|
925
1080
|
"https://graphviz.org/download/.")
|
|
926
1081
|
|
|
927
|
-
format:
|
|
1082
|
+
format: str | None = None
|
|
928
1083
|
if isinstance(path, str):
|
|
929
1084
|
format = path.split('.')[-1]
|
|
930
1085
|
elif isinstance(path, io.BytesIO):
|
|
931
1086
|
format = 'svg'
|
|
932
1087
|
graph = graphviz.Graph(format=format)
|
|
933
1088
|
|
|
934
|
-
def left_align(keys:
|
|
1089
|
+
def left_align(keys: list[str]) -> str:
|
|
935
1090
|
if len(keys) == 0:
|
|
936
1091
|
return ""
|
|
937
1092
|
return '\\l'.join(keys) + '\\l'
|
|
938
1093
|
|
|
939
|
-
fkeys_dict:
|
|
1094
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
940
1095
|
for src_table_name, fkey_name, _ in self.edges:
|
|
941
1096
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
942
1097
|
|
|
@@ -1032,8 +1187,8 @@ class Graph:
|
|
|
1032
1187
|
# Helpers #################################################################
|
|
1033
1188
|
|
|
1034
1189
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
1035
|
-
tables:
|
|
1036
|
-
col_groups:
|
|
1190
|
+
tables: dict[str, TableDefinition] = {}
|
|
1191
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
1037
1192
|
for table_name, table in self.tables.items():
|
|
1038
1193
|
tables[table_name] = table._to_api_table_definition()
|
|
1039
1194
|
if table.primary_key is None:
|