kumoai 2.13.0.dev202512091732__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/client/pquery.py +6 -2
  3. kumoai/experimental/rfm/__init__.py +33 -8
  4. kumoai/experimental/rfm/authenticate.py +3 -4
  5. kumoai/experimental/rfm/backend/local/graph_store.py +40 -83
  6. kumoai/experimental/rfm/backend/local/sampler.py +128 -55
  7. kumoai/experimental/rfm/backend/local/table.py +21 -16
  8. kumoai/experimental/rfm/backend/snow/__init__.py +2 -0
  9. kumoai/experimental/rfm/backend/snow/sampler.py +252 -0
  10. kumoai/experimental/rfm/backend/snow/table.py +101 -49
  11. kumoai/experimental/rfm/backend/sqlite/__init__.py +4 -2
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +349 -0
  13. kumoai/experimental/rfm/backend/sqlite/table.py +84 -31
  14. kumoai/experimental/rfm/base/__init__.py +24 -5
  15. kumoai/experimental/rfm/base/column.py +14 -12
  16. kumoai/experimental/rfm/base/column_expression.py +50 -0
  17. kumoai/experimental/rfm/base/sampler.py +429 -30
  18. kumoai/experimental/rfm/base/source.py +1 -0
  19. kumoai/experimental/rfm/base/sql_sampler.py +84 -0
  20. kumoai/experimental/rfm/base/sql_table.py +229 -0
  21. kumoai/experimental/rfm/base/table.py +165 -135
  22. kumoai/experimental/rfm/graph.py +266 -102
  23. kumoai/experimental/rfm/infer/__init__.py +6 -4
  24. kumoai/experimental/rfm/infer/dtype.py +3 -3
  25. kumoai/experimental/rfm/infer/pkey.py +4 -2
  26. kumoai/experimental/rfm/infer/stype.py +35 -0
  27. kumoai/experimental/rfm/infer/time_col.py +1 -2
  28. kumoai/experimental/rfm/pquery/executor.py +27 -27
  29. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  30. kumoai/experimental/rfm/rfm.py +299 -230
  31. kumoai/experimental/rfm/sagemaker.py +4 -4
  32. kumoai/pquery/predictive_query.py +10 -6
  33. kumoai/testing/snow.py +50 -0
  34. kumoai/utils/__init__.py +3 -2
  35. kumoai/utils/progress_logger.py +178 -12
  36. kumoai/utils/sql.py +3 -0
  37. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/METADATA +3 -2
  38. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/RECORD +41 -35
  39. kumoai/experimental/rfm/local_graph_sampler.py +0 -223
  40. kumoai/experimental/rfm/local_pquery_driver.py +0 -689
  41. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/WHEEL +0 -0
  42. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/licenses/LICENSE +0 -0
  43. {kumoai-2.13.0.dev202512091732.dist-info → kumoai-2.14.0.dev202512191731.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
6
+ from collections.abc import Sequence
5
7
  from dataclasses import dataclass, field
8
+ from itertools import chain
6
9
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
10
+ from typing import TYPE_CHECKING, Any, Union
8
11
 
9
12
  import pandas as pd
10
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,7 +16,12 @@ from kumoapi.typing import Stype
13
16
  from typing_extensions import Self
14
17
 
15
18
  from kumoai import in_notebook, in_snowflake_notebook
16
- from kumoai.experimental.rfm import Table
19
+ from kumoai.experimental.rfm.base import (
20
+ ColumnExpressionSpec,
21
+ DataBackend,
22
+ SQLTable,
23
+ Table,
24
+ )
17
25
  from kumoai.graph import Edge
18
26
  from kumoai.mixin import CastMixin
19
27
 
@@ -25,8 +33,8 @@ if TYPE_CHECKING:
25
33
 
26
34
  @dataclass
27
35
  class SqliteConnectionConfig(CastMixin):
28
- uri: Union[str, Path]
29
- kwargs: Dict[str, Any] = field(default_factory=dict)
36
+ uri: str | Path
37
+ kwargs: dict[str, Any] = field(default_factory=dict)
30
38
 
31
39
 
32
40
  class Graph:
@@ -86,27 +94,38 @@ class Graph:
86
94
  def __init__(
87
95
  self,
88
96
  tables: Sequence[Table],
89
- edges: Optional[Sequence[Edge]] = None,
97
+ edges: Sequence[Edge] | None = None,
90
98
  ) -> None:
91
99
 
92
- self._tables: Dict[str, Table] = {}
93
- self._edges: List[Edge] = []
100
+ self._tables: dict[str, Table] = {}
101
+ self._edges: list[Edge] = []
94
102
 
95
103
  for table in tables:
96
104
  self.add_table(table)
97
105
 
98
106
  for table in tables:
107
+ if not isinstance(table, SQLTable):
108
+ continue
109
+ if '_source_column_dict' not in table.__dict__:
110
+ continue
99
111
  for fkey in table._source_foreign_key_dict.values():
100
- if fkey.name not in table or fkey.dst_table not in self:
112
+ if fkey.name not in table:
113
+ continue
114
+ if not table[fkey.name].is_physical:
115
+ continue
116
+ dst_table_names = [
117
+ table.name for table in self.tables.values()
118
+ if isinstance(table, SQLTable)
119
+ and table._source_name == fkey.dst_table
120
+ ]
121
+ if len(dst_table_names) != 1:
122
+ continue
123
+ dst_table = self[dst_table_names[0]]
124
+ if dst_table._primary_key != fkey.primary_key:
125
+ continue
126
+ if not dst_table[fkey.primary_key].is_physical:
101
127
  continue
102
- if self[fkey.dst_table].primary_key is None:
103
- self[fkey.dst_table].primary_key = fkey.primary_key
104
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
- raise ValueError(f"Found duplicate primary key definition "
106
- f"'{self[fkey.dst_table]._primary_key}' "
107
- f"and '{fkey.primary_key}' in table "
108
- f"'{fkey.dst_table}'.")
109
- self.link(table.name, fkey.name, fkey.dst_table)
128
+ self.link(table.name, fkey.name, dst_table.name)
110
129
 
111
130
  for edge in (edges or []):
112
131
  _edge = Edge._cast(edge)
@@ -117,8 +136,8 @@ class Graph:
117
136
  @classmethod
118
137
  def from_data(
119
138
  cls,
120
- df_dict: Dict[str, pd.DataFrame],
121
- edges: Optional[Sequence[Edge]] = None,
139
+ df_dict: dict[str, pd.DataFrame],
140
+ edges: Sequence[Edge] | None = None,
122
141
  infer_metadata: bool = True,
123
142
  verbose: bool = True,
124
143
  ) -> Self:
@@ -156,15 +175,17 @@ class Graph:
156
175
  verbose: Whether to print verbose output.
157
176
  """
158
177
  from kumoai.experimental.rfm.backend.local import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
178
 
161
- graph = cls(tables, edges=edges or [])
179
+ graph = cls(
180
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
181
+ edges=edges or [],
182
+ )
162
183
 
163
184
  if infer_metadata:
164
- graph.infer_metadata(False)
185
+ graph.infer_metadata(verbose=False)
165
186
 
166
187
  if edges is None:
167
- graph.infer_links(False)
188
+ graph.infer_links(verbose=False)
168
189
 
169
190
  if verbose:
170
191
  graph.print_metadata()
@@ -180,10 +201,10 @@ class Graph:
180
201
  SqliteConnectionConfig,
181
202
  str,
182
203
  Path,
183
- Dict[str, Any],
204
+ dict[str, Any],
184
205
  ],
185
- table_names: Optional[Sequence[str]] = None,
186
- edges: Optional[Sequence[Edge]] = None,
206
+ tables: Sequence[str | dict[str, Any]] | None = None,
207
+ edges: Sequence[Edge] | None = None,
187
208
  infer_metadata: bool = True,
188
209
  verbose: bool = True,
189
210
  ) -> Self:
@@ -199,17 +220,25 @@ class Graph:
199
220
  >>> # Create a graph from a SQLite database:
200
221
  >>> graph = rfm.Graph.from_sqlite('data.db')
201
222
 
223
+ >>> # Fine-grained control over table specification:
224
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
225
+ ... 'USERS',
226
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
227
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
228
+ ... ])
229
+
202
230
  Args:
203
231
  connection: An open connection from
204
232
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
233
  path to the database file.
206
- table_names: Set of table names to include. If ``None``, will add
207
- all tables present in the database.
234
+ tables: Set of table names or :class:`SQLiteTable` keyword
235
+ arguments to include. If ``None``, will add all tables present
236
+ in the database.
208
237
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
209
238
  add to the graph. If not provided, edges will be automatically
210
239
  inferred from the data in case ``infer_metadata=True``.
211
- infer_metadata: Whether to infer metadata for all tables in the
212
- graph.
240
+ infer_metadata: Whether to infer missing metadata for all tables in
241
+ the graph.
213
242
  verbose: Whether to print verbose output.
214
243
  """
215
244
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -218,27 +247,41 @@ class Graph:
218
247
  connect,
219
248
  )
