kumoai 2.14.0.dev202512151351__cp313-cp313-macosx_11_0_arm64.whl → 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/experimental/rfm/__init__.py +33 -8
  3. kumoai/experimental/rfm/authenticate.py +3 -4
  4. kumoai/experimental/rfm/backend/local/graph_store.py +25 -25
  5. kumoai/experimental/rfm/backend/local/table.py +16 -21
  6. kumoai/experimental/rfm/backend/snow/sampler.py +22 -34
  7. kumoai/experimental/rfm/backend/snow/table.py +67 -33
  8. kumoai/experimental/rfm/backend/sqlite/__init__.py +2 -2
  9. kumoai/experimental/rfm/backend/sqlite/sampler.py +21 -26
  10. kumoai/experimental/rfm/backend/sqlite/table.py +54 -26
  11. kumoai/experimental/rfm/base/__init__.py +8 -0
  12. kumoai/experimental/rfm/base/column.py +14 -12
  13. kumoai/experimental/rfm/base/column_expression.py +50 -0
  14. kumoai/experimental/rfm/base/sql_sampler.py +31 -3
  15. kumoai/experimental/rfm/base/sql_table.py +229 -0
  16. kumoai/experimental/rfm/base/table.py +162 -143
  17. kumoai/experimental/rfm/graph.py +242 -95
  18. kumoai/experimental/rfm/infer/__init__.py +6 -4
  19. kumoai/experimental/rfm/infer/dtype.py +3 -3
  20. kumoai/experimental/rfm/infer/pkey.py +4 -2
  21. kumoai/experimental/rfm/infer/stype.py +35 -0
  22. kumoai/experimental/rfm/infer/time_col.py +1 -2
  23. kumoai/experimental/rfm/pquery/executor.py +27 -27
  24. kumoai/experimental/rfm/pquery/pandas_executor.py +29 -31
  25. kumoai/experimental/rfm/rfm.py +86 -80
  26. kumoai/experimental/rfm/sagemaker.py +4 -4
  27. kumoai/utils/__init__.py +1 -2
  28. kumoai/utils/progress_logger.py +178 -12
  29. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/METADATA +2 -1
  30. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/RECORD +33 -30
  31. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/WHEEL +0 -0
  32. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/licenses/LICENSE +0 -0
  33. {kumoai-2.14.0.dev202512151351.dist-info → kumoai-2.14.0.dev202512211732.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
1
  import contextlib
2
+ import copy
2
3
  import io
3
4
  import warnings
4
5
  from collections import defaultdict
6
+ from collections.abc import Sequence
5
7
  from dataclasses import dataclass, field
8
+ from itertools import chain
6
9
  from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
10
+ from typing import TYPE_CHECKING, Any, Union
8
11
 
9
12
  import pandas as pd
10
13
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -13,7 +16,12 @@ from kumoapi.typing import Stype
13
16
  from typing_extensions import Self
14
17
 
15
18
  from kumoai import in_notebook, in_snowflake_notebook
16
- from kumoai.experimental.rfm.base import DataBackend, Table
19
+ from kumoai.experimental.rfm.base import (
20
+ ColumnExpressionSpec,
21
+ DataBackend,
22
+ SQLTable,
23
+ Table,
24
+ )
17
25
  from kumoai.graph import Edge
18
26
  from kumoai.mixin import CastMixin
19
27
 
@@ -25,8 +33,8 @@ if TYPE_CHECKING:
25
33
 
26
34
  @dataclass
27
35
  class SqliteConnectionConfig(CastMixin):
28
- uri: Union[str, Path]
29
- kwargs: Dict[str, Any] = field(default_factory=dict)
36
+ uri: str | Path
37
+ kwargs: dict[str, Any] = field(default_factory=dict)
30
38
 
31
39
 
32
40
  class Graph:
@@ -86,27 +94,38 @@ class Graph:
86
94
  def __init__(
87
95
  self,
88
96
  tables: Sequence[Table],
89
- edges: Optional[Sequence[Edge]] = None,
97
+ edges: Sequence[Edge] | None = None,
90
98
  ) -> None:
91
99
 
92
- self._tables: Dict[str, Table] = {}
93
- self._edges: List[Edge] = []
100
+ self._tables: dict[str, Table] = {}
101
+ self._edges: list[Edge] = []
94
102
 
95
103
  for table in tables:
96
104
  self.add_table(table)
97
105
 
98
106
  for table in tables:
107
+ if not isinstance(table, SQLTable):
108
+ continue
109
+ if '_source_column_dict' not in table.__dict__:
110
+ continue
99
111
  for fkey in table._source_foreign_key_dict.values():
100
- if fkey.name not in table or fkey.dst_table not in self:
112
+ if fkey.name not in table:
113
+ continue
114
+ if not table[fkey.name].is_physical:
115
+ continue
116
+ dst_table_names = [
117
+ table.name for table in self.tables.values()
118
+ if isinstance(table, SQLTable)
119
+ and table._source_name == fkey.dst_table
120
+ ]
121
+ if len(dst_table_names) != 1:
122
+ continue
123
+ dst_table = self[dst_table_names[0]]
124
+ if dst_table._primary_key != fkey.primary_key:
125
+ continue
126
+ if not dst_table[fkey.primary_key].is_physical:
101
127
  continue
102
- if self[fkey.dst_table].primary_key is None:
103
- self[fkey.dst_table].primary_key = fkey.primary_key
104
- elif self[fkey.dst_table]._primary_key != fkey.primary_key:
105
- raise ValueError(f"Found duplicate primary key definition "
106
- f"'{self[fkey.dst_table]._primary_key}' "
107
- f"and '{fkey.primary_key}' in table "
108
- f"'{fkey.dst_table}'.")
109
- self.link(table.name, fkey.name, fkey.dst_table)
128
+ self.link(table.name, fkey.name, dst_table.name)
110
129
 
111
130
  for edge in (edges or []):
112
131
  _edge = Edge._cast(edge)
@@ -117,8 +136,8 @@ class Graph:
117
136
  @classmethod
118
137
  def from_data(
119
138
  cls,
120
- df_dict: Dict[str, pd.DataFrame],
121
- edges: Optional[Sequence[Edge]] = None,
139
+ df_dict: dict[str, pd.DataFrame],
140
+ edges: Sequence[Edge] | None = None,
122
141
  infer_metadata: bool = True,
123
142
  verbose: bool = True,
124
143
  ) -> Self:
@@ -156,15 +175,17 @@ class Graph:
156
175
  verbose: Whether to print verbose output.
157
176
  """
158
177
  from kumoai.experimental.rfm.backend.local import LocalTable
159
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
160
178
 
161
- graph = cls(tables, edges=edges or [])
179
+ graph = cls(
180
+ tables=[LocalTable(df, name) for name, df in df_dict.items()],
181
+ edges=edges or [],
182
+ )
162
183
 
163
184
  if infer_metadata:
164
- graph.infer_metadata(False)
185
+ graph.infer_metadata(verbose=False)
165
186
 
166
187
  if edges is None:
167
- graph.infer_links(False)
188
+ graph.infer_links(verbose=False)
168
189
 
169
190
  if verbose:
170
191
  graph.print_metadata()
@@ -180,10 +201,10 @@ class Graph:
180
201
  SqliteConnectionConfig,
181
202
  str,
182
203
  Path,
183
- Dict[str, Any],
204
+ dict[str, Any],
184
205
  ],
185
- table_names: Optional[Sequence[str]] = None,
186
- edges: Optional[Sequence[Edge]] = None,
206
+ tables: Sequence[str | dict[str, Any]] | None = None,
207
+ edges: Sequence[Edge] | None = None,
187
208
  infer_metadata: bool = True,
188
209
  verbose: bool = True,
189
210
  ) -> Self:
@@ -199,17 +220,25 @@ class Graph:
199
220
  >>> # Create a graph from a SQLite database:
200
221
  >>> graph = rfm.Graph.from_sqlite('data.db')
201
222
 
223
+ >>> # Fine-grained control over table specification:
224
+ >>> graph = rfm.Graph.from_sqlite('data.db', tables=[
225
+ ... 'USERS',
226
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
227
+ ... dict(name='ITEMS', primary_key='ITEM_ID'),
228
+ ... ])
229
+
202
230
  Args:
203
231
  connection: An open connection from
204
232
  :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
205
233
  path to the database file.
206
- table_names: Set of table names to include. If ``None``, will add
207
- all tables present in the database.
234
+ tables: Set of table names or :class:`SQLiteTable` keyword
235
+ arguments to include. If ``None``, will add all tables present
236
+ in the database.
208
237
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
209
238
  add to the graph. If not provided, edges will be automatically
210
239
  inferred from the data in case ``infer_metadata=True``.
211
- infer_metadata: Whether to infer metadata for all tables in the
212
- graph.
240
+ infer_metadata: Whether to infer missing metadata for all tables in
241
+ the graph.
213
242
  verbose: Whether to print verbose output.
214
243
  """
215
244
  from kumoai.experimental.rfm.backend.sqlite import (
@@ -226,24 +255,33 @@ class Graph:
226
255
  internal_connection = True
227
256
  assert isinstance(connection, Connection)
228
257
 
229
- if table_names is None:
258
+ if tables is None:
230
259
  with connection.cursor() as cursor:
231
260
  cursor.execute("SELECT name FROM sqlite_master "
232
261
  "WHERE type='table'")
233
- table_names = [row[0] for row in cursor.fetchall()]
234
-
235
- tables = [SQLiteTable(connection, name) for name in table_names]
262
+ tables = [row[0] for row in cursor.fetchall()]
236
263
 
237
- graph = cls(tables, edges=edges or [])
264
+ table_kwargs: list[dict[str, Any]] = []
265
+ for table in tables:
266
+ kwargs = dict(name=table) if isinstance(table, str) else table
267
+ table_kwargs.append(kwargs)
268
+
269
+ graph = cls(
270
+ tables=[
271
+ SQLiteTable(connection=connection, **kwargs)
272
+ for kwargs in table_kwargs
273
+ ],
274
+ edges=edges or [],
275
+ )
238
276
 
239
277
  if internal_connection:
240
278
  graph._connection = connection # type: ignore
241
279
 
242
280
  if infer_metadata:
243
- graph.infer_metadata(False)
281
+ graph.infer_metadata(verbose=False)
244
282
 
245
283
  if edges is None:
246
- graph.infer_links(False)
284
+ graph.infer_links(verbose=False)
247
285
 
248
286
  if verbose:
249
287
  graph.print_metadata()
@@ -254,11 +292,11 @@ class Graph:
254
292
  @classmethod
255
293
  def from_snowflake(
256
294
  cls,
257
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
258
- database: Optional[str] = None,
259
- schema: Optional[str] = None,
260
- table_names: Optional[Sequence[str]] = None,
261
- edges: Optional[Sequence[Edge]] = None,
295
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
296
+ tables: Sequence[str | dict[str, Any]] | None = None,
297
+ database: str | None = None,
298
+ schema: str | None = None,
299
+ edges: Sequence[Edge] | None = None,
262
300
  infer_metadata: bool = True,
263
301
  verbose: bool = True,
264
302
  ) -> Self:
@@ -275,6 +313,13 @@ class Graph:
275
313
  >>> # Create a graph directly in a Snowflake notebook:
276
314
  >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
277
315
 
316
+ >>> # Fine-grained control over table specification:
317
+ >>> graph = rfm.Graph.from_snowflake(tables=[
318
+ ... 'USERS',
319
+ ... dict(name='ORDERS', source_name='ORDERS_SNAPSHOT'),
320
+ ... dict(name='ITEMS', schema='OTHER_SCHEMA'),
321
+ ... ], database='DEFAULT_DB', schema='DEFAULT_SCHEMA')
322
+
278
323
  Args:
279
324
  connection: An open connection from
280
325
  :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
@@ -282,10 +327,11 @@ class Graph:
282
327
  connection. If ``None``, will re-use an active session in case
283
328
  it exists, or create a new connection from credentials stored
284
329
  in environment variables.
330
+ tables: Set of table names or :class:`SnowTable` keyword arguments
331
+ to include. If ``None``, will add all tables present in the
332
+ current database and schema.
285
333
  database: The database.
286
334
  schema: The schema.
287
- table_names: Set of table names to include. If ``None``, will add
288
- all tables present in the database.
289
335
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
290
336
  add to the graph. If not provided, edges will be automatically
291
337
  inferred from the data in case ``infer_metadata=True``.
@@ -303,37 +349,50 @@ class Graph:
303
349
  connection = connect(**(connection or {}))
304
350
  assert isinstance(connection, Connection)
305
351
 
306
- if table_names is None:
352
+ if database is None or schema is None:
353
+ with connection.cursor() as cursor:
354
+ cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
355
+ result = cursor.fetchone()
356
+ database = database or result[0]
357
+ assert database is not None
358
+ schema = schema or result[1]
359
+
360
+ if tables is None:
361
+ if schema is None:
362
+ raise ValueError("No current 'schema' set. Please specify the "
363
+ "Snowflake schema manually")
364
+
307
365
  with connection.cursor() as cursor:
308
- if database is None and schema is None:
309
- cursor.execute("SELECT CURRENT_DATABASE(), "
310
- "CURRENT_SCHEMA()")
311
- result = cursor.fetchone()
312
- database = database or result[0]
313
- schema = schema or result[1]
314
366
  cursor.execute(f"""
315
367
  SELECT TABLE_NAME
316
368
  FROM {database}.INFORMATION_SCHEMA.TABLES
317
369
  WHERE TABLE_SCHEMA = '{schema}'
318
370
  """)
319
- table_names = [row[0] for row in cursor.fetchall()]
320
-
321
- tables = [
322
- SnowTable(
323
- connection,
324
- name=table_name,
325
- database=database,
326
- schema=schema,
327
- ) for table_name in table_names
328
- ]
371
+ tables = [row[0] for row in cursor.fetchall()]
329
372
 
330
- graph = cls(tables, edges=edges or [])
373
+ table_kwargs: list[dict[str, Any]] = []
374
+ for table in tables:
375
+ if isinstance(table, str):
376
+ kwargs = dict(name=table, database=database, schema=schema)
377
+ else:
378
+ kwargs = copy.copy(table)
379
+ kwargs.setdefault('database', database)
380
+ kwargs.setdefault('schema', schema)
381
+ table_kwargs.append(kwargs)
382
+
383
+ graph = cls(
384
+ tables=[
385
+ SnowTable(connection=connection, **kwargs)
386
+ for kwargs in table_kwargs
387
+ ],
388
+ edges=edges or [],
389
+ )
331
390
 
332
391
  if infer_metadata:
333
- graph.infer_metadata(False)
392
+ graph.infer_metadata(verbose=False)
334
393
 
335
394
  if edges is None:
336
- graph.infer_links(False)
395
+ graph.infer_links(verbose=False)
337
396
 
338
397
  if verbose:
339
398
  graph.print_metadata()
@@ -345,7 +404,7 @@ class Graph:
345
404
  def from_snowflake_semantic_view(
346
405
  cls,
347
406
  semantic_view_name: str,
348
- connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
407
+ connection: Union['SnowflakeConnection', dict[str, Any], None] = None,
349
408
  verbose: bool = True,
350
409
  ) -> Self:
351
410
  import yaml
@@ -363,40 +422,128 @@ class Graph:
363
422
  with connection.cursor() as cursor:
364
423
  cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
365
424
  f"'{semantic_view_name}')")
366
- view = yaml.safe_load(cursor.fetchone()[0])
425
+ cfg = yaml.safe_load(cursor.fetchone()[0])
367
426
 
368
427
  graph = cls(tables=[])
369
428
 
370
- for table_desc in view['tables']:
371
- primary_key: Optional[str] = None
372
- if ('primary_key' in table_desc # NOTE No composite keys yet.
373
- and len(table_desc['primary_key']['columns']) == 1):
374
- primary_key = table_desc['primary_key']['columns'][0]
429
+ msgs = []
430
+ table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
431
+ for table_cfg in cfg['tables']:
432
+ table_name = table_cfg['name']
433
+ source_table_name = table_cfg['base_table']['table']
434
+ database = table_cfg['base_table']['database']
435
+ schema = table_cfg['base_table']['schema']
436
+
437
+ primary_key: str | None = None
438
+ if 'primary_key' in table_cfg:
439
+ primary_key_cfg = table_cfg['primary_key']
440
+ if len(primary_key_cfg['columns']) == 1:
441
+ primary_key = primary_key_cfg['columns'][0]
442
+ elif len(primary_key_cfg['columns']) > 1:
443
+ msgs.append(f"Failed to add primary key for table "
444
+ f"'{table_name}' since composite primary keys "
445
+ f"are not yet supported")
446
+
447
+ columns: list[str] = []
448
+ unsupported_columns: list[str] = []
449
+ column_expression_specs: list[ColumnExpressionSpec] = []
450
+ for column_cfg in chain(
451
+ table_cfg.get('dimensions', []),
452
+ table_cfg.get('time_dimensions', []),
453
+ table_cfg.get('facts', []),
454
+ ):
455
+ column_name = column_cfg['name']
456
+ column_expr = column_cfg.get('expr', None)
457
+ column_data_type = column_cfg.get('data_type', None)
458
+
459
+ if column_expr is None:
460
+ columns.append(column_name)
461
+ continue
462
+
463
+ column_expr = column_expr.replace(f'{table_name}.', '')
464
+
465
+ if column_expr == column_name:
466
+ columns.append(column_name)
467
+ continue
468
+
469
+ # Drop expressions that reference other tables (for now):
470
+ if any(f'{name}.' in column_expr for name in table_names):
471
+ unsupported_columns.append(column_name)
472
+ continue
473
+
474
+ spec = ColumnExpressionSpec(
475
+ name=column_name,
476
+ expr=column_expr,
477
+ dtype=SnowTable.to_dtype(column_data_type),
478
+ )
479
+ column_expression_specs.append(spec)
480
+
481
+ if len(unsupported_columns) == 1:
482
+ msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
483
+ f"of table '{table_name}' since its expression "
484
+ f"references other tables")
485
+ elif len(unsupported_columns) > 1:
486
+ msgs.append(f"Failed to add columns '{unsupported_columns}' "
487
+ f"of table '{table_name}' since their expressions "
488
+ f"reference other tables")
375
489
 
