kumoai 2.13.0.dev202511261731__cp313-cp313-macosx_11_0_arm64.whl → 2.13.0.dev202512081731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +42 -0
- kumoai/experimental/rfm/{local_graph_store.py → backend/local/graph_store.py} +20 -30
- kumoai/experimental/rfm/backend/local/sampler.py +116 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +14 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/sampler.py +373 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +139 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +334 -79
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +43 -2
- kumoai/experimental/rfm/local_pquery_driver.py +1 -1
- kumoai/experimental/rfm/rfm.py +7 -17
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/RECORD +33 -19
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512081731.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,9 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
6
|
-
from
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import pandas as pd
|
|
9
10
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -11,20 +12,29 @@ from kumoapi.table import TableDefinition
|
|
|
11
12
|
from kumoapi.typing import Stype
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
+
from kumoai.experimental.rfm import Table
|
|
16
17
|
from kumoai.graph import Edge
|
|
18
|
+
from kumoai.mixin import CastMixin
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
import graphviz
|
|
22
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
23
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
24
|
|
|
21
25
|
|
|
22
|
-
|
|
23
|
-
|
|
26
|
+
@dataclass
|
|
27
|
+
class SqliteConnectionConfig(CastMixin):
|
|
28
|
+
uri: Union[str, Path]
|
|
29
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Graph:
|
|
33
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
34
|
tables in a relational database.
|
|
25
35
|
|
|
26
36
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
37
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
38
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
39
|
|
|
30
40
|
.. code-block:: python
|
|
@@ -44,7 +54,7 @@ class LocalGraph:
|
|
|
44
54
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
55
|
|
|
46
56
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
57
|
+
>>> graph = rfm.Graph({
|
|
48
58
|
... "table1": table1,
|
|
49
59
|
... "table2": table2,
|
|
50
60
|
... "table3": table3,
|
|
@@ -75,33 +85,47 @@ class LocalGraph:
|
|
|
75
85
|
|
|
76
86
|
def __init__(
|
|
77
87
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges: Optional[
|
|
88
|
+
tables: Sequence[Table],
|
|
89
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
80
90
|
) -> None:
|
|
81
91
|
|
|
82
|
-
self._tables: Dict[str,
|
|
92
|
+
self._tables: Dict[str, Table] = {}
|
|
83
93
|
self._edges: List[Edge] = []
|
|
84
94
|
|
|
85
95
|
for table in tables:
|
|
86
96
|
self.add_table(table)
|
|
87
97
|
|
|
98
|
+
for table in tables:
|
|
99
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
101
|
+
continue
|
|
102
|
+
if self[fkey.dst_table].primary_key is None:
|
|
103
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
+
f"and '{fkey.primary_key}' in table "
|
|
108
|
+
f"'{fkey.dst_table}'.")
|
|
109
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
110
|
+
|
|
88
111
|
for edge in (edges or []):
|
|
89
112
|
_edge = Edge._cast(edge)
|
|
90
113
|
assert _edge is not None
|
|
91
|
-
self.
|
|
114
|
+
if _edge not in self._edges:
|
|
115
|
+
self.link(*_edge)
|
|
92
116
|
|
|
93
117
|
@classmethod
|
|
94
118
|
def from_data(
|
|
95
119
|
cls,
|
|
96
120
|
df_dict: Dict[str, pd.DataFrame],
|
|
97
|
-
edges: Optional[
|
|
121
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
98
122
|
infer_metadata: bool = True,
|
|
99
123
|
verbose: bool = True,
|
|
100
124
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
125
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
126
|
:class:`pandas.DataFrame` objects.
|
|
103
127
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
128
|
+
Automatically infers table metadata and links by default.
|
|
105
129
|
|
|
106
130
|
.. code-block:: python
|
|
107
131
|
|
|
@@ -115,55 +139,258 @@ class LocalGraph:
|
|
|
115
139
|
>>> df3 = pd.DataFrame(...)
|
|
116
140
|
|
|
117
141
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
142
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
143
|
... "table1": df1,
|
|
120
144
|
... "table2": df2,
|
|
121
145
|
... "table3": df3,
|
|
122
146
|
... })
|
|
123
147
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
148
|
Args:
|
|
132
149
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
150
|
of the tables and the values hold table data.
|
|
151
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
152
|
+
add to the graph. If not provided, edges will be automatically
|
|
153
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
154
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
155
|
graph.
|
|
156
|
+
verbose: Whether to print verbose output.
|
|
157
|
+
"""
|
|
158
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
|
+
|
|
161
|
+
graph = cls(tables, edges=edges or [])
|
|
162
|
+
|
|
163
|
+
if infer_metadata:
|
|
164
|
+
graph.infer_metadata(False)
|
|
165
|
+
|
|
166
|
+
if edges is None:
|
|
167
|
+
graph.infer_links(False)
|
|
168
|
+
|
|
169
|
+
if verbose:
|
|
170
|
+
graph.print_metadata()
|
|
171
|
+
graph.print_links()
|
|
172
|
+
|
|
173
|
+
return graph
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_sqlite(
|
|
177
|
+
cls,
|
|
178
|
+
connection: Union[
|
|
179
|
+
'AdbcSqliteConnection',
|
|
180
|
+
SqliteConnectionConfig,
|
|
181
|
+
str,
|
|
182
|
+
Path,
|
|
183
|
+
Dict[str, Any],
|
|
184
|
+
],
|
|
185
|
+
table_names: Optional[Sequence[str]] = None,
|
|
186
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
187
|
+
infer_metadata: bool = True,
|
|
188
|
+
verbose: bool = True,
|
|
189
|
+
) -> Self:
|
|
190
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
191
|
+
|
|
192
|
+
Automatically infers table metadata and links by default.
|
|
193
|
+
|
|
194
|
+
.. code-block:: python
|
|
195
|
+
|
|
196
|
+
>>> # doctest: +SKIP
|
|
197
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
198
|
+
|
|
199
|
+
>>> # Create a graph from a SQLite database:
|
|
200
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
connection: An open connection from
|
|
204
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
|
+
path to the database file.
|
|
206
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
207
|
+
all tables present in the database.
|
|
136
208
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
209
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
210
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
212
|
+
graph.
|
|
139
213
|
verbose: Whether to print verbose output.
|
|
214
|
+
"""
|
|
215
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
216
|
+
Connection,
|
|
217
|
+
SQLiteTable,
|
|
218
|
+
connect,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if not isinstance(connection, Connection):
|
|
222
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
223
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
225
|
+
assert isinstance(connection, Connection)
|
|
226
|
+
|
|
227
|
+
if table_names is None:
|
|
228
|
+
with connection.cursor() as cursor:
|
|
229
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
230
|
+
"WHERE type='table'")
|
|
231
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
232
|
+
|
|
233
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
140
234
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
235
|
+
graph = cls(tables, edges=edges or [])
|
|
236
|
+
|
|
237
|
+
if infer_metadata:
|
|
238
|
+
graph.infer_metadata(False)
|
|
239
|
+
|
|
240
|
+
if edges is None:
|
|
241
|
+
graph.infer_links(False)
|
|
242
|
+
|
|
243
|
+
if verbose:
|
|
244
|
+
graph.print_metadata()
|
|
245
|
+
graph.print_links()
|
|
246
|
+
|
|
247
|
+
return graph
|
|
248
|
+
|
|
249
|
+
@classmethod
|
|
250
|
+
def from_snowflake(
|
|
251
|
+
cls,
|
|
252
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
253
|
+
database: Optional[str] = None,
|
|
254
|
+
schema: Optional[str] = None,
|
|
255
|
+
table_names: Optional[Sequence[str]] = None,
|
|
256
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
257
|
+
infer_metadata: bool = True,
|
|
258
|
+
verbose: bool = True,
|
|
259
|
+
) -> Self:
|
|
260
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
261
|
+
schema.
|
|
262
|
+
|
|
263
|
+
Automatically infers table metadata and links by default.
|
|
264
|
+
|
|
265
|
+
.. code-block:: python
|
|
144
266
|
|
|
145
|
-
Example:
|
|
146
267
|
>>> # doctest: +SKIP
|
|
147
268
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
|
|
149
|
-
>>>
|
|
150
|
-
>>>
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
269
|
+
|
|
270
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
271
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
connection: An open connection from
|
|
275
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
276
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
277
|
+
connection. If ``None``, will re-use an active session in case
|
|
278
|
+
it exists, or create a new connection from credentials stored
|
|
279
|
+
in environment variables.
|
|
280
|
+
database: The database.
|
|
281
|
+
schema: The schema.
|
|
282
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
283
|
+
all tables present in the database.
|
|
284
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
285
|
+
add to the graph. If not provided, edges will be automatically
|
|
286
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
287
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
288
|
+
graph.
|
|
289
|
+
verbose: Whether to print verbose output.
|
|
157
290
|
"""
|
|
158
|
-
|
|
291
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
292
|
+
Connection,
|
|
293
|
+
SnowTable,
|
|
294
|
+
connect,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if not isinstance(connection, Connection):
|
|
298
|
+
connection = connect(**(connection or {}))
|
|
299
|
+
assert isinstance(connection, Connection)
|
|
300
|
+
|
|
301
|
+
if table_names is None:
|
|
302
|
+
with connection.cursor() as cursor:
|
|
303
|
+
if database is None and schema is None:
|
|
304
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
305
|
+
"CURRENT_SCHEMA()")
|
|
306
|
+
result = cursor.fetchone()
|
|
307
|
+
database = database or result[0]
|
|
308
|
+
schema = schema or result[1]
|
|
309
|
+
cursor.execute(f"""
|
|
310
|
+
SELECT TABLE_NAME
|
|
311
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
312
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
313
|
+
""")
|
|
314
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
315
|
+
|
|
316
|
+
tables = [
|
|
317
|
+
SnowTable(
|
|
318
|
+
connection,
|
|
319
|
+
name=table_name,
|
|
320
|
+
database=database,
|
|
321
|
+
schema=schema,
|
|
322
|
+
) for table_name in table_names
|
|
323
|
+
]
|
|
159
324
|
|
|
160
325
|
graph = cls(tables, edges=edges or [])
|
|
161
326
|
|
|
162
327
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(
|
|
328
|
+
graph.infer_metadata(False)
|
|
164
329
|
|
|
165
330
|
if edges is None:
|
|
166
|
-
graph.infer_links(
|
|
331
|
+
graph.infer_links(False)
|
|
332
|
+
|
|
333
|
+
if verbose:
|
|
334
|
+
graph.print_metadata()
|
|
335
|
+
graph.print_links()
|
|
336
|
+
|
|
337
|
+
return graph
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def from_snowflake_semantic_view(
|
|
341
|
+
cls,
|
|
342
|
+
semantic_view_name: str,
|
|
343
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
344
|
+
verbose: bool = True,
|
|
345
|
+
) -> Self:
|
|
346
|
+
import yaml
|
|
347
|
+
|
|
348
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
349
|
+
Connection,
|
|
350
|
+
SnowTable,
|
|
351
|
+
connect,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if not isinstance(connection, Connection):
|
|
355
|
+
connection = connect(**(connection or {}))
|
|
356
|
+
assert isinstance(connection, Connection)
|
|
357
|
+
|
|
358
|
+
with connection.cursor() as cursor:
|
|
359
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
360
|
+
f"'{semantic_view_name}')")
|
|
361
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
362
|
+
|
|
363
|
+
graph = cls(tables=[])
|
|
364
|
+
|
|
365
|
+
for table_desc in view['tables']:
|
|
366
|
+
primary_key: Optional[str] = None
|
|
367
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
368
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
369
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
370
|
+
|
|
371
|
+
table = SnowTable(
|
|
372
|
+
connection,
|
|
373
|
+
name=table_desc['base_table']['table'],
|
|
374
|
+
database=table_desc['base_table']['database'],
|
|
375
|
+
schema=table_desc['base_table']['schema'],
|
|
376
|
+
primary_key=primary_key,
|
|
377
|
+
)
|
|
378
|
+
graph.add_table(table)
|
|
379
|
+
|
|
380
|
+
# TODO Find a solution to register time columns!
|
|
381
|
+
|
|
382
|
+
for relations in view['relationships']:
|
|
383
|
+
if len(relations['relationship_columns']) != 1:
|
|
384
|
+
continue # NOTE No composite keys yet.
|
|
385
|
+
graph.link(
|
|
386
|
+
src_table=relations['left_table'],
|
|
387
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
388
|
+
dst_table=relations['right_table'],
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if verbose:
|
|
392
|
+
graph.print_metadata()
|
|
393
|
+
graph.print_links()
|
|
167
394
|
|
|
168
395
|
return graph
|
|
169
396
|
|
|
@@ -175,7 +402,7 @@ class LocalGraph:
|
|
|
175
402
|
"""
|
|
176
403
|
return name in self.tables
|
|
177
404
|
|
|
178
|
-
def table(self, name: str) ->
|
|
405
|
+
def table(self, name: str) -> Table:
|
|
179
406
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
407
|
|
|
181
408
|
Raises:
|
|
@@ -186,11 +413,11 @@ class LocalGraph:
|
|
|
186
413
|
return self.tables[name]
|
|
187
414
|
|
|
188
415
|
@property
|
|
189
|
-
def tables(self) -> Dict[str,
|
|
416
|
+
def tables(self) -> Dict[str, Table]:
|
|
190
417
|
r"""Returns the dictionary of table objects."""
|
|
191
418
|
return self._tables
|
|
192
419
|
|
|
193
|
-
def add_table(self, table:
|
|
420
|
+
def add_table(self, table: Table) -> Self:
|
|
194
421
|
r"""Adds a table to the graph.
|
|
195
422
|
|
|
196
423
|
Args:
|
|
@@ -199,11 +426,21 @@ class LocalGraph:
|
|
|
199
426
|
Raises:
|
|
200
427
|
KeyError: If a table with the same name already exists in the
|
|
201
428
|
graph.
|
|
429
|
+
ValueError: If the table belongs to a different backend than the
|
|
430
|
+
rest of the tables in the graph.
|
|
202
431
|
"""
|
|
203
432
|
if table.name in self._tables:
|
|
204
433
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
434
|
f"this graph; table names must be globally unique.")
|
|
206
435
|
|
|
436
|
+
if len(self._tables) > 0:
|
|
437
|
+
cls = next(iter(self._tables.values())).__class__
|
|
438
|
+
if table.__class__ != cls:
|
|
439
|
+
raise ValueError(f"Cannot register a "
|
|
440
|
+
f"'{table.__class__.__name__}' to this "
|
|
441
|
+
f"graph since other tables are of type "
|
|
442
|
+
f"'{cls.__name__}'.")
|
|
443
|
+
|
|
207
444
|
self._tables[table.name] = table
|
|
208
445
|
|
|
209
446
|
return self
|
|
@@ -241,7 +478,7 @@ class LocalGraph:
|
|
|
241
478
|
Example:
|
|
242
479
|
>>> # doctest: +SKIP
|
|
243
480
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
481
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
482
|
>>> graph.metadata # doctest: +SKIP
|
|
246
483
|
name primary_key time_column end_time_column
|
|
247
484
|
0 users user_id - -
|
|
@@ -263,10 +500,14 @@ class LocalGraph:
|
|
|
263
500
|
})
|
|
264
501
|
|
|
265
502
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
267
|
-
if
|
|
503
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
504
|
+
if in_snowflake_notebook():
|
|
505
|
+
import streamlit as st
|
|
506
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
507
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
508
|
+
elif in_notebook():
|
|
268
509
|
from IPython.display import Markdown, display
|
|
269
|
-
display(Markdown(
|
|
510
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
270
511
|
df = self.metadata
|
|
271
512
|
try:
|
|
272
513
|
if hasattr(df.style, 'hide'):
|
|
@@ -287,7 +528,7 @@ class LocalGraph:
|
|
|
287
528
|
|
|
288
529
|
Note:
|
|
289
530
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
531
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
532
|
"""
|
|
292
533
|
for table in self.tables.values():
|
|
293
534
|
table.infer_metadata(verbose=False)
|
|
@@ -305,37 +546,47 @@ class LocalGraph:
|
|
|
305
546
|
return self._edges
|
|
306
547
|
|
|
307
548
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
549
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
550
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
551
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
552
|
edges = sorted(edges)
|
|
312
553
|
|
|
313
|
-
if
|
|
554
|
+
if in_snowflake_notebook():
|
|
555
|
+
import streamlit as st
|
|
556
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
557
|
+
if len(edges) > 0:
|
|
558
|
+
st.markdown('\n'.join([
|
|
559
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
560
|
+
for edge in edges
|
|
561
|
+
]))
|
|
562
|
+
else:
|
|
563
|
+
st.markdown("*No links registered*")
|
|
564
|
+
elif in_notebook():
|
|
314
565
|
from IPython.display import Markdown, display
|
|
315
|
-
display(Markdown(
|
|
566
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
316
567
|
if len(edges) > 0:
|
|
317
568
|
display(
|
|
318
569
|
Markdown('\n'.join([
|
|
319
|
-
f
|
|
570
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
320
571
|
for edge in edges
|
|
321
572
|
])))
|
|
322
573
|
else:
|
|
323
|
-
display(Markdown(
|
|
574
|
+
display(Markdown("*No links registered*"))
|
|
324
575
|
else:
|
|
325
576
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
326
577
|
if len(edges) > 0:
|
|
327
578
|
print('\n'.join([
|
|
328
|
-
f
|
|
579
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
329
580
|
for edge in edges
|
|
330
581
|
]))
|
|
331
582
|
else:
|
|
332
|
-
print(
|
|
583
|
+
print("No links registered")
|
|
333
584
|
|
|
334
585
|
def link(
|
|
335
586
|
self,
|
|
336
|
-
src_table: Union[str,
|
|
587
|
+
src_table: Union[str, Table],
|
|
337
588
|
fkey: str,
|
|
338
|
-
dst_table: Union[str,
|
|
589
|
+
dst_table: Union[str, Table],
|
|
339
590
|
) -> Self:
|
|
340
591
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
592
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +609,11 @@ class LocalGraph:
|
|
|
358
609
|
table does not exist in the graph, if the source key does not
|
|
359
610
|
exist in the source table.
|
|
360
611
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
612
|
+
if isinstance(src_table, Table):
|
|
362
613
|
src_table = src_table.name
|
|
363
614
|
assert isinstance(src_table, str)
|
|
364
615
|
|
|
365
|
-
if isinstance(dst_table,
|
|
616
|
+
if isinstance(dst_table, Table):
|
|
366
617
|
dst_table = dst_table.name
|
|
367
618
|
assert isinstance(dst_table, str)
|
|
368
619
|
|
|
@@ -396,9 +647,9 @@ class LocalGraph:
|
|
|
396
647
|
|
|
397
648
|
def unlink(
|
|
398
649
|
self,
|
|
399
|
-
src_table: Union[str,
|
|
650
|
+
src_table: Union[str, Table],
|
|
400
651
|
fkey: str,
|
|
401
|
-
dst_table: Union[str,
|
|
652
|
+
dst_table: Union[str, Table],
|
|
402
653
|
) -> Self:
|
|
403
654
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
655
|
|
|
@@ -410,11 +661,11 @@ class LocalGraph:
|
|
|
410
661
|
Raises:
|
|
411
662
|
ValueError: if the edge is not present in the graph.
|
|
412
663
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
664
|
+
if isinstance(src_table, Table):
|
|
414
665
|
src_table = src_table.name
|
|
415
666
|
assert isinstance(src_table, str)
|
|
416
667
|
|
|
417
|
-
if isinstance(dst_table,
|
|
668
|
+
if isinstance(dst_table, Table):
|
|
418
669
|
dst_table = dst_table.name
|
|
419
670
|
assert isinstance(dst_table, str)
|
|
420
671
|
|
|
@@ -428,17 +679,13 @@ class LocalGraph:
|
|
|
428
679
|
return self
|
|
429
680
|
|
|
430
681
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
682
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
683
|
+
graph.
|
|
432
684
|
|
|
433
685
|
Args:
|
|
434
686
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
687
|
"""
|
|
439
|
-
|
|
440
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
441
|
-
return self
|
|
688
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
442
689
|
|
|
443
690
|
# A list of primary key candidates (+score) for every column:
|
|
444
691
|
candidate_dict: dict[
|
|
@@ -463,6 +710,9 @@ class LocalGraph:
|
|
|
463
710
|
src_table_name = src_table.name.lower()
|
|
464
711
|
|
|
465
712
|
for src_key in src_table.columns:
|
|
713
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
714
|
+
continue
|
|
715
|
+
|
|
466
716
|
if src_key == src_table.primary_key:
|
|
467
717
|
continue # Cannot link to primary key.
|
|
468
718
|
|
|
@@ -528,7 +778,9 @@ class LocalGraph:
|
|
|
528
778
|
score += 1.0
|
|
529
779
|
|
|
530
780
|
# Cardinality ratio:
|
|
531
|
-
if
|
|
781
|
+
if (src_table._num_rows is not None
|
|
782
|
+
and dst_table._num_rows is not None
|
|
783
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
532
784
|
score += 1.0
|
|
533
785
|
|
|
534
786
|
if score < 5.0:
|
|
@@ -645,19 +897,19 @@ class LocalGraph:
|
|
|
645
897
|
|
|
646
898
|
return True
|
|
647
899
|
|
|
648
|
-
# Check basic dependency:
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
900
|
+
try: # Check basic dependency:
|
|
901
|
+
import graphviz
|
|
902
|
+
except ImportError as e:
|
|
903
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
904
|
+
"visualization") from e
|
|
905
|
+
|
|
906
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
653
907
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
654
908
|
"executables are not installed. These "
|
|
655
909
|
"dependencies are required in addition to the "
|
|
656
910
|
"'graphviz' Python package. Please install "
|
|
657
911
|
"them as described at "
|
|
658
912
|
"https://graphviz.org/download/.")
|
|
659
|
-
else:
|
|
660
|
-
import graphviz
|
|
661
913
|
|
|
662
914
|
format: Optional[str] = None
|
|
663
915
|
if isinstance(path, str):
|
|
@@ -741,6 +993,9 @@ class LocalGraph:
|
|
|
741
993
|
graph.render(path, cleanup=True)
|
|
742
994
|
elif isinstance(path, io.BytesIO):
|
|
743
995
|
path.write(graph.pipe())
|
|
996
|
+
elif in_snowflake_notebook():
|
|
997
|
+
import streamlit as st
|
|
998
|
+
st.graphviz_chart(graph)
|
|
744
999
|
elif in_notebook():
|
|
745
1000
|
from IPython.display import display
|
|
746
1001
|
display(graph)
|
|
@@ -790,7 +1045,7 @@ class LocalGraph:
|
|
|
790
1045
|
def __contains__(self, name: str) -> bool:
|
|
791
1046
|
return self.has_table(name)
|
|
792
1047
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
1048
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
1049
|
return self.table(name)
|
|
795
1050
|
|
|
796
1051
|
def __delitem__(self, name: str) -> None:
|
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from .dtype import infer_dtype
|
|
2
|
+
from .pkey import infer_primary_key
|
|
3
|
+
from .time_col import infer_time_column
|
|
1
4
|
from .id import contains_id
|
|
2
5
|
from .timestamp import contains_timestamp
|
|
3
6
|
from .categorical import contains_categorical
|
|
4
7
|
from .multicategorical import contains_multicategorical
|
|
5
8
|
|
|
6
9
|
__all__ = [
|
|
10
|
+
'infer_dtype',
|
|
11
|
+
'infer_primary_key',
|
|
12
|
+
'infer_time_column',
|
|
7
13
|
'contains_id',
|
|
8
14
|
'contains_timestamp',
|
|
9
15
|
'contains_categorical',
|