kumoai 2.13.0.dev202511161731__cp312-cp312-macosx_11_0_arm64.whl → 2.13.0.dev202512011731__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kumoai/__init__.py +6 -9
- kumoai/_version.py +1 -1
- kumoai/client/client.py +9 -13
- kumoai/connector/utils.py +23 -2
- kumoai/experimental/rfm/__init__.py +162 -46
- kumoai/experimental/rfm/backend/__init__.py +0 -0
- kumoai/experimental/rfm/backend/local/__init__.py +38 -0
- kumoai/experimental/rfm/backend/local/table.py +151 -0
- kumoai/experimental/rfm/backend/sqlite/__init__.py +23 -0
- kumoai/experimental/rfm/backend/sqlite/table.py +117 -0
- kumoai/experimental/rfm/base/__init__.py +7 -0
- kumoai/experimental/rfm/base/column.py +66 -0
- kumoai/experimental/rfm/{local_table.py → base/table.py} +67 -139
- kumoai/experimental/rfm/{local_graph.py → graph.py} +44 -30
- kumoai/experimental/rfm/local_graph_sampler.py +0 -2
- kumoai/experimental/rfm/local_graph_store.py +12 -11
- kumoai/experimental/rfm/rfm.py +25 -14
- kumoai/experimental/rfm/sagemaker.py +138 -0
- kumoai/spcs.py +1 -3
- kumoai/testing/decorators.py +1 -1
- {kumoai-2.13.0.dev202511161731.dist-info → kumoai-2.13.0.dev202512011731.dist-info}/METADATA +9 -2
- {kumoai-2.13.0.dev202511161731.dist-info → kumoai-2.13.0.dev202512011731.dist-info}/RECORD +25 -17
- {kumoai-2.13.0.dev202511161731.dist-info → kumoai-2.13.0.dev202512011731.dist-info}/WHEEL +0 -0
- {kumoai-2.13.0.dev202511161731.dist-info → kumoai-2.13.0.dev202512011731.dist-info}/licenses/LICENSE +0 -0
- {kumoai-2.13.0.dev202511161731.dist-info → kumoai-2.13.0.dev202512011731.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from kumoapi.typing import Dtype, Stype
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(init=False, repr=False, eq=False)
|
|
8
|
+
class Column:
|
|
9
|
+
stype: Stype
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
name: str,
|
|
14
|
+
dtype: Dtype,
|
|
15
|
+
stype: Stype,
|
|
16
|
+
is_primary_key: bool = False,
|
|
17
|
+
is_time_column: bool = False,
|
|
18
|
+
is_end_time_column: bool = False,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._name = name
|
|
21
|
+
self._dtype = Dtype(dtype)
|
|
22
|
+
self._is_primary_key = is_primary_key
|
|
23
|
+
self._is_time_column = is_time_column
|
|
24
|
+
self._is_end_time_column = is_end_time_column
|
|
25
|
+
self.stype = Stype(stype)
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def name(self) -> str:
|
|
29
|
+
return self._name
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dtype(self) -> Dtype:
|
|
33
|
+
return self._dtype
|
|
34
|
+
|
|
35
|
+
def __setattr__(self, key: str, val: Any) -> None:
|
|
36
|
+
if key == 'stype':
|
|
37
|
+
if isinstance(val, str):
|
|
38
|
+
val = Stype(val)
|
|
39
|
+
assert isinstance(val, Stype)
|
|
40
|
+
if not val.supports_dtype(self.dtype):
|
|
41
|
+
raise ValueError(f"Column '{self.name}' received an "
|
|
42
|
+
f"incompatible semantic type (got "
|
|
43
|
+
f"dtype='{self.dtype}' and stype='{val}')")
|
|
44
|
+
if self._is_primary_key and val != Stype.ID:
|
|
45
|
+
raise ValueError(f"Primary key '{self.name}' must have 'ID' "
|
|
46
|
+
f"semantic type (got '{val}')")
|
|
47
|
+
if self._is_time_column and val != Stype.timestamp:
|
|
48
|
+
raise ValueError(f"Time column '{self.name}' must have "
|
|
49
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
50
|
+
if self._is_end_time_column and val != Stype.timestamp:
|
|
51
|
+
raise ValueError(f"End time column '{self.name}' must have "
|
|
52
|
+
f"'timestamp' semantic type (got '{val}')")
|
|
53
|
+
|
|
54
|
+
super().__setattr__(key, val)
|
|
55
|
+
|
|
56
|
+
def __hash__(self) -> int:
|
|
57
|
+
return hash((self.name, self.stype, self.dtype))
|
|
58
|
+
|
|
59
|
+
def __eq__(self, other: Any) -> bool:
|
|
60
|
+
if not isinstance(other, Column):
|
|
61
|
+
return False
|
|
62
|
+
return hash(self) == hash(other)
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
return (f'{self.__class__.__name__}(name={self.name}, '
|
|
66
|
+
f'stype={self.stype}, dtype={self.dtype})')
|
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
from
|
|
2
|
-
from typing import
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict, List, Optional, Sequence
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from kumoapi.source_table import UnavailableSourceTable
|
|
@@ -9,107 +9,17 @@ from kumoapi.typing import Dtype, Stype
|
|
|
9
9
|
from typing_extensions import Self
|
|
10
10
|
|
|
11
11
|
from kumoai import in_notebook
|
|
12
|
-
from kumoai.experimental.rfm import
|
|
12
|
+
from kumoai.experimental.rfm.base import Column
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
|
|
16
|
-
class
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
name: str,
|
|
22
|
-
dtype: Dtype,
|
|
23
|
-
stype: Stype,
|
|
24
|
-
is_primary_key: bool = False,
|
|
25
|
-
is_time_column: bool = False,
|
|
26
|
-
is_end_time_column: bool = False,
|
|
27
|
-
) -> None:
|
|
28
|
-
self._name = name
|
|
29
|
-
self._dtype = Dtype(dtype)
|
|
30
|
-
self._is_primary_key = is_primary_key
|
|
31
|
-
self._is_time_column = is_time_column
|
|
32
|
-
self._is_end_time_column = is_end_time_column
|
|
33
|
-
self.stype = Stype(stype)
|
|
34
|
-
|
|
35
|
-
@property
|
|
36
|
-
def name(self) -> str:
|
|
37
|
-
return self._name
|
|
38
|
-
|
|
39
|
-
@property
|
|
40
|
-
def dtype(self) -> Dtype:
|
|
41
|
-
return self._dtype
|
|
42
|
-
|
|
43
|
-
def __setattr__(self, key: str, val: Any) -> None:
|
|
44
|
-
if key == 'stype':
|
|
45
|
-
if isinstance(val, str):
|
|
46
|
-
val = Stype(val)
|
|
47
|
-
assert isinstance(val, Stype)
|
|
48
|
-
if not val.supports_dtype(self.dtype):
|
|
49
|
-
raise ValueError(f"Column '{self.name}' received an "
|
|
50
|
-
f"incompatible semantic type (got "
|
|
51
|
-
f"dtype='{self.dtype}' and stype='{val}')")
|
|
52
|
-
if self._is_primary_key and val != Stype.ID:
|
|
53
|
-
raise ValueError(f"Primary key '{self.name}' must have 'ID' "
|
|
54
|
-
f"semantic type (got '{val}')")
|
|
55
|
-
if self._is_time_column and val != Stype.timestamp:
|
|
56
|
-
raise ValueError(f"Time column '{self.name}' must have "
|
|
57
|
-
f"'timestamp' semantic type (got '{val}')")
|
|
58
|
-
if self._is_end_time_column and val != Stype.timestamp:
|
|
59
|
-
raise ValueError(f"End time column '{self.name}' must have "
|
|
60
|
-
f"'timestamp' semantic type (got '{val}')")
|
|
61
|
-
|
|
62
|
-
super().__setattr__(key, val)
|
|
63
|
-
|
|
64
|
-
def __hash__(self) -> int:
|
|
65
|
-
return hash((self.name, self.stype, self.dtype))
|
|
66
|
-
|
|
67
|
-
def __eq__(self, other: Any) -> bool:
|
|
68
|
-
if not isinstance(other, Column):
|
|
69
|
-
return False
|
|
70
|
-
return hash(self) == hash(other)
|
|
71
|
-
|
|
72
|
-
def __repr__(self) -> str:
|
|
73
|
-
return (f'{self.__class__.__name__}(name={self.name}, '
|
|
74
|
-
f'stype={self.stype}, dtype={self.dtype})')
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
class LocalTable:
|
|
78
|
-
r"""A table backed by a :class:`pandas.DataFrame`.
|
|
79
|
-
|
|
80
|
-
A :class:`LocalTable` fully specifies the relevant metadata, *i.e.*
|
|
81
|
-
selected columns, column semantic types, primary keys and time columns.
|
|
82
|
-
:class:`LocalTable` is used to create a :class:`LocalGraph`.
|
|
83
|
-
|
|
84
|
-
.. code-block:: python
|
|
85
|
-
|
|
86
|
-
import pandas as pd
|
|
87
|
-
import kumoai.experimental.rfm as rfm
|
|
88
|
-
|
|
89
|
-
# Load data from a CSV file:
|
|
90
|
-
df = pd.read_csv("data.csv")
|
|
91
|
-
|
|
92
|
-
# Create a table from a `pandas.DataFrame` and infer its metadata ...
|
|
93
|
-
table = rfm.LocalTable(df, name="my_table").infer_metadata()
|
|
94
|
-
|
|
95
|
-
# ... or create a table explicitly:
|
|
96
|
-
table = rfm.LocalTable(
|
|
97
|
-
df=df,
|
|
98
|
-
name="my_table",
|
|
99
|
-
primary_key="id",
|
|
100
|
-
time_column="time",
|
|
101
|
-
end_time_column=None,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
# Verify metadata:
|
|
105
|
-
table.print_metadata()
|
|
106
|
-
|
|
107
|
-
# Change the semantic type of a column:
|
|
108
|
-
table[column].stype = "text"
|
|
15
|
+
class Table(ABC):
|
|
16
|
+
r"""A :class:`Table` fully specifies the relevant metadata of a single
|
|
17
|
+
table, *i.e.* its selected columns, data types, semantic types, primary
|
|
18
|
+
keys and time columns.
|
|
109
19
|
|
|
110
20
|
Args:
|
|
111
|
-
|
|
112
|
-
|
|
21
|
+
name: The name of this table.
|
|
22
|
+
columns: The selected columns of this table.
|
|
113
23
|
primary_key: The name of the primary key of this table, if it exists.
|
|
114
24
|
time_column: The name of the time column of this table, if it exists.
|
|
115
25
|
end_time_column: The name of the end time column of this table, if it
|
|
@@ -117,46 +27,40 @@ class LocalTable:
|
|
|
117
27
|
"""
|
|
118
28
|
def __init__(
|
|
119
29
|
self,
|
|
120
|
-
df: pd.DataFrame,
|
|
121
30
|
name: str,
|
|
31
|
+
columns: Optional[Sequence[str]] = None,
|
|
122
32
|
primary_key: Optional[str] = None,
|
|
123
33
|
time_column: Optional[str] = None,
|
|
124
34
|
end_time_column: Optional[str] = None,
|
|
125
35
|
) -> None:
|
|
126
36
|
|
|
127
|
-
if df.empty:
|
|
128
|
-
raise ValueError("Data frame must have at least one row")
|
|
129
|
-
if isinstance(df.columns, pd.MultiIndex):
|
|
130
|
-
raise ValueError("Data frame must not have a multi-index")
|
|
131
|
-
if not df.columns.is_unique:
|
|
132
|
-
raise ValueError("Data frame must have unique column names")
|
|
133
|
-
if any(col == '' for col in df.columns):
|
|
134
|
-
raise ValueError("Data frame must have non-empty column names")
|
|
135
|
-
|
|
136
|
-
df = df.copy(deep=False)
|
|
137
|
-
|
|
138
|
-
self._data = df
|
|
139
37
|
self._name = name
|
|
140
38
|
self._primary_key: Optional[str] = None
|
|
141
39
|
self._time_column: Optional[str] = None
|
|
142
40
|
self._end_time_column: Optional[str] = None
|
|
143
41
|
|
|
144
42
|
self._columns: Dict[str, Column] = {}
|
|
145
|
-
for column_name in
|
|
43
|
+
for column_name in columns or []:
|
|
146
44
|
self.add_column(column_name)
|
|
147
45
|
|
|
148
46
|
if primary_key is not None:
|
|
47
|
+
if primary_key not in self:
|
|
48
|
+
self.add_column(primary_key)
|
|
149
49
|
self.primary_key = primary_key
|
|
150
50
|
|
|
151
51
|
if time_column is not None:
|
|
52
|
+
if time_column not in self:
|
|
53
|
+
self.add_column(time_column)
|
|
152
54
|
self.time_column = time_column
|
|
153
55
|
|
|
154
56
|
if end_time_column is not None:
|
|
57
|
+
if end_time_column not in self:
|
|
58
|
+
self.add_column(end_time_column)
|
|
155
59
|
self.end_time_column = end_time_column
|
|
156
60
|
|
|
157
61
|
@property
|
|
158
62
|
def name(self) -> str:
|
|
159
|
-
r"""The name of
|
|
63
|
+
r"""The name of this table."""
|
|
160
64
|
return self._name
|
|
161
65
|
|
|
162
66
|
# Data column #############################################################
|
|
@@ -200,24 +104,25 @@ class LocalTable:
|
|
|
200
104
|
raise KeyError(f"Column '{name}' already exists in table "
|
|
201
105
|
f"'{self.name}'")
|
|
202
106
|
|
|
203
|
-
if
|
|
204
|
-
raise KeyError(f"Column '{name}' does not exist in the
|
|
205
|
-
f"
|
|
107
|
+
if not self._has_source_column(name):
|
|
108
|
+
raise KeyError(f"Column '{name}' does not exist in the underlying "
|
|
109
|
+
f"source table")
|
|
206
110
|
|
|
207
111
|
try:
|
|
208
|
-
dtype =
|
|
112
|
+
dtype = self._get_source_dtype(name)
|
|
209
113
|
except Exception as e:
|
|
210
|
-
raise RuntimeError(f"
|
|
211
|
-
f"table '{self.name}'
|
|
212
|
-
f"
|
|
213
|
-
f"
|
|
114
|
+
raise RuntimeError(f"Could not obtain data type for column "
|
|
115
|
+
f"'{name}' in table '{self.name}'. Change "
|
|
116
|
+
f"the data type of the column in the source "
|
|
117
|
+
f"table or remove it from the table.") from e
|
|
118
|
+
|
|
214
119
|
try:
|
|
215
|
-
stype =
|
|
120
|
+
stype = self._get_source_stype(name, dtype)
|
|
216
121
|
except Exception as e:
|
|
217
|
-
raise RuntimeError(f"
|
|
218
|
-
f"in table '{self.name}'
|
|
219
|
-
f"
|
|
220
|
-
f"
|
|
122
|
+
raise RuntimeError(f"Could not obtain semantic type for column "
|
|
123
|
+
f"'{name}' in table '{self.name}'. Change "
|
|
124
|
+
f"the data type of the column in the source "
|
|
125
|
+
f"table or remove it from the table.") from e
|
|
221
126
|
|
|
222
127
|
self._columns[name] = Column(
|
|
223
128
|
name=name,
|
|
@@ -432,12 +337,14 @@ class LocalTable:
|
|
|
432
337
|
})
|
|
433
338
|
|
|
434
339
|
def print_metadata(self) -> None:
|
|
435
|
-
r"""Prints the :meth:`~
|
|
340
|
+
r"""Prints the :meth:`~metadata` of this table."""
|
|
341
|
+
num_rows = self._num_rows()
|
|
342
|
+
num_rows_repr = ' ({num_rows:,} rows)' if num_rows is not None else ''
|
|
343
|
+
|
|
436
344
|
if in_notebook():
|
|
437
345
|
from IPython.display import Markdown, display
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
f"({len(self._data):,} rows)"))
|
|
346
|
+
md_repr = f"### 🏷️ Metadata of Table `{self.name}`{num_rows_repr}"
|
|
347
|
+
display(Markdown(md_repr))
|
|
441
348
|
df = self.metadata
|
|
442
349
|
try:
|
|
443
350
|
if hasattr(df.style, 'hide'):
|
|
@@ -447,8 +354,7 @@ class LocalTable:
|
|
|
447
354
|
except ImportError:
|
|
448
355
|
print(df.to_string(index=False)) # missing jinja2
|
|
449
356
|
else:
|
|
450
|
-
print(f"🏷️ Metadata of Table '{self.name}'
|
|
451
|
-
f"({len(self._data):,} rows):")
|
|
357
|
+
print(f"🏷️ Metadata of Table '{self.name}'{num_rows_repr}")
|
|
452
358
|
print(self.metadata.to_string(index=False))
|
|
453
359
|
|
|
454
360
|
def infer_metadata(self, verbose: bool = True) -> Self:
|
|
@@ -478,11 +384,7 @@ class LocalTable:
|
|
|
478
384
|
column.name for column in self.columns if is_candidate(column)
|
|
479
385
|
]
|
|
480
386
|
|
|
481
|
-
if primary_key :=
|
|
482
|
-
table_name=self.name,
|
|
483
|
-
df=self._data,
|
|
484
|
-
candidates=candidates,
|
|
485
|
-
):
|
|
387
|
+
if primary_key := self._infer_primary_key(candidates):
|
|
486
388
|
self.primary_key = primary_key
|
|
487
389
|
logs.append(f"primary key '{primary_key}'")
|
|
488
390
|
|
|
@@ -493,7 +395,7 @@ class LocalTable:
|
|
|
493
395
|
if column.stype == Stype.timestamp
|
|
494
396
|
and column.name != self._end_time_column
|
|
495
397
|
]
|
|
496
|
-
if time_column :=
|
|
398
|
+
if time_column := self._infer_time_column(candidates):
|
|
497
399
|
self.time_column = time_column
|
|
498
400
|
logs.append(f"time column '{time_column}'")
|
|
499
401
|
|
|
@@ -543,3 +445,29 @@ class LocalTable:
|
|
|
543
445
|
f' time_column={self._time_column},\n'
|
|
544
446
|
f' end_time_column={self._end_time_column},\n'
|
|
545
447
|
f')')
|
|
448
|
+
|
|
449
|
+
# Abstract method #########################################################
|
|
450
|
+
|
|
451
|
+
@abstractmethod
|
|
452
|
+
def _has_source_column(self, name: str) -> bool:
|
|
453
|
+
pass
|
|
454
|
+
|
|
455
|
+
@abstractmethod
|
|
456
|
+
def _get_source_dtype(self, name: str) -> Dtype:
|
|
457
|
+
pass
|
|
458
|
+
|
|
459
|
+
@abstractmethod
|
|
460
|
+
def _get_source_stype(self, name: str, dtype: Dtype) -> Stype:
|
|
461
|
+
pass
|
|
462
|
+
|
|
463
|
+
@abstractmethod
|
|
464
|
+
def _infer_primary_key(self, candidates: List[str]) -> Optional[str]:
|
|
465
|
+
pass
|
|
466
|
+
|
|
467
|
+
@abstractmethod
|
|
468
|
+
def _infer_time_column(self, candidates: List[str]) -> Optional[str]:
|
|
469
|
+
pass
|
|
470
|
+
|
|
471
|
+
@abstractmethod
|
|
472
|
+
def _num_rows(self) -> Optional[int]:
|
|
473
|
+
pass
|
|
@@ -3,7 +3,7 @@ import io
|
|
|
3
3
|
import warnings
|
|
4
4
|
from collections import defaultdict
|
|
5
5
|
from importlib.util import find_spec
|
|
6
|
-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
|
|
7
7
|
|
|
8
8
|
import pandas as pd
|
|
9
9
|
from kumoapi.graph import ColumnKey, ColumnKeyGroup, GraphDefinition
|
|
@@ -12,19 +12,19 @@ from kumoapi.typing import Stype
|
|
|
12
12
|
from typing_extensions import Self
|
|
13
13
|
|
|
14
14
|
from kumoai import in_notebook
|
|
15
|
-
from kumoai.experimental.rfm import
|
|
15
|
+
from kumoai.experimental.rfm import Table
|
|
16
16
|
from kumoai.graph import Edge
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
19
|
import graphviz
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
class
|
|
23
|
-
r"""A graph of :class:`
|
|
22
|
+
class Graph:
|
|
23
|
+
r"""A graph of :class:`Table` objects, akin to relationships between
|
|
24
24
|
tables in a relational database.
|
|
25
25
|
|
|
26
26
|
Creating a graph is the final step of data definition; after a
|
|
27
|
-
:class:`
|
|
27
|
+
:class:`Graph` is created, you can use it to initialize the
|
|
28
28
|
Kumo Relational Foundation Model (:class:`KumoRFM`).
|
|
29
29
|
|
|
30
30
|
.. code-block:: python
|
|
@@ -44,7 +44,7 @@ class LocalGraph:
|
|
|
44
44
|
>>> table3 = rfm.LocalTable(name="table3", data=df3)
|
|
45
45
|
|
|
46
46
|
>>> # Create a graph from a dictionary of tables:
|
|
47
|
-
>>> graph = rfm.
|
|
47
|
+
>>> graph = rfm.Graph({
|
|
48
48
|
... "table1": table1,
|
|
49
49
|
... "table2": table2,
|
|
50
50
|
... "table3": table3,
|
|
@@ -75,11 +75,11 @@ class LocalGraph:
|
|
|
75
75
|
|
|
76
76
|
def __init__(
|
|
77
77
|
self,
|
|
78
|
-
tables:
|
|
79
|
-
edges: Optional[
|
|
78
|
+
tables: Sequence[Table],
|
|
79
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
80
80
|
) -> None:
|
|
81
81
|
|
|
82
|
-
self._tables: Dict[str,
|
|
82
|
+
self._tables: Dict[str, Table] = {}
|
|
83
83
|
self._edges: List[Edge] = []
|
|
84
84
|
|
|
85
85
|
for table in tables:
|
|
@@ -94,11 +94,11 @@ class LocalGraph:
|
|
|
94
94
|
def from_data(
|
|
95
95
|
cls,
|
|
96
96
|
df_dict: Dict[str, pd.DataFrame],
|
|
97
|
-
edges: Optional[
|
|
97
|
+
edges: Optional[Sequence[Edge]] = None,
|
|
98
98
|
infer_metadata: bool = True,
|
|
99
99
|
verbose: bool = True,
|
|
100
100
|
) -> Self:
|
|
101
|
-
r"""Creates a :class:`
|
|
101
|
+
r"""Creates a :class:`Graph` from a dictionary of
|
|
102
102
|
:class:`pandas.DataFrame` objects.
|
|
103
103
|
|
|
104
104
|
Automatically infers table metadata and links.
|
|
@@ -115,7 +115,7 @@ class LocalGraph:
|
|
|
115
115
|
>>> df3 = pd.DataFrame(...)
|
|
116
116
|
|
|
117
117
|
>>> # Create a graph from a dictionary of data frames:
|
|
118
|
-
>>> graph = rfm.
|
|
118
|
+
>>> graph = rfm.Graph.from_data({
|
|
119
119
|
... "table1": df1,
|
|
120
120
|
... "table2": df2,
|
|
121
121
|
... "table3": df3,
|
|
@@ -148,13 +148,14 @@ class LocalGraph:
|
|
|
148
148
|
>>> df1 = pd.DataFrame(...)
|
|
149
149
|
>>> df2 = pd.DataFrame(...)
|
|
150
150
|
>>> df3 = pd.DataFrame(...)
|
|
151
|
-
>>> graph = rfm.
|
|
151
|
+
>>> graph = rfm.Graph.from_data(data={
|
|
152
152
|
... "table1": df1,
|
|
153
153
|
... "table2": df2,
|
|
154
154
|
... "table3": df3,
|
|
155
155
|
... })
|
|
156
156
|
>>> graph.validate()
|
|
157
157
|
"""
|
|
158
|
+
from kumoai.experimental.rfm import LocalTable
|
|
158
159
|
tables = [LocalTable(df, name) for name, df in df_dict.items()]
|
|
159
160
|
|
|
160
161
|
graph = cls(tables, edges=edges or [])
|
|
@@ -175,7 +176,7 @@ class LocalGraph:
|
|
|
175
176
|
"""
|
|
176
177
|
return name in self.tables
|
|
177
178
|
|
|
178
|
-
def table(self, name: str) ->
|
|
179
|
+
def table(self, name: str) -> Table:
|
|
179
180
|
r"""Returns the table with name ``name`` in the graph.
|
|
180
181
|
|
|
181
182
|
Raises:
|
|
@@ -186,11 +187,11 @@ class LocalGraph:
|
|
|
186
187
|
return self.tables[name]
|
|
187
188
|
|
|
188
189
|
@property
|
|
189
|
-
def tables(self) -> Dict[str,
|
|
190
|
+
def tables(self) -> Dict[str, Table]:
|
|
190
191
|
r"""Returns the dictionary of table objects."""
|
|
191
192
|
return self._tables
|
|
192
193
|
|
|
193
|
-
def add_table(self, table:
|
|
194
|
+
def add_table(self, table: Table) -> Self:
|
|
194
195
|
r"""Adds a table to the graph.
|
|
195
196
|
|
|
196
197
|
Args:
|
|
@@ -199,11 +200,21 @@ class LocalGraph:
|
|
|
199
200
|
Raises:
|
|
200
201
|
KeyError: If a table with the same name already exists in the
|
|
201
202
|
graph.
|
|
203
|
+
ValueError: If the table belongs to a different backend than the
|
|
204
|
+
rest of the tables in the graph.
|
|
202
205
|
"""
|
|
203
206
|
if table.name in self._tables:
|
|
204
207
|
raise KeyError(f"Cannot add table with name '{table.name}' to "
|
|
205
208
|
f"this graph; table names must be globally unique.")
|
|
206
209
|
|
|
210
|
+
if len(self._tables) > 0:
|
|
211
|
+
cls = next(iter(self._tables.values())).__class__
|
|
212
|
+
if table.__class__ != cls:
|
|
213
|
+
raise ValueError(f"Cannot register a "
|
|
214
|
+
f"'{table.__class__.__name__}' to this "
|
|
215
|
+
f"graph since other tables are of type "
|
|
216
|
+
f"'{cls.__name__}'.")
|
|
217
|
+
|
|
207
218
|
self._tables[table.name] = table
|
|
208
219
|
|
|
209
220
|
return self
|
|
@@ -241,7 +252,7 @@ class LocalGraph:
|
|
|
241
252
|
Example:
|
|
242
253
|
>>> # doctest: +SKIP
|
|
243
254
|
>>> import kumoai.experimental.rfm as rfm
|
|
244
|
-
>>> graph = rfm.
|
|
255
|
+
>>> graph = rfm.Graph(tables=...).infer_metadata()
|
|
245
256
|
>>> graph.metadata # doctest: +SKIP
|
|
246
257
|
name primary_key time_column end_time_column
|
|
247
258
|
0 users user_id - -
|
|
@@ -263,7 +274,7 @@ class LocalGraph:
|
|
|
263
274
|
})
|
|
264
275
|
|
|
265
276
|
def print_metadata(self) -> None:
|
|
266
|
-
r"""Prints the :meth:`~
|
|
277
|
+
r"""Prints the :meth:`~Graph.metadata` of the graph."""
|
|
267
278
|
if in_notebook():
|
|
268
279
|
from IPython.display import Markdown, display
|
|
269
280
|
display(Markdown('### 🗂️ Graph Metadata'))
|
|
@@ -287,7 +298,7 @@ class LocalGraph:
|
|
|
287
298
|
|
|
288
299
|
Note:
|
|
289
300
|
For more information, please see
|
|
290
|
-
:meth:`kumoai.experimental.rfm.
|
|
301
|
+
:meth:`kumoai.experimental.rfm.Table.infer_metadata`.
|
|
291
302
|
"""
|
|
292
303
|
for table in self.tables.values():
|
|
293
304
|
table.infer_metadata(verbose=False)
|
|
@@ -305,7 +316,7 @@ class LocalGraph:
|
|
|
305
316
|
return self._edges
|
|
306
317
|
|
|
307
318
|
def print_links(self) -> None:
|
|
308
|
-
r"""Prints the :meth:`~
|
|
319
|
+
r"""Prints the :meth:`~Graph.edges` of the graph."""
|
|
309
320
|
edges = [(edge.dst_table, self[edge.dst_table]._primary_key,
|
|
310
321
|
edge.src_table, edge.fkey) for edge in self.edges]
|
|
311
322
|
edges = sorted(edges)
|
|
@@ -333,9 +344,9 @@ class LocalGraph:
|
|
|
333
344
|
|
|
334
345
|
def link(
|
|
335
346
|
self,
|
|
336
|
-
src_table: Union[str,
|
|
347
|
+
src_table: Union[str, Table],
|
|
337
348
|
fkey: str,
|
|
338
|
-
dst_table: Union[str,
|
|
349
|
+
dst_table: Union[str, Table],
|
|
339
350
|
) -> Self:
|
|
340
351
|
r"""Links two tables (``src_table`` and ``dst_table``) from the foreign
|
|
341
352
|
key ``fkey`` in the source table to the primary key in the destination
|
|
@@ -358,11 +369,11 @@ class LocalGraph:
|
|
|
358
369
|
table does not exist in the graph, if the source key does not
|
|
359
370
|
exist in the source table.
|
|
360
371
|
"""
|
|
361
|
-
if isinstance(src_table,
|
|
372
|
+
if isinstance(src_table, Table):
|
|
362
373
|
src_table = src_table.name
|
|
363
374
|
assert isinstance(src_table, str)
|
|
364
375
|
|
|
365
|
-
if isinstance(dst_table,
|
|
376
|
+
if isinstance(dst_table, Table):
|
|
366
377
|
dst_table = dst_table.name
|
|
367
378
|
assert isinstance(dst_table, str)
|
|
368
379
|
|
|
@@ -396,9 +407,9 @@ class LocalGraph:
|
|
|
396
407
|
|
|
397
408
|
def unlink(
|
|
398
409
|
self,
|
|
399
|
-
src_table: Union[str,
|
|
410
|
+
src_table: Union[str, Table],
|
|
400
411
|
fkey: str,
|
|
401
|
-
dst_table: Union[str,
|
|
412
|
+
dst_table: Union[str, Table],
|
|
402
413
|
) -> Self:
|
|
403
414
|
r"""Removes an :class:`~kumoai.graph.Edge` from the graph.
|
|
404
415
|
|
|
@@ -410,11 +421,11 @@ class LocalGraph:
|
|
|
410
421
|
Raises:
|
|
411
422
|
ValueError: if the edge is not present in the graph.
|
|
412
423
|
"""
|
|
413
|
-
if isinstance(src_table,
|
|
424
|
+
if isinstance(src_table, Table):
|
|
414
425
|
src_table = src_table.name
|
|
415
426
|
assert isinstance(src_table, str)
|
|
416
427
|
|
|
417
|
-
if isinstance(dst_table,
|
|
428
|
+
if isinstance(dst_table, Table):
|
|
418
429
|
dst_table = dst_table.name
|
|
419
430
|
assert isinstance(dst_table, str)
|
|
420
431
|
|
|
@@ -528,7 +539,10 @@ class LocalGraph:
|
|
|
528
539
|
score += 1.0
|
|
529
540
|
|
|
530
541
|
# Cardinality ratio:
|
|
531
|
-
|
|
542
|
+
src_num_rows = src_table._num_rows()
|
|
543
|
+
dst_num_rows = dst_table._num_rows()
|
|
544
|
+
if (src_num_rows is not None and dst_num_rows is not None
|
|
545
|
+
and src_num_rows > dst_num_rows):
|
|
532
546
|
score += 1.0
|
|
533
547
|
|
|
534
548
|
if score < 5.0:
|
|
@@ -790,7 +804,7 @@ class LocalGraph:
|
|
|
790
804
|
def __contains__(self, name: str) -> bool:
|
|
791
805
|
return self.has_table(name)
|
|
792
806
|
|
|
793
|
-
def __getitem__(self, name: str) ->
|
|
807
|
+
def __getitem__(self, name: str) -> Table:
|
|
794
808
|
return self.table(name)
|
|
795
809
|
|
|
796
810
|
def __delitem__(self, name: str) -> None:
|
|
@@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Tuple
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import pandas as pd
|
|
5
|
-
from kumoapi.model_plan import RunMode
|
|
6
5
|
from kumoapi.rfm.context import EdgeLayout, Link, Subgraph, Table
|
|
7
6
|
from kumoapi.typing import Stype
|
|
8
7
|
|
|
@@ -33,7 +32,6 @@ class LocalGraphSampler:
|
|
|
33
32
|
entity_table_names: Tuple[str, ...],
|
|
34
33
|
node: np.ndarray,
|
|
35
34
|
time: np.ndarray,
|
|
36
|
-
run_mode: RunMode,
|
|
37
35
|
num_neighbors: List[int],
|
|
38
36
|
exclude_cols_dict: Dict[str, List[str]],
|
|
39
37
|
) -> Subgraph:
|