kumoai 2.14.0.dev202512271732__cp310-cp310-macosx_11_0_arm64.whl → 2.14.0rc2__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import copy
3
5
  import io
@@ -19,6 +21,7 @@ from kumoai import in_notebook, in_snowflake_notebook
19
21
  from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
20
22
  from kumoai.graph import Edge
21
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
22
25
 
23
26
  if TYPE_CHECKING:
24
27
  import graphviz
@@ -536,6 +539,35 @@ class Graph:
536
539
 
537
540
  return graph
538
541
 
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
539
571
  # Backend #################################################################
540
572
 
541
573
  @property
@@ -617,28 +649,28 @@ class Graph:
617
649
  r"""Returns a :class:`pandas.DataFrame` object containing metadata
618
650
  information about the tables in this graph.
619
651
 
620
- The returned dataframe has columns ``name``, ``primary_key``,
621
- ``time_column``, and ``end_time_column``, which provide an aggregate
622
- view of the properties of the tables of this graph.
652
+ The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
653
+ ``"Time Column"``, and ``"End Time Column"``, which provide an
654
+ aggregated view of the properties of the tables of this graph.
623
655
 
624
656
  Example:
625
657
  >>> # doctest: +SKIP
626
658
  >>> import kumoai.experimental.rfm as rfm
627
659
  >>> graph = rfm.Graph(tables=...).infer_metadata()
628
660
  >>> graph.metadata # doctest: +SKIP
629
- name primary_key time_column end_time_column
630
- 0 users user_id - -
661
+ Name Primary Key Time Column End Time Column
662
+ 0 users user_id - -
631
663
  """
632
664
  tables = list(self.tables.values())
633
665
 
634
666
  return pd.DataFrame({
635
- 'name':
667
+ 'Name':
636
668
  pd.Series(dtype=str, data=[t.name for t in tables]),
637
- 'primary_key':
669
+ 'Primary Key':
638
670
  pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
639
- 'time_column':
671
+ 'Time Column':
640
672
  pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
641
- 'end_time_column':
673
+ 'End Time Column':
642
674
  pd.Series(
643
675
  dtype=str,
644
676
  data=[t._end_time_column or '-' for t in tables],
@@ -647,24 +679,8 @@ class Graph:
647
679
 
648
680
  def print_metadata(self) -> None:
649
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
650
- if in_snowflake_notebook():
651
- import streamlit as st
652
- st.markdown("### 🗂️ Graph Metadata")
653
- st.dataframe(self.metadata, hide_index=True)
654
- elif in_notebook():
655
- from IPython.display import Markdown, display
656
- display(Markdown("### 🗂️ Graph Metadata"))
657
- df = self.metadata
658
- try:
659
- if hasattr(df.style, 'hide'):
660
- display(df.style.hide(axis='index')) # pandas=2
661
- else:
662
- display(df.style.hide_index()) # pandas<1.3
663
- except ImportError:
664
- print(df.to_string(index=False)) # missing jinja2
665
- else:
666
- print("🗂️ Graph Metadata:")
667
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
668
684
 
669
685
  def infer_metadata(self, verbose: bool = True) -> Self:
670
686
  r"""Infers metadata for all tables in the graph.
@@ -693,40 +709,21 @@ class Graph:
693
709
 
694
710
  def print_links(self) -> None:
695
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
696
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
697
- edge.src_table, edge.fkey) for edge in self.edges]
698
- edges = sorted(edges)
699
-
700
- if in_snowflake_notebook():
701
- import streamlit as st
702
- st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
703
- if len(edges) > 0:
704
- st.markdown('\n'.join([
705
- f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
706
- for edge in edges
707
- ]))
708
- else:
709
- st.markdown("*No links registered*")
710
- elif in_notebook():
711
- from IPython.display import Markdown, display
712
- display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
713
- if len(edges) > 0:
714
- display(
715
- Markdown('\n'.join([
716
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
717
- for edge in edges
718
- ])))
719
- else:
720
- display(Markdown("*No links registered*"))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
721
725
  else:
722
- print("🕸️ Graph Links (FK ↔️ PK):")
723
- if len(edges) > 0:
724
- print('\n'.join([
725
- f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
726
- for edge in edges
727
- ]))
728
- else:
729
- print("No links registered")
726
+ display.italic("No links registered")
730
727
 
731
728
  def link(
732
729
  self,
@@ -52,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
52
52
  ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
53
53
 
54
54
  if isinstance(ser.dtype, pd.ArrowDtype):
55
- if pa.types.is_list(ser.dtype.pyarrow_dtype):
55
+ if (pa.types.is_list(ser.dtype.pyarrow_dtype)
56
+ or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
56
57
  elem_dtype = ser.dtype.pyarrow_dtype.value_type
57
58
  if pa.types.is_integer(elem_dtype):
58
59
  return Dtype.intlist
@@ -0,0 +1,76 @@
1
+ import difflib
2
+ import json
3
+ from functools import lru_cache
4
+ from urllib.request import urlopen
5
+
6
+ import pooch
7
+ import pyarrow as pa
8
+
9
+ from kumoai.experimental.rfm import Graph
10
+ from kumoai.experimental.rfm.backend.local import LocalTable
11
+
12
+ PREFIX = 'rel-'
13
+ CACHE_DIR = pooch.os_cache('relbench')
14
+ HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
15
+ 'relbench/datasets/hashes.json')
16
+
17
+
18
+ @lru_cache
19
+ def get_registry() -> pooch.Pooch:
20
+ with urlopen(HASH_URL) as r:
21
+ hashes = json.load(r)
22
+
23
+ return pooch.create(
24
+ path=CACHE_DIR,
25
+ base_url='https://relbench.stanford.edu/download/',
26
+ registry=hashes,
27
+ )
28
+
29
+
30
+ def from_relbench(dataset: str, verbose: bool = True) -> Graph:
31
+ dataset = dataset.lower()
32
+ if dataset.startswith(PREFIX):
33
+ dataset = dataset[len(PREFIX):]
34
+
35
+ registry = get_registry()
36
+
37
+ datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
38
+ if dataset not in datasets:
39
+ matches = difflib.get_close_matches(dataset, datasets, n=1)
40
+ hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
41
+ raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
42
+ f"datasets are {str(datasets)[1:-1]}.")
43
+
44
+ registry.fetch(
45
+ f'{PREFIX}{dataset}/db.zip',
46
+ processor=pooch.Unzip(extract_dir='.'),
47
+ progressbar=verbose,
48
+ )
49
+
50
+ graph = Graph(tables=[])
51
+ edges: list[tuple[str, str, str]] = []
52
+ for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
53
+ data = pa.parquet.read_table(path)
54
+ metadata = {
55
+ key.decode('utf-8'): json.loads(value.decode('utf-8'))
56
+ for key, value in data.schema.metadata.items()
57
+ if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
58
+ }
59
+
60
+ table = LocalTable(
61
+ df=data.to_pandas(),
62
+ name=path.stem,
63
+ primary_key=metadata['pkey_col'],
64
+ time_column=metadata['time_col'],
65
+ )
66
+ graph.add_table(table)
67
+
68
+ edges.extend([
69
+ (path.stem, fkey, dst_table)
70
+ for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
71
+ ])
72
+
73
+ for edge in edges:
74
+ graph.link(*edge)
75
+
76
+ return graph