220
249
 
250
+ internal_connection = False
221
251
  if not isinstance(connection, Connection):
222
252
  connection = SqliteConnectionConfig._cast(connection)
223
253
  assert isinstance(connection, SqliteConnectionConfig)
224
254
  connection = connect(connection.uri, **connection.kwargs)
255
+ internal_connection = True
225
256
  assert isinstance(connection, Connection)
226
257
 
227
- if table_names is None:
258
+ if tables is None:
228
259
  with connection.cursor() as cursor:
229
260
  cursor.execute("SELECT name FROM sqlite_master "
230
261
  "WHERE type='table'")
231
- table_names = [row[0] for row in cursor.fetchall()]
262
+ tables = [row[0] for row in cursor.fetchall()]
232
263
 
233
- tables = [SQLiteTable(connection, name) for name in table_names]
264
+ table_kwargs: list[dict[str, Any]] = []
265
+ for table in tables:
266
+ kwargs = dict(name=table) if isinstance(table, str) else table
267
+ table_kwargs.append(kwargs)
268
+
269
+ graph = cls(
270
+ tables=[
271
+ SQLiteTable(connection=connection, **kwargs)
272
+ for kwargs in table_kwargs
273
+ ],
274
+ edges=edges or [],
275
+ )
234
276
 
235
- graph = cls(tables, edges=edges or [])
277
+ if internal_connection:
278
+ graph._connection = connection # type: ignore
236
279
 
