kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/jobs.py +2 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/snow/sampler.py +83 -14
- kumoai/experimental/rfm/backend/sqlite/sampler.py +68 -12
- kumoai/experimental/rfm/base/mapper.py +67 -0
- kumoai/experimental/rfm/base/sampler.py +21 -0
- kumoai/experimental/rfm/base/sql_sampler.py +233 -10
- kumoai/experimental/rfm/base/table.py +41 -53
- kumoai/experimental/rfm/graph.py +57 -60
- kumoai/experimental/rfm/infer/dtype.py +2 -1
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +529 -303
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/RECORD +24 -20
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512271732.dist-info → kumoai-2.14.0rc2.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import io
|
|
@@ -19,6 +21,7 @@ from kumoai import in_notebook, in_snowflake_notebook
|
|
|
19
21
|
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
20
22
|
from kumoai.graph import Edge
|
|
21
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
22
25
|
|
|
23
26
|
if TYPE_CHECKING:
|
|
24
27
|
import graphviz
|
|
@@ -536,6 +539,35 @@ class Graph:
|
|
|
536
539
|
|
|
537
540
|
return graph
|
|
538
541
|
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
539
571
|
# Backend #################################################################
|
|
540
572
|
|
|
541
573
|
@property
|
|
@@ -617,28 +649,28 @@ class Graph:
|
|
|
617
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
618
650
|
information about the tables in this graph.
|
|
619
651
|
|
|
620
|
-
The returned dataframe has columns ``
|
|
621
|
-
``
|
|
622
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
623
655
|
|
|
624
656
|
Example:
|
|
625
657
|
>>> # doctest: +SKIP
|
|
626
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
627
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
628
660
|
>>> graph.metadata # doctest: +SKIP
|
|
629
|
-
|
|
630
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
631
663
|
"""
|
|
632
664
|
tables = list(self.tables.values())
|
|
633
665
|
|
|
634
666
|
return pd.DataFrame({
|
|
635
|
-
'
|
|
667
|
+
'Name':
|
|
636
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
637
|
-
'
|
|
669
|
+
'Primary Key':
|
|
638
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
639
|
-
'
|
|
671
|
+
'Time Column':
|
|
640
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
641
|
-
'
|
|
673
|
+
'End Time Column':
|
|
642
674
|
pd.Series(
|
|
643
675
|
dtype=str,
|
|
644
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -647,24 +679,8 @@ class Graph:
|
|
|
647
679
|
|
|
648
680
|
def print_metadata(self) -> None:
|
|
649
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
653
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
654
|
-
elif in_notebook():
|
|
655
|
-
from IPython.display import Markdown, display
|
|
656
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
657
|
-
df = self.metadata
|
|
658
|
-
try:
|
|
659
|
-
if hasattr(df.style, 'hide'):
|
|
660
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
661
|
-
else:
|
|
662
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
663
|
-
except ImportError:
|
|
664
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
665
|
-
else:
|
|
666
|
-
print("🗂️ Graph Metadata:")
|
|
667
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
668
684
|
|
|
669
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
670
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -693,40 +709,21 @@ class Graph:
|
|
|
693
709
|
|
|
694
710
|
def print_links(self) -> None:
|
|
695
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
696
|
-
edges = [(
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
st.markdown("*No links registered*")
|
|
710
|
-
elif in_notebook():
|
|
711
|
-
from IPython.display import Markdown, display
|
|
712
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
713
|
-
if len(edges) > 0:
|
|
714
|
-
display(
|
|
715
|
-
Markdown('\n'.join([
|
|
716
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
717
|
-
for edge in edges
|
|
718
|
-
])))
|
|
719
|
-
else:
|
|
720
|
-
display(Markdown("*No links registered*"))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
721
725
|
else:
|
|
722
|
-
|
|
723
|
-
if len(edges) > 0:
|
|
724
|
-
print('\n'.join([
|
|
725
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
726
|
-
for edge in edges
|
|
727
|
-
]))
|
|
728
|
-
else:
|
|
729
|
-
print("No links registered")
|
|
726
|
+
display.italic("No links registered")
|
|
730
727
|
|
|
731
728
|
def link(
|
|
732
729
|
self,
|
|
@@ -52,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
52
52
|
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
53
53
|
|
|
54
54
|
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
55
|
-
if pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
55
|
+
if (pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
56
|
+
or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
|
|
56
57
|
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
57
58
|
if pa.types.is_integer(elem_dtype):
|
|
58
59
|
return Dtype.intlist
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|