kumoai 2.13.0.dev202511261731__cp310-cp310-win_amd64.whl → 2.13.0.dev202512021731__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/_version.py +1 -1
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +20 -45
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +244 -0
- kumoai/experimental/rfm/backend/snow/__init__.py +32 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +30 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +124 -0
- kumoai/experimental/rfm/base/__init__.py +7 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +71 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +144 -57
- kumoai/experimental/rfm/infer/__init__.py +2 -0
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/local_graph_store.py +12 -11
- kumoai/experimental/rfm/rfm.py +5 -5
- kumoai/experimental/rfm/sagemaker.py +11 -3
- kumoai/experimental/rfm/utils.py +1 -120
- kumoai/kumolib.cp310-win_amd64.pyd +0 -0
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/METADATA +8 -8
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/RECORD +26 -17
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511261731.dist-info → kumoai-2.13.0.dev202512021731.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ import io
|
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from importlib.util import find_spec
|
|
6
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -12,19 +12,19 @@ from kumoapi.typing import Stype
|
|
|
12
12
|
from typing_extensions import Self
|
|
13
13
|
|
|
14
14
|
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai.experimental.rfm import Table
|
|
16
16
|
from kumoai.graph import Edge
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
import graphviz
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class
|
|
23
|
-
r"""A graph of :class:`
|
|
22
|
+
class Graph:
|
|
23
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
24
|
tables in a relational database.
|
|
25
25
|
|
|
26
26
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
27
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
28
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
29
|
|
|
30
30
|
.. code-block:: python
|
|
@@ -44,7 +44,7 @@ class LocalGraph:
|
|
|
44
44
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
45
|
|
|
46
46
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
47
|
+
>>> graph = rfm.Graph({
|
|
48
48
|
... "table1": table1,
|
|
49
49
|
... "table2": table2,
|
|
50
50
|
... "table3": table3,
|
|
@@ -75,33 +75,44 @@ class LocalGraph:
|
|
|
75
75
|
|
|
76
76
|
def __init__(
|
|
77
77
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges: Optional[
|
|
78
|
+
tables: Sequence[Table],
|
|
79
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
80
80
|
) -> None:
|
|
81
81
|
|
|
82
|
-
self._tables: Dict[str,
|
|
82
|
+
self._tables: Dict[str, Table] = {}
|
|
83
83
|
self._edges: List[Edge] = []
|
|
84
84
|
|
|
85
85
|
for table in tables:
|
|
86
86
|
self.add_table(table)
|
|
87
87
|
|
|
88
|
+
for table in tables:
|
|
89
|
+
for fkey, dst_table, pkey in table._get_source_foreign_keys():
|
|
90
|
+
if self[dst_table].primary_key is None:
|
|
91
|
+
self[dst_table].primary_key = pkey
|
|
92
|
+
elif self[dst_table]._primary_key != pkey:
|
|
93
|
+
raise ValueError(f"Found duplicate primary key definition "
|
|
94
|
+
f"'{self[dst_table]._primary_key}' and "
|
|
95
|
+
f"'{pkey}' in table '{dst_table}'.")
|
|
96
|
+
self.link(table.name, fkey, dst_table)
|
|
97
|
+
|
|
88
98
|
for edge in (edges or []):
|
|
89
99
|
_edge = Edge._cast(edge)
|
|
90
100
|
assert _edge is not None
|
|
91
|
-
self.
|
|
101
|
+
if _edge not in self._edges:
|
|
102
|
+
self.link(*_edge)
|
|
92
103
|
|
|
93
104
|
@classmethod
|
|
94
105
|
def from_data(
|
|
95
106
|
cls,
|
|
96
107
|
df_dict: Dict[str, pd.DataFrame],
|
|
97
|
-
edges: Optional[
|
|
108
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
98
109
|
infer_metadata: bool = True,
|
|
99
110
|
verbose: bool = True,
|
|
100
111
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
112
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
113
|
:class:`pandas.DataFrame` objects.
|
|
103
114
|
|
|
104
|
-
Automatically infers table metadata and links.
|
|
115
|
+
Automatically infers table metadata and links by default.
|
|
105
116
|
|
|
106
117
|
.. code-block:: python
|
|
107
118
|
|
|
@@ -115,7 +126,7 @@ class LocalGraph:
|
|
|
115
126
|
>>> df3 = pd.DataFrame(...)
|
|
116
127
|
|
|
117
128
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
129
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
130
|
... "table1": df1,
|
|
120
131
|
... "table2": df2,
|
|
121
132
|
... "table3": df3,
|
|
@@ -131,39 +142,103 @@ class LocalGraph:
|
|
|
131
142
|
Args:
|
|
132
143
|
df_dict: A dictionary of data frames, where the keys are the names
|
|
133
144
|
of the tables and the values hold table data.
|
|
134
|
-
infer_metadata: Whether to infer metadata for all tables in the
|
|
135
|
-
graph.
|
|
136
145
|
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
137
146
|
add to the graph. If not provided, edges will be automatically
|
|
138
|
-
inferred from the data
|
|
147
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
148
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
149
|
+
graph.
|
|
139
150
|
verbose: Whether to print verbose output.
|
|
151
|
+
"""
|
|
152
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
153
|
+
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
140
154
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
155
|
+
graph = cls(tables, edges=edges or [])
|
|
156
|
+
|
|
157
|
+
if infer_metadata:
|
|
158
|
+
graph.infer_metadata(False)
|
|
159
|
+
|
|
160
|
+
if edges is None:
|
|
161
|
+
graph.infer_links(False)
|
|
162
|
+
|
|
163
|
+
if verbose:
|
|
164
|
+
graph.print_metadata()
|
|
165
|
+
graph.print_links()
|
|
166
|
+
|
|
167
|
+
return graph
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def from_sqlite(
|
|
171
|
+
cls,
|
|
172
|
+
uri: Any,
|
|
173
|
+
table_names: Optional[Sequence[str]] = None,
|
|
174
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
175
|
+
infer_metadata: bool = True,
|
|
176
|
+
verbose: bool = True,
|
|
177
|
+
conn_kwargs: Optional[Dict[str, Any]] = None,
|
|
178
|
+
) -> Self:
|
|
179
|
+
r"""Creates a :class:`Graph` from a :class:`sqlite` database.
|
|
180
|
+
|
|
181
|
+
Automatically infers table metadata and links by default.
|
|
182
|
+
|
|
183
|
+
.. code-block:: python
|
|
144
184
|
|
|
145
|
-
Example:
|
|
146
185
|
>>> # doctest: +SKIP
|
|
147
186
|
>>> import kumoai.experimental.rfm as rfm
|
|
148
|
-
|
|
149
|
-
>>>
|
|
150
|
-
>>>
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
...
|
|
155
|
-
|
|
156
|
-
>>> graph
|
|
187
|
+
|
|
188
|
+
>>> # Create a graph from a SQLite database:
|
|
189
|
+
>>> graph = rfm.Graph.from_sqlite('data.db')
|
|
190
|
+
|
|
191
|
+
>>> # Inspect table metadata:
|
|
192
|
+
>>> for table in graph.tables.values():
|
|
193
|
+
... table.print_metadata()
|
|
194
|
+
|
|
195
|
+
>>> # Visualize graph (if graphviz is installed):
|
|
196
|
+
>>> graph.visualize()
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
uri: The path to the database file or an open connection obtained
|
|
200
|
+
from :meth:`~kumoai.experimental.rfm.backend.sqlite.connect`.
|
|
201
|
+
table_names: Set of table names to include. If ``None``, will add
|
|
202
|
+
all tables present in the database.
|
|
203
|
+
edges: An optional list of :class:`~kumoai.graph.Edge` objects to
|
|
204
|
+
add to the graph. If not provided, edges will be automatically
|
|
205
|
+
inferred from the data in case ``infer_metadata=True``.
|
|
206
|
+
infer_metadata: Whether to infer metadata for all tables in the
|
|
207
|
+
graph.
|
|
208
|
+
verbose: Whether to print verbose output.
|
|
209
|
+
conn_kwargs: Additional connection arguments, following the
|
|
210
|
+
:class:`adbc_driver_sqlite` protocol.
|
|
157
211
|
"""
|
|
158
|
-
|
|
212
|
+
from kumoai.experimental.rfm.backend.sqlite import (
|
|
213
|
+
Connection,
|
|
214
|
+
SQLiteTable,
|
|
215
|
+
connect,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if not isinstance(uri, Connection):
|
|
219
|
+
connection = connect(uri, **(conn_kwargs or {}))
|
|
220
|
+
else:
|
|
221
|
+
connection = uri
|
|
222
|
+
|
|
223
|
+
if table_names is None:
|
|
224
|
+
with connection.cursor() as cursor:
|
|
225
|
+
cursor.execute("SELECT name FROM sqlite_master "
|
|
226
|
+
"WHERE type='table'")
|
|
227
|
+
table_names = [row[0] for row in cursor.fetchall()]
|
|
228
|
+
|
|
229
|
+
tables = [SQLiteTable(connection, name) for name in table_names]
|
|
159
230
|
|
|
160
231
|
graph = cls(tables, edges=edges or [])
|
|
161
232
|
|
|
162
233
|
if infer_metadata:
|
|
163
|
-
graph.infer_metadata(
|
|
234
|
+
graph.infer_metadata(False)
|
|
164
235
|
|
|
165
236
|
if edges is None:
|
|
166
|
-
graph.infer_links(
|
|
237
|
+
graph.infer_links(False)
|
|
238
|
+
|
|
239
|
+
if verbose:
|
|
240
|
+
graph.print_metadata()
|
|
241
|
+
graph.print_links()
|
|
167
242
|
|
|
168
243
|
return graph
|
|
169
244
|
|
|
@@ -175,7 +250,7 @@ class LocalGraph:
|
|
|
175
250
|
"""
|
|
176
251
|
return name in self.tables
|
|
177
252
|
|
|
178
|
-
def table(self, name: str) ->
|
|
253
|
+
def table(self, name: str) -> Table:
|
|
179
254
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
255
|
|
|
181
256
|
Raises:
|
|
@@ -186,11 +261,11 @@ class LocalGraph:
|
|
|
186
261
|
return self.tables[name]
|
|
187
262
|
|
|
188
263
|
@property
|
|
189
|
-
def tables(self) -> Dict[str,
|
|
264
|
+
def tables(self) -> Dict[str, Table]:
|
|
190
265
|
r"""Returns the dictionary of table objects."""
|
|
191
266
|
return self._tables
|
|
192
267
|
|
|
193
|
-
def add_table(self, table:
|
|
268
|
+
def add_table(self, table: Table) -> Self:
|
|
194
269
|
r"""Adds a table to the graph.
|
|
195
270
|
|
|
196
271
|
Args:
|
|
@@ -199,11 +274,21 @@ class LocalGraph:
|
|
|
199
274
|
Raises:
|
|
200
275
|
KeyError: If a table with the same name already exists in the
|
|
201
276
|
graph.
|
|
277
|
+
ValueError: If the table belongs to a different backend than the
|
|
278
|
+
rest of the tables in the graph.
|
|
202
279
|
"""
|
|
203
280
|
if table.name in self._tables:
|
|
204
281
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
282
|
f"this graph; table names must be globally unique.")
|
|
206
283
|
|
|
284
|
+
if len(self._tables) > 0:
|
|
285
|
+
cls = next(iter(self._tables.values())).__class__
|
|
286
|
+
if table.__class__ != cls:
|
|
287
|
+
raise ValueError(f"Cannot register a "
|
|
288
|
+
f"'{table.__class__.__name__}' to this "
|
|
289
|
+
f"graph since other tables are of type "
|
|
290
|
+
f"'{cls.__name__}'.")
|
|
291
|
+
|
|
207
292
|
self._tables[table.name] = table
|
|
208
293
|
|
|
209
294
|
return self
|
|
@@ -241,7 +326,7 @@ class LocalGraph:
|
|
|
241
326
|
Example:
|
|
242
327
|
>>> # doctest: +SKIP
|
|
243
328
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
329
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
330
|
>>> graph.metadata # doctest: +SKIP
|
|
246
331
|
name primary_key time_column end_time_column
|
|
247
332
|
0 users user_id - -
|
|
@@ -263,7 +348,7 @@ class LocalGraph:
|
|
|
263
348
|
})
|
|
264
349
|
|
|
265
350
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
351
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
267
352
|
if in_notebook():
|
|
268
353
|
from IPython.display import Markdown, display
|
|
269
354
|
display(Markdown('### 🗂️ Graph Metadata'))
|
|
@@ -287,7 +372,7 @@ class LocalGraph:
|
|
|
287
372
|
|
|
288
373
|
Note:
|
|
289
374
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
375
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
376
|
"""
|
|
292
377
|
for table in self.tables.values():
|
|
293
378
|
table.infer_metadata(verbose=False)
|
|
@@ -305,7 +390,7 @@ class LocalGraph:
|
|
|
305
390
|
return self._edges
|
|
306
391
|
|
|
307
392
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
393
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
394
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
395
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
396
|
edges = sorted(edges)
|
|
@@ -333,9 +418,9 @@ class LocalGraph:
|
|
|
333
418
|
|
|
334
419
|
def link(
|
|
335
420
|
self,
|
|
336
|
-
src_table: Union[str,
|
|
421
|
+
src_table: Union[str, Table],
|
|
337
422
|
fkey: str,
|
|
338
|
-
dst_table: Union[str,
|
|
423
|
+
dst_table: Union[str, Table],
|
|
339
424
|
) -> Self:
|
|
340
425
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
426
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +443,11 @@ class LocalGraph:
|
|
|
358
443
|
table does not exist in the graph, if the source key does not
|
|
359
444
|
exist in the source table.
|
|
360
445
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
446
|
+
if isinstance(src_table, Table):
|
|
362
447
|
src_table = src_table.name
|
|
363
448
|
assert isinstance(src_table, str)
|
|
364
449
|
|
|
365
|
-
if isinstance(dst_table,
|
|
450
|
+
if isinstance(dst_table, Table):
|
|
366
451
|
dst_table = dst_table.name
|
|
367
452
|
assert isinstance(dst_table, str)
|
|
368
453
|
|
|
@@ -396,9 +481,9 @@ class LocalGraph:
|
|
|
396
481
|
|
|
397
482
|
def unlink(
|
|
398
483
|
self,
|
|
399
|
-
src_table: Union[str,
|
|
484
|
+
src_table: Union[str, Table],
|
|
400
485
|
fkey: str,
|
|
401
|
-
dst_table: Union[str,
|
|
486
|
+
dst_table: Union[str, Table],
|
|
402
487
|
) -> Self:
|
|
403
488
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
489
|
|
|
@@ -410,11 +495,11 @@ class LocalGraph:
|
|
|
410
495
|
Raises:
|
|
411
496
|
ValueError: if the edge is not present in the graph.
|
|
412
497
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
498
|
+
if isinstance(src_table, Table):
|
|
414
499
|
src_table = src_table.name
|
|
415
500
|
assert isinstance(src_table, str)
|
|
416
501
|
|
|
417
|
-
if isinstance(dst_table,
|
|
502
|
+
if isinstance(dst_table, Table):
|
|
418
503
|
dst_table = dst_table.name
|
|
419
504
|
assert isinstance(dst_table, str)
|
|
420
505
|
|
|
@@ -428,17 +513,13 @@ class LocalGraph:
|
|
|
428
513
|
return self
|
|
429
514
|
|
|
430
515
|
def infer_links(self, verbose: bool = True) -> Self:
|
|
431
|
-
r"""Infers links for the tables and adds them as edges to the
|
|
516
|
+
r"""Infers missing links for the tables and adds them as edges to the
|
|
517
|
+
graph.
|
|
432
518
|
|
|
433
519
|
Args:
|
|
434
520
|
verbose: Whether to print verbose output.
|
|
435
|
-
|
|
436
|
-
Note:
|
|
437
|
-
This function expects graph edges to be undefined upfront.
|
|
438
521
|
"""
|
|
439
|
-
|
|
440
|
-
warnings.warn("Cannot infer links if graph edges already exist")
|
|
441
|
-
return self
|
|
522
|
+
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
442
523
|
|
|
443
524
|
# A list of primary key candidates (+score) for every column:
|
|
444
525
|
candidate_dict: dict[
|
|
@@ -463,6 +544,9 @@ class LocalGraph:
|
|
|
463
544
|
src_table_name = src_table.name.lower()
|
|
464
545
|
|
|
465
546
|
for src_key in src_table.columns:
|
|
547
|
+
if (src_table.name, src_key.name) in known_edges:
|
|
548
|
+
continue
|
|
549
|
+
|
|
466
550
|
if src_key == src_table.primary_key:
|
|
467
551
|
continue # Cannot link to primary key.
|
|
468
552
|
|
|
@@ -528,7 +612,10 @@ class LocalGraph:
|
|
|
528
612
|
score += 1.0
|
|
529
613
|
|
|
530
614
|
# Cardinality ratio:
|
|
531
|
-
|
|
615
|
+
src_num_rows = src_table._num_rows()
|
|
616
|
+
dst_num_rows = dst_table._num_rows()
|
|
617
|
+
if (src_num_rows is not None and dst_num_rows is not None
|
|
618
|
+
and src_num_rows > dst_num_rows):
|
|
532
619
|
score += 1.0
|
|
533
620
|
|
|
534
621
|
if score < 5.0:
|
|
@@ -790,7 +877,7 @@ class LocalGraph:
|
|
|
790
877
|
def __contains__(self, name: str) -> bool:
|
|
791
878
|
return self.has_table(name)
|
|
792
879
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
880
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
881
|
return self.table(name)
|
|
795
882
|
|
|
796
883
|
def __delitem__(self, name: str) -> None:
|
|
@@ -2,10 +2,12 @@ from .id import contains_id
|
|
|
2
2
|
from .timestamp import contains_timestamp
|
|
3
3
|
from .categorical import contains_categorical
|
|
4
4
|
from .multicategorical import contains_multicategorical
|
|
5
|
+
from .stype import infer_stype
|
|
5
6
|
|
|
6
7
|
__all__ = [
|
|
7
8
|
'contains_id',
|
|
8
9
|
'contains_timestamp',
|
|
9
10
|
'contains_categorical',
|
|
10
11
|
'contains_multicategorical',
|
|
12
|
+
'infer_stype',
|
|
11
13
|
]
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from kumoapi.typing import Dtype, Stype
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.infer import (
|
|
5
|
+
contains_categorical,
|
|
6
|
+
contains_id,
|
|
7
|
+
contains_multicategorical,
|
|
8
|
+
contains_timestamp,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
|
|
13
|
+
r"""Infers the semantic type of a column.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
17
|
+
column_name: The name of the column (used for pattern matching).
|
|
18
|
+
dtype: The data type.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The semantic type.
|
|
22
|
+
"""
|
|
23
|
+
if contains_id(ser, column_name, dtype):
|
|
24
|
+
return Stype.ID
|
|
25
|
+
|
|
26
|
+
if contains_timestamp(ser, column_name, dtype):
|
|
27
|
+
return Stype.timestamp
|
|
28
|
+
|
|
29
|
+
if contains_multicategorical(ser, column_name, dtype):
|
|
30
|
+
return Stype.multicategorical
|
|
31
|
+
|
|
32
|
+
if contains_categorical(ser, column_name, dtype):
|
|
33
|
+
return Stype.categorical
|
|
34
|
+
|
|
35
|
+
return dtype.default_stype
|
|
@@ -6,7 +6,7 @@ import pandas as pd
|
|
|
6
6
|
from kumoapi.rfm.context import Subgraph
|
|
7
7
|
from kumoapi.typing import Stype
|
|
8
8
|
|
|
9
|
-
from kumoai.experimental.rfm import
|
|
9
|
+
from kumoai.experimental.rfm import Graph, LocalTable
|
|
10
10
|
from kumoai.experimental.rfm.utils import normalize_text
|
|
11
11
|
from kumoai.utils import InteractiveProgressLogger, ProgressLogger
|
|
12
12
|
|
|
@@ -20,7 +20,7 @@ except ImportError:
|
|
|
20
20
|
class LocalGraphStore:
|
|
21
21
|
def __init__(
|
|
22
22
|
self,
|
|
23
|
-
graph:
|
|
23
|
+
graph: Graph,
|
|
24
24
|
preprocess: bool = False,
|
|
25
25
|
verbose: Union[bool, ProgressLogger] = True,
|
|
26
26
|
) -> None:
|
|
@@ -105,7 +105,7 @@ class LocalGraphStore:
|
|
|
105
105
|
|
|
106
106
|
def sanitize(
|
|
107
107
|
self,
|
|
108
|
-
graph:
|
|
108
|
+
graph: Graph,
|
|
109
109
|
preprocess: bool = False,
|
|
110
110
|
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, np.ndarray]]:
|
|
111
111
|
r"""Sanitizes raw data according to table schema definition:
|
|
@@ -120,10 +120,11 @@ class LocalGraphStore:
|
|
|
120
120
|
data for faster model processing. In particular, it:
|
|
121
121
|
* tokenizes any text column that is not a foreign key
|
|
122
122
|
"""
|
|
123
|
-
df_dict: Dict[str, pd.DataFrame] = {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
123
|
+
df_dict: Dict[str, pd.DataFrame] = {}
|
|
124
|
+
for table_name, table in graph.tables.items():
|
|
125
|
+
assert isinstance(table, LocalTable)
|
|
126
|
+
df = table._data
|
|
127
|
+
df_dict[table_name] = df.copy(deep=False).reset_index(drop=True)
|
|
127
128
|
|
|
128
129
|
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
129
130
|
|
|
@@ -165,7 +166,7 @@ class LocalGraphStore:
|
|
|
165
166
|
|
|
166
167
|
return df_dict, mask_dict
|
|
167
168
|
|
|
168
|
-
def get_stype_dict(self, graph:
|
|
169
|
+
def get_stype_dict(self, graph: Graph) -> Dict[str, Dict[str, Stype]]:
|
|
169
170
|
stype_dict: Dict[str, Dict[str, Stype]] = {}
|
|
170
171
|
foreign_keys = {(edge.src_table, edge.fkey) for edge in graph.edges}
|
|
171
172
|
for table in graph.tables.values():
|
|
@@ -180,7 +181,7 @@ class LocalGraphStore:
|
|
|
180
181
|
|
|
181
182
|
def get_pkey_data(
|
|
182
183
|
self,
|
|
183
|
-
graph:
|
|
184
|
+
graph: Graph,
|
|
184
185
|
) -> Tuple[
|
|
185
186
|
Dict[str, str],
|
|
186
187
|
Dict[str, pd.DataFrame],
|
|
@@ -218,7 +219,7 @@ class LocalGraphStore:
|
|
|
218
219
|
|
|
219
220
|
def get_time_data(
|
|
220
221
|
self,
|
|
221
|
-
graph:
|
|
222
|
+
graph: Graph,
|
|
222
223
|
) -> Tuple[
|
|
223
224
|
Dict[str, str],
|
|
224
225
|
Dict[str, str],
|
|
@@ -259,7 +260,7 @@ class LocalGraphStore:
|
|
|
259
260
|
|
|
260
261
|
def get_csc(
|
|
261
262
|
self,
|
|
262
|
-
graph:
|
|
263
|
+
graph: Graph,
|
|
263
264
|
) -> Tuple[
|
|
264
265
|
Dict[Tuple[str, str, str], np.ndarray],
|
|
265
266
|
Dict[Tuple[str, str, str], np.ndarray],
|
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -32,7 +32,7 @@ from kumoapi.task import TaskType
|
|
|
32
32
|
|
|
33
33
|
from kumoai.client.rfm import RFMAPI
|
|
34
34
|
from kumoai.exceptions import HTTPException
|
|
35
|
-
from kumoai.experimental.rfm import
|
|
35
|
+
from kumoai.experimental.rfm import Graph
|
|
36
36
|
from kumoai.experimental.rfm.local_graph_sampler import LocalGraphSampler
|
|
37
37
|
from kumoai.experimental.rfm.local_graph_store import LocalGraphStore
|
|
38
38
|
from kumoai.experimental.rfm.local_pquery_driver import (
|
|
@@ -123,17 +123,17 @@ class KumoRFM:
|
|
|
123
123
|
:class:`KumoRFM` is a foundation model to generate predictions for any
|
|
124
124
|
relational dataset without training.
|
|
125
125
|
The model is pre-trained and the class provides an interface to query the
|
|
126
|
-
model from a :class:`
|
|
126
|
+
model from a :class:`Graph` object.
|
|
127
127
|
|
|
128
128
|
.. code-block:: python
|
|
129
129
|
|
|
130
|
-
from kumoai.experimental.rfm import
|
|
130
|
+
from kumoai.experimental.rfm import Graph, KumoRFM
|
|
131
131
|
|
|
132
132
|
df_users = pd.DataFrame(...)
|
|
133
133
|
df_items = pd.DataFrame(...)
|
|
134
134
|
df_orders = pd.DataFrame(...)
|
|
135
135
|
|
|
136
|
-
graph =
|
|
136
|
+
graph = Graph.from_data({
|
|
137
137
|
'users': df_users,
|
|
138
138
|
'items': df_items,
|
|
139
139
|
'orders': df_orders,
|
|
@@ -163,7 +163,7 @@ class KumoRFM:
|
|
|
163
163
|
"""
|
|
164
164
|
def __init__(
|
|
165
165
|
self,
|
|
166
|
-
graph:
|
|
166
|
+
graph: Graph,
|
|
167
167
|
preprocess: bool = False,
|
|
168
168
|
verbose: Union[bool, ProgressLogger] = True,
|
|
169
169
|
) -> None:
|
|
@@ -2,15 +2,22 @@ import base64
|
|
|
2
2
|
import json
|
|
3
3
|
from typing import Any, Dict, List, Tuple
|
|
4
4
|
|
|
5
|
-
import boto3
|
|
6
5
|
import requests
|
|
7
|
-
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
8
|
-
from mypy_boto3_sagemaker_runtime.type_defs import InvokeEndpointOutputTypeDef
|
|
9
6
|
|
|
10
7
|
from kumoai.client import KumoClient
|
|
11
8
|
from kumoai.client.endpoints import Endpoint, HTTPMethod
|
|
12
9
|
from kumoai.exceptions import HTTPException
|
|
13
10
|
|
|
11
|
+
try:
|
|
12
|
+
# isort: off
|
|
13
|
+
from mypy_boto3_sagemaker_runtime.client import SageMakerRuntimeClient
|
|
14
|
+
from mypy_boto3_sagemaker_runtime.type_defs import (
|
|
15
|
+
InvokeEndpointOutputTypeDef, )
|
|
16
|
+
# isort: on
|
|
17
|
+
except ImportError:
|
|
18
|
+
SageMakerRuntimeClient = Any
|
|
19
|
+
InvokeEndpointOutputTypeDef = Any
|
|
20
|
+
|
|
14
21
|
|
|
15
22
|
class SageMakerResponseAdapter(requests.Response):
|
|
16
23
|
def __init__(self, sm_response: InvokeEndpointOutputTypeDef):
|
|
@@ -34,6 +41,7 @@ class SageMakerResponseAdapter(requests.Response):
|
|
|
34
41
|
|
|
35
42
|
class KumoClient_SageMakerAdapter(KumoClient):
|
|
36
43
|
def __init__(self, region: str, endpoint_name: str):
|
|
44
|
+
import boto3
|
|
37
45
|
self._client: SageMakerRuntimeClient = boto3.client(
|
|
38
46
|
service_name="sagemaker-runtime", region_name=region)
|
|
39
47
|
self._endpoint_name = endpoint_name
|