237
280
  if infer_metadata:
238
- graph.infer_metadata(False)
281
+ graph.infer_metadata(verbose=False)
239
282
 
240
283
  if edges is None:
241
- graph.infer_links(False)
284
+ graph.infer_links(verbose=False)
242
285
 
243
286
  if verbose:
244
287
  graph.print_metadata()
@@ -249,11 +292,11 @@ class Graph:
249
292
  @classmethod
250
293
  def from_snowflake(
251
294
  cls,
252
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
253
- database: Optional[str] = None,
254
- schema: Optional[str] = None,
255
- table_names: Optional[Sequence[str]] = None,
256
- edges: Optional[Sequence[Edge]] = None,
295
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
296
+ tables: Sequence[str | dict[str, Any]] | None = None,
297
+ database: str | None = None,
298
+ schema: str | None = None,
299
+ edges: Sequence[Edge] | None = None,
257
300
  infer_metadata: bool = True,
258
301
  verbose: bool = True,
259
302
  ) -> Self:
@@ -270,6 +313,13 @@ class Graph:
270
313
  >>> # Create a graph directly in a Snowflake notebook:
271
314
  >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
272
315
 
316
+ >>> # Fine-grained control over table specification:
317
+ >>> graph = rfm.Graph.from_snowflake(tables=[
318
+ ... 'USERS',
319
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
320
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
321
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
322
+
273
323
  Args:
274
324
  connection: An open connection from
275
325
  :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
@@ -277,10 +327,11 @@ class Graph:
277
327
  connection. If ``None``, will re-use an active session in case
278
328
  it exists, or create a new connection from credentials stored
279
329
  in environment variables.
330
+ tables: Set of table names or :class:`SnowTable` keyword arguments
331
+ to include. If ``None``, will add all tables present in the
332
+ current database and schema.
280
333
  database: The database.
281
334
  schema: The schema.
282
- table_names: Set of table names to include. If ``None``, will add
283
- all tables present in the database.
284
335
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
285
336
  add to the graph. If not provided, edges will be automatically
286
337
  inferred from the data in case ``infer_metadata=True``.
@@ -298,37 +349,50 @@ class Graph:
298
349
  connection = connect(**(connection or {}))
299
350
  assert isinstance(connection, Connection)
300
351
 
301
- if table_names is None:
352
+ if database is None or schema is None:
353
+ with connection.cursor() as cursor:
354
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
355
+ result = cursor.fetchone()
356
+ database = database or result[0]
357
+ assert database is not None
358
+ schema = schema or result[1]
359
+
360
+ if tables is None:
361
+ if schema is None:
362
+ raise ValueError("No current 'schema' set. Please specify the "
363
+ "Snowflake schema manually")
364
+
302
365
  with connection.cursor() as cursor:
303
- if database is None and schema is None:
304
- cursor.execute("SELECT CURRENT_DATABASE(), "
305
- "CURRENT_SCHEMA()")
306
- result = cursor.fetchone()
307
- database = database or result[0]
308
- schema = schema or result[1]
309
366
  cursor.execute(f"""
310
367
  SELECT TABLE_NAME
311
368
  FROM {database}.INFORMATION_SCHEMA.TABLES
312
369
  WHERE TABLE_SCHEMA = '{schema}'
313
370
  """)
314
- table_names = [row[0] for row in cursor.fetchall()]
371
+ tables = [row[0] for row in cursor.fetchall()]
315
372
 
316
- tables = [
317
- SnowTable(
318
- connection,
319
- name=table_name,
320
- database=database,
321
- schema=schema,
322
- ) for table_name in table_names
323
- ]
324
-
325
- graph = cls(tables, edges=edges or [])
373
+ table_kwargs: list[dict[str, Any]] = []
374
+ for table in tables:
375
+ if isinstance(table, str):
376
+ kwargs = dict(name=table, database=database, schema=schema)
377
+ else:
378
+ kwargs = copy.copy(table)
379
+ kwargs.setdefault('database', database)
380
+ kwargs.setdefault('schema', schema)
381
+ table_kwargs.append(kwargs)
382
+
383
+ graph = cls(
384
+ tables=[
385
+ SnowTable(connection=connection, **kwargs)
386
+ for kwargs in table_kwargs
387
+ ],
388
+ edges=edges or [],
389
+ )
326
390
 
327
391
  if infer_metadata:
328
- graph.infer_metadata(False)
392
+ graph.infer_metadata(verbose=False)
329
393
 
330
394
  if edges is None:
331
- graph.infer_links(False)
395
+ graph.infer_links(verbose=False)
332
396
 
333
397
  if verbose:
334
398
  graph.print_metadata()
@@ -340,7 +404,7 @@ class Graph:
340
404
  def from_snowflake_semantic_view(
341
405
  cls,
342
406
  semantic_view_name: str,
343
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
407
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
344
408
  verbose: bool = True,
345
409
  ) -> Self:
346
410
  import yaml
@@ -358,43 +422,138 @@ class Graph:
358
422
  with connection.cursor() as cursor:
359
423
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
360
424
  f"'{semantic_view_name}')")
361
- view = yaml.safe_load(cursor.fetchone()[0])
425
+ cfg = yaml.safe_load(cursor.fetchone()[0])
362
426
 
363
427
  graph = cls(tables=[])
364
428
 
365
- for table_desc in view['tables']:
366
- primary_key: Optional[str] = None
367
- if ('primary_key' in table_desc # NOTE No composite keys yet.
368
- and len(table_desc['primary_key']['columns']) == 1):
369
- primary_key = table_desc['primary_key']['columns'][0]
429
+ msgs = []
430
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
431
+ for table_cfg in cfg['tables']:
432
+ table_name = table_cfg['name']
433
+ source_table_name = table_cfg['base_table']['table']
434
+ database = table_cfg['base_table']['database']
435
+ schema = table_cfg['base_table']['schema']
436
+
437
+ primary_key: str | None = None
438
+ if 'primary_key' in table_cfg:
439
+ primary_key_cfg = table_cfg['primary_key']
440
+ if len(primary_key_cfg['columns']) == 1:
441
+ primary_key = primary_key_cfg['columns'][0]
442
+ elif len(primary_key_cfg['columns']) > 1:
443
+ msgs.append(f"Failed to add primary key for table "
444
+ f"'{table_name}' since composite primary keys "
445
+ f"are not yet supported")
446
+
447
+ columns: list[str] = []
448
+ unsupported_columns: list[str] = []
449
+ column_expression_specs: list[ColumnExpressionSpec] = []
450
+ for column_cfg in chain(
451
+ table_cfg.get('dimensions', []),
452
+ table_cfg.get('time_dimensions', []),
453
+ table_cfg.get('facts', []),
454
+ ):
455
+ column_name = column_cfg['name']
456
+ column_expr = column_cfg.get('expr', None)
457
+ column_data_type = column_cfg.get('data_type', None)
458
+
459
+ if column_expr is None:
460
+ columns.append(column_name)
461
+ continue
462
+
463
+ column_expr = column_expr.replace(f'{table_name}.', '')
464
+
465
+ if column_expr == column_name:
466
+ columns.append(column_name)
467
+ continue
468
+
469
+ # Drop expressions that reference other tables (for now):
470
+ if any(f'{name}.' in column_expr for name in table_names):
471
+ unsupported_columns.append(column_name)
472
+ continue
473
+
474
+ spec = ColumnExpressionSpec(
475
+ name=column_name,
476
+ expr=column_expr,
477
+ dtype=SnowTable.to_dtype(column_data_type),
478
+ )
479
+ column_expression_specs.append(spec)
480
+
481
+ if len(unsupported_columns) == 1:
482
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
483
+ f"of table '{table_name}' since its expression "
484
+ f"references other tables")
485
+ elif len(unsupported_columns) > 1:
486
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
487
+ f"of table '{table_name}' since their expressions "
488
+ f"reference other tables")
370
489
 