376
490
  table = SnowTable(
377
491
  connection,
378
- name=table_desc['base_table']['table'],
379
- database=table_desc['base_table']['database'],
380
- schema=table_desc['base_table']['schema'],
492
+ name=table_name,
493
+ source_name=source_table_name,
494
+ database=database,
495
+ schema=schema,
496
+ columns=columns,
497
+ column_expressions=column_expression_specs,
381
498
  primary_key=primary_key,
382
499
  )
500
+
501
+ # TODO Add a way to register time columns without heuristic usage.
502
+ table.infer_time_column(verbose=False)
503
+
383
504
  graph.add_table(table)
384
505
 
385
- # TODO Find a solution to register time columns!
506
+ for relation_cfg in cfg.get('relationships', []):
507
+ name = relation_cfg['name']
508
+ if len(relation_cfg['relationship_columns']) != 1:
509
+ msgs.append(f"Failed to add relationship '{name}' since "
510
+ f"composite key references are not yet supported")
511
+ continue
512
+
513
+ left_table = relation_cfg['left_table']
514
+ left_key = relation_cfg['relationship_columns'][0]['left_column']
515
+ right_table = relation_cfg['right_table']
516
+ right_key = relation_cfg['relationship_columns'][0]['right_column']
386
517
 
387
- for relations in view['relationships']:
388
- if len(relations['relationship_columns']) != 1:
389
- continue # NOTE No composite keys yet.
390
- graph.link(
391
- src_table=relations['left_table'],
392
- fkey=relations['relationship_columns'][0]['left_column'],
393
- dst_table=relations['right_table'],
394
- )
518
+ if graph[right_table]._primary_key != right_key:
519
+ # Semantic view error - this should never be triggered:
520
+ msgs.append(f"Failed to add relationship '{name}' since the "
521
+ f"referenced key '{right_key}' of table "
522
+ f"'{right_table}' is not a primary key")
523
+ continue
524
+
525
+ if graph[left_table]._primary_key == left_key:
526
+ msgs.append(f"Failed to add relationship '{name}' since the "
527
+ f"referencing key '{left_key}' of table "
528
+ f"'{left_table}' is a primary key")
529
+ continue
530
+
531
+ if left_key not in graph[left_table]:
532
+ graph[left_table].add_column(left_key)
533
+
534
+ graph.link(left_table, left_key, right_table)
535
+
536
+ graph.validate()
395
537
 
