kumoai 2.14.0.dev202512181731__cp312-cp312-macosx_11_0_arm64.whl → 2.14.0.dev202512301731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +24 -0
- kumoai/experimental/rfm/__init__.py +22 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +25 -24
- kumoai/experimental/rfm/backend/snow/sampler.py +106 -61
- kumoai/experimental/rfm/backend/snow/table.py +146 -51
- kumoai/experimental/rfm/backend/sqlite/sampler.py +127 -78
- kumoai/experimental/rfm/backend/sqlite/table.py +94 -47
- kumoai/experimental/rfm/base/__init__.py +6 -7
- kumoai/experimental/rfm/base/column.py +97 -5
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/sampler.py +5 -17
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +68 -9
- kumoai/experimental/rfm/base/table.py +284 -120
- kumoai/experimental/rfm/graph.py +139 -86
- kumoai/experimental/rfm/infer/__init__.py +6 -4
- kumoai/experimental/rfm/infer/dtype.py +6 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/stype.py +35 -0
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +4 -20
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +51 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/METADATA +1 -1
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/RECORD +33 -30
- kumoai/experimental/rfm/base/column_expression.py +0 -16
- kumoai/experimental/rfm/base/sql_table.py +0 -113
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512181731.dist-info → kumoai-2.14.0.dev202512301731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import io
|
|
@@ -16,9 +18,10 @@ from kumoapi.typing import Stype
|
|
|
16
18
|
from typing_extensions import Self
|
|
17
19
|
|
|
18
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
-
from kumoai.experimental.rfm.base import
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
20
22
|
from kumoai.graph import Edge
|
|
21
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
22
25
|
|
|
23
26
|
if TYPE_CHECKING:
|
|
24
27
|
import graphviz
|
|
@@ -98,24 +101,25 @@ class Graph:
|
|
|
98
101
|
for table in tables:
|
|
99
102
|
self.add_table(table)
|
|
100
103
|
|
|
101
|
-
for table in tables:
|
|
102
|
-
if not
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
103
106
|
continue
|
|
104
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
105
108
|
if fkey.name not in table:
|
|
106
109
|
continue
|
|
107
|
-
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
111
|
+
continue
|
|
108
112
|
dst_table_names = [
|
|
109
113
|
table.name for table in self.tables.values()
|
|
110
|
-
if
|
|
111
|
-
and table._source_name == fkey.dst_table
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
112
115
|
]
|
|
113
116
|
if len(dst_table_names) != 1:
|
|
114
117
|
continue
|
|
115
118
|
dst_table = self[dst_table_names[0]]
|
|
116
119
|
if dst_table._primary_key != fkey.primary_key:
|
|
117
120
|
continue
|
|
118
|
-
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
122
|
+
continue
|
|
119
123
|
self.link(table.name, fkey.name, dst_table.name)
|
|
120
124
|
|
|
121
125
|
for edge in (edges or []):
|
|
@@ -418,6 +422,7 @@ class Graph:
|
|
|
418
422
|
graph = cls(tables=[])
|
|
419
423
|
|
|
420
424
|
msgs = []
|
|
425
|
+
table_names = {table_cfg['name'] for table_cfg in cfg['tables']}
|
|
421
426
|
for table_cfg in cfg['tables']:
|
|
422
427
|
table_name = table_cfg['name']
|
|
423
428
|
source_table_name = table_cfg['base_table']['table']
|
|
@@ -434,14 +439,47 @@ class Graph:
|
|
|
434
439
|
f"'{table_name}' since composite primary keys "
|
|
435
440
|
f"are not yet supported")
|
|
436
441
|
|
|
437
|
-
columns: list[
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
443
|
+
unsupported_columns: list[str] = []
|
|
438
444
|
for column_cfg in chain(
|
|
439
445
|
table_cfg.get('dimensions', []),
|
|
440
446
|
table_cfg.get('time_dimensions', []),
|
|
441
447
|
table_cfg.get('facts', []),
|
|
442
448
|
):
|
|
443
|
-
|
|
444
|
-
|
|
449
|
+
column_name = column_cfg['name']
|
|
450
|
+
column_expr = column_cfg.get('expr', None)
|
|
451
|
+
column_data_type = column_cfg.get('data_type', None)
|
|
452
|
+
|
|
453
|
+
if column_expr is None:
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
455
|
+
continue
|
|
456
|
+
|
|
457
|
+
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
458
|
+
|
|
459
|
+
if column_expr == column_name:
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
|
+
continue
|
|
462
|
+
|
|
463
|
+
# Drop expressions that reference other tables (for now):
|
|
464
|
+
if any(f'{name}.' in column_expr for name in table_names):
|
|
465
|
+
unsupported_columns.append(column_name)
|
|
466
|
+
continue
|
|
467
|
+
|
|
468
|
+
column = ColumnSpec(
|
|
469
|
+
name=column_name,
|
|
470
|
+
expr=column_expr,
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
472
|
+
)
|
|
473
|
+
columns.append(column)
|
|
474
|
+
|
|
475
|
+
if len(unsupported_columns) == 1:
|
|
476
|
+
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
477
|
+
f"of table '{table_name}' since its expression "
|
|
478
|
+
f"references other tables")
|
|
479
|
+
elif len(unsupported_columns) > 1:
|
|
480
|
+
msgs.append(f"Failed to add columns '{unsupported_columns}' "
|
|
481
|
+
f"of table '{table_name}' since their expressions "
|
|
482
|
+
f"reference other tables")
|
|
445
483
|
|
|
446
484
|
table = SnowTable(
|
|
447
485
|
connection,
|
|
@@ -501,6 +539,35 @@ class Graph:
|
|
|
501
539
|
|
|
502
540
|
return graph
|
|
503
541
|
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
504
571
|
# Backend #################################################################
|
|
505
572
|
|
|
506
573
|
@property
|
|
@@ -612,24 +679,8 @@ class Graph:
|
|
|
612
679
|
|
|
613
680
|
def print_metadata(self) -> None:
|
|
614
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
618
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
619
|
-
elif in_notebook():
|
|
620
|
-
from IPython.display import Markdown, display
|
|
621
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
622
|
-
df = self.metadata
|
|
623
|
-
try:
|
|
624
|
-
if hasattr(df.style, 'hide'):
|
|
625
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
626
|
-
else:
|
|
627
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
628
|
-
except ImportError:
|
|
629
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
630
|
-
else:
|
|
631
|
-
print("🗂️ Graph Metadata:")
|
|
632
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
633
684
|
|
|
634
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
635
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -658,40 +709,21 @@ class Graph:
|
|
|
658
709
|
|
|
659
710
|
def print_links(self) -> None:
|
|
660
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
661
|
-
edges = [(
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
st.markdown("*No links registered*")
|
|
675
|
-
elif in_notebook():
|
|
676
|
-
from IPython.display import Markdown, display
|
|
677
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
678
|
-
if len(edges) > 0:
|
|
679
|
-
display(
|
|
680
|
-
Markdown('\n'.join([
|
|
681
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
682
|
-
for edge in edges
|
|
683
|
-
])))
|
|
684
|
-
else:
|
|
685
|
-
display(Markdown("*No links registered*"))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
686
725
|
else:
|
|
687
|
-
|
|
688
|
-
if len(edges) > 0:
|
|
689
|
-
print('\n'.join([
|
|
690
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
691
|
-
for edge in edges
|
|
692
|
-
]))
|
|
693
|
-
else:
|
|
694
|
-
print("No links registered")
|
|
726
|
+
display.italic("No links registered")
|
|
695
727
|
|
|
696
728
|
def link(
|
|
697
729
|
self,
|
|
@@ -798,6 +830,30 @@ class Graph:
|
|
|
798
830
|
"""
|
|
799
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
800
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
801
857
|
# A list of primary key candidates (+score) for every column:
|
|
802
858
|
candidate_dict: dict[
|
|
803
859
|
tuple[str, str],
|
|
@@ -897,13 +953,8 @@ class Graph:
|
|
|
897
953
|
if score < 5.0:
|
|
898
954
|
continue
|
|
899
955
|
|
|
900
|
-
candidate_dict[(
|
|
901
|
-
|
|
902
|
-
src_key.name,
|
|
903
|
-
)].append((
|
|
904
|
-
dst_table.name,
|
|
905
|
-
score,
|
|
906
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
907
958
|
|
|
908
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
909
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -962,24 +1013,26 @@ class Graph:
|
|
|
962
1013
|
f"either the primary key or the link before "
|
|
963
1014
|
f"before proceeding.")
|
|
964
1015
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
983
1036
|
|
|
984
1037
|
return self
|
|
985
1038
|
|
|
@@ -1,17 +1,19 @@
|
|
|
1
1
|
from .dtype import infer_dtype
|
|
2
|
-
from .pkey import infer_primary_key
|
|
3
|
-
from .time_col import infer_time_column
|
|
4
2
|
from .id import contains_id
|
|
5
3
|
from .timestamp import contains_timestamp
|
|
6
4
|
from .categorical import contains_categorical
|
|
7
5
|
from .multicategorical import contains_multicategorical
|
|
6
|
+
from .stype import infer_stype
|
|
7
|
+
from .pkey import infer_primary_key
|
|
8
|
+
from .time_col import infer_time_column
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
|
10
11
|
'infer_dtype',
|
|
11
|
-
'infer_primary_key',
|
|
12
|
-
'infer_time_column',
|
|
13
12
|
'contains_id',
|
|
14
13
|
'contains_timestamp',
|
|
15
14
|
'contains_categorical',
|
|
16
15
|
'contains_multicategorical',
|
|
16
|
+
'infer_stype',
|
|
17
|
+
'infer_primary_key',
|
|
18
|
+
'infer_time_column',
|
|
17
19
|
]
|
|
@@ -10,6 +10,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
|
10
10
|
'int16': Dtype.int,
|
|
11
11
|
'int32': Dtype.int,
|
|
12
12
|
'int64': Dtype.int,
|
|
13
|
+
'float': Dtype.float,
|
|
14
|
+
'double': Dtype.float,
|
|
13
15
|
'float16': Dtype.float,
|
|
14
16
|
'float32': Dtype.float,
|
|
15
17
|
'float64': Dtype.float,
|
|
@@ -18,6 +20,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
|
18
20
|
'string[python]': Dtype.string,
|
|
19
21
|
'string[pyarrow]': Dtype.string,
|
|
20
22
|
'binary': Dtype.binary,
|
|
23
|
+
'binary[python]': Dtype.binary,
|
|
24
|
+
'binary[pyarrow]': Dtype.binary,
|
|
21
25
|
}
|
|
22
26
|
|
|
23
27
|
|
|
@@ -48,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
48
52
|
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
49
53
|
|
|
50
54
|
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
51
|
-
if pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
55
|
+
if (pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
56
|
+
or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
|
|
52
57
|
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
53
58
|
if pa.types.is_integer(elem_dtype):
|
|
54
59
|
return Dtype.intlist
|
|
@@ -40,7 +40,7 @@ def contains_multicategorical(
|
|
|
40
40
|
sep = max(candidates, key=candidates.get) # type: ignore
|
|
41
41
|
ser = ser.str.split(sep)
|
|
42
42
|
|
|
43
|
-
num_unique_multi = ser.explode().nunique()
|
|
43
|
+
num_unique_multi = ser.astype('object').explode().nunique()
|
|
44
44
|
|
|
45
45
|
if dtype.is_list():
|
|
46
46
|
return num_unique_multi <= MAX_CAT
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from kumoapi.typing import Dtype, Stype
|
|
3
|
+
|
|
4
|
+
from kumoai.experimental.rfm.infer import (
|
|
5
|
+
contains_categorical,
|
|
6
|
+
contains_id,
|
|
7
|
+
contains_multicategorical,
|
|
8
|
+
contains_timestamp,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def infer_stype(ser: pd.Series, column_name: str, dtype: Dtype) -> Stype:
|
|
13
|
+
"""Infers the :class:`Stype` from a :class:`pandas.Series`.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
ser: A :class:`pandas.Series` to analyze.
|
|
17
|
+
column_name: The column name.
|
|
18
|
+
dtype: The data type.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The semantic type.
|
|
22
|
+
"""
|
|
23
|
+
if contains_id(ser, column_name, dtype):
|
|
24
|
+
return Stype.ID
|
|
25
|
+
|
|
26
|
+
if contains_timestamp(ser, column_name, dtype):
|
|
27
|
+
return Stype.timestamp
|
|
28
|
+
|
|
29
|
+
if contains_multicategorical(ser, column_name, dtype):
|
|
30
|
+
return Stype.multicategorical
|
|
31
|
+
|
|
32
|
+
if contains_categorical(ser, column_name, dtype):
|
|
33
|
+
return Stype.categorical
|
|
34
|
+
|
|
35
|
+
return dtype.default_stype
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|
kumoai/experimental/rfm/rfm.py
CHANGED
|
@@ -28,13 +28,12 @@ from kumoapi.rfm import (
|
|
|
28
28
|
from kumoapi.task import TaskType
|
|
29
29
|
from kumoapi.typing import AggregationType, Stype
|
|
30
30
|
|
|
31
|
-
from kumoai import in_notebook, in_snowflake_notebook
|
|
32
31
|
from kumoai.client.rfm import RFMAPI
|
|
33
32
|
from kumoai.exceptions import HTTPException
|
|
34
33
|
from kumoai.experimental.rfm import Graph
|
|
35
34
|
from kumoai.experimental.rfm.base import DataBackend, Sampler
|
|
36
35
|
from kumoai.mixin import CastMixin
|
|
37
|
-
from kumoai.utils import ProgressLogger
|
|
36
|
+
from kumoai.utils import ProgressLogger, display
|
|
38
37
|
|
|
39
38
|
_RANDOM_SEED = 42
|
|
40
39
|
|
|
@@ -104,23 +103,8 @@ class Explanation:
|
|
|
104
103
|
|
|
105
104
|
def print(self) -> None:
|
|
106
105
|
r"""Prints the explanation."""
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
st.dataframe(self.prediction, hide_index=True)
|
|
110
|
-
st.markdown(self.summary)
|
|
111
|
-
elif in_notebook():
|
|
112
|
-
from IPython.display import Markdown, display
|
|
113
|
-
try:
|
|
114
|
-
if hasattr(self.prediction.style, 'hide'):
|
|
115
|
-
display(self.prediction.hide(axis='index')) # pandas=2
|
|
116
|
-
else:
|
|
117
|
-
display(self.prediction.hide_index()) # pandas <1.3
|
|
118
|
-
except ImportError:
|
|
119
|
-
print(self.prediction.to_string(index=False)) # missing jinja2
|
|
120
|
-
display(Markdown(self.summary))
|
|
121
|
-
else:
|
|
122
|
-
print(self.prediction.to_string(index=False))
|
|
123
|
-
print(self.summary)
|
|
106
|
+
display.dataframe(self.prediction)
|
|
107
|
+
display.message(self.summary)
|
|
124
108
|
|
|
125
109
|
def _ipython_display_(self) -> None:
|
|
126
110
|
self.print()
|
|
@@ -714,7 +698,7 @@ class KumoRFM:
|
|
|
714
698
|
f"to have a time column")
|
|
715
699
|
|
|
716
700
|
train, test = self._sampler.sample_target(
|
|
717
|
-
query=
|
|
701
|
+
query=query_def,
|
|
718
702
|
num_train_examples=0,
|
|
719
703
|
train_anchor_time=anchor_time,
|
|
720
704
|
num_train_trials=0,
|