kumoai 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl → 2.15.0.dev202601141731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +51 -24
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/__init__.py +4 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +62 -110
- kumoai/experimental/rfm/backend/local/sampler.py +312 -0
- kumoai/experimental/rfm/backend/local/table.py +35 -31
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +407 -0
- kumoai/experimental/rfm/backend/snow/table.py +178 -50
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +456 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +131 -48
- kumoai/experimental/rfm/base/__init__.py +22 -4
- kumoai/experimental/rfm/base/column.py +96 -10
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +696 -47
- kumoai/experimental/rfm/base/source.py +2 -1
- kumoai/experimental/rfm/base/sql_sampler.py +385 -0
- kumoai/experimental/rfm/base/table.py +384 -207
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +359 -187
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +10 -5
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +5 -4
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +30 -32
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +770 -467
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +192 -13
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/RECORD +54 -42
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512041731.dist-info → kumoai-2.15.0.dev202601141731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
4
|
+
import copy
|
|
2
5
|
import io
|
|
3
6
|
import warnings
|
|
4
7
|
from collections import defaultdict
|
|
8
|
+
from collections.abc import Sequence
|
|
5
9
|
from dataclasses import dataclass, field
|
|
10
|
+
from itertools import chain
|
|
6
11
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
12
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
8
13
|
|
|
9
14
|
import pandas as pd
|
|
10
15
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,9 +18,10 @@ from kumoapi.typing import Stype
|
|
|
13
18
|
from typing_extensions import Self
|
|
14
19
|
|
|
15
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
-
from kumoai.experimental.rfm import Table
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
17
22
|
from kumoai.graph import Edge
|
|
18
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
19
25
|
|
|
20
26
|
if TYPE_CHECKING:
|
|
21
27
|
import graphviz
|
|
@@ -25,8 +31,8 @@ if TYPE_CHECKING:
|
|
|
25
31
|
|
|
26
32
|
@dataclass
|
|
27
33
|
class SqliteConnectionConfig(CastMixin):
|
|
28
|
-
uri:
|
|
29
|
-
kwargs:
|
|
34
|
+
uri: str | Path
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
class Graph:
|
|
@@ -86,27 +92,35 @@ class Graph:
|
|
|
86
92
|
def __init__(
|
|
87
93
|
self,
|
|
88
94
|
tables: Sequence[Table],
|
|
89
|
-
edges:
|
|
95
|
+
edges: Sequence[Edge] | None = None,
|
|
90
96
|
) -> None:
|
|
91
97
|
|
|
92
|
-
self._tables:
|
|
93
|
-
self._edges:
|
|
98
|
+
self._tables: dict[str, Table] = {}
|
|
99
|
+
self._edges: list[Edge] = []
|
|
94
100
|
|
|
95
101
|
for table in tables:
|
|
96
102
|
self.add_table(table)
|
|
97
103
|
|
|
98
|
-
for table in tables:
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
106
|
+
continue
|
|
99
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
-
if fkey.name not in table
|
|
108
|
+
if fkey.name not in table:
|
|
109
|
+
continue
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
112
|
+
dst_table_names = [
|
|
113
|
+
table.name for table in self.tables.values()
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
115
|
+
]
|
|
116
|
+
if len(dst_table_names) != 1:
|
|
117
|
+
continue
|
|
118
|
+
dst_table = self[dst_table_names[0]]
|
|
119
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
101
120
|
continue
|
|
102
|
-
if
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
-
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
-
f"and '{fkey.primary_key}' in table "
|
|
108
|
-
f"'{fkey.dst_table}'.")
|
|
109
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
122
|
+
continue
|
|
123
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
110
124
|
|
|
111
125
|
for edge in (edges or []):
|
|
112
126
|
_edge = Edge._cast(edge)
|
|
@@ -117,8 +131,8 @@ class Graph:
|
|
|
117
131
|
@classmethod
|
|
118
132
|
def from_data(
|
|
119
133
|
cls,
|
|
120
|
-
df_dict:
|
|
121
|
-
edges:
|
|
134
|
+
df_dict: dict[str, pd.DataFrame],
|
|
135
|
+
edges: Sequence[Edge] | None = None,
|
|
122
136
|
infer_metadata: bool = True,
|
|
123
137
|
verbose: bool = True,
|
|
124
138
|
) -> Self:
|
|
@@ -156,15 +170,17 @@ class Graph:
|
|
|
156
170
|
verbose: Whether to print verbose output.
|
|
157
171
|
"""
|
|
158
172
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
173
|
|
|
161
|
-
graph = cls(
|
|
174
|
+
graph = cls(
|
|
175
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
176
|
+
edges=edges or [],
|
|
177
|
+
)
|
|
162
178
|
|
|
163
179
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(False)
|
|
180
|
+
graph.infer_metadata(verbose=False)
|
|
165
181
|
|
|
166
182
|
if edges is None:
|
|
167
|
-
graph.infer_links(False)
|
|
183
|
+
graph.infer_links(verbose=False)
|
|
168
184
|
|
|
169
185
|
if verbose:
|
|
170
186
|
graph.print_metadata()
|
|
@@ -180,10 +196,10 @@ class Graph:
|
|
|
180
196
|
SqliteConnectionConfig,
|
|
181
197
|
str,
|
|
182
198
|
Path,
|
|
183
|
-
|
|
199
|
+
dict[str, Any],
|
|
184
200
|
],
|
|
185
|
-
|
|
186
|
-
edges:
|
|
201
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
202
|
+
edges: Sequence[Edge] | None = None,
|
|
187
203
|
infer_metadata: bool = True,
|
|
188
204
|
verbose: bool = True,
|
|
189
205
|
) -> Self:
|
|
@@ -199,17 +215,25 @@ class Graph:
|
|
|
199
215
|
>>> # Create a graph from a SQLite database:
|
|
200
216
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
217
|
|
|
218
|
+
>>> # Fine-grained control over table specification:
|
|
219
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
220
|
+
... 'USERS',
|
|
221
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
222
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
223
|
+
... ])
|
|
224
|
+
|
|
202
225
|
Args:
|
|
203
226
|
connection: An open connection from
|
|
204
227
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
228
|
path to the database file.
|
|
206
|
-
|
|
207
|
-
all tables present
|
|
229
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
230
|
+
arguments to include. If ``None``, will add all tables present
|
|
231
|
+
in the database.
|
|
208
232
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
209
233
|
add to the graph. If not provided, edges will be automatically
|
|
210
234
|
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
212
|
-
graph.
|
|
235
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
236
|
+
the graph.
|
|
213
237
|
verbose: Whether to print verbose output.
|
|
214
238
|
"""
|
|
215
239
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -218,27 +242,41 @@ class Graph:
|
|
|
218
242
|
connect,
|
|
219
243
|
)
|
|
220
244
|
|
|
245
|
+
internal_connection = False
|
|
221
246
|
if not isinstance(connection, Connection):
|
|
222
247
|
connection = SqliteConnectionConfig._cast(connection)
|
|
223
248
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
249
|
connection = connect(connection.uri, **connection.kwargs)
|
|
250
|
+
internal_connection = True
|
|
225
251
|
assert isinstance(connection, Connection)
|
|
226
252
|
|
|
227
|
-
if
|
|
253
|
+
if tables is None:
|
|
228
254
|
with connection.cursor() as cursor:
|
|
229
255
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
230
256
|
"WHERE type='table'")
|
|
231
|
-
|
|
257
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
232
258
|
|
|
233
|
-
|
|
259
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
260
|
+
for table in tables:
|
|
261
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
262
|
+
table_kwargs.append(kwargs)
|
|
263
|
+
|
|
264
|
+
graph = cls(
|
|
265
|
+
tables=[
|
|
266
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
267
|
+
for kwargs in table_kwargs
|
|
268
|
+
],
|
|
269
|
+
edges=edges or [],
|
|
270
|
+
)
|
|
234
271
|
|
|
235
|
-
|
|
272
|
+
if internal_connection:
|
|
273
|
+
graph._connection = connection # type: ignore
|
|
236
274
|
|
|
237
275
|
if infer_metadata:
|
|
238
|
-
graph.infer_metadata(False)
|
|
276
|
+
graph.infer_metadata(verbose=False)
|
|
239
277
|
|
|
240
278
|
if edges is None:
|
|
241
|
-
graph.infer_links(False)
|
|
279
|
+
graph.infer_links(verbose=False)
|
|
242
280
|
|
|
243
281
|
if verbose:
|
|
244
282
|
graph.print_metadata()
|
|
@@ -249,11 +287,11 @@ class Graph:
|
|
|
249
287
|
@classmethod
|
|
250
288
|
def from_snowflake(
|
|
251
289
|
cls,
|
|
252
|
-
connection: Union['SnowflakeConnection',
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
edges:
|
|
290
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
291
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
292
|
+
database: str | None = None,
|
|
293
|
+
schema: str | None = None,
|
|
294
|
+
edges: Sequence[Edge] | None = None,
|
|
257
295
|
infer_metadata: bool = True,
|
|
258
296
|
verbose: bool = True,
|
|
259
297
|
) -> Self:
|
|
@@ -270,6 +308,13 @@ class Graph:
|
|
|
270
308
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
271
309
|
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
272
310
|
|
|
311
|
+
>>> # Fine-grained control over table specification:
|
|
312
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
313
|
+
... 'USERS',
|
|
314
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
315
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
316
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
317
|
+
|
|
273
318
|
Args:
|
|
274
319
|
connection: An open connection from
|
|
275
320
|
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
@@ -277,10 +322,11 @@ class Graph:
|
|
|
277
322
|
connection. If ``None``, will re-use an active session in case
|
|
278
323
|
it exists, or create a new connection from credentials stored
|
|
279
324
|
in environment variables.
|
|
325
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
326
|
+
to include. If ``None``, will add all tables present in the
|
|
327
|
+
current database and schema.
|
|
280
328
|
database: The database.
|
|
281
329
|
schema: The schema.
|
|
282
|
-
table_names: Set of table names to include. If ``None``, will add
|
|
283
|
-
all tables present in the database.
|
|
284
330
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
285
331
|
add to the graph. If not provided, edges will be automatically
|
|
286
332
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -298,37 +344,50 @@ class Graph:
|
|
|
298
344
|
connection = connect(**(connection or {}))
|
|
299
345
|
assert isinstance(connection, Connection)
|
|
300
346
|
|
|
301
|
-
if
|
|
347
|
+
if database is None or schema is None:
|
|
348
|
+
with connection.cursor() as cursor:
|
|
349
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
350
|
+
result = cursor.fetchone()
|
|
351
|
+
database = database or result[0]
|
|
352
|
+
assert database is not None
|
|
353
|
+
schema = schema or result[1]
|
|
354
|
+
|
|
355
|
+
if tables is None:
|
|
356
|
+
if schema is None:
|
|
357
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
358
|
+
"Snowflake schema manually")
|
|
359
|
+
|
|
302
360
|
with connection.cursor() as cursor:
|
|
303
|
-
if database is None and schema is None:
|
|
304
|
-
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
305
|
-
"CURRENT_SCHEMA()")
|
|
306
|
-
result = cursor.fetchone()
|
|
307
|
-
database = database or result[0]
|
|
308
|
-
schema = schema or result[1]
|
|
309
361
|
cursor.execute(f"""
|
|
310
362
|
SELECT TABLE_NAME
|
|
311
363
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
312
364
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
313
365
|
""")
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
tables = [
|
|
317
|
-
SnowTable(
|
|
318
|
-
connection,
|
|
319
|
-
name=table_name,
|
|
320
|
-
database=database,
|
|
321
|
-
schema=schema,
|
|
322
|
-
) for table_name in table_names
|
|
323
|
-
]
|
|
366
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
324
367
|
|
|
325
|
-
|
|
368
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
369
|
+
for table in tables:
|
|
370
|
+
if isinstance(table, str):
|
|
371
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
372
|
+
else:
|
|
373
|
+
kwargs = copy.copy(table)
|
|
374
|
+
kwargs.setdefault('database', database)
|
|
375
|
+
kwargs.setdefault('schema', schema)
|
|
376
|
+
table_kwargs.append(kwargs)
|
|
377
|
+
|
|
378
|
+
graph = cls(
|
|
379
|
+
tables=[
|
|
380
|
+
SnowTable(connection=connection, **kwargs)
|
|
381
|
+
for kwargs in table_kwargs
|
|
382
|
+
],
|
|
383
|
+
edges=edges or [],
|
|
384
|
+
)
|
|
326
385
|
|
|
327
386
|
if infer_metadata:
|
|
328
|
-
graph.infer_metadata(False)
|
|
387
|
+
graph.infer_metadata(verbose=False)
|
|
329
388
|
|
|
330
389
|
if edges is None:
|
|
331
|
-
graph.infer_links(False)
|
|
390
|
+
graph.infer_links(verbose=False)
|
|
332
391
|
|
|
333
392
|
if verbose:
|
|
334
393
|
graph.print_metadata()
|
|
@@ -340,7 +399,7 @@ class Graph:
|
|
|
340
399
|
def from_snowflake_semantic_view(
|
|
341
400
|
cls,
|
|
342
401
|
semantic_view_name: str,
|
|
343
|
-
connection: Union['SnowflakeConnection',
|
|
402
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
344
403
|
verbose: bool = True,
|
|
345
404
|
) -> Self:
|
|
346
405
|
import yaml
|
|
@@ -358,43 +417,165 @@ class Graph:
|
|
|
358
417
|
with connection.cursor() as cursor:
|
|
359
418
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
360
419
|
f"'{semantic_view_name}')")
|
|
361
|
-
|
|
420
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
362
421
|
|
|
363
422
|
graph = cls(tables=[])
|
|
364
423
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
424
|
+
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
426
|
+
for table_cfg in cfg['tables']:
|
|
427
|
+
table_name = table_cfg['name']
|
|
428
|
+
source_table_name = table_cfg['base_table']['table']
|
|
429
|
+
database = table_cfg['base_table']['database']
|
|
430
|
+
schema = table_cfg['base_table']['schema']
|
|
431
|
+
|
|
432
|
+
primary_key: str | None = None
|
|
433
|
+
if 'primary_key' in table_cfg:
|
|
434
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
435
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
436
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
437
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
438
|
+
msgs.append(f"Failed to add primary key for table "
|
|
439
|
+
f"'{table_name}' since composite primary keys "
|
|
440
|
+
f"are not yet supported")
|
|
441
|
+
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
444
|
+
for column_cfg in chain(
|
|
445
|
+
table_cfg.get('dimensions', []),
|
|
446
|
+
table_cfg.get('time_dimensions', []),
|
|
447
|
+
table_cfg.get('facts', []),
|
|
448
|
+
):
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
370
483
|
|
|
371
484
|
table = SnowTable(
|
|
372
485
|
connection,
|
|
373
|
-
name=
|
|
374
|
-
|
|
375
|
-
|
|
486
|
+
name=table_name,
|
|
487
|
+
source_name=source_table_name,
|
|
488
|
+
database=database,
|
|
489
|
+
schema=schema,
|
|
490
|
+
columns=columns,
|
|
376
491
|
primary_key=primary_key,
|
|
377
492
|
)
|
|
493
|
+
|
|
494
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
495
|
+
table.infer_time_column(verbose=False)
|
|
496
|
+
|
|
378
497
|
graph.add_table(table)
|
|
379
498
|
|
|
380
|
-
|
|
499
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
500
|
+
name = relation_cfg['name']
|
|
501
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
502
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
503
|
+
f"composite key references are not yet supported")
|
|
504
|
+
continue
|
|
381
505
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
506
|
+
left_table = relation_cfg['left_table']
|
|
507
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
508
|
+
right_table = relation_cfg['right_table']
|
|
509
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
510
|
+
|
|
511
|
+
if graph[right_table]._primary_key != right_key:
|
|
512
|
+
# Semantic view error - this should never be triggered:
|
|
513
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
514
|
+
f"referenced key '{right_key}' of table "
|
|
515
|
+
f"'{right_table}' is not a primary key")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if graph[left_table]._primary_key == left_key:
|
|
519
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
520
|
+
f"referencing key '{left_key}' of table "
|
|
521
|
+
f"'{left_table}' is a primary key")
|
|
522
|
+
continue
|
|
523
|
+
|
|
524
|
+
if left_key not in graph[left_table]:
|
|
525
|
+
graph[left_table].add_column(left_key)
|
|
526
|
+
|
|
527
|
+
graph.link(left_table, left_key, right_table)
|
|
528
|
+
|
|
529
|
+
graph.validate()
|
|
390
530
|
|
|
391
531
|
if verbose:
|
|
392
532
|
graph.print_metadata()
|
|
393
533
|
graph.print_links()
|
|
394
534
|
|
|
535
|
+
if len(msgs) > 0:
|
|
536
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
537
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
538
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
539
|
+
|
|
395
540
|
return graph
|
|
396
541
|
|
|
397
|
-
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
571
|
+
# Backend #################################################################
|
|
572
|
+
|
|
573
|
+
@property
|
|
574
|
+
def backend(self) -> DataBackend | None:
|
|
575
|
+
backends = [table.backend for table in self._tables.values()]
|
|
576
|
+
return backends[0] if len(backends) > 0 else None
|
|
577
|
+
|
|
578
|
+
# Tables ##################################################################
|
|
398
579
|
|
|
399
580
|
def has_table(self, name: str) -> bool:
|
|
400
581
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -413,7 +594,7 @@ class Graph:
|
|
|
413
594
|
return self.tables[name]
|
|
414
595
|
|
|
415
596
|
@property
|
|
416
|
-
def tables(self) ->
|
|
597
|
+
def tables(self) -> dict[str, Table]:
|
|
417
598
|
r"""Returns the dictionary of table objects."""
|
|
418
599
|
return self._tables
|
|
419
600
|
|
|
@@ -433,13 +614,10 @@ class Graph:
|
|
|
433
614
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
434
615
|
f"this graph; table names must be globally unique.")
|
|
435
616
|
|
|
436
|
-
if
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
f"'{table.__class__.__name__}' to this "
|
|
441
|
-
f"graph since other tables are of type "
|
|
442
|
-
f"'{cls.__name__}'.")
|
|
617
|
+
if self.backend is not None and table.backend != self.backend:
|
|
618
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
619
|
+
f"'{table.backend}' to this graph since other "
|
|
620
|
+
f"tables have backend '{self.backend}'.")
|
|
443
621
|
|
|
444
622
|
self._tables[table.name] = table
|
|
445
623
|
|
|
@@ -471,28 +649,28 @@ class Graph:
|
|
|
471
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
472
650
|
information about the tables in this graph.
|
|
473
651
|
|
|
474
|
-
The returned dataframe has columns ``
|
|
475
|
-
``
|
|
476
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
477
655
|
|
|
478
656
|
Example:
|
|
479
657
|
>>> # doctest: +SKIP
|
|
480
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
481
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
482
660
|
>>> graph.metadata # doctest: +SKIP
|
|
483
|
-
|
|
484
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
485
663
|
"""
|
|
486
664
|
tables = list(self.tables.values())
|
|
487
665
|
|
|
488
666
|
return pd.DataFrame({
|
|
489
|
-
'
|
|
667
|
+
'Name':
|
|
490
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
491
|
-
'
|
|
669
|
+
'Primary Key':
|
|
492
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
493
|
-
'
|
|
671
|
+
'Time Column':
|
|
494
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
495
|
-
'
|
|
673
|
+
'End Time Column':
|
|
496
674
|
pd.Series(
|
|
497
675
|
dtype=str,
|
|
498
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -501,24 +679,8 @@ class Graph:
|
|
|
501
679
|
|
|
502
680
|
def print_metadata(self) -> None:
|
|
503
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
507
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
508
|
-
elif in_notebook():
|
|
509
|
-
from IPython.display import Markdown, display
|
|
510
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
511
|
-
df = self.metadata
|
|
512
|
-
try:
|
|
513
|
-
if hasattr(df.style, 'hide'):
|
|
514
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
515
|
-
else:
|
|
516
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
517
|
-
except ImportError:
|
|
518
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
519
|
-
else:
|
|
520
|
-
print("🗂️ Graph Metadata:")
|
|
521
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
522
684
|
|
|
523
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
524
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -541,52 +703,33 @@ class Graph:
|
|
|
541
703
|
# Edges ###################################################################
|
|
542
704
|
|
|
543
705
|
@property
|
|
544
|
-
def edges(self) ->
|
|
706
|
+
def edges(self) -> list[Edge]:
|
|
545
707
|
r"""Returns the edges of the graph."""
|
|
546
708
|
return self._edges
|
|
547
709
|
|
|
548
710
|
def print_links(self) -> None:
|
|
549
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
550
|
-
edges = [(
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
st.markdown("*No links registered*")
|
|
564
|
-
elif in_notebook():
|
|
565
|
-
from IPython.display import Markdown, display
|
|
566
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
567
|
-
if len(edges) > 0:
|
|
568
|
-
display(
|
|
569
|
-
Markdown('\n'.join([
|
|
570
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
571
|
-
for edge in edges
|
|
572
|
-
])))
|
|
573
|
-
else:
|
|
574
|
-
display(Markdown("*No links registered*"))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
575
725
|
else:
|
|
576
|
-
|
|
577
|
-
if len(edges) > 0:
|
|
578
|
-
print('\n'.join([
|
|
579
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
580
|
-
for edge in edges
|
|
581
|
-
]))
|
|
582
|
-
else:
|
|
583
|
-
print("No links registered")
|
|
726
|
+
display.italic("No links registered")
|
|
584
727
|
|
|
585
728
|
def link(
|
|
586
729
|
self,
|
|
587
|
-
src_table:
|
|
730
|
+
src_table: str | Table,
|
|
588
731
|
fkey: str,
|
|
589
|
-
dst_table:
|
|
732
|
+
dst_table: str | Table,
|
|
590
733
|
) -> Self:
|
|
591
734
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
592
735
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -647,9 +790,9 @@ class Graph:
|
|
|
647
790
|
|
|
648
791
|
def unlink(
|
|
649
792
|
self,
|
|
650
|
-
src_table:
|
|
793
|
+
src_table: str | Table,
|
|
651
794
|
fkey: str,
|
|
652
|
-
dst_table:
|
|
795
|
+
dst_table: str | Table,
|
|
653
796
|
) -> Self:
|
|
654
797
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
655
798
|
|
|
@@ -687,6 +830,30 @@ class Graph:
|
|
|
687
830
|
"""
|
|
688
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
689
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
690
857
|
# A list of primary key candidates (+score) for every column:
|
|
691
858
|
candidate_dict: dict[
|
|
692
859
|
tuple[str, str],
|
|
@@ -786,13 +953,8 @@ class Graph:
|
|
|
786
953
|
if score < 5.0:
|
|
787
954
|
continue
|
|
788
955
|
|
|
789
|
-
candidate_dict[(
|
|
790
|
-
|
|
791
|
-
src_key.name,
|
|
792
|
-
)].append((
|
|
793
|
-
dst_table.name,
|
|
794
|
-
score,
|
|
795
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
796
958
|
|
|
797
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
798
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -826,6 +988,10 @@ class Graph:
|
|
|
826
988
|
raise ValueError("At least one table needs to be added to the "
|
|
827
989
|
"graph")
|
|
828
990
|
|
|
991
|
+
backends = {table.backend for table in self._tables.values()}
|
|
992
|
+
if len(backends) != 1:
|
|
993
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
994
|
+
|
|
829
995
|
for edge in self.edges:
|
|
830
996
|
src_table, fkey, dst_table = edge
|
|
831
997
|
|
|
@@ -847,24 +1013,26 @@ class Graph:
|
|
|
847
1013
|
f"either the primary key or the link before "
|
|
848
1014
|
f"before proceeding.")
|
|
849
1015
|
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
868
1036
|
|
|
869
1037
|
return self
|
|
870
1038
|
|
|
@@ -872,7 +1040,7 @@ class Graph:
|
|
|
872
1040
|
|
|
873
1041
|
def visualize(
|
|
874
1042
|
self,
|
|
875
|
-
path:
|
|
1043
|
+
path: str | io.BytesIO | None = None,
|
|
876
1044
|
show_columns: bool = True,
|
|
877
1045
|
) -> 'graphviz.Graph':
|
|
878
1046
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -911,19 +1079,19 @@ class Graph:
|
|
|
911
1079
|
"them as described at "
|
|
912
1080
|
"https://graphviz.org/download/.")
|
|
913
1081
|
|
|
914
|
-
format:
|
|
1082
|
+
format: str | None = None
|
|
915
1083
|
if isinstance(path, str):
|
|
916
1084
|
format = path.split('.')[-1]
|
|
917
1085
|
elif isinstance(path, io.BytesIO):
|
|
918
1086
|
format = 'svg'
|
|
919
1087
|
graph = graphviz.Graph(format=format)
|
|
920
1088
|
|
|
921
|
-
def left_align(keys:
|
|
1089
|
+
def left_align(keys: list[str]) -> str:
|
|
922
1090
|
if len(keys) == 0:
|
|
923
1091
|
return ""
|
|
924
1092
|
return '\\l'.join(keys) + '\\l'
|
|
925
1093
|
|
|
926
|
-
fkeys_dict:
|
|
1094
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
927
1095
|
for src_table_name, fkey_name, _ in self.edges:
|
|
928
1096
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
929
1097
|
|
|
@@ -1019,8 +1187,8 @@ class Graph:
|
|
|
1019
1187
|
# Helpers #################################################################
|
|
1020
1188
|
|
|
1021
1189
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
1022
|
-
tables:
|
|
1023
|
-
col_groups:
|
|
1190
|
+
tables: dict[str, TableDefinition] = {}
|
|
1191
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
1024
1192
|
for table_name, table in self.tables.items():
|
|
1025
1193
|
tables[table_name] = table._to_api_table_definition()
|
|
1026
1194
|
if table.primary_key is None:
|
|
@@ -1063,3 +1231,7 @@ class Graph:
|
|
|
1063
1231
|
f' tables={tables},\n'
|
|
1064
1232
|
f' edges={edges},\n'
|
|
1065
1233
|
f')')
|
|
1234
|
+
|
|
1235
|
+
def __del__(self) -> None:
|
|
1236
|
+
if hasattr(self, '_connection'):
|
|
1237
|
+
self._connection.close()
|