kumoai 2.14.0.dev202512191731__cp311-cp311-macosx_11_0_arm64.whl → 2.14.0.dev202601051732__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. kumoai/__init__.py +23 -26
  2. kumoai/_version.py +1 -1
  3. kumoai/client/client.py +6 -0
  4. kumoai/client/jobs.py +24 -0
  5. kumoai/connector/utils.py +21 -7
  6. kumoai/experimental/rfm/__init__.py +24 -22
  7. kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
  8. kumoai/experimental/rfm/backend/local/sampler.py +0 -3
  9. kumoai/experimental/rfm/backend/local/table.py +24 -25
  10. kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
  11. kumoai/experimental/rfm/backend/snow/table.py +137 -64
  12. kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
  13. kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
  14. kumoai/experimental/rfm/base/__init__.py +6 -9
  15. kumoai/experimental/rfm/base/column.py +95 -11
  16. kumoai/experimental/rfm/base/expression.py +44 -0
  17. kumoai/experimental/rfm/base/sampler.py +5 -17
  18. kumoai/experimental/rfm/base/source.py +1 -1
  19. kumoai/experimental/rfm/base/sql_sampler.py +69 -9
  20. kumoai/experimental/rfm/base/table.py +258 -97
  21. kumoai/experimental/rfm/graph.py +106 -98
  22. kumoai/experimental/rfm/infer/dtype.py +4 -1
  23. kumoai/experimental/rfm/infer/multicategorical.py +1 -1
  24. kumoai/experimental/rfm/relbench.py +76 -0
  25. kumoai/experimental/rfm/rfm.py +394 -241
  26. kumoai/experimental/rfm/task_table.py +290 -0
  27. kumoai/trainer/distilled_trainer.py +175 -0
  28. kumoai/utils/display.py +51 -0
  29. kumoai/utils/progress_logger.py +13 -1
  30. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/METADATA +1 -1
  31. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/RECORD +34 -31
  32. kumoai/experimental/rfm/base/column_expression.py +0 -50
  33. kumoai/experimental/rfm/base/sql_table.py +0 -229
  34. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/WHEEL +0 -0
  35. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/licenses/LICENSE +0 -0
  36. {kumoai-2.14.0.dev202512191731.dist-info → kumoai-2.14.0.dev202601051732.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import contextlib
2
4
  import copy
3
5
  import io
@@ -16,14 +18,10 @@ from kumoapi.typing import Stype
16
18
  from typing_extensions import Self
17
19
 
18
20
  from kumoai import in_notebook, in_snowflake_notebook
19
- from kumoai.experimental.rfm.base import (
20
- ColumnExpressionSpec,
21
- DataBackend,
22
- SQLTable,
23
- Table,
24
- )
21
+ from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
25
22
  from kumoai.graph import Edge
26
23
  from kumoai.mixin import CastMixin
24
+ from kumoai.utils import display
27
25
 
28
26
  if TYPE_CHECKING:
29
27
  import graphviz
@@ -103,27 +101,24 @@ class Graph:
103
101
  for table in tables:
104
102
  self.add_table(table)
105
103
 
106
- for table in tables:
107
- if not isinstance(table, SQLTable):
108
- continue
109
- if '_source_column_dict' not in table.__dict__:
104
+ for table in tables: # Use links from source metadata:
105
+ if not any(column.is_source for column in table.columns):
110
106
  continue
111
107
  for fkey in table._source_foreign_key_dict.values():
112
108
  if fkey.name not in table:
113
109
  continue
114
- if not table[fkey.name].is_physical:
110
+ if not table[fkey.name].is_source:
115
111
  continue
116
112
  dst_table_names = [
117
113
  table.name for table in self.tables.values()
118
- if isinstance(table, SQLTable)
119
- and table._source_name == fkey.dst_table
114
+ if table.source_name == fkey.dst_table
120
115
  ]
121
116
  if len(dst_table_names) != 1:
122
117
  continue
123
118
  dst_table = self[dst_table_names[0]]
124
119
  if dst_table._primary_key != fkey.primary_key:
125
120
  continue
126
- if not dst_table[fkey.primary_key].is_physical:
121
+ if not dst_table[fkey.primary_key].is_source:
127
122
  continue
128
123
  self.link(table.name, fkey.name, dst_table.name)
129
124
 
@@ -444,9 +439,8 @@ class Graph:
444
439
  f"'{table_name}' since composite primary keys "
445
440
  f"are not yet supported")
446
441
 