371
490
  table = SnowTable(
372
491
  connection,
373
- name=table_desc['base_table']['table'],
374
- database=table_desc['base_table']['database'],
375
- schema=table_desc['base_table']['schema'],
492
+ name=table_name,
493
+ source_name=source_table_name,
494
+ database=database,
495
+ schema=schema,
496
+ columns=columns,
497
+ column_expressions=column_expression_specs,
376
498
  primary_key=primary_key,
377
499
  )
500
+
501
+ # TODO Add a way to register time columns without heuristic usage.
502
+ table.infer_time_column(verbose=False)
503
+
378
504
  graph.add_table(table)
379
505
 
380
- # TODO Find a solution to register time columns!
506
+ for relation_cfg in cfg.get('relationships', []):
507
+ name = relation_cfg['name']
508
+ if len(relation_cfg['relationship_columns']) != 1:
509
+ msgs.append(f"Failed to add relationship '{name}' since "
510
+ f"composite key references are not yet supported")
511
+ continue
381
512
 
382
- for relations in view['relationships']:
383
- if len(relations['relationship_columns']) != 1:
384
- continue # NOTE No composite keys yet.
385
- graph.link(
386
- src_table=relations['left_table'],
387
- fkey=relations['relationship_columns'][0]['left_column'],
388
- dst_table=relations['right_table'],
389
- )
513
+ left_table = relation_cfg['left_table']
514
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
515
+ right_table = relation_cfg['right_table']
516
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
517
+
518
+ if graph[right_table]._primary_key != right_key:
519
+ # Semantic view error - this should never be triggered:
520
+ msgs.append(f"Failed to add relationship '{name}' since the "
521
+ f"referenced key '{right_key}' of table "
522
+ f"'{right_table}' is not a primary key")
523
+ continue
524
+
525
+ if graph[left_table]._primary_key == left_key:
526
+ msgs.append(f"Failed to add relationship '{name}' since the "
527
+ f"referencing key '{left_key}' of table "
528
+ f"'{left_table}' is a primary key")
529
+ continue
530
+
531
+ if left_key not in graph[left_table]:
532
+ graph[left_table].add_column(left_key)
533
+
534
+ graph.link(left_table, left_key, right_table)
535
+
536
+ graph.validate()
390
537
 
