kumoai 2.8.0.dev202508221830__cp312-cp312-win_amd64.whl → 2.13.0.dev202512041141__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of kumoai might be problematic. Click here for more details.
- kumoai/__init__.py +22 -11
- kumoai/_version.py +1 -1
- kumoai/client/client.py +17 -16
- kumoai/client/endpoints.py +1 -0
- kumoai/client/rfm.py +37 -8
- kumoai/connector/file_upload_connector.py +94 -85
- kumoai/connector/utils.py +1399 -210
- kumoai/experimental/rfm/__init__.py +164 -46
- kumoai/experimental/rfm/authenticate.py +8 -5
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +545 -0
- kumoai/experimental/rfm/{local_graph.py → graph.py} +413 -144
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/infer/timestamp.py +7 -4
- kumoai/experimental/rfm/local_graph_sampler.py +58 -11
- kumoai/experimental/rfm/local_graph_store.py +45 -37
- kumoai/experimental/rfm/local_pquery_driver.py +342 -46
- kumoai/experimental/rfm/pquery/__init__.py +4 -4
- kumoai/experimental/rfm/pquery/{backend.py → executor.py} +28 -58
- kumoai/experimental/rfm/pquery/pandas_executor.py +532 -0
- kumoai/experimental/rfm/rfm.py +559 -148
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/jobs.py +27 -1
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/pquery/prediction_table.py +5 -3
- kumoai/pquery/training_table.py +5 -3
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- kumoai/trainer/job.py +9 -30
- kumoai/trainer/trainer.py +19 -10
- kumoai/utils/__init__.py +2 -1
- kumoai/utils/progress_logger.py +96 -16
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/METADATA +14 -5
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/RECORD +49 -36
- kumoai/experimental/rfm/local_table.py +0 -448
- kumoai/experimental/rfm/pquery/pandas_backend.py +0 -437
- kumoai/experimental/rfm/utils.py +0 -347
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/WHEEL +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.8.0.dev202508221830.dist-info → kumoai-2.13.0.dev202512041141.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import pandas as pd
|
|
9
10
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,156 +12,385 @@ from kumoapi.table import TableDefinition
|
|
|
11
12
|
from kumoapi.typing import Stype
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
+
from kumoai.experimental.rfm import Table
|
|
16
17
|
from kumoai.graph import Edge
|
|
18
|
+
from kumoai.mixin import CastMixin
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
import graphviz
|
|
22
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
23
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
|
|
26
|
+
@dataclass
|
|
27
|
+
class SqliteConnectionConfig(CastMixin):
|
|
28
|
+
uri: Union[str, Path]
|
|
29
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Graph:
|
|
33
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
34
|
tables in a relational database.
|
|
25
35
|
|
|
26
36
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
37
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
38
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
39
|
|
|
30
40
|
.. code-block:: python
|
|
31
41
|
|
|
32
|
-
|
|
33
|
-
import
|
|
42
|
+
>>> # doctest: +SKIP
|
|
43
|
+
>>> import pandas as pd
|
|
44
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
34
45
|
|
|
35
|
-
# Load data frames into memory:
|
|
36
|
-
df1 = pd.DataFrame(...)
|
|
37
|
-
df2 = pd.DataFrame(...)
|
|
38
|
-
df3 = pd.DataFrame(...)
|
|
46
|
+
>>> # Load data frames into memory:
|
|
47
|
+
>>> df1 = pd.DataFrame(...)
|
|
48
|
+
>>> df2 = pd.DataFrame(...)
|
|
49
|
+
>>> df3 = pd.DataFrame(...)
|
|
39
50
|
|
|
40
|
-
# Define tables from data frames:
|
|
41
|
-
table1 = rfm.LocalTable(name="table1", data=df1)
|
|
42
|
-
table2 = rfm.LocalTable(name="table2", data=df2)
|
|
43
|
-
table3 = rfm.LocalTable(name="table3", data=df3)
|
|
51
|
+
>>> # Define tables from data frames:
|
|
52
|
+
>>> table1 = rfm.LocalTable(name="table1", data=df1)
|
|
53
|
+
>>> table2 = rfm.LocalTable(name="table2", data=df2)
|
|
54
|
+
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
44
55
|
|
|
45
|
-
# Create a graph from a dictionary of tables:
|
|
46
|
-
graph = rfm.
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
})
|
|
56
|
+
>>> # Create a graph from a dictionary of tables:
|
|
57
|
+
>>> graph = rfm.Graph({
|
|
58
|
+
... "table1": table1,
|
|
59
|
+
... "table2": table2,
|
|
60
|
+
... "table3": table3,
|
|
61
|
+
... })
|
|
51
62
|
|
|
52
|
-
# Infer table metadata:
|
|
53
|
-
graph.infer_metadata()
|
|
63
|
+
>>> # Infer table metadata:
|
|
64
|
+
>>> graph.infer_metadata()
|
|
54
65
|
|
|
55
|
-
# Infer links/edges:
|
|
56
|
-
graph.infer_links()
|
|
66
|
+
>>> # Infer links/edges:
|
|
67
|
+
>>> graph.infer_links()
|
|
57
68
|
|
|
58
|
-
# Inspect table metadata:
|
|
59
|
-
for table in graph.tables.values():
|
|
60
|
-
|
|
69
|
+
>>> # Inspect table metadata:
|
|
70
|
+
>>> for table in graph.tables.values():
|
|
71
|
+
... table.print_metadata()
|
|
61
72
|
|
|
62
|
-
# Visualize graph (if graphviz is installed):
|
|
63
|
-
graph.visualize()
|
|
73
|
+
>>> # Visualize graph (if graphviz is installed):
|
|
74
|
+
>>> graph.visualize()
|
|
64
75
|
|
|
65
|
-
# Add/Remove edges between tables:
|
|
66
|
-
graph.link(src_table="table1", fkey="id1", dst_table="table2")
|
|
67
|
-
graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
|
|
76
|
+
>>> # Add/Remove edges between tables:
|
|
77
|
+
>>> graph.link(src_table="table1", fkey="id1", dst_table="table2")
|
|
78
|
+
>>> graph.unlink(src_table="table1", fkey="id1", dst_table="table2")
|
|
68
79
|
|
|
69
|
-
# Validate graph:
|
|
70
|
-
graph.validate()
|
|
80
|
+
>>> # Validate graph:
|
|
81
|
+
>>> graph.validate()
|
|
71
82
|
"""
|
|
72
83
|
|
|
73
84
|
# Constructors ############################################################
|
|
74
85
|
|
|
75
86
|
def __init__(
|
|
76
87
|
self,
|
|
77
|
-
tables:
|
|
78
|
-
edges: Optional[
|
|
88
|
+
tables: Sequence[Table],
|
|
89
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
79
90
|
) -> None:
|
|
80
91
|
|
|
81
|
-
self._tables: Dict[str,
|
|
92
|
+
self._tables: Dict[str, Table] = {}
|
|
82
93
|
self._edges: List[Edge] = []
|
|
83
94
|
|
|
84
95
|
for table in tables:
|
|
85
96
|
self.add_table(table)
|
|
86
97
|
|
|
98
|
+
for table in tables:
|
|
99
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
101
|
+
continue
|
|
102
|
+
if self[fkey.dst_table].primary_key is None:
|
|
103
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
+
f"and '{fkey.primary_key}' in table "
|
|
108
|
+
f"'{fkey.dst_table}'.")
|
|
109
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
110
|
+
|
|
87
111
|
for edge in (edges or []):
|
|
88
112
|
_edge = Edge._cast(edge)
|
|
89
113
|
assert _edge is not None
|
|
90
|
-
self.
|
|
114
|
+
if _edge not in self._edges:
|
|
115
|
+
self.link(*_edge)
|
|
91
116
|
|
|
92
117
|
@classmethod
|
|
93
118
|
def from_data(
|
|
94
119
|
cls,
|
|
95
120
|
df_dict: Dict[str, pd.DataFrame],
|
|
96
|
-
edges: Optional[
|
|
121
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
97
122
|
infer_metadata: bool = True,
|
|
98
123
|
verbose: bool = True,
|
|
99
124
|
) -> Self:
|
|
100
|
-
r"""Creates a :class:`
|
|
125
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
101
126
|
:class:`pandas.DataFrame` objects.
|
|
102
127
|
|
|
103
|
-
Automatically infers table metadata and links.
|
|
128
|
+
Automatically infers table metadata and links by default.
|
|
104
129
|
|
|
105
130
|
.. code-block:: python
|
|
106
131
|
|
|
107
|
-
|
|
108
|
-
import
|
|
109
|
-
|
|
110
|
-
# Load data frames into memory:
|
|
111
|
-
df1 = pd.DataFrame(...)
|
|
112
|
-
df2 = pd.DataFrame(...)
|
|
113
|
-
df3 = pd.DataFrame(...)
|
|
114
|
-
|
|
115
|
-
# Create a graph from a dictionary of data frames:
|
|
116
|
-
graph = rfm.LocalGraph.from_data({
|
|
117
|
-
"table1": df1,
|
|
118
|
-
"table2": df2,
|
|
119
|
-
"table3": df3,
|
|
120
|
-
})
|
|
132
|
+
>>> # doctest: +SKIP
|
|
133
|
+
>>> import pandas as pd
|
|
134
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
121
135
|
|
|
122
|
-
#
|
|
123
|
-
|
|
124
|
-
|
|
136
|
+
>>> # Load data frames into memory:
|
|
137
|
+
>>> df1 = pd.DataFrame(...)
|
|
138
|
+
>>> df2 = pd.DataFrame(...)
|
|
139
|
+
>>> df3 = pd.DataFrame(...)
|
|
125
140
|
|
|
126
|
-
#
|
|
127
|
-
graph.
|
|
141
|
+
>>> # Create a graph from a dictionary of data frames:
|
|
142
|
+
>>> graph = rfm.Graph.from_data({
|
|
143
|
+
... "table1": df1,
|
|
144
|
+
... "table2": df2,
|
|
145
|
+
... "table3": df3,
|
|
146
|
+
... })
|
|
128
147
|
|
|
129
148
|
Args:
|
|
130
149
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
131
150
|
of the tables and the values hold table data.
|
|
151
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
152
|
+
add to the graph. If not provided, edges will be automatically
|
|
153
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
132
154
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
133
155
|
graph.
|
|
156
|
+
verbose: Whether to print verbose output.
|
|
157
|
+
"""
|
|
158
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
|
+
|
|
161
|
+
graph = cls(tables, edges=edges or [])
|
|
162
|
+
|
|
163
|
+
if infer_metadata:
|
|
164
|
+
graph.infer_metadata(False)
|
|
165
|
+
|
|
166
|
+
if edges is None:
|
|
167
|
+
graph.infer_links(False)
|
|
168
|
+
|
|
169
|
+
if verbose:
|
|
170
|
+
graph.print_metadata()
|
|
171
|
+
graph.print_links()
|
|
172
|
+
|
|
173
|
+
return graph
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_sqlite(
|
|
177
|
+
cls,
|
|
178
|
+
connection: Union[
|
|
179
|
+
'AdbcSqliteConnection',
|
|
180
|
+
SqliteConnectionConfig,
|
|
181
|
+
str,
|
|
182
|
+
Path,
|
|
183
|
+
Dict[str, Any],
|
|
184
|
+
],
|
|
185
|
+
table_names: Optional[Sequence[str]] = None,
|
|
186
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
187
|
+
infer_metadata: bool = True,
|
|
188
|
+
verbose: bool = True,
|
|
189
|
+
) -> Self:
|
|
190
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
191
|
+
|
|
192
|
+
Automatically infers table metadata and links by default.
|
|
193
|
+
|
|
194
|
+
.. code-block:: python
|
|
195
|
+
|
|
196
|
+
>>> # doctest: +SKIP
|
|
197
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
198
|
+
|
|
199
|
+
>>> # Create a graph from a SQLite database:
|
|
200
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
connection: An open connection from
|
|
204
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
|
+
path to the database file.
|
|
206
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
207
|
+
all tables present in the database.
|
|
134
208
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
135
209
|
add to the graph. If not provided, edges will be automatically
|
|
136
|
-
inferred from the data
|
|
210
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
212
|
+
graph.
|
|
137
213
|
verbose: Whether to print verbose output.
|
|
214
|
+
"""
|
|
215
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
216
|
+
Connection,
|
|
217
|
+
SQLiteTable,
|
|
218
|
+
connect,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if not isinstance(connection, Connection):
|
|
222
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
223
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
225
|
+
assert isinstance(connection, Connection)
|
|
226
|
+
|
|
227
|
+
if table_names is None:
|
|
228
|
+
with connection.cursor() as cursor:
|
|
229
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
230
|
+
"WHERE type='table'")
|
|
231
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
232
|
+
|
|
233
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
138
234
|
|
|
139
|
-
|
|
140
|
-
This method will automatically infer metadata and links for the
|
|
141
|
-
graph.
|
|
235
|
+
graph = cls(tables, edges=edges or [])
|
|
142
236
|
|
|
143
|
-
|
|
237
|
+
if infer_metadata:
|
|
238
|
+
graph.infer_metadata(False)
|
|
239
|
+
|
|
240
|
+
if edges is None:
|
|
241
|
+
graph.infer_links(False)
|
|
242
|
+
|
|
243
|
+
if verbose:
|
|
244
|
+
graph.print_metadata()
|
|
245
|
+
graph.print_links()
|
|
246
|
+
|
|
247
|
+
return graph
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def from_snowflake(
|
|
251
|
+
cls,
|
|
252
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
253
|
+
database: Optional[str] = None,
|
|
254
|
+
schema: Optional[str] = None,
|
|
255
|
+
table_names: Optional[Sequence[str]] = None,
|
|
256
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
257
|
+
infer_metadata: bool = True,
|
|
258
|
+
verbose: bool = True,
|
|
259
|
+
) -> Self:
|
|
260
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
261
|
+
schema.
|
|
262
|
+
|
|
263
|
+
Automatically infers table metadata and links by default.
|
|
264
|
+
|
|
265
|
+
.. code-block:: python
|
|
266
|
+
|
|
267
|
+
>>> # doctest: +SKIP
|
|
144
268
|
>>> import kumoai.experimental.rfm as rfm
|
|
145
|
-
|
|
146
|
-
>>>
|
|
147
|
-
>>>
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
269
|
+
|
|
270
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
271
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
connection: An open connection from
|
|
275
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
276
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
277
|
+
connection. If ``None``, will re-use an active session in case
|
|
278
|
+
it exists, or create a new connection from credentials stored
|
|
279
|
+
in environment variables.
|
|
280
|
+
database: The database.
|
|
281
|
+
schema: The schema.
|
|
282
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
283
|
+
all tables present in the database.
|
|
284
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
285
|
+
add to the graph. If not provided, edges will be automatically
|
|
286
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
287
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
288
|
+
graph.
|
|
289
|
+
verbose: Whether to print verbose output.
|
|
154
290
|
"""
|
|
155
|
-
|
|
291
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
292
|
+
Connection,
|
|
293
|
+
SnowTable,
|
|
294
|
+
connect,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if not isinstance(connection, Connection):
|
|
298
|
+
connection = connect(**(connection or {}))
|
|
299
|
+
assert isinstance(connection, Connection)
|
|
300
|
+
|
|
301
|
+
if table_names is None:
|
|
302
|
+
with connection.cursor() as cursor:
|
|
303
|
+
if database is None and schema is None:
|
|
304
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
305
|
+
"CURRENT_SCHEMA()")
|
|
306
|
+
result = cursor.fetchone()
|
|
307
|
+
database = database or result[0]
|
|
308
|
+
schema = schema or result[1]
|
|
309
|
+
cursor.execute(f"""
|
|
310
|
+
SELECT TABLE_NAME
|
|
311
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
312
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
313
|
+
""")
|
|
314
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
315
|
+
|
|
316
|
+
tables = [
|
|
317
|
+
SnowTable(
|
|
318
|
+
connection,
|
|
319
|
+
name=table_name,
|
|
320
|
+
database=database,
|
|
321
|
+
schema=schema,
|
|
322
|
+
) for table_name in table_names
|
|
323
|
+
]
|
|
156
324
|
|
|
157
325
|
graph = cls(tables, edges=edges or [])
|
|
158
326
|
|
|
159
327
|
if infer_metadata:
|
|
160
|
-
graph.infer_metadata(
|
|
328
|
+
graph.infer_metadata(False)
|
|
161
329
|
|
|
162
330
|
if edges is None:
|
|
163
|
-
graph.infer_links(
|
|
331
|
+
graph.infer_links(False)
|
|
332
|
+
|
|
333
|
+
if verbose:
|
|
334
|
+
graph.print_metadata()
|
|
335
|
+
graph.print_links()
|
|
336
|
+
|
|
337
|
+
return graph
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def from_snowflake_semantic_view(
|
|
341
|
+
cls,
|
|
342
|
+
semantic_view_name: str,
|
|
343
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
344
|
+
verbose: bool = True,
|
|
345
|
+
) -> Self:
|
|
346
|
+
import yaml
|
|
347
|
+
|
|
348
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
349
|
+
Connection,
|
|
350
|
+
SnowTable,
|
|
351
|
+
connect,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if not isinstance(connection, Connection):
|
|
355
|
+
connection = connect(**(connection or {}))
|
|
356
|
+
assert isinstance(connection, Connection)
|
|
357
|
+
|
|
358
|
+
with connection.cursor() as cursor:
|
|
359
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
360
|
+
f"'{semantic_view_name}')")
|
|
361
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
362
|
+
|
|
363
|
+
graph = cls(tables=[])
|
|
364
|
+
|
|
365
|
+
for table_desc in view['tables']:
|
|
366
|
+
primary_key: Optional[str] = None
|
|
367
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
368
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
369
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
370
|
+
|
|
371
|
+
table = SnowTable(
|
|
372
|
+
connection,
|
|
373
|
+
name=table_desc['base_table']['table'],
|
|
374
|
+
database=table_desc['base_table']['database'],
|
|
375
|
+
schema=table_desc['base_table']['schema'],
|
|
376
|
+
primary_key=primary_key,
|
|
377
|
+
)
|
|
378
|
+
graph.add_table(table)
|
|
379
|
+
|
|
380
|
+
# TODO Find a solution to register time columns!
|
|
381
|
+
|
|
382
|
+
for relations in view['relationships']:
|
|
383
|
+
if len(relations['relationship_columns']) != 1:
|
|
384
|
+
continue # NOTE No composite keys yet.
|
|
385
|
+
graph.link(
|
|
386
|
+
src_table=relations['left_table'],
|
|
387
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
388
|
+
dst_table=relations['right_table'],
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if verbose:
|
|
392
|
+
graph.print_metadata()
|
|
393
|
+
graph.print_links()
|
|
164
394
|
|
|
165
395
|
return graph
|
|
166
396
|
|
|
@@ -172,7 +402,7 @@ class LocalGraph:
|
|
|
172
402
|
"""
|
|
173
403
|
return name in self.tables
|
|
174
404
|
|
|
175
|
-
def table(self, name: str) ->
|
|
405
|
+
def table(self, name: str) -> Table:
|
|
176
406
|
r"""Returns the table with name ``name`` in the graph.
|
|
177
407
|
|
|
178
408
|
Raises:
|
|
@@ -183,11 +413,11 @@ class LocalGraph:
|
|
|
183
413
|
return self.tables[name]
|
|
184
414
|
|
|
185
415
|
@property
|
|
186
|
-
def tables(self) -> Dict[str,
|
|
416
|
+
def tables(self) -> Dict[str, Table]:
|
|
187
417
|
r"""Returns the dictionary of table objects."""
|
|
188
418
|
return self._tables
|
|
189
419
|
|
|
190
|
-
def add_table(self, table:
|
|
420
|
+
def add_table(self, table: Table) -> Self:
|
|
191
421
|
r"""Adds a table to the graph.
|
|
192
422
|
|
|
193
423
|
Args:
|
|
@@ -196,17 +426,21 @@ class LocalGraph:
|
|
|
196
426
|
Raises:
|
|
197
427
|
KeyError: If a table with the same name already exists in the
|
|
198
428
|
graph.
|
|
429
|
+
ValueError: If the table belongs to a different backend than the
|
|
430
|
+
rest of the tables in the graph.
|
|
199
431
|
"""
|
|
200
|
-
if len(self.tables) >= 15:
|
|
201
|
-
raise ValueError("Cannot create a graph with more than 15 "
|
|
202
|
-
"tables. Please create a feature request at "
|
|
203
|
-
"'https://github.com/kumo-ai/kumo-rfm' if you "
|
|
204
|
-
"must go beyond this for your use-case.")
|
|
205
|
-
|
|
206
432
|
if table.name in self._tables:
|
|
207
433
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
208
434
|
f"this graph; table names must be globally unique.")
|
|
209
435
|
|
|
436
|
+
if len(self._tables) > 0:
|
|
437
|
+
cls = next(iter(self._tables.values())).__class__
|
|
438
|
+
if table.__class__ != cls:
|
|
439
|
+
raise ValueError(f"Cannot register a "
|
|
440
|
+
f"'{table.__class__.__name__}' to this "
|
|
441
|
+
f"graph since other tables are of type "
|
|
442
|
+
f"'{cls.__name__}'.")
|
|
443
|
+
|
|
210
444
|
self._tables[table.name] = table
|
|
211
445
|
|
|
212
446
|
return self
|
|
@@ -237,16 +471,17 @@ class LocalGraph:
|
|
|
237
471
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
238
472
|
information about the tables in this graph.
|
|
239
473
|
|
|
240
|
-
The returned dataframe has columns ``name``, ``primary_key``,
|
|
241
|
-
``time_column``, which provide an aggregate
|
|
242
|
-
the tables of this graph.
|
|
474
|
+
The returned dataframe has columns ``name``, ``primary_key``,
|
|
475
|
+
``time_column``, and ``end_time_column``, which provide an aggregate
|
|
476
|
+
view of the properties of the tables of this graph.
|
|
243
477
|
|
|
244
478
|
Example:
|
|
479
|
+
>>> # doctest: +SKIP
|
|
245
480
|
>>> import kumoai.experimental.rfm as rfm
|
|
246
|
-
>>> graph = rfm.
|
|
247
|
-
>>> graph.metadata
|
|
248
|
-
name
|
|
249
|
-
0 users
|
|
481
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
482
|
+
>>> graph.metadata # doctest: +SKIP
|
|
483
|
+
name primary_key time_column end_time_column
|
|
484
|
+
0 users user_id - -
|
|
250
485
|
"""
|
|
251
486
|
tables = list(self.tables.values())
|
|
252
487
|
|
|
@@ -257,13 +492,22 @@ class LocalGraph:
|
|
|
257
492
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
258
493
|
'time_column':
|
|
259
494
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
495
|
+
'end_time_column':
|
|
496
|
+
pd.Series(
|
|
497
|
+
dtype=str,
|
|
498
|
+
data=[t._end_time_column or '-' for t in tables],
|
|
499
|
+
),
|
|
260
500
|
})
|
|
261
501
|
|
|
262
502
|
def print_metadata(self) -> None:
|
|
263
|
-
r"""Prints the :meth:`~
|
|
264
|
-
if
|
|
503
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
504
|
+
if in_snowflake_notebook():
|
|
505
|
+
import streamlit as st
|
|
506
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
507
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
508
|
+
elif in_notebook():
|
|
265
509
|
from IPython.display import Markdown, display
|
|
266
|
-
display(Markdown(
|
|
510
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
267
511
|
df = self.metadata
|
|
268
512
|
try:
|
|
269
513
|
if hasattr(df.style, 'hide'):
|
|
@@ -284,7 +528,7 @@ class LocalGraph:
|
|
|
284
528
|
|
|
285
529
|
Note:
|
|
286
530
|
For more information, please see
|
|
287
|
-
:meth:`kumoai.experimental.rfm.
|
|
531
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
288
532
|
"""
|
|
289
533
|
for table in self.tables.values():
|
|
290
534
|
table.infer_metadata(verbose=False)
|
|
@@ -302,37 +546,47 @@ class LocalGraph:
|
|
|
302
546
|
return self._edges
|
|
303
547
|
|
|
304
548
|
def print_links(self) -> None:
|
|
305
|
-
r"""Prints the :meth:`~
|
|
549
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
306
550
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
307
551
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
308
552
|
edges = sorted(edges)
|
|
309
553
|
|
|
310
|
-
if
|
|
554
|
+
if in_snowflake_notebook():
|
|
555
|
+
import streamlit as st
|
|
556
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
557
|
+
if len(edges) > 0:
|
|
558
|
+
st.markdown('\n'.join([
|
|
559
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
560
|
+
for edge in edges
|
|
561
|
+
]))
|
|
562
|
+
else:
|
|
563
|
+
st.markdown("*No links registered*")
|
|
564
|
+
elif in_notebook():
|
|
311
565
|
from IPython.display import Markdown, display
|
|
312
|
-
display(Markdown(
|
|
566
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
313
567
|
if len(edges) > 0:
|
|
314
568
|
display(
|
|
315
569
|
Markdown('\n'.join([
|
|
316
|
-
f
|
|
570
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
317
571
|
for edge in edges
|
|
318
572
|
])))
|
|
319
573
|
else:
|
|
320
|
-
display(Markdown(
|
|
574
|
+
display(Markdown("*No links registered*"))
|
|
321
575
|
else:
|
|
322
576
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
323
577
|
if len(edges) > 0:
|
|
324
578
|
print('\n'.join([
|
|
325
|
-
f
|
|
579
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
326
580
|
for edge in edges
|
|
327
581
|
]))
|
|
328
582
|
else:
|
|
329
|
-
print(
|
|
583
|
+
print("No links registered")
|
|
330
584
|
|
|
331
585
|
def link(
|
|
332
586
|
self,
|
|
333
|
-
src_table: Union[str,
|
|
587
|
+
src_table: Union[str, Table],
|
|
334
588
|
fkey: str,
|
|
335
|
-
dst_table: Union[str,
|
|
589
|
+
dst_table: Union[str, Table],
|
|
336
590
|
) -> Self:
|
|
337
591
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
338
592
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -355,11 +609,11 @@ class LocalGraph:
|
|
|
355
609
|
table does not exist in the graph, if the source key does not
|
|
356
610
|
exist in the source table.
|
|
357
611
|
"""
|
|
358
|
-
if isinstance(src_table,
|
|
612
|
+
if isinstance(src_table, Table):
|
|
359
613
|
src_table = src_table.name
|
|
360
614
|
assert isinstance(src_table, str)
|
|
361
615
|
|
|
362
|
-
if isinstance(dst_table,
|
|
616
|
+
if isinstance(dst_table, Table):
|
|
363
617
|
dst_table = dst_table.name
|
|
364
618
|
assert isinstance(dst_table, str)
|
|
365
619
|
|
|
@@ -393,9 +647,9 @@ class LocalGraph:
|
|
|
393
647
|
|
|
394
648
|
def unlink(
|
|
395
649
|
self,
|
|
396
|
-
src_table: Union[str,
|
|
650
|
+
src_table: Union[str, Table],
|
|
397
651
|
fkey: str,
|
|
398
|
-
dst_table: Union[str,
|
|
652
|
+
dst_table: Union[str, Table],
|
|
399
653
|
) -> Self:
|
|
400
654
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
401
655
|
|
|
@@ -407,11 +661,11 @@ class LocalGraph:
|
|
|
407
661
|
Raises:
|
|
408
662
|
ValueError: if the edge is not present in the graph.
|
|
409
663
|
"""
|
|
410
|
-
if isinstance(src_table,
|
|
664
|
+
if isinstance(src_table, Table):
|
|
411
665
|
src_table = src_table.name
|
|
412
666
|
assert isinstance(src_table, str)
|
|
413
667
|
|
|
414
|
-
if isinstance(dst_table,
|
|
668
|
+
if isinstance(dst_table, Table):
|
|
415
669
|
dst_table = dst_table.name
|
|
416
670
|
assert isinstance(dst_table, str)
|
|
417
671
|
|
|
@@ -425,17 +679,13 @@ class LocalGraph:
|
|
|
425
679
|
return self
|
|
426
680
|
|
|
427
681
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
428
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
682
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
683
|
+
graph.
|
|
429
684
|
|
|
430
685
|
Args:
|
|
431
686
|
verbose: Whether to print verbose output.
|
|
432
|
-
|
|
433
|
-
Note:
|
|
434
|
-
This function expects graph edges to be undefined upfront.
|
|
435
687
|
"""
|
|
436
|
-
|
|
437
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
438
|
-
return self
|
|
688
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
439
689
|
|
|
440
690
|
# A list of primary key candidates (+score) for every column:
|
|
441
691
|
candidate_dict: dict[
|
|
@@ -460,6 +710,9 @@ class LocalGraph:
|
|
|
460
710
|
src_table_name = src_table.name.lower()
|
|
461
711
|
|
|
462
712
|
for src_key in src_table.columns:
|
|
713
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
714
|
+
continue
|
|
715
|
+
|
|
463
716
|
if src_key == src_table.primary_key:
|
|
464
717
|
continue # Cannot link to primary key.
|
|
465
718
|
|
|
@@ -525,7 +778,9 @@ class LocalGraph:
|
|
|
525
778
|
score += 1.0
|
|
526
779
|
|
|
527
780
|
# Cardinality ratio:
|
|
528
|
-
if
|
|
781
|
+
if (src_table._num_rows is not None
|
|
782
|
+
and dst_table._num_rows is not None
|
|
783
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
529
784
|
score += 1.0
|
|
530
785
|
|
|
531
786
|
if score < 5.0:
|
|
@@ -580,13 +835,17 @@ class LocalGraph:
|
|
|
580
835
|
# Check that the destination table defines a primary key:
|
|
581
836
|
if dst_key is None:
|
|
582
837
|
raise ValueError(f"Edge {edge} is invalid since table "
|
|
583
|
-
f"'{dst_table}' does not have a primary key"
|
|
838
|
+
f"'{dst_table}' does not have a primary key. "
|
|
839
|
+
f"Add either a primary key or remove the "
|
|
840
|
+
f"link before proceeding.")
|
|
584
841
|
|
|
585
842
|
# Ensure that foreign key is not a primary key:
|
|
586
843
|
src_pkey = self[src_table].primary_key
|
|
587
844
|
if src_pkey is not None and src_pkey.name == fkey:
|
|
588
845
|
raise ValueError(f"Cannot treat the primary key of table "
|
|
589
|
-
f"'{src_table}' as a foreign key"
|
|
846
|
+
f"'{src_table}' as a foreign key. Remove "
|
|
847
|
+
f"either the primary key or the link before "
|
|
848
|
+
f"before proceeding.")
|
|
590
849
|
|
|
591
850
|
# Check that fkey/pkey have valid and consistent data types:
|
|
592
851
|
assert src_key.dtype is not None
|
|
@@ -604,8 +863,8 @@ class LocalGraph:
|
|
|
604
863
|
raise ValueError(f"{edge} is invalid as foreign key "
|
|
605
864
|
f"'{fkey}' and primary key '{dst_key.name}' "
|
|
606
865
|
f"have incompatible data types (got "
|
|
607
|
-
f"fkey.dtype '{
|
|
608
|
-
f"pkey.dtype '{
|
|
866
|
+
f"fkey.dtype '{src_key.dtype}' and "
|
|
867
|
+
f"pkey.dtype '{dst_key.dtype}')")
|
|
609
868
|
|
|
610
869
|
return self
|
|
611
870
|
|
|
@@ -638,19 +897,19 @@ class LocalGraph:
|
|
|
638
897
|
|
|
639
898
|
return True
|
|
640
899
|
|
|
641
|
-
# Check basic dependency:
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
900
|
+
try: # Check basic dependency:
|
|
901
|
+
import graphviz
|
|
902
|
+
except ImportError as e:
|
|
903
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
904
|
+
"visualization") from e
|
|
905
|
+
|
|
906
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
646
907
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
647
908
|
"executables are not installed. These "
|
|
648
909
|
"dependencies are required in addition to the "
|
|
649
910
|
"'graphviz' Python package. Please install "
|
|
650
911
|
"them as described at "
|
|
651
912
|
"https://graphviz.org/download/.")
|
|
652
|
-
else:
|
|
653
|
-
import graphviz
|
|
654
913
|
|
|
655
914
|
format: Optional[str] = None
|
|
656
915
|
if isinstance(path, str):
|
|
@@ -678,6 +937,11 @@ class LocalGraph:
|
|
|
678
937
|
]
|
|
679
938
|
if time_column := table.time_column:
|
|
680
939
|
keys += [f'{time_column.name}: Time ({time_column.dtype})']
|
|
940
|
+
if end_time_column := table.end_time_column:
|
|
941
|
+
keys += [
|
|
942
|
+
f'{end_time_column.name}: '
|
|
943
|
+
f'End Time ({end_time_column.dtype})'
|
|
944
|
+
]
|
|
681
945
|
key_repr = left_align(keys)
|
|
682
946
|
|
|
683
947
|
columns = []
|
|
@@ -685,9 +949,9 @@ class LocalGraph:
|
|
|
685
949
|
columns += [
|
|
686
950
|
f'{column.name}: {column.stype} ({column.dtype})'
|
|
687
951
|
for column in table.columns
|
|
688
|
-
if column.name not in fkeys_dict[table_name]
|
|
689
|
-
and column.name != table.
|
|
690
|
-
and column.name != table.
|
|
952
|
+
if column.name not in fkeys_dict[table_name] and
|
|
953
|
+
column.name != table._primary_key and column.name != table.
|
|
954
|
+
_time_column and column.name != table._end_time_column
|
|
691
955
|
]
|
|
692
956
|
column_repr = left_align(columns)
|
|
693
957
|
|
|
@@ -729,6 +993,9 @@ class LocalGraph:
|
|
|
729
993
|
graph.render(path, cleanup=True)
|
|
730
994
|
elif isinstance(path, io.BytesIO):
|
|
731
995
|
path.write(graph.pipe())
|
|
996
|
+
elif in_snowflake_notebook():
|
|
997
|
+
import streamlit as st
|
|
998
|
+
st.graphviz_chart(graph)
|
|
732
999
|
elif in_notebook():
|
|
733
1000
|
from IPython.display import display
|
|
734
1001
|
display(graph)
|
|
@@ -754,16 +1021,18 @@ class LocalGraph:
|
|
|
754
1021
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
755
1022
|
tables: Dict[str, TableDefinition] = {}
|
|
756
1023
|
col_groups: List[ColumnKeyGroup] = []
|
|
757
|
-
for
|
|
758
|
-
tables[
|
|
1024
|
+
for table_name, table in self.tables.items():
|
|
1025
|
+
tables[table_name] = table._to_api_table_definition()
|
|
759
1026
|
if table.primary_key is None:
|
|
760
1027
|
continue
|
|
761
|
-
keys = [ColumnKey(
|
|
1028
|
+
keys = [ColumnKey(table_name, table.primary_key.name)]
|
|
762
1029
|
for edge in self.edges:
|
|
763
|
-
if edge.dst_table ==
|
|
1030
|
+
if edge.dst_table == table_name:
|
|
764
1031
|
keys.append(ColumnKey(edge.src_table, edge.fkey))
|
|
765
|
-
keys = sorted(
|
|
766
|
-
|
|
1032
|
+
keys = sorted(
|
|
1033
|
+
list(set(keys)),
|
|
1034
|
+
key=lambda x: f'{x.table_name}.{x.col_name}',
|
|
1035
|
+
)
|
|
767
1036
|
if len(keys) > 1:
|
|
768
1037
|
col_groups.append(ColumnKeyGroup(keys))
|
|
769
1038
|
return GraphDefinition(tables, col_groups)
|
|
@@ -776,7 +1045,7 @@ class LocalGraph:
|
|
|
776
1045
|
def __contains__(self, name: str) -> bool:
|
|
777
1046
|
return self.has_table(name)
|
|
778
1047
|
|
|
779
|
-
def __getitem__(self, name: str) ->
|
|
1048
|
+
def __getitem__(self, name: str) -> Table:
|
|
780
1049
|
return self.table(name)
|
|
781
1050
|
|
|
782
1051
|
def __delitem__(self, name: str) -> None:
|