kumoai 2.14.0.dev202512211732__cp313-cp313-macosx_11_0_arm64.whl → 2.15.0.dev202601121731__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +190 -71
- kumoai/experimental/rfm/backend/snow/table.py +137 -64
- kumoai/experimental/rfm/backend/sqlite/sampler.py +192 -87
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +27 -0
- kumoai/experimental/rfm/graph.py +115 -107
- kumoai/experimental/rfm/infer/dtype.py +4 -1
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +13 -1
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/RECORD +39 -34
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601121731.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import io
|
|
@@ -16,14 +18,10 @@ from kumoapi.typing import Stype
|
|
|
16
18
|
from typing_extensions import Self
|
|
17
19
|
|
|
18
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
-
from kumoai.experimental.rfm.base import
|
|
20
|
-
ColumnExpressionSpec,
|
|
21
|
-
DataBackend,
|
|
22
|
-
SQLTable,
|
|
23
|
-
Table,
|
|
24
|
-
)
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
25
22
|
from kumoai.graph import Edge
|
|
26
23
|
from kumoai.mixin import CastMixin
|
|
24
|
+
from kumoai.utils import display
|
|
27
25
|
|
|
28
26
|
if TYPE_CHECKING:
|
|
29
27
|
import graphviz
|
|
@@ -103,27 +101,24 @@ class Graph:
|
|
|
103
101
|
for table in tables:
|
|
104
102
|
self.add_table(table)
|
|
105
103
|
|
|
106
|
-
for table in tables:
|
|
107
|
-
if not
|
|
108
|
-
continue
|
|
109
|
-
if '_source_column_dict' not in table.__dict__:
|
|
104
|
+
for table in tables: # Use links from source metadata:
|
|
105
|
+
if not any(column.is_source for column in table.columns):
|
|
110
106
|
continue
|
|
111
107
|
for fkey in table._source_foreign_key_dict.values():
|
|
112
108
|
if fkey.name not in table:
|
|
113
109
|
continue
|
|
114
|
-
if not table[fkey.name].
|
|
110
|
+
if not table[fkey.name].is_source:
|
|
115
111
|
continue
|
|
116
112
|
dst_table_names = [
|
|
117
113
|
table.name for table in self.tables.values()
|
|
118
|
-
if
|
|
119
|
-
and table._source_name == fkey.dst_table
|
|
114
|
+
if table.source_name == fkey.dst_table
|
|
120
115
|
]
|
|
121
116
|
if len(dst_table_names) != 1:
|
|
122
117
|
continue
|
|
123
118
|
dst_table = self[dst_table_names[0]]
|
|
124
119
|
if dst_table._primary_key != fkey.primary_key:
|
|
125
120
|
continue
|
|
126
|
-
if not dst_table[fkey.primary_key].
|
|
121
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
127
122
|
continue
|
|
128
123
|
self.link(table.name, fkey.name, dst_table.name)
|
|
129
124
|
|
|
@@ -444,9 +439,8 @@ class Graph:
|
|
|
444
439
|
f"'{table_name}' since composite primary keys "
|
|
445
440
|
f"are not yet supported")
|
|
446
441
|
|
|
447
|
-
columns: list[
|
|
442
|
+
columns: list[ColumnSpec] = []
|
|
448
443
|
unsupported_columns: list[str] = []
|
|
449
|
-
column_expression_specs: list[ColumnExpressionSpec] = []
|
|
450
444
|
for column_cfg in chain(
|
|
451
445
|
table_cfg.get('dimensions', []),
|
|
452
446
|
table_cfg.get('time_dimensions', []),
|
|
@@ -457,13 +451,13 @@ class Graph:
|
|
|
457
451
|
column_data_type = column_cfg.get('data_type', None)
|
|
458
452
|
|
|
459
453
|
if column_expr is None:
|
|
460
|
-
columns.append(column_name)
|
|
454
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
455
|
continue
|
|
462
456
|
|
|
463
457
|
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
464
458
|
|
|
465
459
|
if column_expr == column_name:
|
|
466
|
-
columns.append(column_name)
|
|
460
|
+
columns.append(ColumnSpec(name=column_name))
|
|
467
461
|
continue
|
|
468
462
|
|
|
469
463
|
# Drop expressions that reference other tables (for now):
|
|
@@ -471,12 +465,12 @@ class Graph:
|
|
|
471
465
|
unsupported_columns.append(column_name)
|
|
472
466
|
continue
|
|
473
467
|
|
|
474
|
-
|
|
468
|
+
column = ColumnSpec(
|
|
475
469
|
name=column_name,
|
|
476
470
|
expr=column_expr,
|
|
477
|
-
dtype=SnowTable.
|
|
471
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
478
472
|
)
|
|
479
|
-
|
|
473
|
+
columns.append(column)
|
|
480
474
|
|
|
481
475
|
if len(unsupported_columns) == 1:
|
|
482
476
|
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
@@ -494,7 +488,6 @@ class Graph:
|
|
|
494
488
|
database=database,
|
|
495
489
|
schema=schema,
|
|
496
490
|
columns=columns,
|
|
497
|
-
column_expressions=column_expression_specs,
|
|
498
491
|
primary_key=primary_key,
|
|
499
492
|
)
|
|
500
493
|
|
|
@@ -546,6 +539,35 @@ class Graph:
|
|
|
546
539
|
|
|
547
540
|
return graph
|
|
548
541
|
|
|
542
|
+
@classmethod
|
|
543
|
+
def from_relbench(
|
|
544
|
+
cls,
|
|
545
|
+
dataset: str,
|
|
546
|
+
verbose: bool = True,
|
|
547
|
+
) -> Graph:
|
|
548
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
549
|
+
:class:`Graph` instance.
|
|
550
|
+
|
|
551
|
+
.. code-block:: python
|
|
552
|
+
|
|
553
|
+
>>> # doctest: +SKIP
|
|
554
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
555
|
+
|
|
556
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
557
|
+
|
|
558
|
+
Args:
|
|
559
|
+
dataset: The RelBench dataset name.
|
|
560
|
+
verbose: Whether to print verbose output.
|
|
561
|
+
"""
|
|
562
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
563
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
564
|
+
|
|
565
|
+
if verbose:
|
|
566
|
+
graph.print_metadata()
|
|
567
|
+
graph.print_links()
|
|
568
|
+
|
|
569
|
+
return graph
|
|
570
|
+
|
|
549
571
|
# Backend #################################################################
|
|
550
572
|
|
|
551
573
|
@property
|
|
@@ -627,28 +649,28 @@ class Graph:
|
|
|
627
649
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
628
650
|
information about the tables in this graph.
|
|
629
651
|
|
|
630
|
-
The returned dataframe has columns ``
|
|
631
|
-
``
|
|
632
|
-
view of the properties of the tables of this graph.
|
|
652
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
653
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
654
|
+
aggregated view of the properties of the tables of this graph.
|
|
633
655
|
|
|
634
656
|
Example:
|
|
635
657
|
>>> # doctest: +SKIP
|
|
636
658
|
>>> import kumoai.experimental.rfm as rfm
|
|
637
659
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
638
660
|
>>> graph.metadata # doctest: +SKIP
|
|
639
|
-
|
|
640
|
-
0 users
|
|
661
|
+
Name Primary Key Time Column End Time Column
|
|
662
|
+
0 users user_id - -
|
|
641
663
|
"""
|
|
642
664
|
tables = list(self.tables.values())
|
|
643
665
|
|
|
644
666
|
return pd.DataFrame({
|
|
645
|
-
'
|
|
667
|
+
'Name':
|
|
646
668
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
647
|
-
'
|
|
669
|
+
'Primary Key':
|
|
648
670
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
649
|
-
'
|
|
671
|
+
'Time Column':
|
|
650
672
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
651
|
-
'
|
|
673
|
+
'End Time Column':
|
|
652
674
|
pd.Series(
|
|
653
675
|
dtype=str,
|
|
654
676
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -657,24 +679,8 @@ class Graph:
|
|
|
657
679
|
|
|
658
680
|
def print_metadata(self) -> None:
|
|
659
681
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
663
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
664
|
-
elif in_notebook():
|
|
665
|
-
from IPython.display import Markdown, display
|
|
666
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
667
|
-
df = self.metadata
|
|
668
|
-
try:
|
|
669
|
-
if hasattr(df.style, 'hide'):
|
|
670
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
671
|
-
else:
|
|
672
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
673
|
-
except ImportError:
|
|
674
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
675
|
-
else:
|
|
676
|
-
print("🗂️ Graph Metadata:")
|
|
677
|
-
print(self.metadata.to_string(index=False))
|
|
682
|
+
display.title("🗂️ Graph Metadata")
|
|
683
|
+
display.dataframe(self.metadata)
|
|
678
684
|
|
|
679
685
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
680
686
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -703,40 +709,21 @@ class Graph:
|
|
|
703
709
|
|
|
704
710
|
def print_links(self) -> None:
|
|
705
711
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
706
|
-
edges = [(
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
st.markdown("*No links registered*")
|
|
720
|
-
elif in_notebook():
|
|
721
|
-
from IPython.display import Markdown, display
|
|
722
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
723
|
-
if len(edges) > 0:
|
|
724
|
-
display(
|
|
725
|
-
Markdown('\n'.join([
|
|
726
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
727
|
-
for edge in edges
|
|
728
|
-
])))
|
|
729
|
-
else:
|
|
730
|
-
display(Markdown("*No links registered*"))
|
|
712
|
+
edges = sorted([(
|
|
713
|
+
edge.dst_table,
|
|
714
|
+
self[edge.dst_table]._primary_key,
|
|
715
|
+
edge.src_table,
|
|
716
|
+
edge.fkey,
|
|
717
|
+
) for edge in self.edges])
|
|
718
|
+
|
|
719
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
720
|
+
if len(edges) > 0:
|
|
721
|
+
display.unordered_list(items=[
|
|
722
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
723
|
+
for edge in edges
|
|
724
|
+
])
|
|
731
725
|
else:
|
|
732
|
-
|
|
733
|
-
if len(edges) > 0:
|
|
734
|
-
print('\n'.join([
|
|
735
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
736
|
-
for edge in edges
|
|
737
|
-
]))
|
|
738
|
-
else:
|
|
739
|
-
print("No links registered")
|
|
726
|
+
display.italic("No links registered")
|
|
740
727
|
|
|
741
728
|
def link(
|
|
742
729
|
self,
|
|
@@ -843,6 +830,30 @@ class Graph:
|
|
|
843
830
|
"""
|
|
844
831
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
845
832
|
|
|
833
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
834
|
+
if not any(column.is_source for column in table.columns):
|
|
835
|
+
continue
|
|
836
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
837
|
+
if fkey.name not in table:
|
|
838
|
+
continue
|
|
839
|
+
if not table[fkey.name].is_source:
|
|
840
|
+
continue
|
|
841
|
+
if (table.name, fkey.name) in known_edges:
|
|
842
|
+
continue
|
|
843
|
+
dst_table_names = [
|
|
844
|
+
table.name for table in self.tables.values()
|
|
845
|
+
if table.source_name == fkey.dst_table
|
|
846
|
+
]
|
|
847
|
+
if len(dst_table_names) != 1:
|
|
848
|
+
continue
|
|
849
|
+
dst_table = self[dst_table_names[0]]
|
|
850
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
851
|
+
continue
|
|
852
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
853
|
+
continue
|
|
854
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
855
|
+
known_edges.add((table.name, fkey.name))
|
|
856
|
+
|
|
846
857
|
# A list of primary key candidates (+score) for every column:
|
|
847
858
|
candidate_dict: dict[
|
|
848
859
|
tuple[str, str],
|
|
@@ -942,13 +953,8 @@ class Graph:
|
|
|
942
953
|
if score < 5.0:
|
|
943
954
|
continue
|
|
944
955
|
|
|
945
|
-
candidate_dict[(
|
|
946
|
-
|
|
947
|
-
src_key.name,
|
|
948
|
-
)].append((
|
|
949
|
-
dst_table.name,
|
|
950
|
-
score,
|
|
951
|
-
))
|
|
956
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
957
|
+
(dst_table.name, score))
|
|
952
958
|
|
|
953
959
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
954
960
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -1007,24 +1013,26 @@ class Graph:
|
|
|
1007
1013
|
f"either the primary key or the link before "
|
|
1008
1014
|
f"before proceeding.")
|
|
1009
1015
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1016
|
+
if self.backend == DataBackend.LOCAL:
|
|
1017
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1018
|
+
assert src_key.dtype is not None
|
|
1019
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1020
|
+
src_string = src_key.dtype.is_string()
|
|
1021
|
+
assert dst_key.dtype is not None
|
|
1022
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1023
|
+
dst_string = dst_key.dtype.is_string()
|
|
1024
|
+
|
|
1025
|
+
if not src_number and not src_string:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1028
|
+
f"or string (got '{src_key.dtype}'")
|
|
1029
|
+
|
|
1030
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1031
|
+
raise ValueError(
|
|
1032
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1033
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1034
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1035
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
1028
1036
|
|
|
1029
1037
|
return self
|
|
1030
1038
|
|
|
@@ -20,6 +20,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
|
20
20
|
'string[python]': Dtype.string,
|
|
21
21
|
'string[pyarrow]': Dtype.string,
|
|
22
22
|
'binary': Dtype.binary,
|
|
23
|
+
'binary[python]': Dtype.binary,
|
|
24
|
+
'binary[pyarrow]': Dtype.binary,
|
|
23
25
|
}
|
|
24
26
|
|
|
25
27
|
|
|
@@ -50,7 +52,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
50
52
|
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
51
53
|
|
|
52
54
|
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
53
|
-
if pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
55
|
+
if (pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
56
|
+
or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
|
|
54
57
|
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
55
58
|
if pa.types.is_integer(elem_dtype):
|
|
56
59
|
return Dtype.intlist
|
|
@@ -40,7 +40,7 @@ def contains_multicategorical(
|
|
|
40
40
|
sep = max(candidates, key=candidates.get) # type: ignore
|
|
41
41
|
ser = ser.str.split(sep)
|
|
42
42
|
|
|
43
|
-
num_unique_multi = ser.explode().nunique()
|
|
43
|
+
num_unique_multi = ser.astype('object').explode().nunique()
|
|
44
44
|
|
|
45
45
|
if dtype.is_list():
|
|
46
46
|
return num_unique_multi <= MAX_CAT
|
|
@@ -3,6 +3,8 @@ import warnings
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def infer_time_column(
|
|
8
10
|
df: pd.DataFrame,
|
|
@@ -43,11 +45,11 @@ def infer_time_column(
|
|
|
43
45
|
with warnings.catch_warnings():
|
|
44
46
|
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
45
47
|
min_timestamp_dict = {
|
|
46
|
-
key:
|
|
48
|
+
key: to_datetime(df[key].iloc[:10_000])
|
|
47
49
|
for key in candidates
|
|
48
50
|
}
|
|
49
51
|
min_timestamp_dict = {
|
|
50
|
-
key: value.min()
|
|
52
|
+
key: value.min()
|
|
51
53
|
for key, value in min_timestamp_dict.items()
|
|
52
54
|
}
|
|
53
55
|
min_timestamp_dict = {
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|