391
538
  if verbose:
392
539
  graph.print_metadata()
393
540
  graph.print_links()
394
541
 
542
+ if len(msgs) > 0:
543
+ title = (f"Could not fully convert the semantic view definition "
544
+ f"'{semantic_view_name}' into a graph:\n")
545
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
546
+
395
547
  return graph
396
548
 
397
- # Tables ##############################################################
549
+ # Backend #################################################################
550
+
551
+ @property
552
+ def backend(self) -> DataBackend | None:
553
+ backends = [table.backend for table in self._tables.values()]
554
+ return backends[0] if len(backends) > 0 else None
555
+
556
+ # Tables ##################################################################
398
557
 
399
558
  def has_table(self, name: str) -> bool:
400
559
  r"""Returns ``True`` if the graph has a table with name ``name``;
@@ -413,7 +572,7 @@ class Graph:
413
572
  return self.tables[name]
414
573
 
415
574
  @property
416
- def tables(self) -> Dict[str, Table]:
575
+ def tables(self) -> dict[str, Table]:
417
576
  r"""Returns the dictionary of table objects."""
418
577
  return self._tables
419
578
 
@@ -433,13 +592,10 @@ class Graph:
433
592
  raise KeyError(f"Cannot add table with name '{table.name}' to "
434
593
  f"this graph; table names must be globally unique.")
435
594
 
436
- if len(self._tables) > 0:
437
- cls = next(iter(self._tables.values())).__class__
438
- if table.__class__ != cls:
439
- raise ValueError(f"Cannot register a "
440
- f"'{table.__class__.__name__}' to this "
441
- f"graph since other tables are of type "
442
- f"'{cls.__name__}'.")
595
+ if self.backend is not None and table.backend != self.backend:
596
+ raise ValueError(f"Cannot register a table with backend "
597
+ f"'{table.backend}' to this graph since other "
598
+ f"tables have backend '{self.backend}'.")
443
599
 
444
600
  self._tables[table.name] = table
445
601
 
@@ -541,7 +697,7 @@ class Graph:
541
697
  # Edges ###################################################################
542
698
 
543
699
  @property
544
- def edges(self) -> List[Edge]:
700
+ def edges(self) -> list[Edge]:
545
701
  r"""Returns the edges of the graph."""
546
702
  return self._edges
547
703
 
@@ -556,7 +712,7 @@ class Graph:
556
712
  st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
557
713
  if len(edges) > 0:
558
714
  st.markdown('\n'.join([
559
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
715
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
560
716
  for edge in edges
561
717
  ]))
562
718
  else:
@@ -584,9 +740,9 @@ class Graph:
584
740
 
585
741
  def link(
586
742
  self,
587
- src_table: Union[str, Table],
743
+ src_table: str | Table,
588
744
  fkey: str,
589
- dst_table: Union[str, Table],
745
+ dst_table: str | Table,
590
746
  ) -> Self:
591
747
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
592
748
  key ``fkey`` in the source table to the primary key in the destination
@@ -647,9 +803,9 @@ class Graph:
647
803
 
648
804
  def unlink(
649
805
  self,
650
- src_table: Union[str, Table],
806
+ src_table: str | Table,
651
807
  fkey: str,
652
- dst_table: Union[str, Table],
808
+ dst_table: str | Table,
653
809
  ) -> Self:
654
810
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
655
811
 
@@ -826,6 +982,10 @@ class Graph:
826
982
  raise ValueError("At least one table needs to be added to the "
827
983
  "graph")
828
984
 
985
+ backends = {table.backend for table in self._tables.values()}
986
+ if len(backends) != 1:
987
+ raise ValueError("Found multiple table backends in the graph")
988
+
829
989
  for edge in self.edges:
830
990
  src_table, fkey, dst_table = edge
831
991
 
@@ -872,7 +1032,7 @@ class Graph:
872
1032
 
873
1033
  def visualize(
874
1034
  self,
875
- path: Optional[Union[str, io.BytesIO]] = None,
1035
+ path: str | io.BytesIO | None = None,
876
1036
  show_columns: bool = True,
877
1037
  ) -> 'graphviz.Graph':
878
1038
  r"""Visualizes the tables and edges in this graph using the