396
538
  if verbose:
397
539
  graph.print_metadata()
398
540
  graph.print_links()
399
541
 
542
+ if len(msgs) > 0:
543
+ title = (f"Could not fully convert the semantic view definition "
544
+ f"'{semantic_view_name}' into a graph:\n")
545
+ warnings.warn(title + '\n'.join(f'- {msg}' for msg in msgs))
546
+
400
547
  return graph
401
548
 
402
549
  # Backend #################################################################
@@ -425,7 +572,7 @@ class Graph:
425
572
  return self.tables[name]
426
573
 
427
574
  @property
428
- def tables(self) -> Dict[str, Table]:
575
+ def tables(self) -> dict[str, Table]:
429
576
  r"""Returns the dictionary of table objects."""
430
577
  return self._tables
431
578
 
@@ -550,7 +697,7 @@ class Graph:
550
697
  # Edges ###################################################################
551
698
 
552
699
  @property
553
- def edges(self) -> List[Edge]:
700
+ def edges(self) -> list[Edge]:
554
701
  r"""Returns the edges of the graph."""
555
702
  return self._edges
556
703
 
@@ -565,7 +712,7 @@ class Graph:
565
712
  st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
566
713
  if len(edges) > 0:
567
714
  st.markdown('\n'.join([
568
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
715
+ f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
569
716
  for edge in edges
570
717
  ]))
