kumoai 2.14.0.dev202512211732__cp313-cp313-win_amd64.whl → 2.15.0.dev202601151732__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +23 -26
- kumoai/_version.py +1 -1
- kumoai/client/client.py +6 -0
- kumoai/client/jobs.py +26 -0
- kumoai/connector/utils.py +21 -7
- kumoai/experimental/rfm/__init__.py +24 -22
- kumoai/experimental/rfm/backend/local/graph_store.py +12 -21
- kumoai/experimental/rfm/backend/local/sampler.py +0 -3
- kumoai/experimental/rfm/backend/local/table.py +24 -25
- kumoai/experimental/rfm/backend/snow/sampler.py +235 -80
- kumoai/experimental/rfm/backend/snow/table.py +146 -70
- kumoai/experimental/rfm/backend/sqlite/sampler.py +196 -89
- kumoai/experimental/rfm/backend/sqlite/table.py +85 -55
- kumoai/experimental/rfm/base/__init__.py +6 -9
- kumoai/experimental/rfm/base/column.py +95 -11
- kumoai/experimental/rfm/base/expression.py +44 -0
- kumoai/experimental/rfm/base/mapper.py +69 -0
- kumoai/experimental/rfm/base/sampler.py +28 -18
- kumoai/experimental/rfm/base/source.py +1 -1
- kumoai/experimental/rfm/base/sql_sampler.py +320 -19
- kumoai/experimental/rfm/base/table.py +256 -109
- kumoai/experimental/rfm/base/utils.py +36 -0
- kumoai/experimental/rfm/graph.py +130 -110
- kumoai/experimental/rfm/infer/dtype.py +7 -2
- kumoai/experimental/rfm/infer/multicategorical.py +1 -1
- kumoai/experimental/rfm/infer/time_col.py +4 -2
- kumoai/experimental/rfm/relbench.py +76 -0
- kumoai/experimental/rfm/rfm.py +540 -306
- kumoai/experimental/rfm/task_table.py +292 -0
- kumoai/kumolib.cp313-win_amd64.pyd +0 -0
- kumoai/pquery/training_table.py +16 -2
- kumoai/testing/snow.py +3 -3
- kumoai/trainer/distilled_trainer.py +175 -0
- kumoai/utils/display.py +87 -0
- kumoai/utils/progress_logger.py +15 -2
- kumoai/utils/sql.py +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/METADATA +2 -2
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/RECORD +41 -36
- kumoai/experimental/rfm/base/column_expression.py +0 -50
- kumoai/experimental/rfm/base/sql_table.py +0 -229
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/WHEEL +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.14.0.dev202512211732.dist-info → kumoai-2.15.0.dev202601151732.dist-info}/top_level.txt +0 -0
kumoai/experimental/rfm/graph.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import contextlib
|
|
2
4
|
import copy
|
|
3
5
|
import io
|
|
@@ -16,14 +18,11 @@ from kumoapi.typing import Stype
|
|
|
16
18
|
from typing_extensions import Self
|
|
17
19
|
|
|
18
20
|
from kumoai import in_notebook, in_snowflake_notebook
|
|
19
|
-
from kumoai.experimental.rfm.base import
|
|
20
|
-
|
|
21
|
-
DataBackend,
|
|
22
|
-
SQLTable,
|
|
23
|
-
Table,
|
|
24
|
-
)
|
|
21
|
+
from kumoai.experimental.rfm.base import ColumnSpec, DataBackend, Table
|
|
22
|
+
from kumoai.experimental.rfm.infer import infer_time_column
|
|
25
23
|
from kumoai.graph import Edge
|
|
26
24
|
from kumoai.mixin import CastMixin
|
|
25
|
+
from kumoai.utils import display
|
|
27
26
|
|
|
28
27
|
if TYPE_CHECKING:
|
|
29
28
|
import graphviz
|
|
@@ -103,27 +102,24 @@ class Graph:
|
|
|
103
102
|
for table in tables:
|
|
104
103
|
self.add_table(table)
|
|
105
104
|
|
|
106
|
-
for table in tables:
|
|
107
|
-
if not
|
|
108
|
-
continue
|
|
109
|
-
if '_source_column_dict' not in table.__dict__:
|
|
105
|
+
for table in tables: # Use links from source metadata:
|
|
106
|
+
if not any(column.is_source for column in table.columns):
|
|
110
107
|
continue
|
|
111
108
|
for fkey in table._source_foreign_key_dict.values():
|
|
112
109
|
if fkey.name not in table:
|
|
113
110
|
continue
|
|
114
|
-
if not table[fkey.name].
|
|
111
|
+
if not table[fkey.name].is_source:
|
|
115
112
|
continue
|
|
116
113
|
dst_table_names = [
|
|
117
114
|
table.name for table in self.tables.values()
|
|
118
|
-
if
|
|
119
|
-
and table._source_name == fkey.dst_table
|
|
115
|
+
if table.source_name == fkey.dst_table
|
|
120
116
|
]
|
|
121
117
|
if len(dst_table_names) != 1:
|
|
122
118
|
continue
|
|
123
119
|
dst_table = self[dst_table_names[0]]
|
|
124
120
|
if dst_table._primary_key != fkey.primary_key:
|
|
125
121
|
continue
|
|
126
|
-
if not dst_table[fkey.primary_key].
|
|
122
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
127
123
|
continue
|
|
128
124
|
self.link(table.name, fkey.name, dst_table.name)
|
|
129
125
|
|
|
@@ -420,8 +416,9 @@ class Graph:
|
|
|
420
416
|
assert isinstance(connection, Connection)
|
|
421
417
|
|
|
422
418
|
with connection.cursor() as cursor:
|
|
423
|
-
|
|
424
|
-
|
|
419
|
+
sql = (f"SELECT SYSTEM$READ_YAML_FROM_SEMANTIC_VIEW("
|
|
420
|
+
f"'{semantic_view_name}')")
|
|
421
|
+
cursor.execute(sql)
|
|
425
422
|
cfg = yaml.safe_load(cursor.fetchone()[0])
|
|
426
423
|
|
|
427
424
|
graph = cls(tables=[])
|
|
@@ -444,9 +441,8 @@ class Graph:
|
|
|
444
441
|
f"'{table_name}' since composite primary keys "
|
|
445
442
|
f"are not yet supported")
|
|
446
443
|
|
|
447
|
-
columns: list[
|
|
444
|
+
columns: list[ColumnSpec] = []
|
|
448
445
|
unsupported_columns: list[str] = []
|
|
449
|
-
column_expression_specs: list[ColumnExpressionSpec] = []
|
|
450
446
|
for column_cfg in chain(
|
|
451
447
|
table_cfg.get('dimensions', []),
|
|
452
448
|
table_cfg.get('time_dimensions', []),
|
|
@@ -457,13 +453,13 @@ class Graph:
|
|
|
457
453
|
column_data_type = column_cfg.get('data_type', None)
|
|
458
454
|
|
|
459
455
|
if column_expr is None:
|
|
460
|
-
columns.append(column_name)
|
|
456
|
+
columns.append(ColumnSpec(name=column_name))
|
|
461
457
|
continue
|
|
462
458
|
|
|
463
459
|
column_expr = column_expr.replace(f'{table_name}.', '')
|
|
464
460
|
|
|
465
461
|
if column_expr == column_name:
|
|
466
|
-
columns.append(column_name)
|
|
462
|
+
columns.append(ColumnSpec(name=column_name))
|
|
467
463
|
continue
|
|
468
464
|
|
|
469
465
|
# Drop expressions that reference other tables (for now):
|
|
@@ -471,12 +467,12 @@ class Graph:
|
|
|
471
467
|
unsupported_columns.append(column_name)
|
|
472
468
|
continue
|
|
473
469
|
|
|
474
|
-
|
|
470
|
+
column = ColumnSpec(
|
|
475
471
|
name=column_name,
|
|
476
472
|
expr=column_expr,
|
|
477
|
-
dtype=SnowTable.
|
|
473
|
+
dtype=SnowTable._to_dtype(column_data_type),
|
|
478
474
|
)
|
|
479
|
-
|
|
475
|
+
columns.append(column)
|
|
480
476
|
|
|
481
477
|
if len(unsupported_columns) == 1:
|
|
482
478
|
msgs.append(f"Failed to add column '{unsupported_columns[0]}' "
|
|
@@ -494,12 +490,21 @@ class Graph:
|
|
|
494
490
|
database=database,
|
|
495
491
|
schema=schema,
|
|
496
492
|
columns=columns,
|
|
497
|
-
column_expressions=column_expression_specs,
|
|
498
493
|
primary_key=primary_key,
|
|
499
494
|
)
|
|
500
495
|
|
|
501
496
|
# TODO Add a way to register time columns without heuristic usage.
|
|
502
|
-
|
|
497
|
+
time_candidates = [
|
|
498
|
+
column_cfg['name']
|
|
499
|
+
for column_cfg in table_cfg.get('time_dimensions', [])
|
|
500
|
+
if table.has_column(column_cfg['name'])
|
|
501
|
+
and table[column_cfg['name']].stype == Stype.timestamp
|
|
502
|
+
]
|
|
503
|
+
if time_column := infer_time_column(
|
|
504
|
+
df=table._get_sample_df(),
|
|
505
|
+
candidates=time_candidates,
|
|
506
|
+
):
|
|
507
|
+
table.time_column = time_column
|
|
503
508
|
|
|
504
509
|
graph.add_table(table)
|
|
505
510
|
|
|
@@ -546,6 +551,35 @@ class Graph:
|
|
|
546
551
|
|
|
547
552
|
return graph
|
|
548
553
|
|
|
554
|
+
@classmethod
|
|
555
|
+
def from_relbench(
|
|
556
|
+
cls,
|
|
557
|
+
dataset: str,
|
|
558
|
+
verbose: bool = True,
|
|
559
|
+
) -> Graph:
|
|
560
|
+
r"""Loads a `RelBench <https://relbench.stanford.edu>`_ dataset into a
|
|
561
|
+
:class:`Graph` instance.
|
|
562
|
+
|
|
563
|
+
.. code-block:: python
|
|
564
|
+
|
|
565
|
+
>>> # doctest: +SKIP
|
|
566
|
+
>>> import kumoai.experimental.rfm as rfm
|
|
567
|
+
|
|
568
|
+
>>> graph = rfm.Graph.from_relbench("f1")
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
dataset: The RelBench dataset name.
|
|
572
|
+
verbose: Whether to print verbose output.
|
|
573
|
+
"""
|
|
574
|
+
from kumoai.experimental.rfm.relbench import from_relbench
|
|
575
|
+
graph = from_relbench(dataset, verbose=verbose)
|
|
576
|
+
|
|
577
|
+
if verbose:
|
|
578
|
+
graph.print_metadata()
|
|
579
|
+
graph.print_links()
|
|
580
|
+
|
|
581
|
+
return graph
|
|
582
|
+
|
|
549
583
|
# Backend #################################################################
|
|
550
584
|
|
|
551
585
|
@property
|
|
@@ -627,28 +661,28 @@ class Graph:
|
|
|
627
661
|
r"""Returns a :class:`pandas.DataFrame` object containing metadata
|
|
628
662
|
information about the tables in this graph.
|
|
629
663
|
|
|
630
|
-
The returned dataframe has columns ``
|
|
631
|
-
``
|
|
632
|
-
view of the properties of the tables of this graph.
|
|
664
|
+
The returned dataframe has columns ``"Name"``, ``"Primary Key"``,
|
|
665
|
+
``"Time Column"``, and ``"End Time Column"``, which provide an
|
|
666
|
+
aggregated view of the properties of the tables of this graph.
|
|
633
667
|
|
|
634
668
|
Example:
|
|
635
669
|
>>> # doctest: +SKIP
|
|
636
670
|
>>> import kumoai.experimental.rfm as rfm
|
|
637
671
|
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
638
672
|
>>> graph.metadata # doctest: +SKIP
|
|
639
|
-
|
|
640
|
-
0 users
|
|
673
|
+
Name Primary Key Time Column End Time Column
|
|
674
|
+
0 users user_id - -
|
|
641
675
|
"""
|
|
642
676
|
tables = list(self.tables.values())
|
|
643
677
|
|
|
644
678
|
return pd.DataFrame({
|
|
645
|
-
'
|
|
679
|
+
'Name':
|
|
646
680
|
pd.Series(dtype=str, data=[t.name for t in tables]),
|
|
647
|
-
'
|
|
681
|
+
'Primary Key':
|
|
648
682
|
pd.Series(dtype=str, data=[t._primary_key or '-' for t in tables]),
|
|
649
|
-
'
|
|
683
|
+
'Time Column':
|
|
650
684
|
pd.Series(dtype=str, data=[t._time_column or '-' for t in tables]),
|
|
651
|
-
'
|
|
685
|
+
'End Time Column':
|
|
652
686
|
pd.Series(
|
|
653
687
|
dtype=str,
|
|
654
688
|
data=[t._end_time_column or '-' for t in tables],
|
|
@@ -657,24 +691,8 @@ class Graph:
|
|
|
657
691
|
|
|
658
692
|
def print_metadata(self) -> None:
|
|
659
693
|
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
st.markdown("### 🗂️ Graph Metadata")
|
|
663
|
-
st.dataframe(self.metadata, hide_index=True)
|
|
664
|
-
elif in_notebook():
|
|
665
|
-
from IPython.display import Markdown, display
|
|
666
|
-
display(Markdown("### 🗂️ Graph Metadata"))
|
|
667
|
-
df = self.metadata
|
|
668
|
-
try:
|
|
669
|
-
if hasattr(df.style, 'hide'):
|
|
670
|
-
display(df.style.hide(axis='index')) # pandas=2
|
|
671
|
-
else:
|
|
672
|
-
display(df.style.hide_index()) # pandas<1.3
|
|
673
|
-
except ImportError:
|
|
674
|
-
print(df.to_string(index=False)) # missing jinja2
|
|
675
|
-
else:
|
|
676
|
-
print("🗂️ Graph Metadata:")
|
|
677
|
-
print(self.metadata.to_string(index=False))
|
|
694
|
+
display.title("🗂️ Graph Metadata")
|
|
695
|
+
display.dataframe(self.metadata)
|
|
678
696
|
|
|
679
697
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
680
698
|
r"""Infers metadata for all tables in the graph.
|
|
@@ -703,40 +721,21 @@ class Graph:
|
|
|
703
721
|
|
|
704
722
|
def print_links(self) -> None:
|
|
705
723
|
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
706
|
-
edges = [(
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
st.markdown("*No links registered*")
|
|
720
|
-
elif in_notebook():
|
|
721
|
-
from IPython.display import Markdown, display
|
|
722
|
-
display(Markdown("### 🕸️ Graph Links (FK ↔️ PK)"))
|
|
723
|
-
if len(edges) > 0:
|
|
724
|
-
display(
|
|
725
|
-
Markdown('\n'.join([
|
|
726
|
-
f"- `{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
727
|
-
for edge in edges
|
|
728
|
-
])))
|
|
729
|
-
else:
|
|
730
|
-
display(Markdown("*No links registered*"))
|
|
724
|
+
edges = sorted([(
|
|
725
|
+
edge.dst_table,
|
|
726
|
+
self[edge.dst_table]._primary_key,
|
|
727
|
+
edge.src_table,
|
|
728
|
+
edge.fkey,
|
|
729
|
+
) for edge in self.edges])
|
|
730
|
+
|
|
731
|
+
display.title("🕸️ Graph Links (FK ↔️ PK)")
|
|
732
|
+
if len(edges) > 0:
|
|
733
|
+
display.unordered_list(items=[
|
|
734
|
+
f"`{edge[2]}.{edge[3]}` ↔️ `{edge[0]}.{edge[1]}`"
|
|
735
|
+
for edge in edges
|
|
736
|
+
])
|
|
731
737
|
else:
|
|
732
|
-
|
|
733
|
-
if len(edges) > 0:
|
|
734
|
-
print('\n'.join([
|
|
735
|
-
f"• {edge[2]}.{edge[3]} ↔️ {edge[0]}.{edge[1]}"
|
|
736
|
-
for edge in edges
|
|
737
|
-
]))
|
|
738
|
-
else:
|
|
739
|
-
print("No links registered")
|
|
738
|
+
display.italic("No links registered")
|
|
740
739
|
|
|
741
740
|
def link(
|
|
742
741
|
self,
|
|
@@ -843,6 +842,30 @@ class Graph:
|
|
|
843
842
|
"""
|
|
844
843
|
known_edges = {(edge.src_table, edge.fkey) for edge in self.edges}
|
|
845
844
|
|
|
845
|
+
for table in self.tables.values(): # Use links from source metadata:
|
|
846
|
+
if not any(column.is_source for column in table.columns):
|
|
847
|
+
continue
|
|
848
|
+
for fkey in table._source_foreign_key_dict.values():
|
|
849
|
+
if fkey.name not in table:
|
|
850
|
+
continue
|
|
851
|
+
if not table[fkey.name].is_source:
|
|
852
|
+
continue
|
|
853
|
+
if (table.name, fkey.name) in known_edges:
|
|
854
|
+
continue
|
|
855
|
+
dst_table_names = [
|
|
856
|
+
table.name for table in self.tables.values()
|
|
857
|
+
if table.source_name == fkey.dst_table
|
|
858
|
+
]
|
|
859
|
+
if len(dst_table_names) != 1:
|
|
860
|
+
continue
|
|
861
|
+
dst_table = self[dst_table_names[0]]
|
|
862
|
+
if dst_table._primary_key != fkey.primary_key:
|
|
863
|
+
continue
|
|
864
|
+
if not dst_table[fkey.primary_key].is_source:
|
|
865
|
+
continue
|
|
866
|
+
self.link(table.name, fkey.name, dst_table.name)
|
|
867
|
+
known_edges.add((table.name, fkey.name))
|
|
868
|
+
|
|
846
869
|
# A list of primary key candidates (+score) for every column:
|
|
847
870
|
candidate_dict: dict[
|
|
848
871
|
tuple[str, str],
|
|
@@ -942,13 +965,8 @@ class Graph:
|
|
|
942
965
|
if score < 5.0:
|
|
943
966
|
continue
|
|
944
967
|
|
|
945
|
-
candidate_dict[(
|
|
946
|
-
|
|
947
|
-
src_key.name,
|
|
948
|
-
)].append((
|
|
949
|
-
dst_table.name,
|
|
950
|
-
score,
|
|
951
|
-
))
|
|
968
|
+
candidate_dict[(src_table.name, src_key.name)].append(
|
|
969
|
+
(dst_table.name, score))
|
|
952
970
|
|
|
953
971
|
for (src_table_name, src_key_name), scores in candidate_dict.items():
|
|
954
972
|
scores.sort(key=lambda x: x[-1], reverse=True)
|
|
@@ -1007,24 +1025,26 @@ class Graph:
|
|
|
1007
1025
|
f"either the primary key or the link before "
|
|
1008
1026
|
f"before proceeding.")
|
|
1009
1027
|
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
+
if self.backend == DataBackend.LOCAL:
|
|
1029
|
+
# Check that fkey/pkey have valid and consistent data types:
|
|
1030
|
+
assert src_key.dtype is not None
|
|
1031
|
+
src_number = src_key.dtype.is_int() or src_key.dtype.is_float()
|
|
1032
|
+
src_string = src_key.dtype.is_string()
|
|
1033
|
+
assert dst_key.dtype is not None
|
|
1034
|
+
dst_number = dst_key.dtype.is_int() or dst_key.dtype.is_float()
|
|
1035
|
+
dst_string = dst_key.dtype.is_string()
|
|
1036
|
+
|
|
1037
|
+
if not src_number and not src_string:
|
|
1038
|
+
raise ValueError(
|
|
1039
|
+
f"{edge} is invalid as foreign key must be a number "
|
|
1040
|
+
f"or string (got '{src_key.dtype}'")
|
|
1041
|
+
|
|
1042
|
+
if src_number != dst_number or src_string != dst_string:
|
|
1043
|
+
raise ValueError(
|
|
1044
|
+
f"{edge} is invalid as foreign key '{fkey}' and "
|
|
1045
|
+
f"primary key '{dst_key.name}' have incompatible data "
|
|
1046
|
+
f"types (got foreign key data type '{src_key.dtype}' "
|
|
1047
|
+
f"and primary key data type '{dst_key.dtype}')")
|
|
1028
1048
|
|
|
1029
1049
|
return self
|
|
1030
1050
|
|
|
@@ -3,6 +3,8 @@ import pandas as pd
|
|
|
3
3
|
import pyarrow as pa
|
|
4
4
|
from kumoapi.typing import Dtype
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import is_datetime
|
|
7
|
+
|
|
6
8
|
PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
7
9
|
'bool': Dtype.bool,
|
|
8
10
|
'boolean': Dtype.bool,
|
|
@@ -20,6 +22,8 @@ PANDAS_TO_DTYPE: dict[str, Dtype] = {
|
|
|
20
22
|
'string[python]': Dtype.string,
|
|
21
23
|
'string[pyarrow]': Dtype.string,
|
|
22
24
|
'binary': Dtype.binary,
|
|
25
|
+
'binary[python]': Dtype.binary,
|
|
26
|
+
'binary[pyarrow]': Dtype.binary,
|
|
23
27
|
}
|
|
24
28
|
|
|
25
29
|
|
|
@@ -32,7 +36,7 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
32
36
|
Returns:
|
|
33
37
|
The data type.
|
|
34
38
|
"""
|
|
35
|
-
if
|
|
39
|
+
if is_datetime(ser):
|
|
36
40
|
return Dtype.date
|
|
37
41
|
if pd.api.types.is_timedelta64_dtype(ser.dtype):
|
|
38
42
|
return Dtype.timedelta
|
|
@@ -50,7 +54,8 @@ def infer_dtype(ser: pd.Series) -> Dtype:
|
|
|
50
54
|
ser = pd.Series(arr, dtype=pd.ArrowDtype(arr.type))
|
|
51
55
|
|
|
52
56
|
if isinstance(ser.dtype, pd.ArrowDtype):
|
|
53
|
-
if pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
57
|
+
if (pa.types.is_list(ser.dtype.pyarrow_dtype)
|
|
58
|
+
or pa.types.is_fixed_size_list(ser.dtype.pyarrow_dtype)):
|
|
54
59
|
elem_dtype = ser.dtype.pyarrow_dtype.value_type
|
|
55
60
|
if pa.types.is_integer(elem_dtype):
|
|
56
61
|
return Dtype.intlist
|
|
@@ -40,7 +40,7 @@ def contains_multicategorical(
|
|
|
40
40
|
sep = max(candidates, key=candidates.get) # type: ignore
|
|
41
41
|
ser = ser.str.split(sep)
|
|
42
42
|
|
|
43
|
-
num_unique_multi = ser.explode().nunique()
|
|
43
|
+
num_unique_multi = ser.astype('object').explode().nunique()
|
|
44
44
|
|
|
45
45
|
if dtype.is_list():
|
|
46
46
|
return num_unique_multi <= MAX_CAT
|
|
@@ -3,6 +3,8 @@ import warnings
|
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
|
|
6
|
+
from kumoai.experimental.rfm.base.utils import to_datetime
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def infer_time_column(
|
|
8
10
|
df: pd.DataFrame,
|
|
@@ -43,11 +45,11 @@ def infer_time_column(
|
|
|
43
45
|
with warnings.catch_warnings():
|
|
44
46
|
warnings.filterwarnings('ignore', message='Could not infer format')
|
|
45
47
|
min_timestamp_dict = {
|
|
46
|
-
key:
|
|
48
|
+
key: to_datetime(df[key].iloc[:10_000])
|
|
47
49
|
for key in candidates
|
|
48
50
|
}
|
|
49
51
|
min_timestamp_dict = {
|
|
50
|
-
key: value.min()
|
|
52
|
+
key: value.min()
|
|
51
53
|
for key, value in min_timestamp_dict.items()
|
|
52
54
|
}
|
|
53
55
|
min_timestamp_dict = {
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import json
|
|
3
|
+
from functools import lru_cache
|
|
4
|
+
from urllib.request import urlopen
|
|
5
|
+
|
|
6
|
+
import pooch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
|
|
9
|
+
from kumoai.experimental.rfm import Graph
|
|
10
|
+
from kumoai.experimental.rfm.backend.local import LocalTable
|
|
11
|
+
|
|
12
|
+
PREFIX = 'rel-'
|
|
13
|
+
CACHE_DIR = pooch.os_cache('relbench')
|
|
14
|
+
HASH_URL = ('https://raw.githubusercontent.com/snap-stanford/relbench/main/'
|
|
15
|
+
'relbench/datasets/hashes.json')
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@lru_cache
|
|
19
|
+
def get_registry() -> pooch.Pooch:
|
|
20
|
+
with urlopen(HASH_URL) as r:
|
|
21
|
+
hashes = json.load(r)
|
|
22
|
+
|
|
23
|
+
return pooch.create(
|
|
24
|
+
path=CACHE_DIR,
|
|
25
|
+
base_url='https://relbench.stanford.edu/download/',
|
|
26
|
+
registry=hashes,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def from_relbench(dataset: str, verbose: bool = True) -> Graph:
|
|
31
|
+
dataset = dataset.lower()
|
|
32
|
+
if dataset.startswith(PREFIX):
|
|
33
|
+
dataset = dataset[len(PREFIX):]
|
|
34
|
+
|
|
35
|
+
registry = get_registry()
|
|
36
|
+
|
|
37
|
+
datasets = [key.split('/')[0][len(PREFIX):] for key in registry.registry]
|
|
38
|
+
if dataset not in datasets:
|
|
39
|
+
matches = difflib.get_close_matches(dataset, datasets, n=1)
|
|
40
|
+
hint = f" Did you mean '{matches[0]}'?" if len(matches) > 0 else ''
|
|
41
|
+
raise ValueError(f"Unknown RelBench dataset '{dataset}'.{hint} Valid "
|
|
42
|
+
f"datasets are {str(datasets)[1:-1]}.")
|
|
43
|
+
|
|
44
|
+
registry.fetch(
|
|
45
|
+
f'{PREFIX}{dataset}/db.zip',
|
|
46
|
+
processor=pooch.Unzip(extract_dir='.'),
|
|
47
|
+
progressbar=verbose,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
graph = Graph(tables=[])
|
|
51
|
+
edges: list[tuple[str, str, str]] = []
|
|
52
|
+
for path in (CACHE_DIR / f'{PREFIX}{dataset}' / 'db').glob('*.parquet'):
|
|
53
|
+
data = pa.parquet.read_table(path)
|
|
54
|
+
metadata = {
|
|
55
|
+
key.decode('utf-8'): json.loads(value.decode('utf-8'))
|
|
56
|
+
for key, value in data.schema.metadata.items()
|
|
57
|
+
if key in [b"fkey_col_to_pkey_table", b"pkey_col", b"time_col"]
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
table = LocalTable(
|
|
61
|
+
df=data.to_pandas(),
|
|
62
|
+
name=path.stem,
|
|
63
|
+
primary_key=metadata['pkey_col'],
|
|
64
|
+
time_column=metadata['time_col'],
|
|
65
|
+
)
|
|
66
|
+
graph.add_table(table)
|
|
67
|
+
|
|
68
|
+
edges.extend([
|
|
69
|
+
(path.stem, fkey, dst_table)
|
|
70
|
+
for fkey, dst_table in metadata['fkey_col_to_pkey_table'].items()
|
|
71
|
+
])
|
|
72
|
+
|
|
73
|
+
for edge in edges:
|
|
74
|
+
graph.link(*edge)
|
|
75
|
+
|
|
76
|
+
return graph
|