kumoai 2.13.0.dev202511201731__cp312-cp312-win_amd64.whl → 2.13.0.dev202512040651__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. kumoai/_version.py +1 -1
  2. kumoai/connector/utils.py +23 -2
  3. kumoai/experimental/rfm/__init__.py +20 -45
  4. kumoai/experimental/rfm/backend/__init__.py +0 -0
  5. kumoai/experimental/rfm/backend/local/__init__.py +38 -0
  6. kumoai/experimental/rfm/backend/local/table.py +109 -0
  7. kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
  8. kumoai/experimental/rfm/backend/snow/table.py +117 -0
  9. kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
  10. kumoai/experimental/rfm/backend/sqlite/table.py +101 -0
  11. kumoai/experimental/rfm/base/__init__.py +10 -0
  12. kumoai/experimental/rfm/base/column.py +66 -0
  13. kumoai/experimental/rfm/base/source.py +18 -0
  14. kumoai/experimental/rfm/{local_table.py → base/table.py} +134 -139
  15. kumoai/experimental/rfm/{local_graph.py → graph.py} +301 -62
  16. kumoai/experimental/rfm/infer/__init__.py +6 -0
  17. kumoai/experimental/rfm/infer/dtype.py +79 -0
  18. kumoai/experimental/rfm/infer/pkey.py +126 -0
  19. kumoai/experimental/rfm/infer/time_col.py +62 -0
  20. kumoai/experimental/rfm/local_graph_sampler.py +42 -1
  21. kumoai/experimental/rfm/local_graph_store.py +13 -27
  22. kumoai/experimental/rfm/rfm.py +16 -17
  23. kumoai/experimental/rfm/sagemaker.py +11 -3
  24. kumoai/kumolib.cp312-win_amd64.pyd +0 -0
  25. kumoai/testing/decorators.py +1 -1
  26. {kumoai-2.13.0.dev202511201731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/METADATA +9 -8
  27. {kumoai-2.13.0.dev202511201731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/RECORD +30 -18
  28. kumoai/experimental/rfm/utils.py +0 -344
  29. {kumoai-2.13.0.dev202511201731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/WHEEL +0 -0
  30. {kumoai-2.13.0.dev202511201731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/licenses/LICENSE +0 -0
  31. {kumoai-2.13.0.dev202511201731.dist-info → kumoai-2.13.0.dev202512040651.dist-info}/top_level.txt +0 -0
@@ -2,8 +2,10 @@ import contextlib
2
2
  import io
3
3
  import warnings
4
4
  from collections import defaultdict
5
+ from dataclasses import dataclass, field
5
6
  from importlib.util import find_spec
6
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
7
9
 
8
10
  import pandas as pd
9
11
  from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
@@ -12,19 +14,28 @@ from kumoapi.typing import Stype
12
14
  from typing_extensions import Self
13
15
 
14
16
  from kumoai import in_notebook
15
- from kumoai.experimental.rfm import LocalTable
17
+ from kumoai.experimental.rfm import Table
16
18
  from kumoai.graph import Edge
19
+ from kumoai.mixin import CastMixin
17
20
 
18
21
  if TYPE_CHECKING:
19
22
  import graphviz
23
+ from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
24
+ from snowflake.connector import SnowflakeConnection
20
25
 
21
26
 
22
- class LocalGraph:
23
- r"""A graph of :class:`LocalTable` objects, akin to relationships between
27
+ @dataclass
28
+ class SqliteConnectionConfig(CastMixin):
29
+ uri: Union[str, Path]
30
+ kwargs: Dict[str, Any] = field(default_factory=dict)
31
+
32
+
33
+ class Graph:
34
+ r"""A graph of :class:`Table` objects, akin to relationships between
24
35
  tables in a relational database.
25
36
 
26
37
  Creating a graph is the final step of data definition; after a
27
- :class:`LocalGraph` is created, you can use it to initialize the
38
+ :class:`Graph` is created, you can use it to initialize the
28
39
  Kumo Relational Foundation Model (:class:`KumoRFM`).
29
40
 
30
41
  .. code-block:: python
@@ -44,7 +55,7 @@ class LocalGraph:
44
55
  >>> table3 = rfm.LocalTable(name="table3", data=df3)
45
56
 
46
57
  >>> # Create a graph from a dictionary of tables:
47
- >>> graph = rfm.LocalGraph({
58
+ >>> graph = rfm.Graph({
48
59
  ... "table1": table1,
49
60
  ... "table2": table2,
50
61
  ... "table3": table3,
@@ -75,33 +86,47 @@ class LocalGraph:
75
86
 
76
87
  def __init__(
77
88
  self,
78
- tables: List[LocalTable],
79
- edges: Optional[List[Edge]] = None,
89
+ tables: Sequence[Table],
90
+ edges: Optional[Sequence[Edge]] = None,
80
91
  ) -> None:
81
92
 
82
- self._tables: Dict[str, LocalTable] = {}
93
+ self._tables: Dict[str, Table] = {}
83
94
  self._edges: List[Edge] = []
84
95
 
85
96
  for table in tables:
86
97
  self.add_table(table)
87
98
 
99
+ for table in tables:
100
+ for fkey in table._source_foreign_key_dict.values():
101
+ if fkey.name not in table or fkey.dst_table not in self:
102
+ continue
103
+ if self[fkey.dst_table].primary_key is None:
104
+ self[fkey.dst_table].primary_key = fkey.primary_key
105
+ elif self[fkey.dst_table]._primary_key != fkey.primary_key:
106
+ raise ValueError(f"Found duplicate primary key definition "
107
+ f"'{self[fkey.dst_table]._primary_key}' "
108
+ f"and '{fkey.primary_key}' in table "
109
+ f"'{fkey.dst_table}'.")
110
+ self.link(table.name, fkey.name, fkey.dst_table)
111
+
88
112
  for edge in (edges or []):
89
113
  _edge = Edge._cast(edge)
90
114
  assert _edge is not None
91
- self.link(*_edge)
115
+ if _edge not in self._edges:
116
+ self.link(*_edge)
92
117
 
93
118
  @classmethod
94
119
  def from_data(
95
120
  cls,
96
121
  df_dict: Dict[str, pd.DataFrame],
97
- edges: Optional[List[Edge]] = None,
122
+ edges: Optional[Sequence[Edge]] = None,
98
123
  infer_metadata: bool = True,
99
124
  verbose: bool = True,
100
125
  ) -> Self:
101
- r"""Creates a :class:`LocalGraph` from a dictionary of
126
+ r"""Creates a :class:`Graph` from a dictionary of
102
127
  :class:`pandas.DataFrame` objects.
103
128
 
104
- Automatically infers table metadata and links.
129
+ Automatically infers table metadata and links by default.
105
130
 
106
131
  .. code-block:: python
107
132
 
@@ -115,55 +140,258 @@ class LocalGraph:
115
140
  >>> df3 = pd.DataFrame(...)
116
141
 
117
142
  >>> # Create a graph from a dictionary of data frames:
118
- >>> graph = rfm.LocalGraph.from_data({
143
+ >>> graph = rfm.Graph.from_data({
119
144
  ... "table1": df1,
120
145
  ... "table2": df2,
121
146
  ... "table3": df3,
122
147
  ... })
123
148
 
124
- >>> # Inspect table metadata:
125
- >>> for table in graph.tables.values():
126
- ... table.print_metadata()
127
-
128
- >>> # Visualize graph (if graphviz is installed):
129
- >>> graph.visualize()
130
-
131
149
  Args:
132
150
  df_dict: A dictionary of data frames, where the keys are the names
133
151
  of the tables and the values hold table data.
152
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
153
+ add to the graph. If not provided, edges will be automatically
154
+ inferred from the data in case ``infer_metadata=True``.
134
155
  infer_metadata: Whether to infer metadata for all tables in the
135
156
  graph.
157
+ verbose: Whether to print verbose output.
158
+ """
159
+ from kumoai.experimental.rfm.backend.local import LocalTable
160
+ tables = [LocalTable(df, name) for name, df in df_dict.items()]
161
+
162
+ graph = cls(tables, edges=edges or [])
163
+
164
+ if infer_metadata:
165
+ graph.infer_metadata(False)
166
+
167
+ if edges is None:
168
+ graph.infer_links(False)
169
+
170
+ if verbose:
171
+ graph.print_metadata()
172
+ graph.print_links()
173
+
174
+ return graph
175
+
176
+ @classmethod
177
+ def from_sqlite(
178
+ cls,
179
+ connection: Union[
180
+ 'AdbcSqliteConnection',
181
+ SqliteConnectionConfig,
182
+ str,
183
+ Path,
184
+ Dict[str, Any],
185
+ ],
186
+ table_names: Optional[Sequence[str]] = None,
187
+ edges: Optional[Sequence[Edge]] = None,
188
+ infer_metadata: bool = True,
189
+ verbose: bool = True,
190
+ ) -> Self:
191
+ r"""Creates a :class:`Graph` from a :class:`sqlite` database.
192
+
193
+ Automatically infers table metadata and links by default.
194
+
195
+ .. code-block:: python
196
+
197
+ >>> # doctest: +SKIP
198
+ >>> import kumoai.experimental.rfm as rfm
199
+
200
+ >>> # Create a graph from a SQLite database:
201
+ >>> graph = rfm.Graph.from_sqlite('data.db')
202
+
203
+ Args:
204
+ connection: An open connection from
205
+ :meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
206
+ path to the database file.
207
+ table_names: Set of table names to include. If ``None``, will add
208
+ all tables present in the database.
136
209
  edges: An optional list of :class:`~kumoai.graph.Edge` objects to
137
210
  add to the graph. If not provided, edges will be automatically
138
- inferred from the data.
211
+ inferred from the data in case ``infer_metadata=True``.
212
+ infer_metadata: Whether to infer metadata for all tables in the
213
+ graph.
139
214
  verbose: Whether to print verbose output.
215
+ """
216
+ from kumoai.experimental.rfm.backend.sqlite import (
217
+ Connection,
218
+ SQLiteTable,
219
+ connect,
220
+ )
221
+
222
+ if not isinstance(connection, Connection):
223
+ connection = SqliteConnectionConfig._cast(connection)
224
+ assert isinstance(connection, SqliteConnectionConfig)
225
+ connection = connect(connection.uri, **connection.kwargs)
226
+ assert isinstance(connection, Connection)
227
+
228
+ if table_names is None:
229
+ with connection.cursor() as cursor:
230
+ cursor.execute("SELECT name FROM sqlite_master "
231
+ "WHERE type='table'")
232
+ table_names = [row[0] for row in cursor.fetchall()]
233
+
234
+ tables = [SQLiteTable(connection, name) for name in table_names]
140
235
 
141
- Note:
142
- This method will automatically infer metadata and links for the
143
- graph.
236
+ graph = cls(tables, edges=edges or [])
237
+
238
+ if infer_metadata:
239
+ graph.infer_metadata(False)
240
+
241
+ if edges is None:
242
+ graph.infer_links(False)
243
+
244
+ if verbose:
245
+ graph.print_metadata()
246
+ graph.print_links()
247
+
248
+ return graph
249
+
250
+ @classmethod
251
+ def from_snowflake(
252
+ cls,
253
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
254
+ database: Optional[str] = None,
255
+ schema: Optional[str] = None,
256
+ table_names: Optional[Sequence[str]] = None,
257
+ edges: Optional[Sequence[Edge]] = None,
258
+ infer_metadata: bool = True,
259
+ verbose: bool = True,
260
+ ) -> Self:
261
+ r"""Creates a :class:`Graph` from a :class:`snowflake` database and
262
+ schema.
263
+
264
+ Automatically infers table metadata and links by default.
265
+
266
+ .. code-block:: python
144
267
 
145
- Example:
146
268
  >>> # doctest: +SKIP
147
269
  >>> import kumoai.experimental.rfm as rfm
148
- >>> df1 = pd.DataFrame(...)
149
- >>> df2 = pd.DataFrame(...)
150
- >>> df3 = pd.DataFrame(...)
151
- >>> graph = rfm.LocalGraph.from_data(data={
152
- ... "table1": df1,
153
- ... "table2": df2,
154
- ... "table3": df3,
155
- ... })
156
- >>> graph.validate()
270
+
271
+ >>> # Create a graph directly in a Snowflake notebook:
272
+ >>> graph = rfm.Graph.from_snowflake(schema='my_schema')
273
+
274
+ Args:
275
+ connection: An open connection from
276
+ :meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
277
+ :class:`snowflake` connector keyword arguments to open a new
278
+ connection. If ``None``, will re-use an active session in case
279
+ it exists, or create a new connection from credentials stored
280
+ in environment variables.
281
+ database: The database.
282
+ schema: The schema.
283
+ table_names: Set of table names to include. If ``None``, will add
284
+ all tables present in the database.
285
+ edges: An optional list of :class:`~kumoai.graph.Edge` objects to
286
+ add to the graph. If not provided, edges will be automatically
287
+ inferred from the data in case ``infer_metadata=True``.
288
+ infer_metadata: Whether to infer metadata for all tables in the
289
+ graph.
290
+ verbose: Whether to print verbose output.
157
291
  """
158
- tables = [LocalTable(df, name) for name, df in df_dict.items()]
292
+ from kumoai.experimental.rfm.backend.snow import (
293
+ Connection,
294
+ SnowTable,
295
+ connect,
296
+ )
297
+
298
+ if not isinstance(connection, Connection):
299
+ connection = connect(**(connection or {}))
300
+ assert isinstance(connection, Connection)
301
+
302
+ if table_names is None:
303
+ with connection.cursor() as cursor:
304
+ if database is None and schema is None:
305
+ cursor.execute("SELECT CURRENT_DATABASE(), "
306
+ "CURRENT_SCHEMA()")
307
+ result = cursor.fetchone()
308
+ database = database or result[0]
309
+ schema = schema or result[1]
310
+ cursor.execute(f"""
311
+ SELECT TABLE_NAME
312
+ FROM {database}.INFORMATION_SCHEMA.TABLES
313
+ WHERE TABLE_SCHEMA = '{schema}'
314
+ """)
315
+ table_names = [row[0] for row in cursor.fetchall()]
316
+
317
+ tables = [
318
+ SnowTable(
319
+ connection,
320
+ name=table_name,
321
+ database=database,
322
+ schema=schema,
323
+ ) for table_name in table_names
324
+ ]
159
325
 
160
326
  graph = cls(tables, edges=edges or [])
161
327
 
162
328
  if infer_metadata:
163
- graph.infer_metadata(verbose)
329
+ graph.infer_metadata(False)
164
330
 
165
331
  if edges is None:
166
- graph.infer_links(verbose)
332
+ graph.infer_links(False)
333
+
334
+ if verbose:
335
+ graph.print_metadata()
336
+ graph.print_links()
337
+
338
+ return graph
339
+
340
+ @classmethod
341
+ def from_snowflake_semantic_view(
342
+ cls,
343
+ semantic_view_name: str,
344
+ connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
345
+ verbose: bool = True,
346
+ ) -> Self:
347
+ import yaml
348
+
349
+ from kumoai.experimental.rfm.backend.snow import (
350
+ Connection,
351
+ SnowTable,
352
+ connect,
353
+ )
354
+
355
+ if not isinstance(connection, Connection):
356
+ connection = connect(**(connection or {}))
357
+ assert isinstance(connection, Connection)
358
+
359
+ with connection.cursor() as cursor:
360
+ cursor.execute(f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
361
+ f"'{semantic_view_name}')")
362
+ view = yaml.safe_load(cursor.fetchone()[0])
363
+
364
+ graph = cls(tables=[])
365
+
366
+ for table_desc in view['tables']:
367
+ primary_key: Optional[str] = None
368
+ if ('primary_key' in table_desc # NOTE No composite keys yet.
369
+ and len(table_desc['primary_key']['columns']) == 1):
370
+ primary_key = table_desc['primary_key']['columns'][0]
371
+
372
+ table = SnowTable(
373
+ connection,
374
+ name=table_desc['base_table']['table'],
375
+ database=table_desc['base_table']['database'],
376
+ schema=table_desc['base_table']['schema'],
377
+ primary_key=primary_key,
378
+ )
379
+ graph.add_table(table)
380
+
381
+ # TODO Find a solution to register time columns!
382
+
383
+ for relations in view['relationships']:
384
+ if len(relations['relationship_columns']) != 1:
385
+ continue # NOTE No composite keys yet.
386
+ graph.link(
387
+ src_table=relations['left_table'],
388
+ fkey=relations['relationship_columns'][0]['left_column'],
389
+ dst_table=relations['right_table'],
390
+ )
391
+
392
+ if verbose:
393
+ graph.print_metadata()
394
+ graph.print_links()
167
395
 
168
396
  return graph
169
397
 
@@ -175,7 +403,7 @@ class LocalGraph:
175
403
  """
176
404
  return name in self.tables
177
405
 
178
- def table(self, name: str) -> LocalTable:
406
+ def table(self, name: str) -> Table:
179
407
  r"""Returns the table with name ``name`` in the graph.
180
408
 
181
409
  Raises:
@@ -186,11 +414,11 @@ class LocalGraph:
186
414
  return self.tables[name]
187
415
 
188
416
  @property
189
- def tables(self) -> Dict[str, LocalTable]:
417
+ def tables(self) -> Dict[str, Table]:
190
418
  r"""Returns the dictionary of table objects."""
191
419
  return self._tables
192
420
 
193
- def add_table(self, table: LocalTable) -> Self:
421
+ def add_table(self, table: Table) -> Self:
194
422
  r"""Adds a table to the graph.
195
423
 
196
424
  Args:
@@ -199,11 +427,21 @@ class LocalGraph:
199
427
  Raises:
200
428
  KeyError: If a table with the same name already exists in the
201
429
  graph.
430
+ ValueError: If the table belongs to a different backend than the
431
+ rest of the tables in the graph.
202
432
  """
203
433
  if table.name in self._tables:
204
434
  raise KeyError(f"Cannot add table with name '{table.name}' to "
205
435
  f"this graph; table names must be globally unique.")
206
436
 
437
+ if len(self._tables) > 0:
438
+ cls = next(iter(self._tables.values())).__class__
439
+ if table.__class__ != cls:
440
+ raise ValueError(f"Cannot register a "
441
+ f"'{table.__class__.__name__}' to this "
442
+ f"graph since other tables are of type "
443
+ f"'{cls.__name__}'.")
444
+
207
445
  self._tables[table.name] = table
208
446
 
209
447
  return self
@@ -241,7 +479,7 @@ class LocalGraph:
241
479
  Example:
242
480
  >>> # doctest: +SKIP
243
481
  >>> import kumoai.experimental.rfm as rfm
244
- >>> graph = rfm.LocalGraph(tables=...).infer_metadata()
482
+ >>> graph = rfm.Graph(tables=...).infer_metadata()
245
483
  >>> graph.metadata # doctest: +SKIP
246
484
  name primary_key time_column end_time_column
247
485
  0 users user_id - -
@@ -263,7 +501,7 @@ class LocalGraph:
263
501
  })
264
502
 
265
503
  def print_metadata(self) -> None:
266
- r"""Prints the :meth:`~LocalGraph.metadata` of the graph."""
504
+ r"""Prints the :meth:`~Graph.metadata` of the graph."""
267
505
  if in_notebook():
268
506
  from IPython.display import Markdown, display
269
507
  display(Markdown('### 🗂️ Graph Metadata'))
@@ -287,7 +525,7 @@ class LocalGraph:
287
525
 
288
526
  Note:
289
527
  For more information, please see
290
- :meth:`kumoai.experimental.rfm.LocalTable.infer_metadata`.
528
+ :meth:`kumoai.experimental.rfm.Table.infer_metadata`.
291
529
  """
292
530
  for table in self.tables.values():
293
531
  table.infer_metadata(verbose=False)
@@ -305,7 +543,7 @@ class LocalGraph:
305
543
  return self._edges
306
544
 
307
545
  def print_links(self) -> None:
308
- r"""Prints the :meth:`~LocalGraph.edges` of the graph."""
546
+ r"""Prints the :meth:`~Graph.edges` of the graph."""
309
547
  edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
310
548
  edge.src_table, edge.fkey) for edge in self.edges]
311
549
  edges = sorted(edges)
@@ -333,9 +571,9 @@ class LocalGraph:
333
571
 
334
572
  def link(
335
573
  self,
336
- src_table: Union[str, LocalTable],
574
+ src_table: Union[str, Table],
337
575
  fkey: str,
338
- dst_table: Union[str, LocalTable],
576
+ dst_table: Union[str, Table],
339
577
  ) -> Self:
340
578
  r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
341
579
  key ``fkey`` in the source table to the primary key in the destination
@@ -358,11 +596,11 @@ class LocalGraph:
358
596
  table does not exist in the graph, if the source key does not
359
597
  exist in the source table.
360
598
  """
361
- if isinstance(src_table, LocalTable):
599
+ if isinstance(src_table, Table):
362
600
  src_table = src_table.name
363
601
  assert isinstance(src_table, str)
364
602
 
365
- if isinstance(dst_table, LocalTable):
603
+ if isinstance(dst_table, Table):
366
604
  dst_table = dst_table.name
367
605
  assert isinstance(dst_table, str)
368
606
 
@@ -396,9 +634,9 @@ class LocalGraph:
396
634
 
397
635
  def unlink(
398
636
  self,
399
- src_table: Union[str, LocalTable],
637
+ src_table: Union[str, Table],
400
638
  fkey: str,
401
- dst_table: Union[str, LocalTable],
639
+ dst_table: Union[str, Table],
402
640
  ) -> Self:
403
641
  r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
404
642
 
@@ -410,11 +648,11 @@ class LocalGraph:
410
648
  Raises:
411
649
  ValueError: if the edge is not present in the graph.
412
650
  """
413
- if isinstance(src_table, LocalTable):
651
+ if isinstance(src_table, Table):
414
652
  src_table = src_table.name
415
653
  assert isinstance(src_table, str)
416
654
 
417
- if isinstance(dst_table, LocalTable):
655
+ if isinstance(dst_table, Table):
418
656
  dst_table = dst_table.name
419
657
  assert isinstance(dst_table, str)
420
658
 
@@ -428,17 +666,13 @@ class LocalGraph:
428
666
  return self
429
667
 
430
668
  def infer_links(self, verbose: bool = True) -> Self:
431
- r"""Infers links for the tables and adds them as edges to the graph.
669
+ r"""Infers missing links for the tables and adds them as edges to the
670
+ graph.
432
671
 
433
672
  Args:
434
673
  verbose: Whether to print verbose output.
435
-
436
- Note:
437
- This function expects graph edges to be undefined upfront.
438
674
  """
439
- if len(self.edges) > 0:
440
- warnings.warn("Cannot infer links if graph edges already exist")
441
- return self
675
+ known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
442
676
 
443
677
  # A list of primary key candidates (+score) for every column:
444
678
  candidate_dict: dict[
@@ -463,6 +697,9 @@ class LocalGraph:
463
697
  src_table_name = src_table.name.lower()
464
698
 
465
699
  for src_key in src_table.columns:
700
+ if (src_table.name, src_key.name) in known_edges:
701
+ continue
702
+
466
703
  if src_key == src_table.primary_key:
467
704
  continue # Cannot link to primary key.
468
705
 
@@ -528,7 +765,9 @@ class LocalGraph:
528
765
  score += 1.0
529
766
 
530
767
  # Cardinality ratio:
531
- if len(src_table._data) > len(dst_table._data):
768
+ if (src_table._num_rows is not None
769
+ and dst_table._num_rows is not None
770
+ and src_table._num_rows > dst_table._num_rows):
532
771
  score += 1.0
533
772
 
534
773
  if score < 5.0:
@@ -790,7 +1029,7 @@ class LocalGraph:
790
1029
  def __contains__(self, name: str) -> bool:
791
1030
  return self.has_table(name)
792
1031
 
793
- def __getitem__(self, name: str) -> LocalTable:
1032
+ def __getitem__(self, name: str) -> Table:
794
1033
  return self.table(name)
795
1034
 
796
1035
  def __delitem__(self, name: str) -> None:
@@ -1,9 +1,15 @@
1
+ from .dtype import infer_dtype
2
+ from .pkey import infer_primary_key
3
+ from .time_col import infer_time_column
1
4
  from .id import contains_id
2
5
  from .timestamp import contains_timestamp
3
6
  from .categorical import contains_categorical
4
7
  from .multicategorical import contains_multicategorical
5
8
 
6
9
  __all__ = [
10
+ 'infer_dtype',
11
+ 'infer_primary_key',
12
+ 'infer_time_column',
7
13
  'contains_id',
8
14
  'contains_timestamp',
9
15
  'contains_categorical',
@@ -0,0 +1,79 @@
1
+ from typing import Dict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import pyarrow as pa
6
+ from kumoapi.typing import Dtype
7
+
8
+ PANDAS_TO_DTYPE: Dict[str, Dtype] = {
9
+ 'bool': Dtype.bool,
10
+ 'boolean': Dtype.bool,
11
+ 'int8': Dtype.int,
12
+ 'int16': Dtype.int,
13
+ 'int32': Dtype.int,
14
+ 'int64': Dtype.int,
15
+ 'float16': Dtype.float,
16
+ 'float32': Dtype.float,
17
+ 'float64': Dtype.float,
18
+ 'object': Dtype.string,
19
+ 'string': Dtype.string,
20
+ 'string[python]': Dtype.string,
21
+ 'string[pyarrow]': Dtype.string,
22
+ 'binary': Dtype.binary,
23
+ }
24
+
25
+
26
+ def infer_dtype(ser: pd.Series) -> Dtype:
27
+ """Extracts the :class:`Dtype` from a :class:`pandas.Series`.
28
+
29
+ Args:
30
+ ser: A :class:`pandas.Series` to analyze.
31
+
32
+ Returns:
33
+ The data type.
34
+ """
35
+ if pd.api.types.is_datetime64_any_dtype(ser.dtype):
36
+ return Dtype.date
37
+ if pd.api.types.is_timedelta64_dtype(ser.dtype):
38
+ return Dtype.timedelta
39
+ if isinstance(ser.dtype, pd.CategoricalDtype):
40
+ return Dtype.string
41
+
42
+ if (pd.api.types.is_object_dtype(ser.dtype)
43
+ and not isinstance(ser.dtype, pd.ArrowDtype)):
44
+ index = ser.iloc[:1000].first_valid_index()
45
+ if index is not None and pd.api.types.is_list_like(ser[index]):
46
+ pos = ser.index.get_loc(index)
47
+ assert isinstance(pos, int)
48
+ ser = ser.iloc[pos:pos + 1000].dropna()
49
+ arr = pa.array(ser.tolist())
50
+ ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
+
52
+ if isinstance(ser.dtype, pd.ArrowDtype):
53
+ if pa.types.is_list(ser.dtype.pyarrow_dtype):
54
+ elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
+ if pa.types.is_integer(elem_dtype):
56
+ return Dtype.intlist
57
+ if pa.types.is_floating(elem_dtype):
58
+ return Dtype.floatlist
59
+ if pa.types.is_decimal(elem_dtype):
60
+ return Dtype.floatlist
61
+ if pa.types.is_string(elem_dtype):
62
+ return Dtype.stringlist
63
+ if pa.types.is_null(elem_dtype):
64
+ return Dtype.floatlist
65
+
66
+ if isinstance(ser.dtype, np.dtype):
67
+ dtype_str = str(ser.dtype).lower()
68
+ elif isinstance(ser.dtype, pd.api.extensions.ExtensionDtype):
69
+ dtype_str = ser.dtype.name.lower()
70
+ dtype_str = dtype_str.split('[')[0] # Remove backend metadata
71
+ elif isinstance(ser.dtype, pa.DataType):
72
+ dtype_str = str(ser.dtype).lower()
73
+ else:
74
+ dtype_str = 'object'
75
+
76
+ if dtype_str not in PANDAS_TO_DTYPE:
77
+ raise ValueError(f"Unsupported data type '{ser.dtype}'")
78
+
79
+ return PANDAS_TO_DTYPE[dtype_str]