571
718
  else:
@@ -593,9 +740,9 @@ class Graph:
593
740
 
594
741
  def link(
595
742
  self,
596
- src_table: Union[str, Table],
743
+ src_table: str | Table,
597
744
  fkey: str,
598
- dst_table: Union[str, Table],
745
+ dst_table: str | Table,
599
746
  ) -> Self:
600
747
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
601
748
  key ``fkey`` in the source table to the primary key in the destination
@@ -656,9 +803,9 @@ class Graph:
656
803
 
657
804
  def unlink(
658
805
  self,
659
- src_table: Union[str, Table],
806
+ src_table: str | Table,
660
807
  fkey: str,
661
- dst_table: Union[str, Table],
808
+ dst_table: str | Table,
662
809
  ) -> Self:
663
810
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
664
811
 
@@ -885,7 +1032,7 @@ class Graph:
885
1032
 
886
1033
  def visualize(
887
1034
  self,
888
- path: Optional[Union[str, io.BytesIO]] = None,
1035
+ path: str | io.BytesIO | None = None,
889
1036
  show_columns: bool = True,
890
1037
  ) -> 'graphviz.Graph':
891
1038
  r"""Visualizes the tables and edges in this graph using the
@@ -924,19 +1071,19 @@ class Graph:
924
1071
  "them as described at "
925
1072
  "https://graphviz.org/download/.")
926
1073
 
927
- format: Optional[str] = None
1074
+ format: str | None = None
928
1075
  if isinstance(path, str):
929
1076
  format = path.split('.')[-1]
930
1077
  elif isinstance(path, io.BytesIO):
931
1078
  format = 'svg'
932
1079
  graph = graphviz.Graph(format=format)
933
1080
 
934
- def left_align(keys: List[str]) -> str:
1081
+ def left_align(keys: list[str]) -> str:
935
1082
  if len(keys) == 0:
936
1083
  return ""
937
1084
  return '\\l'.join(keys) + '\\l'
938
1085
 
939
- fkeys_dict: Dict[str, List[str]] = defaultdict(list)
1086
+ fkeys_dict: dict[str, list[str]] = defaultdict(list)
940
1087
  for src_table_name, fkey_name, _ in self.edges:
941
1088
  fkeys_dict[src_table_name].append(fkey_name)
942
1089
 
@@ -1032,8 +1179,8 @@ class Graph:
1032
1179
  # Helpers #################################################################
1033
1180
 
1034
1181
  def _to_api_graph_definition(self) -> GraphDefinition:
1035
- tables: Dict[str, TableDefinition] = {}
1036
- col_groups: List[ColumnKeyGroup] = []
1182
+ tables: dict[str, TableDefinition] = {}
1183
+ col_groups: list[ColumnKeyGroup] = []
1037
1184
  for table_name, table in self.tables.items():
1038
1185
  tables[table_name] = table._to_api_table_definition()
1039
1186
  if table.primary_key is None:
@@ -1,17 +1,19 @@
1
1
  from .dtype import infer_dtype
2
- from .pkey import infer_primary_key
3
- from .time_col import infer_time_column
4
2
  from .id import contains_id
5
3
  from .timestamp import contains_timestamp
6
4
  from .categorical import contains_categorical
7
5
  from .multicategorical import contains_multicategorical
6
+ from .stype import infer_stype
7
+ from .pkey import infer_primary_key
8
+ from .time_col import infer_time_column
8
9
 
9
10
  __all__ = [
10
11
  'infer_dtype',
11
- 'infer_primary_key',
12
- 'infer_time_column',
13
12
  'contains_id',
14
13
  'contains_timestamp',
15
14
  'contains_categorical',
16
15
  'contains_multicategorical',
16
+ 'infer_stype',
17
+ 'infer_primary_key',
18
+ 'infer_time_column',
17
19
  ]
@@ -1,17 +1,17 @@
1
- from typing import Dict
2
-
3
1
  import numpy as np
4
2
  import pandas as pd
5
3
  import pyarrow as pa
6
4
  from kumoapi.typing import Dtype
7
5
 
8
- PANDAS_TO_DTYPE: Dict[str, Dtype] = {
6
+ PANDAS_TO_DTYPE: dict[str, Dtype] = {
9
7
  'bool': Dtype.bool,
10
8
  'boolean': Dtype.bool,
11
9
  'int8': Dtype.int,
12
10
  'int16': Dtype.int,
13
11
  'int32': Dtype.int,
14
12
  'int64': Dtype.int,
13
+ 'float': Dtype.float,
14
+ 'double': Dtype.float,
15
15
  'float16': Dtype.float,
16
16
  'float32': Dtype.float,
17
17
  'float64': Dtype.float,
@@ -1,6 +1,5 @@
1
1
  import re
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  import pandas as pd
6
5
 
@@ -9,7 +8,7 @@ def infer_primary_key(
9
8
  table_name: str,
10
9
  df: pd.DataFrame,
11
10
  candidates: list[str],
12
- ) -> Optional[str]:
11
+ ) -> str | None:
13
12
  r"""Auto-detect potential primary key column.
