kumoai 2.10.0.dev202509231831__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512161731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/pquery.py +6 -2
- kumoai/client/rfm.py +37 -8
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +49 -86
- kumoai/experimental/rfm/backend/local/sampler.py +315 -0
- kumoai/experimental/rfm/backend/local/table.py +119 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +37 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +274 -0
- kumoai/experimental/rfm/backend/snow/table.py +135 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/sampler.py +353 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +126 -0
- kumoai/experimental/rfm/base/__init__.py +25 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +773 -0
- kumoai/experimental/rfm/base/source.py +19 -0
- kumoai/experimental/rfm/base/sql_sampler.py +60 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +245 -156
- kumoai/experimental/rfm/{local_graph.py → graph.py} +425 -137
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +24 -58
- kumoai/experimental/rfm/pquery/{pandas_backend.py → pandas_executor.py} +278 -224
- kumoai/experimental/rfm/rfm.py +669 -246
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +1 -0
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/testing/snow.py +50 -0
- kumoai/trainer/trainer.py +12 -10
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +239 -4
- kumoai/utils/sql.py +3 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/METADATA +15 -5
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/RECORD +50 -32
- kumoai/experimental/rfm/local_graph_sampler.py +0 -176
- kumoai/experimental/rfm/local_pquery_driver.py +0 -404
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/WHEEL +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.10.0.dev202509231831.dist-info → kumoai-2.14.0.dev202512161731.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import pandas as pd
|
|
9
10
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,160 +12,401 @@ from kumoapi.table import TableDefinition
|
|
|
11
12
|
from kumoapi.typing import Stype
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
+
from kumoai.experimental.rfm.base import DataBackend, Table
|
|
16
17
|
from kumoai.graph import Edge
|
|
18
|
+
from kumoai.mixin import CastMixin
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
import graphviz
|
|
22
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
23
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
|
|
26
|
+
@dataclass
|
|
27
|
+
class SqliteConnectionConfig(CastMixin):
|
|
28
|
+
uri: Union[str, Path]
|
|
29
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Graph:
|
|
33
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
34
|
tables in a relational database.
|
|
25
35
|
|
|
26
36
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
37
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
38
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
39
|
|
|
30
40
|
.. code-block:: python
|
|
31
41
|
|
|
32
|
-
|
|
33
|
-
import
|
|
42
|
+
>>> # doctest: +SKIP
|
|
43
|
+
>>> import pandas as pd
|
|
44
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
34
45
|
|
|
35
|
-
# Load data frames into memory:
|
|
36
|
-
df1 = pd.DataFrame(...)
|
|
37
|
-
df2 = pd.DataFrame(...)
|
|
38
|
-
df3 = pd.DataFrame(...)
|
|
46
|
+
>>> # Load data frames into memory:
|
|
47
|
+
>>> df1 = pd.DataFrame(...)
|
|
48
|
+
>>> df2 = pd.DataFrame(...)
|
|
49
|
+
>>> df3 = pd.DataFrame(...)
|
|
39
50
|
|
|
40
|
-
# Define tables from data frames:
|
|
41
|
-
table1 = rfm.LocalTable(name="table1", data=df1)
|
|
42
|
-
table2 = rfm.LocalTable(name="table2", data=df2)
|
|
43
|
-
table3 = rfm.LocalTable(name="table3", data=df3)
|
|
51
|
+
>>> # Define tables from data frames:
|
|
52
|
+
>>> table1 = rfm.LocalTable(name="table1", data=df1)
|
|
53
|
+
>>> table2 = rfm.LocalTable(name="table2", data=df2)
|
|
54
|
+
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
44
55
|
|
|
45
|
-
# Create a graph from a dictionary of tables:
|
|
46
|
-
graph = rfm.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
})
|
|
56
|
+
>>> # Create a graph from a dictionary of tables:
|
|
57
|
+
>>> graph = rfm.Graph({
|
|
58
|
+
... "table1": table1,
|
|
59
|
+
... "table2": table2,
|
|
60
|
+
... "table3": table3,
|
|
61
|
+
... })
|
|
51
62
|
|
|
52
|
-
# Infer table metadata:
|
|
53
|
-
graph.infer_metadata()
|
|
63
|
+
>>> # Infer table metadata:
|
|
64
|
+
>>> graph.infer_metadata()
|
|
54
65
|
|
|
55
|
-
# Infer links/edges:
|
|
56
|
-
graph.infer_links()
|
|
66
|
+
>>> # Infer links/edges:
|
|
67
|
+
>>> graph.infer_links()
|
|
57
68
|
|
|
58
|
-
# Inspect table metadata:
|
|
59
|
-
for table in graph.tables.values():
|
|
60
|
-
|
|
69
|
+
>>> # Inspect table metadata:
|
|
70
|
+
>>> for table in graph.tables.values():
|
|
71
|
+
... table.print_metadata()
|
|
61
72
|
|
|
62
|
-
# Visualize graph (if graphviz is installed):
|
|
63
|
-
graph.visualize()
|
|
73
|
+
>>> # Visualize graph (if graphviz is installed):
|
|
74
|
+
>>> graph.visualize()
|
|
64
75
|
|
|
65
|
-
# Add/Remove edges between tables:
|
|
66
|
-
graph.link(src_table="table1", fkey="id1", dst_table="table2")
|
|
67
|
-
graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
|
|
76
|
+
>>> # Add/Remove edges between tables:
|
|
77
|
+
>>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
|
|
78
|
+
>>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
|
|
68
79
|
|
|
69
|
-
# Validate graph:
|
|
70
|
-
graph.validate()
|
|
80
|
+
>>> # Validate graph:
|
|
81
|
+
>>> graph.validate()
|
|
71
82
|
"""
|
|
72
83
|
|
|
73
84
|
# Constructors ############################################################
|
|
74
85
|
|
|
75
86
|
def __init__(
|
|
76
87
|
self,
|
|
77
|
-
tables:
|
|
78
|
-
edges: Optional[
|
|
88
|
+
tables: Sequence[Table],
|
|
89
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
79
90
|
) -> None:
|
|
80
91
|
|
|
81
|
-
self._tables: Dict[str,
|
|
92
|
+
self._tables: Dict[str, Table] = {}
|
|
82
93
|
self._edges: List[Edge] = []
|
|
83
94
|
|
|
84
95
|
for table in tables:
|
|
85
96
|
self.add_table(table)
|
|
86
97
|
|
|
98
|
+
for table in tables:
|
|
99
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
101
|
+
continue
|
|
102
|
+
if self[fkey.dst_table].primary_key is None:
|
|
103
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
+
f"and '{fkey.primary_key}' in table "
|
|
108
|
+
f"'{fkey.dst_table}'.")
|
|
109
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
110
|
+
|
|
87
111
|
for edge in (edges or []):
|
|
88
112
|
_edge = Edge._cast(edge)
|
|
89
113
|
assert _edge is not None
|
|
90
|
-
self.
|
|
114
|
+
if _edge not in self._edges:
|
|
115
|
+
self.link(*_edge)
|
|
91
116
|
|
|
92
117
|
@classmethod
|
|
93
118
|
def from_data(
|
|
94
119
|
cls,
|
|
95
120
|
df_dict: Dict[str, pd.DataFrame],
|
|
96
|
-
edges: Optional[
|
|
121
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
97
122
|
infer_metadata: bool = True,
|
|
98
123
|
verbose: bool = True,
|
|
99
124
|
) -> Self:
|
|
100
|
-
r"""Creates a :class:`
|
|
125
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
101
126
|
:class:`pandas.DataFrame` objects.
|
|
102
127
|
|
|
103
|
-
Automatically infers table metadata and links.
|
|
128
|
+
Automatically infers table metadata and links by default.
|
|
104
129
|
|
|
105
130
|
.. code-block:: python
|
|
106
131
|
|
|
107
|
-
|
|
108
|
-
import
|
|
109
|
-
|
|
110
|
-
# Load data frames into memory:
|
|
111
|
-
df1 = pd.DataFrame(...)
|
|
112
|
-
df2 = pd.DataFrame(...)
|
|
113
|
-
df3 = pd.DataFrame(...)
|
|
114
|
-
|
|
115
|
-
# Create a graph from a dictionary of data frames:
|
|
116
|
-
graph = rfm.LocalGraph.from_data({
|
|
117
|
-
"table1": df1,
|
|
118
|
-
"table2": df2,
|
|
119
|
-
"table3": df3,
|
|
120
|
-
})
|
|
132
|
+
>>> # doctest: +SKIP
|
|
133
|
+
>>> import pandas as pd
|
|
134
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
121
135
|
|
|
122
|
-
#
|
|
123
|
-
|
|
124
|
-
|
|
136
|
+
>>> # Load data frames into memory:
|
|
137
|
+
>>> df1 = pd.DataFrame(...)
|
|
138
|
+
>>> df2 = pd.DataFrame(...)
|
|
139
|
+
>>> df3 = pd.DataFrame(...)
|
|
125
140
|
|
|
126
|
-
#
|
|
127
|
-
graph.
|
|
141
|
+
>>> # Create a graph from a dictionary of data frames:
|
|
142
|
+
>>> graph = rfm.Graph.from_data({
|
|
143
|
+
... "table1": df1,
|
|
144
|
+
... "table2": df2,
|
|
145
|
+
... "table3": df3,
|
|
146
|
+
... })
|
|
128
147
|
|
|
129
148
|
Args:
|
|
130
149
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
131
150
|
of the tables and the values hold table data.
|
|
151
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
152
|
+
add to the graph. If not provided, edges will be automatically
|
|
153
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
132
154
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
133
155
|
graph.
|
|
156
|
+
verbose: Whether to print verbose output.
|
|
157
|
+
"""
|
|
158
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
|
+
|
|
161
|
+
graph = cls(tables, edges=edges or [])
|
|
162
|
+
|
|
163
|
+
if infer_metadata:
|
|
164
|
+
graph.infer_metadata(False)
|
|
165
|
+
|
|
166
|
+
if edges is None:
|
|
167
|
+
graph.infer_links(False)
|
|
168
|
+
|
|
169
|
+
if verbose:
|
|
170
|
+
graph.print_metadata()
|
|
171
|
+
graph.print_links()
|
|
172
|
+
|
|
173
|
+
return graph
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_sqlite(
|
|
177
|
+
cls,
|
|
178
|
+
connection: Union[
|
|
179
|
+
'AdbcSqliteConnection',
|
|
180
|
+
SqliteConnectionConfig,
|
|
181
|
+
str,
|
|
182
|
+
Path,
|
|
183
|
+
Dict[str, Any],
|
|
184
|
+
],
|
|
185
|
+
table_names: Optional[Sequence[str]] = None,
|
|
186
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
187
|
+
infer_metadata: bool = True,
|
|
188
|
+
verbose: bool = True,
|
|
189
|
+
) -> Self:
|
|
190
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
191
|
+
|
|
192
|
+
Automatically infers table metadata and links by default.
|
|
193
|
+
|
|
194
|
+
.. code-block:: python
|
|
195
|
+
|
|
196
|
+
>>> # doctest: +SKIP
|
|
197
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
198
|
+
|
|
199
|
+
>>> # Create a graph from a SQLite database:
|
|
200
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
connection: An open connection from
|
|
204
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
|
+
path to the database file.
|
|
206
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
207
|
+
all tables present in the database.
|
|
134
208
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
135
209
|
add to the graph. If not provided, edges will be automatically
|
|
136
|
-
inferred from the data
|
|
210
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
212
|
+
graph.
|
|
137
213
|
verbose: Whether to print verbose output.
|
|
214
|
+
"""
|
|
215
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
216
|
+
Connection,
|
|
217
|
+
SQLiteTable,
|
|
218
|
+
connect,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
internal_connection = False
|
|
222
|
+
if not isinstance(connection, Connection):
|
|
223
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
224
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
226
|
+
internal_connection = True
|
|
227
|
+
assert isinstance(connection, Connection)
|
|
228
|
+
|
|
229
|
+
if table_names is None:
|
|
230
|
+
with connection.cursor() as cursor:
|
|
231
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
232
|
+
"WHERE type='table'")
|
|
233
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
234
|
+
|
|
235
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
138
236
|
|
|
139
|
-
|
|
140
|
-
This method will automatically infer metadata and links for the
|
|
141
|
-
graph.
|
|
237
|
+
graph = cls(tables, edges=edges or [])
|
|
142
238
|
|
|
143
|
-
|
|
239
|
+
if internal_connection:
|
|
240
|
+
graph._connection = connection # type: ignore
|
|
241
|
+
|
|
242
|
+
if infer_metadata:
|
|
243
|
+
graph.infer_metadata(False)
|
|
244
|
+
|
|
245
|
+
if edges is None:
|
|
246
|
+
graph.infer_links(False)
|
|
247
|
+
|
|
248
|
+
if verbose:
|
|
249
|
+
graph.print_metadata()
|
|
250
|
+
graph.print_links()
|
|
251
|
+
|
|
252
|
+
return graph
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def from_snowflake(
|
|
256
|
+
cls,
|
|
257
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
258
|
+
database: Optional[str] = None,
|
|
259
|
+
schema: Optional[str] = None,
|
|
260
|
+
table_names: Optional[Sequence[str]] = None,
|
|
261
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
262
|
+
infer_metadata: bool = True,
|
|
263
|
+
verbose: bool = True,
|
|
264
|
+
) -> Self:
|
|
265
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
266
|
+
schema.
|
|
267
|
+
|
|
268
|
+
Automatically infers table metadata and links by default.
|
|
269
|
+
|
|
270
|
+
.. code-block:: python
|
|
271
|
+
|
|
272
|
+
>>> # doctest: +SKIP
|
|
144
273
|
>>> import kumoai.experimental.rfm as rfm
|
|
145
|
-
|
|
146
|
-
>>>
|
|
147
|
-
>>>
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
274
|
+
|
|
275
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
276
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
connection: An open connection from
|
|
280
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
281
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
282
|
+
connection. If ``None``, will re-use an active session in case
|
|
283
|
+
it exists, or create a new connection from credentials stored
|
|
284
|
+
in environment variables.
|
|
285
|
+
database: The database.
|
|
286
|
+
schema: The schema.
|
|
287
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
288
|
+
all tables present in the database.
|
|
289
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
290
|
+
add to the graph. If not provided, edges will be automatically
|
|
291
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
292
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
293
|
+
graph.
|
|
294
|
+
verbose: Whether to print verbose output.
|
|
154
295
|
"""
|
|
155
|
-
|
|
296
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
297
|
+
Connection,
|
|
298
|
+
SnowTable,
|
|
299
|
+
connect,
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if not isinstance(connection, Connection):
|
|
303
|
+
connection = connect(**(connection or {}))
|
|
304
|
+
assert isinstance(connection, Connection)
|
|
305
|
+
|
|
306
|
+
if table_names is None:
|
|
307
|
+
with connection.cursor() as cursor:
|
|
308
|
+
if database is None and schema is None:
|
|
309
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
310
|
+
"CURRENT_SCHEMA()")
|
|
311
|
+
result = cursor.fetchone()
|
|
312
|
+
database = database or result[0]
|
|
313
|
+
schema = schema or result[1]
|
|
314
|
+
cursor.execute(f"""
|
|
315
|
+
SELECT TABLE_NAME
|
|
316
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
317
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
318
|
+
""")
|
|
319
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
320
|
+
|
|
321
|
+
tables = [
|
|
322
|
+
SnowTable(
|
|
323
|
+
connection,
|
|
324
|
+
name=table_name,
|
|
325
|
+
database=database,
|
|
326
|
+
schema=schema,
|
|
327
|
+
) for table_name in table_names
|
|
328
|
+
]
|
|
156
329
|
|
|
157
330
|
graph = cls(tables, edges=edges or [])
|
|
158
331
|
|
|
159
332
|
if infer_metadata:
|
|
160
|
-
graph.infer_metadata(
|
|
333
|
+
graph.infer_metadata(False)
|
|
161
334
|
|
|
162
335
|
if edges is None:
|
|
163
|
-
graph.infer_links(
|
|
336
|
+
graph.infer_links(False)
|
|
337
|
+
|
|
338
|
+
if verbose:
|
|
339
|
+
graph.print_metadata()
|
|
340
|
+
graph.print_links()
|
|
164
341
|
|
|
165
342
|
return graph
|
|
166
343
|
|
|
167
|
-
|
|
344
|
+
@classmethod
|
|
345
|
+
def from_snowflake_semantic_view(
|
|
346
|
+
cls,
|
|
347
|
+
semantic_view_name: str,
|
|
348
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
349
|
+
verbose: bool = True,
|
|
350
|
+
) -> Self:
|
|
351
|
+
import yaml
|
|
352
|
+
|
|
353
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
354
|
+
Connection,
|
|
355
|
+
SnowTable,
|
|
356
|
+
connect,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if not isinstance(connection, Connection):
|
|
360
|
+
connection = connect(**(connection or {}))
|
|
361
|
+
assert isinstance(connection, Connection)
|
|
362
|
+
|
|
363
|
+
with connection.cursor() as cursor:
|
|
364
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
365
|
+
f"'{semantic_view_name}')")
|
|
366
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
367
|
+
|
|
368
|
+
graph = cls(tables=[])
|
|
369
|
+
|
|
370
|
+
for table_desc in view['tables']:
|
|
371
|
+
primary_key: Optional[str] = None
|
|
372
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
373
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
374
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
375
|
+
|
|
376
|
+
table = SnowTable(
|
|
377
|
+
connection,
|
|
378
|
+
name=table_desc['base_table']['table'],
|
|
379
|
+
database=table_desc['base_table']['database'],
|
|
380
|
+
schema=table_desc['base_table']['schema'],
|
|
381
|
+
primary_key=primary_key,
|
|
382
|
+
)
|
|
383
|
+
graph.add_table(table)
|
|
384
|
+
|
|
385
|
+
# TODO Find a solution to register time columns!
|
|
386
|
+
|
|
387
|
+
for relations in view['relationships']:
|
|
388
|
+
if len(relations['relationship_columns']) != 1:
|
|
389
|
+
continue # NOTE No composite keys yet.
|
|
390
|
+
graph.link(
|
|
391
|
+
src_table=relations['left_table'],
|
|
392
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
393
|
+
dst_table=relations['right_table'],
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if verbose:
|
|
397
|
+
graph.print_metadata()
|
|
398
|
+
graph.print_links()
|
|
399
|
+
|
|
400
|
+
return graph
|
|
401
|
+
|
|
402
|
+
# Backend #################################################################
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def backend(self) -> DataBackend | None:
|
|
406
|
+
backends = [table.backend for table in self._tables.values()]
|
|
407
|
+
return backends[0] if len(backends) > 0 else None
|
|
408
|
+
|
|
409
|
+
# Tables ##################################################################
|
|
168
410
|
|
|
169
411
|
def has_table(self, name: str) -> bool:
|
|
170
412
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -172,7 +414,7 @@ class LocalGraph:
|
|
|
172
414
|
"""
|
|
173
415
|
return name in self.tables
|
|
174
416
|
|
|
175
|
-
def table(self, name: str) ->
|
|
417
|
+
def table(self, name: str) -> Table:
|
|
176
418
|
r"""Returns the table with name ``name`` in the graph.
|
|
177
419
|
|
|
178
420
|
Raises:
|
|
@@ -183,11 +425,11 @@ class LocalGraph:
|
|
|
183
425
|
return self.tables[name]
|
|
184
426
|
|
|
185
427
|
@property
|
|
186
|
-
def tables(self) -> Dict[str,
|
|
428
|
+
def tables(self) -> Dict[str, Table]:
|
|
187
429
|
r"""Returns the dictionary of table objects."""
|
|
188
430
|
return self._tables
|
|
189
431
|
|
|
190
|
-
def add_table(self, table:
|
|
432
|
+
def add_table(self, table: Table) -> Self:
|
|
191
433
|
r"""Adds a table to the graph.
|
|
192
434
|
|
|
193
435
|
Args:
|
|
@@ -196,11 +438,18 @@ class LocalGraph:
|
|
|
196
438
|
Raises:
|
|
197
439
|
KeyError: If a table with the same name already exists in the
|
|
198
440
|
graph.
|
|
441
|
+
ValueError: If the table belongs to a different backend than the
|
|
442
|
+
rest of the tables in the graph.
|
|
199
443
|
"""
|
|
200
444
|
if table.name in self._tables:
|
|
201
445
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
202
446
|
f"this graph; table names must be globally unique.")
|
|
203
447
|
|
|
448
|
+
if self.backend is not None and table.backend != self.backend:
|
|
449
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
450
|
+
f"'{table.backend}' to this graph since other "
|
|
451
|
+
f"tables have backend '{self.backend}'.")
|
|
452
|
+
|
|
204
453
|
self._tables[table.name] = table
|
|
205
454
|
|
|
206
455
|
return self
|
|
@@ -231,16 +480,17 @@ class LocalGraph:
|
|
|
231
480
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
232
481
|
information about the tables in this graph.
|
|
233
482
|
|
|
234
|
-
The returned dataframe has columns ``name``, ``primary_key``,
|
|
235
|
-
``time_column``, which provide an aggregate
|
|
236
|
-
the tables of this graph.
|
|
483
|
+
The returned dataframe has columns ``name``, ``primary_key``,
|
|
484
|
+
``time_column``, and ``end_time_column``, which provide an aggregate
|
|
485
|
+
view of the properties of the tables of this graph.
|
|
237
486
|
|
|
238
487
|
Example:
|
|
488
|
+
>>> # doctest: +SKIP
|
|
239
489
|
>>> import kumoai.experimental.rfm as rfm
|
|
240
|
-
>>> graph = rfm.
|
|
241
|
-
>>> graph.metadata
|
|
242
|
-
name
|
|
243
|
-
0 users
|
|
490
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
491
|
+
>>> graph.metadata # doctest: +SKIP
|
|
492
|
+
name primary_key time_column end_time_column
|
|
493
|
+
0 users user_id - -
|
|
244
494
|
"""
|
|
245
495
|
tables = list(self.tables.values())
|
|
246
496
|
|
|
@@ -251,13 +501,22 @@ class LocalGraph:
|
|
|
251
501
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
252
502
|
'time_column':
|
|
253
503
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
504
|
+
'end_time_column':
|
|
505
|
+
pd.Series(
|
|
506
|
+
dtype=str,
|
|
507
|
+
data=[t._end_time_column or '-' for t in tables],
|
|
508
|
+
),
|
|
254
509
|
})
|
|
255
510
|
|
|
256
511
|
def print_metadata(self) -> None:
|
|
257
|
-
r"""Prints the :meth:`~
|
|
258
|
-
if
|
|
512
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
513
|
+
if in_snowflake_notebook():
|
|
514
|
+
import streamlit as st
|
|
515
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
516
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
517
|
+
elif in_notebook():
|
|
259
518
|
from IPython.display import Markdown, display
|
|
260
|
-
display(Markdown(
|
|
519
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
261
520
|
df = self.metadata
|
|
262
521
|
try:
|
|
263
522
|
if hasattr(df.style, 'hide'):
|
|
@@ -278,7 +537,7 @@ class LocalGraph:
|
|
|
278
537
|
|
|
279
538
|
Note:
|
|
280
539
|
For more information, please see
|
|
281
|
-
:meth:`kumoai.experimental.rfm.
|
|
540
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
282
541
|
"""
|
|
283
542
|
for table in self.tables.values():
|
|
284
543
|
table.infer_metadata(verbose=False)
|
|
@@ -296,37 +555,47 @@ class LocalGraph:
|
|
|
296
555
|
return self._edges
|
|
297
556
|
|
|
298
557
|
def print_links(self) -> None:
|
|
299
|
-
r"""Prints the :meth:`~
|
|
558
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
300
559
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
301
560
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
302
561
|
edges = sorted(edges)
|
|
303
562
|
|
|
304
|
-
if
|
|
563
|
+
if in_snowflake_notebook():
|
|
564
|
+
import streamlit as st
|
|
565
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
566
|
+
if len(edges) > 0:
|
|
567
|
+
st.markdown('\n'.join([
|
|
568
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
569
|
+
for edge in edges
|
|
570
|
+
]))
|
|
571
|
+
else:
|
|
572
|
+
st.markdown("*No links registered*")
|
|
573
|
+
elif in_notebook():
|
|
305
574
|
from IPython.display import Markdown, display
|
|
306
|
-
display(Markdown(
|
|
575
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
307
576
|
if len(edges) > 0:
|
|
308
577
|
display(
|
|
309
578
|
Markdown('\n'.join([
|
|
310
|
-
f
|
|
579
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
311
580
|
for edge in edges
|
|
312
581
|
])))
|
|
313
582
|
else:
|
|
314
|
-
display(Markdown(
|
|
583
|
+
display(Markdown("*No links registered*"))
|
|
315
584
|
else:
|
|
316
585
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
317
586
|
if len(edges) > 0:
|
|
318
587
|
print('\n'.join([
|
|
319
|
-
f
|
|
588
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
320
589
|
for edge in edges
|
|
321
590
|
]))
|
|
322
591
|
else:
|
|
323
|
-
print(
|
|
592
|
+
print("No links registered")
|
|
324
593
|
|
|
325
594
|
def link(
|
|
326
595
|
self,
|
|
327
|
-
src_table: Union[str,
|
|
596
|
+
src_table: Union[str, Table],
|
|
328
597
|
fkey: str,
|
|
329
|
-
dst_table: Union[str,
|
|
598
|
+
dst_table: Union[str, Table],
|
|
330
599
|
) -> Self:
|
|
331
600
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
332
601
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -349,11 +618,11 @@ class LocalGraph:
|
|
|
349
618
|
table does not exist in the graph, if the source key does not
|
|
350
619
|
exist in the source table.
|
|
351
620
|
"""
|
|
352
|
-
if isinstance(src_table,
|
|
621
|
+
if isinstance(src_table, Table):
|
|
353
622
|
src_table = src_table.name
|
|
354
623
|
assert isinstance(src_table, str)
|
|
355
624
|
|
|
356
|
-
if isinstance(dst_table,
|
|
625
|
+
if isinstance(dst_table, Table):
|
|
357
626
|
dst_table = dst_table.name
|
|
358
627
|
assert isinstance(dst_table, str)
|
|
359
628
|
|
|
@@ -387,9 +656,9 @@ class LocalGraph:
|
|
|
387
656
|
|
|
388
657
|
def unlink(
|
|
389
658
|
self,
|
|
390
|
-
src_table: Union[str,
|
|
659
|
+
src_table: Union[str, Table],
|
|
391
660
|
fkey: str,
|
|
392
|
-
dst_table: Union[str,
|
|
661
|
+
dst_table: Union[str, Table],
|
|
393
662
|
) -> Self:
|
|
394
663
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
395
664
|
|
|
@@ -401,11 +670,11 @@ class LocalGraph:
|
|
|
401
670
|
Raises:
|
|
402
671
|
ValueError: if the edge is not present in the graph.
|
|
403
672
|
"""
|
|
404
|
-
if isinstance(src_table,
|
|
673
|
+
if isinstance(src_table, Table):
|
|
405
674
|
src_table = src_table.name
|
|
406
675
|
assert isinstance(src_table, str)
|
|
407
676
|
|
|
408
|
-
if isinstance(dst_table,
|
|
677
|
+
if isinstance(dst_table, Table):
|
|
409
678
|
dst_table = dst_table.name
|
|
410
679
|
assert isinstance(dst_table, str)
|
|
411
680
|
|
|
@@ -419,17 +688,13 @@ class LocalGraph:
|
|
|
419
688
|
return self
|
|
420
689
|
|
|
421
690
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
422
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
691
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
692
|
+
graph.
|
|
423
693
|
|
|
424
694
|
Args:
|
|
425
695
|
verbose: Whether to print verbose output.
|
|
426
|
-
|
|
427
|
-
Note:
|
|
428
|
-
This function expects graph edges to be undefined upfront.
|
|
429
696
|
"""
|
|
430
|
-
|
|
431
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
432
|
-
return self
|
|
697
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
433
698
|
|
|
434
699
|
# A list of primary key candidates (+score) for every column:
|
|
435
700
|
candidate_dict: dict[
|
|
@@ -454,6 +719,9 @@ class LocalGraph:
|
|
|
454
719
|
src_table_name = src_table.name.lower()
|
|
455
720
|
|
|
456
721
|
for src_key in src_table.columns:
|
|
722
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
723
|
+
continue
|
|
724
|
+
|
|
457
725
|
if src_key == src_table.primary_key:
|
|
458
726
|
continue # Cannot link to primary key.
|
|
459
727
|
|
|
@@ -519,7 +787,9 @@ class LocalGraph:
|
|
|
519
787
|
score += 1.0
|
|
520
788
|
|
|
521
789
|
# Cardinality ratio:
|
|
522
|
-
if
|
|
790
|
+
if (src_table._num_rows is not None
|
|
791
|
+
and dst_table._num_rows is not None
|
|
792
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
523
793
|
score += 1.0
|
|
524
794
|
|
|
525
795
|
if score < 5.0:
|
|
@@ -565,6 +835,10 @@ class LocalGraph:
|
|
|
565
835
|
raise ValueError("At least one table needs to be added to the "
|
|
566
836
|
"graph")
|
|
567
837
|
|
|
838
|
+
backends = {table.backend for table in self._tables.values()}
|
|
839
|
+
if len(backends) != 1:
|
|
840
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
841
|
+
|
|
568
842
|
for edge in self.edges:
|
|
569
843
|
src_table, fkey, dst_table = edge
|
|
570
844
|
|
|
@@ -602,8 +876,8 @@ class LocalGraph:
|
|
|
602
876
|
raise ValueError(f"{edge} is invalid as foreign key "
|
|
603
877
|
f"'{fkey}' and primary key '{dst_key.name}' "
|
|
604
878
|
f"have incompatible data types (got "
|
|
605
|
-
f"fkey.dtype '{
|
|
606
|
-
f"pkey.dtype '{
|
|
879
|
+
f"fkey.dtype '{src_key.dtype}' and "
|
|
880
|
+
f"pkey.dtype '{dst_key.dtype}')")
|
|
607
881
|
|
|
608
882
|
return self
|
|
609
883
|
|
|
@@ -636,19 +910,19 @@ class LocalGraph:
|
|
|
636
910
|
|
|
637
911
|
return True
|
|
638
912
|
|
|
639
|
-
# Check basic dependency:
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
913
|
+
try: # Check basic dependency:
|
|
914
|
+
import graphviz
|
|
915
|
+
except ImportError as e:
|
|
916
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
917
|
+
"visualization") from e
|
|
918
|
+
|
|
919
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
644
920
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
645
921
|
"executables are not installed. These "
|
|
646
922
|
"dependencies are required in addition to the "
|
|
647
923
|
"'graphviz' Python package. Please install "
|
|
648
924
|
"them as described at "
|
|
649
925
|
"https://graphviz.org/download/.")
|
|
650
|
-
else:
|
|
651
|
-
import graphviz
|
|
652
926
|
|
|
653
927
|
format: Optional[str] = None
|
|
654
928
|
if isinstance(path, str):
|
|
@@ -676,6 +950,11 @@ class LocalGraph:
|
|
|
676
950
|
]
|
|
677
951
|
if time_column := table.time_column:
|
|
678
952
|
keys += [f'{time_column.name}: Time ({time_column.dtype})']
|
|
953
|
+
if end_time_column := table.end_time_column:
|
|
954
|
+
keys += [
|
|
955
|
+
f'{end_time_column.name}: '
|
|
956
|
+
f'End Time ({end_time_column.dtype})'
|
|
957
|
+
]
|
|
679
958
|
key_repr = left_align(keys)
|
|
680
959
|
|
|
681
960
|
columns = []
|
|
@@ -683,9 +962,9 @@ class LocalGraph:
|
|
|
683
962
|
columns += [
|
|
684
963
|
f'{column.name}: {column.stype} ({column.dtype})'
|
|
685
964
|
for column in table.columns
|
|
686
|
-
if column.name not in fkeys_dict[table_name]
|
|
687
|
-
and column.name != table.
|
|
688
|
-
and column.name != table.
|
|
965
|
+
if column.name not in fkeys_dict[table_name] and
|
|
966
|
+
column.name != table._primary_key and column.name != table.
|
|
967
|
+
_time_column and column.name != table._end_time_column
|
|
689
968
|
]
|
|
690
969
|
column_repr = left_align(columns)
|
|
691
970
|
|
|
@@ -727,6 +1006,9 @@ class LocalGraph:
|
|
|
727
1006
|
graph.render(path, cleanup=True)
|
|
728
1007
|
elif isinstance(path, io.BytesIO):
|
|
729
1008
|
path.write(graph.pipe())
|
|
1009
|
+
elif in_snowflake_notebook():
|
|
1010
|
+
import streamlit as st
|
|
1011
|
+
st.graphviz_chart(graph)
|
|
730
1012
|
elif in_notebook():
|
|
731
1013
|
from IPython.display import display
|
|
732
1014
|
display(graph)
|
|
@@ -752,16 +1034,18 @@ class LocalGraph:
|
|
|
752
1034
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
753
1035
|
tables: Dict[str, TableDefinition] = {}
|
|
754
1036
|
col_groups: List[ColumnKeyGroup] = []
|
|
755
|
-
for
|
|
756
|
-
tables[
|
|
1037
|
+
for table_name, table in self.tables.items():
|
|
1038
|
+
tables[table_name] = table._to_api_table_definition()
|
|
757
1039
|
if table.primary_key is None:
|
|
758
1040
|
continue
|
|
759
|
-
keys = [ColumnKey(
|
|
1041
|
+
keys = [ColumnKey(table_name, table.primary_key.name)]
|
|
760
1042
|
for edge in self.edges:
|
|
761
|
-
if edge.dst_table ==
|
|
1043
|
+
if edge.dst_table == table_name:
|
|
762
1044
|
keys.append(ColumnKey(edge.src_table, edge.fkey))
|
|
763
|
-
keys = sorted(
|
|
764
|
-
|
|
1045
|
+
keys = sorted(
|
|
1046
|
+
list(set(keys)),
|
|
1047
|
+
key=lambda x: f'{x.table_name}.{x.col_name}',
|
|
1048
|
+
)
|
|
765
1049
|
if len(keys) > 1:
|
|
766
1050
|
col_groups.append(ColumnKeyGroup(keys))
|
|
767
1051
|
return GraphDefinition(tables, col_groups)
|
|
@@ -774,7 +1058,7 @@ class LocalGraph:
|
|
|
774
1058
|
def __contains__(self, name: str) -> bool:
|
|
775
1059
|
return self.has_table(name)
|
|
776
1060
|
|
|
777
|
-
def __getitem__(self, name: str) ->
|
|
1061
|
+
def __getitem__(self, name: str) -> Table:
|
|
778
1062
|
return self.table(name)
|
|
779
1063
|
|
|
780
1064
|
def __delitem__(self, name: str) -> None:
|
|
@@ -792,3 +1076,7 @@ class LocalGraph:
|
|
|
792
1076
|
f' tables={tables},\n'
|
|
793
1077
|
f' edges={edges},\n'
|
|
794
1078
|
f')')
|
|
1079
|
+
|
|
1080
|
+
def __del__(self) -> None:
|
|
1081
|
+
if hasattr(self, '_connection'):
|
|
1082
|
+
self._connection.close()
|