kumoai 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202512031731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/experimental/rfm/backend/local/table.py +18 -74
- kumoai/experimental/rfm/backend/snow/__init__.py +35 -0
- kumoai/experimental/rfm/backend/snow/table.py +95 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +7 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +56 -79
- kumoai/experimental/rfm/base/__init__.py +3 -0
- kumoai/experimental/rfm/base/source.py +18 -0
- kumoai/experimental/rfm/base/table.py +88 -21
- kumoai/experimental/rfm/graph.py +192 -39
- kumoai/experimental/rfm/infer/__init__.py +6 -0
- kumoai/experimental/rfm/infer/dtype.py +90 -0
- kumoai/experimental/rfm/infer/pkey.py +126 -0
- kumoai/experimental/rfm/infer/time_col.py +62 -0
- kumoai/experimental/rfm/local_graph_sampler.py +42 -1
- kumoai/experimental/rfm/local_graph_store.py +1 -16
- kumoai/experimental/rfm/rfm.py +1 -11
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.13.0.dev202512031731.dist-info}/METADATA +3 -1
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.13.0.dev202512031731.dist-info}/RECORD +22 -17
- kumoai/experimental/rfm/utils.py +0 -344
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.13.0.dev202512031731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.13.0.dev202512031731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202512011731.dist-info → kumoai-2.13.0.dev202512031731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -2,8 +2,10 @@ import contextlib
|
|
|
2
2
|
import io
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
|
+
from dataclasses import dataclass, field
|
|
5
6
|
from importlib.util import find_spec
|
|
6
|
-
from
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
9
|
|
|
8
10
|
import pandas as pd
|
|
9
11
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -14,9 +16,18 @@ from typing_extensions import Self
|
|
|
14
16
|
from kumoai import in_notebook
|
|
15
17
|
from kumoai.experimental.rfm import Table
|
|
16
18
|
from kumoai.graph import Edge
|
|
19
|
+
from kumoai.mixin import CastMixin
|
|
17
20
|
|
|
18
21
|
if TYPE_CHECKING:
|
|
19
22
|
import graphviz
|
|
23
|
+
from adbc_driver_sqlite.dbapi import AdbcSqliteConnection
|
|
24
|
+
from snowflake.connector import SnowflakeConnection
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SqliteConnectionConfig(CastMixin):
|
|
29
|
+
uri: Union[str, Path]
|
|
30
|
+
kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
20
31
|
|
|
21
32
|
|
|
22
33
|
class Graph:
|
|
@@ -85,10 +96,24 @@ class Graph:
|
|
|
85
96
|
for table in tables:
|
|
86
97
|
self.add_table(table)
|
|
87
98
|
|
|
99
|
+
for table in tables:
|
|
100
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
101
|
+
if fkey.name not in table or fkey.dst_table not in self:
|
|
102
|
+
continue
|
|
103
|
+
if self[fkey.dst_table].primary_key is None:
|
|
104
|
+
self[fkey.dst_table].primary_key = fkey.primary_key
|
|
105
|
+
elif self[fkey.dst_table]._primary_key != fkey.primary_key:
|
|
106
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
107
|
+
f"'{self[fkey.dst_table]._primary_key}' "
|
|
108
|
+
f"and '{fkey.primary_key}' in table "
|
|
109
|
+
f"'{fkey.dst_table}'.")
|
|
110
|
+
self.link(table.name, fkey.name, fkey.dst_table)
|
|
111
|
+
|
|
88
112
|
for edge in (edges or []):
|
|
89
113
|
_edge = Edge._cast(edge)
|
|
90
114
|
assert _edge is not None
|
|
91
|
-
self.
|
|
115
|
+
if _edge not in self._edges:
|
|
116
|
+
self.link(*_edge)
|
|
92
117
|
|
|
93
118
|
@classmethod
|
|
94
119
|
def from_data(
|
|
@@ -101,7 +126,7 @@ class Graph:
|
|
|
101
126
|
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
127
|
:class:`pandas.DataFrame` objects.
|
|
103
128
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
129
|
+
Automatically infers table metadata and links by default.
|
|
105
130
|
|
|
106
131
|
.. code-block:: python
|
|
107
132
|
|
|
@@ -121,50 +146,180 @@ class Graph:
|
|
|
121
146
|
... "table3": df3,
|
|
122
147
|
... })
|
|
123
148
|
|
|
124
|
-
>>> # Inspect table metadata:
|
|
125
|
-
>>> for table in graph.tables.values():
|
|
126
|
-
... table.print_metadata()
|
|
127
|
-
|
|
128
|
-
>>> # Visualize graph (if graphviz is installed):
|
|
129
|
-
>>> graph.visualize()
|
|
130
|
-
|
|
131
149
|
Args:
|
|
132
150
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
151
|
of the tables and the values hold table data.
|
|
152
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
153
|
+
add to the graph. If not provided, edges will be automatically
|
|
154
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
134
155
|
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
156
|
graph.
|
|
157
|
+
verbose: Whether to print verbose output.
|
|
158
|
+
"""
|
|
159
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
160
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
161
|
+
|
|
162
|
+
graph = cls(tables, edges=edges or [])
|
|
163
|
+
|
|
164
|
+
if infer_metadata:
|
|
165
|
+
graph.infer_metadata(False)
|
|
166
|
+
|
|
167
|
+
if edges is None:
|
|
168
|
+
graph.infer_links(False)
|
|
169
|
+
|
|
170
|
+
if verbose:
|
|
171
|
+
graph.print_metadata()
|
|
172
|
+
graph.print_links()
|
|
173
|
+
|
|
174
|
+
return graph
|
|
175
|
+
|
|
176
|
+
@classmethod
|
|
177
|
+
def from_sqlite(
|
|
178
|
+
cls,
|
|
179
|
+
connection: Union[
|
|
180
|
+
'AdbcSqliteConnection',
|
|
181
|
+
SqliteConnectionConfig,
|
|
182
|
+
str,
|
|
183
|
+
Path,
|
|
184
|
+
Dict[str, Any],
|
|
185
|
+
],
|
|
186
|
+
table_names: Optional[Sequence[str]] = None,
|
|
187
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
188
|
+
infer_metadata: bool = True,
|
|
189
|
+
verbose: bool = True,
|
|
190
|
+
) -> Self:
|
|
191
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
192
|
+
|
|
193
|
+
Automatically infers table metadata and links by default.
|
|
194
|
+
|
|
195
|
+
.. code-block:: python
|
|
196
|
+
|
|
197
|
+
>>> # doctest: +SKIP
|
|
198
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
199
|
+
|
|
200
|
+
>>> # Create a graph from a SQLite database:
|
|
201
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
connection: An open connection from
|
|
205
|
+
:meth:`~kumoai.experimental.rfm.backend.sqlite.connect` or the
|
|
206
|
+
path to the database file.
|
|
207
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
208
|
+
all tables present in the database.
|
|
136
209
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
210
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
211
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
212
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
213
|
+
graph.
|
|
139
214
|
verbose: Whether to print verbose output.
|
|
215
|
+
"""
|
|
216
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
217
|
+
Connection,
|
|
218
|
+
SQLiteTable,
|
|
219
|
+
connect,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if not isinstance(connection, Connection):
|
|
223
|
+
connection = SqliteConnectionConfig._cast(connection)
|
|
224
|
+
assert isinstance(connection, SqliteConnectionConfig)
|
|
225
|
+
connection = connect(connection.uri, **connection.kwargs)
|
|
226
|
+
assert isinstance(connection, Connection)
|
|
227
|
+
|
|
228
|
+
if table_names is None:
|
|
229
|
+
with connection.cursor() as cursor:
|
|
230
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
231
|
+
"WHERE type='table'")
|
|
232
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
233
|
+
|
|
234
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
140
235
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
236
|
+
graph = cls(tables, edges=edges or [])
|
|
237
|
+
|
|
238
|
+
if infer_metadata:
|
|
239
|
+
graph.infer_metadata(False)
|
|
240
|
+
|
|
241
|
+
if edges is None:
|
|
242
|
+
graph.infer_links(False)
|
|
243
|
+
|
|
244
|
+
if verbose:
|
|
245
|
+
graph.print_metadata()
|
|
246
|
+
graph.print_links()
|
|
247
|
+
|
|
248
|
+
return graph
|
|
249
|
+
|
|
250
|
+
@classmethod
|
|
251
|
+
def from_snowflake(
|
|
252
|
+
cls,
|
|
253
|
+
connection: Union['SnowflakeConnection', Dict[str, Any], None] = None,
|
|
254
|
+
table_names: Optional[Sequence[str]] = None,
|
|
255
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
256
|
+
infer_metadata: bool = True,
|
|
257
|
+
verbose: bool = True,
|
|
258
|
+
) -> Self:
|
|
259
|
+
r"""Creates a :class:`Graph` from a :class:`snowflake` database and
|
|
260
|
+
schema.
|
|
261
|
+
|
|
262
|
+
Automatically infers table metadata and links by default.
|
|
263
|
+
|
|
264
|
+
.. code-block:: python
|
|
144
265
|
|
|
145
|
-
Example:
|
|
146
266
|
>>> # doctest: +SKIP
|
|
147
267
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
|
|
149
|
-
>>>
|
|
150
|
-
>>>
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
268
|
+
|
|
269
|
+
>>> # Create a graph directly in a Snowflake notebook:
|
|
270
|
+
>>> graph = rfm.Graph.from_snowflake()
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
connection: An open connection from
|
|
274
|
+
:meth:`~kumoai.experimental.rfm.backend.snow.connect` or the
|
|
275
|
+
:class:`snowflake` connector keyword arguments to open a new
|
|
276
|
+
connection. If ``None``, will re-use an active session in case
|
|
277
|
+
it exists, or create a new connection from credentials stored
|
|
278
|
+
in environment variables.
|
|
279
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
280
|
+
all tables present in the database.
|
|
281
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
282
|
+
add to the graph. If not provided, edges will be automatically
|
|
283
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
284
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
285
|
+
graph.
|
|
286
|
+
verbose: Whether to print verbose output.
|
|
157
287
|
"""
|
|
158
|
-
from kumoai.experimental.rfm import
|
|
159
|
-
|
|
288
|
+
from kumoai.experimental.rfm.backend.snow import (
|
|
289
|
+
Connection,
|
|
290
|
+
SnowTable,
|
|
291
|
+
connect,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if not isinstance(connection, Connection):
|
|
295
|
+
connection = connect(**(connection or {}))
|
|
296
|
+
assert isinstance(connection, Connection)
|
|
297
|
+
|
|
298
|
+
if table_names is None:
|
|
299
|
+
with connection.cursor() as cursor:
|
|
300
|
+
cursor.execute("SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()")
|
|
301
|
+
database, schema = cursor.fetchone()
|
|
302
|
+
query = f"""
|
|
303
|
+
SELECT TABLE_NAME
|
|
304
|
+
FROM {database}.INFORMATION_SCHEMA.TABLES
|
|
305
|
+
WHERE TABLE_SCHEMA = '{schema}'
|
|
306
|
+
"""
|
|
307
|
+
cursor.execute(query)
|
|
308
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
309
|
+
|
|
310
|
+
tables = [SnowTable(connection, name) for name in table_names]
|
|
160
311
|
|
|
161
312
|
graph = cls(tables, edges=edges or [])
|
|
162
313
|
|
|
163
314
|
if infer_metadata:
|
|
164
|
-
graph.infer_metadata(
|
|
315
|
+
graph.infer_metadata(False)
|
|
165
316
|
|
|
166
317
|
if edges is None:
|
|
167
|
-
graph.infer_links(
|
|
318
|
+
graph.infer_links(False)
|
|
319
|
+
|
|
320
|
+
if verbose:
|
|
321
|
+
graph.print_metadata()
|
|
322
|
+
graph.print_links()
|
|
168
323
|
|
|
169
324
|
return graph
|
|
170
325
|
|
|
@@ -439,17 +594,13 @@ class Graph:
|
|
|
439
594
|
return self
|
|
440
595
|
|
|
441
596
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
442
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
597
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
598
|
+
graph.
|
|
443
599
|
|
|
444
600
|
Args:
|
|
445
601
|
verbose: Whether to print verbose output.
|
|
446
|
-
|
|
447
|
-
Note:
|
|
448
|
-
This function expects graph edges to be undefined upfront.
|
|
449
602
|
"""
|
|
450
|
-
|
|
451
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
452
|
-
return self
|
|
603
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
453
604
|
|
|
454
605
|
# A list of primary key candidates (+score) for every column:
|
|
455
606
|
candidate_dict: dict[
|
|
@@ -474,6 +625,9 @@ class Graph:
|
|
|
474
625
|
src_table_name = src_table.name.lower()
|
|
475
626
|
|
|
476
627
|
for src_key in src_table.columns:
|
|
628
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
629
|
+
continue
|
|
630
|
+
|
|
477
631
|
if src_key == src_table.primary_key:
|
|
478
632
|
continue # Cannot link to primary key.
|
|
479
633
|
|
|
@@ -539,10 +693,9 @@ class Graph:
|
|
|
539
693
|
score += 1.0
|
|
540
694
|
|
|
541
695
|
# Cardinality ratio:
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
and src_num_rows > dst_num_rows):
|
|
696
|
+
if (src_table._num_rows is not None
|
|
697
|
+
and dst_table._num_rows is not None
|
|
698
|
+
and src_table._num_rows > dst_table._num_rows):
|
|
546
699
|
score += 1.0
|
|
547
700
|
|
|
548
701
|
if score < 5.0:
|
|
@@ -1,9 +1,15 @@
|
|
|
1
|
+
from .dtype import infer_dtype
|
|
2
|
+
from .pkey import infer_primary_key
|
|
3
|
+
from .time_col import infer_time_column
|
|
1
4
|
from .id import contains_id
|
|
2
5
|
from .timestamp import contains_timestamp
|
|
3
6
|
from .categorical import contains_categorical
|
|
4
7
|
from .multicategorical import contains_multicategorical
|
|
5
8
|
|
|
6
9
|
__all__ = [
|
|
10
|
+
'infer_dtype',
|
|
11
|
+
'infer_primary_key',
|
|
12
|
+
'infer_time_column',
|
|
7
13
|
'contains_id',
|
|
8
14
|
'contains_timestamp',
|
|
9
15
|
'contains_categorical',
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import pyarrow as pa
|
|
6
|
+
from kumoapi.typing import Dtype
|
|
7
|
+
|
|
8
|
+
PANDAS_TO_DTYPE: Dict[Any, Dtype] = {
|
|
9
|
+
np.dtype('bool'): Dtype.bool,
|
|
10
|
+
pd.BooleanDtype(): Dtype.bool,
|
|
11
|
+
pa.bool_(): Dtype.bool,
|
|
12
|
+
np.dtype('byte'): Dtype.int,
|
|
13
|
+
pd.UInt8Dtype(): Dtype.int,
|
|
14
|
+
np.dtype('int16'): Dtype.int,
|
|
15
|
+
pd.Int16Dtype(): Dtype.int,
|
|
16
|
+
np.dtype('int32'): Dtype.int,
|
|
17
|
+
pd.Int32Dtype(): Dtype.int,
|
|
18
|
+
np.dtype('int64'): Dtype.int,
|
|
19
|
+
pd.Int64Dtype(): Dtype.int,
|
|
20
|
+
np.dtype('float32'): Dtype.float,
|
|
21
|
+
pd.Float32Dtype(): Dtype.float,
|
|
22
|
+
np.dtype('float64'): Dtype.float,
|
|
23
|
+
pd.Float64Dtype(): Dtype.float,
|
|
24
|
+
np.dtype('object'): Dtype.string,
|
|
25
|
+
pd.StringDtype(storage='python'): Dtype.string,
|
|
26
|
+
pd.StringDtype(storage='pyarrow'): Dtype.string,
|
|
27
|
+
pa.string(): Dtype.string,
|
|
28
|
+
pa.binary(): Dtype.binary,
|
|
29
|
+
np.dtype('datetime64[ns]'): Dtype.date,
|
|
30
|
+
np.dtype('timedelta64[ns]'): Dtype.timedelta,
|
|
31
|
+
pa.list_(pa.float32()): Dtype.floatlist,
|
|
32
|
+
pa.list_(pa.int64()): Dtype.intlist,
|
|
33
|
+
pa.list_(pa.string()): Dtype.stringlist,
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def infer_dtype(ser: pd.Series) -> Dtype:
|
|
38
|
+
"""Extracts the :class:`Dtype` from a :class:`pandas.Series`.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
The data type.
|
|
45
|
+
"""
|
|
46
|
+
if pd.api.types.is_datetime64_any_dtype(ser.dtype):
|
|
47
|
+
return Dtype.date
|
|
48
|
+
|
|
49
|
+
if isinstance(ser.dtype, pd.CategoricalDtype):
|
|
50
|
+
return Dtype.string
|
|
51
|
+
|
|
52
|
+
if pd.api.types.is_object_dtype(ser.dtype):
|
|
53
|
+
index = ser.iloc[:1000].first_valid_index()
|
|
54
|
+
if index is not None and pd.api.types.is_list_like(ser[index]):
|
|
55
|
+
pos = ser.index.get_loc(index)
|
|
56
|
+
assert isinstance(pos, int)
|
|
57
|
+
ser = ser.iloc[pos:pos + 1000].dropna()
|
|
58
|
+
|
|
59
|
+
if not ser.map(pd.api.types.is_list_like).all():
|
|
60
|
+
raise ValueError("Data contains a mix of list-like and "
|
|
61
|
+
"non-list-like values")
|
|
62
|
+
|
|
63
|
+
# Remove all empty Python lists without known data type:
|
|
64
|
+
ser = ser[ser.map(lambda x: not isinstance(x, list) or len(x) > 0)]
|
|
65
|
+
|
|
66
|
+
# Infer unique data types in this series:
|
|
67
|
+
dtypes = ser.apply(lambda x: PANDAS_TO_DTYPE.get(
|
|
68
|
+
np.array(x).dtype, Dtype.string)).unique().tolist()
|
|
69
|
+
|
|
70
|
+
invalid_dtypes = set(dtypes) - {
|
|
71
|
+
Dtype.string,
|
|
72
|
+
Dtype.int,
|
|
73
|
+
Dtype.float,
|
|
74
|
+
}
|
|
75
|
+
if len(invalid_dtypes) > 0:
|
|
76
|
+
raise ValueError(f"Data contains unsupported list data types: "
|
|
77
|
+
f"{list(invalid_dtypes)}")
|
|
78
|
+
|
|
79
|
+
if Dtype.string in dtypes:
|
|
80
|
+
return Dtype.stringlist
|
|
81
|
+
|
|
82
|
+
if dtypes == [Dtype.int]:
|
|
83
|
+
return Dtype.intlist
|
|
84
|
+
|
|
85
|
+
return Dtype.floatlist
|
|
86
|
+
|
|
87
|
+
if ser.dtype not in PANDAS_TO_DTYPE:
|
|
88
|
+
raise ValueError(f"Unsupported data type '{ser.dtype}'")
|
|
89
|
+
|
|
90
|
+
return PANDAS_TO_DTYPE[ser.dtype]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_primary_key(
|
|
9
|
+
table_name: str,
|
|
10
|
+
df: pd.DataFrame,
|
|
11
|
+
candidates: list[str],
|
|
12
|
+
) -> Optional[str]:
|
|
13
|
+
r"""Auto-detect potential primary key column.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
table_name: The table name.
|
|
17
|
+
df: The pandas DataFrame to analyze.
|
|
18
|
+
candidates: A list of potential candidates.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The name of the detected primary key, or ``None`` if not found.
|
|
22
|
+
"""
|
|
23
|
+
# A list of (potentially modified) table names that are eligible to match
|
|
24
|
+
# with a primary key, i.e.:
|
|
25
|
+
# - UserInfo -> User
|
|
26
|
+
# - snakecase <-> camelcase
|
|
27
|
+
# - camelcase <-> snakecase
|
|
28
|
+
# - plural <-> singular (users -> user, eligibilities -> eligibility)
|
|
29
|
+
# - verb -> noun (qualifying -> qualify)
|
|
30
|
+
_table_names = {table_name}
|
|
31
|
+
if table_name.lower().endswith('_info'):
|
|
32
|
+
_table_names.add(table_name[:-5])
|
|
33
|
+
elif table_name.lower().endswith('info'):
|
|
34
|
+
_table_names.add(table_name[:-4])
|
|
35
|
+
|
|
36
|
+
table_names = set()
|
|
37
|
+
for _table_name in _table_names:
|
|
38
|
+
table_names.add(_table_name.lower())
|
|
39
|
+
snakecase = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', _table_name)
|
|
40
|
+
snakecase = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', snakecase)
|
|
41
|
+
table_names.add(snakecase.lower())
|
|
42
|
+
camelcase = _table_name.replace('_', '')
|
|
43
|
+
table_names.add(camelcase.lower())
|
|
44
|
+
if _table_name.lower().endswith('s'):
|
|
45
|
+
table_names.add(_table_name.lower()[:-1])
|
|
46
|
+
table_names.add(snakecase.lower()[:-1])
|
|
47
|
+
table_names.add(camelcase.lower()[:-1])
|
|
48
|
+
else:
|
|
49
|
+
table_names.add(_table_name.lower() + 's')
|
|
50
|
+
table_names.add(snakecase.lower() + 's')
|
|
51
|
+
table_names.add(camelcase.lower() + 's')
|
|
52
|
+
if _table_name.lower().endswith('ies'):
|
|
53
|
+
table_names.add(_table_name.lower()[:-3] + 'y')
|
|
54
|
+
table_names.add(snakecase.lower()[:-3] + 'y')
|
|
55
|
+
table_names.add(camelcase.lower()[:-3] + 'y')
|
|
56
|
+
elif _table_name.lower().endswith('y'):
|
|
57
|
+
table_names.add(_table_name.lower()[:-1] + 'ies')
|
|
58
|
+
table_names.add(snakecase.lower()[:-1] + 'ies')
|
|
59
|
+
table_names.add(camelcase.lower()[:-1] + 'ies')
|
|
60
|
+
if _table_name.lower().endswith('ing'):
|
|
61
|
+
table_names.add(_table_name.lower()[:-3])
|
|
62
|
+
table_names.add(snakecase.lower()[:-3])
|
|
63
|
+
table_names.add(camelcase.lower()[:-3])
|
|
64
|
+
|
|
65
|
+
scores: list[tuple[str, int]] = []
|
|
66
|
+
for col_name in candidates:
|
|
67
|
+
col_name_lower = col_name.lower()
|
|
68
|
+
|
|
69
|
+
score = 0
|
|
70
|
+
|
|
71
|
+
if col_name_lower == 'id':
|
|
72
|
+
score += 4
|
|
73
|
+
|
|
74
|
+
for table_name_lower in table_names:
|
|
75
|
+
|
|
76
|
+
if col_name_lower == table_name_lower:
|
|
77
|
+
score += 4 # USER -> USER
|
|
78
|
+
break
|
|
79
|
+
|
|
80
|
+
for suffix in ['id', 'hash', 'key', 'code', 'uuid']:
|
|
81
|
+
if not col_name_lower.endswith(suffix):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
if col_name_lower == f'{table_name_lower}_{suffix}':
|
|
85
|
+
score += 5 # USER -> USER_ID
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
if col_name_lower == f'{table_name_lower}{suffix}':
|
|
89
|
+
score += 5 # User -> UserId
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
if col_name_lower.endswith(f'{table_name_lower}_{suffix}'):
|
|
93
|
+
score += 2
|
|
94
|
+
|
|
95
|
+
if col_name_lower.endswith(f'{table_name_lower}{suffix}'):
|
|
96
|
+
score += 2
|
|
97
|
+
|
|
98
|
+
# `rel-bench` hard-coding :(
|
|
99
|
+
if table_name == 'studies' and col_name == 'nct_id':
|
|
100
|
+
score += 1
|
|
101
|
+
|
|
102
|
+
ser = df[col_name].iloc[:1_000_000]
|
|
103
|
+
score += 3 * (ser.nunique() / len(ser))
|
|
104
|
+
|
|
105
|
+
scores.append((col_name, score))
|
|
106
|
+
|
|
107
|
+
scores = [x for x in scores if x[-1] >= 4]
|
|
108
|
+
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
109
|
+
|
|
110
|
+
if len(scores) == 0:
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
if len(scores) == 1:
|
|
114
|
+
return scores[0][0]
|
|
115
|
+
|
|
116
|
+
# In case of multiple candidates, only return one if its score is unique:
|
|
117
|
+
if scores[0][1] != scores[1][1]:
|
|
118
|
+
return scores[0][0]
|
|
119
|
+
|
|
120
|
+
max_score = max(scores, key=lambda x: x[1])
|
|
121
|
+
candidates = [col_name for col_name, score in scores if score == max_score]
|
|
122
|
+
warnings.warn(f"Found multiple potential primary keys in table "
|
|
123
|
+
f"'{table_name}': {candidates}. Please specify the primary "
|
|
124
|
+
f"key for this table manually.")
|
|
125
|
+
|
|
126
|
+
return None
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import warnings
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def infer_time_column(
|
|
9
|
+
df: pd.DataFrame,
|
|
10
|
+
candidates: list[str],
|
|
11
|
+
) -> Optional[str]:
|
|
12
|
+
r"""Auto-detect potential time column.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
df: The pandas DataFrame to analyze.
|
|
16
|
+
candidates: A list of potential candidates.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
The name of the detected time column, or ``None`` if not found.
|
|
20
|
+
"""
|
|
21
|
+
candidates = [ # Exclude all candidates with `*last*` in column names:
|
|
22
|
+
col_name for col_name in candidates
|
|
23
|
+
if not re.search(r'(^|_)last(_|$)', col_name, re.IGNORECASE)
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
if len(candidates) == 0:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
if len(candidates) == 1:
|
|
30
|
+
return candidates[0]
|
|
31
|
+
|
|
32
|
+
# If there exists a dedicated `create*` column, use it as time column:
|
|
33
|
+
create_candidates = [
|
|
34
|
+
candidate for candidate in candidates
|
|
35
|
+
if candidate.lower().startswith('create')
|
|
36
|
+
]
|
|
37
|
+
if len(create_candidates) == 1:
|
|
38
|
+
return create_candidates[0]
|
|
39
|
+
if len(create_candidates) > 1:
|
|
40
|
+
candidates = create_candidates
|
|
41
|
+
|
|
42
|
+
# Find the most optimal time column. Usually, it is the one pointing to
|
|
43
|
+
# the oldest timestamps:
|
|
44
|
+
with warnings.catch_warnings():
|
|
45
|
+
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
46
|
+
min_timestamp_dict = {
|
|
47
|
+
key: pd.to_datetime(df[key].iloc[:10_000], 'coerce')
|
|
48
|
+
for key in candidates
|
|
49
|
+
}
|
|
50
|
+
min_timestamp_dict = {
|
|
51
|
+
key: value.min().tz_localize(None)
|
|
52
|
+
for key, value in min_timestamp_dict.items()
|
|
53
|
+
}
|
|
54
|
+
min_timestamp_dict = {
|
|
55
|
+
key: value
|
|
56
|
+
for key, value in min_timestamp_dict.items() if not pd.isna(value)
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
if len(min_timestamp_dict) == 0:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
return min(min_timestamp_dict, key=min_timestamp_dict.get) # type: ignore
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import re
|
|
1
2
|
from typing import Dict, List, Optional, Tuple
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -7,7 +8,47 @@ from kumoapi.typing import Stype
|
|
|
7
8
|
|
|
8
9
|
import kumoai.kumolib as kumolib
|
|
9
10
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
10
|
-
|
|
11
|
+
|
|
12
|
+
PUNCTUATION = re.compile(r"[\'\"\.,\(\)\!\?\;\:]")
|
|
13
|
+
MULTISPACE = re.compile(r"\s+")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def normalize_text(
|
|
17
|
+
ser: pd.Series,
|
|
18
|
+
max_words: Optional[int] = 50,
|
|
19
|
+
) -> pd.Series:
|
|
20
|
+
r"""Normalizes text into a list of lower-case words.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
ser: The :class:`pandas.Series` to normalize.
|
|
24
|
+
max_words: The maximum number of words to return.
|
|
25
|
+
This will auto-shrink any large text column to avoid blowing up
|
|
26
|
+
context size.
|
|
27
|
+
"""
|
|
28
|
+
if len(ser) == 0 or pd.api.types.is_list_like(ser.iloc[0]):
|
|
29
|
+
return ser
|
|
30
|
+
|
|
31
|
+
def normalize_fn(line: str) -> list[str]:
|
|
32
|
+
line = PUNCTUATION.sub(" ", line)
|
|
33
|
+
line = re.sub(r"<br\s*/?>", " ", line) # Handle <br /> or <br>
|
|
34
|
+
line = MULTISPACE.sub(" ", line)
|
|
35
|
+
words = line.split()
|
|
36
|
+
if max_words is not None:
|
|
37
|
+
words = words[:max_words]
|
|
38
|
+
return words
|
|
39
|
+
|
|
40
|
+
ser = ser.fillna('').astype(str)
|
|
41
|
+
|
|
42
|
+
if max_words is not None:
|
|
43
|
+
# We estimate the number of words as 5 characters + 1 space in an
|
|
44
|
+
# English text on average. We need this pre-filter here, as word
|
|
45
|
+
# splitting on a giant text can be very expensive:
|
|
46
|
+
ser = ser.str[:6 * max_words]
|
|
47
|
+
|
|
48
|
+
ser = ser.str.lower()
|
|
49
|
+
ser = ser.map(normalize_fn)
|
|
50
|
+
|
|
51
|
+
return ser
|
|
11
52
|
|
|
12
53
|
|
|
13
54
|
class LocalGraphSampler:
|