14
13
 
15
14
  Args:
@@ -20,6 +19,9 @@ def infer_primary_key(
20
19
  Returns:
21
20
  The name of the detected primary key, or ``None`` if not found.
22
21
  """
22
+ if len(candidates) == 0:
23
+ return None
24
+
23
25
  # A list of (potentially modified) table names that are eligible to match
24
26
  # with a primary key, i.e.:
25
27
  # - UserInfo -> User
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+ from kumoapi.typing import Dtype, Stype
3
+
4
+ from kumoai.experimental.rfm.infer import (
5
+ contains_categorical,
6
+ contains_id,
7
+ contains_multicategorical,
8
+ contains_timestamp,
9
+ )
10
+
11
+
12
+ def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
13
+ """Infers the :class:`Stype` from a :class:`pandas.Series`.
14
+
15
+ Args:
16
+ ser: A :class:`pandas.Series` to analyze.
17
+ column_name: The column name.
18
+ dtype: The data type.
19
+
20
+ Returns:
21
+ The semantic type.
22
+ """
23
+ if contains_id(ser, column_name, dtype):
24
+ return Stype.ID
25
+
26
+ if contains_timestamp(ser, column_name, dtype):
27
+ return Stype.timestamp
28
+
29
+ if contains_multicategorical(ser, column_name, dtype):
30
+ return Stype.multicategorical
31
+
32
+ if contains_categorical(ser, column_name, dtype):
33
+ return Stype.categorical
34
+
35
+ return dtype.default_stype