kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/client/pquery.py +6 -2
- kumoai/experimental/rfm/__init__.py +33 -8
- kumoai/experimental/rfm/authenticate.py +3 -4
- kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
- kumoai/experimental/rfm/backend/local/sampler.py +128 -55
- kumoai/experimental/rfm/backend/local/table.py +21 -16
- kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
- kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
- kumoai/experimental/rfm/backend/snow/table.py +101 -49
- kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
- kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
- kumoai/experimental/rfm/base/__init__.py +24 -5
- kumoai/experimental/rfm/base/column.py +14 -12
- kumoai/experimental/rfm/base/column_expression.py +50 -0
- kumoai/experimental/rfm/base/sampler.py +429 -30
- kumoai/experimental/rfm/base/source.py +1 -0
- kumoai/experimental/rfm/base/sql_sampler.py +84 -0
- kumoai/experimental/rfm/base/sql_table.py +229 -0
- kumoai/experimental/rfm/base/table.py +165 -135
- kumoai/experimental/rfm/graph.py +266 -102
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +3 -3
- kumoai/experimental/rfm/infer/pkey.py +4 -2
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/infer/time_col.py +1 -2
- kumoai/experimental/rfm/pquery/executor.py +27 -27
- kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
- kumoai/experimental/rfm/rfm.py +299 -230
- kumoai/experimental/rfm/sagemaker.py +4 -4
- kumoai/pquery/predictive_query.py +10 -6
- kumoai/testing/snow.py +50 -0
- kumoai/utils/__init__.py +3 -2
- kumoai/utils/progress_logger.py +178 -12
- kumoai/utils/sql.py +3 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/METADATA +3 -2
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/RECORD +41 -35
- kumoai/experimental/rfm/local_graph_sampler.py +0 -223
- kumoai/experimental/rfm/local_pquery_driver.py +0 -689
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import copy
|
|
2
3
|
import io
|
|
3
4
|
import warnings
|
|
4
5
|
from collections import defaultdict
|
|
6
|
+
from collections.abc import Sequence
|
|
5
7
|
from dataclasses import dataclass, field
|
|
8
|
+
from itertools import chain
|
|
6
9
|
from pathlib import Path
|
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
|
10
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
8
11
|
|
|
9
12
|
import pandas as pd
|
|
10
13
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -13,7 +16,12 @@ from kumoapi.typing import Stype
|
|
|
13
16
|
from typing_extensions import Self
|
|
14
17
|
|
|
15
18
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
16
|
-
from kumoai.experimental.rfm import
|
|
19
|
+
from kumoai.experimental.rfm.base import (
|
|
20
|
+
ColumnExpressionSpec,
|
|
21
|
+
DataBackend,
|
|
22
|
+
SQLTable,
|
|
23
|
+
Table,
|
|
24
|
+
)
|
|
17
25
|
from kumoai.graph import Edge
|
|
18
26
|
from kumoai.mixin import CastMixin
|
|
19
27
|
|
|
@@ -25,8 +33,8 @@ if TYPE_CHECKING:
|
|
|
25
33
|
|
|
26
34
|
@dataclass
|
|
27
35
|
class SqliteConnectionConfig(CastMixin):
|
|
28
|
-
uri:
|
|
29
|
-
kwargs:
|
|
36
|
+
uri: str | Path
|
|
37
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
30
38
|
|
|
31
39
|
|
|
32
40
|
class Graph:
|
|
@@ -86,27 +94,38 @@ class Graph:
|
|
|
86
94
|
def __init__(
|
|
87
95
|
self,
|
|
88
96
|
tables: Sequence[Table],
|
|
89
|
-
edges:
|
|
97
|
+
edges: Sequence[Edge] | None = None,
|
|
90
98
|
) -> None:
|
|
91
99
|
|
|
92
|
-
self._tables:
|
|
93
|
-
self._edges:
|
|
100
|
+
self._tables: dict[str, Table] = {}
|
|
101
|
+
self._edges: list[Edge] = []
|
|
94
102
|
|
|
95
103
|
for table in tables:
|
|
96
104
|
self.add_table(table)
|
|
97
105
|
|
|
98
106
|
for table in tables:
|
|
107
|
+
if not isinstance(table, SQLTable):
|
|
108
|
+
continue
|
|
109
|
+
if '_source_column_dict' not in table.__dict__:
|
|
110
|
+
continue
|
|
99
111
|
for fkey in table._source_foreign_key_dict.values():
|
|
100
|
-
if fkey.name not in table
|
|
112
|
+
if fkey.name not in table:
|
|
113
|
+
continue
|
|
114
|
+
if not table[fkey.name].is_physical:
|
|
115
|
+
continue
|
|
116
|
+
dst_table_names = [
|
|
117
|
+
table.name for table in self.tables.values()
|
|
118
|
+
if isinstance(table, SQLTable)
|
|
119
|
+
and table._source_name == fkey.dst_table
|
|
120
|
+
]
|
|
121
|
+
if len(dst_table_names) != 1:
|
|
122
|
+
continue
|
|
123
|
+
dst_table = self[dst_table_names[0]]
|
|
124
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
125
|
+
continue
|
|
126
|
+
if not dst_table[fkey.primary_key].is_physical:
|
|
101
127
|
continue
|
|
102
|
-
|
|
103
|
-
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
104
|
-
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
105
|
-
raise ValueError(f"Found duplicate primary key definition "
|
|
106
|
-
f"'{self[fkey.dst_table]._primary_key}' "
|
|
107
|
-
f"and '{fkey.primary_key}' in table "
|
|
108
|
-
f"'{fkey.dst_table}'.")
|
|
109
|
-
self.link(table.name, fkey.name, fkey.dst_table)
|
|
128
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
110
129
|
|
|
111
130
|
for edge in (edges or []):
|
|
112
131
|
_edge = Edge._cast(edge)
|
|
@@ -117,8 +136,8 @@ class Graph:
|
|
|
117
136
|
@classmethod
|
|
118
137
|
def from_data(
|
|
119
138
|
cls,
|
|
120
|
-
df_dict:
|
|
121
|
-
edges:
|
|
139
|
+
df_dict: dict[str, pd.DataFrame],
|
|
140
|
+
edges: Sequence[Edge] | None = None,
|
|
122
141
|
infer_metadata: bool = True,
|
|
123
142
|
verbose: bool = True,
|
|
124
143
|
) -> Self:
|
|
@@ -156,15 +175,17 @@ class Graph:
|
|
|
156
175
|
verbose: Whether to print verbose output.
|
|
157
176
|
"""
|
|
158
177
|
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
159
|
-
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
160
178
|
|
|
161
|
-
graph = cls(
|
|
179
|
+
graph = cls(
|
|
180
|
+
tables=[LocalTable(df, name) for name, df in df_dict.items()],
|
|
181
|
+
edges=edges or [],
|
|
182
|
+
)
|
|
162
183
|
|
|
163
184
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(False)
|
|
185
|
+
graph.infer_metadata(verbose=False)
|
|
165
186
|
|
|
166
187
|
if edges is None:
|
|
167
|
-
graph.infer_links(False)
|
|
188
|
+
graph.infer_links(verbose=False)
|
|
168
189
|
|
|
169
190
|
if verbose:
|
|
170
191
|
graph.print_metadata()
|
|
@@ -180,10 +201,10 @@ class Graph:
|
|
|
180
201
|
SqliteConnectionConfig,
|
|
181
202
|
str,
|
|
182
203
|
Path,
|
|
183
|
-
|
|
204
|
+
dict[str, Any],
|
|
184
205
|
],
|
|
185
|
-
|
|
186
|
-
edges:
|
|
206
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
207
|
+
edges: Sequence[Edge] | None = None,
|
|
187
208
|
infer_metadata: bool = True,
|
|
188
209
|
verbose: bool = True,
|
|
189
210
|
) -> Self:
|
|
@@ -199,17 +220,25 @@ class Graph:
|
|
|
199
220
|
>>> # Create a graph from a SQLite database:
|
|
200
221
|
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
201
222
|
|
|
223
|
+
>>> # Fine-grained control over table specification:
|
|
224
|
+
>>> graph = rfm.Graph.from_sqlite('data.db', tables=[
|
|
225
|
+
... 'USERS',
|
|
226
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
227
|
+
... dict(name='ITEMS', primary_key='ITEM_ID'),
|
|
228
|
+
... ])
|
|
229
|
+
|
|
202
230
|
Args:
|
|
203
231
|
connection: An open connection from
|
|
204
232
|
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
205
233
|
path to the database file.
|
|
206
|
-
|
|
207
|
-
all tables present
|
|
234
|
+
tables: Set of table names or :class:`SQLiteTable` keyword
|
|
235
|
+
arguments to include. If ``None``, will add all tables present
|
|
236
|
+
in the database.
|
|
208
237
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
209
238
|
add to the graph. If not provided, edges will be automatically
|
|
210
239
|
inferred from the data in case ``infer_metadata=True``.
|
|
211
|
-
infer_metadata: Whether to infer metadata for all tables in
|
|
212
|
-
graph.
|
|
240
|
+
infer_metadata: Whether to infer missing metadata for all tables in
|
|
241
|
+
the graph.
|
|
213
242
|
verbose: Whether to print verbose output.
|
|
214
243
|
"""
|
|
215
244
|
from kumoai.experimental.rfm.backend.sqlite import (
|
|
@@ -218,27 +247,41 @@ class Graph:
|
|
|
218
247
|
connect,
|
|
219
248
|
)
|
|
220
249
|
|
|
250
|
+
internal_connection = False
|
|
221
251
|
if not isinstance(connection, Connection):
|
|
222
252
|
connection = SqliteConnectionConfig._cast(connection)
|
|
223
253
|
assert isinstance(connection, SqliteConnectionConfig)
|
|
224
254
|
connection = connect(connection.uri, **connection.kwargs)
|
|
255
|
+
internal_connection = True
|
|
225
256
|
assert isinstance(connection, Connection)
|
|
226
257
|
|
|
227
|
-
if
|
|
258
|
+
if tables is None:
|
|
228
259
|
with connection.cursor() as cursor:
|
|
229
260
|
cursor.execute("SELECT name FROM sqlite_master "
|
|
230
261
|
"WHERE type='table'")
|
|
231
|
-
|
|
262
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
232
263
|
|
|
233
|
-
|
|
264
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
265
|
+
for table in tables:
|
|
266
|
+
kwargs = dict(name=table) if isinstance(table, str) else table
|
|
267
|
+
table_kwargs.append(kwargs)
|
|
268
|
+
|
|
269
|
+
graph = cls(
|
|
270
|
+
tables=[
|
|
271
|
+
SQLiteTable(connection=connection, **kwargs)
|
|
272
|
+
for kwargs in table_kwargs
|
|
273
|
+
],
|
|
274
|
+
edges=edges or [],
|
|
275
|
+
)
|
|
234
276
|
|
|
235
|
-
|
|
277
|
+
if internal_connection:
|
|
278
|
+
graph._connection = connection # type: ignore
|
|
236
279
|
|
|
237
280
|
if infer_metadata:
|
|
238
|
-
graph.infer_metadata(False)
|
|
281
|
+
graph.infer_metadata(verbose=False)
|
|
239
282
|
|
|
240
283
|
if edges is None:
|
|
241
|
-
graph.infer_links(False)
|
|
284
|
+
graph.infer_links(verbose=False)
|
|
242
285
|
|
|
243
286
|
if verbose:
|
|
244
287
|
graph.print_metadata()
|
|
@@ -249,11 +292,11 @@ class Graph:
|
|
|
249
292
|
@classmethod
|
|
250
293
|
def from_snowflake(
|
|
251
294
|
cls,
|
|
252
|
-
connection: Union['SnowflakeConnection',
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
edges:
|
|
295
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
296
|
+
tables: Sequence[str | dict[str, Any]] | None = None,
|
|
297
|
+
database: str | None = None,
|
|
298
|
+
schema: str | None = None,
|
|
299
|
+
edges: Sequence[Edge] | None = None,
|
|
257
300
|
infer_metadata: bool = True,
|
|
258
301
|
verbose: bool = True,
|
|
259
302
|
) -> Self:
|
|
@@ -270,6 +313,13 @@ class Graph:
|
|
|
270
313
|
>>> # Create a graph directly in a Snowflake notebook:
|
|
271
314
|
>>> graph = rfm.Graph.from_snowflake(schema='my_schema')
|
|
272
315
|
|
|
316
|
+
>>> # Fine-grained control over table specification:
|
|
317
|
+
>>> graph = rfm.Graph.from_snowflake(tables=[
|
|
318
|
+
... 'USERS',
|
|
319
|
+
... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
|
|
320
|
+
... dict(name='ITEMS', schema='OTHER_SCHEMA'),
|
|
321
|
+
... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
|
|
322
|
+
|
|
273
323
|
Args:
|
|
274
324
|
connection: An open connection from
|
|
275
325
|
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
@@ -277,10 +327,11 @@ class Graph:
|
|
|
277
327
|
connection. If ``None``, will re-use an active session in case
|
|
278
328
|
it exists, or create a new connection from credentials stored
|
|
279
329
|
in environment variables.
|
|
330
|
+
tables: Set of table names or :class:`SnowTable` keyword arguments
|
|
331
|
+
to include. If ``None``, will add all tables present in the
|
|
332
|
+
current database and schema.
|
|
280
333
|
database: The database.
|
|
281
334
|
schema: The schema.
|
|
282
|
-
table_names: Set of table names to include. If ``None``, will add
|
|
283
|
-
all tables present in the database.
|
|
284
335
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
285
336
|
add to the graph. If not provided, edges will be automatically
|
|
286
337
|
inferred from the data in case ``infer_metadata=True``.
|
|
@@ -298,37 +349,50 @@ class Graph:
|
|
|
298
349
|
connection = connect(**(connection or {}))
|
|
299
350
|
assert isinstance(connection, Connection)
|
|
300
351
|
|
|
301
|
-
if
|
|
352
|
+
if database is None or schema is None:
|
|
353
|
+
with connection.cursor() as cursor:
|
|
354
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
355
|
+
result = cursor.fetchone()
|
|
356
|
+
database = database or result[0]
|
|
357
|
+
assert database is not None
|
|
358
|
+
schema = schema or result[1]
|
|
359
|
+
|
|
360
|
+
if tables is None:
|
|
361
|
+
if schema is None:
|
|
362
|
+
raise ValueError("No current 'schema' set. Please specify the "
|
|
363
|
+
"Snowflake schema manually")
|
|
364
|
+
|
|
302
365
|
with connection.cursor() as cursor:
|
|
303
|
-
if database is None and schema is None:
|
|
304
|
-
cursor.execute("SELECT CURRENT_DATABASE(), "
|
|
305
|
-
"CURRENT_SCHEMA()")
|
|
306
|
-
result = cursor.fetchone()
|
|
307
|
-
database = database or result[0]
|
|
308
|
-
schema = schema or result[1]
|
|
309
366
|
cursor.execute(f"""
|
|
310
367
|
SELECT TABLE_NAME
|
|
311
368
|
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
312
369
|
WHERE TABLE_SCHEMA = '{schema}'
|
|
313
370
|
""")
|
|
314
|
-
|
|
371
|
+
tables = [row[0] for row in cursor.fetchall()]
|
|
315
372
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
name=
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
373
|
+
table_kwargs: list[dict[str, Any]] = []
|
|
374
|
+
for table in tables:
|
|
375
|
+
if isinstance(table, str):
|
|
376
|
+
kwargs = dict(name=table, database=database, schema=schema)
|
|
377
|
+
else:
|
|
378
|
+
kwargs = copy.copy(table)
|
|
379
|
+
kwargs.setdefault('database', database)
|
|
380
|
+
kwargs.setdefault('schema', schema)
|
|
381
|
+
table_kwargs.append(kwargs)
|
|
382
|
+
|
|
383
|
+
graph = cls(
|
|
384
|
+
tables=[
|
|
385
|
+
SnowTable(connection=connection, **kwargs)
|
|
386
|
+
for kwargs in table_kwargs
|
|
387
|
+
],
|
|
388
|
+
edges=edges or [],
|
|
389
|
+
)
|
|
326
390
|
|
|
327
391
|
if infer_metadata:
|
|
328
|
-
graph.infer_metadata(False)
|
|
392
|
+
graph.infer_metadata(verbose=False)
|
|
329
393
|
|
|
330
394
|
if edges is None:
|
|
331
|
-
graph.infer_links(False)
|
|
395
|
+
graph.infer_links(verbose=False)
|
|
332
396
|
|
|
333
397
|
if verbose:
|
|
334
398
|
graph.print_metadata()
|
|
@@ -340,7 +404,7 @@ class Graph:
|
|
|
340
404
|
def from_snowflake_semantic_view(
|
|
341
405
|
cls,
|
|
342
406
|
semantic_view_name: str,
|
|
343
|
-
connection: Union['SnowflakeConnection',
|
|
407
|
+
connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
|
|
344
408
|
verbose: bool = True,
|
|
345
409
|
) -> Self:
|
|
346
410
|
import yaml
|
|
@@ -358,43 +422,138 @@ class Graph:
|
|
|
358
422
|
with connection.cursor() as cursor:
|
|
359
423
|
cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
360
424
|
f"'{semantic_view_name}')")
|
|
361
|
-
|
|
425
|
+
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
362
426
|
|
|
363
427
|
graph = cls(tables=[])
|
|
364
428
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
429
|
+
msgs = []
|
|
430
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
431
|
+
for table_cfg in cfg['tables']:
|
|
432
|
+
table_name = table_cfg['name']
|
|
433
|
+
source_table_name = table_cfg['base_table']['table']
|
|
434
|
+
database = table_cfg['base_table']['database']
|
|
435
|
+
schema = table_cfg['base_table']['schema']
|
|
436
|
+
|
|
437
|
+
primary_key: str | None = None
|
|
438
|
+
if 'primary_key' in table_cfg:
|
|
439
|
+
primary_key_cfg = table_cfg['primary_key']
|
|
440
|
+
if len(primary_key_cfg['columns']) == 1:
|
|
441
|
+
primary_key = primary_key_cfg['columns'][0]
|
|
442
|
+
elif len(primary_key_cfg['columns']) > 1:
|
|
443
|
+
msgs.append(f"Failed to add primary key for table "
|
|
444
|
+
f"'{table_name}' since composite primary keys "
|
|
445
|
+
f"are not yet supported")
|
|
446
|
+
|
|
447
|
+
columns: list[str] = []
|
|
448
|
+
unsupported_columns: list[str] = []
|
|
449
|
+
column_expression_specs: list[ColumnExpressionSpec] = []
|
|
450
|
+
for column_cfg in chain(
|
|
451
|
+
table_cfg.get('dimensions', []),
|
|
452
|
+
table_cfg.get('time_dimensions', []),
|
|
453
|
+
table_cfg.get('facts', []),
|
|
454
|
+
):
|
|
455
|
+
column_name = column_cfg['name']
|
|
456
|
+
column_expr = column_cfg.get('expr', None)
|
|
457
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
458
|
+
|
|
459
|
+
if column_expr is None:
|
|
460
|
+
columns.append(column_name)
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
464
|
+
|
|
465
|
+
if column_expr == column_name:
|
|
466
|
+
columns.append(column_name)
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
# Drop expressions that reference other tables (for now):
|
|
470
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
471
|
+
unsupported_columns.append(column_name)
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
spec = ColumnExpressionSpec(
|
|
475
|
+
name=column_name,
|
|
476
|
+
expr=column_expr,
|
|
477
|
+
dtype=SnowTable.to_dtype(column_data_type),
|
|
478
|
+
)
|
|
479
|
+
column_expression_specs.append(spec)
|
|
480
|
+
|
|
481
|
+
if len(unsupported_columns) == 1:
|
|
482
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
483
|
+
f"of table '{table_name}' since its expression "
|
|
484
|
+
f"references other tables")
|
|
485
|
+
elif len(unsupported_columns) > 1:
|
|
486
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
487
|
+
f"of table '{table_name}' since their expressions "
|
|
488
|
+
f"reference other tables")
|
|
370
489
|
|
|
371
490
|
table = SnowTable(
|
|
372
491
|
connection,
|
|
373
|
-
name=
|
|
374
|
-
|
|
375
|
-
|
|
492
|
+
name=table_name,
|
|
493
|
+
source_name=source_table_name,
|
|
494
|
+
database=database,
|
|
495
|
+
schema=schema,
|
|
496
|
+
columns=columns,
|
|
497
|
+
column_expressions=column_expression_specs,
|
|
376
498
|
primary_key=primary_key,
|
|
377
499
|
)
|
|
500
|
+
|
|
501
|
+
# TODO Add a way to register time columns without heuristic usage.
|
|
502
|
+
table.infer_time_column(verbose=False)
|
|
503
|
+
|
|
378
504
|
graph.add_table(table)
|
|
379
505
|
|
|
380
|
-
|
|
506
|
+
for relation_cfg in cfg.get('relationships', []):
|
|
507
|
+
name = relation_cfg['name']
|
|
508
|
+
if len(relation_cfg['relationship_columns']) != 1:
|
|
509
|
+
msgs.append(f"Failed to add relationship '{name}' since "
|
|
510
|
+
f"composite key references are not yet supported")
|
|
511
|
+
continue
|
|
381
512
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
513
|
+
left_table = relation_cfg['left_table']
|
|
514
|
+
left_key = relation_cfg['relationship_columns'][0]['left_column']
|
|
515
|
+
right_table = relation_cfg['right_table']
|
|
516
|
+
right_key = relation_cfg['relationship_columns'][0]['right_column']
|
|
517
|
+
|
|
518
|
+
if graph[right_table]._primary_key != right_key:
|
|
519
|
+
# Semantic view error - this should never be triggered:
|
|
520
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
521
|
+
f"referenced key '{right_key}' of table "
|
|
522
|
+
f"'{right_table}' is not a primary key")
|
|
523
|
+
continue
|
|
524
|
+
|
|
525
|
+
if graph[left_table]._primary_key == left_key:
|
|
526
|
+
msgs.append(f"Failed to add relationship '{name}' since the "
|
|
527
|
+
f"referencing key '{left_key}' of table "
|
|
528
|
+
f"'{left_table}' is a primary key")
|
|
529
|
+
continue
|
|
530
|
+
|
|
531
|
+
if left_key not in graph[left_table]:
|
|
532
|
+
graph[left_table].add_column(left_key)
|
|
533
|
+
|
|
534
|
+
graph.link(left_table, left_key, right_table)
|
|
535
|
+
|
|
536
|
+
graph.validate()
|
|
390
537
|
|
|
391
538
|
if verbose:
|
|
392
539
|
graph.print_metadata()
|
|
393
540
|
graph.print_links()
|
|
394
541
|
|
|
542
|
+
if len(msgs) > 0:
|
|
543
|
+
title = (f"Could not fully convert the semantic view definition "
|
|
544
|
+
f"'{semantic_view_name}' into a graph:\n")
|
|
545
|
+
warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
|
|
546
|
+
|
|
395
547
|
return graph
|
|
396
548
|
|
|
397
|
-
#
|
|
549
|
+
# Backend #################################################################
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def backend(self) -> DataBackend | None:
|
|
553
|
+
backends = [table.backend for table in self._tables.values()]
|
|
554
|
+
return backends[0] if len(backends) > 0 else None
|
|
555
|
+
|
|
556
|
+
# Tables ##################################################################
|
|
398
557
|
|
|
399
558
|
def has_table(self, name: str) -> bool:
|
|
400
559
|
r"""Returns ``True`` if the graph has a table with name ``name``;
|
|
@@ -413,7 +572,7 @@ class Graph:
|
|
|
413
572
|
return self.tables[name]
|
|
414
573
|
|
|
415
574
|
@property
|
|
416
|
-
def tables(self) ->
|
|
575
|
+
def tables(self) -> dict[str, Table]:
|
|
417
576
|
r"""Returns the dictionary of table objects."""
|
|
418
577
|
return self._tables
|
|
419
578
|
|
|
@@ -433,13 +592,10 @@ class Graph:
|
|
|
433
592
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
434
593
|
f"this graph; table names must be globally unique.")
|
|
435
594
|
|
|
436
|
-
if
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
f"'{table.__class__.__name__}' to this "
|
|
441
|
-
f"graph since other tables are of type "
|
|
442
|
-
f"'{cls.__name__}'.")
|
|
595
|
+
if self.backend is not None and table.backend != self.backend:
|
|
596
|
+
raise ValueError(f"Cannot register a table with backend "
|
|
597
|
+
f"'{table.backend}' to this graph since other "
|
|
598
|
+
f"tables have backend '{self.backend}'.")
|
|
443
599
|
|
|
444
600
|
self._tables[table.name] = table
|
|
445
601
|
|
|
@@ -541,7 +697,7 @@ class Graph:
|
|
|
541
697
|
# Edges ###################################################################
|
|
542
698
|
|
|
543
699
|
@property
|
|
544
|
-
def edges(self) ->
|
|
700
|
+
def edges(self) -> list[Edge]:
|
|
545
701
|
r"""Returns the edges of the graph."""
|
|
546
702
|
return self._edges
|
|
547
703
|
|
|
@@ -556,7 +712,7 @@ class Graph:
|
|
|
556
712
|
st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
|
|
557
713
|
if len(edges) > 0:
|
|
558
714
|
st.markdown('\n'.join([
|
|
559
|
-
f"-
|
|
715
|
+
f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
560
716
|
for edge in edges
|
|
561
717
|
]))
|
|
562
718
|
else:
|
|
@@ -584,9 +740,9 @@ class Graph:
|
|
|
584
740
|
|
|
585
741
|
def link(
|
|
586
742
|
self,
|
|
587
|
-
src_table:
|
|
743
|
+
src_table: str | Table,
|
|
588
744
|
fkey: str,
|
|
589
|
-
dst_table:
|
|
745
|
+
dst_table: str | Table,
|
|
590
746
|
) -> Self:
|
|
591
747
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
592
748
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -647,9 +803,9 @@ class Graph:
|
|
|
647
803
|
|
|
648
804
|
def unlink(
|
|
649
805
|
self,
|
|
650
|
-
src_table:
|
|
806
|
+
src_table: str | Table,
|
|
651
807
|
fkey: str,
|
|
652
|
-
dst_table:
|
|
808
|
+
dst_table: str | Table,
|
|
653
809
|
) -> Self:
|
|
654
810
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
655
811
|
|
|
@@ -826,6 +982,10 @@ class Graph:
|
|
|
826
982
|
raise ValueError("At least one table needs to be added to the "
|
|
827
983
|
"graph")
|
|
828
984
|
|
|
985
|
+
backends = {table.backend for table in self._tables.values()}
|
|
986
|
+
if len(backends) != 1:
|
|
987
|
+
raise ValueError("Found multiple table backends in the graph")
|
|
988
|
+
|
|
829
989
|
for edge in self.edges:
|
|
830
990
|
src_table, fkey, dst_table = edge
|
|
831
991
|
|
|
@@ -872,7 +1032,7 @@ class Graph:
|
|
|
872
1032
|
|
|
873
1033
|
def visualize(
|
|
874
1034
|
self,
|
|
875
|
-
path:
|
|
1035
|
+
path: str | io.BytesIO | None = None,
|
|
876
1036
|
show_columns: bool = True,
|
|
877
1037
|
) -> 'graphviz.Graph':
|
|
878
1038
|
r"""Visualizes the tables and edges in this graph using the
|
|
@@ -911,19 +1071,19 @@ class Graph:
|
|
|
911
1071
|
"them as described at "
|
|
912
1072
|
"https://graphviz.org/download/.")
|
|
913
1073
|
|
|
914
|
-
format:
|
|
1074
|
+
format: str | None = None
|
|
915
1075
|
if isinstance(path, str):
|
|
916
1076
|
format = path.split('.')[-1]
|
|
917
1077
|
elif isinstance(path, io.BytesIO):
|
|
918
1078
|
format = 'svg'
|
|
919
1079
|
graph = graphviz.Graph(format=format)
|
|
920
1080
|
|
|
921
|
-
def left_align(keys:
|
|
1081
|
+
def left_align(keys: list[str]) -> str:
|
|
922
1082
|
if len(keys) == 0:
|
|
923
1083
|
return ""
|
|
924
1084
|
return '\\l'.join(keys) + '\\l'
|
|
925
1085
|
|
|
926
|
-
fkeys_dict:
|
|
1086
|
+
fkeys_dict: dict[str, list[str]] = defaultdict(list)
|
|
927
1087
|
for src_table_name, fkey_name, _ in self.edges:
|
|
928
1088
|
fkeys_dict[src_table_name].append(fkey_name)
|
|
929
1089
|
|
|
@@ -1019,8 +1179,8 @@ class Graph:
|
|
|
1019
1179
|
# Helpers #################################################################
|
|
1020
1180
|
|
|
1021
1181
|
def _to_api_graph_definition(self) -> GraphDefinition:
|
|
1022
|
-
tables:
|
|
1023
|
-
col_groups:
|
|
1182
|
+
tables: dict[str, TableDefinition] = {}
|
|
1183
|
+
col_groups: list[ColumnKeyGroup] = []
|
|
1024
1184
|
for table_name, table in self.tables.items():
|
|
1025
1185
|
tables[table_name] = table._to_api_table_definition()
|
|
1026
1186
|
if table.primary_key is None:
|
|
@@ -1063,3 +1223,7 @@ class Graph:
|
|
|
1063
1223
|
f' tables={tables},\n'
|
|
1064
1224
|
f' edges={edges},\n'
|
|
1065
1225
|
f')')
|
|
1226
|
+
|
|
1227
|
+
def __del__(self) -> None:
|
|
1228
|
+
if hasattr(self, '_connection'):
|
|
1229
|
+
self._connection.close()
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
from .dtype import infer_dtype
|
|
2
|
-
from .pkey import infer_primary_key
|
|
3
|
-
from .time_col import infer_time_column
|
|
4
2
|
from .id import contains_id
|
|
5
3
|
from .timestamp import contains_timestamp
|
|
6
4
|
from .categorical import contains_categorical
|
|
7
5
|
from .multicategorical import contains_multicategorical
|
|
6
|
+
from .stype import infer_stype
|
|
7
|
+
from .pkey import infer_primary_key
|
|
8
|
+
from .time_col import infer_time_column
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
'infer_dtype',
|
|
11
|
-
'infer_primary_key',
|
|
12
|
-
'infer_time_column',
|
|
13
12
|
'contains_id',
|
|
14
13
|
'contains_timestamp',
|
|
15
14
|
'contains_categorical',
|
|
16
15
|
'contains_multicategorical',
|
|
16
|
+
'infer_stype',
|
|
17
|
+
'infer_primary_key',
|
|
18
|
+
'infer_time_column',
|
|
17
19
|
]
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from typing import Dict
|
|
2
|
-
|
|
3
1
|
import numpy as np
|
|
4
2
|
import pandas as pd
|
|
5
3
|
import pyarrow as pa
|
|
6
4
|
from kumoapi.typing import Dtype
|
|
7
5
|
|
|
8
|
-
PANDAS_TO_DTYPE:
|
|
6
|
+
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
9
7
|
'bool': Dtype.bool,
|
|
10
8
|
'boolean': Dtype.bool,
|
|
11
9
|
'int8': Dtype.int,
|
|
12
10
|
'int16': Dtype.int,
|
|
13
11
|
'int32': Dtype.int,
|
|
14
12
|
'int64': Dtype.int,
|
|
13
|
+
'float': Dtype.float,
|
|
14
|
+
'double': Dtype.float,
|
|
15
15
|
'float16': Dtype.float,
|
|
16
16
|
'float32': Dtype.float,
|
|
17
17
|
'float64': Dtype.float,
|