kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
- kumoai/experimental/rfm/backend/local/table.py +16 -21
- kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
- kumoai/experimental/rfm/backend/snow/table.py +67 -33
- kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
- kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
- kumoai/experimental/rfm/base/__init__.py +8 -0
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sql_sampler.py +31 -3
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +162 -143
- kumoai/experimental/rfm/graph.py +242 -95
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/rfm.py +86 -80
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/utils/__init__.py +1 -2
- kumoai/utils/progress_logger.py +178 -12
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
6
|
+
from collections.abc import Sequence
|
|
5
7
|
from dataclasses import dataclass, field
|
|
8
|
+
from itertools import chain
|
|
6
9
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
8
11
|
|
|
9
12
|
import pandas as pd
|
|
10
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,7 +16,12 @@ from kumoapi.typing import Stype
|
|
|
13
16
|
from typing_extensions import Self
|
|
14
17
|
|
|
15
18
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
-
from kumoai.experimental.rfm.base import
|
|
19
|
+
from kumoai.experimental.rfm.base import (
|
|
20
|
+
ColumnExpressionSpec,
|
|
21
|
+
DataBackend,
|
|
22
|
+
SQLTable,
|
|
23
|
+
Table,
|
|
24
|
+
)
|
|
17
25
|
from kumoai.graph import Edge
|
|
18
26
|
from kumoai.mixin import CastMixin
|
|
19
27
|
|
|
@@ -25,8 +33,8 @@ if TYPE_CHECKING:
|
|
|
25
33
|
|
|
26
34
|
@dataclass
|
|
27
35
|
class SqliteConnectionConfig(CastMixin):
|
|
28
|
-
uri:
|
|
29
|
-
kwargs:
|
|
36
|
+
uri: str | Path
|
|
37
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
30
38
|
|
|
31
39
|
|
|
32
40
|
class Graph:
|
|
@@ -86,27 +94,38 @@ class Graph:
|
|
|
86
94
|
def __init__(
|
|
87
95
|
self,
|
|
88
96
|
tables: Sequence[Table],
|
|
89
|
-
edges:
|
|
97
|
+
edges: Sequence[Edge] | None = None,
|
|
90
98
|
) -> None:
|
|
91
99
|
|
|
92
|
-
self._tables:
|
|
93
|
-
self._edges:
|
|
100
|
+
self._tables: dict[str, Table] = {}
|
|
101
|
+
self._edges: list[Edge] = []
|
|
94
102
|
|
|
95
103
|
for table in tables:
|
|
96
104
|
self.add_table(table)
|
|
97
105
|
|
|
98
106
|
for table in tables:
|
|
107
|
+
if not isinstance(table, SQLTable):
|
|
108
|
+
continue
|
|
109
|
+
if '_source_column_dict' not in table.__dict__:
|
|
110
|
+
continue
|
|
99
111
|
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
-
if fkey.name not in table
|
|
112
|
+
if fkey.name not in table:
|
|
113
|
+
continue
|
|
114
|
+
if not table[fkey.name].is_physical:
|
|
115
|
+
continue
|
|
116
|
+
dst_table_names = [
|
|
117
|
+
table.name for table in self.tables.values()
|
|
118
|
+
if isinstance(table, SQLTable)
|
|
119
|
+
and table._source_name == fkey.dst_table
|
|
120
|
+
]
|
|
121
|
+
if len(dst_table_names) != 1:
|
|
122
|
+
continue
|
|
123
|
+
dst_table = self[dst_table_names[0]]
|
|
124
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
125
|
+
continue
|
|
126
|
+
if not dst_table[fkey.primary_key].is_physical:
|
|
101
127
|
continue
|
|
102
|
-
|
|
103
|
-
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
-
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
-
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
-
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
-
f"and '{fkey.primary_key}' in table "
|
|
108
|
-
f"'{fkey.dst_table}'.")
|
|
109
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
128
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
110
129
|
|
|
111
130
|
for edge in (edges or []):
|
|
112
131
|
_edge = Edge._cast(edge)
|
|
@@ -117,8 +136,8 @@ class Graph:
|
|
|
117
136
|
@classmethod
|
|
118
137
|
def from_data(
|
|
119
138
|
cls,
|
|
120
|
-
df_dict:
|
|
121
|
-
edges:
|
|
139
|
+
df_dict: dict[str, pd.DataFrame],
|
|
140
|
+
edges: Sequence[Edge] | None = None,
|
|
122
141
|
infer_metadata: bool = True,
|
|
123
142
|
verbose: bool = True,
|
|
124
143
|
) -> Self:
|
|
@@ -156,15 +175,17 @@ class Graph:
|
|
|
156
175
|
verbose: Whether to print verbose output.
|
|
157
176
|
"""
|
|
158
177
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
178
|
|
|
161
|
-
graph = cls(
|
|
179
|
+
graph = cls(
|
|
180
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
181
|
+
edges=edges or [],
|
|
182
|
+
)
|
|
162
183
|
|
|
163
184
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(False)
|
|
185
|
+
graph.infer_metadata(verbose=False)
|
|
165
186
|
|
|
166
187
|
if edges is None:
|
|
167
|
-
graph.infer_links(False)
|
|
188
|
+
graph.infer_links(verbose=False)
|
|
168
189
|
|
|
169
190
|
if verbose:
|
|
170
191
|
graph.print_metadata()
|
|
@@ -180,10 +201,10 @@ class Graph:
|
|
|
180
201
|
SqliteConnectionConfig,
|
|
181
202
|
str,
|
|
182
203
|
Path,
|
|
183
|
-
|
|
204
|
+
dict[str, Any],
|
|
184
205
|
],
|
|
185
|
-
|
|
186
|
-
edges:
|
|
206
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
207
|
+
edges: Sequence[Edge] | None = None,
|
|
187
208
|
infer_metadata: bool = True,
|
|
188
209
|
verbose: bool = True,
|
|
189
210
|
) -> Self:
|
|
@@ -199,17 +220,25 @@ class Graph:
|
|
|
199
220
|
>>> # Create a graph from a SQLite database:
|
|
200
221
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
222
|
|
|
223
|
+
>>> # Fine-grained control over table specification:
|
|
224
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
225
|
+
... 'USERS',
|
|
226
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
227
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
228
|
+
... ])
|
|
229
|
+
|
|
202
230
|
Args:
|
|
203
231
|
connection: An open connection from
|
|
204
232
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
233
|
path to the database file.
|
|
206
|
-
|
|
207
|
-
all tables present
|
|
234
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
235
|
+
arguments to include. If ``None``, will add all tables present
|
|
236
|
+
in the database.
|
|
208
237
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
209
238
|
add to the graph. If not provided, edges will be automatically
|
|
210
239
|
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
212
|
-
graph.
|
|
240
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
241
|
+
the graph.
|
|
213
242
|
verbose: Whether to print verbose output.
|
|
214
243
|
"""
|
|
215
244
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -226,24 +255,33 @@ class Graph:
|
|
|
226
255
|
internal_connection = True
|
|
227
256
|
assert isinstance(connection, Connection)
|
|
228
257
|
|
|
229
|
-
if
|
|
258
|
+
if tables is None:
|
|
230
259
|
with connection.cursor() as cursor:
|
|
231
260
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
232
261
|
"WHERE type='table'")
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
262
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
236
263
|
|
|
237
|
-
|
|
264
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
265
|
+
for table in tables:
|
|
266
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
267
|
+
table_kwargs.append(kwargs)
|
|
268
|
+
|
|
269
|
+
graph = cls(
|
|
270
|
+
tables=[
|
|
271
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
272
|
+
for kwargs in table_kwargs
|
|
273
|
+
],
|
|
274
|
+
edges=edges or [],
|
|
275
|
+
)
|
|
238
276
|
|
|
239
277
|
if internal_connection:
|
|
240
278
|
graph._connection = connection # type: ignore
|
|
241
279
|
|
|
242
280
|
if infer_metadata:
|
|
243
|
-
graph.infer_metadata(False)
|
|
281
|
+
graph.infer_metadata(verbose=False)
|
|
244
282
|
|
|
245
283
|
if edges is None:
|
|
246
|
-
graph.infer_links(False)
|
|
284
|
+
graph.infer_links(verbose=False)
|
|
247
285
|
|
|
248
286
|
if verbose:
|
|
249
287
|
graph.print_metadata()
|
|
@@ -254,11 +292,11 @@ class Graph:
|
|
|
254
292
|
@classmethod
|
|
255
293
|
def from_snowflake(
|
|
256
294
|
cls,
|
|
257
|
-
connection: Union['SnowflakeConnection',
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
edges:
|
|
295
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
296
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
297
|
+
database: str | None = None,
|
|
298
|
+
schema: str | None = None,
|
|
299
|
+
edges: Sequence[Edge] | None = None,
|
|
262
300
|
infer_metadata: bool = True,
|
|
263
301
|
verbose: bool = True,
|
|
264
302
|
) -> Self:
|
|
@@ -275,6 +313,13 @@ class Graph:
|
|
|
275
313
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
276
314
|
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
277
315
|
|
|
316
|
+
>>> # Fine-grained control over table specification:
|
|
317
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
318
|
+
... 'USERS',
|
|
319
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
320
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
321
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
322
|
+
|
|
278
323
|
Args:
|
|
279
324
|
connection: An open connection from
|
|
280
325
|
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
@@ -282,10 +327,11 @@ class Graph:
|
|
|
282
327
|
connection. If ``None``, will re-use an active session in case
|
|
283
328
|
it exists, or create a new connection from credentials stored
|
|
284
329
|
in environment variables.
|
|
330
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
331
|
+
to include. If ``None``, will add all tables present in the
|
|
332
|
+
current database and schema.
|
|
285
333
|
database: The database.
|
|
286
334
|
schema: The schema.
|
|
287
|
-
table_names: Set of table names to include. If ``None``, will add
|
|
288
|
-
all tables present in the database.
|
|
289
335
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
290
336
|
add to the graph. If not provided, edges will be automatically
|
|
291
337
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -303,37 +349,50 @@ class Graph:
|
|
|
303
349
|
connection = connect(**(connection or {}))
|
|
304
350
|
assert isinstance(connection, Connection)
|
|
305
351
|
|
|
306
|
-
if
|
|
352
|
+
if database is None or schema is None:
|
|
353
|
+
with connection.cursor() as cursor:
|
|
354
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
355
|
+
result = cursor.fetchone()
|
|
356
|
+
database = database or result[0]
|
|
357
|
+
assert database is not None
|
|
358
|
+
schema = schema or result[1]
|
|
359
|
+
|
|
360
|
+
if tables is None:
|
|
361
|
+
if schema is None:
|
|
362
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
363
|
+
"Snowflake schema manually")
|
|
364
|
+
|
|
307
365
|
with connection.cursor() as cursor:
|
|
308
|
-
if database is None and schema is None:
|
|
309
|
-
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
310
|
-
"CURRENT_SCHEMA()")
|
|
311
|
-
result = cursor.fetchone()
|
|
312
|
-
database = database or result[0]
|
|
313
|
-
schema = schema or result[1]
|
|
314
366
|
cursor.execute(f"""
|
|
315
367
|
SELECT TABLE_NAME
|
|
316
368
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
317
369
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
318
370
|
""")
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
tables = [
|
|
322
|
-
SnowTable(
|
|
323
|
-
connection,
|
|
324
|
-
name=table_name,
|
|
325
|
-
database=database,
|
|
326
|
-
schema=schema,
|
|
327
|
-
) for table_name in table_names
|
|
328
|
-
]
|
|
371
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
329
372
|
|
|
330
|
-
|
|
373
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
374
|
+
for table in tables:
|
|
375
|
+
if isinstance(table, str):
|
|
376
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
377
|
+
else:
|
|
378
|
+
kwargs = copy.copy(table)
|
|
379
|
+
kwargs.setdefault('database', database)
|
|
380
|
+
kwargs.setdefault('schema', schema)
|
|
381
|
+
table_kwargs.append(kwargs)
|
|
382
|
+
|
|
383
|
+
graph = cls(
|
|
384
|
+
tables=[
|
|
385
|
+
SnowTable(connection=connection, **kwargs)
|
|
386
|
+
for kwargs in table_kwargs
|
|
387
|
+
],
|
|
388
|
+
edges=edges or [],
|
|
389
|
+
)
|
|
331
390
|
|
|
332
391
|
if infer_metadata:
|
|
333
|
-
graph.infer_metadata(False)
|
|
392
|
+
graph.infer_metadata(verbose=False)
|
|
334
393
|
|
|
335
394
|
if edges is None:
|
|
336
|
-
graph.infer_links(False)
|
|
395
|
+
graph.infer_links(verbose=False)
|
|
337
396
|
|
|
338
397
|
if verbose:
|
|
339
398
|
graph.print_metadata()
|
|
@@ -345,7 +404,7 @@ class Graph:
|
|
|
345
404
|
def from_snowflake_semantic_view(
|
|
346
405
|
cls,
|
|
347
406
|
semantic_view_name: str,
|
|
348
|
-
connection: Union['SnowflakeConnection',
|
|
407
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
349
408
|
verbose: bool = True,
|
|
350
409
|
) -> Self:
|
|
351
410
|
import yaml
|
|
@@ -363,40 +422,128 @@ class Graph:
|
|
|
363
422
|
with connection.cursor() as cursor:
|
|
364
423
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
365
424
|
f"'{semantic_view_name}')")
|
|
366
|
-
|
|
425
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
367
426
|
|
|
368
427
|
graph = cls(tables=[])
|
|
369
428
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
429
|
+
msgs = []
|
|
430
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
431
|
+
for table_cfg in cfg['tables']:
|
|
432
|
+
table_name = table_cfg['name']
|
|
433
|
+
source_table_name = table_cfg['base_table']['table']
|
|
434
|
+
database = table_cfg['base_table']['database']
|
|
435
|
+
schema = table_cfg['base_table']['schema']
|
|
436
|
+
|
|
437
|
+
primary_key: str | None = None
|
|
438
|
+
if 'primary_key' in table_cfg:
|
|
439
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
440
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
441
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
442
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
443
|
+
msgs.append(f"Failed to add primary key for table "
|
|
444
|
+
f"'{table_name}' since composite primary keys "
|
|
445
|
+
f"are not yet supported")
|
|
446
|
+
|
|
447
|
+
columns: list[str] = []
|
|
448
|
+
unsupported_columns: list[str] = []
|
|
449
|
+
column_expression_specs: list[ColumnExpressionSpec] = []
|
|
450
|
+
for column_cfg in chain(
|
|
451
|
+
table_cfg.get('dimensions', []),
|
|
452
|
+
table_cfg.get('time_dimensions', []),
|
|
453
|
+
table_cfg.get('facts', []),
|
|
454
|
+
):
|
|
455
|
+
column_name = column_cfg['name']
|
|
456
|
+
column_expr = column_cfg.get('expr', None)
|
|
457
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
458
|
+
|
|
459
|
+
if column_expr is None:
|
|
460
|
+
columns.append(column_name)
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
464
|
+
|
|
465
|
+
if column_expr == column_name:
|
|
466
|
+
columns.append(column_name)
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
# Drop expressions that reference other tables (for now):
|
|
470
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
471
|
+
unsupported_columns.append(column_name)
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
spec = ColumnExpressionSpec(
|
|
475
|
+
name=column_name,
|
|
476
|
+
expr=column_expr,
|
|
477
|
+
dtype=SnowTable.to_dtype(column_data_type),
|
|
478
|
+
)
|
|
479
|
+
column_expression_specs.append(spec)
|
|
480
|
+
|
|
481
|
+
if len(unsupported_columns) == 1:
|
|
482
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
483
|
+
f"of table '{table_name}' since its expression "
|
|
484
|
+
f"references other tables")
|
|
485
|
+
elif len(unsupported_columns) > 1:
|
|
486
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
487
|
+
f"of table '{table_name}' since their expressions "
|
|
488
|
+
f"reference other tables")
|
|
375
489
|
|
|
376
490
|
table = SnowTable(
|
|
377
491
|
connection,
|
|
378
|
-
name=
|
|
379
|
-
|
|
380
|
-
|
|
492
|
+
name=table_name,
|
|
493
|
+
source_name=source_table_name,
|
|
494
|
+
database=database,
|
|
495
|
+
schema=schema,
|
|
496
|
+
columns=columns,
|
|
497
|
+
column_expressions=column_expression_specs,
|
|
381
498
|
primary_key=primary_key,
|
|
382
499
|
)
|
|
500
|
+
|
|
501
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
502
|
+
table.infer_time_column(verbose=False)
|
|
503
|
+
|
|
383
504
|
graph.add_table(table)
|
|
384
505
|
|
|
385
|
-
|
|
506
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
507
|
+
name = relation_cfg['name']
|
|
508
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
509
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
510
|
+
f"composite key references are not yet supported")
|
|
511
|
+
continue
|
|
512
|
+
|
|
513
|
+
left_table = relation_cfg['left_table']
|
|
514
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
515
|
+
right_table = relation_cfg['right_table']
|
|
516
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
386
517
|
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
518
|
+
if graph[right_table]._primary_key != right_key:
|
|
519
|
+
# Semantic view error - this should never be triggered:
|
|
520
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
521
|
+
f"referenced key '{right_key}' of table "
|
|
522
|
+
f"'{right_table}' is not a primary key")
|
|
523
|
+
continue
|
|
524
|
+
|
|
525
|
+
if graph[left_table]._primary_key == left_key:
|
|
526
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
527
|
+
f"referencing key '{left_key}' of table "
|
|
528
|
+
f"'{left_table}' is a primary key")
|
|
529
|
+
continue
|
|
530
|
+
|
|
531
|
+
if left_key not in graph[left_table]:
|
|
532
|
+
graph[left_table].add_column(left_key)
|
|
533
|
+
|
|
534
|
+
graph.link(left_table, left_key, right_table)
|
|
535
|
+
|
|
536
|
+
graph.validate()
|
|
395
537
|
|
|
396
538
|
if verbose:
|
|
397
539
|
graph.print_metadata()
|
|
398
540
|
graph.print_links()
|
|
399
541
|
|
|
542
|
+
if len(msgs) > 0:
|
|
543
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
544
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
545
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
546
|
+
|
|
400
547
|
return graph
|
|
401
548
|
|
|
402
549
|
# Backend #################################################################
|
|
@@ -425,7 +572,7 @@ class Graph:
|
|
|
425
572
|
return self.tables[name]
|
|
426
573
|
|
|
427
574
|
@property
|
|
428
|
-
def tables(self) ->
|
|
575
|
+
def tables(self) -> dict[str, Table]:
|
|
429
576
|
r"""Returns the dictionary of table objects."""
|
|
430
577
|
return self._tables
|
|
431
578
|
|
|
@@ -550,7 +697,7 @@ class Graph:
|
|
|
550
697
|
# Edges ###################################################################
|
|
551
698
|
|
|
552
699
|
@property
|
|
553
|
-
def edges(self) ->
|
|
700
|
+
def edges(self) -> list[Edge]:
|
|
554
701
|
r"""Returns the edges of the graph."""
|
|
555
702
|
return self._edges
|
|
556
703
|
|
|
@@ -565,7 +712,7 @@ class Graph:
|
|
|
565
712
|
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
566
713
|
if len(edges) > 0:
|
|
567
714
|
st.markdown('\n'.join([
|
|
568
|
-
f"-
|
|
715
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
569
716
|
for edge in edges
|
|
570
717
|
]))
|
|
571
718
|
else:
|
|
@@ -593,9 +740,9 @@ class Graph:
|
|
|
593
740
|
|
|
594
741
|
def link(
|
|
595
742
|
self,
|
|
596
|
-
src_table:
|
|
743
|
+
src_table: str | Table,
|
|
597
744
|
fkey: str,
|
|
598
|
-
dst_table:
|
|
745
|
+
dst_table: str | Table,
|
|
599
746
|
) -> Self:
|
|
600
747
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
601
748
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -656,9 +803,9 @@ class Graph:
|
|
|
656
803
|
|
|
657
804
|
def unlink(
|
|
658
805
|
self,
|
|
659
|
-
src_table:
|
|
806
|
+
src_table: str | Table,
|
|
660
807
|
fkey: str,
|
|
661
|
-
dst_table:
|
|
808
|
+
dst_table: str | Table,
|
|
662
809
|
) -> Self:
|
|
663
810
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
664
811
|
|
|
@@ -885,7 +1032,7 @@ class Graph:
|
|
|
885
1032
|
|
|
886
1033
|
def visualize(
|
|
887
1034
|
self,
|
|
888
|
-
path:
|
|
1035
|
+
path: str | io.BytesIO | None = None,
|
|
889
1036
|
show_columns: bool = True,
|
|
890
1037
|
) -> 'graphviz.Graph':
|
|
891
1038
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -924,19 +1071,19 @@ class Graph:
|
|
|
924
1071
|
"them as described at "
|
|
925
1072
|
"https://graphviz.org/download/.")
|
|
926
1073
|
|
|
927
|
-
format:
|
|
1074
|
+
format: str | None = None
|
|
928
1075
|
if isinstance(path, str):
|
|
929
1076
|
format = path.split('.')[-1]
|
|
930
1077
|
elif isinstance(path, io.BytesIO):
|
|
931
1078
|
format = 'svg'
|
|
932
1079
|
graph = graphviz.Graph(format=format)
|
|
933
1080
|
|
|
934
|
-
def left_align(keys:
|
|
1081
|
+
def left_align(keys: list[str]) -> str:
|
|
935
1082
|
if len(keys) == 0:
|
|
936
1083
|
return ""
|
|
937
1084
|
return '\\l'.join(keys) + '\\l'
|
|
938
1085
|
|
|
939
|
-
fkeys_dict:
|
|
1086
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
940
1087
|
for src_table_name, fkey_name, _ in self.edges:
|
|
941
1088
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
942
1089
|
|
|
@@ -1032,8 +1179,8 @@ class Graph:
|
|
|
1032
1179
|
# Helpers #################################################################
|
|
1033
1180
|
|
|
1034
1181
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
1035
|
-
tables:
|
|
1036
|
-
col_groups:
|
|
1182
|
+
tables: dict[str, TableDefinition] = {}
|
|
1183
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
1037
1184
|
for table_name, table in self.tables.items():
|
|
1038
1185
|
tables[table_name] = table._to_api_table_definition()
|
|
1039
1186
|
if table.primary_key is None:
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
from .dtype import infer_dtype
|
|
2
|
-
from .pkey import infer_primary_key
|
|
3
|
-
from .time_col import infer_time_column
|
|
4
2
|
from .id import contains_id
|
|
5
3
|
from .timestamp import contains_timestamp
|
|
6
4
|
from .categorical import contains_categorical
|
|
7
5
|
from .multicategorical import contains_multicategorical
|
|
6
|
+
from .stype import infer_stype
|
|
7
|
+
from .pkey import infer_primary_key
|
|
8
|
+
from .time_col import infer_time_column
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
'infer_dtype',
|
|
11
|
-
'infer_primary_key',
|
|
12
|
-
'infer_time_column',
|
|
13
12
|
'contains_id',
|
|
14
13
|
'contains_timestamp',
|
|
15
14
|
'contains_categorical',
|
|
16
15
|
'contains_multicategorical',
|
|
16
|
+
'infer_stype',
|
|
17
|
+
'infer_primary_key',
|
|
18
|
+
'infer_time_column',
|
|
17
19
|
]
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from typing import Dict
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import pandas as pd
|
|
5
3
|
import pyarrow as pa
|
|
6
4
|
from kumoapi.typing import Dtype
|
|
7
5
|
|
|
8
|
-
PANDAS_TO_DTYPE:
|
|
6
|
+
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
9
7
|
'bool': Dtype.bool,
|
|
10
8
|
'boolean': Dtype.bool,
|
|
11
9
|
'int8': Dtype.int,
|
|
12
10
|
'int16': Dtype.int,
|
|
13
11
|
'int32': Dtype.int,
|
|
14
12
|
'int64': Dtype.int,
|
|
13
|
+
'float': Dtype.float,
|
|
14
|
+
'double': Dtype.float,
|
|
15
15
|
'float16': Dtype.float,
|
|
16
16
|
'float32': Dtype.float,
|
|
17
17
|
'float64': Dtype.float,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import re
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Optional
|
|
4
3
|
|
|
5
4
|
import pandas as pd
|
|
6
5
|
|
|
@@ -9,7 +8,7 @@ def infer_primary_key(
|
|
|
9
8
|
table_name: str,
|
|
10
9
|
df: pd.DataFrame,
|
|
11
10
|
candidates: list[str],
|
|
12
|
-
) ->
|
|
11
|
+
) -> str | None:
|
|
13
12
|
r"""Auto-detect potential primary key column.
|
|
14
13
|
|
|
15
14
|
Args:
|
|
@@ -20,6 +19,9 @@ def infer_primary_key(
|
|
|
20
19
|
Returns:
|
|
21
20
|
The name of the detected primary key, or ``None`` if not found.
|
|
22
21
|
"""
|
|
22
|
+
if len(candidates) == 0:
|
|
23
|
+
return None
|
|
24
|
+
|
|
23
25
|
# A list of (potentially modified) table names that are eligible to match
|
|
24
26
|
# with a primary key, i.e.:
|
|
25
27
|
# - UserInfo -> User
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from kumoapi.typing import Dtype, Stype
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.infer import (
|
|
5
|
+
contains_categorical,
|
|
6
|
+
contains_id,
|
|
7
|
+
contains_multicategorical,
|
|
8
|
+
contains_timestamp,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
|
|
13
|
+
"""Infers the :class:`Stype` from a :class:`pandas.Series`.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
17
|
+
column_name: The column name.
|
|
18
|
+
dtype: The data type.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The semantic type.
|
|
22
|
+
"""
|
|
23
|
+
if contains_id(ser, column_name, dtype):
|
|
24
|
+
return Stype.ID
|
|
25
|
+
|
|
26
|
+
if contains_timestamp(ser, column_name, dtype):
|
|
27
|
+
return Stype.timestamp
|
|
28
|
+
|
|
29
|
+
if contains_multicategorical(ser, column_name, dtype):
|
|
30
|
+
return Stype.multicategorical
|
|
31
|
+
|
|
32
|
+
if contains_categorical(ser, column_name, dtype):
|
|
33
|
+
return Stype.categorical
|
|
34
|
+
|
|
35
|
+
return dtype.default_stype
|