kumoai 2.13.0.dev202511271731__cp312-cp312-win_amd64.whl → 2.13.0.dev202512040651__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +109 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +117 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
- kumoai/experimental/rfm/base/__init__.py +10 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +134 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +301 -62
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +79 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +42 -1
- kumoai/experimental/rfm/local_graph_store.py +13 -27
- kumoai/experimental/rfm/rfm.py +6 -16
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/kumolib.cp312-win_amd64.pyd +0 -0
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/METADATA +9 -8
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/RECORD +30 -18
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511271731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/top_level.txt +0 -0
|
@@ -2,8 +2,10 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass, field
|
|
5
6
|
from importlib.util import find_spec
|
|
6
|
-
from
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -12,19 +14,28 @@ from kumoapi.typing import Stype
|
|
|
12
14
|
from typing_extensions import Self
|
|
13
15
|
|
|
14
16
|
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
17
|
+
from kumoai.experimental.rfm import Table
|
|
16
18
|
from kumoai.graph import Edge
|
|
19
|
+
from kumoai.mixin import CastMixin
|
|
17
20
|
|
|
18
21
|
if TYPE_CHECKING:
|
|
19
22
|
import graphviz
|
|
23
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
24
|
+
from snowflake.connector import SnowflakeConnection
|
|
20
25
|
|
|
21
26
|
|
|
22
|
-
|
|
23
|
-
|
|
27
|
+
@dataclass
|
|
28
|
+
class SqliteConnectionConfig(CastMixin):
|
|
29
|
+
uri: Union[str, Path]
|
|
30
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Graph:
|
|
34
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
35
|
tables in a relational database.
|
|
25
36
|
|
|
26
37
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
38
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
39
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
40
|
|
|
30
41
|
.. code-block:: python
|
|
@@ -44,7 +55,7 @@ class LocalGraph:
|
|
|
44
55
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
56
|
|
|
46
57
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
58
|
+
>>> graph = rfm.Graph({
|
|
48
59
|
... "table1": table1,
|
|
49
60
|
... "table2": table2,
|
|
50
61
|
... "table3": table3,
|
|
@@ -75,33 +86,47 @@ class LocalGraph:
|
|
|
75
86
|
|
|
76
87
|
def __init__(
|
|
77
88
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges: Optional[
|
|
89
|
+
tables: Sequence[Table],
|
|
90
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
80
91
|
) -> None:
|
|
81
92
|
|
|
82
|
-
self._tables: Dict[str,
|
|
93
|
+
self._tables: Dict[str, Table] = {}
|
|
83
94
|
self._edges: List[Edge] = []
|
|
84
95
|
|
|
85
96
|
for table in tables:
|
|
86
97
|
self.add_table(table)
|
|
87
98
|
|
|
99
|
+
for table in tables:
|
|
100
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
102
|
+
continue
|
|
103
|
+
if self[fkey.dst_table].primary_key is None:
|
|
104
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
105
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
106
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
107
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
108
|
+
f"and '{fkey.primary_key}' in table "
|
|
109
|
+
f"'{fkey.dst_table}'.")
|
|
110
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
111
|
+
|
|
88
112
|
for edge in (edges or []):
|
|
89
113
|
_edge = Edge._cast(edge)
|
|
90
114
|
assert _edge is not None
|
|
91
|
-
self.
|
|
115
|
+
if _edge not in self._edges:
|
|
116
|
+
self.link(*_edge)
|
|
92
117
|
|
|
93
118
|
@classmethod
|
|
94
119
|
def from_data(
|
|
95
120
|
cls,
|
|
96
121
|
df_dict: Dict[str, pd.DataFrame],
|
|
97
|
-
edges: Optional[
|
|
122
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
98
123
|
infer_metadata: bool = True,
|
|
99
124
|
verbose: bool = True,
|
|
100
125
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
126
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
127
|
:class:`pandas.DataFrame` objects.
|
|
103
128
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
129
|
+
Automatically infers table metadata and links by default.
|
|
105
130
|
|
|
106
131
|
.. code-block:: python
|
|
107
132
|
|
|
@@ -115,55 +140,258 @@ class LocalGraph:
|
|
|
115
140
|
>>> df3 = pd.DataFrame(...)
|
|
116
141
|
|
|
117
142
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
143
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
144
|
... "table1": df1,
|
|
120
145
|
... "table2": df2,
|
|
121
146
|
... "table3": df3,
|
|
122
147
|
... })
|
|
123
148
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
149
|
Args:
|
|
132
150
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
151
|
of the tables and the values hold table data.
|
|
152
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
153
|
+
add to the graph. If not provided, edges will be automatically
|
|
154
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
155
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
156
|
graph.
|
|
157
|
+
verbose: Whether to print verbose output.
|
|
158
|
+
"""
|
|
159
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
|
+
|
|
162
|
+
graph = cls(tables, edges=edges or [])
|
|
163
|
+
|
|
164
|
+
if infer_metadata:
|
|
165
|
+
graph.infer_metadata(False)
|
|
166
|
+
|
|
167
|
+
if edges is None:
|
|
168
|
+
graph.infer_links(False)
|
|
169
|
+
|
|
170
|
+
if verbose:
|
|
171
|
+
graph.print_metadata()
|
|
172
|
+
graph.print_links()
|
|
173
|
+
|
|
174
|
+
return graph
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def from_sqlite(
|
|
178
|
+
cls,
|
|
179
|
+
connection: Union[
|
|
180
|
+
'AdbcSqliteConnection',
|
|
181
|
+
SqliteConnectionConfig,
|
|
182
|
+
str,
|
|
183
|
+
Path,
|
|
184
|
+
Dict[str, Any],
|
|
185
|
+
],
|
|
186
|
+
table_names: Optional[Sequence[str]] = None,
|
|
187
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
188
|
+
infer_metadata: bool = True,
|
|
189
|
+
verbose: bool = True,
|
|
190
|
+
) -> Self:
|
|
191
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
192
|
+
|
|
193
|
+
Automatically infers table metadata and links by default.
|
|
194
|
+
|
|
195
|
+
.. code-block:: python
|
|
196
|
+
|
|
197
|
+
>>> # doctest: +SKIP
|
|
198
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
199
|
+
|
|
200
|
+
>>> # Create a graph from a SQLite database:
|
|
201
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
connection: An open connection from
|
|
205
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
|
+
path to the database file.
|
|
207
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
208
|
+
all tables present in the database.
|
|
136
209
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
210
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
211
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
213
|
+
graph.
|
|
139
214
|
verbose: Whether to print verbose output.
|
|
215
|
+
"""
|
|
216
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
217
|
+
Connection,
|
|
218
|
+
SQLiteTable,
|
|
219
|
+
connect,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if not isinstance(connection, Connection):
|
|
223
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
224
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
226
|
+
assert isinstance(connection, Connection)
|
|
227
|
+
|
|
228
|
+
if table_names is None:
|
|
229
|
+
with connection.cursor() as cursor:
|
|
230
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
|
+
"WHERE type='table'")
|
|
232
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
233
|
+
|
|
234
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
140
235
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
236
|
+
graph = cls(tables, edges=edges or [])
|
|
237
|
+
|
|
238
|
+
if infer_metadata:
|
|
239
|
+
graph.infer_metadata(False)
|
|
240
|
+
|
|
241
|
+
if edges is None:
|
|
242
|
+
graph.infer_links(False)
|
|
243
|
+
|
|
244
|
+
if verbose:
|
|
245
|
+
graph.print_metadata()
|
|
246
|
+
graph.print_links()
|
|
247
|
+
|
|
248
|
+
return graph
|
|
249
|
+
|
|
250
|
+
@classmethod
|
|
251
|
+
def from_snowflake(
|
|
252
|
+
cls,
|
|
253
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
254
|
+
database: Optional[str] = None,
|
|
255
|
+
schema: Optional[str] = None,
|
|
256
|
+
table_names: Optional[Sequence[str]] = None,
|
|
257
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
258
|
+
infer_metadata: bool = True,
|
|
259
|
+
verbose: bool = True,
|
|
260
|
+
) -> Self:
|
|
261
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
262
|
+
schema.
|
|
263
|
+
|
|
264
|
+
Automatically infers table metadata and links by default.
|
|
265
|
+
|
|
266
|
+
.. code-block:: python
|
|
144
267
|
|
|
145
|
-
Example:
|
|
146
268
|
>>> # doctest: +SKIP
|
|
147
269
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
|
|
149
|
-
>>>
|
|
150
|
-
>>>
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
270
|
+
|
|
271
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
272
|
+
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
connection: An open connection from
|
|
276
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
277
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
278
|
+
connection. If ``None``, will re-use an active session in case
|
|
279
|
+
it exists, or create a new connection from credentials stored
|
|
280
|
+
in environment variables.
|
|
281
|
+
database: The database.
|
|
282
|
+
schema: The schema.
|
|
283
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
284
|
+
all tables present in the database.
|
|
285
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
286
|
+
add to the graph. If not provided, edges will be automatically
|
|
287
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
288
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
289
|
+
graph.
|
|
290
|
+
verbose: Whether to print verbose output.
|
|
157
291
|
"""
|
|
158
|
-
|
|
292
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
293
|
+
Connection,
|
|
294
|
+
SnowTable,
|
|
295
|
+
connect,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
if not isinstance(connection, Connection):
|
|
299
|
+
connection = connect(**(connection or {}))
|
|
300
|
+
assert isinstance(connection, Connection)
|
|
301
|
+
|
|
302
|
+
if table_names is None:
|
|
303
|
+
with connection.cursor() as cursor:
|
|
304
|
+
if database is None and schema is None:
|
|
305
|
+
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
306
|
+
"CURRENT_SCHEMA()")
|
|
307
|
+
result = cursor.fetchone()
|
|
308
|
+
database = database or result[0]
|
|
309
|
+
schema = schema or result[1]
|
|
310
|
+
cursor.execute(f"""
|
|
311
|
+
SELECT TABLE_NAME
|
|
312
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
313
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
314
|
+
""")
|
|
315
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
316
|
+
|
|
317
|
+
tables = [
|
|
318
|
+
SnowTable(
|
|
319
|
+
connection,
|
|
320
|
+
name=table_name,
|
|
321
|
+
database=database,
|
|
322
|
+
schema=schema,
|
|
323
|
+
) for table_name in table_names
|
|
324
|
+
]
|
|
159
325
|
|
|
160
326
|
graph = cls(tables, edges=edges or [])
|
|
161
327
|
|
|
162
328
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(
|
|
329
|
+
graph.infer_metadata(False)
|
|
164
330
|
|
|
165
331
|
if edges is None:
|
|
166
|
-
graph.infer_links(
|
|
332
|
+
graph.infer_links(False)
|
|
333
|
+
|
|
334
|
+
if verbose:
|
|
335
|
+
graph.print_metadata()
|
|
336
|
+
graph.print_links()
|
|
337
|
+
|
|
338
|
+
return graph
|
|
339
|
+
|
|
340
|
+
@classmethod
|
|
341
|
+
def from_snowflake_semantic_view(
|
|
342
|
+
cls,
|
|
343
|
+
semantic_view_name: str,
|
|
344
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
345
|
+
verbose: bool = True,
|
|
346
|
+
) -> Self:
|
|
347
|
+
import yaml
|
|
348
|
+
|
|
349
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
350
|
+
Connection,
|
|
351
|
+
SnowTable,
|
|
352
|
+
connect,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if not isinstance(connection, Connection):
|
|
356
|
+
connection = connect(**(connection or {}))
|
|
357
|
+
assert isinstance(connection, Connection)
|
|
358
|
+
|
|
359
|
+
with connection.cursor() as cursor:
|
|
360
|
+
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
361
|
+
f"'{semantic_view_name}')")
|
|
362
|
+
view = yaml.safe_load(cursor.fetchone()[0])
|
|
363
|
+
|
|
364
|
+
graph = cls(tables=[])
|
|
365
|
+
|
|
366
|
+
for table_desc in view['tables']:
|
|
367
|
+
primary_key: Optional[str] = None
|
|
368
|
+
if ('primary_key' in table_desc # NOTE No composite keys yet.
|
|
369
|
+
and len(table_desc['primary_key']['columns']) == 1):
|
|
370
|
+
primary_key = table_desc['primary_key']['columns'][0]
|
|
371
|
+
|
|
372
|
+
table = SnowTable(
|
|
373
|
+
connection,
|
|
374
|
+
name=table_desc['base_table']['table'],
|
|
375
|
+
database=table_desc['base_table']['database'],
|
|
376
|
+
schema=table_desc['base_table']['schema'],
|
|
377
|
+
primary_key=primary_key,
|
|
378
|
+
)
|
|
379
|
+
graph.add_table(table)
|
|
380
|
+
|
|
381
|
+
# TODO Find a solution to register time columns!
|
|
382
|
+
|
|
383
|
+
for relations in view['relationships']:
|
|
384
|
+
if len(relations['relationship_columns']) != 1:
|
|
385
|
+
continue # NOTE No composite keys yet.
|
|
386
|
+
graph.link(
|
|
387
|
+
src_table=relations['left_table'],
|
|
388
|
+
fkey=relations['relationship_columns'][0]['left_column'],
|
|
389
|
+
dst_table=relations['right_table'],
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
if verbose:
|
|
393
|
+
graph.print_metadata()
|
|
394
|
+
graph.print_links()
|
|
167
395
|
|
|
168
396
|
return graph
|
|
169
397
|
|
|
@@ -175,7 +403,7 @@ class LocalGraph:
|
|
|
175
403
|
"""
|
|
176
404
|
return name in self.tables
|
|
177
405
|
|
|
178
|
-
def table(self, name: str) ->
|
|
406
|
+
def table(self, name: str) -> Table:
|
|
179
407
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
408
|
|
|
181
409
|
Raises:
|
|
@@ -186,11 +414,11 @@ class LocalGraph:
|
|
|
186
414
|
return self.tables[name]
|
|
187
415
|
|
|
188
416
|
@property
|
|
189
|
-
def tables(self) -> Dict[str,
|
|
417
|
+
def tables(self) -> Dict[str, Table]:
|
|
190
418
|
r"""Returns the dictionary of table objects."""
|
|
191
419
|
return self._tables
|
|
192
420
|
|
|
193
|
-
def add_table(self, table:
|
|
421
|
+
def add_table(self, table: Table) -> Self:
|
|
194
422
|
r"""Adds a table to the graph.
|
|
195
423
|
|
|
196
424
|
Args:
|
|
@@ -199,11 +427,21 @@ class LocalGraph:
|
|
|
199
427
|
Raises:
|
|
200
428
|
KeyError: If a table with the same name already exists in the
|
|
201
429
|
graph.
|
|
430
|
+
ValueError: If the table belongs to a different backend than the
|
|
431
|
+
rest of the tables in the graph.
|
|
202
432
|
"""
|
|
203
433
|
if table.name in self._tables:
|
|
204
434
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
435
|
f"this graph; table names must be globally unique.")
|
|
206
436
|
|
|
437
|
+
if len(self._tables) > 0:
|
|
438
|
+
cls = next(iter(self._tables.values())).__class__
|
|
439
|
+
if table.__class__ != cls:
|
|
440
|
+
raise ValueError(f"Cannot register a "
|
|
441
|
+
f"'{table.__class__.__name__}' to this "
|
|
442
|
+
f"graph since other tables are of type "
|
|
443
|
+
f"'{cls.__name__}'.")
|
|
444
|
+
|
|
207
445
|
self._tables[table.name] = table
|
|
208
446
|
|
|
209
447
|
return self
|
|
@@ -241,7 +479,7 @@ class LocalGraph:
|
|
|
241
479
|
Example:
|
|
242
480
|
>>> # doctest: +SKIP
|
|
243
481
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
482
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
483
|
>>> graph.metadata # doctest: +SKIP
|
|
246
484
|
name primary_key time_column end_time_column
|
|
247
485
|
0 users user_id - -
|
|
@@ -263,7 +501,7 @@ class LocalGraph:
|
|
|
263
501
|
})
|
|
264
502
|
|
|
265
503
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
504
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
267
505
|
if in_notebook():
|
|
268
506
|
from IPython.display import Markdown, display
|
|
269
507
|
display(Markdown('### 🗂️ Graph Metadata'))
|
|
@@ -287,7 +525,7 @@ class LocalGraph:
|
|
|
287
525
|
|
|
288
526
|
Note:
|
|
289
527
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
528
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
529
|
"""
|
|
292
530
|
for table in self.tables.values():
|
|
293
531
|
table.infer_metadata(verbose=False)
|
|
@@ -305,7 +543,7 @@ class LocalGraph:
|
|
|
305
543
|
return self._edges
|
|
306
544
|
|
|
307
545
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
546
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
547
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
548
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
549
|
edges = sorted(edges)
|
|
@@ -333,9 +571,9 @@ class LocalGraph:
|
|
|
333
571
|
|
|
334
572
|
def link(
|
|
335
573
|
self,
|
|
336
|
-
src_table: Union[str,
|
|
574
|
+
src_table: Union[str, Table],
|
|
337
575
|
fkey: str,
|
|
338
|
-
dst_table: Union[str,
|
|
576
|
+
dst_table: Union[str, Table],
|
|
339
577
|
) -> Self:
|
|
340
578
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
579
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +596,11 @@ class LocalGraph:
|
|
|
358
596
|
table does not exist in the graph, if the source key does not
|
|
359
597
|
exist in the source table.
|
|
360
598
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
599
|
+
if isinstance(src_table, Table):
|
|
362
600
|
src_table = src_table.name
|
|
363
601
|
assert isinstance(src_table, str)
|
|
364
602
|
|
|
365
|
-
if isinstance(dst_table,
|
|
603
|
+
if isinstance(dst_table, Table):
|
|
366
604
|
dst_table = dst_table.name
|
|
367
605
|
assert isinstance(dst_table, str)
|
|
368
606
|
|
|
@@ -396,9 +634,9 @@ class LocalGraph:
|
|
|
396
634
|
|
|
397
635
|
def unlink(
|
|
398
636
|
self,
|
|
399
|
-
src_table: Union[str,
|
|
637
|
+
src_table: Union[str, Table],
|
|
400
638
|
fkey: str,
|
|
401
|
-
dst_table: Union[str,
|
|
639
|
+
dst_table: Union[str, Table],
|
|
402
640
|
) -> Self:
|
|
403
641
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
642
|
|
|
@@ -410,11 +648,11 @@ class LocalGraph:
|
|
|
410
648
|
Raises:
|
|
411
649
|
ValueError: if the edge is not present in the graph.
|
|
412
650
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
651
|
+
if isinstance(src_table, Table):
|
|
414
652
|
src_table = src_table.name
|
|
415
653
|
assert isinstance(src_table, str)
|
|
416
654
|
|
|
417
|
-
if isinstance(dst_table,
|
|
655
|
+
if isinstance(dst_table, Table):
|
|
418
656
|
dst_table = dst_table.name
|
|
419
657
|
assert isinstance(dst_table, str)
|
|
420
658
|
|
|
@@ -428,17 +666,13 @@ class LocalGraph:
|
|
|
428
666
|
return self
|
|
429
667
|
|
|
430
668
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
669
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
670
|
+
graph.
|
|
432
671
|
|
|
433
672
|
Args:
|
|
434
673
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
674
|
"""
|
|
439
|
-
|
|
440
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
441
|
-
return self
|
|
675
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
442
676
|
|
|
443
677
|
# A list of primary key candidates (+score) for every column:
|
|
444
678
|
candidate_dict: dict[
|
|
@@ -463,6 +697,9 @@ class LocalGraph:
|
|
|
463
697
|
src_table_name = src_table.name.lower()
|
|
464
698
|
|
|
465
699
|
for src_key in src_table.columns:
|
|
700
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
701
|
+
continue
|
|
702
|
+
|
|
466
703
|
if src_key == src_table.primary_key:
|
|
467
704
|
continue # Cannot link to primary key.
|
|
468
705
|
|
|
@@ -528,7 +765,9 @@ class LocalGraph:
|
|
|
528
765
|
score += 1.0
|
|
529
766
|
|
|
530
767
|
# Cardinality ratio:
|
|
531
|
-
if
|
|
768
|
+
if (src_table._num_rows is not None
|
|
769
|
+
and dst_table._num_rows is not None
|
|
770
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
532
771
|
score += 1.0
|
|
533
772
|
|
|
534
773
|
if score < 5.0:
|
|
@@ -790,7 +1029,7 @@ class LocalGraph:
|
|
|
790
1029
|
def __contains__(self, name: str) -> bool:
|
|
791
1030
|
return self.has_table(name)
|
|
792
1031
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
1032
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
1033
|
return self.table(name)
|
|
795
1034
|
|
|
796
1035
|
def __delitem__(self, name: str) -> None:
|
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from .dtype import infer_dtype
|
|
2
|
+
from .pkey import infer_primary_key
|
|
3
|
+
from .time_col import infer_time_column
|
|
1
4
|
from .id import contains_id
|
|
2
5
|
from .timestamp import contains_timestamp
|
|
3
6
|
from .categorical import contains_categorical
|
|
4
7
|
from .multicategorical import contains_multicategorical
|
|
5
8
|
|
|
6
9
|
__all__ = [
|
|
10
|
+
'infer_dtype',
|
|
11
|
+
'infer_primary_key',
|
|
12
|
+
'infer_time_column',
|
|
7
13
|
'contains_id',
|
|
8
14
|
'contains_timestamp',
|
|
9
15
|
'contains_categorical',
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from typing import Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
PANDAS_TO_DTYPE: Dict[str, Dtype] = {
|
|
9
|
+
'bool': Dtype.bool,
|
|
10
|
+
'boolean': Dtype.bool,
|
|
11
|
+
'int8': Dtype.int,
|
|
12
|
+
'int16': Dtype.int,
|
|
13
|
+
'int32': Dtype.int,
|
|
14
|
+
'int64': Dtype.int,
|
|
15
|
+
'float16': Dtype.float,
|
|
16
|
+
'float32': Dtype.float,
|
|
17
|
+
'float64': Dtype.float,
|
|
18
|
+
'object': Dtype.string,
|
|
19
|
+
'string': Dtype.string,
|
|
20
|
+
'string[python]': Dtype.string,
|
|
21
|
+
'string[pyarrow]': Dtype.string,
|
|
22
|
+
'binary': Dtype.binary,
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def infer_dtype(ser: pd.Series) -> Dtype:
|
|
27
|
+
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The data type.
|
|
34
|
+
"""
|
|
35
|
+
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
36
|
+
return Dtype.date
|
|
37
|
+
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
38
|
+
return Dtype.timedelta
|
|
39
|
+
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
40
|
+
return Dtype.string
|
|
41
|
+
|
|
42
|
+
if (pd.api.types.is_object_dtype(ser.dtype)
|
|
43
|
+
and not isinstance(ser.dtype, pd.ArrowDtype)):
|
|
44
|
+
index = ser.iloc[:1000].first_valid_index()
|
|
45
|
+
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
46
|
+
pos = ser.index.get_loc(index)
|
|
47
|
+
assert isinstance(pos, int)
|
|
48
|
+
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
49
|
+
arr = pa.array(ser.tolist())
|
|
50
|
+
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
51
|
+
|
|
52
|
+
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
53
|
+
if pa.types.is_list(ser.dtype.pyarrow_dtype):
|
|
54
|
+
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
55
|
+
if pa.types.is_integer(elem_dtype):
|
|
56
|
+
return Dtype.intlist
|
|
57
|
+
if pa.types.is_floating(elem_dtype):
|
|
58
|
+
return Dtype.floatlist
|
|
59
|
+
if pa.types.is_decimal(elem_dtype):
|
|
60
|
+
return Dtype.floatlist
|
|
61
|
+
if pa.types.is_string(elem_dtype):
|
|
62
|
+
return Dtype.stringlist
|
|
63
|
+
if pa.types.is_null(elem_dtype):
|
|
64
|
+
return Dtype.floatlist
|
|
65
|
+
|
|
66
|
+
if isinstance(ser.dtype, np.dtype):
|
|
67
|
+
dtype_str = str(ser.dtype).lower()
|
|
68
|
+
elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
|
|
69
|
+
dtype_str = ser.dtype.name.lower()
|
|
70
|
+
dtype_str = dtype_str.split('[')[0] # Remove backend metadata
|
|
71
|
+
elif isinstance(ser.dtype, pa.DataType):
|
|
72
|
+
dtype_str = str(ser.dtype).lower()
|
|
73
|
+
else:
|
|
74
|
+
dtype_str = 'object'
|
|
75
|
+
|
|
76
|
+
if dtype_str not in PANDAS_TO_DTYPE:
|
|
77
|
+
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
78
|
+
|
|
79
|
+
return PANDAS_TO_DTYPE[dtype_str]
|