@@ -911,19 +1071,19 @@ class Graph:
911
1071
  "them as described at "
912
1072
  "https://graphviz.org/download/.")
913
1073
 
914
- format: Optional[str] = None
1074
+ format: str | None = None
915
1075
  if isinstance(path, str):
916
1076
  format = path.split('.')[-1]
917
1077
  elif isinstance(path, io.BytesIO):
918
1078
  format = 'svg'
919
1079
  graph = graphviz.Graph(format=format)
920
1080
 
921
- def left_align(keys: List[str]) -> str:
1081
+ def left_align(keys: list[str]) -> str:
922
1082
  if len(keys) == 0:
923
1083
  return ""
924
1084
  return '\\l'.join(keys) + '\\l'
925
1085
 
926
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1086
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
927
1087
  for src_table_name, fkey_name, _ in self.edges:
928
1088
  fkeys_dict[src_table_name].append(fkey_name)
929
1089
 
@@ -1019,8 +1179,8 @@ class Graph:
1019
1179
  # Helpers #################################################################
1020
1180
 
1021
1181
  def _to_api_graph_definition(self) -> GraphDefinition:
1022
- tables: Dict[str, TableDefinition] = {}
1023
- col_groups: List[ColumnKeyGroup] = []
1182
+ tables: dict[str, TableDefinition] = {}
1183
+ col_groups: list[ColumnKeyGroup] = []
1024
1184
  for table_name, table in self.tables.items():
