kumoai 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512041731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +12 -0
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/backend/local/table.py +32 -167
- kumoai/experimental/rfm/backend/snow/__init__.py +3 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +58 -81
- kumoai/experimental/rfm/base/__init__.py +5 -0
- kumoai/experimental/rfm/base/sampler.py +134 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +95 -27
- kumoai/experimental/rfm/graph.py +220 -52
- kumoai/experimental/rfm/infer/__init__.py +6 -2
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/{utils.py → infer/pkey.py} +2 -101
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +42 -1
- kumoai/experimental/rfm/local_graph_store.py +1 -16
- kumoai/experimental/rfm/rfm.py +1 -11
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/METADATA +2 -1
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/RECORD +24 -20
- kumoai/experimental/rfm/infer/stype.py +0 -35
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512021731.dist-info → kumoai-2.13.0.dev202512041731.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from kumoapi.typing import Dtype
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SourceColumn:
|
|
8
|
+
name: str
|
|
9
|
+
dtype: Dtype
|
|
10
|
+
is_primary_key: bool
|
|
11
|
+
is_unique_key: bool
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class SourceForeignKey:
|
|
16
|
+
name: str
|
|
17
|
+
dst_table: str
|
|
18
|
+
primary_key: str
|
|
@@ -1,15 +1,25 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
-
from
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Dict, List, Optional, Sequence, Set
|
|
3
5
|
|
|
4
6
|
import pandas as pd
|
|
5
7
|
from kumoapi.source_table import UnavailableSourceTable
|
|
6
8
|
from kumoapi.table import Column as ColumnDefinition
|
|
7
9
|
from kumoapi.table import TableDefinition
|
|
8
|
-
from kumoapi.typing import
|
|
10
|
+
from kumoapi.typing import Stype
|
|
9
11
|
from typing_extensions import Self
|
|
10
12
|
|
|
11
|
-
from kumoai import in_notebook
|
|
12
|
-
from kumoai.experimental.rfm.base import Column
|
|
13
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
14
|
+
from kumoai.experimental.rfm.base import Column, SourceColumn, SourceForeignKey
|
|
15
|
+
from kumoai.experimental.rfm.infer import (
|
|
16
|
+
contains_categorical,
|
|
17
|
+
contains_id,
|
|
18
|
+
contains_multicategorical,
|
|
19
|
+
contains_timestamp,
|
|
20
|
+
infer_primary_key,
|
|
21
|
+
infer_time_column,
|
|
22
|
+
)
|
|
13
23
|
|
|
14
24
|
|
|
15
25
|
class Table(ABC):
|
|
@@ -39,8 +49,30 @@ class Table(ABC):
|
|
|
39
49
|
self._time_column: Optional[str] = None
|
|
40
50
|
self._end_time_column: Optional[str] = None
|
|
41
51
|
|
|
52
|
+
if len(self._source_column_dict) == 0:
|
|
53
|
+
raise ValueError(f"Table '{name}' does not hold any column with "
|
|
54
|
+
f"a supported data type")
|
|
55
|
+
|
|
56
|
+
primary_keys = [
|
|
57
|
+
column.name for column in self._source_column_dict.values()
|
|
58
|
+
if column.is_primary_key
|
|
59
|
+
]
|
|
60
|
+
if len(primary_keys) == 1: # NOTE No composite keys yet.
|
|
61
|
+
if primary_key is not None and primary_key != primary_keys[0]:
|
|
62
|
+
raise ValueError(f"Found duplicate primary key "
|
|
63
|
+
f"definition '{primary_key}' and "
|
|
64
|
+
f"'{primary_keys[0]}' in table '{name}'")
|
|
65
|
+
primary_key = primary_keys[0]
|
|
66
|
+
|
|
67
|
+
unique_keys = [
|
|
68
|
+
column.name for column in self._source_column_dict.values()
|
|
69
|
+
if column.is_unique_key
|
|
70
|
+
]
|
|
71
|
+
if primary_key is None and len(unique_keys) == 1:
|
|
72
|
+
primary_key = unique_keys[0]
|
|
73
|
+
|
|
42
74
|
self._columns: Dict[str, Column] = {}
|
|
43
|
-
for column_name in columns or
|
|
75
|
+
for column_name in columns or list(self._source_column_dict.keys()):
|
|
44
76
|
self.add_column(column_name)
|
|
45
77
|
|
|
46
78
|
if primary_key is not None:
|
|
@@ -104,12 +136,12 @@ class Table(ABC):
|
|
|
104
136
|
raise KeyError(f"Column '{name}' already exists in table "
|
|
105
137
|
f"'{self.name}'")
|
|
106
138
|
|
|
107
|
-
if not self.
|
|
139
|
+
if name not in self._source_column_dict:
|
|
108
140
|
raise KeyError(f"Column '{name}' does not exist in the underlying "
|
|
109
141
|
f"source table")
|
|
110
142
|
|
|
111
143
|
try:
|
|
112
|
-
dtype = self.
|
|
144
|
+
dtype = self._source_column_dict[name].dtype
|
|
113
145
|
except Exception as e:
|
|
114
146
|
raise RuntimeError(f"Could not obtain data type for column "
|
|
115
147
|
f"'{name}' in table '{self.name}'. Change "
|
|
@@ -117,7 +149,17 @@ class Table(ABC):
|
|
|
117
149
|
f"table or remove it from the table.") from e
|
|
118
150
|
|
|
119
151
|
try:
|
|
120
|
-
|
|
152
|
+
ser = self._sample_df[name]
|
|
153
|
+
if contains_id(ser, name, dtype):
|
|
154
|
+
stype = Stype.ID
|
|
155
|
+
elif contains_timestamp(ser, name, dtype):
|
|
156
|
+
stype = Stype.timestamp
|
|
157
|
+
elif contains_multicategorical(ser, name, dtype):
|
|
158
|
+
stype = Stype.multicategorical
|
|
159
|
+
elif contains_categorical(ser, name, dtype):
|
|
160
|
+
stype = Stype.categorical
|
|
161
|
+
else:
|
|
162
|
+
stype = dtype.default_stype
|
|
121
163
|
except Exception as e:
|
|
122
164
|
raise RuntimeError(f"Could not obtain semantic type for column "
|
|
123
165
|
f"'{name}' in table '{self.name}'. Change "
|
|
@@ -338,10 +380,16 @@ class Table(ABC):
|
|
|
338
380
|
|
|
339
381
|
def print_metadata(self) -> None:
|
|
340
382
|
r"""Prints the :meth:`~metadata` of this table."""
|
|
341
|
-
|
|
342
|
-
|
|
383
|
+
num_rows_repr = ''
|
|
384
|
+
if self._num_rows is not None:
|
|
385
|
+
num_rows_repr = ' ({self._num_rows:,} rows)'
|
|
343
386
|
|
|
344
|
-
if
|
|
387
|
+
if in_snowflake_notebook():
|
|
388
|
+
import streamlit as st
|
|
389
|
+
md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
|
|
390
|
+
st.markdown(md_repr)
|
|
391
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
392
|
+
elif in_notebook():
|
|
345
393
|
from IPython.display import Markdown, display
|
|
346
394
|
md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
|
|
347
395
|
display(Markdown(md_repr))
|
|
@@ -384,7 +432,11 @@ class Table(ABC):
|
|
|
384
432
|
column.name for column in self.columns if is_candidate(column)
|
|
385
433
|
]
|
|
386
434
|
|
|
387
|
-
if primary_key :=
|
|
435
|
+
if primary_key := infer_primary_key(
|
|
436
|
+
table_name=self.name,
|
|
437
|
+
df=self._sample_df,
|
|
438
|
+
candidates=candidates,
|
|
439
|
+
):
|
|
388
440
|
self.primary_key = primary_key
|
|
389
441
|
logs.append(f"primary key '{primary_key}'")
|
|
390
442
|
|
|
@@ -395,7 +447,10 @@ class Table(ABC):
|
|
|
395
447
|
if column.stype == Stype.timestamp
|
|
396
448
|
and column.name != self._end_time_column
|
|
397
449
|
]
|
|
398
|
-
if time_column :=
|
|
450
|
+
if time_column := infer_time_column(
|
|
451
|
+
df=self._sample_df,
|
|
452
|
+
candidates=candidates,
|
|
453
|
+
):
|
|
399
454
|
self.time_column = time_column
|
|
400
455
|
logs.append(f"time column '{time_column}'")
|
|
401
456
|
|
|
@@ -446,32 +501,45 @@ class Table(ABC):
|
|
|
446
501
|
f' end_time_column={self._end_time_column},\n'
|
|
447
502
|
f')')
|
|
448
503
|
|
|
449
|
-
# Abstract
|
|
504
|
+
# Abstract Methods ########################################################
|
|
450
505
|
|
|
451
|
-
@
|
|
452
|
-
def
|
|
453
|
-
|
|
506
|
+
@cached_property
|
|
507
|
+
def _source_column_dict(self) -> Dict[str, SourceColumn]:
|
|
508
|
+
return {col.name: col for col in self._get_source_columns()}
|
|
454
509
|
|
|
455
510
|
@abstractmethod
|
|
456
|
-
def
|
|
511
|
+
def _get_source_columns(self) -> List[SourceColumn]:
|
|
457
512
|
pass
|
|
458
513
|
|
|
459
|
-
@
|
|
460
|
-
def
|
|
461
|
-
|
|
514
|
+
@cached_property
|
|
515
|
+
def _source_foreign_key_dict(self) -> Dict[str, SourceForeignKey]:
|
|
516
|
+
fkeys = self._get_source_foreign_keys()
|
|
517
|
+
# NOTE Drop all keys that link to different primary keys in the same
|
|
518
|
+
# table since we don't support composite keys yet:
|
|
519
|
+
table_pkeys: Dict[str, Set[str]] = defaultdict(set)
|
|
520
|
+
for fkey in fkeys:
|
|
521
|
+
table_pkeys[fkey.dst_table].add(fkey.primary_key)
|
|
522
|
+
return {
|
|
523
|
+
fkey.name: fkey
|
|
524
|
+
for fkey in fkeys if len(table_pkeys[fkey.dst_table]) == 1
|
|
525
|
+
}
|
|
462
526
|
|
|
463
527
|
@abstractmethod
|
|
464
|
-
def _get_source_foreign_keys(self) -> List[
|
|
528
|
+
def _get_source_foreign_keys(self) -> List[SourceForeignKey]:
|
|
465
529
|
pass
|
|
466
530
|
|
|
467
|
-
@
|
|
468
|
-
def
|
|
469
|
-
|
|
531
|
+
@cached_property
|
|
532
|
+
def _sample_df(self) -> pd.DataFrame:
|
|
533
|
+
return self._get_sample_df()
|
|
470
534
|
|
|
471
535
|
@abstractmethod
|
|
472
|
-
def
|
|
536
|
+
def _get_sample_df(self) -> pd.DataFrame:
|
|
473
537
|
pass
|
|
474
538
|
|
|
475
|
-
@
|
|
539
|
+
@cached_property
|
|
476
540
|
def _num_rows(self) -> Optional[int]:
|
|
541
|
+
return self._get_num_rows()
|
|
542
|
+
|
|
543
|
+
@abstractmethod
|
|
544
|
+
def _get_num_rows(self) -> Optional[int]:
|
|
477
545
|
pass
|
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -2,7 +2,8 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
-
from
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
8
|
|
|
8
9
|
import pandas as pd
|
|
@@ -11,12 +12,21 @@ from kumoapi.table import TableDefinition
|
|
|
11
12
|
from kumoapi.typing import Stype
|
|
12
13
|
from typing_extensions import Self
|
|
13
14
|
|
|
14
|
-
from kumoai import in_notebook
|
|
15
|
+
from kumoai import in_notebook, in_snowflake_notebook
|
|
15
16
|
from kumoai.experimental.rfm import Table
|
|
16
17
|
from kumoai.graph import Edge
|
|
18
|
+
from kumoai.mixin import CastMixin
|
|
17
19
|
|
|
18
20
|
if TYPE_CHECKING:
|
|
19
21
|
import graphviz
|
|
22
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
23
|
+
from snowflake.connector import SnowflakeConnection
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class SqliteConnectionConfig(CastMixin):
|
|
28
|
+
uri: Union[str, Path]
|
|
29
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
class Graph:
|
|
@@ -86,14 +96,17 @@ class Graph:
|
|
|
86
96
|
self.add_table(table)
|
|
87
97
|
|
|
88
98
|
for table in tables:
|
|
89
|
-
for fkey
|
|
90
|
-
if
|
|
91
|
-
|
|
92
|
-
|
|
99
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
101
|
+
continue
|
|
102
|
+
if self[fkey.dst_table].primary_key is None:
|
|
103
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
93
105
|
raise ValueError(f"Found duplicate primary key definition "
|
|
94
|
-
f"'{self[dst_table]._primary_key}'
|
|
95
|
-
f"'{
|
|
96
|
-
|
|
106
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
+
f"and '{fkey.primary_key}' in table "
|
|
108
|
+
f"'{fkey.dst_table}'.")
|
|
109
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
97
110
|
|
|
98
111
|
for edge in (edges or []):
|
|
99
112
|
_edge = Edge._cast(edge)
|
|
@@ -132,13 +145,6 @@ class Graph:
|
|
|
132
145
|
... "table3": df3,
|
|
133
146
|
... })
|
|
134
147
|
|
|
135
|
-
>>> # Inspect table metadata:
|
|
136
|
-
>>> for table in graph.tables.values():
|
|
137
|
-
... table.print_metadata()
|
|
138
|
-
|
|
139
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
140
|
-
>>> graph.visualize()
|
|
141
|
-
|
|
142
148
|
Args:
|
|
143
149
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
144
150
|
of the tables and the values hold table data.
|
|
@@ -169,12 +175,17 @@ class Graph:
|
|
|
169
175
|
@classmethod
|
|
170
176
|
def from_sqlite(
|
|
171
177
|
cls,
|
|
172
|
-
|
|
178
|
+
connection: Union[
|
|
179
|
+
'AdbcSqliteConnection',
|
|
180
|
+
SqliteConnectionConfig,
|
|
181
|
+
str,
|
|
182
|
+
Path,
|
|
183
|
+
Dict[str, Any],
|
|
184
|
+
],
|
|
173
185
|
table_names: Optional[Sequence[str]] = None,
|
|
174
186
|
edges: Optional[Sequence[Edge]] = None,
|
|
175
187
|
infer_metadata: bool = True,
|
|
176
188
|
verbose: bool = True,
|
|
177
|
-
conn_kwargs: Optional[Dict[str, Any]] = None,
|
|
178
189
|
) -> Self:
|
|
179
190
|
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
180
191
|
|
|
@@ -188,16 +199,10 @@ class Graph:
|
|
|
188
199
|
>>> # Create a graph from a SQLite database:
|
|
189
200
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
190
201
|
|
|
191
|
-
>>> # Inspect table metadata:
|
|
192
|
-
>>> for table in graph.tables.values():
|
|
193
|
-
... table.print_metadata()
|
|
194
|
-
|
|
195
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
196
|
-
>>> graph.visualize()
|
|
197
|
-
|
|
198
202
|
Args:
|
|
199
|
-
|
|
200
|
-
|
|
203
|
+
connection: An open connection from
|
|
204
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
|
+
path to the database file.
|
|
201
206
|
table_names: Set of table names to include. If ``None``, will add
|
|
202
207
|
all tables present in the database.
|
|
203
208
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
@@ -206,8 +211,6 @@ class Graph:
|
|
|
206
211
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
207
212
|
graph.
|
|
208
213
|
verbose: Whether to print verbose output.
|
|
209
|
-
conn_kwargs: Additional connection arguments, following the
|
|
210
|
-
:class:`adbc_driver_sqlite` protocol.
|
|
211
214
|
"""
|
|
212
215
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
213
216
|
Connection,
|
|
@@ -215,10 +218,11 @@ class Graph:
|
|
|
215
218
|
connect,
|
|
216
219
|
)
|
|
217
220
|
|
|
218
|
-
if not isinstance(
|
|
219
|
-
connection =
|
|
220
|
-
|
|
221
|
-
connection = uri
|
|
221
|
+
if not isinstance(connection, Connection):
|
|
222
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
223
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
225
|
+
assert isinstance(connection, Connection)
|
|
222
226
|
|
|
223
227
|
if table_names is None:
|
|
224
228
|
with connection.cursor() as cursor:
|
|
@@ -242,6 +246,154 @@ class Graph:
|
|
|
242
246
|
|
|
243
247
|
return graph
|
|
244
248
|
|
|
249
|
+
@classmethod
|
|
250
|
+
def from_snowflake(
|
|
251
|
+
cls,
|
|
252
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
253
|
+
database: Optional[str] = None,
|
|
254
|
+
schema: Optional[str] = None,
|
|
255
|
+
table_names: Optional[Sequence[str]] = None,
|
|
256
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
257
|
+
infer_metadata: bool = True,
|
|
258
|
+
verbose: bool = True,
|
|
259
|
+
) -> Self:
|
|
260
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
261
|
+
schema.
|
|
262
|
+
|
|
263
|
+
Automatically infers table metadata and links by default.
|
|
264
|
+
|
|
265
|
+
.. code-block:: python
|
|
266
|
+
|
|
267
|
+
>>> # doctest: +SKIP
|
|
268
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
269
|
+
|
|
270
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
271
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
connection: An open connection from
|
|
275
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
276
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
277
|
+
connection. If ``None``, will re-use an active session in case
|
|
278
|
+
it exists, or create a new connection from credentials stored
|
|
279
|
+
in environment variables.
|
|
280
|
+
database: The database.
|
|
281
|
+
schema: The schema.
|
|
282
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
283
|
+
all tables present in the database.
|
|
284
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
285
|
+
add to the graph. If not provided, edges will be automatically
|
|
286
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
287
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
288
|
+
graph.
|
|
289
|
+
verbose: Whether to print verbose output.
|
|
290
|
+
"""
|
|
291
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
292
|
+
Connection,
|
|
293
|
+
SnowTable,
|
|
294
|
+
connect,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if not isinstance(connection, Connection):
|
|
298
|
+
connection = connect(**(connection or {}))
|
|
299
|
+
assert isinstance(connection, Connection)
|
|
300
|
+
|
|
301
|
+
if table_names is None:
|
|
302
|
+
with connection.cursor() as cursor:
|
|
303
|
+
if database is None and schema is None:
|
|
304
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
305
|
+
"CURRENT_SCHEMA()")
|
|
306
|
+
result = cursor.fetchone()
|
|
307
|
+
database = database or result[0]
|
|
308
|
+
schema = schema or result[1]
|
|
309
|
+
cursor.execute(f"""
|
|
310
|
+
SELECT TABLE_NAME
|
|
311
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
312
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
313
|
+
""")
|
|
314
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
315
|
+
|
|
316
|
+
tables = [
|
|
317
|
+
SnowTable(
|
|
318
|
+
connection,
|
|
319
|
+
name=table_name,
|
|
320
|
+
database=database,
|
|
321
|
+
schema=schema,
|
|
322
|
+
) for table_name in table_names
|
|
323
|
+
]
|
|
324
|
+
|
|
325
|
+
graph = cls(tables, edges=edges or [])
|
|
326
|
+
|
|
327
|
+
if infer_metadata:
|
|
328
|
+
graph.infer_metadata(False)
|
|
329
|
+
|
|
330
|
+
if edges is None:
|
|
331
|
+
graph.infer_links(False)
|
|
332
|
+
|
|
333
|
+
if verbose:
|
|
334
|
+
graph.print_metadata()
|
|
335
|
+
graph.print_links()
|
|
336
|
+
|
|
337
|
+
return graph
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def from_snowflake_semantic_view(
|
|
341
|
+
cls,
|
|
342
|
+
semantic_view_name: str,
|
|
343
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
344
|
+
verbose: bool = True,
|
|
345
|
+
) -> Self:
|
|
346
|
+
import yaml
|
|
347
|
+
|
|
348
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
349
|
+
Connection,
|
|
350
|
+
SnowTable,
|
|
351
|
+
connect,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if not isinstance(connection, Connection):
|
|
355
|
+
connection = connect(**(connection or {}))
|
|
356
|
+
assert isinstance(connection, Connection)
|
|
357
|
+
|
|
358
|
+
with connection.cursor() as cursor:
|
|
359
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
360
|
+
f"'{semantic_view_name}')")
|
|
361
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
362
|
+
|
|
363
|
+
graph = cls(tables=[])
|
|
364
|
+
|
|
365
|
+
for table_desc in view['tables']:
|
|
366
|
+
primary_key: Optional[str] = None
|
|
367
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
368
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
369
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
370
|
+
|
|
371
|
+
table = SnowTable(
|
|
372
|
+
connection,
|
|
373
|
+
name=table_desc['base_table']['table'],
|
|
374
|
+
database=table_desc['base_table']['database'],
|
|
375
|
+
schema=table_desc['base_table']['schema'],
|
|
376
|
+
primary_key=primary_key,
|
|
377
|
+
)
|
|
378
|
+
graph.add_table(table)
|
|
379
|
+
|
|
380
|
+
# TODO Find a solution to register time columns!
|
|
381
|
+
|
|
382
|
+
for relations in view['relationships']:
|
|
383
|
+
if len(relations['relationship_columns']) != 1:
|
|
384
|
+
continue # NOTE No composite keys yet.
|
|
385
|
+
graph.link(
|
|
386
|
+
src_table=relations['left_table'],
|
|
387
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
388
|
+
dst_table=relations['right_table'],
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
if verbose:
|
|
392
|
+
graph.print_metadata()
|
|
393
|
+
graph.print_links()
|
|
394
|
+
|
|
395
|
+
return graph
|
|
396
|
+
|
|
245
397
|
# Tables ##############################################################
|
|
246
398
|
|
|
247
399
|
def has_table(self, name: str) -> bool:
|
|
@@ -349,9 +501,13 @@ class Graph:
|
|
|
349
501
|
|
|
350
502
|
def print_metadata(self) -> None:
|
|
351
503
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
352
|
-
if
|
|
504
|
+
if in_snowflake_notebook():
|
|
505
|
+
import streamlit as st
|
|
506
|
+
st.markdown("### 🗂️ Graph Metadata")
|
|
507
|
+
st.dataframe(self.metadata, hide_index=True)
|
|
508
|
+
elif in_notebook():
|
|
353
509
|
from IPython.display import Markdown, display
|
|
354
|
-
display(Markdown(
|
|
510
|
+
display(Markdown("### 🗂️ Graph Metadata"))
|
|
355
511
|
df = self.metadata
|
|
356
512
|
try:
|
|
357
513
|
if hasattr(df.style, 'hide'):
|
|
@@ -395,26 +551,36 @@ class Graph:
|
|
|
395
551
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
396
552
|
edges = sorted(edges)
|
|
397
553
|
|
|
398
|
-
if
|
|
554
|
+
if in_snowflake_notebook():
|
|
555
|
+
import streamlit as st
|
|
556
|
+
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
557
|
+
if len(edges) > 0:
|
|
558
|
+
st.markdown('\n'.join([
|
|
559
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
560
|
+
for edge in edges
|
|
561
|
+
]))
|
|
562
|
+
else:
|
|
563
|
+
st.markdown("*No links registered*")
|
|
564
|
+
elif in_notebook():
|
|
399
565
|
from IPython.display import Markdown, display
|
|
400
|
-
display(Markdown(
|
|
566
|
+
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
401
567
|
if len(edges) > 0:
|
|
402
568
|
display(
|
|
403
569
|
Markdown('\n'.join([
|
|
404
|
-
f
|
|
570
|
+
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
405
571
|
for edge in edges
|
|
406
572
|
])))
|
|
407
573
|
else:
|
|
408
|
-
display(Markdown(
|
|
574
|
+
display(Markdown("*No links registered*"))
|
|
409
575
|
else:
|
|
410
576
|
print("🕸️ Graph Links (FK ↔️ PK):")
|
|
411
577
|
if len(edges) > 0:
|
|
412
578
|
print('\n'.join([
|
|
413
|
-
f
|
|
579
|
+
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
414
580
|
for edge in edges
|
|
415
581
|
]))
|
|
416
582
|
else:
|
|
417
|
-
print(
|
|
583
|
+
print("No links registered")
|
|
418
584
|
|
|
419
585
|
def link(
|
|
420
586
|
self,
|
|
@@ -612,10 +778,9 @@ class Graph:
|
|
|
612
778
|
score += 1.0
|
|
613
779
|
|
|
614
780
|
# Cardinality ratio:
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
and src_num_rows > dst_num_rows):
|
|
781
|
+
if (src_table._num_rows is not None
|
|
782
|
+
and dst_table._num_rows is not None
|
|
783
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
619
784
|
score += 1.0
|
|
620
785
|
|
|
621
786
|
if score < 5.0:
|
|
@@ -732,19 +897,19 @@ class Graph:
|
|
|
732
897
|
|
|
733
898
|
return True
|
|
734
899
|
|
|
735
|
-
# Check basic dependency:
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
900
|
+
try: # Check basic dependency:
|
|
901
|
+
import graphviz
|
|
902
|
+
except ImportError as e:
|
|
903
|
+
raise ImportError("The 'graphviz' package is required for "
|
|
904
|
+
"visualization") from e
|
|
905
|
+
|
|
906
|
+
if not in_snowflake_notebook() and not has_graphviz_executables():
|
|
740
907
|
raise RuntimeError("Could not visualize graph as 'graphviz' "
|
|
741
908
|
"executables are not installed. These "
|
|
742
909
|
"dependencies are required in addition to the "
|
|
743
910
|
"'graphviz' Python package. Please install "
|
|
744
911
|
"them as described at "
|
|
745
912
|
"https://graphviz.org/download/.")
|
|
746
|
-
else:
|
|
747
|
-
import graphviz
|
|
748
913
|
|
|
749
914
|
format: Optional[str] = None
|
|
750
915
|
if isinstance(path, str):
|
|
@@ -828,6 +993,9 @@ class Graph:
|
|
|
828
993
|
graph.render(path, cleanup=True)
|
|
829
994
|
elif isinstance(path, io.BytesIO):
|
|
830
995
|
path.write(graph.pipe())
|
|
996
|
+
elif in_snowflake_notebook():
|
|
997
|
+
import streamlit as st
|
|
998
|
+
st.graphviz_chart(graph)
|
|
831
999
|
elif in_notebook():
|
|
832
1000
|
from IPython.display import display
|
|
833
1001
|
display(graph)
|
|
@@ -1,13 +1,17 @@
|
|
|
1
|
+
from .dtype import infer_dtype
|
|
2
|
+
from .pkey import infer_primary_key
|
|
3
|
+
from .time_col import infer_time_column
|
|
1
4
|
from .id import contains_id
|
|
2
5
|
from .timestamp import contains_timestamp
|
|
3
6
|
from .categorical import contains_categorical
|
|
4
7
|
from .multicategorical import contains_multicategorical
|
|
5
|
-
from .stype import infer_stype
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
10
|
+
'infer_dtype',
|
|
11
|
+
'infer_primary_key',
|
|
12
|
+
'infer_time_column',
|
|
8
13
|
'contains_id',
|
|
9
14
|
'contains_timestamp',
|
|
10
15
|
'contains_categorical',
|
|
11
16
|
'contains_multicategorical',
|
|
12
|
-
'infer_stype',
|
|
13
17
|
]
|