447
- columns: list[str] = []
442
+ columns: list[ColumnSpec] = []
448
443
  unsupported_columns: list[str] = []
449
- column_expression_specs: list[ColumnExpressionSpec] = []
450
444
  for column_cfg in chain(
451
445
  table_cfg.get('dimensions', []),
452
446
  table_cfg.get('time_dimensions', []),
@@ -457,13 +451,13 @@ class Graph:
457
451
  column_data_type = column_cfg.get('data_type', None)
458
452
 
459
453
  if column_expr is None:
460
- columns.append(column_name)
454
+ columns.append(ColumnSpec(name=column_name))
461
455
  continue
462
456
 
463
457
  column_expr = column_expr.replace(f'{table_name}.', '')
464
458
 
465
459
  if column_expr == column_name:
466
- columns.append(column_name)
460
+ columns.append(ColumnSpec(name=column_name))
467
461
  continue
468
462
 
469
463
  # Drop expressions that reference other tables (for now):
@@ -471,12 +465,12 @@ class Graph:
471
465
  unsupported_columns.append(column_name)
472
466
  continue
473
467
 
474
- spec = ColumnExpressionSpec(
468
+ column = ColumnSpec(
475
469
  name=column_name,
476
470
  expr=column_expr,
477
- dtype=SnowTable.to_dtype(column_data_type),
471
+ dtype=SnowTable._to_dtype(column_data_type),
478
472
  )
479
- column_expression_specs.append(spec)
473
+ columns.append(column)
480
474
 
481
475
  if len(unsupported_columns) == 1:
482
476
  msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
@@ -494,7 +488,6 @@ class Graph:
494
488
  database=database,
495
489
  schema=schema,
496
490
  columns=columns,
497
- column_expressions=column_expression_specs,
498
491
  primary_key=primary_key,
499
492
  )
500
493
 
@@ -546,6 +539,35 @@ class Graph:
546
539
 
547
540
  return graph
548
541
 
542
+ @classmethod
543
+ def from_relbench(
544
+ cls,
545
+ dataset: str,
546
+ verbose: bool = True,
547
+ ) -> Graph:
548
+ r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
549
+ :class:`Graph` instance.
550
+
551
+ .. code-block:: python
552
+
553
+ >>> # doctest: +SKIP
554
+ >>> import kumoai.experimental.rfm as rfm
555
+
556
+ >>> graph = rfm.Graph.from_relbench("f1")
557
+
558
+ Args:
559
+ dataset: The RelBench dataset name.
560
+ verbose: Whether to print verbose output.
561
+ """
562
+ from kumoai.experimental.rfm.relbench import from_relbench
563
+ graph = from_relbench(dataset, verbose=verbose)
564
+
565
+ if verbose:
566
+ graph.print_metadata()
567
+ graph.print_links()
568
+
569
+ return graph
570
+
549
571
  # Backend #################################################################
550
572
 
551
573
  @property
@@ -657,24 +679,8 @@ class Graph:
657
679
 
658
680
  def print_metadata(self) -> None:
659
681
  r"""Prints the :meth:`~Graph.metadata` of the graph."""
660
- if in_snowflake_notebook():
661
- import streamlit as st
662
- st.markdown("### 🗂️ Graph Metadata")
663
- st.dataframe(self.metadata, hide_index=True)
664
- elif in_notebook():
665
- from IPython.display import Markdown, display
666
- display(Markdown("### 🗂️ Graph Metadata"))
667
- df = self.metadata
668
- try:
669
- if hasattr(df.style, 'hide'):
670
- display(df.style.hide(axis='index')) # pandas=2
671
- else:
672
- display(df.style.hide_index()) # pandas<1.3
673
- except ImportError:
674
- print(df.to_string(index=False)) # missing jinja2
675
- else:
676
- print("🗂️ Graph Metadata:")
677
- print(self.metadata.to_string(index=False))
682
+ display.title("🗂️ Graph Metadata")
683
+ display.dataframe(self.metadata)
678
684
 
679
685
  def infer_metadata(self, verbose: bool = True) -> Self:
680
686
  r"""Infers metadata for all tables in the graph.
@@ -703,40 +709,21 @@ class Graph:
703
709
 
704
710
  def print_links(self) -> None:
705
711
  r"""Prints the :meth:`~Graph.edges` of the graph."""
706
- edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
707
- edge.src_table, edge.fkey) for edge in self.edges]
708
- edges = sorted(edges)
709
-
710
- if in_snowflake_notebook():
711
- import streamlit as st
712
- st.markdown("### 🕸️ Graph Links (FK ↔️ PK)")
713
- if len(edges) > 0:
714
- st.markdown('\n'.join([
715
- f"- {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
716
- for edge in edges
717
- ]))
718
- else:
719
- st.markdown("*No links registered*")
720
- elif in_notebook():
721
- from IPython.display import Markdown, display
722
- display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
723
- if len(edges) > 0:
724
- display(
725
- Markdown('\n'.join([
726
- f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
727
- for edge in edges
728
- ])))
729
- else:
730
- display(Markdown("*No links registered*"))
712
+ edges = sorted([(
713
+ edge.dst_table,
714
+ self[edge.dst_table]._primary_key,
715
+ edge.src_table,
716
+ edge.fkey,
717
+ ) for edge in self.edges])
718
+
719
+ display.title("🕸️ Graph Links (FK ↔️ PK)")
720
+ if len(edges) > 0:
721
+ display.unordered_list(items=[
722
+ f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
723
+ for edge in edges
724
+ ])
731
725
  else:
732
- print("🕸️ Graph Links (FK ↔️ PK):")
733
- if len(edges) > 0:
734
- print('\n'.join([
735
- f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
736
- for edge in edges
737
- ]))
738
- else:
739
- print("No links registered")
726
+ display.italic("No links registered")
740
727
 
741
728
  def link(
742
729
  self,
@@ -843,6 +830,30 @@ class Graph:
843
830
  """
844
831
  known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
845
832
 
833
+ for table in self.tables.values(): # Use links from source metadata:
834
+ if not any(column.is_source for column in table.columns):
835
+ continue
836
+ for fkey in table._source_foreign_key_dict.values():
837
+ if fkey.name not in table:
838
+ continue
839
+ if not table[fkey.name].is_source:
840
+ continue
841
+ if (table.name, fkey.name) in known_edges:
842
+ continue
843
+ dst_table_names = [
844
+ table.name for table in self.tables.values()
845
+ if table.source_name == fkey.dst_table
846
+ ]
847
+ if len(dst_table_names) != 1:
848
+ continue
849
+ dst_table = self[dst_table_names[0]]
850
+ if dst_table._primary_key != fkey.primary_key:
851
+ continue
852
+ if not dst_table[fkey.primary_key].is_source:
853
+ continue
854
+ self.link(table.name, fkey.name, dst_table.name)
855
+ known_edges.add((table.name, fkey.name))
856
+
846
857
  # A list of primary key candidates (+score) for every column:
847
858
  candidate_dict: dict[
848
859
  tuple[str, str],
@@ -942,13 +953,8 @@ class Graph:
942
953
  if score < 5.0:
943
954
  continue
944
955
 
945
- candidate_dict[(
946
- src_table.name,
947
- src_key.name,
948
- )].append((
949
- dst_table.name,
950
- score,
951
- ))
956
+ candidate_dict[(src_table.name, src_key.name)].append(
957
+ (dst_table.name, score))
952
958
 
953
959
  for (src_table_name, src_key_name), scores in candidate_dict.items():
954
960
  scores.sort(key=lambda x: x[-1], reverse=True)
@@ -1007,24 +1013,26 @@ class Graph:
1007
1013
  f"either the primary key or the link before "
1008
1014
  f"before proceeding.")
1009
1015
 
1010
- # Check that fkey/pkey have valid and consistent data types:
1011
- assert src_key.dtype is not None
1012
- src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1013
- src_string = src_key.dtype.is_string()
1014
- assert dst_key.dtype is not None
1015
- dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1016
- dst_string = dst_key.dtype.is_string()
1017
-
1018
- if not src_number and not src_string:
1019
- raise ValueError(f"{edge} is invalid as foreign key must be a "
1020
- f"number or string (got '{src_key.dtype}'")
1021
-
1022
- if src_number != dst_number or src_string != dst_string:
1023
- raise ValueError(f"{edge} is invalid as foreign key "
1024
- f"'{fkey}' and primary key '{dst_key.name}' "
1025
- f"have incompatible data types (got "
1026
- f"fkey.dtype '{src_key.dtype}' and "
1027
- f"pkey.dtype '{dst_key.dtype}')")
1016
+ if self.backend == DataBackend.LOCAL:
1017
+ # Check that fkey/pkey have valid and consistent data types:
1018
+ assert src_key.dtype is not None
1019
+ src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
1020
+ src_string = src_key.dtype.is_string()
1021
+ assert dst_key.dtype is not None
1022
+ dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
1023
+ dst_string = dst_key.dtype.is_string()
1024
+
1025
+ if not src_number and not src_string:
1026
+ raise ValueError(
1027
+ f"{edge} is invalid as foreign key must be a number "
1028
+ f"or string (got '{src_key.dtype}'")
1029
+
1030
+ if src_number != dst_number or src_string != dst_string:
1031
+ raise ValueError(
1032
+ f"{edge} is invalid as foreign key '{fkey}' and "
1033
+ f"primary key '{dst_key.name}' have incompatible data "
1034
+ f"types (got foreign key data type '{src_key.dtype}' "
1035
+ f"and primary key data type '{dst_key.dtype}')")
1028
1036
 
1029
1037
  return self
1030
1038
 
@@ -20,6 +20,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
20
20
  'string[python]': Dtype.string,
21
21
  'string[pyarrow]': Dtype.string,
22
22
  'binary': Dtype.binary,
23
+ 'binary[python]': Dtype.binary,
24
+ 'binary[pyarrow]': Dtype.binary,
23
25
  }
24
26
 
25
27
 
@@ -50,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
50
52
  ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
51
53
 
52
54
  if isinstance(ser.dtype, pd.ArrowDtype):
53
- if pa.types.is_list(ser.dtype.pyarrow_dtype):
55
+ if (pa.types.is_list(ser.dtype.pyarrow_dtype)
56
+ or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
54
57
  elem_dtype = ser.dtype.pyarrow_dtype.value_type
55
58
  if pa.types.is_integer(elem_dtype):
56
59
  return Dtype.intlist
@@ -40,7 +40,7 @@ def contains_multicategorical(
40
40
  sep = max(candidates, key=candidates.get) # type: ignore
41
41
  ser = ser.str.split(sep)
42
42
 
43
- num_unique_multi = ser.explode().nunique()
43
+ num_unique_multi = ser.astype('object').explode().nunique()
44
44
 
45
45
  if dtype.is_list():
46
46
  return num_unique_multi <= MAX_CAT
@@ -0,0 +1,76 @@
1
+ import difflib
2
+ import json
3
+ from functools import lru_cache
4
+ from urllib.request import urlopen
5
+
6
+ import pooch
7
+ import pyarrow as pa
8
+
9
+ from kumoai.experimental.rfm import Graph
10
+ from kumoai.experimental.rfm.backend.local import LocalTable
11
+
12
+ PREFIX = 'rel-'
13
+ CACHE_DIR = pooch.os_cache('relbench')
14
+ HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
15
+ 'relbench/datasets/hashes.json')
16
+
17
+
18
+ @lru_cache
19
+ def get_registry() -> pooch.Pooch:
20
+ with urlopen(HASH_URL) as r:
21
+ hashes = json.load(r)
22
+
23
+ return pooch.create(
24
+ path=CACHE_DIR,
25
+ base_url='https://relbench.stanford.edu/download/',
26
+ registry=hashes,
27
+ )
28
+
29
+
30
+ def from_relbench(dataset: str, verbose: bool = True) -> Graph:
31
+ dataset = dataset.lower()
32
+ if dataset.startswith(PREFIX):
33
+ dataset = dataset[len(PREFIX):]
34
+
35
+ registry = get_registry()
36
+
37
+ datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
38
+ if dataset not in datasets:
39
+ matches = difflib.get_close_matches(dataset, datasets, n=1)
40
+ hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
41
+ raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
42
+ f"datasets are {str(datasets)[1:-1]}.")
43
+
44
+ registry.fetch(
45
+ f'{PREFIX}{dataset}/db.zip',
46
+ processor=pooch.Unzip(extract_dir='.'),
47
+ progressbar=verbose,
48
+ )
49
+
50
+ graph = Graph(tables=[])
51
+ edges: list[tuple[str, str, str]] = []
52
+ for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
53
+ data = pa.parquet.read_table(path)
54
+ metadata = {
55
+ key.decode('utf-8'): json.loads(value.decode('utf-8'))
56
+ for key, value in data.schema.metadata.items()
57
+ if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
58
+ }
59
+
60
+ table = LocalTable(
61
+ df=data.to_pandas(),
62
+ name=path.stem,
63
+ primary_key=metadata['pkey_col'],
64
+ time_column=metadata['time_col'],
65
+ )
66
+ graph.add_table(table)
67
+
68
+ edges.extend([
69
+ (path.stem, fkey, dst_table)
70
+ for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
71
+ ])
72
+
73
+ for edge in edges:
74
+ graph.link(*edge)
75
+
76
+ return graph