1025
1185
  tables[table_name] = table._to_api_table_definition()
1026
1186
  if table.primary_key is None:
@@ -1063,3 +1223,7 @@ class Graph:
1063
1223
  f' tables={tables},\n'
1064
1224
  f' edges={edges},\n'
1065
1225
  f')')
1226
+
1227
+ def __del__(self) -> None:
1228
+ if hasattr(self, '_connection'):
1229
+ self._connection.close()
@@ -1,17 +1,19 @@
1
1
  from .dtype import infer_dtype
2
- from .pkey import infer_primary_key
3
- from .time_col import infer_time_column
4
2
  from .id import contains_id
5
3
  from .timestamp import contains_timestamp
6
4
  from .categorical import contains_categorical
7
5
  from .multicategorical import contains_multicategorical
6
+ from .stype import infer_stype
7
+ from .pkey import infer_primary_key
8
+ from .time_col import infer_time_column
8
9
 
9
10
  __all__ = [
10
11
  'infer_dtype',
11
- 'infer_primary_key',
12
- 'infer_time_column',
13
12
  'contains_id',
14
13
  'contains_timestamp',
15
14
  'contains_categorical',
16
15
  'contains_multicategorical',
16
+ 'infer_stype',
17
+ 'infer_primary_key',
18
+ 'infer_time_column',
17
19
  ]
@@ -1,17 +1,17 @@
1
- from typing import Dict
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  import pyarrow as pa
6
4
  from kumoapi.typing import Dtype
7
5
 
8
- PANDAS_TO_DTYPE: Dict[str, Dtype] = {
6
+ PANDAS_TO_DTYPE: dict[str, Dtype] = {
9
7
  'bool': Dtype.bool,
10
8
  'boolean': Dtype.bool,
11
9
  'int8': Dtype.int,
12
10
  'int16': Dtype.int,
13
11
  'int32': Dtype.int,
14
12
  'int64': Dtype.int,
13
+ 'float': Dtype.float,
14
+ 'double': Dtype.float,
15
15
  'float16': Dtype.float,
16
16
  'float32': Dtype.float,
17
17
  'float64': Dtype.float,