kumoai 2.13.0.dev202511211730__py3-none-any.whl → 2.14.0.dev202512141732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +37 -90
- kumoai/experimental/rfm/backend/local/sampler.py +313 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +119 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +112 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +115 -0
- kumoai/experimental/rfm/base/__init__.py +23 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +152 -141
- kumoai/experimental/rfm/{local_graph.py → graph.py} +352 -80
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/pquery/pandas_executor.py +1 -1
- kumoai/experimental/rfm/rfm.py +224 -167
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +2 -0
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/RECORD +39 -23
- kumoai/experimental/rfm/local_graph_sampler.py +0 -182
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511211730.dist-info → kumoai-2.14.0.dev202512141732.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import pandas as pd
|
|
9
10
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,20 +12,29 @@ from kumoapi.table import TableDefinition
|
|
|
11
12
|
from kumoapi.typing import Stype
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
+
from kumoai.experimental.rfm.base import DataBackend, Table
|
|
16
17
|
from kumoai.graph import Edge
|
|
18
|
+
from kumoai.mixin import CastMixin
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
import graphviz
|
|
22
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
23
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
|
|
26
|
+
@dataclass
|
|
27
|
+
class SqliteConnectionConfig(CastMixin):
|
|
28
|
+
uri: Union[str, Path]
|
|
29
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Graph:
|
|
33
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
34
|
tables in a relational database.
|
|
25
35
|
|
|
26
36
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
37
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
38
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
39
|
|
|
30
40
|
.. code-block:: python
|
|
@@ -44,7 +54,7 @@ class LocalGraph:
|
|
|
44
54
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
55
|
|
|
46
56
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
57
|
+
>>> graph = rfm.Graph({
|
|
48
58
|
... "table1": table1,
|
|
49
59
|
... "table2": table2,
|
|
50
60
|
... "table3": table3,
|
|
@@ -75,33 +85,47 @@ class LocalGraph:
|
|
|
75
85
|
|
|
76
86
|
def __init__(
|
|
77
87
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges: Optional[
|
|
88
|
+
tables: Sequence[Table],
|
|
89
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
80
90
|
) -> None:
|
|
81
91
|
|
|
82
|
-
self._tables: Dict[str,
|
|
92
|
+
self._tables: Dict[str, Table] = {}
|
|
83
93
|
self._edges: List[Edge] = []
|
|
84
94
|
|
|
85
95
|
for table in tables:
|
|
86
96
|
self.add_table(table)
|
|
87
97
|
|
|
98
|
+
for table in tables:
|
|
99
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
101
|
+
continue
|
|
102
|
+
if self[fkey.dst_table].primary_key is None:
|
|
103
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
+
f"and '{fkey.primary_key}' in table "
|
|
108
|
+
f"'{fkey.dst_table}'.")
|
|
109
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
110
|
+
|
|
88
111
|
for edge in (edges or []):
|
|
89
112
|
_edge = Edge._cast(edge)
|
|
90
113
|
assert _edge is not None
|
|
91
|
-
self.
|
|
114
|
+
if _edge not in self._edges:
|
|
115
|
+
self.link(*_edge)
|
|
92
116
|
|
|
93
117
|
@classmethod
|
|
94
118
|
def from_data(
|
|
95
119
|
cls,
|
|
96
120
|
df_dict: Dict[str, pd.DataFrame],
|
|
97
|
-
edges: Optional[
|
|
121
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
98
122
|
infer_metadata: bool = True,
|
|
99
123
|
verbose: bool = True,
|
|
100
124
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
125
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
126
|
:class:`pandas.DataFrame` objects.
|
|
103
127
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
128
|
+
Automatically infers table metadata and links by default.
|
|
105
129
|
|
|
106
130
|
.. code-block:: python
|
|
107
131
|
|
|
@@ -115,59 +139,274 @@ class LocalGraph:
|
|
|
115
139
|
>>> df3 = pd.DataFrame(...)
|
|
116
140
|
|
|
117
141
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
142
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
143
|
... "table1": df1,
|
|
120
144
|
... "table2": df2,
|
|
121
145
|
... "table3": df3,
|
|
122
146
|
... })
|
|
123
147
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
148
|
Args:
|
|
132
149
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
150
|
of the tables and the values hold table data.
|
|
151
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
152
|
+
add to the graph. If not provided, edges will be automatically
|
|
153
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
154
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
155
|
graph.
|
|
156
|
+
verbose: Whether to print verbose output.
|
|
157
|
+
"""
|
|
158
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
|
+
|
|
161
|
+
graph = cls(tables, edges=edges or [])
|
|
162
|
+
|
|
163
|
+
if infer_metadata:
|
|
164
|
+
graph.infer_metadata(False)
|
|
165
|
+
|
|
166
|
+
if edges is None:
|
|
167
|
+
graph.infer_links(False)
|
|
168
|
+
|
|
169
|
+
if verbose:
|
|
170
|
+
graph.print_metadata()
|
|
171
|
+
graph.print_links()
|
|
172
|
+
|
|
173
|
+
return graph
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_sqlite(
|
|
177
|
+
cls,
|
|
178
|
+
connection: Union[
|
|
179
|
+
'AdbcSqliteConnection',
|
|
180
|
+
SqliteConnectionConfig,
|
|
181
|
+
str,
|
|
182
|
+
Path,
|
|
183
|
+
Dict[str, Any],
|
|
184
|
+
],
|
|
185
|
+
table_names: Optional[Sequence[str]] = None,
|
|
186
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
187
|
+
infer_metadata: bool = True,
|
|
188
|
+
verbose: bool = True,
|
|
189
|
+
) -> Self:
|
|
190
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
191
|
+
|
|
192
|
+
Automatically infers table metadata and links by default.
|
|
193
|
+
|
|
194
|
+
.. code-block:: python
|
|
195
|
+
|
|
196
|
+
>>> # doctest: +SKIP
|
|
197
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
198
|
+
|
|
199
|
+
>>> # Create a graph from a SQLite database:
|
|
200
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
connection: An open connection from
|
|
204
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
|
+
path to the database file.
|
|
206
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
207
|
+
all tables present in the database.
|
|
136
208
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
209
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
210
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
212
|
+
graph.
|
|
139
213
|
verbose: Whether to print verbose output.
|
|
214
|
+
"""
|
|
215
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
216
|
+
Connection,
|
|
217
|
+
SQLiteTable,
|
|
218
|
+
connect,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
internal_connection = False
|
|
222
|
+
if not isinstance(connection, Connection):
|
|
223
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
224
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
226
|
+
internal_connection = True
|
|
227
|
+
assert isinstance(connection, Connection)
|
|
228
|
+
|
|
229
|
+
if table_names is None:
|
|
230
|
+
with connection.cursor() as cursor:
|
|
231
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
232
|
+
"WHERE type='table'")
|
|
233
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
234
|
+
|
|
235
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
140
236
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
237
|
+
graph = cls(tables, edges=edges or [])
|
|
238
|
+
|
|
239
|
+
if internal_connection:
|
|
240
|
+
graph._connection = connection # type: ignore
|
|
241
|
+
|
|
242
|
+
if infer_metadata:
|
|
243
|
+
graph.infer_metadata(False)
|
|
244
|
+
|
|
245
|
+
if edges is None:
|
|
246
|
+
graph.infer_links(False)
|
|
247
|
+
|
|
248
|
+
if verbose:
|
|
249
|
+
graph.print_metadata()
|
|
250
|
+
graph.print_links()
|
|
251
|
+
|
|
252
|
+
return graph
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def from_snowflake(
|
|
256
|
+
cls,
|
|
257
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
258
|
+
database: Optional[str] = None,
|
|
259
|
+
schema: Optional[str] = None,
|
|
260
|
+
table_names: Optional[Sequence[str]] = None,
|
|
261
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
262
|
+
infer_metadata: bool = True,
|
|
263
|
+
verbose: bool = True,
|
|
264
|
+
) -> Self:
|
|
265
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
266
|
+
schema.
|
|
267
|
+
|
|
268
|
+
Automatically infers table metadata and links by default.
|
|
269
|
+
|
|
270
|
+
.. code-block:: python
|
|
144
271
|
|
|
145
|
-
Example:
|
|
146
272
|
>>> # doctest: +SKIP
|
|
147
273
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
|
|
149
|
-
>>>
|
|
150
|
-
>>>
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
274
|
+
|
|
275
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
276
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
connection: An open connection from
|
|
280
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
281
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
282
|
+
connection. If ``None``, will re-use an active session in case
|
|
283
|
+
it exists, or create a new connection from credentials stored
|
|
284
|
+
in environment variables.
|
|
285
|
+
database: The database.
|
|
286
|
+
schema: The schema.
|
|
287
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
288
|
+
all tables present in the database.
|
|
289
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
290
|
+
add to the graph. If not provided, edges will be automatically
|
|
291
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
292
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
293
|
+
graph.
|
|
294
|
+
verbose: Whether to print verbose output.
|
|
157
295
|
"""
|
|
158
|
-
|
|
296
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
297
|
+
Connection,
|
|
298
|
+
SnowTable,
|
|
299
|
+
connect,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if not isinstance(connection, Connection):
|
|
303
|
+
connection = connect(**(connection or {}))
|
|
304
|
+
assert isinstance(connection, Connection)
|
|
305
|
+
|
|
306
|
+
if table_names is None:
|
|
307
|
+
with connection.cursor() as cursor:
|
|
308
|
+
if database is None and schema is None:
|
|
309
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
310
|
+
"CURRENT_SCHEMA()")
|
|
311
|
+
result = cursor.fetchone()
|
|
312
|
+
database = database or result[0]
|
|
313
|
+
schema = schema or result[1]
|
|
314
|
+
cursor.execute(f"""
|
|
315
|
+
SELECT TABLE_NAME
|
|
316
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
317
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
318
|
+
""")
|
|
319
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
320
|
+
|
|
321
|
+
tables = [
|
|
322
|
+
SnowTable(
|
|
323
|
+
connection,
|
|
324
|
+
name=table_name,
|
|
325
|
+
database=database,
|
|
326
|
+
schema=schema,
|
|
327
|
+
) for table_name in table_names
|
|
328
|
+
]
|
|
159
329
|
|
|
160
330
|
graph = cls(tables, edges=edges or [])
|
|
161
331
|
|
|
162
332
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(
|
|
333
|
+
graph.infer_metadata(False)
|
|
164
334
|
|
|
165
335
|
if edges is None:
|
|
166
|
-
graph.infer_links(
|
|
336
|
+
graph.infer_links(False)
|
|
337
|
+
|
|
338
|
+
if verbose:
|
|
339
|
+
graph.print_metadata()
|
|
340
|
+
graph.print_links()
|
|
341
|
+
|
|
342
|
+
return graph
|
|
343
|
+
|
|
344
|
+
@classmethod
|
|
345
|
+
def from_snowflake_semantic_view(
|
|
346
|
+
cls,
|
|
347
|
+
semantic_view_name: str,
|
|
348
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
349
|
+
verbose: bool = True,
|
|
350
|
+
) -> Self:
|
|
351
|
+
import yaml
|
|
352
|
+
|
|
353
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
354
|
+
Connection,
|
|
355
|
+
SnowTable,
|
|
356
|
+
connect,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if not isinstance(connection, Connection):
|
|
360
|
+
connection = connect(**(connection or {}))
|
|
361
|
+
assert isinstance(connection, Connection)
|
|
362
|
+
|
|
363
|
+
with connection.cursor() as cursor:
|
|
364
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
365
|
+
f"'{semantic_view_name}')")
|
|
366
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
367
|
+
|
|
368
|
+
graph = cls(tables=[])
|
|
369
|
+
|
|
370
|
+
for table_desc in view['tables']:
|
|
371
|
+
primary_key: Optional[str] = None
|
|
372
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
373
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
374
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
375
|
+
|
|
376
|
+
table = SnowTable(
|
|
377
|
+
connection,
|
|
378
|
+
name=table_desc['base_table']['table'],
|
|
379
|
+
database=table_desc['base_table']['database'],
|
|
380
|
+
schema=table_desc['base_table']['schema'],
|
|
381
|
+
primary_key=primary_key,
|
|
382
|
+
)
|
|
383
|
+
graph.add_table(table)
|
|
384
|
+
|
|
385
|
+
# TODO Find a solution to register time columns!
|
|
386
|
+
|
|
387
|
+
for relations in view['relationships']:
|
|
388
|
+
if len(relations['relationship_columns']) != 1:
|
|
389
|
+
continue # NOTE No composite keys yet.
|
|
390
|
+
graph.link(
|
|
391
|
+
src_table=relations['left_table'],
|
|
392
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
393
|
+
dst_table=relations['right_table'],
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if verbose:
|
|
397
|
+
graph.print_metadata()
|
|
398
|
+
graph.print_links()
|
|
167
399
|
|
|
168
400
|
return graph
|
|
169
401
|
|
|
170
|
-
#
|
|
402
|
+
# Backend #################################################################
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def backend(self) -> DataBackend | None:
|
|
406
|
+
backends = [table.backend for table in self._tables.values()]
|
|
407
|
+
return backends[0] if len(backends) > 0 else None
|
|
408
|
+
|
|
409
|
+
# Tables ##################################################################
|
|
171
410
|
|
|
172
411
|
def has_table(self, name: str) -> bool:
|
|
173
412
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -175,7 +414,7 @@ class LocalGraph:
|
|
|
175
414
|
"""
|
|
176
415
|
return name in self.tables
|
|
177
416
|
|
|
178
|
-
def table(self, name: str) ->
|
|
417
|
+
def table(self, name: str) -> Table:
|
|
179
418
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
419
|
|
|
181
420
|
Raises:
|
|
@@ -186,11 +425,11 @@ class LocalGraph:
|
|
|
186
425
|
return self.tables[name]
|
|
187
426
|
|
|
188
427
|
@property
|
|
189
|
-
def tables(self) -> Dict[str,
|
|
428
|
+
def tables(self) -> Dict[str, Table]:
|
|
190
429
|
r"""Returns the dictionary of table objects."""
|
|
191
430
|
return self._tables
|
|
192
431
|
|
|
193
|
-
def add_table(self, table:
|
|
432
|
+
def add_table(self, table: Table) -> Self:
|
|
194
433
|
r"""Adds a table to the graph.
|
|
195
434
|
|
|
196
435
|
Args:
|
|
@@ -199,11 +438,18 @@ class LocalGraph:
|
|
|
199
438
|
Raises:
|
|
200
439
|
KeyError: If a table with the same name already exists in the
|
|
201
440
|
graph.
|
|
441
|
+
ValueError: If the table belongs to a different backend than the
|
|
442
|
+
rest of the tables in the graph.
|
|
202
443
|
"""
|
|
203
444
|
if table.name in self._tables:
|
|
204
445
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
446
|
f"this graph; table names must be globally unique.")
|
|
206
447
|
|
|
448
|
+
if self.backend is not None and table.backend != self.backend:
|
|
449
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
450
|
+
f"'{table.backend}' to this graph since other "
|
|
451
|
+
f"tables have backend '{self.backend}'.")
|
|
452
|
+
|
|
207
453
|
self._tables[table.name] = table
|
|
208
454
|
|
|
209
455
|
return self
|
|
@@ -241,7 +487,7 @@ class LocalGraph:
|
|
|
241
487
|
Example:
|
|
242
488
|
>>> # doctest: +SKIP
|
|
243
489
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
490
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
491
|
>>> graph.metadata # doctest: +SKIP
|
|
246
492
|
name primary_key time_column end_time_column
|
|
247
493
|
0 users user_id - -
|
|
@@ -263,10 +509,14 @@ class LocalGraph:
|
|
|
263
509
|
})
|
|
264
510
|
|
|
265
511
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
267
|
-
if
|
|
512
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
513
|
+
if in_snowflake_notebook():
|
|
514
|
+
import streamlit as st
|
|
515
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
516
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
517
|
+
elif in_notebook():
|
|
268
518
|
from IPython.display import Markdown, display
|
|
269
|
-
display(Markdown(
|
|
519
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
270
520
|
df = self.metadata
|
|
271
521
|
try:
|
|
272
522
|
if hasattr(df.style, 'hide'):
|
|
@@ -287,7 +537,7 @@ class LocalGraph:
|
|
|
287
537
|
|
|
288
538
|
Note:
|
|
289
539
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
540
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
541
|
"""
|
|
292
542
|
for table in self.tables.values():
|
|
293
543
|
table.infer_metadata(verbose=False)
|
|
@@ -305,37 +555,47 @@ class LocalGraph:
|
|
|
305
555
|
return self._edges
|
|
306
556
|
|
|
307
557
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
558
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
559
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
560
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
561
|
edges = sorted(edges)
|
|
312
562
|
|
|
313
|
-
if
|
|
563
|
+
if in_snowflake_notebook():
|
|
564
|
+
import streamlit as st
|
|
565
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
566
|
+
if len(edges) > 0:
|
|
567
|
+
st.markdown('\n'.join([
|
|
568
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
569
|
+
for edge in edges
|
|
570
|
+
]))
|
|
571
|
+
else:
|
|
572
|
+
st.markdown("*No links registered*")
|
|
573
|
+
elif in_notebook():
|
|
314
574
|
from IPython.display import Markdown, display
|
|
315
|
-
display(Markdown(
|
|
575
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
316
576
|
if len(edges) > 0:
|
|
317
577
|
display(
|
|
318
578
|
Markdown('\n'.join([
|
|
319
|
-
f
|
|
579
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
320
580
|
for edge in edges
|
|
321
581
|
])))
|
|
322
582
|
else:
|
|
323
|
-
display(Markdown(
|
|
583
|
+
display(Markdown("*No links registered*"))
|
|
324
584
|
else:
|
|
325
585
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
326
586
|
if len(edges) > 0:
|
|
327
587
|
print('\n'.join([
|
|
328
|
-
f
|
|
588
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
329
589
|
for edge in edges
|
|
330
590
|
]))
|
|
331
591
|
else:
|
|
332
|
-
print(
|
|
592
|
+
print("No links registered")
|
|
333
593
|
|
|
334
594
|
def link(
|
|
335
595
|
self,
|
|
336
|
-
src_table: Union[str,
|
|
596
|
+
src_table: Union[str, Table],
|
|
337
597
|
fkey: str,
|
|
338
|
-
dst_table: Union[str,
|
|
598
|
+
dst_table: Union[str, Table],
|
|
339
599
|
) -> Self:
|
|
340
600
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
601
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +618,11 @@ class LocalGraph:
|
|
|
358
618
|
table does not exist in the graph, if the source key does not
|
|
359
619
|
exist in the source table.
|
|
360
620
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
621
|
+
if isinstance(src_table, Table):
|
|
362
622
|
src_table = src_table.name
|
|
363
623
|
assert isinstance(src_table, str)
|
|
364
624
|
|
|
365
|
-
if isinstance(dst_table,
|
|
625
|
+
if isinstance(dst_table, Table):
|
|
366
626
|
dst_table = dst_table.name
|
|
367
627
|
assert isinstance(dst_table, str)
|
|
368
628
|
|
|
@@ -396,9 +656,9 @@ class LocalGraph:
|
|
|
396
656
|
|
|
397
657
|
def unlink(
|
|
398
658
|
self,
|
|
399
|
-
src_table: Union[str,
|
|
659
|
+
src_table: Union[str, Table],
|
|
400
660
|
fkey: str,
|
|
401
|
-
dst_table: Union[str,
|
|
661
|
+
dst_table: Union[str, Table],
|
|
402
662
|
) -> Self:
|
|
403
663
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
664
|
|
|
@@ -410,11 +670,11 @@ class LocalGraph:
|
|
|
410
670
|
Raises:
|
|
411
671
|
ValueError: if the edge is not present in the graph.
|
|
412
672
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
673
|
+
if isinstance(src_table, Table):
|
|
414
674
|
src_table = src_table.name
|
|
415
675
|
assert isinstance(src_table, str)
|
|
416
676
|
|
|
417
|
-
if isinstance(dst_table,
|
|
677
|
+
if isinstance(dst_table, Table):
|
|
418
678
|
dst_table = dst_table.name
|
|
419
679
|
assert isinstance(dst_table, str)
|
|
420
680
|
|
|
@@ -428,17 +688,13 @@ class LocalGraph:
|
|
|
428
688
|
return self
|
|
429
689
|
|
|
430
690
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
691
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
692
|
+
graph.
|
|
432
693
|
|
|
433
694
|
Args:
|
|
434
695
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
696
|
"""
|
|
439
|
-
|
|
440
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
441
|
-
return self
|
|
697
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
442
698
|
|
|
443
699
|
# A list of primary key candidates (+score) for every column:
|
|
444
700
|
candidate_dict: dict[
|
|
@@ -463,6 +719,9 @@ class LocalGraph:
|
|
|
463
719
|
src_table_name = src_table.name.lower()
|
|
464
720
|
|
|
465
721
|
for src_key in src_table.columns:
|
|
722
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
723
|
+
continue
|
|
724
|
+
|
|
466
725
|
if src_key == src_table.primary_key:
|
|
467
726
|
continue # Cannot link to primary key.
|
|
468
727
|
|
|
@@ -528,7 +787,9 @@ class LocalGraph:
|
|
|
528
787
|
score += 1.0
|
|
529
788
|
|
|
530
789
|
# Cardinality ratio:
|
|
531
|
-
if
|
|
790
|
+
if (src_table._num_rows is not None
|
|
791
|
+
and dst_table._num_rows is not None
|
|
792
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
532
793
|
score += 1.0
|
|
533
794
|
|
|
534
795
|
if score < 5.0:
|
|
@@ -574,6 +835,10 @@ class LocalGraph:
|
|
|
574
835
|
raise ValueError("At least one table needs to be added to the "
|
|
575
836
|
"graph")
|
|
576
837
|
|
|
838
|
+
backends = {table.backend for table in self._tables.values()}
|
|
839
|
+
if len(backends) != 1:
|
|
840
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
841
|
+
|
|
577
842
|
for edge in self.edges:
|
|
578
843
|
src_table, fkey, dst_table = edge
|
|
579
844
|
|
|
@@ -645,19 +910,19 @@ class LocalGraph:
|
|
|
645
910
|
|
|
646
911
|
return True
|
|
647
912
|
|
|
648
|
-
# Check basic dependency:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
913
|
+
try: # Check basic dependency:
|
|
914
|
+
import graphviz
|
|
915
|
+
except ImportError as e:
|
|
916
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
917
|
+
"visualization") from e
|
|
918
|
+
|
|
919
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
653
920
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
654
921
|
"executables are not installed. These "
|
|
655
922
|
"dependencies are required in addition to the "
|
|
656
923
|
"'graphviz' Python package. Please install "
|
|
657
924
|
"them as described at "
|
|
658
925
|
"https://graphviz.org/download/.")
|
|
659
|
-
else:
|
|
660
|
-
import graphviz
|
|
661
926
|
|
|
662
927
|
format: Optional[str] = None
|
|
663
928
|
if isinstance(path, str):
|
|
@@ -741,6 +1006,9 @@ class LocalGraph:
|
|
|
741
1006
|
graph.render(path, cleanup=True)
|
|
742
1007
|
elif isinstance(path, io.BytesIO):
|
|
743
1008
|
path.write(graph.pipe())
|
|
1009
|
+
elif in_snowflake_notebook():
|
|
1010
|
+
import streamlit as st
|
|
1011
|
+
st.graphviz_chart(graph)
|
|
744
1012
|
elif in_notebook():
|
|
745
1013
|
from IPython.display import display
|
|
746
1014
|
display(graph)
|
|
@@ -790,7 +1058,7 @@ class LocalGraph:
|
|
|
790
1058
|
def __contains__(self, name: str) -> bool:
|
|
791
1059
|
return self.has_table(name)
|
|
792
1060
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
1061
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
1062
|
return self.table(name)
|
|
795
1063
|
|
|
796
1064
|
def __delitem__(self, name: str) -> None:
|
|
@@ -808,3 +1076,7 @@ class LocalGraph:
|
|
|
808
1076
|
f' tables={tables},\n'
|
|
809
1077
|
f' edges={edges},\n'
|
|
810
1078
|
f')')
|
|
1079
|
+
|
|
1080
|
+
def __del__(self) -> None:
|
|
1081
|
+
if hasattr(self, '_connection'):
|
|
1082
|
+
self._connection.close()
|